Spaces:
Runtime error
Runtime error
Upload optimizer.py
Browse files- dnnlib/tflib/optimizer.py +372 -0
dnnlib/tflib/optimizer.py
ADDED
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Helper wrapper for a Tensorflow optimizer."""
|
10 |
+
|
11 |
+
import platform
|
12 |
+
import numpy as np
|
13 |
+
import tensorflow as tf
|
14 |
+
|
15 |
+
from collections import OrderedDict
|
16 |
+
from typing import List, Union
|
17 |
+
|
18 |
+
from . import autosummary
|
19 |
+
from . import tfutil
|
20 |
+
from .. import util
|
21 |
+
|
22 |
+
from .tfutil import TfExpression, TfExpressionEx
|
23 |
+
|
24 |
+
_collective_ops_warning_printed = False
|
25 |
+
_collective_ops_group_key = 831766147
|
26 |
+
_collective_ops_instance_key = 436340067
|
27 |
+
|
28 |
+
class Optimizer:
|
29 |
+
"""A Wrapper for tf.train.Optimizer.
|
30 |
+
|
31 |
+
Automatically takes care of:
|
32 |
+
- Gradient averaging for multi-GPU training.
|
33 |
+
- Gradient accumulation for arbitrarily large minibatches.
|
34 |
+
- Dynamic loss scaling and typecasts for FP16 training.
|
35 |
+
- Ignoring corrupted gradients that contain NaNs/Infs.
|
36 |
+
- Reporting statistics.
|
37 |
+
- Well-chosen default settings.
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(self,
|
41 |
+
name: str = "Train", # Name string that will appear in TensorFlow graph.
|
42 |
+
tf_optimizer: str = "tf.train.AdamOptimizer", # Underlying optimizer class.
|
43 |
+
learning_rate: TfExpressionEx = 0.001, # Learning rate. Can vary over time.
|
44 |
+
minibatch_multiplier: TfExpressionEx = None, # Treat N consecutive minibatches as one by accumulating gradients.
|
45 |
+
share: "Optimizer" = None, # Share internal state with a previously created optimizer?
|
46 |
+
use_loss_scaling: bool = False, # Enable dynamic loss scaling for robust mixed-precision training?
|
47 |
+
loss_scaling_init: float = 64.0, # Log2 of initial loss scaling factor.
|
48 |
+
loss_scaling_inc: float = 0.0005, # Log2 of per-minibatch loss scaling increment when there is no overflow.
|
49 |
+
loss_scaling_dec: float = 1.0, # Log2 of per-minibatch loss scaling decrement when there is an overflow.
|
50 |
+
report_mem_usage: bool = False, # Report fine-grained memory usage statistics in TensorBoard?
|
51 |
+
**kwargs):
|
52 |
+
|
53 |
+
# Public fields.
|
54 |
+
self.name = name
|
55 |
+
self.learning_rate = learning_rate
|
56 |
+
self.minibatch_multiplier = minibatch_multiplier
|
57 |
+
self.id = self.name.replace("/", ".")
|
58 |
+
self.scope = tf.get_default_graph().unique_name(self.id)
|
59 |
+
self.optimizer_class = util.get_obj_by_name(tf_optimizer)
|
60 |
+
self.optimizer_kwargs = dict(kwargs)
|
61 |
+
self.use_loss_scaling = use_loss_scaling
|
62 |
+
self.loss_scaling_init = loss_scaling_init
|
63 |
+
self.loss_scaling_inc = loss_scaling_inc
|
64 |
+
self.loss_scaling_dec = loss_scaling_dec
|
65 |
+
|
66 |
+
# Private fields.
|
67 |
+
self._updates_applied = False
|
68 |
+
self._devices = OrderedDict() # device_name => EasyDict()
|
69 |
+
self._shared_optimizers = OrderedDict() # device_name => optimizer_class
|
70 |
+
self._gradient_shapes = None # [shape, ...]
|
71 |
+
self._report_mem_usage = report_mem_usage
|
72 |
+
|
73 |
+
# Validate arguments.
|
74 |
+
assert callable(self.optimizer_class)
|
75 |
+
|
76 |
+
# Share internal state if requested.
|
77 |
+
if share is not None:
|
78 |
+
assert isinstance(share, Optimizer)
|
79 |
+
assert self.optimizer_class is share.optimizer_class
|
80 |
+
assert self.learning_rate is share.learning_rate
|
81 |
+
assert self.optimizer_kwargs == share.optimizer_kwargs
|
82 |
+
self._shared_optimizers = share._shared_optimizers # pylint: disable=protected-access
|
83 |
+
|
84 |
+
def _get_device(self, device_name: str):
|
85 |
+
"""Get internal state for the given TensorFlow device."""
|
86 |
+
tfutil.assert_tf_initialized()
|
87 |
+
if device_name in self._devices:
|
88 |
+
return self._devices[device_name]
|
89 |
+
|
90 |
+
# Initialize fields.
|
91 |
+
device = util.EasyDict()
|
92 |
+
device.name = device_name
|
93 |
+
device.optimizer = None # Underlying optimizer: optimizer_class
|
94 |
+
device.loss_scaling_var = None # Log2 of loss scaling: tf.Variable
|
95 |
+
device.grad_raw = OrderedDict() # Raw gradients: var => [grad, ...]
|
96 |
+
device.grad_clean = OrderedDict() # Clean gradients: var => grad
|
97 |
+
device.grad_acc_vars = OrderedDict() # Accumulation sums: var => tf.Variable
|
98 |
+
device.grad_acc_count = None # Accumulation counter: tf.Variable
|
99 |
+
device.grad_acc = OrderedDict() # Accumulated gradients: var => grad
|
100 |
+
|
101 |
+
# Setup TensorFlow objects.
|
102 |
+
with tfutil.absolute_name_scope(self.scope + "/Devices"), tf.device(device_name), tf.control_dependencies(None):
|
103 |
+
if device_name not in self._shared_optimizers:
|
104 |
+
optimizer_name = self.scope.replace("/", "_") + "_opt%d" % len(self._shared_optimizers)
|
105 |
+
self._shared_optimizers[device_name] = self.optimizer_class(name=optimizer_name, learning_rate=self.learning_rate, **self.optimizer_kwargs)
|
106 |
+
device.optimizer = self._shared_optimizers[device_name]
|
107 |
+
if self.use_loss_scaling:
|
108 |
+
device.loss_scaling_var = tf.Variable(np.float32(self.loss_scaling_init), trainable=False, name="loss_scaling_var")
|
109 |
+
|
110 |
+
# Register device.
|
111 |
+
self._devices[device_name] = device
|
112 |
+
return device
|
113 |
+
|
114 |
+
def register_gradients(self, loss: TfExpression, trainable_vars: Union[List, dict]) -> None:
|
115 |
+
"""Register the gradients of the given loss function with respect to the given variables.
|
116 |
+
Intended to be called once per GPU."""
|
117 |
+
tfutil.assert_tf_initialized()
|
118 |
+
assert not self._updates_applied
|
119 |
+
device = self._get_device(loss.device)
|
120 |
+
|
121 |
+
# Validate trainables.
|
122 |
+
if isinstance(trainable_vars, dict):
|
123 |
+
trainable_vars = list(trainable_vars.values()) # allow passing in Network.trainables as vars
|
124 |
+
assert isinstance(trainable_vars, list) and len(trainable_vars) >= 1
|
125 |
+
assert all(tfutil.is_tf_expression(expr) for expr in trainable_vars + [loss])
|
126 |
+
assert all(var.device == device.name for var in trainable_vars)
|
127 |
+
|
128 |
+
# Validate shapes.
|
129 |
+
if self._gradient_shapes is None:
|
130 |
+
self._gradient_shapes = [var.shape.as_list() for var in trainable_vars]
|
131 |
+
assert len(trainable_vars) == len(self._gradient_shapes)
|
132 |
+
assert all(var.shape.as_list() == var_shape for var, var_shape in zip(trainable_vars, self._gradient_shapes))
|
133 |
+
|
134 |
+
# Report memory usage if requested.
|
135 |
+
deps = [loss]
|
136 |
+
if self._report_mem_usage:
|
137 |
+
self._report_mem_usage = False
|
138 |
+
try:
|
139 |
+
with tf.name_scope(self.id + '_mem'), tf.device(device.name), tf.control_dependencies([loss]):
|
140 |
+
deps.append(autosummary.autosummary(self.id + "/mem_usage_gb", tf.contrib.memory_stats.BytesInUse() / 2**30))
|
141 |
+
except tf.errors.NotFoundError:
|
142 |
+
pass
|
143 |
+
|
144 |
+
# Compute gradients.
|
145 |
+
with tf.name_scope(self.id + "_grad"), tf.device(device.name), tf.control_dependencies(deps):
|
146 |
+
loss = self.apply_loss_scaling(tf.cast(loss, tf.float32))
|
147 |
+
gate = tf.train.Optimizer.GATE_NONE # disable gating to reduce memory usage
|
148 |
+
grad_list = device.optimizer.compute_gradients(loss=loss, var_list=trainable_vars, gate_gradients=gate)
|
149 |
+
|
150 |
+
# Register gradients.
|
151 |
+
for grad, var in grad_list:
|
152 |
+
if var not in device.grad_raw:
|
153 |
+
device.grad_raw[var] = []
|
154 |
+
device.grad_raw[var].append(grad)
|
155 |
+
|
156 |
+
def apply_updates(self, allow_no_op: bool = False) -> tf.Operation:
|
157 |
+
"""Construct training op to update the registered variables based on their gradients."""
|
158 |
+
tfutil.assert_tf_initialized()
|
159 |
+
assert not self._updates_applied
|
160 |
+
self._updates_applied = True
|
161 |
+
all_ops = []
|
162 |
+
|
163 |
+
# Check for no-op.
|
164 |
+
if allow_no_op and len(self._devices) == 0:
|
165 |
+
with tfutil.absolute_name_scope(self.scope):
|
166 |
+
return tf.no_op(name='TrainingOp')
|
167 |
+
|
168 |
+
# Clean up gradients.
|
169 |
+
for device_idx, device in enumerate(self._devices.values()):
|
170 |
+
with tfutil.absolute_name_scope(self.scope + "/Clean%d" % device_idx), tf.device(device.name):
|
171 |
+
for var, grad in device.grad_raw.items():
|
172 |
+
|
173 |
+
# Filter out disconnected gradients and convert to float32.
|
174 |
+
grad = [g for g in grad if g is not None]
|
175 |
+
grad = [tf.cast(g, tf.float32) for g in grad]
|
176 |
+
|
177 |
+
# Sum within the device.
|
178 |
+
if len(grad) == 0:
|
179 |
+
grad = tf.zeros(var.shape) # No gradients => zero.
|
180 |
+
elif len(grad) == 1:
|
181 |
+
grad = grad[0] # Single gradient => use as is.
|
182 |
+
else:
|
183 |
+
grad = tf.add_n(grad) # Multiple gradients => sum.
|
184 |
+
|
185 |
+
# Scale as needed.
|
186 |
+
scale = 1.0 / len(device.grad_raw[var]) / len(self._devices)
|
187 |
+
scale = tf.constant(scale, dtype=tf.float32, name="scale")
|
188 |
+
if self.minibatch_multiplier is not None:
|
189 |
+
scale /= tf.cast(self.minibatch_multiplier, tf.float32)
|
190 |
+
scale = self.undo_loss_scaling(scale)
|
191 |
+
device.grad_clean[var] = grad * scale
|
192 |
+
|
193 |
+
# Sum gradients across devices.
|
194 |
+
if len(self._devices) > 1:
|
195 |
+
with tfutil.absolute_name_scope(self.scope + "/Broadcast"), tf.device(None):
|
196 |
+
if platform.system() == "Windows": # Windows => NCCL ops are not available.
|
197 |
+
self._broadcast_fallback()
|
198 |
+
elif tf.VERSION.startswith("1.15."): # TF 1.15 => NCCL ops are broken: https://github.com/tensorflow/tensorflow/issues/41539
|
199 |
+
self._broadcast_fallback()
|
200 |
+
else: # Otherwise => NCCL ops are safe to use.
|
201 |
+
self._broadcast_nccl()
|
202 |
+
|
203 |
+
# Apply updates separately on each device.
|
204 |
+
for device_idx, device in enumerate(self._devices.values()):
|
205 |
+
with tfutil.absolute_name_scope(self.scope + "/Apply%d" % device_idx), tf.device(device.name):
|
206 |
+
# pylint: disable=cell-var-from-loop
|
207 |
+
|
208 |
+
# Accumulate gradients over time.
|
209 |
+
if self.minibatch_multiplier is None:
|
210 |
+
acc_ok = tf.constant(True, name='acc_ok')
|
211 |
+
device.grad_acc = OrderedDict(device.grad_clean)
|
212 |
+
else:
|
213 |
+
# Create variables.
|
214 |
+
with tf.control_dependencies(None):
|
215 |
+
for var in device.grad_clean.keys():
|
216 |
+
device.grad_acc_vars[var] = tf.Variable(tf.zeros(var.shape), trainable=False, name="grad_acc_var")
|
217 |
+
device.grad_acc_count = tf.Variable(tf.zeros([]), trainable=False, name="grad_acc_count")
|
218 |
+
|
219 |
+
# Track counter.
|
220 |
+
count_cur = device.grad_acc_count + 1.0
|
221 |
+
count_inc_op = lambda: tf.assign(device.grad_acc_count, count_cur)
|
222 |
+
count_reset_op = lambda: tf.assign(device.grad_acc_count, tf.zeros([]))
|
223 |
+
acc_ok = (count_cur >= tf.cast(self.minibatch_multiplier, tf.float32))
|
224 |
+
all_ops.append(tf.cond(acc_ok, count_reset_op, count_inc_op))
|
225 |
+
|
226 |
+
# Track gradients.
|
227 |
+
for var, grad in device.grad_clean.items():
|
228 |
+
acc_var = device.grad_acc_vars[var]
|
229 |
+
acc_cur = acc_var + grad
|
230 |
+
device.grad_acc[var] = acc_cur
|
231 |
+
with tf.control_dependencies([acc_cur]):
|
232 |
+
acc_inc_op = lambda: tf.assign(acc_var, acc_cur)
|
233 |
+
acc_reset_op = lambda: tf.assign(acc_var, tf.zeros(var.shape))
|
234 |
+
all_ops.append(tf.cond(acc_ok, acc_reset_op, acc_inc_op))
|
235 |
+
|
236 |
+
# No overflow => apply gradients.
|
237 |
+
all_ok = tf.reduce_all(tf.stack([acc_ok] + [tf.reduce_all(tf.is_finite(g)) for g in device.grad_acc.values()]))
|
238 |
+
apply_op = lambda: device.optimizer.apply_gradients([(tf.cast(grad, var.dtype), var) for var, grad in device.grad_acc.items()])
|
239 |
+
all_ops.append(tf.cond(all_ok, apply_op, tf.no_op))
|
240 |
+
|
241 |
+
# Adjust loss scaling.
|
242 |
+
if self.use_loss_scaling:
|
243 |
+
ls_inc_op = lambda: tf.assign_add(device.loss_scaling_var, self.loss_scaling_inc)
|
244 |
+
ls_dec_op = lambda: tf.assign_sub(device.loss_scaling_var, self.loss_scaling_dec)
|
245 |
+
ls_update_op = lambda: tf.group(tf.cond(all_ok, ls_inc_op, ls_dec_op))
|
246 |
+
all_ops.append(tf.cond(acc_ok, ls_update_op, tf.no_op))
|
247 |
+
|
248 |
+
# Last device => report statistics.
|
249 |
+
if device_idx == len(self._devices) - 1:
|
250 |
+
all_ops.append(autosummary.autosummary(self.id + "/learning_rate", tf.convert_to_tensor(self.learning_rate)))
|
251 |
+
all_ops.append(autosummary.autosummary(self.id + "/overflow_frequency", tf.where(all_ok, 0, 1), condition=acc_ok))
|
252 |
+
if self.use_loss_scaling:
|
253 |
+
all_ops.append(autosummary.autosummary(self.id + "/loss_scaling_log2", device.loss_scaling_var))
|
254 |
+
|
255 |
+
# Initialize variables.
|
256 |
+
self.reset_optimizer_state()
|
257 |
+
if self.use_loss_scaling:
|
258 |
+
tfutil.init_uninitialized_vars([device.loss_scaling_var for device in self._devices.values()])
|
259 |
+
if self.minibatch_multiplier is not None:
|
260 |
+
tfutil.run([var.initializer for device in self._devices.values() for var in list(device.grad_acc_vars.values()) + [device.grad_acc_count]])
|
261 |
+
|
262 |
+
# Group everything into a single op.
|
263 |
+
with tfutil.absolute_name_scope(self.scope):
|
264 |
+
return tf.group(*all_ops, name="TrainingOp")
|
265 |
+
|
266 |
+
def reset_optimizer_state(self) -> None:
|
267 |
+
"""Reset internal state of the underlying optimizer."""
|
268 |
+
tfutil.assert_tf_initialized()
|
269 |
+
tfutil.run([var.initializer for device in self._devices.values() for var in device.optimizer.variables()])
|
270 |
+
|
271 |
+
def get_loss_scaling_var(self, device: str) -> Union[tf.Variable, None]:
|
272 |
+
"""Get or create variable representing log2 of the current dynamic loss scaling factor."""
|
273 |
+
return self._get_device(device).loss_scaling_var
|
274 |
+
|
275 |
+
def apply_loss_scaling(self, value: TfExpression) -> TfExpression:
|
276 |
+
"""Apply dynamic loss scaling for the given expression."""
|
277 |
+
assert tfutil.is_tf_expression(value)
|
278 |
+
if not self.use_loss_scaling:
|
279 |
+
return value
|
280 |
+
return value * tfutil.exp2(self.get_loss_scaling_var(value.device))
|
281 |
+
|
282 |
+
def undo_loss_scaling(self, value: TfExpression) -> TfExpression:
|
283 |
+
"""Undo the effect of dynamic loss scaling for the given expression."""
|
284 |
+
assert tfutil.is_tf_expression(value)
|
285 |
+
if not self.use_loss_scaling:
|
286 |
+
return value
|
287 |
+
return value * tfutil.exp2(-self.get_loss_scaling_var(value.device)) # pylint: disable=invalid-unary-operand-type
|
288 |
+
|
289 |
+
def _broadcast_nccl(self):
|
290 |
+
"""Sum gradients across devices using NCCL ops (fast path)."""
|
291 |
+
from tensorflow.python.ops import nccl_ops # pylint: disable=no-name-in-module
|
292 |
+
for all_vars in zip(*[device.grad_clean.keys() for device in self._devices.values()]):
|
293 |
+
if any(x.shape.num_elements() > 0 for x in all_vars):
|
294 |
+
all_grads = [device.grad_clean[var] for device, var in zip(self._devices.values(), all_vars)]
|
295 |
+
all_grads = nccl_ops.all_sum(all_grads)
|
296 |
+
for device, var, grad in zip(self._devices.values(), all_vars, all_grads):
|
297 |
+
device.grad_clean[var] = grad
|
298 |
+
|
299 |
+
def _broadcast_fallback(self):
|
300 |
+
"""Sum gradients across devices using TensorFlow collective ops (slow fallback path)."""
|
301 |
+
from tensorflow.python.ops import collective_ops # pylint: disable=no-name-in-module
|
302 |
+
global _collective_ops_warning_printed, _collective_ops_group_key, _collective_ops_instance_key
|
303 |
+
if all(x.shape.num_elements() == 0 for device in self._devices.values() for x in device.grad_clean.values()):
|
304 |
+
return
|
305 |
+
if not _collective_ops_warning_printed:
|
306 |
+
print("------------------------------------------------------------------------")
|
307 |
+
print("WARNING: Using slow fallback implementation for inter-GPU communication.")
|
308 |
+
print("Please use TensorFlow 1.14 on Linux for optimal training performance.")
|
309 |
+
print("------------------------------------------------------------------------")
|
310 |
+
_collective_ops_warning_printed = True
|
311 |
+
for device in self._devices.values():
|
312 |
+
with tf.device(device.name):
|
313 |
+
combo = [tf.reshape(x, [x.shape.num_elements()]) for x in device.grad_clean.values()]
|
314 |
+
combo = tf.concat(combo, axis=0)
|
315 |
+
combo = collective_ops.all_reduce(combo, merge_op='Add', final_op='Id',
|
316 |
+
group_size=len(self._devices), group_key=_collective_ops_group_key,
|
317 |
+
instance_key=_collective_ops_instance_key)
|
318 |
+
cur_ofs = 0
|
319 |
+
for var, grad_old in device.grad_clean.items():
|
320 |
+
grad_new = tf.reshape(combo[cur_ofs : cur_ofs + grad_old.shape.num_elements()], grad_old.shape)
|
321 |
+
cur_ofs += grad_old.shape.num_elements()
|
322 |
+
device.grad_clean[var] = grad_new
|
323 |
+
_collective_ops_instance_key += 1
|
324 |
+
|
325 |
+
|
326 |
+
class SimpleAdam:
|
327 |
+
"""Simplified version of tf.train.AdamOptimizer that behaves identically when used with dnnlib.tflib.Optimizer."""
|
328 |
+
|
329 |
+
def __init__(self, name="Adam", learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8):
|
330 |
+
self.name = name
|
331 |
+
self.learning_rate = learning_rate
|
332 |
+
self.beta1 = beta1
|
333 |
+
self.beta2 = beta2
|
334 |
+
self.epsilon = epsilon
|
335 |
+
self.all_state_vars = []
|
336 |
+
|
337 |
+
def variables(self):
|
338 |
+
return self.all_state_vars
|
339 |
+
|
340 |
+
def compute_gradients(self, loss, var_list, gate_gradients=tf.train.Optimizer.GATE_NONE):
|
341 |
+
assert gate_gradients == tf.train.Optimizer.GATE_NONE
|
342 |
+
return list(zip(tf.gradients(loss, var_list), var_list))
|
343 |
+
|
344 |
+
def apply_gradients(self, grads_and_vars):
|
345 |
+
with tf.name_scope(self.name):
|
346 |
+
state_vars = []
|
347 |
+
update_ops = []
|
348 |
+
|
349 |
+
# Adjust learning rate to deal with startup bias.
|
350 |
+
with tf.control_dependencies(None):
|
351 |
+
b1pow_var = tf.Variable(dtype=tf.float32, initial_value=1, trainable=False)
|
352 |
+
b2pow_var = tf.Variable(dtype=tf.float32, initial_value=1, trainable=False)
|
353 |
+
state_vars += [b1pow_var, b2pow_var]
|
354 |
+
b1pow_new = b1pow_var * self.beta1
|
355 |
+
b2pow_new = b2pow_var * self.beta2
|
356 |
+
update_ops += [tf.assign(b1pow_var, b1pow_new), tf.assign(b2pow_var, b2pow_new)]
|
357 |
+
lr_new = self.learning_rate * tf.sqrt(1 - b2pow_new) / (1 - b1pow_new)
|
358 |
+
|
359 |
+
# Construct ops to update each variable.
|
360 |
+
for grad, var in grads_and_vars:
|
361 |
+
with tf.control_dependencies(None):
|
362 |
+
m_var = tf.Variable(dtype=tf.float32, initial_value=tf.zeros_like(var), trainable=False)
|
363 |
+
v_var = tf.Variable(dtype=tf.float32, initial_value=tf.zeros_like(var), trainable=False)
|
364 |
+
state_vars += [m_var, v_var]
|
365 |
+
m_new = self.beta1 * m_var + (1 - self.beta1) * grad
|
366 |
+
v_new = self.beta2 * v_var + (1 - self.beta2) * tf.square(grad)
|
367 |
+
var_delta = lr_new * m_new / (tf.sqrt(v_new) + self.epsilon)
|
368 |
+
update_ops += [tf.assign(m_var, m_new), tf.assign(v_var, v_new), tf.assign_sub(var, var_delta)]
|
369 |
+
|
370 |
+
# Group everything together.
|
371 |
+
self.all_state_vars += state_vars
|
372 |
+
return tf.group(*update_ops)
|