Matyáš Boháček
Init commit
a001524
raw
history blame contribute delete
No virus
2.35 kB
import logging
import torch
def train_epoch(model, dataloader, criterion, optimizer, device, scheduler=None):
pred_correct, pred_all = 0, 0
running_loss = 0.0
for i, data in enumerate(dataloader):
inputs, labels = data
inputs = inputs.squeeze(0).to(device)
labels = labels.to(device, dtype=torch.long)
optimizer.zero_grad()
outputs = model(inputs).expand(1, -1, -1)
loss = criterion(outputs[0], labels[0])
loss.backward()
optimizer.step()
running_loss += loss
# Statistics
if int(torch.argmax(torch.nn.functional.softmax(outputs, dim=2))) == int(labels[0][0]):
pred_correct += 1
pred_all += 1
if scheduler:
scheduler.step(running_loss.item() / len(dataloader))
return running_loss, pred_correct, pred_all, (pred_correct / pred_all)
def evaluate(model, dataloader, device, print_stats=False):
pred_correct, pred_all = 0, 0
stats = {i: [0, 0] for i in range(101)}
for i, data in enumerate(dataloader):
inputs, labels = data
inputs = inputs.squeeze(0).to(device)
labels = labels.to(device, dtype=torch.long)
outputs = model(inputs).expand(1, -1, -1)
# Statistics
if int(torch.argmax(torch.nn.functional.softmax(outputs, dim=2))) == int(labels[0][0]):
stats[int(labels[0][0])][0] += 1
pred_correct += 1
stats[int(labels[0][0])][1] += 1
pred_all += 1
if print_stats:
stats = {key: value[0] / value[1] for key, value in stats.items() if value[1] != 0}
print("Label accuracies statistics:")
print(str(stats) + "\n")
logging.info("Label accuracies statistics:")
logging.info(str(stats) + "\n")
return pred_correct, pred_all, (pred_correct / pred_all)
def evaluate_top_k(model, dataloader, device, k=5):
pred_correct, pred_all = 0, 0
for i, data in enumerate(dataloader):
inputs, labels = data
inputs = inputs.squeeze(0).to(device)
labels = labels.to(device, dtype=torch.long)
outputs = model(inputs).expand(1, -1, -1)
if int(labels[0][0]) in torch.topk(outputs, k).indices.tolist():
pred_correct += 1
pred_all += 1
return pred_correct, pred_all, (pred_correct / pred_all)