echen01 commited on
Commit
0513aaf
1 Parent(s): 3eee23c

add app templtae

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +34 -0
  2. dnnlib/__init__.py +9 -0
  3. dnnlib/__pycache__/__init__.cpython-36.pyc +0 -0
  4. dnnlib/__pycache__/__init__.cpython-38.pyc +0 -0
  5. dnnlib/__pycache__/__init__.cpython-39.pyc +0 -0
  6. dnnlib/__pycache__/util.cpython-36.pyc +0 -0
  7. dnnlib/__pycache__/util.cpython-38.pyc +0 -0
  8. dnnlib/__pycache__/util.cpython-39.pyc +0 -0
  9. dnnlib/util.py +477 -0
  10. legacy.py +408 -0
  11. torch_utils/__init__.py +9 -0
  12. torch_utils/__pycache__/__init__.cpython-36.pyc +0 -0
  13. torch_utils/__pycache__/__init__.cpython-38.pyc +0 -0
  14. torch_utils/__pycache__/__init__.cpython-39.pyc +0 -0
  15. torch_utils/__pycache__/custom_ops.cpython-36.pyc +0 -0
  16. torch_utils/__pycache__/custom_ops.cpython-38.pyc +0 -0
  17. torch_utils/__pycache__/custom_ops.cpython-39.pyc +0 -0
  18. torch_utils/__pycache__/misc.cpython-36.pyc +0 -0
  19. torch_utils/__pycache__/misc.cpython-38.pyc +0 -0
  20. torch_utils/__pycache__/misc.cpython-39.pyc +0 -0
  21. torch_utils/__pycache__/persistence.cpython-36.pyc +0 -0
  22. torch_utils/__pycache__/persistence.cpython-38.pyc +0 -0
  23. torch_utils/__pycache__/persistence.cpython-39.pyc +0 -0
  24. torch_utils/custom_ops.py +126 -0
  25. torch_utils/misc.py +332 -0
  26. torch_utils/ops/__init__.py +9 -0
  27. torch_utils/ops/__pycache__/__init__.cpython-36.pyc +0 -0
  28. torch_utils/ops/__pycache__/__init__.cpython-38.pyc +0 -0
  29. torch_utils/ops/__pycache__/__init__.cpython-39.pyc +0 -0
  30. torch_utils/ops/__pycache__/bias_act.cpython-36.pyc +0 -0
  31. torch_utils/ops/__pycache__/bias_act.cpython-38.pyc +0 -0
  32. torch_utils/ops/__pycache__/bias_act.cpython-39.pyc +0 -0
  33. torch_utils/ops/__pycache__/conv2d_gradfix.cpython-36.pyc +0 -0
  34. torch_utils/ops/__pycache__/conv2d_gradfix.cpython-38.pyc +0 -0
  35. torch_utils/ops/__pycache__/conv2d_gradfix.cpython-39.pyc +0 -0
  36. torch_utils/ops/__pycache__/conv2d_resample.cpython-36.pyc +0 -0
  37. torch_utils/ops/__pycache__/conv2d_resample.cpython-38.pyc +0 -0
  38. torch_utils/ops/__pycache__/conv2d_resample.cpython-39.pyc +0 -0
  39. torch_utils/ops/__pycache__/fma.cpython-36.pyc +0 -0
  40. torch_utils/ops/__pycache__/fma.cpython-38.pyc +0 -0
  41. torch_utils/ops/__pycache__/fma.cpython-39.pyc +0 -0
  42. torch_utils/ops/__pycache__/grid_sample_gradfix.cpython-36.pyc +0 -0
  43. torch_utils/ops/__pycache__/grid_sample_gradfix.cpython-38.pyc +0 -0
  44. torch_utils/ops/__pycache__/grid_sample_gradfix.cpython-39.pyc +0 -0
  45. torch_utils/ops/__pycache__/upfirdn2d.cpython-36.pyc +0 -0
  46. torch_utils/ops/__pycache__/upfirdn2d.cpython-38.pyc +0 -0
  47. torch_utils/ops/__pycache__/upfirdn2d.cpython-39.pyc +0 -0
  48. torch_utils/ops/bias_act.cpp +99 -0
  49. torch_utils/ops/bias_act.cu +173 -0
  50. torch_utils/ops/bias_act.h +38 -0
