i72sijia commited on
Commit
51794ac
1 Parent(s): 664be34

Upload tfutil.py

Browse files
Files changed (1) hide show
  1. dnnlib/tflib/tfutil.py +262 -0
dnnlib/tflib/tfutil.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Miscellaneous helper utils for Tensorflow."""
10
+
11
+ import os
12
+ import numpy as np
13
+ import tensorflow as tf
14
+
15
+ # Silence deprecation warnings from TensorFlow 1.13 onwards
16
+ import logging
17
+ logging.getLogger('tensorflow').setLevel(logging.ERROR)
18
+ import tensorflow.contrib # requires TensorFlow 1.x!
19
+ tf.contrib = tensorflow.contrib
20
+
21
+ from typing import Any, Iterable, List, Union
22
+
23
+ TfExpression = Union[tf.Tensor, tf.Variable, tf.Operation]
24
+ """A type that represents a valid Tensorflow expression."""
25
+
26
+ TfExpressionEx = Union[TfExpression, int, float, np.ndarray]
27
+ """A type that can be converted to a valid Tensorflow expression."""
28
+
29
+
30
+ def run(*args, **kwargs) -> Any:
31
+ """Run the specified ops in the default session."""
32
+ assert_tf_initialized()
33
+ return tf.get_default_session().run(*args, **kwargs)
34
+
35
+
36
+ def is_tf_expression(x: Any) -> bool:
37
+ """Check whether the input is a valid Tensorflow expression, i.e., Tensorflow Tensor, Variable, or Operation."""
38
+ return isinstance(x, (tf.Tensor, tf.Variable, tf.Operation))
39
+
40
+
41
+ def shape_to_list(shape: Iterable[tf.Dimension]) -> List[Union[int, None]]:
42
+ """Convert a Tensorflow shape to a list of ints. Retained for backwards compatibility -- use TensorShape.as_list() in new code."""
43
+ return [dim.value for dim in shape]
44
+
45
+
46
+ def flatten(x: TfExpressionEx) -> TfExpression:
47
+ """Shortcut function for flattening a tensor."""
48
+ with tf.name_scope("Flatten"):
49
+ return tf.reshape(x, [-1])
50
+
51
+
52
+ def log2(x: TfExpressionEx) -> TfExpression:
53
+ """Logarithm in base 2."""
54
+ with tf.name_scope("Log2"):
55
+ return tf.log(x) * np.float32(1.0 / np.log(2.0))
56
+
57
+
58
+ def exp2(x: TfExpressionEx) -> TfExpression:
59
+ """Exponent in base 2."""
60
+ with tf.name_scope("Exp2"):
61
+ return tf.exp(x * np.float32(np.log(2.0)))
62
+
63
+
64
+ def erfinv(y: TfExpressionEx) -> TfExpression:
65
+ """Inverse of the error function."""
66
+ # pylint: disable=no-name-in-module
67
+ from tensorflow.python.ops.distributions import special_math
68
+ return special_math.erfinv(y)
69
+
70
+
71
+ def lerp(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpressionEx:
72
+ """Linear interpolation."""
73
+ with tf.name_scope("Lerp"):
74
+ return a + (b - a) * t
75
+
76
+
77
+ def lerp_clip(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpression:
78
+ """Linear interpolation with clip."""
79
+ with tf.name_scope("LerpClip"):
80
+ return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0)
81
+
82
+
83
+ def absolute_name_scope(scope: str) -> tf.name_scope:
84
+ """Forcefully enter the specified name scope, ignoring any surrounding scopes."""
85
+ return tf.name_scope(scope + "/")
86
+
87
+
88
+ def absolute_variable_scope(scope: str, **kwargs) -> tf.variable_scope:
89
+ """Forcefully enter the specified variable scope, ignoring any surrounding scopes."""
90
+ return tf.variable_scope(tf.VariableScope(name=scope, **kwargs), auxiliary_name_scope=False)
91
+
92
+
93
+ def _sanitize_tf_config(config_dict: dict = None) -> dict:
94
+ # Defaults.
95
+ cfg = dict()
96
+ cfg["rnd.np_random_seed"] = None # Random seed for NumPy. None = keep as is.
97
+ cfg["rnd.tf_random_seed"] = "auto" # Random seed for TensorFlow. 'auto' = derive from NumPy random state. None = keep as is.
98
+ cfg["env.TF_CPP_MIN_LOG_LEVEL"] = "1" # 0 = Print all available debug info from TensorFlow. 1 = Print warnings and errors, but disable debug info.
99
+ cfg["env.HDF5_USE_FILE_LOCKING"] = "FALSE" # Disable HDF5 file locking to avoid concurrency issues with network shares.
100
+ cfg["graph_options.place_pruned_graph"] = True # False = Check that all ops are available on the designated device. True = Skip the check for ops that are not used.
101
+ cfg["gpu_options.allow_growth"] = True # False = Allocate all GPU memory at the beginning. True = Allocate only as much GPU memory as needed.
102
+
103
+ # Remove defaults for environment variables that are already set.
104
+ for key in list(cfg):
105
+ fields = key.split(".")
106
+ if fields[0] == "env":
107
+ assert len(fields) == 2
108
+ if fields[1] in os.environ:
109
+ del cfg[key]
110
+
111
+ # User overrides.
112
+ if config_dict is not None:
113
+ cfg.update(config_dict)
114
+ return cfg
115
+
116
+
117
+ def init_tf(config_dict: dict = None) -> None:
118
+ """Initialize TensorFlow session using good default settings."""
119
+ # Skip if already initialized.
120
+ if tf.get_default_session() is not None:
121
+ return
122
+
123
+ # Setup config dict and random seeds.
124
+ cfg = _sanitize_tf_config(config_dict)
125
+ np_random_seed = cfg["rnd.np_random_seed"]
126
+ if np_random_seed is not None:
127
+ np.random.seed(np_random_seed)
128
+ tf_random_seed = cfg["rnd.tf_random_seed"]
129
+ if tf_random_seed == "auto":
130
+ tf_random_seed = np.random.randint(1 << 31)
131
+ if tf_random_seed is not None:
132
+ tf.set_random_seed(tf_random_seed)
133
+
134
+ # Setup environment variables.
135
+ for key, value in cfg.items():
136
+ fields = key.split(".")
137
+ if fields[0] == "env":
138
+ assert len(fields) == 2
139
+ os.environ[fields[1]] = str(value)
140
+
141
+ # Create default TensorFlow session.
142
+ create_session(cfg, force_as_default=True)
143
+
144
+
145
+ def assert_tf_initialized():
146
+ """Check that TensorFlow session has been initialized."""
147
+ if tf.get_default_session() is None:
148
+ raise RuntimeError("No default TensorFlow session found. Please call dnnlib.tflib.init_tf().")
149
+
150
+
151
+ def create_session(config_dict: dict = None, force_as_default: bool = False) -> tf.Session:
152
+ """Create tf.Session based on config dict."""
153
+ # Setup TensorFlow config proto.
154
+ cfg = _sanitize_tf_config(config_dict)
155
+ config_proto = tf.ConfigProto()
156
+ for key, value in cfg.items():
157
+ fields = key.split(".")
158
+ if fields[0] not in ["rnd", "env"]:
159
+ obj = config_proto
160
+ for field in fields[:-1]:
161
+ obj = getattr(obj, field)
162
+ setattr(obj, fields[-1], value)
163
+
164
+ # Create session.
165
+ session = tf.Session(config=config_proto)
166
+ if force_as_default:
167
+ # pylint: disable=protected-access
168
+ session._default_session = session.as_default()
169
+ session._default_session.enforce_nesting = False
170
+ session._default_session.__enter__()
171
+ return session
172
+
173
+
174
+ def init_uninitialized_vars(target_vars: List[tf.Variable] = None) -> None:
175
+ """Initialize all tf.Variables that have not already been initialized.
176
+
177
+ Equivalent to the following, but more efficient and does not bloat the tf graph:
178
+ tf.variables_initializer(tf.report_uninitialized_variables()).run()
179
+ """
180
+ assert_tf_initialized()
181
+ if target_vars is None:
182
+ target_vars = tf.global_variables()
183
+
184
+ test_vars = []
185
+ test_ops = []
186
+
187
+ with tf.control_dependencies(None): # ignore surrounding control_dependencies
188
+ for var in target_vars:
189
+ assert is_tf_expression(var)
190
+
191
+ try:
192
+ tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/IsVariableInitialized:0"))
193
+ except KeyError:
194
+ # Op does not exist => variable may be uninitialized.
195
+ test_vars.append(var)
196
+
197
+ with absolute_name_scope(var.name.split(":")[0]):
198
+ test_ops.append(tf.is_variable_initialized(var))
199
+
200
+ init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited]
201
+ run([var.initializer for var in init_vars])
202
+
203
+
204
+ def set_vars(var_to_value_dict: dict) -> None:
205
+ """Set the values of given tf.Variables.
206
+
207
+ Equivalent to the following, but more efficient and does not bloat the tf graph:
208
+ tflib.run([tf.assign(var, value) for var, value in var_to_value_dict.items()]
209
+ """
210
+ assert_tf_initialized()
211
+ ops = []
212
+ feed_dict = {}
213
+
214
+ for var, value in var_to_value_dict.items():
215
+ assert is_tf_expression(var)
216
+
217
+ try:
218
+ setter = tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/setter:0")) # look for existing op
219
+ except KeyError:
220
+ with absolute_name_scope(var.name.split(":")[0]):
221
+ with tf.control_dependencies(None): # ignore surrounding control_dependencies
222
+ setter = tf.assign(var, tf.placeholder(var.dtype, var.shape, "new_value"), name="setter") # create new setter
223
+
224
+ ops.append(setter)
225
+ feed_dict[setter.op.inputs[1]] = value
226
+
227
+ run(ops, feed_dict)
228
+
229
+
230
+ def create_var_with_large_initial_value(initial_value: np.ndarray, *args, **kwargs):
231
+ """Create tf.Variable with large initial value without bloating the tf graph."""
232
+ assert_tf_initialized()
233
+ assert isinstance(initial_value, np.ndarray)
234
+ zeros = tf.zeros(initial_value.shape, initial_value.dtype)
235
+ var = tf.Variable(zeros, *args, **kwargs)
236
+ set_vars({var: initial_value})
237
+ return var
238
+
239
+
240
+ def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False):
241
+ """Convert a minibatch of images from uint8 to float32 with configurable dynamic range.
242
+ Can be used as an input transformation for Network.run().
243
+ """
244
+ images = tf.cast(images, tf.float32)
245
+ if nhwc_to_nchw:
246
+ images = tf.transpose(images, [0, 3, 1, 2])
247
+ return images * ((drange[1] - drange[0]) / 255) + drange[0]
248
+
249
+
250
+ def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, shrink=1):
251
+ """Convert a minibatch of images from float32 to uint8 with configurable dynamic range.
252
+ Can be used as an output transformation for Network.run().
253
+ """
254
+ images = tf.cast(images, tf.float32)
255
+ if shrink > 1:
256
+ ksize = [1, 1, shrink, shrink]
257
+ images = tf.nn.avg_pool(images, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW")
258
+ if nchw_to_nhwc:
259
+ images = tf.transpose(images, [0, 2, 3, 1])
260
+ scale = 255 / (drange[1] - drange[0])
261
+ images = images * scale + (0.5 - drange[0] * scale)
262
+ return tf.saturate_cast(images, tf.uint8)