praeclarumjj3 commited on
Commit
9eae6e7
1 Parent(s): 5a27e81

:zap: Build App

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
.DS_Store ADDED
Binary file (8.2 kB). View file
 
.gitattributes CHANGED
@@ -14,6 +14,7 @@
14
  *.ot filter=lfs diff=lfs merge=lfs -text
15
  *.parquet filter=lfs diff=lfs merge=lfs -text
16
  *.pb filter=lfs diff=lfs merge=lfs -text
 
17
  *.pt filter=lfs diff=lfs merge=lfs -text
18
  *.pth filter=lfs diff=lfs merge=lfs -text
19
  *.rar filter=lfs diff=lfs merge=lfs -text
@@ -21,7 +22,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
  *.tar.* filter=lfs diff=lfs merge=lfs -text
22
  *.tflite filter=lfs diff=lfs merge=lfs -text
23
  *.tgz filter=lfs diff=lfs merge=lfs -text
24
- *.wasm filter=lfs diff=lfs merge=lfs -text
25
  *.xz filter=lfs diff=lfs merge=lfs -text
26
  *.zip filter=lfs diff=lfs merge=lfs -text
27
  *.zstandard filter=lfs diff=lfs merge=lfs -text
 
14
  *.ot filter=lfs diff=lfs merge=lfs -text
15
  *.parquet filter=lfs diff=lfs merge=lfs -text
16
  *.pb filter=lfs diff=lfs merge=lfs -text
17
+ *.pkl filter=lfs diff=lfs merge=lfs -text
18
  *.pt filter=lfs diff=lfs merge=lfs -text
19
  *.pth filter=lfs diff=lfs merge=lfs -text
20
  *.rar filter=lfs diff=lfs merge=lfs -text
 
22
  *.tar.* filter=lfs diff=lfs merge=lfs -text
23
  *.tflite filter=lfs diff=lfs merge=lfs -text
24
  *.tgz filter=lfs diff=lfs merge=lfs -text
 
25
  *.xz filter=lfs diff=lfs merge=lfs -text
26
  *.zip filter=lfs diff=lfs merge=lfs -text
27
  *.zstandard filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ web.sh
