File size: 3,526 Bytes
3ef85e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# Copyright 2022-present NAVER Corp.
# CC BY-NC-SA 4.0
# Available only for non-commercial use

import pdb; bb = pdb.set_trace
from tqdm import tqdm
from collections import defaultdict

import torch
import torch.nn as nn
from torch.nn import DataParallel

from .common import todevice


class Trainer (nn.Module):
    """ Helper class to train a deep network.
        Overload this class `forward_backward` for your actual needs.
    
    Usage: 
        train = Trainer(net, loss, optimizer)
        for epoch in range(n_epochs):
            train()
    """
    def __init__(self, net, loss, optimizer, epoch=0):
        super().__init__()
        self.net = net
        self.loss = loss
        self.optimizer = optimizer
        self.epoch = epoch

    @property
    def device(self):
        return next(self.net.parameters()).device

    @property
    def model(self):
        return self.net.module if isinstance(self.net, DataParallel) else self.net

    def distribute(self):
        self.net = DataParallel(self.net) # DataDistributed not implemented yet

    def __call__(self, data_loader):
        print(f'>> Training (epoch {self.epoch} --> {self.epoch+1})')
        self.net.train()

        stats = defaultdict(list)

        for batch in tqdm(data_loader):
            batch = todevice(batch, self.device)
            
            # compute gradient and do model update
            self.optimizer.zero_grad()
            details = self.forward_backward(batch)
            self.optimizer.step()

            for key, val in details.items():
                stats[key].append( val )

        self.epoch += 1

        print("   Summary of losses during this epoch:")
        for loss_name, vals in stats.items():
            N = 1 + len(vals)//10
            print(f"    - {loss_name:10}: {avg(vals[:N]):.3f} --> {avg(vals[-N:]):.3f} (avg: {avg(vals):.3f})")

    def forward_backward(self, inputs):
        raise NotImplementedError()

    def save(self, path):
        print(f"\n>> Saving model to {path}")

        data = {'model': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'loss': self.loss.state_dict(),
                'epoch': self.epoch}

        torch.save(data, open(path,'wb'))

    def load(self, path, resume=True):
        print(f">> Loading weights from {path} ...")
        checkpoint = torch.load(path, map_location='cpu')
        assert isinstance(checkpoint, dict)

        self.net.load_state_dict(checkpoint['model'])
        if resume:
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.loss.load_state_dict(checkpoint['optimizer'])
            self.epoch = checkpoint['epoch']
            print(f"   Resuming training at Epoch {self.epoch}!")


def get_loss( loss ):
    """ returns a tuple (loss, dictionary of loss details)
    """
    assert isinstance(loss, dict)
    grads = None

    k,l = next(iter(loss.items())) # first item is assumed to be the main loss
    if isinstance(l, tuple):
        l, grads = l
        loss[k] = l

    return (l, grads), {k:float(v) for k,v in loss.items()}


def backward( loss ):
    if isinstance(loss, tuple):
        loss, grads = loss
    else:
        loss, grads = (loss, None)

    assert loss == loss, 'loss is NaN'

    if grads is None:
        loss.backward()
    else:
        # dictionary of separate subgraphs
        for var,grad in grads:
             var.backward(grad)
    return float(loss)


def avg( lis ):
    return sum(lis) / len(lis)