i72sijia commited on
Commit
664be34
1 Parent(s): 8a6a841

Upload optimizer.py

Browse files
Files changed (1) hide show
  1. dnnlib/tflib/optimizer.py +372 -0
dnnlib/tflib/optimizer.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Helper wrapper for a Tensorflow optimizer."""
10
+
11
+ import platform
12
+ import numpy as np
13
+ import tensorflow as tf
14
+
15
+ from collections import OrderedDict
16
+ from typing import List, Union
17
+
18
+ from . import autosummary
19
+ from . import tfutil
20
+ from .. import util
21
+
22
+ from .tfutil import TfExpression, TfExpressionEx
23
+
24
+ _collective_ops_warning_printed = False
25
+ _collective_ops_group_key = 831766147
26
+ _collective_ops_instance_key = 436340067
27
+
28
+ class Optimizer:
29
+ """A Wrapper for tf.train.Optimizer.
30
+
31
+ Automatically takes care of:
32
+ - Gradient averaging for multi-GPU training.
33
+ - Gradient accumulation for arbitrarily large minibatches.
34
+ - Dynamic loss scaling and typecasts for FP16 training.
35
+ - Ignoring corrupted gradients that contain NaNs/Infs.
36
+ - Reporting statistics.
37
+ - Well-chosen default settings.
38
+ """
39
+
40
+ def __init__(self,
41
+ name: str = "Train", # Name string that will appear in TensorFlow graph.
42
+ tf_optimizer: str = "tf.train.AdamOptimizer", # Underlying optimizer class.
43
+ learning_rate: TfExpressionEx = 0.001, # Learning rate. Can vary over time.
44
+ minibatch_multiplier: TfExpressionEx = None, # Treat N consecutive minibatches as one by accumulating gradients.
45
+ share: "Optimizer" = None, # Share internal state with a previously created optimizer?
46
+ use_loss_scaling: bool = False, # Enable dynamic loss scaling for robust mixed-precision training?
47
+ loss_scaling_init: float = 64.0, # Log2 of initial loss scaling factor.
48
+ loss_scaling_inc: float = 0.0005, # Log2 of per-minibatch loss scaling increment when there is no overflow.
49
+ loss_scaling_dec: float = 1.0, # Log2 of per-minibatch loss scaling decrement when there is an overflow.
50
+ report_mem_usage: bool = False, # Report fine-grained memory usage statistics in TensorBoard?
51
+ **kwargs):
52
+
53
+ # Public fields.
54
+ self.name = name
55
+ self.learning_rate = learning_rate
56
+ self.minibatch_multiplier = minibatch_multiplier
57
+ self.id = self.name.replace("/", ".")
58
+ self.scope = tf.get_default_graph().unique_name(self.id)
59
+ self.optimizer_class = util.get_obj_by_name(tf_optimizer)
60
+ self.optimizer_kwargs = dict(kwargs)
61
+ self.use_loss_scaling = use_loss_scaling
62
+ self.loss_scaling_init = loss_scaling_init
63
+ self.loss_scaling_inc = loss_scaling_inc
64
+ self.loss_scaling_dec = loss_scaling_dec
65
+
66
+ # Private fields.
67
+ self._updates_applied = False
68
+ self._devices = OrderedDict() # device_name => EasyDict()
69
+ self._shared_optimizers = OrderedDict() # device_name => optimizer_class
70
+ self._gradient_shapes = None # [shape, ...]
71
+ self._report_mem_usage = report_mem_usage
72
+
73
+ # Validate arguments.
74
+ assert callable(self.optimizer_class)
75
+
76
+ # Share internal state if requested.
77
+ if share is not None:
78
+ assert isinstance(share, Optimizer)
79
+ assert self.optimizer_class is share.optimizer_class
80
+ assert self.learning_rate is share.learning_rate
81
+ assert self.optimizer_kwargs == share.optimizer_kwargs
82
+ self._shared_optimizers = share._shared_optimizers # pylint: disable=protected-access
83
+
84
+ def _get_device(self, device_name: str):
85
+ """Get internal state for the given TensorFlow device."""
86
+ tfutil.assert_tf_initialized()
87
+ if device_name in self._devices:
88
+ return self._devices[device_name]
89
+
90
+ # Initialize fields.
91
+ device = util.EasyDict()
92
+ device.name = device_name
93
+ device.optimizer = None # Underlying optimizer: optimizer_class
94
+ device.loss_scaling_var = None # Log2 of loss scaling: tf.Variable
95
+ device.grad_raw = OrderedDict() # Raw gradients: var => [grad, ...]
96
+ device.grad_clean = OrderedDict() # Clean gradients: var => grad
97
+ device.grad_acc_vars = OrderedDict() # Accumulation sums: var => tf.Variable
98
+ device.grad_acc_count = None # Accumulation counter: tf.Variable
99
+ device.grad_acc = OrderedDict() # Accumulated gradients: var => grad
100
+
101
+ # Setup TensorFlow objects.
102
+ with tfutil.absolute_name_scope(self.scope + "/Devices"), tf.device(device_name), tf.control_dependencies(None):
103
+ if device_name not in self._shared_optimizers:
104
+ optimizer_name = self.scope.replace("/", "_") + "_opt%d" % len(self._shared_optimizers)
105
+ self._shared_optimizers[device_name] = self.optimizer_class(name=optimizer_name, learning_rate=self.learning_rate, **self.optimizer_kwargs)
106
+ device.optimizer = self._shared_optimizers[device_name]
107
+ if self.use_loss_scaling:
108
+ device.loss_scaling_var = tf.Variable(np.float32(self.loss_scaling_init), trainable=False, name="loss_scaling_var")
109
+
110
+ # Register device.
111
+ self._devices[device_name] = device
112
+ return device
113
+
114
+ def register_gradients(self, loss: TfExpression, trainable_vars: Union[List, dict]) -> None:
115
+ """Register the gradients of the given loss function with respect to the given variables.
116
+ Intended to be called once per GPU."""
117
+ tfutil.assert_tf_initialized()
118
+ assert not self._updates_applied
119
+ device = self._get_device(loss.device)
120
+
121
+ # Validate trainables.
122
+ if isinstance(trainable_vars, dict):
123
+ trainable_vars = list(trainable_vars.values()) # allow passing in Network.trainables as vars
124
+ assert isinstance(trainable_vars, list) and len(trainable_vars) >= 1
125
+ assert all(tfutil.is_tf_expression(expr) for expr in trainable_vars + [loss])
126
+ assert all(var.device == device.name for var in trainable_vars)
127
+
128
+ # Validate shapes.
129
+ if self._gradient_shapes is None:
130
+ self._gradient_shapes = [var.shape.as_list() for var in trainable_vars]
131
+ assert len(trainable_vars) == len(self._gradient_shapes)
132
+ assert all(var.shape.as_list() == var_shape for var, var_shape in zip(trainable_vars, self._gradient_shapes))
133
+
134
+ # Report memory usage if requested.
135
+ deps = [loss]
136
+ if self._report_mem_usage:
137
+ self._report_mem_usage = False
138
+ try:
139
+ with tf.name_scope(self.id + '_mem'), tf.device(device.name), tf.control_dependencies([loss]):
140
+ deps.append(autosummary.autosummary(self.id + "/mem_usage_gb", tf.contrib.memory_stats.BytesInUse() / 2**30))
141
+ except tf.errors.NotFoundError:
142
+ pass
143
+
144
+ # Compute gradients.
145
+ with tf.name_scope(self.id + "_grad"), tf.device(device.name), tf.control_dependencies(deps):
146
+ loss = self.apply_loss_scaling(tf.cast(loss, tf.float32))
147
+ gate = tf.train.Optimizer.GATE_NONE # disable gating to reduce memory usage
148
+ grad_list = device.optimizer.compute_gradients(loss=loss, var_list=trainable_vars, gate_gradients=gate)
149
+
150
+ # Register gradients.
151
+ for grad, var in grad_list:
152
+ if var not in device.grad_raw:
153
+ device.grad_raw[var] = []
154
+ device.grad_raw[var].append(grad)
155
+
156
+ def apply_updates(self, allow_no_op: bool = False) -> tf.Operation:
157
+ """Construct training op to update the registered variables based on their gradients."""
158
+ tfutil.assert_tf_initialized()
159
+ assert not self._updates_applied
160
+ self._updates_applied = True
161
+ all_ops = []
162
+
163
+ # Check for no-op.
164
+ if allow_no_op and len(self._devices) == 0:
165
+ with tfutil.absolute_name_scope(self.scope):
166
+ return tf.no_op(name='TrainingOp')
167
+
168
+ # Clean up gradients.
169
+ for device_idx, device in enumerate(self._devices.values()):
170
+ with tfutil.absolute_name_scope(self.scope + "/Clean%d" % device_idx), tf.device(device.name):
171
+ for var, grad in device.grad_raw.items():
172
+
173
+ # Filter out disconnected gradients and convert to float32.
174
+ grad = [g for g in grad if g is not None]
175
+ grad = [tf.cast(g, tf.float32) for g in grad]
176
+
177
+ # Sum within the device.
178
+ if len(grad) == 0:
179
+ grad = tf.zeros(var.shape) # No gradients => zero.
180
+ elif len(grad) == 1:
181
+ grad = grad[0] # Single gradient => use as is.
182
+ else:
183
+ grad = tf.add_n(grad) # Multiple gradients => sum.
184
+
185
+ # Scale as needed.
186
+ scale = 1.0 / len(device.grad_raw[var]) / len(self._devices)
187
+ scale = tf.constant(scale, dtype=tf.float32, name="scale")
188
+ if self.minibatch_multiplier is not None:
189
+ scale /= tf.cast(self.minibatch_multiplier, tf.float32)
190
+ scale = self.undo_loss_scaling(scale)
191
+ device.grad_clean[var] = grad * scale
192
+
193
+ # Sum gradients across devices.
194
+ if len(self._devices) > 1:
195
+ with tfutil.absolute_name_scope(self.scope + "/Broadcast"), tf.device(None):
196
+ if platform.system() == "Windows": # Windows => NCCL ops are not available.
197
+ self._broadcast_fallback()
198
+ elif tf.VERSION.startswith("1.15."): # TF 1.15 => NCCL ops are broken: https://github.com/tensorflow/tensorflow/issues/41539
199
+ self._broadcast_fallback()
200
+ else: # Otherwise => NCCL ops are safe to use.
201
+ self._broadcast_nccl()
202
+
203
+ # Apply updates separately on each device.
204
+ for device_idx, device in enumerate(self._devices.values()):
205
+ with tfutil.absolute_name_scope(self.scope + "/Apply%d" % device_idx), tf.device(device.name):
206
+ # pylint: disable=cell-var-from-loop
207
+
208
+ # Accumulate gradients over time.
209
+ if self.minibatch_multiplier is None:
210
+ acc_ok = tf.constant(True, name='acc_ok')
211
+ device.grad_acc = OrderedDict(device.grad_clean)
212
+ else:
213
+ # Create variables.
214
+ with tf.control_dependencies(None):
215
+ for var in device.grad_clean.keys():
216
+ device.grad_acc_vars[var] = tf.Variable(tf.zeros(var.shape), trainable=False, name="grad_acc_var")
217
+ device.grad_acc_count = tf.Variable(tf.zeros([]), trainable=False, name="grad_acc_count")
218
+
219
+ # Track counter.
220
+ count_cur = device.grad_acc_count + 1.0
221
+ count_inc_op = lambda: tf.assign(device.grad_acc_count, count_cur)
222
+ count_reset_op = lambda: tf.assign(device.grad_acc_count, tf.zeros([]))
223
+ acc_ok = (count_cur >= tf.cast(self.minibatch_multiplier, tf.float32))
224
+ all_ops.append(tf.cond(acc_ok, count_reset_op, count_inc_op))
225
+
226
+ # Track gradients.
227
+ for var, grad in device.grad_clean.items():
228
+ acc_var = device.grad_acc_vars[var]
229
+ acc_cur = acc_var + grad
230
+ device.grad_acc[var] = acc_cur
231
+ with tf.control_dependencies([acc_cur]):
232
+ acc_inc_op = lambda: tf.assign(acc_var, acc_cur)
233
+ acc_reset_op = lambda: tf.assign(acc_var, tf.zeros(var.shape))
234
+ all_ops.append(tf.cond(acc_ok, acc_reset_op, acc_inc_op))
235
+
236
+ # No overflow => apply gradients.
237
+ all_ok = tf.reduce_all(tf.stack([acc_ok] + [tf.reduce_all(tf.is_finite(g)) for g in device.grad_acc.values()]))
238
+ apply_op = lambda: device.optimizer.apply_gradients([(tf.cast(grad, var.dtype), var) for var, grad in device.grad_acc.items()])
239
+ all_ops.append(tf.cond(all_ok, apply_op, tf.no_op))
240
+
241
+ # Adjust loss scaling.
242
+ if self.use_loss_scaling:
243
+ ls_inc_op = lambda: tf.assign_add(device.loss_scaling_var, self.loss_scaling_inc)
244
+ ls_dec_op = lambda: tf.assign_sub(device.loss_scaling_var, self.loss_scaling_dec)
245
+ ls_update_op = lambda: tf.group(tf.cond(all_ok, ls_inc_op, ls_dec_op))
246
+ all_ops.append(tf.cond(acc_ok, ls_update_op, tf.no_op))
247
+
248
+ # Last device => report statistics.
249
+ if device_idx == len(self._devices) - 1:
250
+ all_ops.append(autosummary.autosummary(self.id + "/learning_rate", tf.convert_to_tensor(self.learning_rate)))
251
+ all_ops.append(autosummary.autosummary(self.id + "/overflow_frequency", tf.where(all_ok, 0, 1), condition=acc_ok))
252
+ if self.use_loss_scaling:
253
+ all_ops.append(autosummary.autosummary(self.id + "/loss_scaling_log2", device.loss_scaling_var))
254
+
255
+ # Initialize variables.
256
+ self.reset_optimizer_state()
257
+ if self.use_loss_scaling:
258
+ tfutil.init_uninitialized_vars([device.loss_scaling_var for device in self._devices.values()])
259
+ if self.minibatch_multiplier is not None:
260
+ tfutil.run([var.initializer for device in self._devices.values() for var in list(device.grad_acc_vars.values()) + [device.grad_acc_count]])
261
+
262
+ # Group everything into a single op.
263
+ with tfutil.absolute_name_scope(self.scope):
264
+ return tf.group(*all_ops, name="TrainingOp")
265
+
266
+ def reset_optimizer_state(self) -> None:
267
+ """Reset internal state of the underlying optimizer."""
268
+ tfutil.assert_tf_initialized()
269
+ tfutil.run([var.initializer for device in self._devices.values() for var in device.optimizer.variables()])
270
+
271
+ def get_loss_scaling_var(self, device: str) -> Union[tf.Variable, None]:
272
+ """Get or create variable representing log2 of the current dynamic loss scaling factor."""
273
+ return self._get_device(device).loss_scaling_var
274
+
275
+ def apply_loss_scaling(self, value: TfExpression) -> TfExpression:
276
+ """Apply dynamic loss scaling for the given expression."""
277
+ assert tfutil.is_tf_expression(value)
278
+ if not self.use_loss_scaling:
279
+ return value
280
+ return value * tfutil.exp2(self.get_loss_scaling_var(value.device))
281
+
282
+ def undo_loss_scaling(self, value: TfExpression) -> TfExpression:
283
+ """Undo the effect of dynamic loss scaling for the given expression."""
284
+ assert tfutil.is_tf_expression(value)
285
+ if not self.use_loss_scaling:
286
+ return value
287
+ return value * tfutil.exp2(-self.get_loss_scaling_var(value.device)) # pylint: disable=invalid-unary-operand-type
288
+
289
+ def _broadcast_nccl(self):
290
+ """Sum gradients across devices using NCCL ops (fast path)."""
291
+ from tensorflow.python.ops import nccl_ops # pylint: disable=no-name-in-module
292
+ for all_vars in zip(*[device.grad_clean.keys() for device in self._devices.values()]):
293
+ if any(x.shape.num_elements() > 0 for x in all_vars):
294
+ all_grads = [device.grad_clean[var] for device, var in zip(self._devices.values(), all_vars)]
295
+ all_grads = nccl_ops.all_sum(all_grads)
296
+ for device, var, grad in zip(self._devices.values(), all_vars, all_grads):
297
+ device.grad_clean[var] = grad
298
+
299
+ def _broadcast_fallback(self):
300
+ """Sum gradients across devices using TensorFlow collective ops (slow fallback path)."""
301
+ from tensorflow.python.ops import collective_ops # pylint: disable=no-name-in-module
302
+ global _collective_ops_warning_printed, _collective_ops_group_key, _collective_ops_instance_key
303
+ if all(x.shape.num_elements() == 0 for device in self._devices.values() for x in device.grad_clean.values()):
304
+ return
305
+ if not _collective_ops_warning_printed:
306
+ print("------------------------------------------------------------------------")
307
+ print("WARNING: Using slow fallback implementation for inter-GPU communication.")
308
+ print("Please use TensorFlow 1.14 on Linux for optimal training performance.")
309
+ print("------------------------------------------------------------------------")
310
+ _collective_ops_warning_printed = True
311
+ for device in self._devices.values():
312
+ with tf.device(device.name):
313
+ combo = [tf.reshape(x, [x.shape.num_elements()]) for x in device.grad_clean.values()]
314
+ combo = tf.concat(combo, axis=0)
315
+ combo = collective_ops.all_reduce(combo, merge_op='Add', final_op='Id',
316
+ group_size=len(self._devices), group_key=_collective_ops_group_key,
317
+ instance_key=_collective_ops_instance_key)
318
+ cur_ofs = 0
319
+ for var, grad_old in device.grad_clean.items():
320
+ grad_new = tf.reshape(combo[cur_ofs : cur_ofs + grad_old.shape.num_elements()], grad_old.shape)
321
+ cur_ofs += grad_old.shape.num_elements()
322
+ device.grad_clean[var] = grad_new
323
+ _collective_ops_instance_key += 1
324
+
325
+
326
+ class SimpleAdam:
327
+ """Simplified version of tf.train.AdamOptimizer that behaves identically when used with dnnlib.tflib.Optimizer."""
328
+
329
+ def __init__(self, name="Adam", learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8):
330
+ self.name = name
331
+ self.learning_rate = learning_rate
332
+ self.beta1 = beta1
333
+ self.beta2 = beta2
334
+ self.epsilon = epsilon
335
+ self.all_state_vars = []
336
+
337
+ def variables(self):
338
+ return self.all_state_vars
339
+
340
+ def compute_gradients(self, loss, var_list, gate_gradients=tf.train.Optimizer.GATE_NONE):
341
+ assert gate_gradients == tf.train.Optimizer.GATE_NONE
342
+ return list(zip(tf.gradients(loss, var_list), var_list))
343
+
344
+ def apply_gradients(self, grads_and_vars):
345
+ with tf.name_scope(self.name):
346
+ state_vars = []
347
+ update_ops = []
348
+
349
+ # Adjust learning rate to deal with startup bias.
350
+ with tf.control_dependencies(None):
351
+ b1pow_var = tf.Variable(dtype=tf.float32, initial_value=1, trainable=False)
352
+ b2pow_var = tf.Variable(dtype=tf.float32, initial_value=1, trainable=False)
353
+ state_vars += [b1pow_var, b2pow_var]
354
+ b1pow_new = b1pow_var * self.beta1
355
+ b2pow_new = b2pow_var * self.beta2
356
+ update_ops += [tf.assign(b1pow_var, b1pow_new), tf.assign(b2pow_var, b2pow_new)]
357
+ lr_new = self.learning_rate * tf.sqrt(1 - b2pow_new) / (1 - b1pow_new)
358
+
359
+ # Construct ops to update each variable.
360
+ for grad, var in grads_and_vars:
361
+ with tf.control_dependencies(None):
362
+ m_var = tf.Variable(dtype=tf.float32, initial_value=tf.zeros_like(var), trainable=False)
363
+ v_var = tf.Variable(dtype=tf.float32, initial_value=tf.zeros_like(var), trainable=False)
364
+ state_vars += [m_var, v_var]
365
+ m_new = self.beta1 * m_var + (1 - self.beta1) * grad
366
+ v_new = self.beta2 * v_var + (1 - self.beta2) * tf.square(grad)
367
+ var_delta = lr_new * m_new / (tf.sqrt(v_new) + self.epsilon)
368
+ update_ops += [tf.assign(m_var, m_new), tf.assign(v_var, v_new), tf.assign_sub(var, var_delta)]
369
+
370
+ # Group everything together.
371
+ self.all_state_vars += state_vars
372
+ return tf.group(*update_ops)