doevent commited on
Commit
c74fb4f
1 Parent(s): 7a69f1b

Upload utils.py

Browse files
Files changed (1) hide show
  1. utils.py +278 -0
utils.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
3
+ """Decay the learning rate"""
4
+ lr = (init_lr - min_lr) * 0.5 * (1. + math.cos(math.pi * epoch / max_epoch)) + min_lr
5
+ for param_group in optimizer.param_groups:
6
+ param_group['lr'] = lr
7
+
8
+ def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
9
+ """Warmup the learning rate"""
10
+ lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max_step)
11
+ for param_group in optimizer.param_groups:
12
+ param_group['lr'] = lr
13
+
14
+ def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
15
+ """Decay the learning rate"""
16
+ lr = max(min_lr, init_lr * (decay_rate**epoch))
17
+ for param_group in optimizer.param_groups:
18
+ param_group['lr'] = lr
19
+
20
+ import numpy as np
21
+ import io
22
+ import os
23
+ import time
24
+ from collections import defaultdict, deque
25
+ import datetime
26
+
27
+ import torch
28
+ import torch.distributed as dist
29
+
30
+ class SmoothedValue(object):
31
+ """Track a series of values and provide access to smoothed values over a
32
+ window or the global series average.
33
+ """
34
+
35
+ def __init__(self, window_size=20, fmt=None):
36
+ if fmt is None:
37
+ fmt = "{median:.4f} ({global_avg:.4f})"
38
+ self.deque = deque(maxlen=window_size)
39
+ self.total = 0.0
40
+ self.count = 0
41
+ self.fmt = fmt
42
+
43
+ def update(self, value, n=1):
44
+ self.deque.append(value)
45
+ self.count += n
46
+ self.total += value * n
47
+
48
+ def synchronize_between_processes(self):
49
+ """
50
+ Warning: does not synchronize the deque!
51
+ """
52
+ if not is_dist_avail_and_initialized():
53
+ return
54
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
55
+ dist.barrier()
56
+ dist.all_reduce(t)
57
+ t = t.tolist()
58
+ self.count = int(t[0])
59
+ self.total = t[1]
60
+
61
+ @property
62
+ def median(self):
63
+ d = torch.tensor(list(self.deque))
64
+ return d.median().item()
65
+
66
+ @property
67
+ def avg(self):
68
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
69
+ return d.mean().item()
70
+
71
+ @property
72
+ def global_avg(self):
73
+ return self.total / self.count
74
+
75
+ @property
76
+ def max(self):
77
+ return max(self.deque)
78
+
79
+ @property
80
+ def value(self):
81
+ return self.deque[-1]
82
+
83
+ def __str__(self):
84
+ return self.fmt.format(
85
+ median=self.median,
86
+ avg=self.avg,
87
+ global_avg=self.global_avg,
88
+ max=self.max,
89
+ value=self.value)
90
+
91
+
92
+ class MetricLogger(object):
93
+ def __init__(self, delimiter="\t"):
94
+ self.meters = defaultdict(SmoothedValue)
95
+ self.delimiter = delimiter
96
+
97
+ def update(self, **kwargs):
98
+ for k, v in kwargs.items():
99
+ if isinstance(v, torch.Tensor):
100
+ v = v.item()
101
+ assert isinstance(v, (float, int))
102
+ self.meters[k].update(v)
103
+
104
+ def __getattr__(self, attr):
105
+ if attr in self.meters:
106
+ return self.meters[attr]
107
+ if attr in self.__dict__:
108
+ return self.__dict__[attr]
109
+ raise AttributeError("'{}' object has no attribute '{}'".format(
110
+ type(self).__name__, attr))
111
+
112
+ def __str__(self):
113
+ loss_str = []
114
+ for name, meter in self.meters.items():
115
+ loss_str.append(
116
+ "{}: {}".format(name, str(meter))
117
+ )
118
+ return self.delimiter.join(loss_str)
119
+
120
+ def global_avg(self):
121
+ loss_str = []
122
+ for name, meter in self.meters.items():
123
+ loss_str.append(
124
+ "{}: {:.4f}".format(name, meter.global_avg)
125
+ )
126
+ return self.delimiter.join(loss_str)
127
+
128
+ def synchronize_between_processes(self):
129
+ for meter in self.meters.values():
130
+ meter.synchronize_between_processes()
131
+
132
+ def add_meter(self, name, meter):
133
+ self.meters[name] = meter
134
+
135
+ def log_every(self, iterable, print_freq, header=None):
136
+ i = 0
137
+ if not header:
138
+ header = ''
139
+ start_time = time.time()
140
+ end = time.time()
141
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
142
+ data_time = SmoothedValue(fmt='{avg:.4f}')
143
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
144
+ log_msg = [
145
+ header,
146
+ '[{0' + space_fmt + '}/{1}]',
147
+ 'eta: {eta}',
148
+ '{meters}',
149
+ 'time: {time}',
150
+ 'data: {data}'
151
+ ]
152
+ if torch.cuda.is_available():
153
+ log_msg.append('max mem: {memory:.0f}')
154
+ log_msg = self.delimiter.join(log_msg)
155
+ MB = 1024.0 * 1024.0
156
+ for obj in iterable:
157
+ data_time.update(time.time() - end)
158
+ yield obj
159
+ iter_time.update(time.time() - end)
160
+ if i % print_freq == 0 or i == len(iterable) - 1:
161
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
162
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
163
+ if torch.cuda.is_available():
164
+ print(log_msg.format(
165
+ i, len(iterable), eta=eta_string,
166
+ meters=str(self),
167
+ time=str(iter_time), data=str(data_time),
168
+ memory=torch.cuda.max_memory_allocated() / MB))
169
+ else:
170
+ print(log_msg.format(
171
+ i, len(iterable), eta=eta_string,
172
+ meters=str(self),
173
+ time=str(iter_time), data=str(data_time)))
174
+ i += 1
175
+ end = time.time()
176
+ total_time = time.time() - start_time
177
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
178
+ print('{} Total time: {} ({:.4f} s / it)'.format(
179
+ header, total_time_str, total_time / len(iterable)))
180
+
181
+
182
+ class AttrDict(dict):
183
+ def __init__(self, *args, **kwargs):
184
+ super(AttrDict, self).__init__(*args, **kwargs)
185
+ self.__dict__ = self
186
+
187
+
188
+ def compute_acc(logits, label, reduction='mean'):
189
+ ret = (torch.argmax(logits, dim=1) == label).float()
190
+ if reduction == 'none':
191
+ return ret.detach()
192
+ elif reduction == 'mean':
193
+ return ret.mean().item()
194
+
195
+ def compute_n_params(model, return_str=True):
196
+ tot = 0
197
+ for p in model.parameters():
198
+ w = 1
199
+ for x in p.shape:
200
+ w *= x
201
+ tot += w
202
+ if return_str:
203
+ if tot >= 1e6:
204
+ return '{:.1f}M'.format(tot / 1e6)
205
+ else:
206
+ return '{:.1f}K'.format(tot / 1e3)
207
+ else:
208
+ return tot
209
+
210
+ def setup_for_distributed(is_master):
211
+ """
212
+ This function disables printing when not in master process
213
+ """
214
+ import builtins as __builtin__
215
+ builtin_print = __builtin__.print
216
+
217
+ def print(*args, **kwargs):
218
+ force = kwargs.pop('force', False)
219
+ if is_master or force:
220
+ builtin_print(*args, **kwargs)
221
+
222
+ __builtin__.print = print
223
+
224
+
225
+ def is_dist_avail_and_initialized():
226
+ if not dist.is_available():
227
+ return False
228
+ if not dist.is_initialized():
229
+ return False
230
+ return True
231
+
232
+
233
+ def get_world_size():
234
+ if not is_dist_avail_and_initialized():
235
+ return 1
236
+ return dist.get_world_size()
237
+
238
+
239
+ def get_rank():
240
+ if not is_dist_avail_and_initialized():
241
+ return 0
242
+ return dist.get_rank()
243
+
244
+
245
+ def is_main_process():
246
+ return get_rank() == 0
247
+
248
+
249
+ def save_on_master(*args, **kwargs):
250
+ if is_main_process():
251
+ torch.save(*args, **kwargs)
252
+
253
+
254
+ def init_distributed_mode(args):
255
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
256
+ args.rank = int(os.environ["RANK"])
257
+ args.world_size = int(os.environ['WORLD_SIZE'])
258
+ args.gpu = int(os.environ['LOCAL_RANK'])
259
+ elif 'SLURM_PROCID' in os.environ:
260
+ args.rank = int(os.environ['SLURM_PROCID'])
261
+ args.gpu = args.rank % torch.cuda.device_count()
262
+ else:
263
+ print('Not using distributed mode')
264
+ args.distributed = False
265
+ return
266
+
267
+ args.distributed = True
268
+
269
+ torch.cuda.set_device(args.gpu)
270
+ args.dist_backend = 'nccl'
271
+ print('| distributed init (rank {}, word {}): {}'.format(
272
+ args.rank, args.world_size, args.dist_url), flush=True)
273
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
274
+ world_size=args.world_size, rank=args.rank)
275
+ torch.distributed.barrier()
276
+ setup_for_distributed(args.rank == 0)
277
+
278
+