Spaces:
Build error
Build error
| import logging | |
| import torch | |
| def train_epoch(model, dataloader, criterion, optimizer, device, scheduler=None): | |
| pred_correct, pred_all = 0, 0 | |
| running_loss = 0.0 | |
| for i, data in enumerate(dataloader): | |
| inputs, labels = data | |
| inputs = inputs.squeeze(0).to(device) | |
| labels = labels.to(device, dtype=torch.long) | |
| optimizer.zero_grad() | |
| outputs = model(inputs).expand(1, -1, -1) | |
| loss = criterion(outputs[0], labels[0]) | |
| loss.backward() | |
| optimizer.step() | |
| running_loss += loss | |
| # Statistics | |
| if int(torch.argmax(torch.nn.functional.softmax(outputs, dim=2))) == int(labels[0][0]): | |
| pred_correct += 1 | |
| pred_all += 1 | |
| if scheduler: | |
| scheduler.step(running_loss.item() / len(dataloader)) | |
| return running_loss, pred_correct, pred_all, (pred_correct / pred_all) | |
| def evaluate(model, dataloader, device, print_stats=False): | |
| pred_correct, pred_all = 0, 0 | |
| stats = {i: [0, 0] for i in range(101)} | |
| for i, data in enumerate(dataloader): | |
| inputs, labels = data | |
| inputs = inputs.squeeze(0).to(device) | |
| labels = labels.to(device, dtype=torch.long) | |
| outputs = model(inputs).expand(1, -1, -1) | |
| # Statistics | |
| if int(torch.argmax(torch.nn.functional.softmax(outputs, dim=2))) == int(labels[0][0]): | |
| stats[int(labels[0][0])][0] += 1 | |
| pred_correct += 1 | |
| stats[int(labels[0][0])][1] += 1 | |
| pred_all += 1 | |
| if print_stats: | |
| stats = {key: value[0] / value[1] for key, value in stats.items() if value[1] != 0} | |
| print("Label accuracies statistics:") | |
| print(str(stats) + "\n") | |
| logging.info("Label accuracies statistics:") | |
| logging.info(str(stats) + "\n") | |
| return pred_correct, pred_all, (pred_correct / pred_all) | |
| def evaluate_top_k(model, dataloader, device, k=5): | |
| pred_correct, pred_all = 0, 0 | |
| for i, data in enumerate(dataloader): | |
| inputs, labels = data | |
| inputs = inputs.squeeze(0).to(device) | |
| labels = labels.to(device, dtype=torch.long) | |
| outputs = model(inputs).expand(1, -1, -1) | |
| if int(labels[0][0]) in torch.topk(outputs, k).indices.tolist(): | |
| pred_correct += 1 | |
| pred_all += 1 | |
| return pred_correct, pred_all, (pred_correct / pred_all) | |