# Copyright 2019-present NAVER Corp. # CC BY-NC-SA 3.0 # Available only for non-commercial use import pdb from tqdm import tqdm from collections import defaultdict import torch import torch.nn as nn class Trainer(nn.Module): """Helper class to train a deep network. Overload this class `forward_backward` for your actual needs. Usage: train = Trainer(net, loader, loss, optimizer) for epoch in range(n_epochs): train() """ def __init__(self, net, loader, loss, optimizer): nn.Module.__init__(self) self.net = net self.loader = loader self.loss_func = loss self.optimizer = optimizer def iscuda(self): return next(self.net.parameters()).device != torch.device("cpu") def todevice(self, x): if isinstance(x, dict): return {k: self.todevice(v) for k, v in x.items()} if isinstance(x, (tuple, list)): return [self.todevice(v) for v in x] if self.iscuda(): return x.contiguous().cuda(non_blocking=True) else: return x.cpu() def __call__(self): self.net.train() stats = defaultdict(list) for iter, inputs in enumerate(tqdm(self.loader)): inputs = self.todevice(inputs) # compute gradient and do model update self.optimizer.zero_grad() loss, details = self.forward_backward(inputs) if torch.isnan(loss): raise RuntimeError("Loss is NaN") self.optimizer.step() for key, val in details.items(): stats[key].append(val) print(" Summary of losses during this epoch:") mean = lambda lis: sum(lis) / len(lis) for loss_name, vals in stats.items(): N = 1 + len(vals) // 10 print(f" - {loss_name:20}:", end="") print( f" {mean(vals[:N]):.3f} --> {mean(vals[-N:]):.3f} (avg: {mean(vals):.3f})" ) return mean(stats["loss"]) # return average loss def forward_backward(self, inputs): raise NotImplementedError()