2
+ *__pycache__
3
+ test_512_old/
app.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ import dnnlib
3
+ from PIL import Image
4
+ import numpy as np
5
+ import torch
6
+ import legacy
7
+ import cv2
8
+ import paddlehub as hub
9
+
10
+ u2net = hub.Module(name='U2Net')
11
+
12
+ # gradio app imports
13
+ import gradio as gr
14
+ from torchvision.transforms import ToTensor, ToPILImage
15
+ image_to_tensor = ToTensor()
16
+ tensor_to_image = ToPILImage()
17
+
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ class_idx = None
20
+ truncation_psi = 0.1
21
+
22
+ def create_model(network_pkl):
23
+ print('Loading networks from "%s"...' % network_pkl)
24
+ with dnnlib.util.open_url(network_pkl) as f:
25
+ G = legacy.load_network_pkl(f)['G_ema'] # type: ignore
26
+
27
+ G = G.eval().to(device)
28
+ netG_params = sum(p.numel() for p in G.parameters())
29
+ print("Generator Params: {} M".format(netG_params/1e6))
30
+ return G
31
+
32
+ def fcf_inpaint(G, org_img, erased_img, mask):
33
+ label = torch.zeros([1, G.c_dim], device=device)
34
+ if G.c_dim != 0:
35
+ if class_idx is None:
36
+ ValueError("class_idx can't be None.")
37
+ label[:, class_idx] = 1
38
+ else:
39
+ if class_idx is not None:
40
+ print ('warn: --class=lbl ignored when running on an unconditional network')
41
+
42
+ pred_img = G(img=torch.cat([0.5 - mask, erased_img], dim=1), c=label, truncation_psi=truncation_psi, noise_mode='const')
43
+ comp_img = mask.to(device) * pred_img + (1 - mask).to(device) * org_img.to(device)
44
+ return comp_img
45
+
46
+ def show_images(img):
47
+ """ Display a batch of images inline. """
48
+ return Image.fromarray(img)
49
+
50
+ def denorm(img):
51
+ img = np.asarray(img[0].cpu(), dtype=np.float32).transpose(1, 2, 0)
52
+ img = (img +1) * 127.5
53
+ img = np.rint(img).clip(0, 255).astype(np.uint8)
54
+ return img
55
+
56
+ def pil_to_numpy(pil_img: Image) -> Tuple[torch.Tensor, torch.Tensor]:
57
+ img = np.array(pil_img)
58
+ return torch.from_numpy(img)[None].permute(0, 3, 1, 2).float() / 127.5 - 1
59
+
60
+ def inpaint(input_img, mask, option):
61
+ width, height = input_img.size
62
+
63
+ if option == "Automatic":
64
+ result = u2net.Segmentation(
65
+ images=[cv2.cvtColor(np.array(input_img), cv2.COLOR_RGB2BGR)],
66
+ paths=None,
67
+ batch_size=1,
68
+ input_size=320,
69
+ output_dir='output',
70
+ visualization=True)
71
+ mask = Image.fromarray(result[0]['mask'])
72
+ else:
73
+ mask = mask.resize((width,height))
74
+
75
+ mask = mask.convert('L')
76
+ mask = np.array(mask) / 255.
77
+ mask = cv2.resize(mask,
78
+ (512, 512), interpolation=cv2.INTER_NEAREST)
79
+ mask_tensor = torch.from_numpy(mask).to(torch.float32)
80
+ mask_tensor = mask_tensor.unsqueeze(0)
81
+ mask_tensor = mask_tensor.unsqueeze(0).to(device)
82
+
83
+ rgb = input_img.convert('RGB')
84
+ rgb = np.array(rgb)
85
+ rgb = cv2.resize(rgb,
86
+ (512, 512), interpolation=cv2.INTER_AREA)
87
+ rgb = rgb.transpose(2,0,1)
88
+ rgb = torch.from_numpy(rgb.astype(np.float32)).unsqueeze(0)
89
+ rgb = (rgb.to(torch.float32) / 127.5 - 1).to(device)
90
+ rgb_erased = rgb.clone()
91
+ rgb_erased = rgb_erased * (1 - mask_tensor) # erase rgb
92
+ rgb_erased = rgb_erased.to(torch.float32)
93
+
94
+ # model = create_model("models/places_512.pkl")
95
+ # comp_img = fcf_inpaint(G=model, org_img=rgb.to(torch.float32), erased_img=rgb_erased.to(torch.float32), mask=mask_tensor.to(torch.float32))
96
+ rgb_erased = denorm(rgb_erased)
97
+ # comp_img = denorm(comp_img)
98
+
99
+ return show_images(rgb_erased), show_images(rgb_erased)
100
+
101
+ gradio_inputs = [gr.inputs.Image(type='pil',
102
+ tool=None,
103
+ label="Input Image"),
104
+ gr.inputs.Image(type='pil',source="canvas", label="Mask", invert_colors=True),
105
+ gr.inputs.Radio(choices=["Automatic", "Manual"], type="value", default="Manual", label="Masking Choice")
106
+ # gr.inputs.Image(type='pil',
107
+ # tool=None,
108
+ # label="Mask")]
109
+ ]
110
+
111
+ # gradio_outputs = [gr.outputs.Image(label='Auto-Detected Mask (From drawn black pixels)')]
112
+
113
+ gradio_outputs = [gr.outputs.Image(label='Image with Hole'),
114
+ gr.outputs.Image(label='Inpainted Image')]
115
+
116
+ examples = [['test_512/person512.png', 'test_512/mask_auto.png', 'Automatic'],
117
+ ['test_512/a_org.png', 'test_512/a_mask.png', 'Manual'],
118
+ ['test_512/c_org.png', 'test_512/b_mask.png', 'Manual'],
119
+ ['test_512/b_org.png', 'test_512/c_mask.png', 'Manual'],
120
+ ['test_512/d_org.png', 'test_512/d_mask.png', 'Manual'],
121
+ ['test_512/e_org.png', 'test_512/e_mask.png', 'Manual'],
122
+ ['test_512/f_org.png', 'test_512/f_mask.png', 'Manual'],
123
+ ['test_512/g_org.png', 'test_512/g_mask.png', 'Manual'],
124
+ ['test_512/h_org.png', 'test_512/h_mask.png', 'Manual'],
125
+ ['test_512/i_org.png', 'test_512/i_mask.png', 'Manual']]
126
+
127
+ title = "FcF-Inpainting"
128
+ description = "[Note: Queue time may take upto 20 seconds! The image and mask are resized to 512x512 before inpainting.] To use FcF-Inpainting: \n \
129
+ (1) Upload an Image; \n \
130
+ (2) Draw (Manual) a Mask on the White Canvas or Generate a mask using U2Net by selecting the Automatic option; \n \
131
+ (3) Click on Submit and witness the MAGIC! 🪄 ✨ ✨"
132
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.10741' target='_blank'> Keys to Better Image Inpainting: Structure and Texture Go Hand in Hand</a> | <a href='https://github.com/SHI-Labs/FcF-Inpainting' target='_blank'>Github Repo</a></p>"
133
+
134
+ css = ".image-preview {height: 32rem; width: auto;} .output-image {height: 32rem; width: auto;} .panel-buttons { display: flex; flex-direction: row;}"
135
+
136
+ iface = gr.Interface(fn=inpaint, inputs=gradio_inputs,
137
+ outputs=gradio_outputs,
138
+ css=css,
139
+ layout="vertical",
140
+ examples_per_page=5,
141
+ thumbnail="fcf_gan.png",
142
+ allow_flagging="never",
143
+ examples=examples, title=title,
144
+ description=description, article=article)
145
+ iface.launch(enable_queue=True,
146
+ share=True, server_name="0.0.0.0")
dnnlib/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ from .util import EasyDict, make_cache_dir_path
dnnlib/util.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Miscellaneous utility classes and functions."""
10
+
11
+ import ctypes
12
+ import fnmatch
13
+ import importlib
14
+ import inspect
15
+ import numpy as np
16
+ import os
17
+ import shutil
18
+ import sys
19
+ import types
20
+ import io
21
+ import pickle
22
+ import re
23
+ import requests
24
+ import html
25
+ import hashlib
26
+ import glob
27
+ import tempfile
28
+ import urllib
29
+ import urllib.request
30
+ import uuid
31
+
32
+ from distutils.util import strtobool
33
+ from typing import Any, List, Tuple, Union
34
+
35
+
36
+ # Util classes
37
+ # ------------------------------------------------------------------------------------------
38
+
39
+
40
+ class EasyDict(dict):
41
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
42
+
43
+ def __getattr__(self, name: str) -> Any:
44
+ try:
45
+ return self[name]
46
+ except KeyError:
47
+ raise AttributeError(name)
48
+
49
+ def __setattr__(self, name: str, value: Any) -> None:
50
+ self[name] = value
51
+
52
+ def __delattr__(self, name: str) -> None:
53
+ del self[name]
54
+
55
+
56
+ class Logger(object):
57
+ """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
58
+
59
+ def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
60
+ self.file = None
61
+
62
+ if file_name is not None:
63
+ self.file = open(file_name, file_mode)
64
+
65
+ self.should_flush = should_flush
66
+ self.stdout = sys.stdout
67
+ self.stderr = sys.stderr
68
+
69
+ sys.stdout = self
70
+ sys.stderr = self
71
+
72
+ def __enter__(self) -> "Logger":
73
+ return self
74
+
75
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
76
+ self.close()
77
+
78
+ def write(self, text: Union[str, bytes]) -> None:
79
+ """Write text to stdout (and a file) and optionally flush."""
80
+ if isinstance(text, bytes):
81
+ text = text.decode()
82
+ if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
83
+ return
84
+
85
+ if self.file is not None:
86
+ self.file.write(text)
87
+
88
+ self.stdout.write(text)
89
+
90
+ if self.should_flush:
91
+ self.flush()
92
+
93
+ def flush(self) -> None:
94
+ """Flush written text to both stdout and a file, if open."""
95
+ if self.file is not None:
96
+ self.file.flush()
97
+
98
+ self.stdout.flush()
99
+
100
+ def close(self) -> None:
101
+ """Flush, close possible files, and remove stdout/stderr mirroring."""
102
+ self.flush()
103
+
104
+ # if using multiple loggers, prevent closing in wrong order
105
+ if sys.stdout is self:
106
+ sys.stdout = self.stdout
107
+ if sys.stderr is self:
108
+ sys.stderr = self.stderr
109
+
110
+ if self.file is not None:
111
+ self.file.close()
112
+ self.file = None
113
+
114
+
115
+ # Cache directories
116
+ # ------------------------------------------------------------------------------------------
117
+
118
+ _dnnlib_cache_dir = None
119
+
120
+ def set_cache_dir(path: str) -> None:
121
+ global _dnnlib_cache_dir
122
+ _dnnlib_cache_dir = path
123
+
124
+ def make_cache_dir_path(*paths: str) -> str:
125
+ if _dnnlib_cache_dir is not None:
126
+ return os.path.join(_dnnlib_cache_dir, *paths)
127
+ if 'DNNLIB_CACHE_DIR' in os.environ:
128
+ return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
129
+ if 'HOME' in os.environ:
130
+ return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
131
+ if 'USERPROFILE' in os.environ:
132
+ return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
133
+ return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
134
+
135
+ # Small util functions
136
+ # ------------------------------------------------------------------------------------------
137
+
138
+
139
+ def format_time(seconds: Union[int, float]) -> str:
140
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
141
+ s = int(np.rint(seconds))
142
+
143
+ if s < 60:
144
+ return "{0}s".format(s)
145
+ elif s < 60 * 60:
146
+ return "{0}m {1:02}s".format(s // 60, s % 60)
147
+ elif s < 24 * 60 * 60:
148
+ return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
149
+ else:
150
+ return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
151
+
152
+
153
+ def ask_yes_no(question: str) -> bool:
154
+ """Ask the user the question until the user inputs a valid answer."""
155
+ while True:
156
+ try:
157
+ print("{0} [y/n]".format(question))
158
+ return strtobool(input().lower())
159
+ except ValueError:
160
+ pass
161
+
162
+
163
+ def tuple_product(t: Tuple) -> Any:
164
+ """Calculate the product of the tuple elements."""
165
+ result = 1
166
+
167
+ for v in t:
168
+ result *= v
169
+
170
+ return result
171
+
172
+
173
+ _str_to_ctype = {
174
+ "uint8": ctypes.c_ubyte,
175
+ "uint16": ctypes.c_uint16,
176
+ "uint32": ctypes.c_uint32,
177
+ "uint64": ctypes.c_uint64,
178
+ "int8": ctypes.c_byte,
179
+ "int16": ctypes.c_int16,
180
+ "int32": ctypes.c_int32,
181
+ "int64": ctypes.c_int64,
182
+ "float32": ctypes.c_float,
183
+ "float64": ctypes.c_double
184
+ }
185
+
186
+
187
+ def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
188
+ """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
189
+ type_str = None
190
+
191
+ if isinstance(type_obj, str):
192
+ type_str = type_obj
193
+ elif hasattr(type_obj, "__name__"):
194
+ type_str = type_obj.__name__
195
+ elif hasattr(type_obj, "name"):
196
+ type_str = type_obj.name
197
+ else:
198
+ raise RuntimeError("Cannot infer type name from input")
199
+
200
+ assert type_str in _str_to_ctype.keys()
201
+
202
+ my_dtype = np.dtype(type_str)
203
+ my_ctype = _str_to_ctype[type_str]
204
+
205
+ assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
206
+
207
+ return my_dtype, my_ctype
208
+
209
+
210
+ def is_pickleable(obj: Any) -> bool:
211
+ try:
212
+ with io.BytesIO() as stream:
213
+ pickle.dump(obj, stream)
214
+ return True
215
+ except:
216
+ return False
217
+
218
+
219
+ # Functionality to import modules/objects by name, and call functions by name
220
+ # ------------------------------------------------------------------------------------------
221
+
222
+ def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
223
+ """Searches for the underlying module behind the name to some python object.
224
+ Returns the module and the object name (original name with module part removed)."""
225
+
226
+ # allow convenience shorthands, substitute them by full names
227
+ obj_name = re.sub("^np.", "numpy.", obj_name)
228
+ obj_name = re.sub("^tf.", "tensorflow.", obj_name)
229
+
230
+ # list alternatives for (module_name, local_obj_name)
231
+ parts = obj_name.split(".")
232
+ name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
233
+
234
+ # try each alternative in turn
235
+ for module_name, local_obj_name in name_pairs:
236
+ try:
237
+ module = importlib.import_module(module_name) # may raise ImportError
238
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
239
+ return module, local_obj_name
240
+ except:
241
+ pass
242
+
243
+ # maybe some of the modules themselves contain errors?
244
+ for module_name, _local_obj_name in name_pairs:
245
+ try:
246
+ importlib.import_module(module_name) # may raise ImportError
247
+ except ImportError:
248
+ if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
249
+ raise
250
+
251
+ # maybe the requested attribute is missing?
252
+ for module_name, local_obj_name in name_pairs:
253
+ try:
254
+ module = importlib.import_module(module_name) # may raise ImportError
255
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
256
+ except ImportError:
257
+ pass
258
+
259
+ # we are out of luck, but we have no idea why
260
+ raise ImportError(obj_name)
261
+
262
+
263
+ def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
264
+ """Traverses the object name and returns the last (rightmost) python object."""
265
+ if obj_name == '':
266
+ return module
267
+ obj = module
268
+ for part in obj_name.split("."):
269
+ obj = getattr(obj, part)
270
+ return obj
271
+
272
+
273
+ def get_obj_by_name(name: str) -> Any:
274
+ """Finds the python object with the given name."""
275
+ module, obj_name = get_module_from_obj_name(name)
276
+ return get_obj_from_module(module, obj_name)
277
+
278
+
279
+ def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
280
+ """Finds the python object with the given name and calls it as a function."""
281
+ assert func_name is not None
282
+ func_obj = get_obj_by_name(func_name)
283
+ assert callable(func_obj)
284
+ return func_obj(*args, **kwargs)
285
+
286
+
287
+ def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
288
+ """Finds the python class with the given name and constructs it with the given arguments."""
289
+ return call_func_by_name(*args, func_name=class_name, **kwargs)
290
+
291
+
292
+ def get_module_dir_by_obj_name(obj_name: str) -> str:
293
+ """Get the directory path of the module containing the given object name."""
294
+ module, _ = get_module_from_obj_name(obj_name)
295
+ return os.path.dirname(inspect.getfile(module))
296
+
297
+
298
+ def is_top_level_function(obj: Any) -> bool:
299
+ """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
300
+ return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
301
+
302
+
303
+ def get_top_level_function_name(obj: Any) -> str:
304
+ """Return the fully-qualified name of a top-level function."""
305
+ assert is_top_level_function(obj)
306
+ module = obj.__module__
307
+ if module == '__main__':
308
+ module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
309
+ return module + "." + obj.__name__
310
+
311
+
312
+ # File system helpers
313
+ # ------------------------------------------------------------------------------------------
314
+
315
+ def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
316
+ """List all files recursively in a given directory while ignoring given file and directory names.
317
+ Returns list of tuples containing both absolute and relative paths."""
318
+ assert os.path.isdir(dir_path)
319
+ base_name = os.path.basename(os.path.normpath(dir_path))
320
+
321
+ if ignores is None:
322
+ ignores = []
323
+
324
+ result = []
325
+
326
+ for root, dirs, files in os.walk(dir_path, topdown=True):
327
+ for ignore_ in ignores:
328
+ dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
329
+
330
+ # dirs need to be edited in-place
331
+ for d in dirs_to_remove:
332
+ dirs.remove(d)
333
+
334
+ files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
335
+
336
+ absolute_paths = [os.path.join(root, f) for f in files]
337
+ relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
338
+
339
+ if add_base_to_relative:
340
+ relative_paths = [os.path.join(base_name, p) for p in relative_paths]
341
+
342
+ assert len(absolute_paths) == len(relative_paths)
343
+ result += zip(absolute_paths, relative_paths)
344
+
345
+ return result
346
+
347
+
348
+ def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
349
+ """Takes in a list of tuples of (src, dst) paths and copies files.
350
+ Will create all necessary directories."""
351
+ for file in files:
352
+ target_dir_name = os.path.dirname(file[1])
353
+
354
+ # will create all intermediate-level directories
355
+ if not os.path.exists(target_dir_name):
356
+ os.makedirs(target_dir_name)
357
+
358
+ shutil.copyfile(file[0], file[1])
359
+
360
+
361
+ # URL helpers
362
+ # ------------------------------------------------------------------------------------------
363
+
364
+ def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
365
+ """Determine whether the given object is a valid URL string."""
366
+ if not isinstance(obj, str) or not "://" in obj:
367
+ return False
368
+ if allow_file_urls and obj.startswith('file://'):
369
+ return True
370
+ try:
371
+ res = requests.compat.urlparse(obj)
372
+ if not res.scheme or not res.netloc or not "." in res.netloc:
373
+ return False
374
+ res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
375
+ if not res.scheme or not res.netloc or not "." in res.netloc:
376
+ return False
377
+ except:
378
+ return False
379
+ return True
380
+
381
+
382
+ def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
383
+ """Download the given URL and return a binary-mode file object to access the data."""
384
+ assert num_attempts >= 1
385
+ assert not (return_filename and (not cache))
386
+
387
+ # Doesn't look like an URL scheme so interpret it as a local filename.
388
+ if not re.match('^[a-z]+://', url):
389
+ return url if return_filename else open(url, "rb")
390
+
391
+ # Handle file URLs. This code handles unusual file:// patterns that
392
+ # arise on Windows:
393
+ #
394
+ # file:///c:/foo.txt
395
+ #
396
+ # which would translate to a local '/c:/foo.txt' filename that's
397
+ # invalid. Drop the forward slash for such pathnames.
398
+ #
399
+ # If you touch this code path, you should test it on both Linux and
400
+ # Windows.
401
+ #
402
+ # Some internet resources suggest using urllib.request.url2pathname() but
403
+ # but that converts forward slashes to backslashes and this causes
404
+ # its own set of problems.
405
+ if url.startswith('file://'):
406
+ filename = urllib.parse.urlparse(url).path
407
+ if re.match(r'^/[a-zA-Z]:', filename):
408
+ filename = filename[1:]
409
+ return filename if return_filename else open(filename, "rb")
410
+
411
+ assert is_url(url)
412
+
413
+ # Lookup from cache.
414
+ if cache_dir is None:
415
+ cache_dir = make_cache_dir_path('downloads')
416
+
417
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
418
+ if cache:
419
+ cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
420
+ if len(cache_files) == 1:
421
+ filename = cache_files[0]
422
+ return filename if return_filename else open(filename, "rb")
423
+
424
+ # Download.
425
+ url_name = None
426
+ url_data = None
427
+ with requests.Session() as session:
428
+ if verbose:
429
+ print("Downloading %s ..." % url, end="", flush=True)
430
+ for attempts_left in reversed(range(num_attempts)):
431
+ try:
432
+ with session.get(url) as res:
433
+ res.raise_for_status()
434
+ if len(res.content) == 0:
435
+ raise IOError("No data received")
436
+
437
+ if len(res.content) < 8192:
438
+ content_str = res.content.decode("utf-8")
439
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
440
+ links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
441
+ if len(links) == 1:
442
+ url = requests.compat.urljoin(url, links[0])
443
+ raise IOError("Google Drive virus checker nag")
444
+ if "Google Drive - Quota exceeded" in content_str:
445
+ raise IOError("Google Drive download quota exceeded -- please try again later")
446
+
447
+ match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
448
+ url_name = match[1] if match else url
449
+ url_data = res.content
450
+ if verbose:
451
+ print(" done")
452
+ break
453
+ except KeyboardInterrupt:
454
+ raise
455
+ except:
456
+ if not attempts_left:
457
+ if verbose:
458
+ print(" failed")
459
+ raise
460
+ if verbose:
461
+ print(".", end="", flush=True)
462
+
463
+ # Save to cache.
464
+ if cache:
465
+ safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
466
+ cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
467
+ temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
468
+ os.makedirs(cache_dir, exist_ok=True)
469
+ with open(temp_file, "wb") as f:
470
+ f.write(url_data)
471
+ os.replace(temp_file, cache_file) # atomic
472
+ if return_filename:
473
+ return cache_file
474
+
475
+ # Return data as file object.
476
+ assert not return_filename
477
+ return io.BytesIO(url_data)
fcf_gan.png ADDED
legacy.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import click
10
+ import pickle
11
+ import re
12
+ import copy
13
+ import numpy as np
14
+ import torch
15
+ import dnnlib
16
+ from torch_utils import misc
17
+
18
+ #----------------------------------------------------------------------------
19
+
20
+ def load_network_pkl(f, force_fp16=False):
21
+ data = _LegacyUnpickler(f).load()
22
+
23
+ # Legacy TensorFlow pickle => convert.
24
+ if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data):
25
+ tf_G, tf_D, tf_Gs = data
26
+ G = convert_tf_generator(tf_G)
27
+ D = convert_tf_discriminator(tf_D)
28
+ G_ema = convert_tf_generator(tf_Gs)
29
+ data = dict(G=G, D=D, G_ema=G_ema)
30
+
31
+ # Add missing fields.
32
+ if 'training_set_kwargs' not in data:
33
+ data['training_set_kwargs'] = None
34
+ if 'augment_pipe' not in data:
35
+ data['augment_pipe'] = None
36
+
37
+ # Validate contents.
38
+ assert isinstance(data['G'], torch.nn.Module)
39
+ assert isinstance(data['D'], torch.nn.Module)
40
+ assert isinstance(data['G_ema'], torch.nn.Module)
41
+ assert isinstance(data['training_set_kwargs'], (dict, type(None)))
42
+ assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None)))
43
+
44
+ # Force FP16.
45
+ if force_fp16:
46
+ for key in ['G', 'D', 'G_ema']:
47
+ old = data[key]
48
+ kwargs = copy.deepcopy(old.init_kwargs)
49
+ if key.startswith('G'):
50
+ kwargs.synthesis_kwargs = dnnlib.EasyDict(kwargs.get('synthesis_kwargs', {}))
51
+ kwargs.synthesis_kwargs.num_fp16_res = 4
52
+ kwargs.synthesis_kwargs.conv_clamp = 256
53
+ if key.startswith('D'):
54
+ kwargs.num_fp16_res = 4
55
+ kwargs.conv_clamp = 256
56
+ if kwargs != old.init_kwargs:
57
+ new = type(old)(**kwargs).eval().requires_grad_(False)
58
+ misc.copy_params_and_buffers(old, new, require_all=True)
59
+ data[key] = new
60
+ return data
61
+
62
+ #----------------------------------------------------------------------------
63
+
64
+ class _TFNetworkStub(dnnlib.EasyDict):
65
+ pass
66
+
67
+ class _LegacyUnpickler(pickle.Unpickler):
68
+ def find_class(self, module, name):
69
+ if module == 'dnnlib.tflib.network' and name == 'Network':
70
+ return _TFNetworkStub
71
+ return super().find_class(module, name)
72
+
73
+ #----------------------------------------------------------------------------
74
+
75
+ def _collect_tf_params(tf_net):
76
+ # pylint: disable=protected-access
77
+ tf_params = dict()
78
+ def recurse(prefix, tf_net):
79
+ for name, value in tf_net.variables:
80
+ tf_params[prefix + name] = value
81
+ for name, comp in tf_net.components.items():
82
+ recurse(prefix + name + '/', comp)
83
+ recurse('', tf_net)
84
+ return tf_params
85
+
86
+ #----------------------------------------------------------------------------
87
+
88
+ def _populate_module_params(module, *patterns):
89
+ for name, tensor in misc.named_params_and_buffers(module):
90
+ found = False
91
+ value = None
92
+ for pattern, value_fn in zip(patterns[0::2], patterns[1::2]):
93
+ match = re.fullmatch(pattern, name)
94
+ if match:
95
+ found = True
96
+ if value_fn is not None:
97
+ value = value_fn(*match.groups())
98
+ break
99
+ try:
100
+ assert found
101
+ if value is not None:
102
+ tensor.copy_(torch.from_numpy(np.array(value)))
103
+ except:
104
+ print(name, list(tensor.shape))
105
+ raise
106
+
107
+ #----------------------------------------------------------------------------
108
+
109
+ def convert_tf_generator(tf_G):
110
+ if tf_G.version < 4:
111
+ raise ValueError('TensorFlow pickle version too low')
112
+
113
+ # Collect kwargs.
114
+ tf_kwargs = tf_G.static_kwargs
115
+ known_kwargs = set()
116
+ def kwarg(tf_name, default=None, none=None):
117
+ known_kwargs.add(tf_name)
118
+ val = tf_kwargs.get(tf_name, default)
119
+ return val if val is not None else none
120
+
121
+ # Convert kwargs.
122
+ kwargs = dnnlib.EasyDict(
123
+ z_dim = kwarg('latent_size', 512),
124
+ c_dim = kwarg('label_size', 0),
125
+ w_dim = kwarg('dlatent_size', 512),
126
+ img_resolution = kwarg('resolution', 1024),
127
+ img_channels = kwarg('num_channels', 3),
128
+ mapping_kwargs = dnnlib.EasyDict(
129
+ num_layers = kwarg('mapping_layers', 8),
130
+ embed_features = kwarg('label_fmaps', None),
131
+ layer_features = kwarg('mapping_fmaps', None),
132
+ activation = kwarg('mapping_nonlinearity', 'lrelu'),
133
+ lr_multiplier = kwarg('mapping_lrmul', 0.01),
134
+ w_avg_beta = kwarg('w_avg_beta', 0.995, none=1),
135
+ ),
136
+ synthesis_kwargs = dnnlib.EasyDict(
137
+ channel_base = kwarg('fmap_base', 16384) * 2,
138
+ channel_max = kwarg('fmap_max', 512),
139
+ num_fp16_res = kwarg('num_fp16_res', 0),
140
+ conv_clamp = kwarg('conv_clamp', None),
141
+ architecture = kwarg('architecture', 'skip'),
142
+ resample_filter = kwarg('resample_kernel', [1,3,3,1]),
143
+ use_noise = kwarg('use_noise', True),
144
+ activation = kwarg('nonlinearity', 'lrelu'),
145
+ ),
146
+ )
147
+
148
+ # Check for unknown kwargs.
149
+ kwarg('truncation_psi')
150
+ kwarg('truncation_cutoff')
151
+ kwarg('style_mixing_prob')
152
+ kwarg('structure')
153
+ unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
154
+ if len(unknown_kwargs) > 0:
155
+ raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
156
+
157
+ # Collect params.
158
+ tf_params = _collect_tf_params(tf_G)
159
+ for name, value in list(tf_params.items()):
160
+ match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name)
161
+ if match:
162
+ r = kwargs.img_resolution // (2 ** int(match.group(1)))
163
+ tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value
164
+ kwargs.synthesis.kwargs.architecture = 'orig'
165
+ #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
166
+
167
+ # Convert params.
168
+ from training import networks
169
+ G = networks.Generator(**kwargs).eval().requires_grad_(False)
170
+ # pylint: disable=unnecessary-lambda
171
+ _populate_module_params(G,
172
+ r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'],
173
+ r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(),
174
+ r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'],
175
+ r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(),
176
+ r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'],
177
+ r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0],
178
+ r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1),
179
+ r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'],
180
+ r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0],
181
+ r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'],
182
+ r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(),
183
+ r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1,
184
+ r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
185
+ r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'],
186
+ r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0],
187
+ r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'],
188
+ r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(),
189
+ r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1,
190
+ r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1),
191
+ r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'],
192
+ r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0],
193
+ r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'],
194
+ r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(),
195
+ r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1,
196
+ r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1),
197
+ r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'],
198
+ r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(),
199
+ r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1,
200
+ r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
201
+ r'.*\.resample_filter', None,
202
+ )
203
+ return G
204
+
205
+ #----------------------------------------------------------------------------
206
+
207
+ def convert_tf_discriminator(tf_D):
208
+ if tf_D.version < 4:
209
+ raise ValueError('TensorFlow pickle version too low')
210
+
211
+ # Collect kwargs.
212
+ tf_kwargs = tf_D.static_kwargs
213
+ known_kwargs = set()
214
+ def kwarg(tf_name, default=None):
215
+ known_kwargs.add(tf_name)
216
+ return tf_kwargs.get(tf_name, default)
217
+
218
+ # Convert kwargs.
219
+ kwargs = dnnlib.EasyDict(
220
+ c_dim = kwarg('label_size', 0),
221
+ img_resolution = kwarg('resolution', 1024),
222
+ img_channels = kwarg('num_channels', 3),
223
+ architecture = kwarg('architecture', 'resnet'),
224
+ channel_base = kwarg('fmap_base', 16384) * 2,
225
+ channel_max = kwarg('fmap_max', 512),
226
+ num_fp16_res = kwarg('num_fp16_res', 0),
227
+ conv_clamp = kwarg('conv_clamp', None),
228
+ cmap_dim = kwarg('mapping_fmaps', None),
229
+ block_kwargs = dnnlib.EasyDict(
230
+ activation = kwarg('nonlinearity', 'lrelu'),
231
+ resample_filter = kwarg('resample_kernel', [1,3,3,1]),
232
+ freeze_layers = kwarg('freeze_layers', 0),
233
+ ),
234
+ mapping_kwargs = dnnlib.EasyDict(
235
+ num_layers = kwarg('mapping_layers', 0),
236
+ embed_features = kwarg('mapping_fmaps', None),
237
+ layer_features = kwarg('mapping_fmaps', None),
238
+ activation = kwarg('nonlinearity', 'lrelu'),
239
+ lr_multiplier = kwarg('mapping_lrmul', 0.1),
240
+ ),
241
+ epilogue_kwargs = dnnlib.EasyDict(
242
+ mbstd_group_size = kwarg('mbstd_group_size', None),
243
+ mbstd_num_channels = kwarg('mbstd_num_features', 1),
244
+ activation = kwarg('nonlinearity', 'lrelu'),
245
+ ),
246
+ )
247
+
248
+ # Check for unknown kwargs.
249
+ kwarg('structure')
250
+ unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
251
+ if len(unknown_kwargs) > 0:
252
+ raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
253
+
254
+ # Collect params.
255
+ tf_params = _collect_tf_params(tf_D)
256
+ for name, value in list(tf_params.items()):
257
+ match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name)
258
+ if match:
259
+ r = kwargs.img_resolution // (2 ** int(match.group(1)))
260
+ tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value
261
+ kwargs.architecture = 'orig'
262
+ #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
263
+
264
+ # Convert params.
265
+ from training import networks
266
+ D = networks.Discriminator(**kwargs).eval().requires_grad_(False)
267
+ # pylint: disable=unnecessary-lambda
268
+ _populate_module_params(D,
269
+ r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1),
270
+ r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'],
271
+ r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1),
272
+ r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'],
273
+ r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1),
274
+ r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(),
275
+ r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'],
276
+ r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(),
277
+ r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'],
278
+ r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1),
279
+ r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'],
280
+ r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(),
281
+ r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'],
282
+ r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(),
283
+ r'b4\.out\.bias', lambda: tf_params[f'Output/bias'],
284
+ r'.*\.resample_filter', None,
285
+ )
286
+ return D
287
+
288
+ #----------------------------------------------------------------------------
289
+
290
+ @click.command()
291
+ @click.option('--source', help='Input pickle', required=True, metavar='PATH')
292
+ @click.option('--dest', help='Output pickle', required=True, metavar='PATH')
293
+ @click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True)
294
+ def convert_network_pickle(source, dest, force_fp16):
295
+ """Convert legacy network pickle into the native PyTorch format.
296
+
297
+ The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA.
298
+ It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks.
299
+
300
+ Example:
301
+
302
+ \b
303
+ python legacy.py \\
304
+ --source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\
305
+ --dest=stylegan2-cat-config-f.pkl
306
+ """
307
+ print(f'Loading "{source}"...')
308
+ with dnnlib.util.open_url(source) as f:
309
+ data = load_network_pkl(f, force_fp16=force_fp16)
310
+ print(f'Saving "{dest}"...')
311
+ with open(dest, 'wb') as f:
312
+ pickle.dump(data, f)
313
+ print('Done.')
314
+
315
+ #----------------------------------------------------------------------------
316
+
317
+ if __name__ == "__main__":
318
+ convert_network_pickle() # pylint: disable=no-value-for-parameter
319
+
320
+ #----------------------------------------------------------------------------
output/result_0.png ADDED
output/result_mask_0.png ADDED
requirements.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ icecream
2
+ psutil
3
+ click
4
+ requests
5
+ matplotlib
6
+ tqdm
7
+ ninja
8
+ imageio-ffmpeg==0.4.3
9
+ scipy
10
+ termcolor>=1.1
11
+ colorama
12
+ cvbase
13
+ opencv-python
14
+ etaprogress
15
+ scikit-learn
16
+ pandas
17
+ tensorboard
18
+ pydrive2
19
+ pandas
20
+ easydict
21
+ kornia==0.5.0
22
+ gradio
23
+ ipython
24
+ Jinja2
25
+ paddlepaddle
26
+ paddlehub
setup.sh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+ eval "$(conda shell.bash hook)"
3
+ conda create --name fcf -y python=3.7
4
+ conda activate fcf
5
+ conda env list
6
+ conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=11.0 -c pytorch
7
+ pip3 install -r requirements.txt
test_512/.DS_Store ADDED
Binary file (6.15 kB). View file
 