app.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import utils
3
+ from PIL import Image
4
+ import torch
5
+ import math
6
+ from torchvision import transforms
7
+
8
+
9
+ device = "cpu"
10
+ years = [str(y) for y in range(1880, 2020, 10)]
11
+
12
+
13
+ orig_models = {}
14
+
15
+ for year in years:
16
+ G, w_avg = utils.load_stylegan2(f"pretrained_models/{year}.pkl", device)
17
+ orig_models[year] = { "G": G.eval()}
18
+
19
+ transform = transforms.Compose([
20
+ transforms.Resize((256, 256)),
21
+ transforms.ToTensor(),
22
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
23
+
24
+ # Download human-readable labels for ImageNet.
25
+ def predict(inp):
26
+ #with torch.no_grad():
27
+ return inp
28
+
29
+
30
+ gr.Interface(fn=predict,
31
+ inputs=gr.Image(type="pil"),
32
+ outputs=gr.Image(type="pil"),
33
+ #examples=["lion.jpg", "cheetah.jpg"]
34
+ ).launch()
dnnlib/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ from .util import EasyDict, make_cache_dir_path
dnnlib/__pycache__/__init__.cpython-36.pyc ADDED
Binary file (218 Bytes). View file
 
dnnlib/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (226 Bytes). View file
 
dnnlib/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (226 Bytes). View file
 
dnnlib/__pycache__/util.cpython-36.pyc ADDED
Binary file (13.6 kB). View file
 
dnnlib/__pycache__/util.cpython-38.pyc ADDED
Binary file (13.7 kB). View file
 
dnnlib/__pycache__/util.cpython-39.pyc ADDED
Binary file (13.8 kB). View file
 
dnnlib/util.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Miscellaneous utility classes and functions."""
10
+
11
+ import ctypes
12
+ import fnmatch
13
+ import importlib
14
+ import inspect
15
+ import numpy as np
16
+ import os
17
+ import shutil
18
+ import sys
19
+ import types
20
+ import io
21
+ import pickle
22
+ import re
23
+ import requests
24
+ import html
25
+ import hashlib
26
+ import glob
27
+ import tempfile
28
+ import urllib
29
+ import urllib.request
30
+ import uuid
31
+
32
+ from distutils.util import strtobool
33
+ from typing import Any, List, Tuple, Union
34
+
35
+
36
+ # Util classes
37
+ # ------------------------------------------------------------------------------------------
38
+
39
+
40
+ class EasyDict(dict):
41
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
42
+
43
+ def __getattr__(self, name: str) -> Any:
44
+ try:
45
+ return self[name]
46
+ except KeyError:
47
+ raise AttributeError(name)
48
+
49
+ def __setattr__(self, name: str, value: Any) -> None:
50
+ self[name] = value
51
+
52
+ def __delattr__(self, name: str) -> None:
53
+ del self[name]
54
+
55
+
56
+ class Logger(object):
57
+ """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
58
+
59
+ def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
60
+ self.file = None
61
+
62
+ if file_name is not None:
63
+ self.file = open(file_name, file_mode)
64
+
65
+ self.should_flush = should_flush
66
+ self.stdout = sys.stdout
67
+ self.stderr = sys.stderr
68
+
69
+ sys.stdout = self
70
+ sys.stderr = self
71
+
72
+ def __enter__(self) -> "Logger":
73
+ return self
74
+
75
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
76
+ self.close()
77
+
78
+ def write(self, text: Union[str, bytes]) -> None:
79
+ """Write text to stdout (and a file) and optionally flush."""
80
+ if isinstance(text, bytes):
81
+ text = text.decode()
82
+ if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
83
+ return
84
+
85
+ if self.file is not None:
86
+ self.file.write(text)
87
+
88
+ self.stdout.write(text)
89
+
90
+ if self.should_flush:
91
+ self.flush()
92
+
93
+ def flush(self) -> None:
94
+ """Flush written text to both stdout and a file, if open."""
95
+ if self.file is not None:
96
+ self.file.flush()
97
+
98
+ self.stdout.flush()
99
+
100
+ def close(self) -> None:
101
+ """Flush, close possible files, and remove stdout/stderr mirroring."""
102
+ self.flush()
103
+
104
+ # if using multiple loggers, prevent closing in wrong order
105
+ if sys.stdout is self:
106
+ sys.stdout = self.stdout
107
+ if sys.stderr is self:
108
+ sys.stderr = self.stderr
109
+
110
+ if self.file is not None:
111
+ self.file.close()
112
+ self.file = None
113
+
114
+
115
+ # Cache directories
116
+ # ------------------------------------------------------------------------------------------
117
+
118
+ _dnnlib_cache_dir = None
119
+
120
+ def set_cache_dir(path: str) -> None:
121
+ global _dnnlib_cache_dir
122
+ _dnnlib_cache_dir = path
123
+
124
+ def make_cache_dir_path(*paths: str) -> str:
125
+ if _dnnlib_cache_dir is not None:
126
+ return os.path.join(_dnnlib_cache_dir, *paths)
127
+ if 'DNNLIB_CACHE_DIR' in os.environ:
128
+ return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
129
+ if 'HOME' in os.environ:
130
+ return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
131
+ if 'USERPROFILE' in os.environ:
132
+ return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
133
+ return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
134
+
135
+ # Small util functions
136
+ # ------------------------------------------------------------------------------------------
137
+
138
+
139
+ def format_time(seconds: Union[int, float]) -> str:
140
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
141
+ s = int(np.rint(seconds))
142
+
143
+ if s < 60:
144
+ return "{0}s".format(s)
145
+ elif s < 60 * 60:
146
+ return "{0}m {1:02}s".format(s // 60, s % 60)
147
+ elif s < 24 * 60 * 60:
148
+ return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
149
+ else:
150
+ return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
151
+
152
+
153
+ def ask_yes_no(question: str) -> bool:
154
+ """Ask the user the question until the user inputs a valid answer."""
155
+ while True:
156
+ try:
157
+ print("{0} [y/n]".format(question))
158
+ return strtobool(input().lower())
159
+ except ValueError:
160
+ pass
161
+
162
+
163
+ def tuple_product(t: Tuple) -> Any:
164
+ """Calculate the product of the tuple elements."""
165
+ result = 1
166
+
167
+ for v in t:
168
+ result *= v
169
+
170
+ return result
171
+
172
+
173
+ _str_to_ctype = {
174
+ "uint8": ctypes.c_ubyte,
175
+ "uint16": ctypes.c_uint16,
176
+ "uint32": ctypes.c_uint32,
177
+ "uint64": ctypes.c_uint64,
178
+ "int8": ctypes.c_byte,
179
+ "int16": ctypes.c_int16,
180
+ "int32": ctypes.c_int32,
181
+ "int64": ctypes.c_int64,
182
+ "float32": ctypes.c_float,
183
+ "float64": ctypes.c_double
184
+ }
185
+
186
+
187
+ def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
188
+ """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
189
+ type_str = None
190
+
191
+ if isinstance(type_obj, str):
192
+ type_str = type_obj
193
+ elif hasattr(type_obj, "__name__"):
194
+ type_str = type_obj.__name__
195
+ elif hasattr(type_obj, "name"):
196
+ type_str = type_obj.name
197
+ else:
198
+ raise RuntimeError("Cannot infer type name from input")
199
+
200
+ assert type_str in _str_to_ctype.keys()
201
+
202
+ my_dtype = np.dtype(type_str)
203
+ my_ctype = _str_to_ctype[type_str]
204
+
205
+ assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
206
+
207
+ return my_dtype, my_ctype
208
+
209
+
210
+ def is_pickleable(obj: Any) -> bool:
211
+ try:
212
+ with io.BytesIO() as stream:
213
+ pickle.dump(obj, stream)
214
+ return True
215
+ except:
216
+ return False
217
+
218
+
219
+ # Functionality to import modules/objects by name, and call functions by name
220
+ # ------------------------------------------------------------------------------------------
221
+
222
+ def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
223
+ """Searches for the underlying module behind the name to some python object.
224
+ Returns the module and the object name (original name with module part removed)."""
225
+
226
+ # allow convenience shorthands, substitute them by full names
227
+ obj_name = re.sub("^np.", "numpy.", obj_name)
228
+ obj_name = re.sub("^tf.", "tensorflow.", obj_name)
229
+
230
+ # list alternatives for (module_name, local_obj_name)
231
+ parts = obj_name.split(".")
232
+ name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
233
+
234
+ # try each alternative in turn
235
+ for module_name, local_obj_name in name_pairs:
236
+ try:
237
+ module = importlib.import_module(module_name) # may raise ImportError
238
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
239
+ return module, local_obj_name
240
+ except:
241
+ pass
242
+
243
+ # maybe some of the modules themselves contain errors?
244
+ for module_name, _local_obj_name in name_pairs:
245
+ try:
246
+ importlib.import_module(module_name) # may raise ImportError
247
+ except ImportError:
248
+ if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
249
+ raise
250
+
251
+ # maybe the requested attribute is missing?
252
+ for module_name, local_obj_name in name_pairs:
253
+ try:
254
+ module = importlib.import_module(module_name) # may raise ImportError
255
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
256
+ except ImportError:
257
+ pass
258
+
259
+ # we are out of luck, but we have no idea why
260
+ raise ImportError(obj_name)
261
+
262
+
263
+ def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
264
+ """Traverses the object name and returns the last (rightmost) python object."""
265
+ if obj_name == '':
266
+ return module
267
+ obj = module
268
+ for part in obj_name.split("."):
269
+ obj = getattr(obj, part)
270
+ return obj
271
+
272
+
273
+ def get_obj_by_name(name: str) -> Any:
274
+ """Finds the python object with the given name."""
275
+ module, obj_name = get_module_from_obj_name(name)
276
+ return get_obj_from_module(module, obj_name)
277
+
278
+
279
+ def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
280
+ """Finds the python object with the given name and calls it as a function."""
281
+ assert func_name is not None
282
+ func_obj = get_obj_by_name(func_name)
283
+ assert callable(func_obj)
284
+ return func_obj(*args, **kwargs)
285
+
286
+
287
+ def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
288
+ """Finds the python class with the given name and constructs it with the given arguments."""
289
+ return call_func_by_name(*args, func_name=class_name, **kwargs)
290
+
291
+
292
+ def get_module_dir_by_obj_name(obj_name: str) -> str:
293
+ """Get the directory path of the module containing the given object name."""
294
+ module, _ = get_module_from_obj_name(obj_name)
295
+ return os.path.dirname(inspect.getfile(module))
296
+
297
+
298
+ def is_top_level_function(obj: Any) -> bool:
299
+ """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
300
+ return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
301
+
302
+
303
+ def get_top_level_function_name(obj: Any) -> str:
304
+ """Return the fully-qualified name of a top-level function."""
305
+ assert is_top_level_function(obj)
306
+ module = obj.__module__
307
+ if module == '__main__':
308
+ module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
309
+ return module + "." + obj.__name__
310
+
311
+
312
+ # File system helpers
313
+ # ------------------------------------------------------------------------------------------
314
+
315
+ def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
316
+ """List all files recursively in a given directory while ignoring given file and directory names.
317
+ Returns list of tuples containing both absolute and relative paths."""
318
+ assert os.path.isdir(dir_path)
319
+ base_name = os.path.basename(os.path.normpath(dir_path))
320
+
321
+ if ignores is None:
322
+ ignores = []
323
+
324
+ result = []
325
+
326
+ for root, dirs, files in os.walk(dir_path, topdown=True):
327
+ for ignore_ in ignores:
328
+ dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
329
+
330
+ # dirs need to be edited in-place
331
+ for d in dirs_to_remove:
332
+ dirs.remove(d)
333
+
334
+ files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
335
+
336
+ absolute_paths = [os.path.join(root, f) for f in files]
337
+ relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
338
+
339
+ if add_base_to_relative:
340
+ relative_paths = [os.path.join(base_name, p) for p in relative_paths]
341
+
342
+ assert len(absolute_paths) == len(relative_paths)
343
+ result += zip(absolute_paths, relative_paths)
344
+
345
+ return result
346
+
347
+
348
+ def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
349
+ """Takes in a list of tuples of (src, dst) paths and copies files.
350
+ Will create all necessary directories."""
351
+ for file in files:
352
+ target_dir_name = os.path.dirname(file[1])
353
+
354
+ # will create all intermediate-level directories
355
+ if not os.path.exists(target_dir_name):
356
+ os.makedirs(target_dir_name)
357
+
358
+ shutil.copyfile(file[0], file[1])
359
+
360
+
361
+ # URL helpers
362
+ # ------------------------------------------------------------------------------------------
363
+
364
+ def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
365
+ """Determine whether the given object is a valid URL string."""
366
+ if not isinstance(obj, str) or not "://" in obj:
367
+ return False
368
+ if allow_file_urls and obj.startswith('file://'):
369
+ return True
370
+ try:
371
+ res = requests.compat.urlparse(obj)
372
+ if not res.scheme or not res.netloc or not "." in res.netloc:
373
+ return False
374
+ res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
375
+ if not res.scheme or not res.netloc or not "." in res.netloc:
376
+ return False
377
+ except:
378
+ return False
379
+ return True
380
+
381
+
382
+ def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
383
+ """Download the given URL and return a binary-mode file object to access the data."""
384
+ assert num_attempts >= 1
385
+ assert not (return_filename and (not cache))
386
+
387
+ # Doesn't look like an URL scheme so interpret it as a local filename.
388
+ if not re.match('^[a-z]+://', url):
389
+ return url if return_filename else open(url, "rb")
390
+
391
+ # Handle file URLs. This code handles unusual file:// patterns that
392
+ # arise on Windows:
393
+ #
394
+ # file:///c:/foo.txt
395
+ #
396
+ # which would translate to a local '/c:/foo.txt' filename that's
397
+ # invalid. Drop the forward slash for such pathnames.
398
+ #
399
+ # If you touch this code path, you should test it on both Linux and
400
+ # Windows.
401
+ #
402
+ # Some internet resources suggest using urllib.request.url2pathname() but
403
+ # but that converts forward slashes to backslashes and this causes
404
+ # its own set of problems.
405
+ if url.startswith('file://'):
406
+ filename = urllib.parse.urlparse(url).path
407
+ if re.match(r'^/[a-zA-Z]:', filename):
408
+ filename = filename[1:]
409
+ return filename if return_filename else open(filename, "rb")
410
+
411
+ assert is_url(url)
412
+
413
+ # Lookup from cache.
414
+ if cache_dir is None:
415
+ cache_dir = make_cache_dir_path('downloads')
416
+
417
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
418
+ if cache:
419
+ cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
420
+ if len(cache_files) == 1:
421
+ filename = cache_files[0]
422
+ return filename if return_filename else open(filename, "rb")
423
+
424
+ # Download.
425
+ url_name = None
426
+ url_data = None
427
+ with requests.Session() as session:
428
+ if verbose:
429
+ print("Downloading %s ..." % url, end="", flush=True)
430
+ for attempts_left in reversed(range(num_attempts)):
431
+ try:
432
+ with session.get(url) as res:
433
+ res.raise_for_status()
434
+ if len(res.content) == 0:
435
+ raise IOError("No data received")
436
+
437
+ if len(res.content) < 8192:
438
+ content_str = res.content.decode("utf-8")
439
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
440
+ links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
441
+ if len(links) == 1:
442
+ url = requests.compat.urljoin(url, links[0])
443
+ raise IOError("Google Drive virus checker nag")
444
+ if "Google Drive - Quota exceeded" in content_str:
445
+ raise IOError("Google Drive download quota exceeded -- please try again later")
446
+
447
+ match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
448
+ url_name = match[1] if match else url
449
+ url_data = res.content
450
+ if verbose:
451
+ print(" done")
452
+ break
453
+ except KeyboardInterrupt:
454
+ raise
455
+ except:
456
+ if not attempts_left:
457
+ if verbose:
458
+ print(" failed")
459
+ raise
460
+ if verbose:
461
+ print(".", end="", flush=True)
462
+
463
+ # Save to cache.
464
+ if cache:
465
+ safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
466
+ cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
467
+ temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
468
+ os.makedirs(cache_dir, exist_ok=True)
469
+ with open(temp_file, "wb") as f:
470
+ f.write(url_data)
471
+ os.replace(temp_file, cache_file) # atomic
472
+ if return_filename:
473
+ return cache_file
474
+
475
+ # Return data as file object.
476
+ assert not return_filename
477
+ return io.BytesIO(url_data)
legacy.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import click
10
+ import pickle
11
+ import re
12
+ import copy
13
+ import numpy as np
14
+ import torch
15
+ import dnnlib
16
+ from torch_utils import misc
17
+
18
+ # ----------------------------------------------------------------------------
19
+
20
+
21
+ def load_network_pkl(f, force_fp16=False):
22
+ data = _LegacyUnpickler(f).load()
23
+
24
+ # Legacy TensorFlow pickle => convert.
25
+ if (
26
+ isinstance(data, tuple)
27
+ and len(data) == 3
28
+ and all(isinstance(net, _TFNetworkStub) for net in data)
29
+ ):
30
+ tf_G, tf_D, tf_Gs = data
31
+ G = convert_tf_generator(tf_G)
32
+ D = convert_tf_discriminator(tf_D)
33
+ G_ema = convert_tf_generator(tf_Gs)
34
+ data = dict(G=G, D=D, G_ema=G_ema)
35
+
36
+ # Add missing fields.
37
+ if "training_set_kwargs" not in data:
38
+ data["training_set_kwargs"] = None
39
+ if "augment_pipe" not in data:
40
+ data["augment_pipe"] = None
41
+
42
+ # Validate contents.
43
+ assert isinstance(data["G"], torch.nn.Module)
44
+ assert isinstance(data["D"], torch.nn.Module)
45
+ assert isinstance(data["G_ema"], torch.nn.Module)
46
+ assert isinstance(data["training_set_kwargs"], (dict, type(None)))
47
+ assert isinstance(data["augment_pipe"], (torch.nn.Module, type(None)))
48
+
49
+ # Force FP16.
50
+ if force_fp16:
51
+ for key in ["G", "D", "G_ema"]:
52
+ old = data[key]
53
+ kwargs = copy.deepcopy(old.init_kwargs)
54
+ if key.startswith("G"):
55
+ kwargs.synthesis_kwargs = dnnlib.EasyDict(
56
+ kwargs.get("synthesis_kwargs", {})
57
+ )
58
+ kwargs.synthesis_kwargs.num_fp16_res = 4
59
+ kwargs.synthesis_kwargs.conv_clamp = 256
60
+ if key.startswith("D"):
61
+ kwargs.num_fp16_res = 4
62
+ kwargs.conv_clamp = 256
63
+ if kwargs != old.init_kwargs:
64
+ new = type(old)(**kwargs).eval().requires_grad_(False)
65
+ misc.copy_params_and_buffers(old, new, require_all=True)
66
+ data[key] = new
67
+ return data
68
+
69
+
70
+ # ----------------------------------------------------------------------------
71
+
72
+
73
+ class _TFNetworkStub(dnnlib.EasyDict):
74
+ pass
75
+
76
+
77
+ class _LegacyUnpickler(pickle.Unpickler):
78
+ def find_class(self, module, name):
79
+ if module == "dnnlib.tflib.network" and name == "Network":
80
+ return _TFNetworkStub
81
+ return super().find_class(module, name)
82
+
83
+
84
+ # ----------------------------------------------------------------------------
85
+
86
+
87
+ def _collect_tf_params(tf_net):
88
+ # pylint: disable=protected-access
89
+ tf_params = dict()
90
+
91
+ def recurse(prefix, tf_net):
92
+ for name, value in tf_net.variables:
93
+ tf_params[prefix + name] = value
94
+ for name, comp in tf_net.components.items():
95
+ recurse(prefix + name + "/", comp)
96
+
97
+ recurse("", tf_net)
98
+ return tf_params
99
+
100
+
101
+ # ----------------------------------------------------------------------------
102
+
103
+
104
+ def _populate_module_params(module, *patterns):
105
+ for name, tensor in misc.named_params_and_buffers(module):
106
+ found = False
107
+ value = None
108
+ for pattern, value_fn in zip(patterns[0::2], patterns[1::2]):
109
+ match = re.fullmatch(pattern, name)
110
+ if match:
111
+ found = True
112
+ if value_fn is not None:
113
+ value = value_fn(*match.groups())
114
+ break
115
+ try:
116
+ assert found
117
+ if value is not None:
118
+ tensor.copy_(torch.from_numpy(np.array(value)))
119
+ except:
120
+ print(name, list(tensor.shape))
121
+ raise
122
+
123
+
124
+ # ----------------------------------------------------------------------------
125
+
126
+
127
+ def convert_tf_generator(tf_G):
128
+ if tf_G.version < 4:
129
+ raise ValueError("TensorFlow pickle version too low")
130
+
131
+ # Collect kwargs.
132
+ tf_kwargs = tf_G.static_kwargs
133
+ known_kwargs = set()
134
+
135
+ def kwarg(tf_name, default=None, none=None):
136
+ known_kwargs.add(tf_name)
137
+ val = tf_kwargs.get(tf_name, default)
138
+ return val if val is not None else none
139
+
140
+ # Convert kwargs.
141
+ kwargs = dnnlib.EasyDict(
142
+ z_dim=kwarg("latent_size", 512),
143
+ c_dim=kwarg("label_size", 0),
144
+ w_dim=kwarg("dlatent_size", 512),
145
+ img_resolution=kwarg("resolution", 1024),
146
+ img_channels=kwarg("num_channels", 3),
147
+ mapping_kwargs=dnnlib.EasyDict(
148
+ num_layers=kwarg("mapping_layers", 8),
149
+ embed_features=kwarg("label_fmaps", None),
150
+ layer_features=kwarg("mapping_fmaps", None),
151
+ activation=kwarg("mapping_nonlinearity", "lrelu"),
152
+ lr_multiplier=kwarg("mapping_lrmul", 0.01),
153
+ w_avg_beta=kwarg("w_avg_beta", 0.995, none=1),
154
+ ),
155
+ synthesis_kwargs=dnnlib.EasyDict(
156
+ channel_base=kwarg("fmap_base", 16384) * 2,
157
+ channel_max=kwarg("fmap_max", 512),
158
+ num_fp16_res=kwarg("num_fp16_res", 0),
159
+ conv_clamp=kwarg("conv_clamp", None),
160
+ architecture=kwarg("architecture", "skip"),
161
+ resample_filter=kwarg("resample_kernel", [1, 3, 3, 1]),
162
+ use_noise=kwarg("use_noise", True),
163
+ activation=kwarg("nonlinearity", "lrelu"),
164
+ ),
165
+ )
166
+
167
+ # Check for unknown kwargs.
168
+ kwarg("truncation_psi")
169
+ kwarg("truncation_cutoff")
170
+ kwarg("style_mixing_prob")
171
+ kwarg("structure")
172
+ unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
173
+ if len(unknown_kwargs) > 0:
174
+ raise ValueError("Unknown TensorFlow kwarg", unknown_kwargs[0])
175
+
176
+ # Collect params.
177
+ tf_params = _collect_tf_params(tf_G)
178
+ for name, value in list(tf_params.items()):
179
+ match = re.fullmatch(r"ToRGB_lod(\d+)/(.*)", name)
180
+ if match:
181
+ r = kwargs.img_resolution // (2 ** int(match.group(1)))
182
+ tf_params[f"{r}x{r}/ToRGB/{match.group(2)}"] = value
183
+ kwargs.synthesis.kwargs.architecture = "orig"
184
+ # for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
185
+
186
+ # Convert params.
187
+ from training import networks
188
+
189
+ G = networks.Generator(**kwargs).eval().requires_grad_(False)
190
+ # pylint: disable=unnecessary-lambda
191
+ _populate_module_params(
192
+ G,
193
+ r"mapping\.w_avg",
194
+ lambda: tf_params[f"dlatent_avg"],
195
+ r"mapping\.embed\.weight",
196
+ lambda: tf_params[f"mapping/LabelEmbed/weight"].transpose(),
197
+ r"mapping\.embed\.bias",
198
+ lambda: tf_params[f"mapping/LabelEmbed/bias"],
199
+ r"mapping\.fc(\d+)\.weight",
200
+ lambda i: tf_params[f"mapping/Dense{i}/weight"].transpose(),
201
+ r"mapping\.fc(\d+)\.bias",
202
+ lambda i: tf_params[f"mapping/Dense{i}/bias"],
203
+ r"synthesis\.b4\.const",
204
+ lambda: tf_params[f"synthesis/4x4/Const/const"][0],
205
+ r"synthesis\.b4\.conv1\.weight",
206
+ lambda: tf_params[f"synthesis/4x4/Conv/weight"].transpose(3, 2, 0, 1),
207
+ r"synthesis\.b4\.conv1\.bias",
208
+ lambda: tf_params[f"synthesis/4x4/Conv/bias"],
209
+ r"synthesis\.b4\.conv1\.noise_const",
210
+ lambda: tf_params[f"synthesis/noise0"][0, 0],
211
+ r"synthesis\.b4\.conv1\.noise_strength",
212
+ lambda: tf_params[f"synthesis/4x4/Conv/noise_strength"],
213
+ r"synthesis\.b4\.conv1\.affine\.weight",
214
+ lambda: tf_params[f"synthesis/4x4/Conv/mod_weight"].transpose(),
215
+ r"synthesis\.b4\.conv1\.affine\.bias",
216
+ lambda: tf_params[f"synthesis/4x4/Conv/mod_bias"] + 1,
217
+ r"synthesis\.b(\d+)\.conv0\.weight",
218
+ lambda r: tf_params[f"synthesis/{r}x{r}/Conv0_up/weight"][::-1, ::-1].transpose(
219
+ 3, 2, 0, 1
220
+ ),
221
+ r"synthesis\.b(\d+)\.conv0\.bias",
222
+ lambda r: tf_params[f"synthesis/{r}x{r}/Conv0_up/bias"],
223
+ r"synthesis\.b(\d+)\.conv0\.noise_const",
224
+ lambda r: tf_params[f"synthesis/noise{int(np.log2(int(r)))*2-5}"][0, 0],
225
+ r"synthesis\.b(\d+)\.conv0\.noise_strength",
226
+ lambda r: tf_params[f"synthesis/{r}x{r}/Conv0_up/noise_strength"],
227
+ r"synthesis\.b(\d+)\.conv0\.affine\.weight",
228
+ lambda r: tf_params[f"synthesis/{r}x{r}/Conv0_up/mod_weight"].transpose(),
229
+ r"synthesis\.b(\d+)\.conv0\.affine\.bias",
230
+ lambda r: tf_params[f"synthesis/{r}x{r}/Conv0_up/mod_bias"] + 1,
231
+ r"synthesis\.b(\d+)\.conv1\.weight",
232
+ lambda r: tf_params[f"synthesis/{r}x{r}/Conv1/weight"].transpose(3, 2, 0, 1),
233
+ r"synthesis\.b(\d+)\.conv1\.bias",
234
+ lambda r: tf_params[f"synthesis/{r}x{r}/Conv1/bias"],
235
+ r"synthesis\.b(\d+)\.conv1\.noise_const",
236
+ lambda r: tf_params[f"synthesis/noise{int(np.log2(int(r)))*2-4}"][0, 0],
237
+ r"synthesis\.b(\d+)\.conv1\.noise_strength",
238
+ lambda r: tf_params[f"synthesis/{r}x{r}/Conv1/noise_strength"],
239
+ r"synthesis\.b(\d+)\.conv1\.affine\.weight",
240
+ lambda r: tf_params[f"synthesis/{r}x{r}/Conv1/mod_weight"].transpose(),
241
+ r"synthesis\.b(\d+)\.conv1\.affine\.bias",
242
+ lambda r: tf_params[f"synthesis/{r}x{r}/Conv1/mod_bias"] + 1,
243
+ r"synthesis\.b(\d+)\.torgb\.weight",
244
+ lambda r: tf_params[f"synthesis/{r}x{r}/ToRGB/weight"].transpose(3, 2, 0, 1),
245
+ r"synthesis\.b(\d+)\.torgb\.bias",
246
+ lambda r: tf_params[f"synthesis/{r}x{r}/ToRGB/bias"],
247
+ r"synthesis\.b(\d+)\.torgb\.affine\.weight",
248
+ lambda r: tf_params[f"synthesis/{r}x{r}/ToRGB/mod_weight"].transpose(),
249
+ r"synthesis\.b(\d+)\.torgb\.affine\.bias",
250
+ lambda r: tf_params[f"synthesis/{r}x{r}/ToRGB/mod_bias"] + 1,
251
+ r"synthesis\.b(\d+)\.skip\.weight",
252
+ lambda r: tf_params[f"synthesis/{r}x{r}/Skip/weight"][::-1, ::-1].transpose(
253
+ 3, 2, 0, 1
254
+ ),
255
+ r".*\.resample_filter",
256
+ None,
257
+ )
258
+ return G
259
+
260
+
261
+ # ----------------------------------------------------------------------------
262
+
263
+
264
+ def convert_tf_discriminator(tf_D):
265
+ if tf_D.version < 4:
266
+ raise ValueError("TensorFlow pickle version too low")
267
+
268
+ # Collect kwargs.
269
+ tf_kwargs = tf_D.static_kwargs
270
+ known_kwargs = set()
271
+
272
+ def kwarg(tf_name, default=None):
273
+ known_kwargs.add(tf_name)
274
+ return tf_kwargs.get(tf_name, default)
275
+
276
+ # Convert kwargs.
277
+ kwargs = dnnlib.EasyDict(
278
+ c_dim=kwarg("label_size", 0),
279
+ img_resolution=kwarg("resolution", 1024),
280
+ img_channels=kwarg("num_channels", 3),
281
+ architecture=kwarg("architecture", "resnet"),
282
+ channel_base=kwarg("fmap_base", 16384) * 2,
283
+ channel_max=kwarg("fmap_max", 512),
284
+ num_fp16_res=kwarg("num_fp16_res", 0),
285
+ conv_clamp=kwarg("conv_clamp", None),
286
+ cmap_dim=kwarg("mapping_fmaps", None),
287
+ block_kwargs=dnnlib.EasyDict(
288
+ activation=kwarg("nonlinearity", "lrelu"),
289
+ resample_filter=kwarg("resample_kernel", [1, 3, 3, 1]),
290
+ freeze_layers=kwarg("freeze_layers", 0),
291
+ ),
292
+ mapping_kwargs=dnnlib.EasyDict(
293
+ num_layers=kwarg("mapping_layers", 0),
294
+ embed_features=kwarg("mapping_fmaps", None),
295
+ layer_features=kwarg("mapping_fmaps", None),
296
+ activation=kwarg("nonlinearity", "lrelu"),
297
+ lr_multiplier=kwarg("mapping_lrmul", 0.1),
298
+ ),
299
+ epilogue_kwargs=dnnlib.EasyDict(
300
+ mbstd_group_size=kwarg("mbstd_group_size", None),
301
+ mbstd_num_channels=kwarg("mbstd_num_features", 1),
302
+ activation=kwarg("nonlinearity", "lrelu"),
303
+ ),
304
+ )
305
+
306
+ # Check for unknown kwargs.
307
+ kwarg("structure")
308
+ unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
309
+ if len(unknown_kwargs) > 0:
310
+ raise ValueError("Unknown TensorFlow kwarg", unknown_kwargs[0])
311
+
312
+ # Collect params.
313
+ tf_params = _collect_tf_params(tf_D)
314
+ for name, value in list(tf_params.items()):
315
+ match = re.fullmatch(r"FromRGB_lod(\d+)/(.*)", name)
316
+ if match:
317
+ r = kwargs.img_resolution // (2 ** int(match.group(1)))
318
+ tf_params[f"{r}x{r}/FromRGB/{match.group(2)}"] = value
319
+ kwargs.architecture = "orig"
320
+ # for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
321
+
322
+ # Convert params.
323
+ from training import networks
324
+
325
+ D = networks.Discriminator(**kwargs).eval().requires_grad_(False)
326
+ # pylint: disable=unnecessary-lambda
327
+ _populate_module_params(
328
+ D,
329
+ r"b(\d+)\.fromrgb\.weight",
330
+ lambda r: tf_params[f"{r}x{r}/FromRGB/weight"].transpose(3, 2, 0, 1),
331
+ r"b(\d+)\.fromrgb\.bias",
332
+ lambda r: tf_params[f"{r}x{r}/FromRGB/bias"],
333
+ r"b(\d+)\.conv(\d+)\.weight",
334
+ lambda r, i: tf_params[
335
+ f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'
336
+ ].transpose(3, 2, 0, 1),
337
+ r"b(\d+)\.conv(\d+)\.bias",
338
+ lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'],
339
+ r"b(\d+)\.skip\.weight",
340
+ lambda r: tf_params[f"{r}x{r}/Skip/weight"].transpose(3, 2, 0, 1),
341
+ r"mapping\.embed\.weight",
342
+ lambda: tf_params[f"LabelEmbed/weight"].transpose(),
343
+ r"mapping\.embed\.bias",
344
+ lambda: tf_params[f"LabelEmbed/bias"],
345
+ r"mapping\.fc(\d+)\.weight",
346
+ lambda i: tf_params[f"Mapping{i}/weight"].transpose(),
347
+ r"mapping\.fc(\d+)\.bias",
348
+ lambda i: tf_params[f"Mapping{i}/bias"],
349
+ r"b4\.conv\.weight",
350
+ lambda: tf_params[f"4x4/Conv/weight"].transpose(3, 2, 0, 1),
351
+ r"b4\.conv\.bias",
352
+ lambda: tf_params[f"4x4/Conv/bias"],
353
+ r"b4\.fc\.weight",
354
+ lambda: tf_params[f"4x4/Dense0/weight"].transpose(),
355
+ r"b4\.fc\.bias",
356
+ lambda: tf_params[f"4x4/Dense0/bias"],
357
+ r"b4\.out\.weight",
358
+ lambda: tf_params[f"Output/weight"].transpose(),
359
+ r"b4\.out\.bias",
360
+ lambda: tf_params[f"Output/bias"],
361
+ r".*\.resample_filter",
362
+ None,
363
+ )
364
+ return D
365
+
366
+
367
+ # ----------------------------------------------------------------------------
368
+
369
+
370
+ @click.command()
371
+ @click.option("--source", help="Input pickle", required=True, metavar="PATH")
372
+ @click.option("--dest", help="Output pickle", required=True, metavar="PATH")
373
+ @click.option(
374
+ "--force-fp16",
375
+ help="Force the networks to use FP16",
376
+ type=bool,
377
+ default=False,
378
+ metavar="BOOL",
379
+ show_default=True,
380
+ )
381
+ def convert_network_pickle(source, dest, force_fp16):
382
+ """Convert legacy network pickle into the native PyTorch format.
383
+
384
+ The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA.
385
+ It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks.
386
+
387
+ Example:
388
+
389
+ \b
390
+ python legacy.py \\
391
+ --source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\
392
+ --dest=stylegan2-cat-config-f.pkl
393
+ """
394
+ print(f'Loading "{source}"...')
395
+ with dnnlib.util.open_url(source) as f:
396
+ data = load_network_pkl(f, force_fp16=force_fp16)
397
+ print(f'Saving "{dest}"...')
398
+ with open(dest, "wb") as f:
399
+ pickle.dump(data, f)
400
+ print("Done.")
401
+
402
+
403
+ # ----------------------------------------------------------------------------
404
+
405
+ if __name__ == "__main__":
406
+ convert_network_pickle() # pylint: disable=no-value-for-parameter
407
+
408
+ # ----------------------------------------------------------------------------
torch_utils/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ # empty
torch_utils/__pycache__/__init__.cpython-36.pyc ADDED
Binary file (156 Bytes). View file
 
torch_utils/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (164 Bytes). View file
 
torch_utils/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (164 Bytes). View file
 
torch_utils/__pycache__/custom_ops.cpython-36.pyc ADDED
Binary file (3.2 kB). View file
 
torch_utils/__pycache__/custom_ops.cpython-38.pyc ADDED
Binary file (3.22 kB). View file
 
torch_utils/__pycache__/custom_ops.cpython-39.pyc ADDED
Binary file (3.21 kB). View file
 
torch_utils/__pycache__/misc.cpython-36.pyc ADDED
Binary file (9.77 kB). View file
 
torch_utils/__pycache__/misc.cpython-38.pyc ADDED
Binary file (9.9 kB). View file
 
torch_utils/__pycache__/misc.cpython-39.pyc ADDED
Binary file (9.84 kB). View file
 
torch_utils/__pycache__/persistence.cpython-36.pyc ADDED
Binary file (8.6 kB). View file
 
torch_utils/__pycache__/persistence.cpython-38.pyc ADDED
Binary file (8.65 kB). View file
 
torch_utils/__pycache__/persistence.cpython-39.pyc ADDED
Binary file (8.63 kB). View file
 
torch_utils/custom_ops.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import os
10
+ import glob
11
+ import torch
12
+ import torch.utils.cpp_extension
13
+ import importlib
14
+ import hashlib
15
+ import shutil
16
+ from pathlib import Path
17
+
18
+ from torch.utils.file_baton import FileBaton
19
+
20
+ #----------------------------------------------------------------------------
21
+ # Global options.
22
+
23
+ verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
24
+
25
+ #----------------------------------------------------------------------------
26
+ # Internal helper funcs.
27
+
28
+ def _find_compiler_bindir():
29
+ patterns = [
30
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
31
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
32
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
33
+ 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
34
+ ]
35
+ for pattern in patterns:
36
+ matches = sorted(glob.glob(pattern))
37
+ if len(matches):
38
+ return matches[-1]
39
+ return None
40
+
41
+ #----------------------------------------------------------------------------
42
+ # Main entry point for compiling and loading C++/CUDA plugins.
43
+
44
+ _cached_plugins = dict()
45
+
46
+ def get_plugin(module_name, sources, **build_kwargs):
47
+ assert verbosity in ['none', 'brief', 'full']
48
+
49
+ # Already cached?
50
+ if module_name in _cached_plugins:
51
+ return _cached_plugins[module_name]
52
+
53
+ # Print status.
54
+ if verbosity == 'full':
55
+ print(f'Setting up PyTorch plugin "{module_name}"...')
56
+ elif verbosity == 'brief':
57
+ print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
58
+
59
+ try: # pylint: disable=too-many-nested-blocks
60
+ # Make sure we can find the necessary compiler binaries.
61
+ if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
62
+ compiler_bindir = _find_compiler_bindir()
63
+ if compiler_bindir is None:
64
+ raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
65
+ os.environ['PATH'] += ';' + compiler_bindir
66
+
67
+ # Compile and load.
68
+ verbose_build = (verbosity == 'full')
69
+
70
+ # Incremental build md5sum trickery. Copies all the input source files
71
+ # into a cached build directory under a combined md5 digest of the input
72
+ # source files. Copying is done only if the combined digest has changed.
73
+ # This keeps input file timestamps and filenames the same as in previous
74
+ # extension builds, allowing for fast incremental rebuilds.
75
+ #
76
+ # This optimization is done only in case all the source files reside in
77
+ # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
78
+ # environment variable is set (we take this as a signal that the user
79
+ # actually cares about this.)
80
+ source_dirs_set = set(os.path.dirname(source) for source in sources)
81
+ if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ):
82
+ all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file()))
83
+
84
+ # Compute a combined hash digest for all source files in the same
85
+ # custom op directory (usually .cu, .cpp, .py and .h files).
86
+ hash_md5 = hashlib.md5()
87
+ for src in all_source_files:
88
+ with open(src, 'rb') as f:
89
+ hash_md5.update(f.read())
90
+ build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
91
+ digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest())
92
+
93
+ if not os.path.isdir(digest_build_dir):
94
+ os.makedirs(digest_build_dir, exist_ok=True)
95
+ baton = FileBaton(os.path.join(digest_build_dir, 'lock'))
96
+ if baton.try_acquire():
97
+ try:
98
+ for src in all_source_files:
99
+ shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src)))
100
+ finally:
101
+ baton.release()
102
+ else:
103
+ # Someone else is copying source files under the digest dir,
104
+ # wait until done and continue.
105
+ baton.wait()
106
+ digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources]
107
+ torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir,
108
+ verbose=verbose_build, sources=digest_sources, **build_kwargs)
109
+ else:
110
+ torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
111
+ module = importlib.import_module(module_name)
112
+
113
+ except:
114
+ if verbosity == 'brief':
115
+ print('Failed!')
116
+ raise
117
+
118
+ # Print status and add to cache.
119
+ if verbosity == 'full':
120
+ print(f'Done setting up PyTorch plugin "{module_name}".')
121
+ elif verbosity == 'brief':
122
+ print('Done.')
123
+ _cached_plugins[module_name] = module
124
+ return module
125
+
126
+ #----------------------------------------------------------------------------
torch_utils/misc.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import re
10
+ import contextlib
11
+ import numpy as np
12
+ import torch
13
+ import warnings
14
+ import dnnlib
15
+
16
+ # ----------------------------------------------------------------------------
17
+ # Cached construction of constant tensors. Avoids CPU=>GPU copy when the
18
+ # same constant is used multiple times.
19
+
20
+ _constant_cache = dict()
21
+
22
+
23
+ def constant(value, shape=None, dtype=None, device=None, memory_format=None):
24
+ value = np.asarray(value)
25
+ if shape is not None:
26
+ shape = tuple(shape)
27
+ if dtype is None:
28
+ dtype = torch.get_default_dtype()
29
+ if device is None:
30
+ device = torch.device("cpu")
31
+ if memory_format is None:
32
+ memory_format = torch.contiguous_format
33
+
34
+ key = (
35
+ value.shape,
36
+ value.dtype,
37
+ value.tobytes(),
38
+ shape,
39
+ dtype,
40
+ device,
41
+ memory_format,
42
+ )
43
+ tensor = _constant_cache.get(key, None)
44
+ if tensor is None:
45
+ tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
46
+ if shape is not None:
47
+ tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
48
+ tensor = tensor.contiguous(memory_format=memory_format)
49
+ _constant_cache[key] = tensor
50
+ return tensor
51
+
52
+
53
+ # ----------------------------------------------------------------------------
54
+ # Replace NaN/Inf with specified numerical values.
55
+
56
+ try:
57
+ nan_to_num = torch.nan_to_num # 1.8.0a0
58
+ except AttributeError:
59
+
60
+ def nan_to_num(
61
+ input, nan=0.0, posinf=None, neginf=None, *, out=None
62
+ ): # pylint: disable=redefined-builtin
63
+ assert isinstance(input, torch.Tensor)
64
+ if posinf is None:
65
+ posinf = torch.finfo(input.dtype).max
66
+ if neginf is None:
67
+ neginf = torch.finfo(input.dtype).min
68
+ assert nan == 0
69
+ return torch.clamp(
70
+ input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out
71
+ )
72
+
73
+
74
+ # ----------------------------------------------------------------------------
75
+ # Symbolic assert.
76
+
77
+ try:
78
+ symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
79
+ except AttributeError:
80
+ symbolic_assert = torch.Assert # 1.7.0
81
+
82
+ # ----------------------------------------------------------------------------
83
+ # Context manager to suppress known warnings in torch.jit.trace().
84
+
85
+
86
+ class suppress_tracer_warnings(warnings.catch_warnings):
87
+ def __enter__(self):
88
+ super().__enter__()
89
+ warnings.simplefilter("ignore", category=torch.jit.TracerWarning)
90
+ return self
91
+
92
+
93
+ # ----------------------------------------------------------------------------
94
+ # Assert that the shape of a tensor matches the given list of integers.
95
+ # None indicates that the size of a dimension is allowed to vary.
96
+ # Performs symbolic assertion when used in torch.jit.trace().
97
+
98
+
99
+ def assert_shape(tensor, ref_shape):
100
+ if tensor.ndim != len(ref_shape):
101
+ raise AssertionError(
102
+ f"Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}"
103
+ )
104
+ for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
105
+ if ref_size is None:
106
+ pass
107
+ elif isinstance(ref_size, torch.Tensor):
108
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
109
+ symbolic_assert(
110
+ torch.equal(torch.as_tensor(size), ref_size),
111
+ f"Wrong size for dimension {idx}",
112
+ )
113
+ elif isinstance(size, torch.Tensor):
114
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
115
+ symbolic_assert(
116
+ torch.equal(size, torch.as_tensor(ref_size)),
117
+ f"Wrong size for dimension {idx}: expected {ref_size}",
118
+ )
119
+ elif size != ref_size:
120
+ raise AssertionError(
121
+ f"Wrong size for dimension {idx}: got {size}, expected {ref_size}"
122
+ )
123
+
124
+
125
+ # ----------------------------------------------------------------------------
126
+ # Function decorator that calls torch.autograd.profiler.record_function().
127
+
128
+
129
+ def profiled_function(fn):
130
+ def decorator(*args, **kwargs):
131
+ with torch.autograd.profiler.record_function(fn.__name__):
132
+ return fn(*args, **kwargs)
133
+
134
+ decorator.__name__ = fn.__name__
135
+ return decorator
136
+
137
+
138
+ # ----------------------------------------------------------------------------
139
+ # Sampler for torch.utils.data.DataLoader that loops over the dataset
140
+ # indefinitely, shuffling items as it goes.
141
+
142
+
143
+ class InfiniteSampler(torch.utils.data.Sampler):
144
+ def __init__(
145
+ self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5
146
+ ):
147
+ assert len(dataset) > 0
148
+ assert num_replicas > 0
149
+ assert 0 <= rank < num_replicas
150
+ assert 0 <= window_size <= 1
151
+ super().__init__(dataset)
152
+ self.dataset = dataset
153
+ self.rank = rank
154
+ self.num_replicas = num_replicas
155
+ self.shuffle = shuffle
156
+ self.seed = seed
157
+ self.window_size = window_size
158
+
159
+ def __iter__(self):
160
+ order = np.arange(len(self.dataset))
161
+ rnd = None
162
+ window = 0
163
+ if self.shuffle:
164
+ rnd = np.random.RandomState(self.seed)
165
+ rnd.shuffle(order)
166
+ window = int(np.rint(order.size * self.window_size))
167
+
168
+ idx = 0
169
+ while True:
170
+ i = idx % order.size
171
+ if idx % self.num_replicas == self.rank:
172
+ yield order[i]
173
+ if window >= 2:
174
+ j = (i - rnd.randint(window)) % order.size
175
+ order[i], order[j] = order[j], order[i]
176
+ idx += 1
177
+
178
+
179
+ # ----------------------------------------------------------------------------
180
+ # Utilities for operating with torch.nn.Module parameters and buffers.
181
+
182
+
183
+ def params_and_buffers(module):
184
+ assert isinstance(module, torch.nn.Module)
185
+ return list(module.parameters()) + list(module.buffers())
186
+
187
+
188
+ def named_params_and_buffers(module):
189
+ assert isinstance(module, torch.nn.Module)
190
+ return list(module.named_parameters()) + list(module.named_buffers())
191
+
192
+
193
+ def copy_params_and_buffers(src_module, dst_module, require_all=False):
194
+ assert isinstance(src_module, torch.nn.Module)
195
+ assert isinstance(dst_module, torch.nn.Module)
196
+ src_tensors = {
197
+ name: tensor for name, tensor in named_params_and_buffers(src_module)
198
+ }
199
+ for name, tensor in named_params_and_buffers(dst_module):
200
+ assert (name in src_tensors) or (not require_all)
201
+ if name in src_tensors:
202
+ tensor.copy_(src_tensors[name].detach()).requires_grad_(
203
+ tensor.requires_grad
204
+ )
205
+
206
+
207
+ # ----------------------------------------------------------------------------
208
+ # Context manager for easily enabling/disabling DistributedDataParallel
209
+ # synchronization.
210
+
211
+
212
+ @contextlib.contextmanager
213
+ def ddp_sync(module, sync):
214
+ assert isinstance(module, torch.nn.Module)
215
+ if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
216
+ yield
217
+ else:
218
+ with module.no_sync():
219
+ yield
220
+
221
+
222
+ # ----------------------------------------------------------------------------
223
+ # Check DistributedDataParallel consistency across processes.
224
+
225
+
226
+ def check_ddp_consistency(module, ignore_regex=None):
227
+ assert isinstance(module, torch.nn.Module)
228
+ for name, tensor in named_params_and_buffers(module):
229
+ fullname = type(module).__name__ + "." + name
230
+ if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
231
+ continue
232
+ tensor = tensor.detach()
233
+ other = tensor.clone()
234
+ torch.distributed.broadcast(tensor=other, src=0)
235
+ assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname
236
+
237
+
238
+ # ----------------------------------------------------------------------------
239
+ # Print summary table of module hierarchy.
240
+
241
+
242
+ def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
243
+ assert isinstance(module, torch.nn.Module)
244
+ assert not isinstance(module, torch.jit.ScriptModule)
245
+ assert isinstance(inputs, (tuple, list))
246
+
247
+ # Register hooks.
248
+ entries = []
249
+ nesting = [0]
250
+
251
+ def pre_hook(_mod, _inputs):
252
+ nesting[0] += 1
253
+
254
+ def post_hook(mod, _inputs, outputs):
255
+ nesting[0] -= 1
256
+ if nesting[0] <= max_nesting:
257
+ outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
258
+ outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
259
+ entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
260
+
261
+ hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
262
+ hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
263
+
264
+ # Run module.
265
+ outputs = module(*inputs)
266
+ for hook in hooks:
267
+ hook.remove()
268
+
269
+ # Identify unique outputs, parameters, and buffers.
270
+ tensors_seen = set()
271
+ for e in entries:
272
+ e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
273
+ e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
274
+ e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
275
+ tensors_seen |= {
276
+ id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs
277
+ }
278
+
279
+ # Filter out redundant entries.
280
+ if skip_redundant:
281
+ entries = [
282
+ e
283
+ for e in entries
284
+ if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)
285
+ ]
286
+
287
+ # Construct table.
288
+ rows = [
289
+ [type(module).__name__, "Parameters", "Buffers", "Output shape", "Datatype"]
290
+ ]
291
+ rows += [["---"] * len(rows[0])]
292
+ param_total = 0
293
+ buffer_total = 0
294
+ submodule_names = {mod: name for name, mod in module.named_modules()}
295
+ for e in entries:
296
+ name = "<top-level>" if e.mod is module else submodule_names[e.mod]
297
+ param_size = sum(t.numel() for t in e.unique_params)
298
+ buffer_size = sum(t.numel() for t in e.unique_buffers)
299
+ output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs]
300
+ output_dtypes = [str(t.dtype).split(".")[-1] for t in e.outputs]
301
+ rows += [
302
+ [
303
+ name + (":0" if len(e.outputs) >= 2 else ""),
304
+ str(param_size) if param_size else "-",
305
+ str(buffer_size) if buffer_size else "-",
306
+ (output_shapes + ["-"])[0],
307
+ (output_dtypes + ["-"])[0],
308
+ ]
309
+ ]
310
+ for idx in range(1, len(e.outputs)):
311
+ rows += [
312
+ [name + f":{idx}", "-", "-", output_shapes[idx], output_dtypes[idx]]
313
+ ]
314
+ param_total += param_size
315
+ buffer_total += buffer_size
316
+ rows += [["---"] * len(rows[0])]
317
+ rows += [["Total", str(param_total), str(buffer_total), "-", "-"]]
318
+
319
+ # Print table.
320
+ widths = [max(len(cell) for cell in column) for column in zip(*rows)]
321
+ print()
322
+ for row in rows:
323
+ print(
324
+ " ".join(
325
+ cell + " " * (width - len(cell)) for cell, width in zip(row, widths)
326
+ )
327
+ )
328
+ print()
329
+ return outputs
330
+
331
+
332
+ # ----------------------------------------------------------------------------
torch_utils/ops/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ # empty
torch_utils/ops/__pycache__/__init__.cpython-36.pyc ADDED
Binary file (160 Bytes). View file
 
