Spaces:
Running
Running
# Copyright 2019-present NAVER Corp. | |
# CC BY-NC-SA 3.0 | |
# Available only for non-commercial use | |
import pdb | |
from tqdm import tqdm | |
from collections import defaultdict | |
import torch | |
import torch.nn as nn | |
class Trainer(nn.Module): | |
"""Helper class to train a deep network. | |
Overload this class `forward_backward` for your actual needs. | |
Usage: | |
train = Trainer(net, loader, loss, optimizer) | |
for epoch in range(n_epochs): | |
train() | |
""" | |
def __init__(self, net, loader, loss, optimizer): | |
nn.Module.__init__(self) | |
self.net = net | |
self.loader = loader | |
self.loss_func = loss | |
self.optimizer = optimizer | |
def iscuda(self): | |
return next(self.net.parameters()).device != torch.device("cpu") | |
def todevice(self, x): | |
if isinstance(x, dict): | |
return {k: self.todevice(v) for k, v in x.items()} | |
if isinstance(x, (tuple, list)): | |
return [self.todevice(v) for v in x] | |
if self.iscuda(): | |
return x.contiguous().cuda(non_blocking=True) | |
else: | |
return x.cpu() | |
def __call__(self): | |
self.net.train() | |
stats = defaultdict(list) | |
for iter, inputs in enumerate(tqdm(self.loader)): | |
inputs = self.todevice(inputs) | |
# compute gradient and do model update | |
self.optimizer.zero_grad() | |
loss, details = self.forward_backward(inputs) | |
if torch.isnan(loss): | |
raise RuntimeError("Loss is NaN") | |
self.optimizer.step() | |
for key, val in details.items(): | |
stats[key].append(val) | |
print(" Summary of losses during this epoch:") | |
mean = lambda lis: sum(lis) / len(lis) | |
for loss_name, vals in stats.items(): | |
N = 1 + len(vals) // 10 | |
print(f" - {loss_name:20}:", end="") | |
print( | |
f" {mean(vals[:N]):.3f} --> {mean(vals[-N:]):.3f} (avg: {mean(vals):.3f})" | |
) | |
return mean(stats["loss"]) # return average loss | |
def forward_backward(self, inputs): | |
raise NotImplementedError() | |