Vincentqyw
fix: roma
8b973ee
raw
history blame
2.15 kB
# Copyright 2019-present NAVER Corp.
# CC BY-NC-SA 3.0
# Available only for non-commercial use
import pdb
from tqdm import tqdm
from collections import defaultdict
import torch
import torch.nn as nn
class Trainer(nn.Module):
"""Helper class to train a deep network.
Overload this class `forward_backward` for your actual needs.
Usage:
train = Trainer(net, loader, loss, optimizer)
for epoch in range(n_epochs):
train()
"""
def __init__(self, net, loader, loss, optimizer):
nn.Module.__init__(self)
self.net = net
self.loader = loader
self.loss_func = loss
self.optimizer = optimizer
def iscuda(self):
return next(self.net.parameters()).device != torch.device("cpu")
def todevice(self, x):
if isinstance(x, dict):
return {k: self.todevice(v) for k, v in x.items()}
if isinstance(x, (tuple, list)):
return [self.todevice(v) for v in x]
if self.iscuda():
return x.contiguous().cuda(non_blocking=True)
else:
return x.cpu()
def __call__(self):
self.net.train()
stats = defaultdict(list)
for iter, inputs in enumerate(tqdm(self.loader)):
inputs = self.todevice(inputs)
# compute gradient and do model update
self.optimizer.zero_grad()
loss, details = self.forward_backward(inputs)
if torch.isnan(loss):
raise RuntimeError("Loss is NaN")
self.optimizer.step()
for key, val in details.items():
stats[key].append(val)
print(" Summary of losses during this epoch:")
mean = lambda lis: sum(lis) / len(lis)
for loss_name, vals in stats.items():
N = 1 + len(vals) // 10
print(f" - {loss_name:20}:", end="")
print(
f" {mean(vals[:N]):.3f} --> {mean(vals[-N:]):.3f} (avg: {mean(vals):.3f})"
)
return mean(stats["loss"]) # return average loss
def forward_backward(self, inputs):
raise NotImplementedError()