File size: 5,006 Bytes
2fbcf51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import timeit
import numpy as np
import os
import os.path as osp
import shutil
import copy
import torch
import torch.nn as nn
import torch.distributed as dist
from .cfg_holder import cfg_unique_holder as cfguh
from . import sync

print_console_local_rank0_only = True

def print_log(*console_info):
    local_rank = sync.get_rank('local')
    if print_console_local_rank0_only and (local_rank!=0):
        return
    console_info = [str(i) for i in console_info]
    console_info = ' '.join(console_info)
    print(console_info)

    if local_rank!=0:
        return

    log_file = None
    try:
        log_file = cfguh().cfg.train.log_file
    except:
        try:
            log_file = cfguh().cfg.eval.log_file
        except:
            return
    if log_file is not None:
        with open(log_file, 'a') as f:
            f.write(console_info + '\n')

class distributed_log_manager(object):
    def __init__(self):
        self.sum = {}
        self.cnt = {}
        self.time_check = timeit.default_timer()

        cfgt = cfguh().cfg.train
        use_tensorboard = getattr(cfgt, 'log_tensorboard', False)

        self.ddp = sync.is_ddp()
        self.rank = sync.get_rank('local')
        self.world_size = sync.get_world_size('local')

        self.tb = None
        if use_tensorboard and (self.rank==0):
            import tensorboardX
            monitoring_dir = osp.join(cfguh().cfg.train.log_dir, 'tensorboard')
            self.tb = tensorboardX.SummaryWriter(osp.join(monitoring_dir))

    def accumulate(self, n, **data):
        if n < 0:
            raise ValueError

        for itemn, di in data.items():
            if itemn in self.sum:
                self.sum[itemn] += di * n
                self.cnt[itemn] += n
            else:
                self.sum[itemn] = di * n
                self.cnt[itemn] = n

    def get_mean_value_dict(self):
        value_gather = [
            self.sum[itemn]/self.cnt[itemn] \
                for itemn in sorted(self.sum.keys()) ]

        value_gather_tensor = torch.FloatTensor(value_gather).to(self.rank)
        if self.ddp:
            dist.all_reduce(value_gather_tensor, op=dist.ReduceOp.SUM)
            value_gather_tensor /= self.world_size

        mean = {}
        for idx, itemn in enumerate(sorted(self.sum.keys())):
            mean[itemn] = value_gather_tensor[idx].item()
        return mean

    def tensorboard_log(self, step, data, mode='train', **extra):
        if self.tb is None:
            return
        if mode == 'train':
            self.tb.add_scalar('other/epochn', extra['epochn'], step)
            if 'lr' in extra:
                self.tb.add_scalar('other/lr', extra['lr'], step)
            for itemn, di in data.items():
                if itemn.find('loss') == 0:
                    self.tb.add_scalar('loss/'+itemn,  di, step)
                elif itemn == 'Loss':
                    self.tb.add_scalar('Loss',  di, step)
                else:
                    self.tb.add_scalar('other/'+itemn, di, step)
        elif mode == 'eval':
            if isinstance(data, dict):
                for itemn, di in data.items():
                    self.tb.add_scalar('eval/'+itemn, di, step)
            else:
                self.tb.add_scalar('eval', data, step)
        return

    def train_summary(self, itern, epochn, samplen, lr, tbstep=None):
        console_info = [
            'Iter:{}'.format(itern),
            'Epoch:{}'.format(epochn),
            'Sample:{}'.format(samplen),]

        if lr is not None:
            console_info += ['LR:{:.4E}'.format(lr)]

        mean = self.get_mean_value_dict()

        tbstep = itern if tbstep is None else tbstep
        self.tensorboard_log(
            tbstep, mean, mode='train',
            itern=itern, epochn=epochn, lr=lr)

        loss = mean.pop('Loss')
        mean_info = ['Loss:{:.4f}'.format(loss)] + [
            '{}:{:.4f}'.format(itemn, mean[itemn]) \
                for itemn in sorted(mean.keys()) \
                    if itemn.find('loss') == 0
        ]
        console_info += mean_info
        console_info.append('Time:{:.2f}s'.format(
            timeit.default_timer() - self.time_check))
        return ' , '.join(console_info)

    def clear(self):
        self.sum = {}
        self.cnt = {}
        self.time_check = timeit.default_timer()

    def tensorboard_close(self):
        if self.tb is not None:
            self.tb.close()

# ----- also include some small utils -----

def torch_to_numpy(*argv):
    if len(argv) > 1:
        data = list(argv)
    else:
        data = argv[0]

    if isinstance(data, torch.Tensor):
        return data.to('cpu').detach().numpy()

    elif isinstance(data, (list, tuple)):
        out = []
        for di in data:
            out.append(torch_to_numpy(di))
        return out

    elif isinstance(data, dict):
        out = {}
        for ni, di in data.items():
            out[ni] = torch_to_numpy(di)
        return out
    
    else:
        return data