i72sijia commited on
Commit
d44bc96
1 Parent(s): 6778286

Upload training_stats.py

Browse files
Files changed (1) hide show
  1. torch_utils/training_stats.py +268 -0
torch_utils/training_stats.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Facilities for reporting and collecting training statistics across
10
+ multiple processes and devices. The interface is designed to minimize
11
+ synchronization overhead as well as the amount of boilerplate in user
12
+ code."""
13
+
14
+ import re
15
+ import numpy as np
16
+ import torch
17
+ import dnnlib
18
+
19
+ from . import misc
20
+
21
+ #----------------------------------------------------------------------------
22
+
23
+ _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
24
+ _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
25
+ _counter_dtype = torch.float64 # Data type to use for the internal counters.
26
+ _rank = 0 # Rank of the current process.
27
+ _sync_device = None # Device to use for multiprocess communication. None = single-process.
28
+ _sync_called = False # Has _sync() been called yet?
29
+ _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor
30
+ _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
31
+
32
+ #----------------------------------------------------------------------------
33
+
34
+ def init_multiprocessing(rank, sync_device):
35
+ r"""Initializes `torch_utils.training_stats` for collecting statistics
36
+ across multiple processes.
37
+
38
+ This function must be called after
39
+ `torch.distributed.init_process_group()` and before `Collector.update()`.
40
+ The call is not necessary if multi-process collection is not needed.
41
+
42
+ Args:
43
+ rank: Rank of the current process.
44
+ sync_device: PyTorch device to use for inter-process
45
+ communication, or None to disable multi-process
46
+ collection. Typically `torch.device('cuda', rank)`.
47
+ """
48
+ global _rank, _sync_device
49
+ assert not _sync_called
50
+ _rank = rank
51
+ _sync_device = sync_device
52
+
53
+ #----------------------------------------------------------------------------
54
+
55
+ @misc.profiled_function
56
+ def report(name, value):
57
+ r"""Broadcasts the given set of scalars to all interested instances of
58
+ `Collector`, across device and process boundaries.
59
+
60
+ This function is expected to be extremely cheap and can be safely
61
+ called from anywhere in the training loop, loss function, or inside a
62
+ `torch.nn.Module`.
63
+
64
+ Warning: The current implementation expects the set of unique names to
65
+ be consistent across processes. Please make sure that `report()` is
66
+ called at least once for each unique name by each process, and in the
67
+ same order. If a given process has no scalars to broadcast, it can do
68
+ `report(name, [])` (empty list).
69
+
70
+ Args:
71
+ name: Arbitrary string specifying the name of the statistic.
72
+ Averages are accumulated separately for each unique name.
73
+ value: Arbitrary set of scalars. Can be a list, tuple,
74
+ NumPy array, PyTorch tensor, or Python scalar.
75
+
76
+ Returns:
77
+ The same `value` that was passed in.
78
+ """
79
+ if name not in _counters:
80
+ _counters[name] = dict()
81
+
82
+ elems = torch.as_tensor(value)
83
+ if elems.numel() == 0:
84
+ return value
85
+
86
+ elems = elems.detach().flatten().to(_reduce_dtype)
87
+ moments = torch.stack([
88
+ torch.ones_like(elems).sum(),
89
+ elems.sum(),
90
+ elems.square().sum(),
91
+ ])
92
+ assert moments.ndim == 1 and moments.shape[0] == _num_moments
93
+ moments = moments.to(_counter_dtype)
94
+
95
+ device = moments.device
96
+ if device not in _counters[name]:
97
+ _counters[name][device] = torch.zeros_like(moments)
98
+ _counters[name][device].add_(moments)
99
+ return value
100
+
101
+ #----------------------------------------------------------------------------
102
+
103
+ def report0(name, value):
104
+ r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
105
+ but ignores any scalars provided by the other processes.
106
+ See `report()` for further details.
107
+ """
108
+ report(name, value if _rank == 0 else [])
109
+ return value
110
+
111
+ #----------------------------------------------------------------------------
112
+
113
+ class Collector:
114
+ r"""Collects the scalars broadcasted by `report()` and `report0()` and
115
+ computes their long-term averages (mean and standard deviation) over
116
+ user-defined periods of time.
117
+
118
+ The averages are first collected into internal counters that are not
119
+ directly visible to the user. They are then copied to the user-visible
120
+ state as a result of calling `update()` and can then be queried using
121
+ `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the
122
+ internal counters for the next round, so that the user-visible state
123
+ effectively reflects averages collected between the last two calls to
124
+ `update()`.
125
+
126
+ Args:
127
+ regex: Regular expression defining which statistics to
128
+ collect. The default is to collect everything.
129
+ keep_previous: Whether to retain the previous averages if no
130
+ scalars were collected on a given round
131
+ (default: True).
132
+ """
133
+ def __init__(self, regex='.*', keep_previous=True):
134
+ self._regex = re.compile(regex)
135
+ self._keep_previous = keep_previous
136
+ self._cumulative = dict()
137
+ self._moments = dict()
138
+ self.update()
139
+ self._moments.clear()
140
+
141
+ def names(self):
142
+ r"""Returns the names of all statistics broadcasted so far that
143
+ match the regular expression specified at construction time.
144
+ """
145
+ return [name for name in _counters if self._regex.fullmatch(name)]
146
+
147
+ def update(self):
148
+ r"""Copies current values of the internal counters to the
149
+ user-visible state and resets them for the next round.
150
+
151
+ If `keep_previous=True` was specified at construction time, the
152
+ operation is skipped for statistics that have received no scalars
153
+ since the last update, retaining their previous averages.
154
+
155
+ This method performs a number of GPU-to-CPU transfers and one
156
+ `torch.distributed.all_reduce()`. It is intended to be called
157
+ periodically in the main training loop, typically once every
158
+ N training steps.
159
+ """
160
+ if not self._keep_previous:
161
+ self._moments.clear()
162
+ for name, cumulative in _sync(self.names()):
163
+ if name not in self._cumulative:
164
+ self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
165
+ delta = cumulative - self._cumulative[name]
166
+ self._cumulative[name].copy_(cumulative)
167
+ if float(delta[0]) != 0:
168
+ self._moments[name] = delta
169
+
170
+ def _get_delta(self, name):
171
+ r"""Returns the raw moments that were accumulated for the given
172
+ statistic between the last two calls to `update()`, or zero if
173
+ no scalars were collected.
174
+ """
175
+ assert self._regex.fullmatch(name)
176
+ if name not in self._moments:
177
+ self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
178
+ return self._moments[name]
179
+
180
+ def num(self, name):
181
+ r"""Returns the number of scalars that were accumulated for the given
182
+ statistic between the last two calls to `update()`, or zero if
183
+ no scalars were collected.
184
+ """
185
+ delta = self._get_delta(name)
186
+ return int(delta[0])
187
+
188
+ def mean(self, name):
189
+ r"""Returns the mean of the scalars that were accumulated for the
190
+ given statistic between the last two calls to `update()`, or NaN if
191
+ no scalars were collected.
192
+ """
193
+ delta = self._get_delta(name)
194
+ if int(delta[0]) == 0:
195
+ return float('nan')
196
+ return float(delta[1] / delta[0])
197
+
198
+ def std(self, name):
199
+ r"""Returns the standard deviation of the scalars that were
200
+ accumulated for the given statistic between the last two calls to
201
+ `update()`, or NaN if no scalars were collected.
202
+ """
203
+ delta = self._get_delta(name)
204
+ if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
205
+ return float('nan')
206
+ if int(delta[0]) == 1:
207
+ return float(0)
208
+ mean = float(delta[1] / delta[0])
209
+ raw_var = float(delta[2] / delta[0])
210
+ return np.sqrt(max(raw_var - np.square(mean), 0))
211
+
212
+ def as_dict(self):
213
+ r"""Returns the averages accumulated between the last two calls to
214
+ `update()` as an `dnnlib.EasyDict`. The contents are as follows:
215
+
216
+ dnnlib.EasyDict(
217
+ NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT),
218
+ ...
219
+ )
220
+ """
221
+ stats = dnnlib.EasyDict()
222
+ for name in self.names():
223
+ stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
224
+ return stats
225
+
226
+ def __getitem__(self, name):
227
+ r"""Convenience getter.
228
+ `collector[name]` is a synonym for `collector.mean(name)`.
229
+ """
230
+ return self.mean(name)
231
+
232
+ #----------------------------------------------------------------------------
233
+
234
+ def _sync(names):
235
+ r"""Synchronize the global cumulative counters across devices and
236
+ processes. Called internally by `Collector.update()`.
237
+ """
238
+ if len(names) == 0:
239
+ return []
240
+ global _sync_called
241
+ _sync_called = True
242
+
243
+ # Collect deltas within current rank.
244
+ deltas = []
245
+ device = _sync_device if _sync_device is not None else torch.device('cpu')
246
+ for name in names:
247
+ delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
248
+ for counter in _counters[name].values():
249
+ delta.add_(counter.to(device))
250
+ counter.copy_(torch.zeros_like(counter))
251
+ deltas.append(delta)
252
+ deltas = torch.stack(deltas)
253
+
254
+ # Sum deltas across ranks.
255
+ if _sync_device is not None:
256
+ torch.distributed.all_reduce(deltas)
257
+
258
+ # Update cumulative values.
259
+ deltas = deltas.cpu()
260
+ for idx, name in enumerate(names):
261
+ if name not in _cumulative:
262
+ _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
263
+ _cumulative[name].add_(deltas[idx])
264
+
265
+ # Return name-value pairs.
266
+ return [(name, _cumulative[name]) for name in names]
267
+
268
+ #----------------------------------------------------------------------------