test_512/a_mask.png ADDED
test_512/a_org.png ADDED
test_512/b_mask.png ADDED
test_512/b_org.png ADDED
test_512/c_mask.png ADDED
test_512/c_org.png ADDED
test_512/d_mask.png ADDED
test_512/d_org.png ADDED
test_512/e_mask.png ADDED
test_512/e_org.png ADDED
test_512/f_mask.png ADDED
test_512/f_org.png ADDED
test_512/g_mask.png ADDED
test_512/g_org.png ADDED
test_512/h_mask.png ADDED
test_512/h_org.png ADDED
test_512/i_mask.png ADDED
test_512/i_org.png ADDED
test_512/mask_auto.png ADDED
test_512/person512.png ADDED
torch_utils/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ # empty
torch_utils/custom_ops.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import os
10
+ import glob
11
+ import torch
12
+ import torch.utils.cpp_extension
13
+ import importlib
14
+ import hashlib
15
+ import shutil
16
+ from pathlib import Path
17
+
18
+ from torch.utils.file_baton import FileBaton
19
+
20
+ #----------------------------------------------------------------------------
21
+ # Global options.
22
+
23
+ verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
24
+
25
+ #----------------------------------------------------------------------------
26
+ # Internal helper funcs.
27
+
28
+ def _find_compiler_bindir():
29
+ patterns = [
30
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
31
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
32
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
33
+ 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
34
+ ]
35
+ for pattern in patterns:
36
+ matches = sorted(glob.glob(pattern))
37
+ if len(matches):
38
+ return matches[-1]
39
+ return None
40
+
41
+ #----------------------------------------------------------------------------
42
+ # Main entry point for compiling and loading C++/CUDA plugins.
43
+
44
+ _cached_plugins = dict()
45
+
46
+ def get_plugin(module_name, sources, **build_kwargs):
47
+ assert verbosity in ['none', 'brief', 'full']
48
+
49
+ # Already cached?
50
+ if module_name in _cached_plugins:
51
+ return _cached_plugins[module_name]
52
+
53
+ # Print status.
54
+ if verbosity == 'full':
55
+ print(f'Setting up PyTorch plugin "{module_name}"...')
56
+ elif verbosity == 'brief':
57
+ print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
58
+
59
+ try: # pylint: disable=too-many-nested-blocks
60
+ # Make sure we can find the necessary compiler binaries.
61
+ if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
62
+ compiler_bindir = _find_compiler_bindir()
63
+ if compiler_bindir is None:
64
+ raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
65
+ os.environ['PATH'] += ';' + compiler_bindir
66
+
67
+ # Compile and load.
68
+ verbose_build = (verbosity == 'full')
69
+
70
+ # Incremental build md5sum trickery. Copies all the input source files
71
+ # into a cached build directory under a combined md5 digest of the input
72
+ # source files. Copying is done only if the combined digest has changed.
73
+ # This keeps input file timestamps and filenames the same as in previous
74
+ # extension builds, allowing for fast incremental rebuilds.
75
+ #
76
+ # This optimization is done only in case all the source files reside in
77
+ # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
78
+ # environment variable is set (we take this as a signal that the user
79
+ # actually cares about this.)
80
+ source_dirs_set = set(os.path.dirname(source) for source in sources)
81
+ if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ):
82
+ all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file()))
83
+
84
+ # Compute a combined hash digest for all source files in the same
85
+ # custom op directory (usually .cu, .cpp, .py and .h files).
86
+ hash_md5 = hashlib.md5()
87
+ for src in all_source_files:
88
+ with open(src, 'rb') as f:
89
+ hash_md5.update(f.read())
90
+ build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
91
+ digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest())
92
+
93
+ if not os.path.isdir(digest_build_dir):
94
+ os.makedirs(digest_build_dir, exist_ok=True)
95
+ baton = FileBaton(os.path.join(digest_build_dir, 'lock'))
96
+ if baton.try_acquire():
97
+ try:
98
+ for src in all_source_files:
99
+ shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src)))
100
+ finally:
101
+ baton.release()
102
+ else:
103
+ # Someone else is copying source files under the digest dir,
104
+ # wait until done and continue.
105
+ baton.wait()
106
+ digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources]
107
+ torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir,
108
+ verbose=verbose_build, sources=digest_sources, **build_kwargs)
109
+ else:
110
+ torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
111
+ module = importlib.import_module(module_name)
112
+
113
+ except:
114
+ if verbosity == 'brief':
115
+ print('Failed!')
116
+ raise
117
+
118
+ # Print status and add to cache.
119
+ if verbosity == 'full':
120
+ print(f'Done setting up PyTorch plugin "{module_name}".')
121
+ elif verbosity == 'brief':
122
+ print('Done.')
123
+ _cached_plugins[module_name] = module
124
+ return module
125
+
126
+ #----------------------------------------------------------------------------
torch_utils/misc.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import re
10
+ import contextlib
11
+ import numpy as np
12
+ import torch
13
+ import warnings
14
+ import dnnlib
15
+
16
+ #----------------------------------------------------------------------------
17
+ # Cached construction of constant tensors. Avoids CPU=>GPU copy when the
18
+ # same constant is used multiple times.
19
+
20
+ _constant_cache = dict()
21
+
22
+ def constant(value, shape=None, dtype=None, device=None, memory_format=None):
23
+ value = np.asarray(value)
24
+ if shape is not None:
25
+ shape = tuple(shape)
26
+ if dtype is None:
27
+ dtype = torch.get_default_dtype()
28
+ if device is None:
29
+ device = torch.device('cpu')
30
+ if memory_format is None:
31
+ memory_format = torch.contiguous_format
32
+
33
+ key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
34
+ tensor = _constant_cache.get(key, None)
35
+ if tensor is None:
36
+ tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
37
+ if shape is not None:
38
+ tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
39
+ tensor = tensor.contiguous(memory_format=memory_format)
40
+ _constant_cache[key] = tensor
41
+ return tensor
42
+
43
+ #----------------------------------------------------------------------------
44
+ # Replace NaN/Inf with specified numerical values.
45
+
46
+ try:
47
+ nan_to_num = torch.nan_to_num # 1.8.0a0
48
+ except AttributeError:
49
+ def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
50
+ assert isinstance(input, torch.Tensor)
51
+ input = input.to(torch.float32)
52
+ if posinf is None:
53
+ posinf = torch.finfo(input.dtype).max
54
+ if neginf is None:
55
+ neginf = torch.finfo(input.dtype).min
56
+ assert nan == 0
57
+ return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
58
+
59
+ #----------------------------------------------------------------------------
60
+ # Symbolic assert.
61
+
62
+ try:
63
+ symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
64
+ except AttributeError:
65
+ symbolic_assert = torch.Assert # 1.7.0
66
+
67
+ #----------------------------------------------------------------------------
68
+ # Context manager to suppress known warnings in torch.jit.trace().
69
+
70
+ class suppress_tracer_warnings(warnings.catch_warnings):
71
+ def __enter__(self):
72
+ super().__enter__()
73
+ warnings.simplefilter('ignore', category=torch.jit.TracerWarning)
74
+ return self
75
+
76
+ #----------------------------------------------------------------------------
77
+ # Assert that the shape of a tensor matches the given list of integers.
78
+ # None indicates that the size of a dimension is allowed to vary.
79
+ # Performs symbolic assertion when used in torch.jit.trace().
80
+
81
+ def assert_shape(tensor, ref_shape):
82
+ if tensor.ndim != len(ref_shape):
83
+ raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
84
+ for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
85
+ if ref_size is None:
86
+ pass
87
+ elif isinstance(ref_size, torch.Tensor):
88
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
89
+ symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
90
+ elif isinstance(size, torch.Tensor):
91
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
92
+ symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
93
+ elif size != ref_size:
94
+ raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
95
+
96
+ #----------------------------------------------------------------------------
97
+ # Function decorator that calls torch.autograd.profiler.record_function().
98
+
99
+ def profiled_function(fn):
100
+ def decorator(*args, **kwargs):
101
+ with torch.autograd.profiler.record_function(fn.__name__):
102
+ return fn(*args, **kwargs)
103
+ decorator.__name__ = fn.__name__
104
+ return decorator
105
+
106
+ #----------------------------------------------------------------------------
107
+ # Sampler for torch.utils.data.DataLoader that loops over the dataset
108
+ # indefinitely, shuffling items as it goes.
109
+
110
+ class InfiniteSampler(torch.utils.data.Sampler):
111
+ def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
112
+ assert len(dataset) > 0
113
+ assert num_replicas > 0
114
+ assert 0 <= rank < num_replicas
115
+ assert 0 <= window_size <= 1
116
+ super().__init__(dataset)
117
+ self.dataset = dataset
118
+ self.rank = rank
119
+ self.num_replicas = num_replicas
120
+ self.shuffle = shuffle
121
+ self.seed = seed
122
+ self.window_size = window_size
123
+
124
+ def __iter__(self):
125
+ order = np.arange(len(self.dataset))
126
+ rnd = None
127
+ window = 0
128
+ if self.shuffle:
129
+ rnd = np.random.RandomState(self.seed)
130
+ rnd.shuffle(order)
131
+ window = int(np.rint(order.size * self.window_size))
132
+
133
+ idx = 0
134
+ while True:
135
+ i = idx % order.size
136
+ if idx % self.num_replicas == self.rank:
137
+ yield order[i]
138
+ if window >= 2:
139
+ j = (i - rnd.randint(window)) % order.size
140
+ order[i], order[j] = order[j], order[i]
141
+ idx += 1
142
+
143
+ #----------------------------------------------------------------------------
144
+ # Utilities for operating with torch.nn.Module parameters and buffers.
145
+
146
+ def params_and_buffers(module):
147
+ assert isinstance(module, torch.nn.Module)
148
+ return list(module.parameters()) + list(module.buffers())
149
+
150
+ def named_params_and_buffers(module):
151
+ assert isinstance(module, torch.nn.Module)
152
+ return list(module.named_parameters()) + list(module.named_buffers())
153
+
154
+ def copy_params_and_buffers(src_module, dst_module, require_all=False):
155
+ assert isinstance(src_module, torch.nn.Module)
156
+ assert isinstance(dst_module, torch.nn.Module)
157
+ src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)}
158
+ for name, tensor in named_params_and_buffers(dst_module):
159
+ assert (name in src_tensors) or (not require_all)
160
+ if name in src_tensors:
161
+ tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
162
+
163
+ #----------------------------------------------------------------------------
164
+ # Context manager for easily enabling/disabling DistributedDataParallel
165
+ # synchronization.
166
+
167
+ @contextlib.contextmanager
168
+ def ddp_sync(module, sync):
169
+ assert isinstance(module, torch.nn.Module)
170
+ if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
171
+ yield
172
+ else:
173
+ with module.no_sync():
174
+ yield
175
+
176
+ #----------------------------------------------------------------------------
177
+ # Check DistributedDataParallel consistency across processes.
178
+
179
+ def check_ddp_consistency(module, ignore_regex=None):
180
+ assert isinstance(module, torch.nn.Module)
181
+ for name, tensor in named_params_and_buffers(module):
182
+ fullname = type(module).__name__ + '.' + name
183
+ if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
184
+ continue
185
+ tensor = tensor.detach()
186
+ other = tensor.clone()
187
+ torch.distributed.broadcast(tensor=other, src=0)
188
+ assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname
189
+
190
+ #----------------------------------------------------------------------------
191
+ # Print summary table of module hierarchy.
192
+
193
+ def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
194
+ assert isinstance(module, torch.nn.Module)
195
+ assert not isinstance(module, torch.jit.ScriptModule)
196
+ assert isinstance(inputs, (tuple, list))
197
+
198
+ # Register hooks.
199
+ entries = []
200
+ nesting = [0]
201
+ def pre_hook(_mod, _inputs):
202
+ nesting[0] += 1
203
+ def post_hook(mod, _inputs, outputs):
204
+ nesting[0] -= 1
205
+ if nesting[0] <= max_nesting:
206
+ outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
207
+ outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
208
+ entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
209
+ hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
210
+ hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
211
+
212
+ # Run module.
213
+ outputs = module(*inputs)
214
+ for hook in hooks:
215
+ hook.remove()
216
+
217
+ # Identify unique outputs, parameters, and buffers.
218
+ tensors_seen = set()
219
+ for e in entries:
220
+ e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
221
+ e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
222
+ e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
223
+ tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
224
+
225
+ # Filter out redundant entries.
226
+ if skip_redundant:
227
+ entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
228
+
229
+ # Construct table.
230
+ rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
231
+ rows += [['---'] * len(rows[0])]
232
+ param_total = 0
233
+ buffer_total = 0
234
+ submodule_names = {mod: name for name, mod in module.named_modules()}
235
+ for e in entries:
236
+ name = '<top-level>' if e.mod is module else submodule_names[e.mod]
237
+ param_size = sum(t.numel() for t in e.unique_params)
238
+ buffer_size = sum(t.numel() for t in e.unique_buffers)
239
+ output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs]
240
+ output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
241
+ rows += [[
242
+ name + (':0' if len(e.outputs) >= 2 else ''),
243
+ str(param_size) if param_size else '-',
244
+ str(buffer_size) if buffer_size else '-',
245
+ (output_shapes + ['-'])[0],
246
+ (output_dtypes + ['-'])[0],
247
+ ]]
248
+ for idx in range(1, len(e.outputs)):
249
+ rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
250
+ param_total += param_size
251
+ buffer_total += buffer_size
252
+ rows += [['---'] * len(rows[0])]
253
+ rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
254
+
255
+ # Print table.
256
+ widths = [max(len(cell) for cell in column) for column in zip(*rows)]
257
+ print()
258
+ for row in rows:
259
+ print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
260
+ print()
261
+ return outputs
262
+
263
+ #----------------------------------------------------------------------------
torch_utils/ops/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ # empty
torch_utils/ops/bias_act.cpp ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <torch/extension.h>
10
+ #include <ATen/cuda/CUDAContext.h>
11
+ #include <c10/cuda/CUDAGuard.h>
12
+ #include "bias_act.h"
13
+
14
+ //------------------------------------------------------------------------
15
+
16
+ static bool has_same_layout(torch::Tensor x, torch::Tensor y)
17
+ {
18
+ if (x.dim() != y.dim())
19
+ return false;
20
+ for (int64_t i = 0; i < x.dim(); i++)
21
+ {
22
+ if (x.size(i) != y.size(i))
23
+ return false;
24
+ if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
25
+ return false;
26
+ }
27
+ return true;
28
+ }
29
+
30
+ //------------------------------------------------------------------------
31
+
32
+ static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
33
+ {
34
+ // Validate arguments.
35
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
36
+ TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
37
+ TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
38
+ TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
39
+ TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
40
+ TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
41
+ TORCH_CHECK(b.dim() == 1, "b must have rank 1");
42
+ TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
43
+ TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
44
+ TORCH_CHECK(grad >= 0, "grad must be non-negative");
45
+
46
+ // Validate layout.
47
+ TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
48
+ TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
49
+ TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
50
+ TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
51
+ TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
52
+
53
+ // Create output tensor.
54
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
55
+ torch::Tensor y = torch::empty_like(x);
56
+ TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
57
+
58
+ // Initialize CUDA kernel parameters.
59
+ bias_act_kernel_params p;
60
+ p.x = x.data_ptr();
61
+ p.b = (b.numel()) ? b.data_ptr() : NULL;
62
+ p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
63
+ p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
64
+ p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
65
+ p.y = y.data_ptr();
66
+ p.grad = grad;
67
+ p.act = act;
68
+ p.alpha = alpha;
69
+ p.gain = gain;
70
+ p.clamp = clamp;
71
+ p.sizeX = (int)x.numel();
72
+ p.sizeB = (int)b.numel();
73
+ p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
74
+
75
+ // Choose CUDA kernel.
76
+ void* kernel;
77
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
78
+ {
79
+ kernel = choose_bias_act_kernel<scalar_t>(p);
80
+ });
81
+ TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
82
+
83
+ // Launch CUDA kernel.
84
+ p.loopX = 4;
85
+ int blockSize = 4 * 32;
86
+ int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
87
+ void* args[] = {&p};
88
+ AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
89
+ return y;
90
+ }
91
+
92
+ //------------------------------------------------------------------------
93
+
94
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
95
+ {
96
+ m.def("bias_act", &bias_act);
97
+ }
98
+
99
+ //------------------------------------------------------------------------
torch_utils/ops/bias_act.cu ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <c10/util/Half.h>
10
+ #include "bias_act.h"
11
+
12
+ //------------------------------------------------------------------------
13
+ // Helpers.
14
+
15
+ template <class T> struct InternalType;
16
+ template <> struct InternalType<double> { typedef double scalar_t; };
17
+ template <> struct InternalType<float> { typedef float scalar_t; };
18
+ template <> struct InternalType<c10::Half> { typedef float scalar_t; };
19
+
20
+ //------------------------------------------------------------------------
21
+ // CUDA kernel.
22
+
23
+ template <class T, int A>
24
+ __global__ void bias_act_kernel(bias_act_kernel_params p)
25
+ {
26
+ typedef typename InternalType<T>::scalar_t scalar_t;
27
+ int G = p.grad;
28
+ scalar_t alpha = (scalar_t)p.alpha;
29
+ scalar_t gain = (scalar_t)p.gain;
30
+ scalar_t clamp = (scalar_t)p.clamp;
31
+ scalar_t one = (scalar_t)1;
32
+ scalar_t two = (scalar_t)2;
33
+ scalar_t expRange = (scalar_t)80;
34
+ scalar_t halfExpRange = (scalar_t)40;
35
+ scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
36
+ scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
37
+
38
+ // Loop over elements.
39
+ int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
40
+ for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
41
+ {
42
+ // Load.
43
+ scalar_t x = (scalar_t)((const T*)p.x)[xi];
44
+ scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
45
+ scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
46
+ scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
47
+ scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
48
+ scalar_t yy = (gain != 0) ? yref / gain : 0;
49
+ scalar_t y = 0;
50
+
51
+ // Apply bias.
52
+ ((G == 0) ? x : xref) += b;
53
+
54
+ // linear
55
+ if (A == 1)
56
+ {
57
+ if (G == 0) y = x;
58
+ if (G == 1) y = x;
59
+ }
60
+
61
+ // relu
62
+ if (A == 2)
63
+ {
64
+ if (G == 0) y = (x > 0) ? x : 0;
65
+ if (G == 1) y = (yy > 0) ? x : 0;
66
+ }
67
+
68
+ // lrelu
69
+ if (A == 3)
70
+ {
71
+ if (G == 0) y = (x > 0) ? x : x * alpha;
72
+ if (G == 1) y = (yy > 0) ? x : x * alpha;
73
+ }
74
+
75
+ // tanh
76
+ if (A == 4)
77
+ {
78
+ if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
79
+ if (G == 1) y = x * (one - yy * yy);
80
+ if (G == 2) y = x * (one - yy * yy) * (-two * yy);
81
+ }
82
+
83
+ // sigmoid
84
+ if (A == 5)
85
+ {
86
+ if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
87
+ if (G == 1) y = x * yy * (one - yy);
88
+ if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
89
+ }
90
+
91
+ // elu
92
+ if (A == 6)
93
+ {
94
+ if (G == 0) y = (x >= 0) ? x : exp(x) - one;
95
+ if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
96
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
97
+ }
98
+
99
+ // selu
100
+ if (A == 7)
101
+ {
102
+ if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
103
+ if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
104
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
105
+ }
106
+
107
+ // softplus
108
+ if (A == 8)
109
+ {
110
+ if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
111
+ if (G == 1) y = x * (one - exp(-yy));
112
+ if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
113
+ }
114
+
115
+ // swish
116
+ if (A == 9)
117
+ {
118
+ if (G == 0)
119
+ y = (x < -expRange) ? 0 : x / (exp(-x) + one);
120
+ else
121
+ {
122
+ scalar_t c = exp(xref);
123
+ scalar_t d = c + one;
124
+ if (G == 1)
125
+ y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
126
+ else
127
+ y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
128
+ yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
129
+ }
130
+ }
131
+
132
+ // Apply gain.
133
+ y *= gain * dy;
134
+
135
+ // Clamp.
136
+ if (clamp >= 0)
137
+ {
138
+ if (G == 0)
139
+ y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
140
+ else
141
+ y = (yref > -clamp & yref < clamp) ? y : 0;
142
+ }
143
+
144
+ // Store.
145
+ ((T*)p.y)[xi] = (T)y;
146
+ }
147
+ }
148
+
149
+ //------------------------------------------------------------------------
150
+ // CUDA kernel selection.
151
+
152
+ template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p)
153
+ {
154
+ if (p.act == 1) return (void*)bias_act_kernel<T, 1>;
155
+ if (p.act == 2) return (void*)bias_act_kernel<T, 2>;
156
+ if (p.act == 3) return (void*)bias_act_kernel<T, 3>;
157
+ if (p.act == 4) return (void*)bias_act_kernel<T, 4>;
158
+ if (p.act == 5) return (void*)bias_act_kernel<T, 5>;
159
+ if (p.act == 6) return (void*)bias_act_kernel<T, 6>;
160
+ if (p.act == 7) return (void*)bias_act_kernel<T, 7>;
161
+ if (p.act == 8) return (void*)bias_act_kernel<T, 8>;
162
+ if (p.act == 9) return (void*)bias_act_kernel<T, 9>;
163
+ return NULL;
164
+ }
165
+
166
+ //------------------------------------------------------------------------
167
+ // Template specializations.
168
+
169
+ template void* choose_bias_act_kernel<double> (const bias_act_kernel_params& p);
170
+ template void* choose_bias_act_kernel<float> (const bias_act_kernel_params& p);
171
+ template void* choose_bias_act_kernel<c10::Half> (const bias_act_kernel_params& p);
172
+
173
+ //------------------------------------------------------------------------
torch_utils/ops/bias_act.h ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ //------------------------------------------------------------------------
10
+ // CUDA kernel parameters.
11
+
12
+ struct bias_act_kernel_params
13
+ {
14
+ const void* x; // [sizeX]
15
+ const void* b; // [sizeB] or NULL
16
+ const void* xref; // [sizeX] or NULL
17
+ const void* yref; // [sizeX] or NULL
18
+ const void* dy; // [sizeX] or NULL
19
+ void* y; // [sizeX]
20
+
21
+ int grad;
22
+ int act;
23
+ float alpha;
24
+ float gain;
25
+ float clamp;
26
+
27
+ int sizeX;
28
+ int sizeB;
29
+ int stepB;
30
+ int loopX;
31
+ };
32
+
33
+ //------------------------------------------------------------------------
34
+ // CUDA kernel selection.
35
+
36
+ template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p);
37
+
38
+ //------------------------------------------------------------------------
torch_utils/ops/bias_act.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Custom PyTorch ops for efficient bias and activation."""
10
+
11
+ import os
12
+ import warnings
13
+ import numpy as np
14
+ import torch
15
+ import dnnlib
16
+ import traceback
17
+
18
+ from .. import custom_ops
19
+ from .. import misc
20
+
21
+ #----------------------------------------------------------------------------
22
+
23
+ activation_funcs = {
24
+ 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
25
+ 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
26
+ 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
27
+ 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
28
+ 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
29
+ 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
30
+ 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
31
+ 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
32
+ 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
33
+ }
34
+
35
+ #----------------------------------------------------------------------------
36
+
37
+ _inited = False
38
+ _plugin = None
39
+ _null_tensor = torch.empty([0])
40
+
41
+ def _init():
42
+ global _inited, _plugin
43
+ if not _inited:
44
+ _inited = True
45
+ sources = ['bias_act.cpp', 'bias_act.cu']
46
+ sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
47
+ try:
48
+ _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
49
+ except:
50
+ warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
51
+ return _plugin is not None
52
+
53
+ #----------------------------------------------------------------------------
54
+
55
+ def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
56
+ r"""Fused bias and activation function.
57
+
58
+ Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
59
+ and scales the result by `gain`. Each of the steps is optional. In most cases,
60
+ the fused op is considerably more efficient than performing the same calculation
61
+ using standard PyTorch ops. It supports first and second order gradients,
62
+ but not third order gradients.
63
+
64
+ Args:
65
+ x: Input activation tensor. Can be of any shape.
66
+ b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
67
+ as `x`. The shape must be known, and it must match the dimension of `x`
68
+ corresponding to `dim`.
69
+ dim: The dimension in `x` corresponding to the elements of `b`.
70
+ The value of `dim` is ignored if `b` is not specified.
71
+ act: Name of the activation function to evaluate, or `"linear"` to disable.
72
+ Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
73
+ See `activation_funcs` for a full list. `None` is not allowed.
74
+ alpha: Shape parameter for the activation function, or `None` to use the default.
75
+ gain: Scaling factor for the output tensor, or `None` to use default.
76
+ See `activation_funcs` for the default scaling of each activation function.
77
+ If unsure, consider specifying 1.
78
+ clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
79
+ the clamping (default).
80
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
81
+
82
+ Returns:
83
+ Tensor of the same shape and datatype as `x`.
84
+ """
85
+ assert isinstance(x, torch.Tensor)
86
+ assert impl in ['ref', 'cuda']
87
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
88
+ return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
89
+ return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
90
+
91
+ #----------------------------------------------------------------------------
92
+
93
+ @misc.profiled_function
94
+ def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
95
+ """Slow reference implementation of `bias_act()` using standard TensorFlow ops.
96
+ """
97
+ assert isinstance(x, torch.Tensor)
98
+ assert clamp is None or clamp >= 0
99
+ spec = activation_funcs[act]
100
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
101
+ gain = float(gain if gain is not None else spec.def_gain)
102
+ clamp = float(clamp if clamp is not None else -1)
103
+
104
+ # Add bias.
105
+ if b is not None:
106
+ assert isinstance(b, torch.Tensor) and b.ndim == 1
107
+ assert 0 <= dim < x.ndim
108
+ assert b.shape[0] == x.shape[dim]
109
+ x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
110
+
111
+ # Evaluate activation function.
112
+ alpha = float(alpha)
113
+ x = spec.func(x, alpha=alpha)
114
+
115
+ # Scale by gain.
116
+ gain = float(gain)
117
+ if gain != 1:
118
+ x = x * gain
119
+
120
+ # Clamp.
121
+ if clamp >= 0:
122
+ x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
123
+ return x
124
+
125
+ #----------------------------------------------------------------------------
126
+
127
+ _bias_act_cuda_cache = dict()
128
+
129
+ def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
130
+ """Fast CUDA implementation of `bias_act()` using custom ops.
131
+ """
132
+ # Parse arguments.
133
+ assert clamp is None or clamp >= 0
134
+ spec = activation_funcs[act]
135
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
136
+ gain = float(gain if gain is not None else spec.def_gain)
137
+ clamp = float(clamp if clamp is not None else -1)
138
+
139
+ # Lookup from cache.
140
+ key = (dim, act, alpha, gain, clamp)
141
+ if key in _bias_act_cuda_cache:
142
+ return _bias_act_cuda_cache[key]
143
+
144
+ # Forward op.
145
+ class BiasActCuda(torch.autograd.Function):
146
+ @staticmethod
147
+ def forward(ctx, x, b): # pylint: disable=arguments-differ
148
+ ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format
149
+ x = x.contiguous(memory_format=ctx.memory_format)
150
+ b = b.contiguous() if b is not None else _null_tensor
151
+ y = x
152
+ if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
153
+ y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
154
+ ctx.save_for_backward(
155
+ x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
156
+ b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
157
+ y if 'y' in spec.ref else _null_tensor)
158
+ return y
159
+
160
+ @staticmethod
161
+ def backward(ctx, dy): # pylint: disable=arguments-differ
162
+ dy = dy.contiguous(memory_format=ctx.memory_format)
163
+ x, b, y = ctx.saved_tensors
164
+ dx = None
165
+ db = None
166
+
167
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
168
+ dx = dy
169
+ if act != 'linear' or gain != 1 or clamp >= 0:
170
+ dx = BiasActCudaGrad.apply(dy, x, b, y)
171
+
172
+ if ctx.needs_input_grad[1]:
173
+ db = dx.sum([i for i in range(dx.ndim) if i != dim])
174
+
175
+ return dx, db
176
+
177
+ # Backward op.
178
+ class BiasActCudaGrad(torch.autograd.Function):
179
+ @staticmethod
180
+ def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
181
+ ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format
182
+ dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
183
+ ctx.save_for_backward(
184
+ dy if spec.has_2nd_grad else _null_tensor,
185
+ x, b, y)
186
+ return dx
187
+
188
+ @staticmethod
189
+ def backward(ctx, d_dx): # pylint: disable=arguments-differ
190
+ d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
191
+ dy, x, b, y = ctx.saved_tensors
192
+ d_dy = None
193
+ d_x = None
194
+ d_b = None
195
+ d_y = None
196
+
197
+ if ctx.needs_input_grad[0]:
198
+ d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
199
+
200
+ if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
201
+ d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
202
+
203
+ if spec.has_2nd_grad and ctx.needs_input_grad[2]:
204
+ d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
205
+
206
+ return d_dy, d_x, d_b, d_y
207
+
208
+ # Add to cache.
209
+ _bias_act_cuda_cache[key] = BiasActCuda
210
+ return BiasActCuda
211
+
212
+ #----------------------------------------------------------------------------
torch_utils/ops/conv2d_gradfix.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Custom replacement for `torch.nn.functional.conv2d` that supports
10
+ arbitrarily high order gradients with zero performance penalty."""
11
+
12
+ import warnings
13
+ import contextlib
14
+ import torch
15
+
16
+ # pylint: disable=redefined-builtin
17
+ # pylint: disable=arguments-differ
18
+ # pylint: disable=protected-access
19
+
20
+ #----------------------------------------------------------------------------
21
+
22
+ enabled = False # Enable the custom op by setting this to true.
23
+ weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
24
+
25
+ @contextlib.contextmanager
26
+ def no_weight_gradients():
27
+ global weight_gradients_disabled
28
+ old = weight_gradients_disabled
29
+ weight_gradients_disabled = True
30
+ yield
31
+ weight_gradients_disabled = old
32
+
33
+ #----------------------------------------------------------------------------
34
+
35
+ def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
36
+ if _should_use_custom_op(input):
37
+ return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
38
+ return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
39
+
40
+ def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
41
+ if _should_use_custom_op(input):
42
+ return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
43
+ return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
44
+
45
+ #----------------------------------------------------------------------------
46
+
47
+ def _should_use_custom_op(input):
48
+ assert isinstance(input, torch.Tensor)
49
+ if (not enabled) or (not torch.backends.cudnn.enabled):
50
+ return False
51
+ if input.device.type != 'cuda':
52
+ return False
53
+ if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
54
+ return True
55
+ warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().')
56
+ return False
57
+
58
+ def _tuple_of_ints(xs, ndim):
59
+ xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
60
+ assert len(xs) == ndim
61
+ assert all(isinstance(x, int) for x in xs)
62
+ return xs
63
+
64
+ #----------------------------------------------------------------------------
65
+
66
+ _conv2d_gradfix_cache = dict()
67
+
68
+ def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
69
+ # Parse arguments.
70
+ ndim = 2
71
+ weight_shape = tuple(weight_shape)
72
+ stride = _tuple_of_ints(stride, ndim)
73
+ padding = _tuple_of_ints(padding, ndim)
74
+ output_padding = _tuple_of_ints(output_padding, ndim)
75
+ dilation = _tuple_of_ints(dilation, ndim)
76
+
77
+ # Lookup from cache.
78
+ key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
79
+ if key in _conv2d_gradfix_cache:
80
+ return _conv2d_gradfix_cache[key]
81
+
82
+ # Validate arguments.
83
+ assert groups >= 1
84
+ assert len(weight_shape) == ndim + 2
85
+ assert all(stride[i] >= 1 for i in range(ndim))
86
+ assert all(padding[i] >= 0 for i in range(ndim))
87
+ assert all(dilation[i] >= 0 for i in range(ndim))
88
+ if not transpose:
89
+ assert all(output_padding[i] == 0 for i in range(ndim))
90
+ else: # transpose
91
+ assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
92
+
93
+ # Helpers.
94
+ common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
95
+ def calc_output_padding(input_shape, output_shape):
96
+ if transpose:
97
+ return [0, 0]
98
+ return [
99
+ input_shape[i + 2]
100
+ - (output_shape[i + 2] - 1) * stride[i]
101
+ - (1 - 2 * padding[i])
102
+ - dilation[i] * (weight_shape[i + 2] - 1)
103
+ for i in range(ndim)
104
+ ]
105
+
106
+ # Forward & backward.
107
+ class Conv2d(torch.autograd.Function):
108
+ @staticmethod
109
+ def forward(ctx, input, weight, bias):
110
+ assert weight.shape == weight_shape
111
+ if not transpose:
112
+ output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
113
+ else: # transpose
114
+ output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
115
+ ctx.save_for_backward(input, weight)
116
+ return output
117
+
118
+ @staticmethod
119
+ def backward(ctx, grad_output):
120
+ input, weight = ctx.saved_tensors
121
+ grad_input = None
122
+ grad_weight = None
123
+ grad_bias = None
124
+
125
+ if ctx.needs_input_grad[0]:
126
+ p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
127
+ grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None)
128
+ assert grad_input.shape == input.shape
129
+
130
+ if ctx.needs_input_grad[1] and not weight_gradients_disabled:
131
+ grad_weight = Conv2dGradWeight.apply(grad_output, input)
132
+ assert grad_weight.shape == weight_shape
133
+
134
+ if ctx.needs_input_grad[2]:
135
+ grad_bias = grad_output.sum([0, 2, 3])
136
+
137
+ return grad_input, grad_weight, grad_bias
138
+
139
+ # Gradient with respect to the weights.
140
+ class Conv2dGradWeight(torch.autograd.Function):
141
+ @staticmethod
142
+ def forward(ctx, grad_output, input):
143
+ op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight')
144
+ flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
145
+ grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
146
+ assert grad_weight.shape == weight_shape
147
+ ctx.save_for_backward(grad_output, input)
148
+ return grad_weight
149
+
150
+ @staticmethod
151
+ def backward(ctx, grad2_grad_weight):
152
+ grad_output, input = ctx.saved_tensors
153
+ grad2_grad_output = None
154
+ grad2_input = None
155
+
156
+ if ctx.needs_input_grad[0]:
157
+ grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
158
+ assert grad2_grad_output.shape == grad_output.shape
159
+
160
+ if ctx.needs_input_grad[1]:
161
+ p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
162
+ grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None)
163
+ assert grad2_input.shape == input.shape
164
+
165
+ return grad2_grad_output, grad2_input
166
+
167
+ _conv2d_gradfix_cache[key] = Conv2d
168
+ return Conv2d
169
+
170
+ #----------------------------------------------------------------------------
torch_utils/ops/conv2d_resample.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """2D convolution with optional up/downsampling."""
10
+
11
+ import torch
12
+
13
+ from .. import misc
14
+ from . import conv2d_gradfix
15
+ from . import upfirdn2d
16
+ from .upfirdn2d import _parse_padding
17
+ from .upfirdn2d import _get_filter_size
18
+
19
+ #----------------------------------------------------------------------------
20
+
21
+ def _get_weight_shape(w):
22
+ with misc.suppress_tracer_warnings(): # this value will be treated as a constant
23
+ shape = [int(sz) for sz in w.shape]
24
+ misc.assert_shape(w, shape)
25
+ return shape
26
+
27
+ #----------------------------------------------------------------------------
28
+
29
+ def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
30
+ """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
31
+ """
32
+ out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
33
+
34
+ # Flip weight if requested.
35
+ if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
36
+ w = w.flip([2, 3])
37
+
38
+ # Workaround performance pitfall in cuDNN 8.0.5, triggered when using
39
+ # 1x1 kernel + memory_format=channels_last + less than 64 channels.
40
+ if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose:
41
+ if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64:
42
+ if out_channels <= 4 and groups == 1:
43
+ in_shape = x.shape
44
+ x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1])
45
+ x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]])
46
+ else:
47
+ x = x.to(memory_format=torch.contiguous_format)
48
+ w = w.to(memory_format=torch.contiguous_format)
49
+ x = conv2d_gradfix.conv2d(x, w, groups=groups)
50
+ return x.to(memory_format=torch.channels_last)
51
+
52
+ # Otherwise => execute using conv2d_gradfix.
53
+ op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
54
+ return op(x, w, stride=stride, padding=padding, groups=groups)
55
+
56
+ #----------------------------------------------------------------------------
57
+
58
+ @misc.profiled_function
59
+ def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
60
+ r"""2D convolution with optional up/downsampling.
61
+
62
+ Padding is performed only once at the beginning, not between the operations.
63
+
64
+ Args:
65
+ x: Input tensor of shape
66
+ `[batch_size, in_channels, in_height, in_width]`.
67
+ w: Weight tensor of shape
68
+ `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
69
+ f: Low-pass filter for up/downsampling. Must be prepared beforehand by
70
+ calling upfirdn2d.setup_filter(). None = identity (default).
71
+ up: Integer upsampling factor (default: 1).
72
+ down: Integer downsampling factor (default: 1).
73
+ padding: Padding with respect to the upsampled image. Can be a single number
74
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
75
+ (default: 0).
76
+ groups: Split input channels into N groups (default: 1).
77
+ flip_weight: False = convolution, True = correlation (default: True).
78
+ flip_filter: False = convolution, True = correlation (default: False).
79
+
80
+ Returns:
81
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
82
+ """
83
+ # Validate arguments.
84
+ assert isinstance(x, torch.Tensor) and (x.ndim == 4)
85
+ assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
86
+ assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
87
+ assert isinstance(up, int) and (up >= 1)
88
+ assert isinstance(down, int) and (down >= 1)
89
+ assert isinstance(groups, int) and (groups >= 1)
90
+ out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
91
+ fw, fh = _get_filter_size(f)
92
+ px0, px1, py0, py1 = _parse_padding(padding)
93
+
94
+ # Adjust padding to account for up/downsampling.
95
+ if up > 1:
96
+ px0 += (fw + up - 1) // 2
97
+ px1 += (fw - up) // 2
98
+ py0 += (fh + up - 1) // 2
99
+ py1 += (fh - up) // 2
100
+ if down > 1:
101
+ px0 += (fw - down + 1) // 2
102
+ px1 += (fw - down) // 2
103
+ py0 += (fh - down + 1) // 2
104
+ py1 += (fh - down) // 2
105
+
106
+ # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
107
+ if kw == 1 and kh == 1 and (down > 1 and up == 1):
108
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
109
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
110
+ return x
111
+
112
+ # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
113
+ if kw == 1 and kh == 1 and (up > 1 and down == 1):
114
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
115
+ x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
116
+ return x
117
+
118
+ # Fast path: downsampling only => use strided convolution.
119
+ if down > 1 and up == 1:
120
+ x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
121
+ x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
122
+ return x
123
+
124
+ # Fast path: upsampling with optional downsampling => use transpose strided convolution.
125
+ if up > 1:
126
+ if groups == 1:
127
+ w = w.transpose(0, 1)
128
+ else:
129
+ w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
130
+ w = w.transpose(1, 2)
131
+ w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
132
+ px0 -= kw - 1
133
+ px1 -= kw - up
134
+ py0 -= kh - 1
135
+ py1 -= kh - up
136
+ pxt = max(min(-px0, -px1), 0)
137
+ pyt = max(min(-py0, -py1), 0)
138
+ x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
139
+ x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
140
+ if down > 1:
141
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
142
+ return x
143
+
144
+ # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
145
+ if up == 1 and down == 1:
146
+ if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
147
+ return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
148
+
149
+ # Fallback: Generic reference implementation.
150
+ x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
151
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
152
+ if down > 1:
153
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
154
+ return x
155
+
156
+ #----------------------------------------------------------------------------
torch_utils/ops/fma.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
10
+
11
+ import torch
12
+
13
+ #----------------------------------------------------------------------------
14
+
15
+ def fma(a, b, c): # => a * b + c
16
+ return _FusedMultiplyAdd.apply(a, b, c)
17
+
18
+ #----------------------------------------------------------------------------
19
+
20
+ class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
21
+ @staticmethod
22
+ def forward(ctx, a, b, c): # pylint: disable=arguments-differ
23
+ out = torch.addcmul(c, a, b)
24
+ ctx.save_for_backward(a, b)
25
+ ctx.c_shape = c.shape
26
+ return out
27
+
28
+ @staticmethod
29
+ def backward(ctx, dout): # pylint: disable=arguments-differ
30
+ a, b = ctx.saved_tensors
31
+ c_shape = ctx.c_shape
32
+ da = None
33
+ db = None
34
+ dc = None
35
+
36
+ if ctx.needs_input_grad[0]:
37
+ da = _unbroadcast(dout * b, a.shape)
38
+
39
+ if ctx.needs_input_grad[1]:
40
+ db = _unbroadcast(dout * a, b.shape)
41
+
42
+ if ctx.needs_input_grad[2]:
43
+ dc = _unbroadcast(dout, c_shape)
44
+
45
+ return da, db, dc
46
+
47
+ #----------------------------------------------------------------------------
48
+
49
+ def _unbroadcast(x, shape):
50
+ extra_dims = x.ndim - len(shape)
51
+ assert extra_dims >= 0
52
+ dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
53
+ if len(dim):
54
+ x = x.sum(dim=dim, keepdim=True)
55
+ if extra_dims:
56
+ x = x.reshape(-1, *x.shape[extra_dims+1:])
57
+ assert x.shape == shape
58
+ return x
59
+
60
+ #----------------------------------------------------------------------------
torch_utils/ops/grid_sample_gradfix.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Custom replacement for `torch.nn.functional.grid_sample` that
10
+ supports arbitrarily high order gradients between the input and output.
11
+ Only works on 2D images and assumes
12
+ `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
13
+
14
+ import warnings
15
+ import torch
16
+
17
+ # pylint: disable=redefined-builtin
18
+ # pylint: disable=arguments-differ
19
+ # pylint: disable=protected-access
20
+
21
+ #----------------------------------------------------------------------------
22
+
23
+ enabled = False # Enable the custom op by setting this to true.
24
+
25
+ #----------------------------------------------------------------------------
26
+
27
+ def grid_sample(input, grid):
28
+ if _should_use_custom_op():
29
+ return _GridSample2dForward.apply(input, grid)
30
+ return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
31
+
32
+ #----------------------------------------------------------------------------
33
+
34
+ def _should_use_custom_op():
35
+ if not enabled:
36
+ return False
37
+ if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
38
+ return True
39
+ warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().')
40
+ return False
41
+
42
+ #----------------------------------------------------------------------------
43
+
44
+ class _GridSample2dForward(torch.autograd.Function):
45
+ @staticmethod
46
+ def forward(ctx, input, grid):
47
+ assert input.ndim == 4
48
+ assert grid.ndim == 4
49
+ output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
50
+ ctx.save_for_backward(input, grid)
51
+ return output
52
+
53
+ @staticmethod
54
+ def backward(ctx, grad_output):
55
+ input, grid = ctx.saved_tensors
56
+ grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
57
+ return grad_input, grad_grid
58
+
59
+ #----------------------------------------------------------------------------
60
+
61
+ class _GridSample2dBackward(torch.autograd.Function):
62
+ @staticmethod
63
+ def forward(ctx, grad_output, input, grid):
64
+ op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
65
+ grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
66
+ ctx.save_for_backward(grid)
67
+ return grad_input, grad_grid
68
+
69
+ @staticmethod
70
+ def backward(ctx, grad2_grad_input, grad2_grad_grid):
71
+ _ = grad2_grad_grid # unused
72
+ grid, = ctx.saved_tensors
73
+ grad2_grad_output = None
74
+ grad2_input = None
75
+ grad2_grid = None
76
+
77
+ if ctx.needs_input_grad[0]:
78
+ grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
79
+
80
+ assert not ctx.needs_input_grad[2]
81
+ return grad2_grad_output, grad2_input, grad2_grid
82
+
83
+ #----------------------------------------------------------------------------
torch_utils/ops/upfirdn2d.cpp ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <torch/extension.h>
10
+ #include <ATen/cuda/CUDAContext.h>
11
+ #include <c10/cuda/CUDAGuard.h>
12
+ #include "upfirdn2d.h"
13
+
14
+ //------------------------------------------------------------------------
15
+
16
+ static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
17
+ {
18
+ // Validate arguments.
19
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
20
+ TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
21
+ TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
22
+ TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
23
+ TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
24
+ TORCH_CHECK(x.dim() == 4, "x must be rank 4");
25
+ TORCH_CHECK(f.dim() == 2, "f must be rank 2");
26
+ TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
27
+ TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
28
+ TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
29
+
30
+ // Create output tensor.
31
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
32
+ int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
33
+ int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
34
+ TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
35
+ torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
36
+ TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
37
+
38
+ // Initialize CUDA kernel parameters.
39
+ upfirdn2d_kernel_params p;
40
+ p.x = x.data_ptr();
41
+ p.f = f.data_ptr<float>();
42
+ p.y = y.data_ptr();
43
+ p.up = make_int2(upx, upy);
44
+ p.down = make_int2(downx, downy);
45
+ p.pad0 = make_int2(padx0, pady0);
46
+ p.flip = (flip) ? 1 : 0;
47
+ p.gain = gain;
48
+ p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
49
+ p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
50
+ p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
51
+ p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
52
+ p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
53
+ p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
54
+ p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
55
+ p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
56
+
57
+ // Choose CUDA kernel.
58
+ upfirdn2d_kernel_spec spec;
59
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
60
+ {
61
+ spec = choose_upfirdn2d_kernel<scalar_t>(p);
62
+ });
63
+
64
+ // Set looping options.
65
+ p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
66
+ p.loopMinor = spec.loopMinor;
67
+ p.loopX = spec.loopX;
68
+ p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
69
+ p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
70
+
71
+ // Compute grid size.
72
+ dim3 blockSize, gridSize;
73
+ if (spec.tileOutW < 0) // large
74
+ {
75
+ blockSize = dim3(4, 32, 1);
76
+ gridSize = dim3(
77
+ ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
78
+ (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
79
+ p.launchMajor);
80
+ }
81
+ else // small
82
+ {
83
+ blockSize = dim3(256, 1, 1);
84
+ gridSize = dim3(
85
+ ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
86
+ (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
87
+ p.launchMajor);
88
+ }
89
+
90
+ // Launch CUDA kernel.
91
+ void* args[] = {&p};
92
+ AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
93
+ return y;
94
+ }
95
+
96
+ //------------------------------------------------------------------------
97
+
98
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
99
+ {
100
+ m.def("upfirdn2d", &upfirdn2d);
101
+ }
102
+
103
+ //------------------------------------------------------------------------
torch_utils/ops/upfirdn2d.cu ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <c10/util/Half.h>
10
+ #include "upfirdn2d.h"
11
+
12
+ //------------------------------------------------------------------------
13
+ // Helpers.
14
+
15
+ template <class T> struct InternalType;
16
+ template <> struct InternalType<double> { typedef double scalar_t; };
17
+ template <> struct InternalType<float> { typedef float scalar_t; };
18
+ template <> struct InternalType<c10::Half> { typedef float scalar_t; };
19
+
20
+ static __device__ __forceinline__ int floor_div(int a, int b)
21
+ {
22
+ int t = 1 - a / b;
23
+ return (a + t * b) / b - t;
24
+ }
25
+
26
+ //------------------------------------------------------------------------
27
+ // Generic CUDA implementation for large filters.
28
+
29
+ template <class T> static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p)
30
+ {
31
+ typedef typename InternalType<T>::scalar_t scalar_t;
32
+
33
+ // Calculate thread index.
34
+ int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
35
+ int outY = minorBase / p.launchMinor;
36
+ minorBase -= outY * p.launchMinor;
37
+ int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
38
+ int majorBase = blockIdx.z * p.loopMajor;
39
+ if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
40
+ return;
41
+
42
+ // Setup Y receptive field.
43
+ int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
44
+ int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
45
+ int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
46
+ int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
47
+ if (p.flip)
48
+ filterY = p.filterSize.y - 1 - filterY;
49
+
50
+ // Loop over major, minor, and X.
51
+ for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
52
+ for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor)
53
+ {
54
+ int nc = major * p.sizeMinor + minor;
55
+ int n = nc / p.inSize.z;
56
+ int c = nc - n * p.inSize.z;
57
+ for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y)
58
+ {
59
+ // Setup X receptive field.
60
+ int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
61
+ int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
62
+ int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX;
63
+ int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
64
+ if (p.flip)
65
+ filterX = p.filterSize.x - 1 - filterX;
66
+
67
+ // Initialize pointers.
68
+ const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
69
+ const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
70
+ int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
71
+ int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
72
+
73
+ // Inner loop.
74
+ scalar_t v = 0;
75
+ for (int y = 0; y < h; y++)
76
+ {
77
+ for (int x = 0; x < w; x++)
78
+ {
79
+ v += (scalar_t)(*xp) * (scalar_t)(*fp);
80
+ xp += p.inStride.x;
81
+ fp += filterStepX;
82
+ }
83
+ xp += p.inStride.y - w * p.inStride.x;
84
+ fp += filterStepY - w * filterStepX;
85
+ }
86
+
87
+ // Store result.
88
+ v *= p.gain;
89
+ ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
90
+ }
91
+ }
92
+ }
93
+
94
+ //------------------------------------------------------------------------
95
+ // Specialized CUDA implementation for small filters.
96
+
97
+ template <class T, int upx, int upy, int downx, int downy, int filterW, int filterH, int tileOutW, int tileOutH, int loopMinor>
98
+ static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p)
99
+ {
100
+ typedef typename InternalType<T>::scalar_t scalar_t;
101
+ const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
102
+ const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
103
+ __shared__ volatile scalar_t sf[filterH][filterW];
104
+ __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
105
+
106
+ // Calculate tile index.
107
+ int minorBase = blockIdx.x;
108
+ int tileOutY = minorBase / p.launchMinor;
109
+ minorBase -= tileOutY * p.launchMinor;
110
+ minorBase *= loopMinor;
111
+ tileOutY *= tileOutH;
112
+ int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
113
+ int majorBase = blockIdx.z * p.loopMajor;
114
+ if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor)
115
+ return;
116
+
117
+ // Load filter (flipped).
118
+ for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x)
119
+ {
120
+ int fy = tapIdx / filterW;
121
+ int fx = tapIdx - fy * filterW;
122
+ scalar_t v = 0;
123
+ if (fx < p.filterSize.x & fy < p.filterSize.y)
124
+ {
125
+ int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
126
+ int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
127
+ v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
128
+ }
129
+ sf[fy][fx] = v;
130
+ }
131
+
132
+ // Loop over major and X.
133
+ for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
134
+ {
135
+ int baseNC = major * p.sizeMinor + minorBase;
136
+ int n = baseNC / p.inSize.z;
137
+ int baseC = baseNC - n * p.inSize.z;
138
+ for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW)
139
+ {
140
+ // Load input pixels.
141
+ int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
142
+ int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
143
+ int tileInX = floor_div(tileMidX, upx);
144
+ int tileInY = floor_div(tileMidY, upy);
145
+ __syncthreads();
146
+ for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x)
147
+ {
148
+ int relC = inIdx;
149
+ int relInX = relC / loopMinor;
150
+ int relInY = relInX / tileInW;
151
+ relC -= relInX * loopMinor;
152
+ relInX -= relInY * tileInW;
153
+ int c = baseC + relC;
154
+ int inX = tileInX + relInX;
155
+ int inY = tileInY + relInY;
156
+ scalar_t v = 0;
157
+ if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z)
158
+ v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
159
+ sx[relInY][relInX][relC] = v;
160
+ }
161
+
162
+ // Loop over output pixels.
163
+ __syncthreads();
164
+ for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x)
165
+ {
166
+ int relC = outIdx;
167
+ int relOutX = relC / loopMinor;
168
+ int relOutY = relOutX / tileOutW;
169
+ relC -= relOutX * loopMinor;
170
+ relOutX -= relOutY * tileOutW;
171
+ int c = baseC + relC;
172
+ int outX = tileOutX + relOutX;
173
+ int outY = tileOutY + relOutY;
174
+
175
+ // Setup receptive field.
176
+ int midX = tileMidX + relOutX * downx;
177
+ int midY = tileMidY + relOutY * downy;
178
+ int inX = floor_div(midX, upx);
179
+ int inY = floor_div(midY, upy);
180
+ int relInX = inX - tileInX;
181
+ int relInY = inY - tileInY;
182
+ int filterX = (inX + 1) * upx - midX - 1; // flipped
183
+ int filterY = (inY + 1) * upy - midY - 1; // flipped
184
+
185
+ // Inner loop.
186
+ if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z)
187
+ {
188
+ scalar_t v = 0;
189
+ #pragma unroll
190
+ for (int y = 0; y < filterH / upy; y++)
191
+ #pragma unroll
192
+ for (int x = 0; x < filterW / upx; x++)
193
+ v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx];
194
+ v *= p.gain;
195
+ ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
196
+ }
197
+ }
198
+ }
199
+ }
200
+ }
201
+
202
+ //------------------------------------------------------------------------
203
+ // CUDA kernel selection.
204
+
205
+ template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p)
206
+ {
207
+ int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
208
+
209
+ upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,1, 4}; // contiguous
210
+ if (s == 1) spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,4, 1}; // channels_last
211
+
212
+ if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
213
+ {
214
+ if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 64,16,1>, 64,16,1, 1};
215
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
216
+ if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5, 64,16,1>, 64,16,1, 1};
217
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
218
+ if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3, 64,16,1>, 64,16,1, 1};
219
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
220
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 20,1, 128,8,1>, 128,8,1, 1};
221
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
222
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 12,1, 128,8,1>, 128,8,1, 1};
223
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
224
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
225
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,20, 32,32,1>, 32,32,1, 1};
226
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
227
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,12, 32,32,1>, 32,32,1, 1};
228
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
229
+ }
230
+ if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
231
+ {
232
+ if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 16,16,8>, 16,16,8, 1};
233
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
234
+ if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
235
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
236
+ if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
237
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
238
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 20,1, 128,1,16>, 128,1,16, 1};
239
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
240
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 12,1, 128,1,16>, 128,1,16, 1};
241
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
242
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
243
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,20, 1,128,16>, 1,128,16, 1};
244
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
245
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,12, 1,128,16>, 1,128,16, 1};
246
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
247
+ }
248
+ if (s != 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
249
+ {
250
+ if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 64,16,1>, 64,16,1, 1};
251
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
252
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
253
+ if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 64,16,1>, 64,16,1, 1};
254
+ }
255
+ if (s == 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
256
+ {
257
+ if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 16,16,8>, 16,16,8, 1};
258
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 16,16,8>, 16,16,8, 1};
259
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
260
+ if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 16,16,8>, 16,16,8, 1};
261
+ }
262
+ if (s != 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
263
+ {
264
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
265
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 20,1, 128,8,1>, 128,8,1, 1};
266
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
267
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 12,1, 128,8,1>, 128,8,1, 1};
268
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
269
+ }
270
+ if (s == 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
271
+ {
272
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
273
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 20,1, 128,1,16>, 128,1,16, 1};
274
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
275
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 12,1, 128,1,16>, 128,1,16, 1};
276
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
277
+ }
278
+ if (s != 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
279
+ {
280
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
281
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,20, 32,32,1>, 32,32,1, 1};
282
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
283
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,12, 32,32,1>, 32,32,1, 1};
284
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
285
+ }
286
+ if (s == 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
287
+ {
288
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
289
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,20, 1,128,16>, 1,128,16, 1};
290
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
291
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,12, 1,128,16>, 1,128,16, 1};
292
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
293
+ }
294
+ if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // contiguous
295
+ {
296
+ if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 32,8,1>, 32,8,1, 1};
297
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 32,8,1>, 32,8,1, 1};
298
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 32,8,1>, 32,8,1, 1};
299
+ if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 32,8,1>, 32,8,1, 1};
300
+ }
301
+ if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // channels_last
302
+ {
303
+ if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 8,8,8>, 8,8,8, 1};
304
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 8,8,8>, 8,8,8, 1};
305
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 8,8,8>, 8,8,8, 1};
306
+ if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 8,8,8>, 8,8,8, 1};
307
+ }
308
+ if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // contiguous
309
+ {
310
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,8,1>, 64,8,1, 1};
311
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 20,1, 64,8,1>, 64,8,1, 1};
312
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,8,1>, 64,8,1, 1};
313
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 12,1, 64,8,1>, 64,8,1, 1};
314
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,8,1>, 64,8,1, 1};
315
+ }
316
+ if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // channels_last
317
+ {
318
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,1,8>, 64,1,8, 1};
319
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 20,1, 64,1,8>, 64,1,8, 1};
320
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,1,8>, 64,1,8, 1};
321
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 12,1, 64,1,8>, 64,1,8, 1};
322
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,1,8>, 64,1,8, 1};
323
+ }
324
+ if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // contiguous
325
+ {
326
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 32,16,1>, 32,16,1, 1};
327
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,20, 32,16,1>, 32,16,1, 1};
328
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 32,16,1>, 32,16,1, 1};
329
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,12, 32,16,1>, 32,16,1, 1};
330
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 32,16,1>, 32,16,1, 1};
331
+ }
332
+ if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // channels_last
333
+ {
334
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 1,64,8>, 1,64,8, 1};
335
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,20, 1,64,8>, 1,64,8, 1};
336
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 1,64,8>, 1,64,8, 1};
337
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,12, 1,64,8>, 1,64,8, 1};
338
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 1,64,8>, 1,64,8, 1};
339
+ }
340
+ return spec;
341
+ }
342
+
343
+ //------------------------------------------------------------------------
344
+ // Template specializations.
345
+
346
+ template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<double> (const upfirdn2d_kernel_params& p);
347
+ template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<float> (const upfirdn2d_kernel_params& p);
348
+ template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<c10::Half>(const upfirdn2d_kernel_params& p);
349
+
350
+ //------------------------------------------------------------------------
torch_utils/ops/upfirdn2d.h ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <cuda_runtime.h>
10
+
11
+ //------------------------------------------------------------------------
12
+ // CUDA kernel parameters.
13
+
14
+ struct upfirdn2d_kernel_params
15
+ {
16
+ const void* x;
17
+ const float* f;
18
+ void* y;
19
+
20
+ int2 up;
21
+ int2 down;
22
+ int2 pad0;
23
+ int flip;
24
+ float gain;
25
+
26
+ int4 inSize; // [width, height, channel, batch]
27
+ int4 inStride;
28
+ int2 filterSize; // [width, height]
29
+ int2 filterStride;
30
+ int4 outSize; // [width, height, channel, batch]
31
+ int4 outStride;
32
+ int sizeMinor;
33
+ int sizeMajor;
34
+
35
+ int loopMinor;
36
+ int loopMajor;
37
+ int loopX;
38
+ int launchMinor;
39
+ int launchMajor;
40
+ };
41
+
42
+ //------------------------------------------------------------------------
43
+ // CUDA kernel specialization.
44
+
45
+ struct upfirdn2d_kernel_spec
46
+ {
47
+ void* kernel;
48
+ int tileOutW;
49
+ int tileOutH;
50
+ int loopMinor;
51
+ int loopX;
52
+ };
53
+
54
+ //------------------------------------------------------------------------
55
+ // CUDA kernel selection.
56
+
57
+ template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
58
+
59
+ //------------------------------------------------------------------------
torch_utils/ops/upfirdn2d.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Custom PyTorch ops for efficient resampling of 2D images."""
10
+
11
+ import os
12
+ import warnings
13
+ import numpy as np
14
+ import torch
15
+ import traceback
16
+
17
+ from .. import custom_ops
18
+ from .. import misc
19
+ from . import conv2d_gradfix
20
+
21
+ #----------------------------------------------------------------------------
22
+
23
+ _inited = False
24
+ _plugin = None
25
+
26
+ def _init():
27
+ global _inited, _plugin
28
+ if not _inited:
29
+ sources = ['upfirdn2d.cpp', 'upfirdn2d.cu']
30
+ sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
31
+ try:
32
+ _plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
33
+ except:
34
+ warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
35
+ return _plugin is not None
36
+
37
+ def _parse_scaling(scaling):
38
+ if isinstance(scaling, int):
39
+ scaling = [scaling, scaling]
40
+ assert isinstance(scaling, (list, tuple))
41
+ assert all(isinstance(x, int) for x in scaling)
42
+ sx, sy = scaling
43
+ assert sx >= 1 and sy >= 1
44
+ return sx, sy
45
+
46
+ def _parse_padding(padding):
47
+ if isinstance(padding, int):
48
+ padding = [padding, padding]
49
+ assert isinstance(padding, (list, tuple))
50
+ assert all(isinstance(x, int) for x in padding)
51
+ if len(padding) == 2:
52
+ padx, pady = padding
53
+ padding = [padx, padx, pady, pady]
54
+ padx0, padx1, pady0, pady1 = padding
55
+ return padx0, padx1, pady0, pady1
56
+
57
+ def _get_filter_size(f):
58
+ if f is None:
59
+ return 1, 1
60
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
61
+ fw = f.shape[-1]
62
+ fh = f.shape[0]
63
+ with misc.suppress_tracer_warnings():
64
+ fw = int(fw)
65
+ fh = int(fh)
66
+ misc.assert_shape(f, [fh, fw][:f.ndim])
67
+ assert fw >= 1 and fh >= 1
68
+ return fw, fh
69
+
70
+ #----------------------------------------------------------------------------
71
+
72
+ def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None):
73
+ r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
74
+
75
+ Args:
76
+ f: Torch tensor, numpy array, or python list of the shape
77
+ `[filter_height, filter_width]` (non-separable),
78
+ `[filter_taps]` (separable),
79
+ `[]` (impulse), or
80
+ `None` (identity).
81
+ device: Result device (default: cpu).
82
+ normalize: Normalize the filter so that it retains the magnitude
83
+ for constant input signal (DC)? (default: True).
84
+ flip_filter: Flip the filter? (default: False).
85
+ gain: Overall scaling factor for signal magnitude (default: 1).
86
+ separable: Return a separable filter? (default: select automatically).
87
+
88
+ Returns:
89
+ Float32 tensor of the shape
90
+ `[filter_height, filter_width]` (non-separable) or
91
+ `[filter_taps]` (separable).
92
+ """
93
+ # Validate.
94
+ if f is None:
95
+ f = 1
96
+ f = torch.as_tensor(f, dtype=torch.float32)
97
+ assert f.ndim in [0, 1, 2]
98
+ assert f.numel() > 0
99
+ if f.ndim == 0:
100
+ f = f[np.newaxis]
101
+
102
+ # Separable?
103
+ if separable is None:
104
+ separable = (f.ndim == 1 and f.numel() >= 8)
105
+ if f.ndim == 1 and not separable:
106
+ f = f.ger(f)
107
+ assert f.ndim == (1 if separable else 2)
108
+
109
+ # Apply normalize, flip, gain, and device.
110
+ if normalize:
111
+ f /= f.sum()
112
+ if flip_filter:
113
+ f = f.flip(list(range(f.ndim)))
114
+ f = f * (gain ** (f.ndim / 2))
115
+ f = f.to(device=device)
116
+ return f
117
+
118
+ #----------------------------------------------------------------------------
119
+
120
+ def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
121
+ r"""Pad, upsample, filter, and downsample a batch of 2D images.
122
+
123
+ Performs the following sequence of operations for each channel:
124
+
125
+ 1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
126
+
127
+ 2. Pad the image with the specified number of zeros on each side (`padding`).
128
+ Negative padding corresponds to cropping the image.
129
+
130
+ 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
131
+ so that the footprint of all output pixels lies within the input image.
132
+
133
+ 4. Downsample the image by keeping every Nth pixel (`down`).
134
+
135
+ This sequence of operations bears close resemblance to scipy.signal.upfirdn().
136
+ The fused op is considerably more efficient than performing the same calculation
137
+ using standard PyTorch ops. It supports gradients of arbitrary order.
138
+
139
+ Args:
140
+ x: Float32/float64/float16 input tensor of the shape
141
+ `[batch_size, num_channels, in_height, in_width]`.
142
+ f: Float32 FIR filter of the shape
143
+ `[filter_height, filter_width]` (non-separable),
144
+ `[filter_taps]` (separable), or
145
+ `None` (identity).
146
+ up: Integer upsampling factor. Can be a single int or a list/tuple
147
+ `[x, y]` (default: 1).
148
+ down: Integer downsampling factor. Can be a single int or a list/tuple
149
+ `[x, y]` (default: 1).
150
+ padding: Padding with respect to the upsampled image. Can be a single number
151
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
152
+ (default: 0).
153
+ flip_filter: False = convolution, True = correlation (default: False).
154
+ gain: Overall scaling factor for signal magnitude (default: 1).
155
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
156
+
157
+ Returns:
158
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
159
+ """
160
+ assert isinstance(x, torch.Tensor)
161
+ assert impl in ['ref', 'cuda']
162
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
163
+ return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f)
164
+ return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
165
+
166
+ #----------------------------------------------------------------------------
167
+
168
+ @misc.profiled_function
169
+ def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
170
+ """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
171
+ """
172
+ # Validate arguments.
173
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
174
+ if f is None:
175
+ f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
176
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
177
+ assert f.dtype == torch.float32 and not f.requires_grad
178
+ batch_size, num_channels, in_height, in_width = x.shape
179
+ upx, upy = _parse_scaling(up)
180
+ downx, downy = _parse_scaling(down)
181
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
182
+
183
+ # Upsample by inserting zeros.
184
+ x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
185
+ x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
186
+ x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
187
+
188
+ # Pad or crop.
189
+ x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
190
+ x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)]
191
+
192
+ # Setup filter.
193
+ f = f * (gain ** (f.ndim / 2))
194
+ f = f.to(x.dtype)
195
+ if not flip_filter:
196
+ f = f.flip(list(range(f.ndim)))
197
+
198
+ # Convolve with the filter.
199
+ f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
200
+ if f.ndim == 4:
201
+ x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels)
202
+ else:
203
+ x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
204
+ x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
205
+
206
+ # Downsample by throwing away pixels.
207
+ x = x[:, :, ::downy, ::downx]
208
+ return x
209
+
210
+ #----------------------------------------------------------------------------
211
+
212
+ _upfirdn2d_cuda_cache = dict()
213
+
214
+ def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1):
215
+ """Fast CUDA implementation of `upfirdn2d()` using custom ops.
216
+ """
217
+ # Parse arguments.
218
+ upx, upy = _parse_scaling(up)
219
+ downx, downy = _parse_scaling(down)
220
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
221
+
222
+ # Lookup from cache.
223
+ key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
224
+ if key in _upfirdn2d_cuda_cache:
225
+ return _upfirdn2d_cuda_cache[key]
226
+
227
+ # Forward op.
228
+ class Upfirdn2dCuda(torch.autograd.Function):
229
+ @staticmethod
230
+ def forward(ctx, x, f): # pylint: disable=arguments-differ
231
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
232
+ if f is None:
233
+ f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
234
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
235
+ y = x
236
+ if f.ndim == 2:
237
+ y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
238
+ else:
239
+ y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain))
240
+ y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain))
241
+ ctx.save_for_backward(f)
242
+ ctx.x_shape = x.shape
243
+ return y
244
+
245
+ @staticmethod
246
+ def backward(ctx, dy): # pylint: disable=arguments-differ
247
+ f, = ctx.saved_tensors
248
+ _, _, ih, iw = ctx.x_shape
249
+ _, _, oh, ow = dy.shape
250
+ fw, fh = _get_filter_size(f)
251
+ p = [
252
+ fw - padx0 - 1,
253
+ iw * upx - ow * downx + padx0 - upx + 1,
254
+ fh - pady0 - 1,
255
+ ih * upy - oh * downy + pady0 - upy + 1,
256
+ ]
257
+ dx = None
258
+ df = None
259
+
260
+ if ctx.needs_input_grad[0]:
261
+ dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f)
262
+
263
+ assert not ctx.needs_input_grad[1]
264
+ return dx, df
265
+
266
+ # Add to cache.
267
+ _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda
268
+ return Upfirdn2dCuda
269
+
270
+ #----------------------------------------------------------------------------
271
+
272
+ def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'):
273
+ r"""Filter a batch of 2D images using the given 2D FIR filter.
274
+
275
+ By default, the result is padded so that its shape matches the input.
276
+ User-specified padding is applied on top of that, with negative values
277
+ indicating cropping. Pixels outside the image are assumed to be zero.
278
+
279
+ Args:
280
+ x: Float32/float64/float16 input tensor of the shape
281
+ `[batch_size, num_channels, in_height, in_width]`.
282
+ f: Float32 FIR filter of the shape
283
+ `[filter_height, filter_width]` (non-separable),
284
+ `[filter_taps]` (separable), or
285
+ `None` (identity).
286
+ padding: Padding with respect to the output. Can be a single number or a
287
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
288
+ (default: 0).
289
+ flip_filter: False = convolution, True = correlation (default: False).
290
+ gain: Overall scaling factor for signal magnitude (default: 1).
291
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
292
+
293
+ Returns:
294
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
295
+ """
296
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
297
+ fw, fh = _get_filter_size(f)
298
+ p = [
299
+ padx0 + fw // 2,
300
+ padx1 + (fw - 1) // 2,
301
+ pady0 + fh // 2,
302
+ pady1 + (fh - 1) // 2,
303
+ ]
304
+ return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
305
+
306
+ #----------------------------------------------------------------------------
307
+
308
+ def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
309
+ r"""Upsample a batch of 2D images using the given 2D FIR filter.
310
+
311
+ By default, the result is padded so that its shape is a multiple of the input.
312
+ User-specified padding is applied on top of that, with negative values
313
+ indicating cropping. Pixels outside the image are assumed to be zero.
314
+
315
+ Args:
316
+ x: Float32/float64/float16 input tensor of the shape
317
+ `[batch_size, num_channels, in_height, in_width]`.
318
+ f: Float32 FIR filter of the shape
319
+ `[filter_height, filter_width]` (non-separable),
320
+ `[filter_taps]` (separable), or
321
+ `None` (identity).
322
+ up: Integer upsampling factor. Can be a single int or a list/tuple
323
+ `[x, y]` (default: 1).
324
+ padding: Padding with respect to the output. Can be a single number or a
325
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
326
+ (default: 0).
327
+ flip_filter: False = convolution, True = correlation (default: False).
328
+ gain: Overall scaling factor for signal magnitude (default: 1).
329
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
330
+
331
+ Returns:
332
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
333
+ """
334
+ upx, upy = _parse_scaling(up)
335
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
336
+ fw, fh = _get_filter_size(f)
337
+ p = [
338
+ padx0 + (fw + upx - 1) // 2,
339
+ padx1 + (fw - upx) // 2,
340
+ pady0 + (fh + upy - 1) // 2,
341
+ pady1 + (fh - upy) // 2,
342
+ ]
343
+ return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl)
344
+
345
+ #----------------------------------------------------------------------------
346
+
347
+ def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
348
+ r"""Downsample a batch of 2D images using the given 2D FIR filter.
349
+
350
+ By default, the result is padded so that its shape is a fraction of the input.
351
+ User-specified padding is applied on top of that, with negative values
352
+ indicating cropping. Pixels outside the image are assumed to be zero.
353
+
354
+ Args:
355
+ x: Float32/float64/float16 input tensor of the shape
356
+ `[batch_size, num_channels, in_height, in_width]`.
357
+ f: Float32 FIR filter of the shape
358
+ `[filter_height, filter_width]` (non-separable),
359
+ `[filter_taps]` (separable), or
360
+ `None` (identity).
361
+ down: Integer downsampling factor. Can be a single int or a list/tuple
362
+ `[x, y]` (default: 1).
363
+ padding: Padding with respect to the input. Can be a single number or a
364
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
365
+ (default: 0).
366
+ flip_filter: False = convolution, True = correlation (default: False).
367
+ gain: Overall scaling factor for signal magnitude (default: 1).
368
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
369
+
370
+ Returns:
371
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
372
+ """
373
+ downx, downy = _parse_scaling(down)
374
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
375
+ fw, fh = _get_filter_size(f)
376
+ p = [
377
+ padx0 + (fw - downx + 1) // 2,
378
+ padx1 + (fw - downx) // 2,
379
+ pady0 + (fh - downy + 1) // 2,
380
+ pady1 + (fh - downy) // 2,
381
+ ]
382
+ return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
383
+
384
+ #----------------------------------------------------------------------------
torch_utils/persistence.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Facilities for pickling Python code alongside other data.
10
+
11
+ The pickled code is automatically imported into a separate Python module
12
+ during unpickling. This way, any previously exported pickles will remain
13
+ usable even if the original code is no longer available, or if the current
14
+ version of the code is not consistent with what was originally pickled."""
15
+
16
+ import sys
17
+ import pickle
18
+ import io
19
+ import inspect
20
+ import copy
21
+ import uuid
22
+ import types
23
+ import dnnlib
24
+
25
+ #----------------------------------------------------------------------------
26
+
27
+ _version = 6 # internal version number
28
+ _decorators = set() # {decorator_class, ...}
29
+ _import_hooks = [] # [hook_function, ...]
30
+ _module_to_src_dict = dict() # {module: src, ...}
31
+ _src_to_module_dict = dict() # {src: module, ...}
32
+
33
+ #----------------------------------------------------------------------------
34
+
35
+ def persistent_class(orig_class):
36
+ r"""Class decorator that extends a given class to save its source code
37
+ when pickled.
38
+
39
+ Example:
40
+
41
+ from torch_utils import persistence
42
+
43
+ @persistence.persistent_class
44
+ class MyNetwork(torch.nn.Module):
45
+ def __init__(self, num_inputs, num_outputs):
46
+ super().__init__()
47
+ self.fc = MyLayer(num_inputs, num_outputs)
48
+ ...
49
+
50
+ @persistence.persistent_class
51
+ class MyLayer(torch.nn.Module):
52
+ ...
53
+
54
+ When pickled, any instance of `MyNetwork` and `MyLayer` will save its
55
+ source code alongside other internal state (e.g., parameters, buffers,
56
+ and submodules). This way, any previously exported pickle will remain
57
+ usable even if the class definitions have been modified or are no
58
+ longer available.
59
+
60
+ The decorator saves the source code of the entire Python module
61
+ containing the decorated class. It does *not* save the source code of
62
+ any imported modules. Thus, the imported modules must be available
63
+ during unpickling, also including `torch_utils.persistence` itself.
64
+
65
+ It is ok to call functions defined in the same module from the
66
+ decorated class. However, if the decorated class depends on other
67
+ classes defined in the same module, they must be decorated as well.
68
+ This is illustrated in the above example in the case of `MyLayer`.
69
+
70
+ It is also possible to employ the decorator just-in-time before
71
+ calling the constructor. For example:
72
+
73
+ cls = MyLayer
74
+ if want_to_make_it_persistent:
75
+ cls = persistence.persistent_class(cls)
76
+ layer = cls(num_inputs, num_outputs)
77
+
78
+ As an additional feature, the decorator also keeps track of the
79
+ arguments that were used to construct each instance of the decorated
80
+ class. The arguments can be queried via `obj.init_args` and
81
+ `obj.init_kwargs`, and they are automatically pickled alongside other
82
+ object state. A typical use case is to first unpickle a previous
83
+ instance of a persistent class, and then upgrade it to use the latest
84
+ version of the source code:
85
+
86
+ with open('old_pickle.pkl', 'rb') as f:
87
+ old_net = pickle.load(f)
88
+ new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
89
+ misc.copy_params_and_buffers(old_net, new_net, require_all=True)
90
+ """
91
+ assert isinstance(orig_class, type)
92
+ if is_persistent(orig_class):
93
+ return orig_class
94
+
95
+ assert orig_class.__module__ in sys.modules
96
+ orig_module = sys.modules[orig_class.__module__]
97
+ orig_module_src = _module_to_src(orig_module)
98
+
99
+ class Decorator(orig_class):
100
+ _orig_module_src = orig_module_src
101
+ _orig_class_name = orig_class.__name__
102
+
103
+ def __init__(self, *args, **kwargs):
104
+ super().__init__(*args, **kwargs)
105
+ self._init_args = copy.deepcopy(args)
106
+ self._init_kwargs = copy.deepcopy(kwargs)
107
+ assert orig_class.__name__ in orig_module.__dict__
108
+ _check_pickleable(self.__reduce__())
109
+
110
+ @property
111
+ def init_args(self):
112
+ return copy.deepcopy(self._init_args)
113
+
114
+ @property
115
+ def init_kwargs(self):
116
+ return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs))
117
+
118
+ def __reduce__(self):
119
+ fields = list(super().__reduce__())
120
+ fields += [None] * max(3 - len(fields), 0)
121
+ if fields[0] is not _reconstruct_persistent_obj:
122
+ meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
123
+ fields[0] = _reconstruct_persistent_obj # reconstruct func
124
+ fields[1] = (meta,) # reconstruct args
125
+ fields[2] = None # state dict
126
+ return tuple(fields)
127
+
128
+ Decorator.__name__ = orig_class.__name__
129
+ _decorators.add(Decorator)
130
+ return Decorator
131
+
132
+ #----------------------------------------------------------------------------
133
+
134
+ def is_persistent(obj):
135
+ r"""Test whether the given object or class is persistent, i.e.,
136
+ whether it will save its source code when pickled.
137
+ """
138
+ try:
139
+ if obj in _decorators:
140
+ return True
141
+ except TypeError:
142
+ pass
143
+ return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
144
+
145
+ #----------------------------------------------------------------------------
146
+
147
+ def import_hook(hook):
148
+ r"""Register an import hook that is called whenever a persistent object
149
+ is being unpickled. A typical use case is to patch the pickled source
150
+ code to avoid errors and inconsistencies when the API of some imported
151
+ module has changed.
152
+
153
+ The hook should have the following signature:
154
+
155
+ hook(meta) -> modified meta
156
+
157
+ `meta` is an instance of `dnnlib.EasyDict` with the following fields:
158
+
159
+ type: Type of the persistent object, e.g. `'class'`.
160
+ version: Internal version number of `torch_utils.persistence`.
161
+ module_src Original source code of the Python module.
162
+ class_name: Class name in the original Python module.
163
+ state: Internal state of the object.
164
+
165
+ Example:
166
+
167
+ @persistence.import_hook
168
+ def wreck_my_network(meta):
169
+ if meta.class_name == 'MyNetwork':
170
+ print('MyNetwork is being imported. I will wreck it!')
171
+ meta.module_src = meta.module_src.replace("True", "False")
172
+ return meta
173
+ """
174
+ assert callable(hook)
175
+ _import_hooks.append(hook)
176
+
177
+ #----------------------------------------------------------------------------
178
+
179
+ def _reconstruct_persistent_obj(meta):
180
+ r"""Hook that is called internally by the `pickle` module to unpickle
181
+ a persistent object.
182
+ """
183
+ meta = dnnlib.EasyDict(meta)
184
+ meta.state = dnnlib.EasyDict(meta.state)
185
+ for hook in _import_hooks:
186
+ meta = hook(meta)
187
+ assert meta is not None
188
+
189
+ assert meta.version == _version
190
+ module = _src_to_module(meta.module_src)
191
+
192
+ assert meta.type == 'class'
193
+ orig_class = module.__dict__[meta.class_name]
194
+ decorator_class = persistent_class(orig_class)
195
+ obj = decorator_class.__new__(decorator_class)
196
+
197
+ setstate = getattr(obj, '__setstate__', None)
198
+ if callable(setstate):
199
+ setstate(meta.state) # pylint: disable=not-callable
200
+ else:
201
+ obj.__dict__.update(meta.state)
202
+ return obj
203
+
204
+ #----------------------------------------------------------------------------
205
+
206
+ def _module_to_src(module):
207
+ r"""Query the source code of a given Python module.
208
+ """
209
+ src = _module_to_src_dict.get(module, None)
210
+ if src is None:
211
+ src = inspect.getsource(module)
212
+ _module_to_src_dict[module] = src
213
+ _src_to_module_dict[src] = module
214
+ return src
215
+
216
+ def _src_to_module(src):
217
+ r"""Get or create a Python module for the given source code.
218
+ """
219
+ module = _src_to_module_dict.get(src, None)
220
+ if module is None:
221
+ module_name = "_imported_module_" + uuid.uuid4().hex
222
+ module = types.ModuleType(module_name)
223
+ sys.modules[module_name] = module
224
+ _module_to_src_dict[module] = src
225
+ _src_to_module_dict[src] = module
226
+ exec(src, module.__dict__) # pylint: disable=exec-used
227
+ return module
228
+
229
+ #----------------------------------------------------------------------------
230
+
231
+ def _check_pickleable(obj):
232
+ r"""Check that the given object is pickleable, raising an exception if
233
+ it is not. This function is expected to be considerably more efficient
234
+ than actually pickling the object.
235
+ """
236
+ def recurse(obj):
237
+ if isinstance(obj, (list, tuple, set)):
238
+ return [recurse(x) for x in obj]
239
+ if isinstance(obj, dict):
240
+ return [[recurse(x), recurse(y)] for x, y in obj.items()]
241
+ if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
242
+ return None # Python primitive types are pickleable.
243
+ if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']:
244
+ return None # NumPy arrays and PyTorch tensors are pickleable.
245
+ if is_persistent(obj):
246
+ return None # Persistent objects are pickleable, by virtue of the constructor check.
247
+ return obj
248
+ with io.BytesIO() as f:
249
+ pickle.dump(recurse(obj), f)
250
+
251
+ #----------------------------------------------------------------------------