Spaces:
Runtime error
Runtime error
Upload tfutil.py
Browse files- dnnlib/tflib/tfutil.py +262 -0
dnnlib/tflib/tfutil.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Miscellaneous helper utils for Tensorflow."""
|
10 |
+
|
11 |
+
import os
|
12 |
+
import numpy as np
|
13 |
+
import tensorflow as tf
|
14 |
+
|
15 |
+
# Silence deprecation warnings from TensorFlow 1.13 onwards
|
16 |
+
import logging
|
17 |
+
logging.getLogger('tensorflow').setLevel(logging.ERROR)
|
18 |
+
import tensorflow.contrib # requires TensorFlow 1.x!
|
19 |
+
tf.contrib = tensorflow.contrib
|
20 |
+
|
21 |
+
from typing import Any, Iterable, List, Union
|
22 |
+
|
23 |
+
TfExpression = Union[tf.Tensor, tf.Variable, tf.Operation]
|
24 |
+
"""A type that represents a valid Tensorflow expression."""
|
25 |
+
|
26 |
+
TfExpressionEx = Union[TfExpression, int, float, np.ndarray]
|
27 |
+
"""A type that can be converted to a valid Tensorflow expression."""
|
28 |
+
|
29 |
+
|
30 |
+
def run(*args, **kwargs) -> Any:
|
31 |
+
"""Run the specified ops in the default session."""
|
32 |
+
assert_tf_initialized()
|
33 |
+
return tf.get_default_session().run(*args, **kwargs)
|
34 |
+
|
35 |
+
|
36 |
+
def is_tf_expression(x: Any) -> bool:
|
37 |
+
"""Check whether the input is a valid Tensorflow expression, i.e., Tensorflow Tensor, Variable, or Operation."""
|
38 |
+
return isinstance(x, (tf.Tensor, tf.Variable, tf.Operation))
|
39 |
+
|
40 |
+
|
41 |
+
def shape_to_list(shape: Iterable[tf.Dimension]) -> List[Union[int, None]]:
|
42 |
+
"""Convert a Tensorflow shape to a list of ints. Retained for backwards compatibility -- use TensorShape.as_list() in new code."""
|
43 |
+
return [dim.value for dim in shape]
|
44 |
+
|
45 |
+
|
46 |
+
def flatten(x: TfExpressionEx) -> TfExpression:
|
47 |
+
"""Shortcut function for flattening a tensor."""
|
48 |
+
with tf.name_scope("Flatten"):
|
49 |
+
return tf.reshape(x, [-1])
|
50 |
+
|
51 |
+
|
52 |
+
def log2(x: TfExpressionEx) -> TfExpression:
|
53 |
+
"""Logarithm in base 2."""
|
54 |
+
with tf.name_scope("Log2"):
|
55 |
+
return tf.log(x) * np.float32(1.0 / np.log(2.0))
|
56 |
+
|
57 |
+
|
58 |
+
def exp2(x: TfExpressionEx) -> TfExpression:
|
59 |
+
"""Exponent in base 2."""
|
60 |
+
with tf.name_scope("Exp2"):
|
61 |
+
return tf.exp(x * np.float32(np.log(2.0)))
|
62 |
+
|
63 |
+
|
64 |
+
def erfinv(y: TfExpressionEx) -> TfExpression:
|
65 |
+
"""Inverse of the error function."""
|
66 |
+
# pylint: disable=no-name-in-module
|
67 |
+
from tensorflow.python.ops.distributions import special_math
|
68 |
+
return special_math.erfinv(y)
|
69 |
+
|
70 |
+
|
71 |
+
def lerp(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpressionEx:
|
72 |
+
"""Linear interpolation."""
|
73 |
+
with tf.name_scope("Lerp"):
|
74 |
+
return a + (b - a) * t
|
75 |
+
|
76 |
+
|
77 |
+
def lerp_clip(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpression:
|
78 |
+
"""Linear interpolation with clip."""
|
79 |
+
with tf.name_scope("LerpClip"):
|
80 |
+
return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0)
|
81 |
+
|
82 |
+
|
83 |
+
def absolute_name_scope(scope: str) -> tf.name_scope:
|
84 |
+
"""Forcefully enter the specified name scope, ignoring any surrounding scopes."""
|
85 |
+
return tf.name_scope(scope + "/")
|
86 |
+
|
87 |
+
|
88 |
+
def absolute_variable_scope(scope: str, **kwargs) -> tf.variable_scope:
|
89 |
+
"""Forcefully enter the specified variable scope, ignoring any surrounding scopes."""
|
90 |
+
return tf.variable_scope(tf.VariableScope(name=scope, **kwargs), auxiliary_name_scope=False)
|
91 |
+
|
92 |
+
|
93 |
+
def _sanitize_tf_config(config_dict: dict = None) -> dict:
|
94 |
+
# Defaults.
|
95 |
+
cfg = dict()
|
96 |
+
cfg["rnd.np_random_seed"] = None # Random seed for NumPy. None = keep as is.
|
97 |
+
cfg["rnd.tf_random_seed"] = "auto" # Random seed for TensorFlow. 'auto' = derive from NumPy random state. None = keep as is.
|
98 |
+
cfg["env.TF_CPP_MIN_LOG_LEVEL"] = "1" # 0 = Print all available debug info from TensorFlow. 1 = Print warnings and errors, but disable debug info.
|
99 |
+
cfg["env.HDF5_USE_FILE_LOCKING"] = "FALSE" # Disable HDF5 file locking to avoid concurrency issues with network shares.
|
100 |
+
cfg["graph_options.place_pruned_graph"] = True # False = Check that all ops are available on the designated device. True = Skip the check for ops that are not used.
|
101 |
+
cfg["gpu_options.allow_growth"] = True # False = Allocate all GPU memory at the beginning. True = Allocate only as much GPU memory as needed.
|
102 |
+
|
103 |
+
# Remove defaults for environment variables that are already set.
|
104 |
+
for key in list(cfg):
|
105 |
+
fields = key.split(".")
|
106 |
+
if fields[0] == "env":
|
107 |
+
assert len(fields) == 2
|
108 |
+
if fields[1] in os.environ:
|
109 |
+
del cfg[key]
|
110 |
+
|
111 |
+
# User overrides.
|
112 |
+
if config_dict is not None:
|
113 |
+
cfg.update(config_dict)
|
114 |
+
return cfg
|
115 |
+
|
116 |
+
|
117 |
+
def init_tf(config_dict: dict = None) -> None:
|
118 |
+
"""Initialize TensorFlow session using good default settings."""
|
119 |
+
# Skip if already initialized.
|
120 |
+
if tf.get_default_session() is not None:
|
121 |
+
return
|
122 |
+
|
123 |
+
# Setup config dict and random seeds.
|
124 |
+
cfg = _sanitize_tf_config(config_dict)
|
125 |
+
np_random_seed = cfg["rnd.np_random_seed"]
|
126 |
+
if np_random_seed is not None:
|
127 |
+
np.random.seed(np_random_seed)
|
128 |
+
tf_random_seed = cfg["rnd.tf_random_seed"]
|
129 |
+
if tf_random_seed == "auto":
|
130 |
+
tf_random_seed = np.random.randint(1 << 31)
|
131 |
+
if tf_random_seed is not None:
|
132 |
+
tf.set_random_seed(tf_random_seed)
|
133 |
+
|
134 |
+
# Setup environment variables.
|
135 |
+
for key, value in cfg.items():
|
136 |
+
fields = key.split(".")
|
137 |
+
if fields[0] == "env":
|
138 |
+
assert len(fields) == 2
|
139 |
+
os.environ[fields[1]] = str(value)
|
140 |
+
|
141 |
+
# Create default TensorFlow session.
|
142 |
+
create_session(cfg, force_as_default=True)
|
143 |
+
|
144 |
+
|
145 |
+
def assert_tf_initialized():
|
146 |
+
"""Check that TensorFlow session has been initialized."""
|
147 |
+
if tf.get_default_session() is None:
|
148 |
+
raise RuntimeError("No default TensorFlow session found. Please call dnnlib.tflib.init_tf().")
|
149 |
+
|
150 |
+
|
151 |
+
def create_session(config_dict: dict = None, force_as_default: bool = False) -> tf.Session:
|
152 |
+
"""Create tf.Session based on config dict."""
|
153 |
+
# Setup TensorFlow config proto.
|
154 |
+
cfg = _sanitize_tf_config(config_dict)
|
155 |
+
config_proto = tf.ConfigProto()
|
156 |
+
for key, value in cfg.items():
|
157 |
+
fields = key.split(".")
|
158 |
+
if fields[0] not in ["rnd", "env"]:
|
159 |
+
obj = config_proto
|
160 |
+
for field in fields[:-1]:
|
161 |
+
obj = getattr(obj, field)
|
162 |
+
setattr(obj, fields[-1], value)
|
163 |
+
|
164 |
+
# Create session.
|
165 |
+
session = tf.Session(config=config_proto)
|
166 |
+
if force_as_default:
|
167 |
+
# pylint: disable=protected-access
|
168 |
+
session._default_session = session.as_default()
|
169 |
+
session._default_session.enforce_nesting = False
|
170 |
+
session._default_session.__enter__()
|
171 |
+
return session
|
172 |
+
|
173 |
+
|
174 |
+
def init_uninitialized_vars(target_vars: List[tf.Variable] = None) -> None:
|
175 |
+
"""Initialize all tf.Variables that have not already been initialized.
|
176 |
+
|
177 |
+
Equivalent to the following, but more efficient and does not bloat the tf graph:
|
178 |
+
tf.variables_initializer(tf.report_uninitialized_variables()).run()
|
179 |
+
"""
|
180 |
+
assert_tf_initialized()
|
181 |
+
if target_vars is None:
|
182 |
+
target_vars = tf.global_variables()
|
183 |
+
|
184 |
+
test_vars = []
|
185 |
+
test_ops = []
|
186 |
+
|
187 |
+
with tf.control_dependencies(None): # ignore surrounding control_dependencies
|
188 |
+
for var in target_vars:
|
189 |
+
assert is_tf_expression(var)
|
190 |
+
|
191 |
+
try:
|
192 |
+
tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/IsVariableInitialized:0"))
|
193 |
+
except KeyError:
|
194 |
+
# Op does not exist => variable may be uninitialized.
|
195 |
+
test_vars.append(var)
|
196 |
+
|
197 |
+
with absolute_name_scope(var.name.split(":")[0]):
|
198 |
+
test_ops.append(tf.is_variable_initialized(var))
|
199 |
+
|
200 |
+
init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited]
|
201 |
+
run([var.initializer for var in init_vars])
|
202 |
+
|
203 |
+
|
204 |
+
def set_vars(var_to_value_dict: dict) -> None:
|
205 |
+
"""Set the values of given tf.Variables.
|
206 |
+
|
207 |
+
Equivalent to the following, but more efficient and does not bloat the tf graph:
|
208 |
+
tflib.run([tf.assign(var, value) for var, value in var_to_value_dict.items()]
|
209 |
+
"""
|
210 |
+
assert_tf_initialized()
|
211 |
+
ops = []
|
212 |
+
feed_dict = {}
|
213 |
+
|
214 |
+
for var, value in var_to_value_dict.items():
|
215 |
+
assert is_tf_expression(var)
|
216 |
+
|
217 |
+
try:
|
218 |
+
setter = tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/setter:0")) # look for existing op
|
219 |
+
except KeyError:
|
220 |
+
with absolute_name_scope(var.name.split(":")[0]):
|
221 |
+
with tf.control_dependencies(None): # ignore surrounding control_dependencies
|
222 |
+
setter = tf.assign(var, tf.placeholder(var.dtype, var.shape, "new_value"), name="setter") # create new setter
|
223 |
+
|
224 |
+
ops.append(setter)
|
225 |
+
feed_dict[setter.op.inputs[1]] = value
|
226 |
+
|
227 |
+
run(ops, feed_dict)
|
228 |
+
|
229 |
+
|
230 |
+
def create_var_with_large_initial_value(initial_value: np.ndarray, *args, **kwargs):
|
231 |
+
"""Create tf.Variable with large initial value without bloating the tf graph."""
|
232 |
+
assert_tf_initialized()
|
233 |
+
assert isinstance(initial_value, np.ndarray)
|
234 |
+
zeros = tf.zeros(initial_value.shape, initial_value.dtype)
|
235 |
+
var = tf.Variable(zeros, *args, **kwargs)
|
236 |
+
set_vars({var: initial_value})
|
237 |
+
return var
|
238 |
+
|
239 |
+
|
240 |
+
def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False):
|
241 |
+
"""Convert a minibatch of images from uint8 to float32 with configurable dynamic range.
|
242 |
+
Can be used as an input transformation for Network.run().
|
243 |
+
"""
|
244 |
+
images = tf.cast(images, tf.float32)
|
245 |
+
if nhwc_to_nchw:
|
246 |
+
images = tf.transpose(images, [0, 3, 1, 2])
|
247 |
+
return images * ((drange[1] - drange[0]) / 255) + drange[0]
|
248 |
+
|
249 |
+
|
250 |
+
def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, shrink=1):
|
251 |
+
"""Convert a minibatch of images from float32 to uint8 with configurable dynamic range.
|
252 |
+
Can be used as an output transformation for Network.run().
|
253 |
+
"""
|
254 |
+
images = tf.cast(images, tf.float32)
|
255 |
+
if shrink > 1:
|
256 |
+
ksize = [1, 1, shrink, shrink]
|
257 |
+
images = tf.nn.avg_pool(images, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW")
|
258 |
+
if nchw_to_nhwc:
|
259 |
+
images = tf.transpose(images, [0, 2, 3, 1])
|
260 |
+
scale = 255 / (drange[1] - drange[0])
|
261 |
+
images = images * scale + (0.5 - drange[0] * scale)
|
262 |
+
return tf.saturate_cast(images, tf.uint8)
|