torch_utils/ops/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (168 Bytes). View file
 
torch_utils/ops/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (168 Bytes). View file
 
torch_utils/ops/__pycache__/bias_act.cpython-36.pyc ADDED
Binary file (8.73 kB). View file
 
torch_utils/ops/__pycache__/bias_act.cpython-38.pyc ADDED
Binary file (8.7 kB). View file
 
torch_utils/ops/__pycache__/bias_act.cpython-39.pyc ADDED
Binary file (8.65 kB). View file
 
torch_utils/ops/__pycache__/conv2d_gradfix.cpython-36.pyc ADDED
Binary file (6.57 kB). View file
 
torch_utils/ops/__pycache__/conv2d_gradfix.cpython-38.pyc ADDED
Binary file (6.5 kB). View file
 
torch_utils/ops/__pycache__/conv2d_gradfix.cpython-39.pyc ADDED
Binary file (6.44 kB). View file
 
torch_utils/ops/__pycache__/conv2d_resample.cpython-36.pyc ADDED
Binary file (4.77 kB). View file
 
torch_utils/ops/__pycache__/conv2d_resample.cpython-38.pyc ADDED
Binary file (4.81 kB). View file
 
torch_utils/ops/__pycache__/conv2d_resample.cpython-39.pyc ADDED
Binary file (4.81 kB). View file
 
