Spaces:
Build error
Build error
import logging | |
import torch | |
def train_epoch(model, dataloader, criterion, optimizer, device, scheduler=None): | |
pred_correct, pred_all = 0, 0 | |
running_loss = 0.0 | |
for i, data in enumerate(dataloader): | |
inputs, labels = data | |
inputs = inputs.squeeze(0).to(device) | |
labels = labels.to(device, dtype=torch.long) | |
optimizer.zero_grad() | |
outputs = model(inputs).expand(1, -1, -1) | |
loss = criterion(outputs[0], labels[0]) | |
loss.backward() | |
optimizer.step() | |
running_loss += loss | |
# Statistics | |
if int(torch.argmax(torch.nn.functional.softmax(outputs, dim=2))) == int(labels[0][0]): | |
pred_correct += 1 | |
pred_all += 1 | |
if scheduler: | |
scheduler.step(running_loss.item() / len(dataloader)) | |
return running_loss, pred_correct, pred_all, (pred_correct / pred_all) | |
def evaluate(model, dataloader, device, print_stats=False): | |
pred_correct, pred_all = 0, 0 | |
stats = {i: [0, 0] for i in range(101)} | |
for i, data in enumerate(dataloader): | |
inputs, labels = data | |
inputs = inputs.squeeze(0).to(device) | |
labels = labels.to(device, dtype=torch.long) | |
outputs = model(inputs).expand(1, -1, -1) | |
# Statistics | |
if int(torch.argmax(torch.nn.functional.softmax(outputs, dim=2))) == int(labels[0][0]): | |
stats[int(labels[0][0])][0] += 1 | |
pred_correct += 1 | |
stats[int(labels[0][0])][1] += 1 | |
pred_all += 1 | |
if print_stats: | |
stats = {key: value[0] / value[1] for key, value in stats.items() if value[1] != 0} | |
print("Label accuracies statistics:") | |
print(str(stats) + "\n") | |
logging.info("Label accuracies statistics:") | |
logging.info(str(stats) + "\n") | |
return pred_correct, pred_all, (pred_correct / pred_all) | |
def evaluate_top_k(model, dataloader, device, k=5): | |
pred_correct, pred_all = 0, 0 | |
for i, data in enumerate(dataloader): | |
inputs, labels = data | |
inputs = inputs.squeeze(0).to(device) | |
labels = labels.to(device, dtype=torch.long) | |
outputs = model(inputs).expand(1, -1, -1) | |
if int(labels[0][0]) in torch.topk(outputs, k).indices.tolist(): | |
pred_correct += 1 | |
pred_all += 1 | |
return pred_correct, pred_all, (pred_correct / pred_all) | |