i72sijia commited on
Commit
8a6a841
1 Parent(s): 51d59a1

Upload network.py

Browse files
Files changed (1) hide show
  1. dnnlib/tflib/network.py +781 -0
dnnlib/tflib/network.py ADDED
@@ -0,0 +1,781 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Helper for managing networks."""
10
+
11
+ import types
12
+ import inspect
13
+ import re
14
+ import uuid
15
+ import sys
16
+ import copy
17
+ import numpy as np
18
+ import tensorflow as tf
19
+
20
+ from collections import OrderedDict
21
+ from typing import Any, List, Tuple, Union, Callable
22
+
23
+ from . import tfutil
24
+ from .. import util
25
+
26
+ from .tfutil import TfExpression, TfExpressionEx
27
+
28
+ # pylint: disable=protected-access
29
+ # pylint: disable=attribute-defined-outside-init
30
+ # pylint: disable=too-many-public-methods
31
+
32
+ _import_handlers = [] # Custom import handlers for dealing with legacy data in pickle import.
33
+ _import_module_src = dict() # Source code for temporary modules created during pickle import.
34
+
35
+
36
+ def import_handler(handler_func):
37
+ """Function decorator for declaring custom import handlers."""
38
+ _import_handlers.append(handler_func)
39
+ return handler_func
40
+
41
+
42
+ class Network:
43
+ """Generic network abstraction.
44
+
45
+ Acts as a convenience wrapper for a parameterized network construction
46
+ function, providing several utility methods and convenient access to
47
+ the inputs/outputs/weights.
48
+
49
+ Network objects can be safely pickled and unpickled for long-term
50
+ archival purposes. The pickling works reliably as long as the underlying
51
+ network construction function is defined in a standalone Python module
52
+ that has no side effects or application-specific imports.
53
+
54
+ Args:
55
+ name: Network name. Used to select TensorFlow name and variable scopes. Defaults to build func name if None.
56
+ func_name: Fully qualified name of the underlying network construction function, or a top-level function object.
57
+ static_kwargs: Keyword arguments to be passed in to the network construction function.
58
+ """
59
+
60
+ def __init__(self, name: str = None, func_name: Any = None, **static_kwargs):
61
+ # Locate the user-specified build function.
62
+ assert isinstance(func_name, str) or util.is_top_level_function(func_name)
63
+ if util.is_top_level_function(func_name):
64
+ func_name = util.get_top_level_function_name(func_name)
65
+ module, func_name = util.get_module_from_obj_name(func_name)
66
+ func = util.get_obj_from_module(module, func_name)
67
+
68
+ # Dig up source code for the module containing the build function.
69
+ module_src = _import_module_src.get(module, None)
70
+ if module_src is None:
71
+ module_src = inspect.getsource(module)
72
+
73
+ # Initialize fields.
74
+ self._init_fields(name=(name or func_name), static_kwargs=static_kwargs, build_func=func, build_func_name=func_name, build_module_src=module_src)
75
+
76
+ def _init_fields(self, name: str, static_kwargs: dict, build_func: Callable, build_func_name: str, build_module_src: str) -> None:
77
+ tfutil.assert_tf_initialized()
78
+ assert isinstance(name, str)
79
+ assert len(name) >= 1
80
+ assert re.fullmatch(r"[A-Za-z0-9_.\\-]*", name)
81
+ assert isinstance(static_kwargs, dict)
82
+ assert util.is_pickleable(static_kwargs)
83
+ assert callable(build_func)
84
+ assert isinstance(build_func_name, str)
85
+ assert isinstance(build_module_src, str)
86
+
87
+ # Choose TensorFlow name scope.
88
+ with tf.name_scope(None):
89
+ scope = tf.get_default_graph().unique_name(name, mark_as_used=True)
90
+
91
+ # Query current TensorFlow device.
92
+ with tfutil.absolute_name_scope(scope), tf.control_dependencies(None):
93
+ device = tf.no_op(name="_QueryDevice").device
94
+
95
+ # Immutable state.
96
+ self._name = name
97
+ self._scope = scope
98
+ self._device = device
99
+ self._static_kwargs = util.EasyDict(copy.deepcopy(static_kwargs))
100
+ self._build_func = build_func
101
+ self._build_func_name = build_func_name
102
+ self._build_module_src = build_module_src
103
+
104
+ # State before _init_graph().
105
+ self._var_inits = dict() # var_name => initial_value, set to None by _init_graph()
106
+ self._all_inits_known = False # Do we know for sure that _var_inits covers all the variables?
107
+ self._components = None # subnet_name => Network, None if the components are not known yet
108
+
109
+ # Initialized by _init_graph().
110
+ self._input_templates = None
111
+ self._output_templates = None
112
+ self._own_vars = None
113
+
114
+ # Cached values initialized the respective methods.
115
+ self._input_shapes = None
116
+ self._output_shapes = None
117
+ self._input_names = None
118
+ self._output_names = None
119
+ self._vars = None
120
+ self._trainables = None
121
+ self._var_global_to_local = None
122
+ self._run_cache = dict()
123
+
124
+ def _init_graph(self) -> None:
125
+ assert self._var_inits is not None
126
+ assert self._input_templates is None
127
+ assert self._output_templates is None
128
+ assert self._own_vars is None
129
+
130
+ # Initialize components.
131
+ if self._components is None:
132
+ self._components = util.EasyDict()
133
+
134
+ # Choose build func kwargs.
135
+ build_kwargs = dict(self.static_kwargs)
136
+ build_kwargs["is_template_graph"] = True
137
+ build_kwargs["components"] = self._components
138
+
139
+ # Override scope and device, and ignore surrounding control dependencies.
140
+ with tfutil.absolute_variable_scope(self.scope, reuse=False), tfutil.absolute_name_scope(self.scope), tf.device(self.device), tf.control_dependencies(None):
141
+ assert tf.get_variable_scope().name == self.scope
142
+ assert tf.get_default_graph().get_name_scope() == self.scope
143
+
144
+ # Create input templates.
145
+ self._input_templates = []
146
+ for param in inspect.signature(self._build_func).parameters.values():
147
+ if param.kind == param.POSITIONAL_OR_KEYWORD and param.default is param.empty:
148
+ self._input_templates.append(tf.placeholder(tf.float32, name=param.name))
149
+
150
+ # Call build func.
151
+ out_expr = self._build_func(*self._input_templates, **build_kwargs)
152
+
153
+ # Collect output templates and variables.
154
+ assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)
155
+ self._output_templates = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr)
156
+ self._own_vars = OrderedDict((var.name[len(self.scope) + 1:].split(":")[0], var) for var in tf.global_variables(self.scope + "/"))
157
+
158
+ # Check for errors.
159
+ if len(self._input_templates) == 0:
160
+ raise ValueError("Network build func did not list any inputs.")
161
+ if len(self._output_templates) == 0:
162
+ raise ValueError("Network build func did not return any outputs.")
163
+ if any(not tfutil.is_tf_expression(t) for t in self._output_templates):
164
+ raise ValueError("Network outputs must be TensorFlow expressions.")
165
+ if any(t.shape.ndims is None for t in self._input_templates):
166
+ raise ValueError("Network input shapes not defined. Please call x.set_shape() for each input.")
167
+ if any(t.shape.ndims is None for t in self._output_templates):
168
+ raise ValueError("Network output shapes not defined. Please call x.set_shape() where applicable.")
169
+ if any(not isinstance(comp, Network) for comp in self._components.values()):
170
+ raise ValueError("Components of a Network must be Networks themselves.")
171
+ if len(self._components) != len(set(comp.name for comp in self._components.values())):
172
+ raise ValueError("Components of a Network must have unique names.")
173
+
174
+ # Initialize variables.
175
+ if len(self._var_inits):
176
+ tfutil.set_vars({self._get_vars()[name]: value for name, value in self._var_inits.items() if name in self._get_vars()})
177
+ remaining_inits = [var.initializer for name, var in self._own_vars.items() if name not in self._var_inits]
178
+ if self._all_inits_known:
179
+ assert len(remaining_inits) == 0
180
+ else:
181
+ tfutil.run(remaining_inits)
182
+ self._var_inits = None
183
+
184
+ @property
185
+ def name(self):
186
+ """User-specified name string."""
187
+ return self._name
188
+
189
+ @property
190
+ def scope(self):
191
+ """Unique TensorFlow scope containing template graph and variables, derived from the user-specified name."""
192
+ return self._scope
193
+
194
+ @property
195
+ def device(self):
196
+ """Name of the TensorFlow device that the weights of this network reside on. Determined by the current device at construction time."""
197
+ return self._device
198
+
199
+ @property
200
+ def static_kwargs(self):
201
+ """EasyDict of arguments passed to the user-supplied build func."""
202
+ return copy.deepcopy(self._static_kwargs)
203
+
204
+ @property
205
+ def components(self):
206
+ """EasyDict of sub-networks created by the build func."""
207
+ return copy.copy(self._get_components())
208
+
209
+ def _get_components(self):
210
+ if self._components is None:
211
+ self._init_graph()
212
+ assert self._components is not None
213
+ return self._components
214
+
215
+ @property
216
+ def input_shapes(self):
217
+ """List of input tensor shapes, including minibatch dimension."""
218
+ if self._input_shapes is None:
219
+ self._input_shapes = [t.shape.as_list() for t in self.input_templates]
220
+ return copy.deepcopy(self._input_shapes)
221
+
222
+ @property
223
+ def output_shapes(self):
224
+ """List of output tensor shapes, including minibatch dimension."""
225
+ if self._output_shapes is None:
226
+ self._output_shapes = [t.shape.as_list() for t in self.output_templates]
227
+ return copy.deepcopy(self._output_shapes)
228
+
229
+ @property
230
+ def input_shape(self):
231
+ """Short-hand for input_shapes[0]."""
232
+ return self.input_shapes[0]
233
+
234
+ @property
235
+ def output_shape(self):
236
+ """Short-hand for output_shapes[0]."""
237
+ return self.output_shapes[0]
238
+
239
+ @property
240
+ def num_inputs(self):
241
+ """Number of input tensors."""
242
+ return len(self.input_shapes)
243
+
244
+ @property
245
+ def num_outputs(self):
246
+ """Number of output tensors."""
247
+ return len(self.output_shapes)
248
+
249
+ @property
250
+ def input_names(self):
251
+ """Name string for each input."""
252
+ if self._input_names is None:
253
+ self._input_names = [t.name.split("/")[-1].split(":")[0] for t in self.input_templates]
254
+ return copy.copy(self._input_names)
255
+
256
+ @property
257
+ def output_names(self):
258
+ """Name string for each output."""
259
+ if self._output_names is None:
260
+ self._output_names = [t.name.split("/")[-1].split(":")[0] for t in self.output_templates]
261
+ return copy.copy(self._output_names)
262
+
263
+ @property
264
+ def input_templates(self):
265
+ """Input placeholders in the template graph."""
266
+ if self._input_templates is None:
267
+ self._init_graph()
268
+ assert self._input_templates is not None
269
+ return copy.copy(self._input_templates)
270
+
271
+ @property
272
+ def output_templates(self):
273
+ """Output tensors in the template graph."""
274
+ if self._output_templates is None:
275
+ self._init_graph()
276
+ assert self._output_templates is not None
277
+ return copy.copy(self._output_templates)
278
+
279
+ @property
280
+ def own_vars(self):
281
+ """Variables defined by this network (local_name => var), excluding sub-networks."""
282
+ return copy.copy(self._get_own_vars())
283
+
284
+ def _get_own_vars(self):
285
+ if self._own_vars is None:
286
+ self._init_graph()
287
+ assert self._own_vars is not None
288
+ return self._own_vars
289
+
290
+ @property
291
+ def vars(self):
292
+ """All variables (local_name => var)."""
293
+ return copy.copy(self._get_vars())
294
+
295
+ def _get_vars(self):
296
+ if self._vars is None:
297
+ self._vars = OrderedDict(self._get_own_vars())
298
+ for comp in self._get_components().values():
299
+ self._vars.update((comp.name + "/" + name, var) for name, var in comp._get_vars().items())
300
+ return self._vars
301
+
302
+ @property
303
+ def trainables(self):
304
+ """All trainable variables (local_name => var)."""
305
+ return copy.copy(self._get_trainables())
306
+
307
+ def _get_trainables(self):
308
+ if self._trainables is None:
309
+ self._trainables = OrderedDict((name, var) for name, var in self.vars.items() if var.trainable)
310
+ return self._trainables
311
+
312
+ @property
313
+ def var_global_to_local(self):
314
+ """Mapping from variable global names to local names."""
315
+ return copy.copy(self._get_var_global_to_local())
316
+
317
+ def _get_var_global_to_local(self):
318
+ if self._var_global_to_local is None:
319
+ self._var_global_to_local = OrderedDict((var.name.split(":")[0], name) for name, var in self.vars.items())
320
+ return self._var_global_to_local
321
+
322
+ def reset_own_vars(self) -> None:
323
+ """Re-initialize all variables of this network, excluding sub-networks."""
324
+ if self._var_inits is None or self._components is None:
325
+ tfutil.run([var.initializer for var in self._get_own_vars().values()])
326
+ else:
327
+ self._var_inits.clear()
328
+ self._all_inits_known = False
329
+
330
+ def reset_vars(self) -> None:
331
+ """Re-initialize all variables of this network, including sub-networks."""
332
+ if self._var_inits is None:
333
+ tfutil.run([var.initializer for var in self._get_vars().values()])
334
+ else:
335
+ self._var_inits.clear()
336
+ self._all_inits_known = False
337
+ if self._components is not None:
338
+ for comp in self._components.values():
339
+ comp.reset_vars()
340
+
341
+ def reset_trainables(self) -> None:
342
+ """Re-initialize all trainable variables of this network, including sub-networks."""
343
+ tfutil.run([var.initializer for var in self._get_trainables().values()])
344
+
345
+ def get_output_for(self, *in_expr: TfExpression, return_as_list: bool = False, **dynamic_kwargs) -> Union[TfExpression, List[TfExpression]]:
346
+ """Construct TensorFlow expression(s) for the output(s) of this network, given the input expression(s).
347
+ The graph is placed on the current TensorFlow device."""
348
+ assert len(in_expr) == self.num_inputs
349
+ assert not all(expr is None for expr in in_expr)
350
+ self._get_vars() # ensure that all variables have been created
351
+
352
+ # Choose build func kwargs.
353
+ build_kwargs = dict(self.static_kwargs)
354
+ build_kwargs.update(dynamic_kwargs)
355
+ build_kwargs["is_template_graph"] = False
356
+ build_kwargs["components"] = self._components
357
+
358
+ # Build TensorFlow graph to evaluate the network.
359
+ with tfutil.absolute_variable_scope(self.scope, reuse=True), tf.name_scope(self.name):
360
+ assert tf.get_variable_scope().name == self.scope
361
+ valid_inputs = [expr for expr in in_expr if expr is not None]
362
+ final_inputs = []
363
+ for expr, name, shape in zip(in_expr, self.input_names, self.input_shapes):
364
+ if expr is not None:
365
+ expr = tf.identity(expr, name=name)
366
+ else:
367
+ expr = tf.zeros([tf.shape(valid_inputs[0])[0]] + shape[1:], name=name)
368
+ final_inputs.append(expr)
369
+ out_expr = self._build_func(*final_inputs, **build_kwargs)
370
+
371
+ # Propagate input shapes back to the user-specified expressions.
372
+ for expr, final in zip(in_expr, final_inputs):
373
+ if isinstance(expr, tf.Tensor):
374
+ expr.set_shape(final.shape)
375
+
376
+ # Express outputs in the desired format.
377
+ assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)
378
+ if return_as_list:
379
+ out_expr = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr)
380
+ return out_expr
381
+
382
+ def get_var_local_name(self, var_or_global_name: Union[TfExpression, str]) -> str:
383
+ """Get the local name of a given variable, without any surrounding name scopes."""
384
+ assert tfutil.is_tf_expression(var_or_global_name) or isinstance(var_or_global_name, str)
385
+ global_name = var_or_global_name if isinstance(var_or_global_name, str) else var_or_global_name.name
386
+ return self._get_var_global_to_local()[global_name]
387
+
388
+ def find_var(self, var_or_local_name: Union[TfExpression, str]) -> TfExpression:
389
+ """Find variable by local or global name."""
390
+ assert tfutil.is_tf_expression(var_or_local_name) or isinstance(var_or_local_name, str)
391
+ return self._get_vars()[var_or_local_name] if isinstance(var_or_local_name, str) else var_or_local_name
392
+
393
+ def get_var(self, var_or_local_name: Union[TfExpression, str]) -> np.ndarray:
394
+ """Get the value of a given variable as NumPy array.
395
+ Note: This method is very inefficient -- prefer to use tflib.run(list_of_vars) whenever possible."""
396
+ return self.find_var(var_or_local_name).eval()
397
+
398
+ def set_var(self, var_or_local_name: Union[TfExpression, str], new_value: Union[int, float, np.ndarray]) -> None:
399
+ """Set the value of a given variable based on the given NumPy array.
400
+ Note: This method is very inefficient -- prefer to use tflib.set_vars() whenever possible."""
401
+ tfutil.set_vars({self.find_var(var_or_local_name): new_value})
402
+
403
+ def __getstate__(self) -> dict:
404
+ """Pickle export."""
405
+ state = dict()
406
+ state["version"] = 5
407
+ state["name"] = self.name
408
+ state["static_kwargs"] = dict(self.static_kwargs)
409
+ state["components"] = dict(self.components)
410
+ state["build_module_src"] = self._build_module_src
411
+ state["build_func_name"] = self._build_func_name
412
+ state["variables"] = list(zip(self._get_own_vars().keys(), tfutil.run(list(self._get_own_vars().values()))))
413
+ state["input_shapes"] = self.input_shapes
414
+ state["output_shapes"] = self.output_shapes
415
+ state["input_names"] = self.input_names
416
+ state["output_names"] = self.output_names
417
+ return state
418
+
419
+ def __setstate__(self, state: dict) -> None:
420
+ """Pickle import."""
421
+
422
+ # Execute custom import handlers.
423
+ for handler in _import_handlers:
424
+ state = handler(state)
425
+
426
+ # Get basic fields.
427
+ assert state["version"] in [2, 3, 4, 5]
428
+ name = state["name"]
429
+ static_kwargs = state["static_kwargs"]
430
+ build_module_src = state["build_module_src"]
431
+ build_func_name = state["build_func_name"]
432
+
433
+ # Create temporary module from the imported source code.
434
+ module_name = "_tflib_network_import_" + uuid.uuid4().hex
435
+ module = types.ModuleType(module_name)
436
+ sys.modules[module_name] = module
437
+ _import_module_src[module] = build_module_src
438
+ exec(build_module_src, module.__dict__) # pylint: disable=exec-used
439
+ build_func = util.get_obj_from_module(module, build_func_name)
440
+
441
+ # Initialize fields.
442
+ self._init_fields(name=name, static_kwargs=static_kwargs, build_func=build_func, build_func_name=build_func_name, build_module_src=build_module_src)
443
+ self._var_inits.update(copy.deepcopy(state["variables"]))
444
+ self._all_inits_known = True
445
+ self._components = util.EasyDict(state.get("components", {}))
446
+ self._input_shapes = copy.deepcopy(state.get("input_shapes", None))
447
+ self._output_shapes = copy.deepcopy(state.get("output_shapes", None))
448
+ self._input_names = copy.deepcopy(state.get("input_names", None))
449
+ self._output_names = copy.deepcopy(state.get("output_names", None))
450
+
451
+ def clone(self, name: str = None, **new_static_kwargs) -> "Network":
452
+ """Create a clone of this network with its own copy of the variables."""
453
+ static_kwargs = dict(self.static_kwargs)
454
+ static_kwargs.update(new_static_kwargs)
455
+ net = object.__new__(Network)
456
+ net._init_fields(name=(name or self.name), static_kwargs=static_kwargs, build_func=self._build_func, build_func_name=self._build_func_name, build_module_src=self._build_module_src)
457
+ net.copy_vars_from(self)
458
+ return net
459
+
460
+ def copy_own_vars_from(self, src_net: "Network") -> None:
461
+ """Copy the values of all variables from the given network, excluding sub-networks."""
462
+
463
+ # Source has unknown variables or unknown components => init now.
464
+ if (src_net._var_inits is not None and not src_net._all_inits_known) or src_net._components is None:
465
+ src_net._get_vars()
466
+
467
+ # Both networks are inited => copy directly.
468
+ if src_net._var_inits is None and self._var_inits is None:
469
+ names = [name for name in self._get_own_vars().keys() if name in src_net._get_own_vars()]
470
+ tfutil.set_vars(tfutil.run({self._get_vars()[name]: src_net._get_vars()[name] for name in names}))
471
+ return
472
+
473
+ # Read from source.
474
+ if src_net._var_inits is None:
475
+ value_dict = tfutil.run(src_net._get_own_vars())
476
+ else:
477
+ value_dict = src_net._var_inits
478
+
479
+ # Write to destination.
480
+ if self._var_inits is None:
481
+ tfutil.set_vars({self._get_vars()[name]: value for name, value in value_dict.items() if name in self._get_vars()})
482
+ else:
483
+ self._var_inits.update(value_dict)
484
+
485
+ def copy_vars_from(self, src_net: "Network") -> None:
486
+ """Copy the values of all variables from the given network, including sub-networks."""
487
+
488
+ # Source has unknown variables or unknown components => init now.
489
+ if (src_net._var_inits is not None and not src_net._all_inits_known) or src_net._components is None:
490
+ src_net._get_vars()
491
+
492
+ # Source is inited, but destination components have not been created yet => set as initial values.
493
+ if src_net._var_inits is None and self._components is None:
494
+ self._var_inits.update(tfutil.run(src_net._get_vars()))
495
+ return
496
+
497
+ # Destination has unknown components => init now.
498
+ if self._components is None:
499
+ self._get_vars()
500
+
501
+ # Both networks are inited => copy directly.
502
+ if src_net._var_inits is None and self._var_inits is None:
503
+ names = [name for name in self._get_vars().keys() if name in src_net._get_vars()]
504
+ tfutil.set_vars(tfutil.run({self._get_vars()[name]: src_net._get_vars()[name] for name in names}))
505
+ return
506
+
507
+ # Copy recursively, component by component.
508
+ self.copy_own_vars_from(src_net)
509
+ for name, src_comp in src_net._components.items():
510
+ if name in self._components:
511
+ self._components[name].copy_vars_from(src_comp)
512
+
513
+ def copy_trainables_from(self, src_net: "Network") -> None:
514
+ """Copy the values of all trainable variables from the given network, including sub-networks."""
515
+ names = [name for name in self._get_trainables().keys() if name in src_net._get_trainables()]
516
+ tfutil.set_vars(tfutil.run({self._get_vars()[name]: src_net._get_vars()[name] for name in names}))
517
+
518
+ def convert(self, new_func_name: str, new_name: str = None, **new_static_kwargs) -> "Network":
519
+ """Create new network with the given parameters, and copy all variables from this network."""
520
+ if new_name is None:
521
+ new_name = self.name
522
+ static_kwargs = dict(self.static_kwargs)
523
+ static_kwargs.update(new_static_kwargs)
524
+ net = Network(name=new_name, func_name=new_func_name, **static_kwargs)
525
+ net.copy_vars_from(self)
526
+ return net
527
+
528
+ def setup_as_moving_average_of(self, src_net: "Network", beta: TfExpressionEx = 0.99, beta_nontrainable: TfExpressionEx = 0.0) -> tf.Operation:
529
+ """Construct a TensorFlow op that updates the variables of this network
530
+ to be slightly closer to those of the given network."""
531
+ with tfutil.absolute_name_scope(self.scope + "/_MovingAvg"):
532
+ ops = []
533
+ for name, var in self._get_vars().items():
534
+ if name in src_net._get_vars():
535
+ cur_beta = beta if var.trainable else beta_nontrainable
536
+ new_value = tfutil.lerp(src_net._get_vars()[name], var, cur_beta)
537
+ ops.append(var.assign(new_value))
538
+ return tf.group(*ops)
539
+
540
+ def run(self,
541
+ *in_arrays: Tuple[Union[np.ndarray, None], ...],
542
+ input_transform: dict = None,
543
+ output_transform: dict = None,
544
+ return_as_list: bool = False,
545
+ print_progress: bool = False,
546
+ minibatch_size: int = None,
547
+ num_gpus: int = 1,
548
+ assume_frozen: bool = False,
549
+ **dynamic_kwargs) -> Union[np.ndarray, Tuple[np.ndarray, ...], List[np.ndarray]]:
550
+ """Run this network for the given NumPy array(s), and return the output(s) as NumPy array(s).
551
+
552
+ Args:
553
+ input_transform: A dict specifying a custom transformation to be applied to the input tensor(s) before evaluating the network.
554
+ The dict must contain a 'func' field that points to a top-level function. The function is called with the input
555
+ TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.
556
+ output_transform: A dict specifying a custom transformation to be applied to the output tensor(s) after evaluating the network.
557
+ The dict must contain a 'func' field that points to a top-level function. The function is called with the output
558
+ TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.
559
+ return_as_list: True = return a list of NumPy arrays, False = return a single NumPy array, or a tuple if there are multiple outputs.
560
+ print_progress: Print progress to the console? Useful for very large input arrays.
561
+ minibatch_size: Maximum minibatch size to use, None = disable batching.
562
+ num_gpus: Number of GPUs to use.
563
+ assume_frozen: Improve multi-GPU performance by assuming that the trainable parameters will remain changed between calls.
564
+ dynamic_kwargs: Additional keyword arguments to be passed into the network build function.
565
+ """
566
+ assert len(in_arrays) == self.num_inputs
567
+ assert not all(arr is None for arr in in_arrays)
568
+ assert input_transform is None or util.is_top_level_function(input_transform["func"])
569
+ assert output_transform is None or util.is_top_level_function(output_transform["func"])
570
+ output_transform, dynamic_kwargs = _handle_legacy_output_transforms(output_transform, dynamic_kwargs)
571
+ num_items = in_arrays[0].shape[0]
572
+ if minibatch_size is None:
573
+ minibatch_size = num_items
574
+
575
+ # Construct unique hash key from all arguments that affect the TensorFlow graph.
576
+ key = dict(input_transform=input_transform, output_transform=output_transform, num_gpus=num_gpus, assume_frozen=assume_frozen, dynamic_kwargs=dynamic_kwargs)
577
+ def unwind_key(obj):
578
+ if isinstance(obj, dict):
579
+ return [(key, unwind_key(value)) for key, value in sorted(obj.items())]
580
+ if callable(obj):
581
+ return util.get_top_level_function_name(obj)
582
+ return obj
583
+ key = repr(unwind_key(key))
584
+
585
+ # Build graph.
586
+ if key not in self._run_cache:
587
+ with tfutil.absolute_name_scope(self.scope + "/_Run"), tf.control_dependencies(None):
588
+ with tf.device("/cpu:0"):
589
+ in_expr = [tf.placeholder(tf.float32, name=name) for name in self.input_names]
590
+ in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr]))
591
+
592
+ out_split = []
593
+ for gpu in range(num_gpus):
594
+ with tf.device(self.device if num_gpus == 1 else "/gpu:%d" % gpu):
595
+ net_gpu = self.clone() if assume_frozen else self
596
+ in_gpu = in_split[gpu]
597
+
598
+ if input_transform is not None:
599
+ in_kwargs = dict(input_transform)
600
+ in_gpu = in_kwargs.pop("func")(*in_gpu, **in_kwargs)
601
+ in_gpu = [in_gpu] if tfutil.is_tf_expression(in_gpu) else list(in_gpu)
602
+
603
+ assert len(in_gpu) == self.num_inputs
604
+ out_gpu = net_gpu.get_output_for(*in_gpu, return_as_list=True, **dynamic_kwargs)
605
+
606
+ if output_transform is not None:
607
+ out_kwargs = dict(output_transform)
608
+ out_gpu = out_kwargs.pop("func")(*out_gpu, **out_kwargs)
609
+ out_gpu = [out_gpu] if tfutil.is_tf_expression(out_gpu) else list(out_gpu)
610
+
611
+ assert len(out_gpu) == self.num_outputs
612
+ out_split.append(out_gpu)
613
+
614
+ with tf.device("/cpu:0"):
615
+ out_expr = [tf.concat(outputs, axis=0) for outputs in zip(*out_split)]
616
+ self._run_cache[key] = in_expr, out_expr
617
+
618
+ # Run minibatches.
619
+ in_expr, out_expr = self._run_cache[key]
620
+ out_arrays = [np.empty([num_items] + expr.shape.as_list()[1:], expr.dtype.name) for expr in out_expr]
621
+
622
+ for mb_begin in range(0, num_items, minibatch_size):
623
+ if print_progress:
624
+ print("\r%d / %d" % (mb_begin, num_items), end="")
625
+
626
+ mb_end = min(mb_begin + minibatch_size, num_items)
627
+ mb_num = mb_end - mb_begin
628
+ mb_in = [src[mb_begin : mb_end] if src is not None else np.zeros([mb_num] + shape[1:]) for src, shape in zip(in_arrays, self.input_shapes)]
629
+ mb_out = tf.get_default_session().run(out_expr, dict(zip(in_expr, mb_in)))
630
+
631
+ for dst, src in zip(out_arrays, mb_out):
632
+ dst[mb_begin: mb_end] = src
633
+
634
+ # Done.
635
+ if print_progress:
636
+ print("\r%d / %d" % (num_items, num_items))
637
+
638
+ if not return_as_list:
639
+ out_arrays = out_arrays[0] if len(out_arrays) == 1 else tuple(out_arrays)
640
+ return out_arrays
641
+
642
+ def list_ops(self) -> List[TfExpression]:
643
+ _ = self.output_templates # ensure that the template graph has been created
644
+ include_prefix = self.scope + "/"
645
+ exclude_prefix = include_prefix + "_"
646
+ ops = tf.get_default_graph().get_operations()
647
+ ops = [op for op in ops if op.name.startswith(include_prefix)]
648
+ ops = [op for op in ops if not op.name.startswith(exclude_prefix)]
649
+ return ops
650
+
651
+ def list_layers(self) -> List[Tuple[str, TfExpression, List[TfExpression]]]:
652
+ """Returns a list of (layer_name, output_expr, trainable_vars) tuples corresponding to
653
+ individual layers of the network. Mainly intended to be used for reporting."""
654
+ layers = []
655
+
656
+ def recurse(scope, parent_ops, parent_vars, level):
657
+ if len(parent_ops) == 0 and len(parent_vars) == 0:
658
+ return
659
+
660
+ # Ignore specific patterns.
661
+ if any(p in scope for p in ["/Shape", "/strided_slice", "/Cast", "/concat", "/Assign"]):
662
+ return
663
+
664
+ # Filter ops and vars by scope.
665
+ global_prefix = scope + "/"
666
+ local_prefix = global_prefix[len(self.scope) + 1:]
667
+ cur_ops = [op for op in parent_ops if op.name.startswith(global_prefix) or op.name == global_prefix[:-1]]
668
+ cur_vars = [(name, var) for name, var in parent_vars if name.startswith(local_prefix) or name == local_prefix[:-1]]
669
+ if not cur_ops and not cur_vars:
670
+ return
671
+
672
+ # Filter out all ops related to variables.
673
+ for var in [op for op in cur_ops if op.type.startswith("Variable")]:
674
+ var_prefix = var.name + "/"
675
+ cur_ops = [op for op in cur_ops if not op.name.startswith(var_prefix)]
676
+
677
+ # Scope does not contain ops as immediate children => recurse deeper.
678
+ contains_direct_ops = any("/" not in op.name[len(global_prefix):] and op.type not in ["Identity", "Cast", "Transpose"] for op in cur_ops)
679
+ if (level == 0 or not contains_direct_ops) and (len(cur_ops) != 0 or len(cur_vars) != 0):
680
+ visited = set()
681
+ for rel_name in [op.name[len(global_prefix):] for op in cur_ops] + [name[len(local_prefix):] for name, _var in cur_vars]:
682
+ token = rel_name.split("/")[0]
683
+ if token not in visited:
684
+ recurse(global_prefix + token, cur_ops, cur_vars, level + 1)
685
+ visited.add(token)
686
+ return
687
+
688
+ # Report layer.
689
+ layer_name = scope[len(self.scope) + 1:]
690
+ layer_output = cur_ops[-1].outputs[0] if cur_ops else cur_vars[-1][1]
691
+ layer_trainables = [var for _name, var in cur_vars if var.trainable]
692
+ layers.append((layer_name, layer_output, layer_trainables))
693
+
694
+ recurse(self.scope, self.list_ops(), list(self._get_vars().items()), 0)
695
+ return layers
696
+
697
+ def print_layers(self, title: str = None, hide_layers_with_no_params: bool = False) -> None:
698
+ """Print a summary table of the network structure."""
699
+ rows = [[title if title is not None else self.name, "Params", "OutputShape", "WeightShape"]]
700
+ rows += [["---"] * 4]
701
+ total_params = 0
702
+
703
+ for layer_name, layer_output, layer_trainables in self.list_layers():
704
+ num_params = sum(int(np.prod(var.shape.as_list())) for var in layer_trainables)
705
+ weights = [var for var in layer_trainables if var.name.endswith("/weight:0")]
706
+ weights.sort(key=lambda x: len(x.name))
707
+ if len(weights) == 0 and len(layer_trainables) == 1:
708
+ weights = layer_trainables
709
+ total_params += num_params
710
+
711
+ if not hide_layers_with_no_params or num_params != 0:
712
+ num_params_str = str(num_params) if num_params > 0 else "-"
713
+ output_shape_str = str(layer_output.shape)
714
+ weight_shape_str = str(weights[0].shape) if len(weights) >= 1 else "-"
715
+ rows += [[layer_name, num_params_str, output_shape_str, weight_shape_str]]
716
+
717
+ rows += [["---"] * 4]
718
+ rows += [["Total", str(total_params), "", ""]]
719
+
720
+ widths = [max(len(cell) for cell in column) for column in zip(*rows)]
721
+ print()
722
+ for row in rows:
723
+ print(" ".join(cell + " " * (width - len(cell)) for cell, width in zip(row, widths)))
724
+ print()
725
+
726
+ def setup_weight_histograms(self, title: str = None) -> None:
727
+ """Construct summary ops to include histograms of all trainable parameters in TensorBoard."""
728
+ if title is None:
729
+ title = self.name
730
+
731
+ with tf.name_scope(None), tf.device(None), tf.control_dependencies(None):
732
+ for local_name, var in self._get_trainables().items():
733
+ if "/" in local_name:
734
+ p = local_name.split("/")
735
+ name = title + "_" + p[-1] + "/" + "_".join(p[:-1])
736
+ else:
737
+ name = title + "_toplevel/" + local_name
738
+
739
+ tf.summary.histogram(name, var)
740
+
741
+ #----------------------------------------------------------------------------
742
+ # Backwards-compatible emulation of legacy output transformation in Network.run().
743
+
744
+ _print_legacy_warning = True
745
+
746
+ def _handle_legacy_output_transforms(output_transform, dynamic_kwargs):
747
+ global _print_legacy_warning
748
+ legacy_kwargs = ["out_mul", "out_add", "out_shrink", "out_dtype"]
749
+ if not any(kwarg in dynamic_kwargs for kwarg in legacy_kwargs):
750
+ return output_transform, dynamic_kwargs
751
+
752
+ if _print_legacy_warning:
753
+ _print_legacy_warning = False
754
+ print()
755
+ print("WARNING: Old-style output transformations in Network.run() are deprecated.")
756
+ print("Consider using 'output_transform=dict(func=tflib.convert_images_to_uint8)'")
757
+ print("instead of 'out_mul=127.5, out_add=127.5, out_dtype=np.uint8'.")
758
+ print()
759
+ assert output_transform is None
760
+
761
+ new_kwargs = dict(dynamic_kwargs)
762
+ new_transform = {kwarg: new_kwargs.pop(kwarg) for kwarg in legacy_kwargs if kwarg in dynamic_kwargs}
763
+ new_transform["func"] = _legacy_output_transform_func
764
+ return new_transform, new_kwargs
765
+
766
+ def _legacy_output_transform_func(*expr, out_mul=1.0, out_add=0.0, out_shrink=1, out_dtype=None):
767
+ if out_mul != 1.0:
768
+ expr = [x * out_mul for x in expr]
769
+
770
+ if out_add != 0.0:
771
+ expr = [x + out_add for x in expr]
772
+
773
+ if out_shrink > 1:
774
+ ksize = [1, 1, out_shrink, out_shrink]
775
+ expr = [tf.nn.avg_pool(x, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW") for x in expr]
776
+
777
+ if out_dtype is not None:
778
+ if tf.as_dtype(out_dtype).is_integer:
779
+ expr = [tf.round(x) for x in expr]
780
+ expr = [tf.saturate_cast(x, out_dtype) for x in expr]
781
+ return expr