torch_utils/ops/__pycache__/fma.cpython-36.pyc ADDED
Binary file (1.71 kB). View file
 
torch_utils/ops/__pycache__/fma.cpython-38.pyc ADDED
Binary file (1.74 kB). View file
 
torch_utils/ops/__pycache__/fma.cpython-39.pyc ADDED
Binary file (1.71 kB). View file
 
torch_utils/ops/__pycache__/grid_sample_gradfix.cpython-36.pyc ADDED
Binary file (2.78 kB). View file
 
torch_utils/ops/__pycache__/grid_sample_gradfix.cpython-38.pyc ADDED
Binary file (2.77 kB). View file
 
torch_utils/ops/__pycache__/grid_sample_gradfix.cpython-39.pyc ADDED
Binary file (2.74 kB). View file
 
torch_utils/ops/__pycache__/upfirdn2d.cpython-36.pyc ADDED
Binary file (14.5 kB). View file
 
torch_utils/ops/__pycache__/upfirdn2d.cpython-38.pyc ADDED
Binary file (14.5 kB). View file
 
torch_utils/ops/__pycache__/upfirdn2d.cpython-39.pyc ADDED
Binary file (14.4 kB). View file
 
torch_utils/ops/bias_act.cpp ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <torch/extension.h>
10
+ #include <ATen/cuda/CUDAContext.h>
11
+ #include <c10/cuda/CUDAGuard.h>
12
+ #include "bias_act.h"
13
+
14
+ //------------------------------------------------------------------------
15
+
16
+ static bool has_same_layout(torch::Tensor x, torch::Tensor y)
17
+ {
18
+ if (x.dim() != y.dim())
19
+ return false;
20
+ for (int64_t i = 0; i < x.dim(); i++)
21
+ {
22
+ if (x.size(i) != y.size(i))
23
+ return false;
24
+ if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
25
+ return false;
26
+ }
27
+ return true;
28
+ }
29
+
30
+ //------------------------------------------------------------------------
31
+
32
+ static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
33
+ {
34
+ // Validate arguments.
35
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
36
+ TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
37
+ TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
38
+ TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
39
+ TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
40
+ TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
41
+ TORCH_CHECK(b.dim() == 1, "b must have rank 1");
42
+ TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
43
+ TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
44
+ TORCH_CHECK(grad >= 0, "grad must be non-negative");
45
+
46
+ // Validate layout.
47
+ TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
48
+ TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
49
+ TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
50
+ TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
51
+ TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
52
+
53
+ // Create output tensor.
54
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
55
+ torch::Tensor y = torch::empty_like(x);
56
+ TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
57
+
58
+ // Initialize CUDA kernel parameters.
59
+ bias_act_kernel_params p;
60
+ p.x = x.data_ptr();
61
+ p.b = (b.numel()) ? b.data_ptr() : NULL;
62
+ p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
63
+ p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
64
+ p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
65
+ p.y = y.data_ptr();
66
+ p.grad = grad;
67
+ p.act = act;
68
+ p.alpha = alpha;
69
+ p.gain = gain;
70
+ p.clamp = clamp;
71
+ p.sizeX = (int)x.numel();
72
+ p.sizeB = (int)b.numel();
73
+ p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
74
+
75
+ // Choose CUDA kernel.
76
+ void* kernel;
77
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
78
+ {
79
+ kernel = choose_bias_act_kernel<scalar_t>(p);
80
+ });
81
+ TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
82
+
83
+ // Launch CUDA kernel.
84
+ p.loopX = 4;
85
+ int blockSize = 4 * 32;
86
+ int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
87
+ void* args[] = {&p};
88
+ AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
89
+ return y;
90
+ }
91
+
92
+ //------------------------------------------------------------------------
93
+
94
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
95
+ {
96
+ m.def("bias_act", &bias_act);
97
+ }
98
+
99
+ //------------------------------------------------------------------------
torch_utils/ops/bias_act.cu ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <c10/util/Half.h>
10
+ #include "bias_act.h"
11
+
12
+ //------------------------------------------------------------------------
13
+ // Helpers.
14
+
15
+ template <class T> struct InternalType;
16
+ template <> struct InternalType<double> { typedef double scalar_t; };
17
+ template <> struct InternalType<float> { typedef float scalar_t; };
18
+ template <> struct InternalType<c10::Half> { typedef float scalar_t; };
19
+
20
+ //------------------------------------------------------------------------
21
+ // CUDA kernel.
22
+
23
+ template <class T, int A>
24
+ __global__ void bias_act_kernel(bias_act_kernel_params p)
25
+ {
26
+ typedef typename InternalType<T>::scalar_t scalar_t;
27
+ int G = p.grad;
28
+ scalar_t alpha = (scalar_t)p.alpha;
29
+ scalar_t gain = (scalar_t)p.gain;
30
+ scalar_t clamp = (scalar_t)p.clamp;
31
+ scalar_t one = (scalar_t)1;
32
+ scalar_t two = (scalar_t)2;
33
+ scalar_t expRange = (scalar_t)80;
34
+ scalar_t halfExpRange = (scalar_t)40;
35
+ scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
36
+ scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
37
+
38
+ // Loop over elements.
39
+ int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
40
+ for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
41
+ {
42
+ // Load.
43
+ scalar_t x = (scalar_t)((const T*)p.x)[xi];
44
+ scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
45
+ scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
46
+ scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
47
+ scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
48
+ scalar_t yy = (gain != 0) ? yref / gain : 0;
49
+ scalar_t y = 0;
50
+
51
+ // Apply bias.
52
+ ((G == 0) ? x : xref) += b;
53
+
54
+ // linear
55
+ if (A == 1)
56
+ {
57
+ if (G == 0) y = x;
58
+ if (G == 1) y = x;
59
+ }
60
+
61
+ // relu
62
+ if (A == 2)
63
+ {
64
+ if (G == 0) y = (x > 0) ? x : 0;
65
+ if (G == 1) y = (yy > 0) ? x : 0;
66
+ }
67
+
68
+ // lrelu
69
+ if (A == 3)
70
+ {
71
+ if (G == 0) y = (x > 0) ? x : x * alpha;
72
+ if (G == 1) y = (yy > 0) ? x : x * alpha;
73
+ }
74
+
75
+ // tanh
76
+ if (A == 4)
77
+ {
78
+ if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
79
+ if (G == 1) y = x * (one - yy * yy);
80
+ if (G == 2) y = x * (one - yy * yy) * (-two * yy);
81
+ }
82
+
83
+ // sigmoid
84
+ if (A == 5)
85
+ {
86
+ if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
87
+ if (G == 1) y = x * yy * (one - yy);
88
+ if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
89
+ }
90
+
91
+ // elu
92
+ if (A == 6)
93
+ {
94
+ if (G == 0) y = (x >= 0) ? x : exp(x) - one;
95
+ if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
96
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
97
+ }
98
+
99
+ // selu
100
+ if (A == 7)
101
+ {
102
+ if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
103
+ if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
104
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
105
+ }
106
+
107
+ // softplus
108
+ if (A == 8)
109
+ {
110
+ if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
111
+ if (G == 1) y = x * (one - exp(-yy));
112
+ if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
113
+ }
114
+
115
+ // swish
116
+ if (A == 9)
117
+ {
118
+ if (G == 0)
119
+ y = (x < -expRange) ? 0 : x / (exp(-x) + one);
120
+ else
121
+ {
122
+ scalar_t c = exp(xref);
123
+ scalar_t d = c + one;
124
+ if (G == 1)
125
+ y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
126
+ else
127
+ y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
128
+ yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
129
+ }
130
+ }
131
+
132
+ // Apply gain.
133
+ y *= gain * dy;
134
+
135
+ // Clamp.
136
+ if (clamp >= 0)
137
+ {
138
+ if (G == 0)
139
+ y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
140
+ else
141
+ y = (yref > -clamp & yref < clamp) ? y : 0;
142
+ }
143
+
144
+ // Store.
145
+ ((T*)p.y)[xi] = (T)y;
146
+ }
147
+ }
148
+
149
+ //------------------------------------------------------------------------
150
+ // CUDA kernel selection.
151
+
152
+ template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p)
153
+ {
154
+ if (p.act == 1) return (void*)bias_act_kernel<T, 1>;
155
+ if (p.act == 2) return (void*)bias_act_kernel<T, 2>;
156
+ if (p.act == 3) return (void*)bias_act_kernel<T, 3>;
157
+ if (p.act == 4) return (void*)bias_act_kernel<T, 4>;
158
+ if (p.act == 5) return (void*)bias_act_kernel<T, 5>;
159
+ if (p.act == 6) return (void*)bias_act_kernel<T, 6>;
160
+ if (p.act == 7) return (void*)bias_act_kernel<T, 7>;
161
+ if (p.act == 8) return (void*)bias_act_kernel<T, 8>;
162
+ if (p.act == 9) return (void*)bias_act_kernel<T, 9>;
163
+ return NULL;
164
+ }
165
+
166
+ //------------------------------------------------------------------------
167
+ // Template specializations.
168
+
169
+ template void* choose_bias_act_kernel<double> (const bias_act_kernel_params& p);
170
+ template void* choose_bias_act_kernel<float> (const bias_act_kernel_params& p);
171
+ template void* choose_bias_act_kernel<c10::Half> (const bias_act_kernel_params& p);
172
+
173
+ //------------------------------------------------------------------------
torch_utils/ops/bias_act.h ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ //------------------------------------------------------------------------
10
+ // CUDA kernel parameters.
11
+
12
+ struct bias_act_kernel_params
13
+ {
14
+ const void* x; // [sizeX]
15
+ const void* b; // [sizeB] or NULL
16
+ const void* xref; // [sizeX] or NULL
17
+ const void* yref; // [sizeX] or NULL
18
+ const void* dy; // [sizeX] or NULL
19
+ void* y; // [sizeX]
20
+
21
+ int grad;
22
+ int act;
23
+ float alpha;
24
+ float gain;
25
+ float clamp;
26
+
27
+ int sizeX;
28
+ int sizeB;
29
+ int stepB;
30
+ int loopX;
31
+ };
32
+
33
+ //------------------------------------------------------------------------
34
+ // CUDA kernel selection.
35
+
36
+ template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p);
37
+
38
+ //------------------------------------------------------------------------