feng2022 commited on
Commit
f60473f
1 Parent(s): 3d38d42
torch_utils/__init__.py DELETED
@@ -1,11 +0,0 @@
1
- # Copyright (c) SenseTime Research. All rights reserved.
2
-
3
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
- #
5
- # NVIDIA CORPORATION and its licensors retain all intellectual property
6
- # and proprietary rights in and to this software, related documentation
7
- # and any modifications thereto. Any use, reproduction, disclosure or
8
- # distribution of this software and related documentation without an express
9
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
-
11
- # empty
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/custom_ops.py DELETED
@@ -1,238 +0,0 @@
1
- # Copyright (c) SenseTime Research. All rights reserved.
2
-
3
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
- #
5
- # NVIDIA CORPORATION and its licensors retain all intellectual property
6
- # and proprietary rights in and to this software, related documentation
7
- # and any modifications thereto. Any use, reproduction, disclosure or
8
- # distribution of this software and related documentation without an express
9
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
-
11
- import os
12
- import glob
13
- import torch
14
- import torch.utils.cpp_extension
15
- import importlib
16
- import hashlib
17
- import shutil
18
- from pathlib import Path
19
- import re
20
- import uuid
21
-
22
- from torch.utils.file_baton import FileBaton
23
-
24
- #----------------------------------------------------------------------------
25
- # Global options.
26
-
27
- verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
28
-
29
- #----------------------------------------------------------------------------
30
- # Internal helper funcs.
31
-
32
- def _find_compiler_bindir():
33
- patterns = [
34
- 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
35
- 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
36
- 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
37
- 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
38
- ]
39
- for pattern in patterns:
40
- matches = sorted(glob.glob(pattern))
41
- if len(matches):
42
- return matches[-1]
43
- return None
44
-
45
- def _get_mangled_gpu_name():
46
- name = torch.cuda.get_device_name().lower()
47
- out = []
48
- for c in name:
49
- if re.match('[a-z0-9_-]+', c):
50
- out.append(c)
51
- else:
52
- out.append('-')
53
- return ''.join(out)
54
-
55
-
56
- #----------------------------------------------------------------------------
57
- # Main entry point for compiling and loading C++/CUDA plugins.
58
-
59
- _cached_plugins = dict()
60
-
61
- def get_plugin(module_name, sources, **build_kwargs):
62
- assert verbosity in ['none', 'brief', 'full']
63
-
64
- # Already cached?
65
- if module_name in _cached_plugins:
66
- return _cached_plugins[module_name]
67
-
68
- # Print status.
69
- if verbosity == 'full':
70
- print(f'Setting up PyTorch plugin "{module_name}"...')
71
- elif verbosity == 'brief':
72
- print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
73
-
74
- try: # pylint: disable=too-many-nested-blocks
75
- # Make sure we can find the necessary compiler binaries.
76
- if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
77
- compiler_bindir = _find_compiler_bindir()
78
- if compiler_bindir is None:
79
- raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
80
- os.environ['PATH'] += ';' + compiler_bindir
81
-
82
- # Compile and load.
83
- verbose_build = (verbosity == 'full')
84
-
85
- # Incremental build md5sum trickery. Copies all the input source files
86
- # into a cached build directory under a combined md5 digest of the input
87
- # source files. Copying is done only if the combined digest has changed.
88
- # This keeps input file timestamps and filenames the same as in previous
89
- # extension builds, allowing for fast incremental rebuilds.
90
- #
91
- # This optimization is done only in case all the source files reside in
92
- # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
93
- # environment variable is set (we take this as a signal that the user
94
- # actually cares about this.)
95
- source_dirs_set = set(os.path.dirname(source) for source in sources)
96
- if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ):
97
- all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file()))
98
-
99
- # Compute a combined hash digest for all source files in the same
100
- # custom op directory (usually .cu, .cpp, .py and .h files).
101
- hash_md5 = hashlib.md5()
102
- for src in all_source_files:
103
- with open(src, 'rb') as f:
104
- hash_md5.update(f.read())
105
- build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
106
- digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest())
107
-
108
- if not os.path.isdir(digest_build_dir):
109
- os.makedirs(digest_build_dir, exist_ok=True)
110
- baton = FileBaton(os.path.join(digest_build_dir, 'lock'))
111
- if baton.try_acquire():
112
- try:
113
- for src in all_source_files:
114
- shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src)))
115
- finally:
116
- baton.release()
117
- else:
118
- # Someone else is copying source files under the digest dir,
119
- # wait until done and continue.
120
- baton.wait()
121
- digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources]
122
- torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir,
123
- verbose=verbose_build, sources=digest_sources, **build_kwargs)
124
- else:
125
- torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
126
- module = importlib.import_module(module_name)
127
-
128
- except:
129
- if verbosity == 'brief':
130
- print('Failed!')
131
- raise
132
-
133
- # Print status and add to cache.
134
- if verbosity == 'full':
135
- print(f'Done setting up PyTorch plugin "{module_name}".')
136
- elif verbosity == 'brief':
137
- print('Done.')
138
- _cached_plugins[module_name] = module
139
- return module
140
-
141
- #----------------------------------------------------------------------------
142
- def get_plugin_v3(module_name, sources, headers=None, source_dir=None, **build_kwargs):
143
- assert verbosity in ['none', 'brief', 'full']
144
- if headers is None:
145
- headers = []
146
- if source_dir is not None:
147
- sources = [os.path.join(source_dir, fname) for fname in sources]
148
- headers = [os.path.join(source_dir, fname) for fname in headers]
149
-
150
- # Already cached?
151
- if module_name in _cached_plugins:
152
- return _cached_plugins[module_name]
153
-
154
- # Print status.
155
- if verbosity == 'full':
156
- print(f'Setting up PyTorch plugin "{module_name}"...')
157
- elif verbosity == 'brief':
158
- print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
159
- verbose_build = (verbosity == 'full')
160
-
161
- # Compile and load.
162
- try: # pylint: disable=too-many-nested-blocks
163
- # Make sure we can find the necessary compiler binaries.
164
- if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
165
- compiler_bindir = _find_compiler_bindir()
166
- if compiler_bindir is None:
167
- raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
168
- os.environ['PATH'] += ';' + compiler_bindir
169
-
170
- # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either
171
- # break the build or unnecessarily restrict what's available to nvcc.
172
- # Unset it to let nvcc decide based on what's available on the
173
- # machine.
174
- os.environ['TORCH_CUDA_ARCH_LIST'] = ''
175
-
176
- # Incremental build md5sum trickery. Copies all the input source files
177
- # into a cached build directory under a combined md5 digest of the input
178
- # source files. Copying is done only if the combined digest has changed.
179
- # This keeps input file timestamps and filenames the same as in previous
180
- # extension builds, allowing for fast incremental rebuilds.
181
- #
182
- # This optimization is done only in case all the source files reside in
183
- # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
184
- # environment variable is set (we take this as a signal that the user
185
- # actually cares about this.)
186
- #
187
- # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work
188
- # around the *.cu dependency bug in ninja config.
189
- #
190
- all_source_files = sorted(sources + headers)
191
- all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files)
192
- if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ):
193
-
194
- # Compute combined hash digest for all source files.
195
- hash_md5 = hashlib.md5()
196
- for src in all_source_files:
197
- with open(src, 'rb') as f:
198
- hash_md5.update(f.read())
199
-
200
- # Select cached build directory name.
201
- source_digest = hash_md5.hexdigest()
202
- build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
203
- cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}')
204
-
205
- if not os.path.isdir(cached_build_dir):
206
- tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}'
207
- os.makedirs(tmpdir)
208
- for src in all_source_files:
209
- shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src)))
210
- try:
211
- os.replace(tmpdir, cached_build_dir) # atomic
212
- except OSError:
213
- # source directory already exists, delete tmpdir and its contents.
214
- shutil.rmtree(tmpdir)
215
- if not os.path.isdir(cached_build_dir): raise
216
-
217
- # Compile.
218
- cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources]
219
- torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir,
220
- verbose=verbose_build, sources=cached_sources, **build_kwargs)
221
- else:
222
- torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
223
-
224
- # Load.
225
- module = importlib.import_module(module_name)
226
-
227
- except:
228
- if verbosity == 'brief':
229
- print('Failed!')
230
- raise
231
-
232
- # Print status and add to cache dict.
233
- if verbosity == 'full':
234
- print(f'Done setting up PyTorch plugin "{module_name}".')
235
- elif verbosity == 'brief':
236
- print('Done.')
237
- _cached_plugins[module_name] = module
238
- return module
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/misc.py DELETED
@@ -1,264 +0,0 @@
1
- # Copyright (c) SenseTime Research. All rights reserved.
2
-
3
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
- #
5
- # NVIDIA CORPORATION and its licensors retain all intellectual property
6
- # and proprietary rights in and to this software, related documentation
7
- # and any modifications thereto. Any use, reproduction, disclosure or
8
- # distribution of this software and related documentation without an express
9
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
-
11
- import re
12
- import contextlib
13
- import numpy as np
14
- import torch
15
- import warnings
16
- import dnnlib
17
-
18
- #----------------------------------------------------------------------------
19
- # Cached construction of constant tensors. Avoids CPU=>GPU copy when the
20
- # same constant is used multiple times.
21
-
22
- _constant_cache = dict()
23
-
24
- def constant(value, shape=None, dtype=None, device=None, memory_format=None):
25
- value = np.asarray(value)
26
- if shape is not None:
27
- shape = tuple(shape)
28
- if dtype is None:
29
- dtype = torch.get_default_dtype()
30
- if device is None:
31
- device = torch.device('cpu')
32
- if memory_format is None:
33
- memory_format = torch.contiguous_format
34
-
35
- key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
36
- tensor = _constant_cache.get(key, None)
37
- if tensor is None:
38
- tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
39
- if shape is not None:
40
- tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
41
- tensor = tensor.contiguous(memory_format=memory_format)
42
- _constant_cache[key] = tensor
43
- return tensor
44
-
45
- #----------------------------------------------------------------------------
46
- # Replace NaN/Inf with specified numerical values.
47
-
48
- try:
49
- nan_to_num = torch.nan_to_num # 1.8.0a0
50
- except AttributeError:
51
- def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
52
- assert isinstance(input, torch.Tensor)
53
- if posinf is None:
54
- posinf = torch.finfo(input.dtype).max
55
- if neginf is None:
56
- neginf = torch.finfo(input.dtype).min
57
- assert nan == 0
58
- return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
59
-
60
- #----------------------------------------------------------------------------
61
- # Symbolic assert.
62
-
63
- try:
64
- symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
65
- except AttributeError:
66
- symbolic_assert = torch.Assert # 1.7.0
67
-
68
- #----------------------------------------------------------------------------
69
- # Context manager to suppress known warnings in torch.jit.trace().
70
-
71
- class suppress_tracer_warnings(warnings.catch_warnings):
72
- def __enter__(self):
73
- super().__enter__()
74
- warnings.simplefilter('ignore', category=torch.jit.TracerWarning)
75
- return self
76
-
77
- #----------------------------------------------------------------------------
78
- # Assert that the shape of a tensor matches the given list of integers.
79
- # None indicates that the size of a dimension is allowed to vary.
80
- # Performs symbolic assertion when used in torch.jit.trace().
81
-
82
- def assert_shape(tensor, ref_shape):
83
- if tensor.ndim != len(ref_shape):
84
- raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
85
- for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
86
- if ref_size is None:
87
- pass
88
- elif isinstance(ref_size, torch.Tensor):
89
- with suppress_tracer_warnings(): # as_tensor results are registered as constants
90
- symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
91
- elif isinstance(size, torch.Tensor):
92
- with suppress_tracer_warnings(): # as_tensor results are registered as constants
93
- symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
94
- elif size != ref_size:
95
- raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
96
-
97
- #----------------------------------------------------------------------------
98
- # Function decorator that calls torch.autograd.profiler.record_function().
99
-
100
- def profiled_function(fn):
101
- def decorator(*args, **kwargs):
102
- with torch.autograd.profiler.record_function(fn.__name__):
103
- return fn(*args, **kwargs)
104
- decorator.__name__ = fn.__name__
105
- return decorator
106
-
107
- #----------------------------------------------------------------------------
108
- # Sampler for torch.utils.data.DataLoader that loops over the dataset
109
- # indefinitely, shuffling items as it goes.
110
-
111
- class InfiniteSampler(torch.utils.data.Sampler):
112
- def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
113
- assert len(dataset) > 0
114
- assert num_replicas > 0
115
- assert 0 <= rank < num_replicas
116
- assert 0 <= window_size <= 1
117
- super().__init__(dataset)
118
- self.dataset = dataset
119
- self.rank = rank
120
- self.num_replicas = num_replicas
121
- self.shuffle = shuffle
122
- self.seed = seed
123
- self.window_size = window_size
124
-
125
- def __iter__(self):
126
- order = np.arange(len(self.dataset))
127
- rnd = None
128
- window = 0
129
- if self.shuffle:
130
- rnd = np.random.RandomState(self.seed)
131
- rnd.shuffle(order)
132
- window = int(np.rint(order.size * self.window_size))
133
-
134
- idx = 0
135
- while True:
136
- i = idx % order.size
137
- if idx % self.num_replicas == self.rank:
138
- yield order[i]
139
- if window >= 2:
140
- j = (i - rnd.randint(window)) % order.size
141
- order[i], order[j] = order[j], order[i]
142
- idx += 1
143
-
144
- #----------------------------------------------------------------------------
145
- # Utilities for operating with torch.nn.Module parameters and buffers.
146
-
147
- def params_and_buffers(module):
148
- assert isinstance(module, torch.nn.Module)
149
- return list(module.parameters()) + list(module.buffers())
150
-
151
- def named_params_and_buffers(module):
152
- assert isinstance(module, torch.nn.Module)
153
- return list(module.named_parameters()) + list(module.named_buffers())
154
-
155
- def copy_params_and_buffers(src_module, dst_module, require_all=False):
156
- assert isinstance(src_module, torch.nn.Module)
157
- assert isinstance(dst_module, torch.nn.Module)
158
- src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)}
159
- for name, tensor in named_params_and_buffers(dst_module):
160
- assert (name in src_tensors) or (not require_all)
161
- if name in src_tensors:
162
- tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
163
-
164
- #----------------------------------------------------------------------------
165
- # Context manager for easily enabling/disabling DistributedDataParallel
166
- # synchronization.
167
-
168
- @contextlib.contextmanager
169
- def ddp_sync(module, sync):
170
- assert isinstance(module, torch.nn.Module)
171
- if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
172
- yield
173
- else:
174
- with module.no_sync():
175
- yield
176
-
177
- #----------------------------------------------------------------------------
178
- # Check DistributedDataParallel consistency across processes.
179
-
180
- def check_ddp_consistency(module, ignore_regex=None):
181
- assert isinstance(module, torch.nn.Module)
182
- for name, tensor in named_params_and_buffers(module):
183
- fullname = type(module).__name__ + '.' + name
184
- if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
185
- continue
186
- tensor = tensor.detach()
187
- other = tensor.clone()
188
- torch.distributed.broadcast(tensor=other, src=0)
189
- assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname
190
-
191
- #----------------------------------------------------------------------------
192
- # Print summary table of module hierarchy.
193
-
194
- def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
195
- assert isinstance(module, torch.nn.Module)
196
- assert not isinstance(module, torch.jit.ScriptModule)
197
- assert isinstance(inputs, (tuple, list))
198
-
199
- # Register hooks.
200
- entries = []
201
- nesting = [0]
202
- def pre_hook(_mod, _inputs):
203
- nesting[0] += 1
204
- def post_hook(mod, _inputs, outputs):
205
- nesting[0] -= 1
206
- if nesting[0] <= max_nesting:
207
- outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
208
- outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
209
- entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
210
- hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
211
- hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
212
-
213
- # Run module.
214
- outputs = module(*inputs)
215
- for hook in hooks:
216
- hook.remove()
217
-
218
- # Identify unique outputs, parameters, and buffers.
219
- tensors_seen = set()
220
- for e in entries:
221
- e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
222
- e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
223
- e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
224
- tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
225
-
226
- # Filter out redundant entries.
227
- if skip_redundant:
228
- entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
229
-
230
- # Construct table.
231
- rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
232
- rows += [['---'] * len(rows[0])]
233
- param_total = 0
234
- buffer_total = 0
235
- submodule_names = {mod: name for name, mod in module.named_modules()}
236
- for e in entries:
237
- name = '<top-level>' if e.mod is module else submodule_names[e.mod]
238
- param_size = sum(t.numel() for t in e.unique_params)
239
- buffer_size = sum(t.numel() for t in e.unique_buffers)
240
- output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs]
241
- output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
242
- rows += [[
243
- name + (':0' if len(e.outputs) >= 2 else ''),
244
- str(param_size) if param_size else '-',
245
- str(buffer_size) if buffer_size else '-',
246
- (output_shapes + ['-'])[0],
247
- (output_dtypes + ['-'])[0],
248
- ]]
249
- for idx in range(1, len(e.outputs)):
250
- rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
251
- param_total += param_size
252
- buffer_total += buffer_size
253
- rows += [['---'] * len(rows[0])]
254
- rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
255
-
256
- # Print table.
257
- widths = [max(len(cell) for cell in column) for column in zip(*rows)]
258
- print()
259
- for row in rows:
260
- print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
261
- print()
262
- return outputs
263
-
264
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/models.py DELETED
@@ -1,756 +0,0 @@
1
- # Copyright (c) SenseTime Research. All rights reserved.
2
-
3
- # https://github.com/rosinality/stylegan2-pytorch/blob/master/model.py
4
-
5
- import math
6
- import random
7
- import functools
8
- import operator
9
-
10
- import torch
11
- from torch import nn
12
- from torch.nn import functional as F
13
- import torch.nn.init as init
14
- from torch.autograd import Function
15
-
16
- from .op_edit import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
17
-
18
-
19
- class PixelNorm(nn.Module):
20
- def __init__(self):
21
- super().__init__()
22
-
23
- def forward(self, input):
24
- return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
25
-
26
-
27
- def make_kernel(k):
28
- k = torch.tensor(k, dtype=torch.float32)
29
- if k.ndim == 1:
30
- k = k[None, :] * k[:, None]
31
- k /= k.sum()
32
- return k
33
-
34
-
35
- class Upsample(nn.Module):
36
- def __init__(self, kernel, factor=2):
37
- super().__init__()
38
-
39
- self.factor = factor
40
- kernel = make_kernel(kernel) * (factor ** 2)
41
- self.register_buffer("kernel", kernel)
42
-
43
- p = kernel.shape[0] - factor
44
-
45
- pad0 = (p + 1) // 2 + factor - 1
46
- pad1 = p // 2
47
-
48
- self.pad = (pad0, pad1)
49
-
50
- def forward(self, input):
51
- out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
52
- return out
53
-
54
-
55
- class Downsample(nn.Module):
56
- def __init__(self, kernel, factor=2):
57
- super().__init__()
58
-
59
- self.factor = factor
60
- kernel = make_kernel(kernel)
61
- self.register_buffer("kernel", kernel)
62
-
63
- p = kernel.shape[0] - factor
64
-
65
- pad0 = (p + 1) // 2
66
- pad1 = p // 2
67
-
68
- self.pad = (pad0, pad1)
69
-
70
- def forward(self, input):
71
- out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
72
- return out
73
-
74
-
75
- class Blur(nn.Module):
76
- def __init__(self, kernel, pad, upsample_factor=1):
77
- super().__init__()
78
-
79
- kernel = make_kernel(kernel)
80
-
81
- if upsample_factor > 1:
82
- kernel = kernel * (upsample_factor ** 2)
83
-
84
- self.register_buffer("kernel", kernel)
85
-
86
- self.pad = pad
87
-
88
- def forward(self, input):
89
- out = upfirdn2d(input, self.kernel, pad=self.pad)
90
- return out
91
-
92
-
93
- class EqualConv2d(nn.Module):
94
- def __init__(
95
- self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
96
- ):
97
- super().__init__()
98
-
99
- self.weight = nn.Parameter(
100
- torch.randn(out_channel, in_channel, kernel_size, kernel_size)
101
- )
102
- self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
103
-
104
- self.stride = stride
105
- self.padding = padding
106
-
107
- if bias:
108
- self.bias = nn.Parameter(torch.zeros(out_channel))
109
-
110
- else:
111
- self.bias = None
112
-
113
- def forward(self, input):
114
- out = F.conv2d(
115
- input,
116
- self.weight * self.scale,
117
- bias=self.bias,
118
- stride=self.stride,
119
- padding=self.padding,
120
- )
121
- return out
122
-
123
- def __repr__(self):
124
- return (
125
- f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
126
- f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
127
- )
128
-
129
-
130
- class EqualLinear(nn.Module):
131
- def __init__(
132
- self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
133
- ):
134
- super().__init__()
135
-
136
- self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
137
-
138
- if bias:
139
- self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
140
- else:
141
- self.bias = None
142
-
143
- self.activation = activation
144
-
145
- self.scale = (1 / math.sqrt(in_dim)) * lr_mul
146
- self.lr_mul = lr_mul
147
-
148
- def forward(self, input):
149
- if self.activation:
150
- out = F.linear(input, self.weight * self.scale)
151
- out = fused_leaky_relu(out, self.bias * self.lr_mul)
152
- else:
153
- out = F.linear(
154
- input, self.weight * self.scale, bias=self.bias * self.lr_mul
155
- )
156
- return out
157
-
158
- def __repr__(self):
159
- return (
160
- f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
161
- )
162
-
163
-
164
- class ScaledLeakyReLU(nn.Module):
165
- def __init__(self, negative_slope=0.2):
166
- super().__init__()
167
- self.negative_slope = negative_slope
168
-
169
- def forward(self, input):
170
- out = F.leaky_relu(input, negative_slope=self.negative_slope)
171
- return out * math.sqrt(2)
172
-
173
-
174
- class ModulatedConv2d(nn.Module):
175
- def __init__(
176
- self,
177
- in_channel,
178
- out_channel,
179
- kernel_size,
180
- style_dim,
181
- demodulate=True,
182
- upsample=False,
183
- downsample=False,
184
- blur_kernel=[1, 3, 3, 1],
185
- ):
186
- super().__init__()
187
-
188
- self.eps = 1e-8
189
- self.kernel_size = kernel_size
190
- self.in_channel = in_channel
191
- self.out_channel = out_channel
192
- self.upsample = upsample
193
- self.downsample = downsample
194
-
195
- if upsample:
196
- factor = 2
197
- p = (len(blur_kernel) - factor) - (kernel_size - 1)
198
- pad0 = (p + 1) // 2 + factor - 1
199
- pad1 = p // 2 + 1
200
- self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
201
-
202
- if downsample:
203
- factor = 2
204
- p = (len(blur_kernel) - factor) + (kernel_size - 1)
205
- pad0 = (p + 1) // 2
206
- pad1 = p // 2
207
- self.blur = Blur(blur_kernel, pad=(pad0, pad1))
208
-
209
- fan_in = in_channel * kernel_size ** 2
210
- self.scale = 1 / math.sqrt(fan_in)
211
- self.padding = kernel_size // 2
212
- self.weight = nn.Parameter(
213
- torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
214
- )
215
- self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
216
- self.demodulate = demodulate
217
-
218
- def __repr__(self):
219
- return (
220
- f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
221
- f"upsample={self.upsample}, downsample={self.downsample})"
222
- )
223
-
224
- def forward(self, input, style):
225
- batch, in_channel, height, width = input.shape
226
-
227
- style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
228
- weight = self.scale * self.weight * style
229
-
230
- if self.demodulate:
231
- demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
232
- weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
233
-
234
- weight = weight.view(
235
- batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
236
- )
237
-
238
- if self.upsample:
239
- input = input.view(1, batch * in_channel, height, width)
240
- weight = weight.view(
241
- batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
242
- )
243
- weight = weight.transpose(1, 2).reshape(
244
- batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
245
- )
246
- out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
247
- _, _, height, width = out.shape
248
- out = out.view(batch, self.out_channel, height, width)
249
- out = self.blur(out)
250
-
251
- elif self.downsample:
252
- input = self.blur(input)
253
- _, _, height, width = input.shape
254
- input = input.view(1, batch * in_channel, height, width)
255
- out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
256
- _, _, height, width = out.shape
257
- out = out.view(batch, self.out_channel, height, width)
258
-
259
- else:
260
- input = input.view(1, batch * in_channel, height, width)
261
- out = F.conv2d(input, weight, padding=self.padding, groups=batch)
262
- _, _, height, width = out.shape
263
- out = out.view(batch, self.out_channel, height, width)
264
-
265
- return out
266
-
267
-
268
- class NoiseInjection(nn.Module):
269
- def __init__(self):
270
- super().__init__()
271
- self.weight = nn.Parameter(torch.zeros(1))
272
-
273
- def forward(self, image, noise=None):
274
- if noise is None:
275
- batch, _, height, width = image.shape
276
- noise = image.new_empty(batch, 1, height, width).normal_()
277
- return image + self.weight * noise
278
-
279
-
280
- class ConstantInput(nn.Module):
281
- def __init__(self, channel, size=4):
282
- super().__init__()
283
- self.input = nn.Parameter(torch.randn(1, channel, size, size // 2))
284
-
285
- def forward(self, input):
286
- batch = input.shape[0]
287
- out = self.input.repeat(batch, 1, 1, 1)
288
- return out
289
-
290
-
291
- class StyledConv(nn.Module):
292
- def __init__(
293
- self,
294
- in_channel,
295
- out_channel,
296
- kernel_size,
297
- style_dim,
298
- upsample=False,
299
- blur_kernel=[1, 3, 3, 1],
300
- demodulate=True,
301
- ):
302
- super().__init__()
303
- self.conv = ModulatedConv2d(
304
- in_channel,
305
- out_channel,
306
- kernel_size,
307
- style_dim,
308
- upsample=upsample,
309
- blur_kernel=blur_kernel,
310
- demodulate=demodulate,
311
- )
312
- self.noise = NoiseInjection()
313
- self.activate = FusedLeakyReLU(out_channel)
314
-
315
- def forward(self, input, style, noise=None):
316
- out = self.conv(input, style)
317
- out = self.noise(out, noise=noise)
318
- out = self.activate(out)
319
- return out
320
-
321
-
322
- class ToRGB(nn.Module):
323
- def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
324
- super().__init__()
325
- if upsample:
326
- self.upsample = Upsample(blur_kernel)
327
-
328
- self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
329
- self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
330
-
331
- def forward(self, input, style, skip=None):
332
- out = self.conv(input, style)
333
- out = out + self.bias
334
-
335
- if skip is not None:
336
- skip = self.upsample(skip)
337
- out = out + skip
338
-
339
- return out
340
-
341
-
342
- class Generator(nn.Module):
343
- def __init__(
344
- self,
345
- size,
346
- style_dim,
347
- n_mlp,
348
- channel_multiplier=1,
349
- blur_kernel=[1, 3, 3, 1],
350
- lr_mlp=0.01,
351
- small=False,
352
- small_isaac=False,
353
- ):
354
- super().__init__()
355
-
356
- self.size = size
357
-
358
- if small and size > 64:
359
- raise ValueError("small only works for sizes <= 64")
360
-
361
- self.style_dim = style_dim
362
- layers = [PixelNorm()]
363
-
364
- for i in range(n_mlp):
365
- layers.append(
366
- EqualLinear(
367
- style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
368
- )
369
- )
370
-
371
- self.style = nn.Sequential(*layers)
372
-
373
- if small:
374
- self.channels = {
375
- 4: 64 * channel_multiplier,
376
- 8: 64 * channel_multiplier,
377
- 16: 64 * channel_multiplier,
378
- 32: 64 * channel_multiplier,
379
- 64: 64 * channel_multiplier,
380
- }
381
- elif small_isaac:
382
- self.channels = {4: 256, 8: 256, 16: 256, 32: 256, 64: 128, 128: 128}
383
- else:
384
- self.channels = {
385
- 4: 512,
386
- 8: 512,
387
- 16: 512,
388
- 32: 512,
389
- 64: 256 * channel_multiplier,
390
- 128: 128 * channel_multiplier,
391
- 256: 64 * channel_multiplier,
392
- 512: 32 * channel_multiplier,
393
- 1024: 16 * channel_multiplier,
394
- }
395
-
396
- self.input = ConstantInput(self.channels[4])
397
- self.conv1 = StyledConv(
398
- self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
399
- )
400
- self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
401
-
402
- self.log_size = int(math.log(size, 2))
403
- self.num_layers = (self.log_size - 2) * 2 + 1
404
-
405
- self.convs = nn.ModuleList()
406
- self.upsamples = nn.ModuleList()
407
- self.to_rgbs = nn.ModuleList()
408
- self.noises = nn.Module()
409
-
410
- in_channel = self.channels[4]
411
-
412
- for layer_idx in range(self.num_layers):
413
- res = (layer_idx + 5) // 2
414
- shape = [1, 1, 2 ** res, 2 ** res // 2]
415
- self.noises.register_buffer(
416
- "noise_{}".format(layer_idx), torch.randn(*shape)
417
- )
418
-
419
- for i in range(3, self.log_size + 1):
420
- out_channel = self.channels[2 ** i]
421
-
422
- self.convs.append(
423
- StyledConv(
424
- in_channel,
425
- out_channel,
426
- 3,
427
- style_dim,
428
- upsample=True,
429
- blur_kernel=blur_kernel,
430
- )
431
- )
432
-
433
- self.convs.append(
434
- StyledConv(
435
- out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
436
- )
437
- )
438
-
439
- self.to_rgbs.append(ToRGB(out_channel, style_dim))
440
- in_channel = out_channel
441
-
442
- self.n_latent = self.log_size * 2 - 2
443
-
444
- def make_noise(self):
445
- device = self.input.input.device
446
-
447
- noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2 // 2, device=device)]
448
-
449
- for i in range(3, self.log_size + 1):
450
- for _ in range(2):
451
- noises.append(torch.randn(1, 1, 2 ** i, 2 ** i // 2, device=device))
452
-
453
- return noises
454
-
455
- def mean_latent(self, n_latent):
456
- latent_in = torch.randn(
457
- n_latent, self.style_dim, device=self.input.input.device
458
- )
459
- latent = self.style(latent_in).mean(0, keepdim=True)
460
-
461
- return latent
462
-
463
- def get_latent(self, input):
464
- return self.style(input)
465
-
466
- def forward(
467
- self,
468
- styles,
469
- return_latents=False,
470
- return_features=False,
471
- inject_index=None,
472
- truncation=1,
473
- truncation_latent=None,
474
- input_is_latent=False,
475
- noise=None,
476
- randomize_noise=True,
477
- real=False,
478
- ):
479
- if not input_is_latent:
480
- styles = [self.style(s) for s in styles]
481
- if noise is None:
482
- if randomize_noise:
483
- noise = [None] * self.num_layers
484
- else:
485
- noise = [
486
- getattr(self.noises, "noise_{}".format(i))
487
- for i in range(self.num_layers)
488
- ]
489
-
490
- if truncation < 1:
491
- # print('truncation_latent: ', truncation_latent.shape)
492
- if not real: #if type(styles) == list:
493
- style_t = []
494
- for style in styles:
495
- style_t.append(
496
- truncation_latent + truncation * (style - truncation_latent)
497
- ) # (-1.1162e-03-(-1.0914e-01))*0.8+(-1.0914e-01)
498
- styles = style_t
499
- else: # styles are latent (tensor: 1,18,512), for real PTI output
500
- truncation_latent = truncation_latent.repeat(18,1).unsqueeze(0) # (1,512) --> (1,18,512)
501
- styles = torch.add(truncation_latent,torch.mul(torch.sub(styles,truncation_latent),truncation))
502
- # print('now styles after truncation : ', styles)
503
- #if type(styles) == list and len(styles) < 2: # this if for input as list of [(1,512)]
504
- if not real:
505
- if len(styles) < 2:
506
- inject_index = self.n_latent
507
- if styles[0].ndim < 3:
508
- latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
509
- else:
510
- latent = styles[0]
511
- elif type(styles) == list:
512
- if inject_index is None:
513
- inject_index = 4
514
-
515
- latent = styles[0].unsqueeze(0)
516
- if latent.shape[1] == 1:
517
- latent = latent.repeat(1, inject_index, 1)
518
- else:
519
- latent = latent[:, :inject_index, :]
520
- latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
521
- latent = torch.cat([latent, latent2], 1)
522
- else: # input is tensor of size with torch.Size([1, 18, 512]), for real PTI output
523
- latent = styles
524
-
525
- # print(f'processed latent: {latent.shape}')
526
-
527
- features = {}
528
- out = self.input(latent)
529
- features["out_0"] = out
530
- out = self.conv1(out, latent[:, 0], noise=noise[0])
531
- features["conv1_0"] = out
532
-
533
- skip = self.to_rgb1(out, latent[:, 1])
534
- features["skip_0"] = skip
535
- i = 1
536
- for conv1, conv2, noise1, noise2, to_rgb in zip(
537
- self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
538
- ):
539
- out = conv1(out, latent[:, i], noise=noise1)
540
- features["conv1_{}".format(i)] = out
541
- out = conv2(out, latent[:, i + 1], noise=noise2)
542
- features["conv2_{}".format(i)] = out
543
- skip = to_rgb(out, latent[:, i + 2], skip)
544
- features["skip_{}".format(i)] = skip
545
-
546
- i += 2
547
-
548
- image = skip
549
-
550
- if return_latents:
551
- return image, latent
552
- elif return_features:
553
- return image, features
554
- else:
555
- return image, None
556
-
557
-
558
- class ConvLayer(nn.Sequential):
559
- def __init__(
560
- self,
561
- in_channel,
562
- out_channel,
563
- kernel_size,
564
- downsample=False,
565
- blur_kernel=[1, 3, 3, 1],
566
- bias=True,
567
- activate=True,
568
- ):
569
- layers = []
570
-
571
- if downsample:
572
- factor = 2
573
- p = (len(blur_kernel) - factor) + (kernel_size - 1)
574
- pad0 = (p + 1) // 2
575
- pad1 = p // 2
576
-
577
- layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
578
-
579
- stride = 2
580
- self.padding = 0
581
-
582
- else:
583
- stride = 1
584
- self.padding = kernel_size // 2
585
-
586
- layers.append(
587
- EqualConv2d(
588
- in_channel,
589
- out_channel,
590
- kernel_size,
591
- padding=self.padding,
592
- stride=stride,
593
- bias=bias and not activate,
594
- )
595
- )
596
-
597
- if activate:
598
- if bias:
599
- layers.append(FusedLeakyReLU(out_channel))
600
- else:
601
- layers.append(ScaledLeakyReLU(0.2))
602
-
603
- super().__init__(*layers)
604
-
605
-
606
- class ResBlock(nn.Module):
607
- def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
608
- super().__init__()
609
-
610
- self.conv1 = ConvLayer(in_channel, in_channel, 3)
611
- self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
612
-
613
- self.skip = ConvLayer(
614
- in_channel, out_channel, 1, downsample=True, activate=False, bias=False
615
- )
616
-
617
- def forward(self, input):
618
- out = self.conv1(input)
619
- out = self.conv2(out)
620
-
621
- skip = self.skip(input)
622
- out = (out + skip) / math.sqrt(2)
623
-
624
- return out
625
-
626
-
627
- class StyleDiscriminator(nn.Module):
628
- def __init__(
629
- self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1], small=False
630
- ):
631
- super().__init__()
632
-
633
- if small:
634
- channels = {4: 64, 8: 64, 16: 64, 32: 64, 64: 64}
635
-
636
- else:
637
- channels = {
638
- 4: 512,
639
- 8: 512,
640
- 16: 512,
641
- 32: 512,
642
- 64: 256 * channel_multiplier,
643
- 128: 128 * channel_multiplier,
644
- 256: 64 * channel_multiplier,
645
- 512: 32 * channel_multiplier,
646
- 1024: 16 * channel_multiplier,
647
- }
648
-
649
- convs = [ConvLayer(3, channels[size], 1)]
650
-
651
- log_size = int(math.log(size, 2))
652
- in_channel = channels[size]
653
-
654
- for i in range(log_size, 2, -1):
655
- out_channel = channels[2 ** (i - 1)]
656
-
657
- convs.append(ResBlock(in_channel, out_channel, blur_kernel))
658
-
659
- in_channel = out_channel
660
-
661
- self.convs = nn.Sequential(*convs)
662
-
663
- self.stddev_group = 4
664
- self.stddev_feat = 1
665
-
666
- self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
667
- self.final_linear = nn.Sequential(
668
- EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
669
- EqualLinear(channels[4], 1),
670
- )
671
-
672
-
673
- def forward(self, input):
674
- h = input
675
- h_list = []
676
-
677
- for index, blocklist in enumerate(self.convs):
678
- h = blocklist(h)
679
- h_list.append(h)
680
-
681
- out = h
682
- batch, channel, height, width = out.shape
683
- group = min(batch, self.stddev_group)
684
- stddev = out.view(
685
- group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
686
- )
687
- stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
688
- stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
689
- stddev = stddev.repeat(group, 1, height, width)
690
- out = torch.cat([out, stddev], 1)
691
-
692
- out = self.final_conv(out)
693
- h_list.append(out)
694
-
695
- out = out.view(batch, -1)
696
- out = self.final_linear(out)
697
-
698
- return out, h_list
699
-
700
-
701
- class StyleEncoder(nn.Module):
702
- def __init__(self, size, w_dim=512):
703
- super().__init__()
704
-
705
- channels = {
706
- 4: 512,
707
- 8: 512,
708
- 16: 512,
709
- 32: 512,
710
- 64: 256,
711
- 128: 128,
712
- 256: 64,
713
- 512: 32,
714
- 1024: 16
715
- }
716
-
717
- self.w_dim = w_dim
718
- log_size = int(math.log(size, 2))
719
- convs = [ConvLayer(3, channels[size], 1)]
720
-
721
- in_channel = channels[size]
722
- for i in range(log_size, 2, -1):
723
- out_channel = channels[2 ** (i - 1)]
724
- convs.append(ResBlock(in_channel, out_channel))
725
- in_channel = out_channel
726
-
727
- convs.append(EqualConv2d(in_channel,2*self.w_dim, 4, padding=0, bias=False))
728
-
729
- self.convs = nn.Sequential(*convs)
730
-
731
- def forward(self, input):
732
- out = self.convs(input)
733
- # return out.view(len(input), self.n_latents, self.w_dim)
734
- reshaped = out.view(len(input), 2*self.w_dim)
735
- return reshaped[:,:self.w_dim], reshaped[:,self.w_dim:]
736
-
737
- def kaiming_init(m):
738
- if isinstance(m, (nn.Linear, nn.Conv2d)):
739
- init.kaiming_normal_(m.weight)
740
- if m.bias is not None:
741
- m.bias.data.fill_(0)
742
- elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
743
- m.weight.data.fill_(1)
744
- if m.bias is not None:
745
- m.bias.data.fill_(0)
746
-
747
-
748
- def normal_init(m):
749
- if isinstance(m, (nn.Linear, nn.Conv2d)):
750
- init.normal_(m.weight, 0, 0.02)
751
- if m.bias is not None:
752
- m.bias.data.fill_(0)
753
- elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
754
- m.weight.data.fill_(1)
755
- if m.bias is not None:
756
- m.bias.data.fill_(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/models_face.py DELETED
@@ -1,809 +0,0 @@
1
- # Copyright (c) SenseTime Research. All rights reserved.
2
-
3
- import math
4
- import random
5
- import functools
6
- import operator
7
-
8
- import torch
9
- from torch import nn
10
- from torch.nn import functional as F
11
- import torch.nn.init as init
12
- from torch.autograd import Function
13
-
14
- from .op_edit import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
15
-
16
-
17
- class PixelNorm(nn.Module):
18
- def __init__(self):
19
- super().__init__()
20
-
21
- def forward(self, input):
22
- return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
23
-
24
-
25
- def make_kernel(k):
26
- k = torch.tensor(k, dtype=torch.float32)
27
-
28
- if k.ndim == 1:
29
- k = k[None, :] * k[:, None]
30
-
31
- k /= k.sum()
32
-
33
- return k
34
-
35
-
36
- class Upsample(nn.Module):
37
- def __init__(self, kernel, factor=2):
38
- super().__init__()
39
-
40
- self.factor = factor
41
- kernel = make_kernel(kernel) * (factor ** 2)
42
- self.register_buffer("kernel", kernel)
43
-
44
- p = kernel.shape[0] - factor
45
-
46
- pad0 = (p + 1) // 2 + factor - 1
47
- pad1 = p // 2
48
-
49
- self.pad = (pad0, pad1)
50
-
51
- def forward(self, input):
52
- out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
53
-
54
- return out
55
-
56
-
57
- class Downsample(nn.Module):
58
- def __init__(self, kernel, factor=2):
59
- super().__init__()
60
-
61
- self.factor = factor
62
- kernel = make_kernel(kernel)
63
- self.register_buffer("kernel", kernel)
64
-
65
- p = kernel.shape[0] - factor
66
-
67
- pad0 = (p + 1) // 2
68
- pad1 = p // 2
69
-
70
- self.pad = (pad0, pad1)
71
-
72
- def forward(self, input):
73
- out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
74
-
75
- return out
76
-
77
-
78
- class Blur(nn.Module):
79
- def __init__(self, kernel, pad, upsample_factor=1):
80
- super().__init__()
81
-
82
- kernel = make_kernel(kernel)
83
-
84
- if upsample_factor > 1:
85
- kernel = kernel * (upsample_factor ** 2)
86
-
87
- self.register_buffer("kernel", kernel)
88
-
89
- self.pad = pad
90
-
91
- def forward(self, input):
92
- out = upfirdn2d(input, self.kernel, pad=self.pad)
93
-
94
- return out
95
-
96
-
97
- class EqualConv2d(nn.Module):
98
- def __init__(
99
- self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
100
- ):
101
- super().__init__()
102
-
103
- self.weight = nn.Parameter(
104
- torch.randn(out_channel, in_channel, kernel_size, kernel_size)
105
- )
106
- self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
107
-
108
- self.stride = stride
109
- self.padding = padding
110
-
111
- if bias:
112
- self.bias = nn.Parameter(torch.zeros(out_channel))
113
-
114
- else:
115
- self.bias = None
116
-
117
- def forward(self, input):
118
- out = F.conv2d(
119
- input,
120
- self.weight * self.scale,
121
- bias=self.bias,
122
- stride=self.stride,
123
- padding=self.padding,
124
- )
125
-
126
- return out
127
-
128
- def __repr__(self):
129
- return (
130
- f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
131
- f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
132
- )
133
-
134
-
135
- class EqualLinear(nn.Module):
136
- def __init__(
137
- self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
138
- ):
139
- super().__init__()
140
-
141
- self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
142
-
143
- if bias:
144
- self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
145
-
146
- else:
147
- self.bias = None
148
-
149
- self.activation = activation
150
-
151
- self.scale = (1 / math.sqrt(in_dim)) * lr_mul
152
- self.lr_mul = lr_mul
153
-
154
- def forward(self, input):
155
- if self.activation:
156
- out = F.linear(input, self.weight * self.scale)
157
- out = fused_leaky_relu(out, self.bias * self.lr_mul)
158
-
159
- else:
160
- out = F.linear(
161
- input, self.weight * self.scale, bias=self.bias * self.lr_mul
162
- )
163
-
164
- return out
165
-
166
- def __repr__(self):
167
- return (
168
- f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
169
- )
170
-
171
-
172
- class ScaledLeakyReLU(nn.Module):
173
- def __init__(self, negative_slope=0.2):
174
- super().__init__()
175
-
176
- self.negative_slope = negative_slope
177
-
178
- def forward(self, input):
179
- out = F.leaky_relu(input, negative_slope=self.negative_slope)
180
-
181
- return out * math.sqrt(2)
182
-
183
-
184
- class ModulatedConv2d(nn.Module):
185
- def __init__(
186
- self,
187
- in_channel,
188
- out_channel,
189
- kernel_size,
190
- style_dim,
191
- demodulate=True,
192
- upsample=False,
193
- downsample=False,
194
- blur_kernel=[1, 3, 3, 1],
195
- ):
196
- super().__init__()
197
-
198
- self.eps = 1e-8
199
- self.kernel_size = kernel_size
200
- self.in_channel = in_channel
201
- self.out_channel = out_channel
202
- self.upsample = upsample
203
- self.downsample = downsample
204
-
205
- if upsample:
206
- factor = 2
207
- p = (len(blur_kernel) - factor) - (kernel_size - 1)
208
- pad0 = (p + 1) // 2 + factor - 1
209
- pad1 = p // 2 + 1
210
-
211
- self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
212
-
213
- if downsample:
214
- factor = 2
215
- p = (len(blur_kernel) - factor) + (kernel_size - 1)
216
- pad0 = (p + 1) // 2
217
- pad1 = p // 2
218
-
219
- self.blur = Blur(blur_kernel, pad=(pad0, pad1))
220
-
221
- fan_in = in_channel * kernel_size ** 2
222
- self.scale = 1 / math.sqrt(fan_in)
223
- self.padding = kernel_size // 2
224
-
225
- self.weight = nn.Parameter(
226
- torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
227
- )
228
-
229
- self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
230
-
231
- self.demodulate = demodulate
232
-
233
- def __repr__(self):
234
- return (
235
- f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
236
- f"upsample={self.upsample}, downsample={self.downsample})"
237
- )
238
-
239
- def forward(self, input, style):
240
- batch, in_channel, height, width = input.shape
241
-
242
- style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
243
- weight = self.scale * self.weight * style
244
-
245
- if self.demodulate:
246
- demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
247
- weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
248
-
249
- weight = weight.view(
250
- batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
251
- )
252
-
253
- if self.upsample:
254
- input = input.view(1, batch * in_channel, height, width)
255
- weight = weight.view(
256
- batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
257
- )
258
- weight = weight.transpose(1, 2).reshape(
259
- batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
260
- )
261
- out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
262
- _, _, height, width = out.shape
263
- out = out.view(batch, self.out_channel, height, width)
264
- out = self.blur(out)
265
-
266
- elif self.downsample:
267
- input = self.blur(input)
268
- _, _, height, width = input.shape
269
- input = input.view(1, batch * in_channel, height, width)
270
- out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
271
- _, _, height, width = out.shape
272
- out = out.view(batch, self.out_channel, height, width)
273
-
274
- else:
275
- input = input.view(1, batch * in_channel, height, width)
276
- out = F.conv2d(input, weight, padding=self.padding, groups=batch)
277
- _, _, height, width = out.shape
278
- out = out.view(batch, self.out_channel, height, width)
279
-
280
- return out
281
-
282
-
283
- class NoiseInjection(nn.Module):
284
- def __init__(self):
285
- super().__init__()
286
-
287
- self.weight = nn.Parameter(torch.zeros(1))
288
-
289
- def forward(self, image, noise=None):
290
- if noise is None:
291
- batch, _, height, width = image.shape
292
- noise = image.new_empty(batch, 1, height, width).normal_()
293
-
294
- return image + self.weight * noise
295
-
296
-
297
- class ConstantInput(nn.Module):
298
- def __init__(self, channel, size=4):
299
- super().__init__()
300
-
301
- self.input = nn.Parameter(torch.randn(1, channel, size, size))
302
-
303
- def forward(self, input):
304
- batch = input.shape[0]
305
- out = self.input.repeat(batch, 1, 1, 1)
306
-
307
- return out
308
-
309
-
310
- class StyledConv(nn.Module):
311
- def __init__(
312
- self,
313
- in_channel,
314
- out_channel,
315
- kernel_size,
316
- style_dim,
317
- upsample=False,
318
- blur_kernel=[1, 3, 3, 1],
319
- demodulate=True,
320
- ):
321
- super().__init__()
322
-
323
- self.conv = ModulatedConv2d(
324
- in_channel,
325
- out_channel,
326
- kernel_size,
327
- style_dim,
328
- upsample=upsample,
329
- blur_kernel=blur_kernel,
330
- demodulate=demodulate,
331
- )
332
-
333
- self.noise = NoiseInjection()
334
- # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
335
- # self.activate = ScaledLeakyReLU(0.2)
336
- self.activate = FusedLeakyReLU(out_channel)
337
-
338
- def forward(self, input, style, noise=None):
339
- out = self.conv(input, style)
340
- out = self.noise(out, noise=noise)
341
- # out = out + self.bias
342
- out = self.activate(out)
343
-
344
- return out
345
-
346
-
347
- class ToRGB(nn.Module):
348
- def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
349
- super().__init__()
350
-
351
- if upsample:
352
- self.upsample = Upsample(blur_kernel)
353
-
354
- self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
355
- self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
356
-
357
- def forward(self, input, style, skip=None):
358
- out = self.conv(input, style)
359
- out = out + self.bias
360
-
361
- if skip is not None:
362
- skip = self.upsample(skip)
363
-
364
- out = out + skip
365
-
366
- return out
367
-
368
-
369
- class Generator(nn.Module):
370
- def __init__(
371
- self,
372
- size,
373
- style_dim,
374
- n_mlp,
375
- channel_multiplier=1,
376
- blur_kernel=[1, 3, 3, 1],
377
- lr_mlp=0.01,
378
- small=False,
379
- small_isaac=False,
380
- ):
381
- super().__init__()
382
-
383
- self.size = size
384
-
385
- if small and size > 64:
386
- raise ValueError("small only works for sizes <= 64")
387
-
388
- self.style_dim = style_dim
389
-
390
- layers = [PixelNorm()]
391
-
392
- for i in range(n_mlp):
393
- layers.append(
394
- EqualLinear(
395
- style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
396
- )
397
- )
398
-
399
- self.style = nn.Sequential(*layers)
400
-
401
- if small:
402
- self.channels = {
403
- 4: 64 * channel_multiplier,
404
- 8: 64 * channel_multiplier,
405
- 16: 64 * channel_multiplier,
406
- 32: 64 * channel_multiplier,
407
- 64: 64 * channel_multiplier,
408
- }
409
- elif small_isaac:
410
- self.channels = {4: 256, 8: 256, 16: 256, 32: 256, 64: 128, 128: 128}
411
- else:
412
- self.channels = {
413
- 4: 512,
414
- 8: 512,
415
- 16: 512,
416
- 32: 512,
417
- 64: 256 * channel_multiplier,
418
- 128: 128 * channel_multiplier,
419
- 256: 64 * channel_multiplier,
420
- 512: 32 * channel_multiplier,
421
- 1024: 16 * channel_multiplier,
422
- }
423
-
424
- self.input = ConstantInput(self.channels[4])
425
- self.conv1 = StyledConv(
426
- self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
427
- )
428
- self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
429
-
430
- self.log_size = int(math.log(size, 2))
431
- self.num_layers = (self.log_size - 2) * 2 + 1
432
-
433
- self.convs = nn.ModuleList()
434
- self.upsamples = nn.ModuleList()
435
- self.to_rgbs = nn.ModuleList()
436
- self.noises = nn.Module()
437
-
438
- in_channel = self.channels[4]
439
-
440
- for layer_idx in range(self.num_layers):
441
- res = (layer_idx + 5) // 2
442
- shape = [1, 1, 2 ** res, 2 ** res]
443
- self.noises.register_buffer(
444
- "noise_{}".format(layer_idx), torch.randn(*shape)
445
- )
446
-
447
- for i in range(3, self.log_size + 1):
448
- out_channel = self.channels[2 ** i]
449
-
450
- self.convs.append(
451
- StyledConv(
452
- in_channel,
453
- out_channel,
454
- 3,
455
- style_dim,
456
- upsample=True,
457
- blur_kernel=blur_kernel,
458
- )
459
- )
460
-
461
- self.convs.append(
462
- StyledConv(
463
- out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
464
- )
465
- )
466
-
467
- self.to_rgbs.append(ToRGB(out_channel, style_dim))
468
-
469
- in_channel = out_channel
470
-
471
- self.n_latent = self.log_size * 2 - 2
472
-
473
- def make_noise(self):
474
- device = self.input.input.device
475
-
476
- noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
477
-
478
- for i in range(3, self.log_size + 1):
479
- for _ in range(2):
480
- noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
481
-
482
- return noises
483
-
484
- def mean_latent(self, n_latent):
485
- latent_in = torch.randn(
486
- n_latent, self.style_dim, device=self.input.input.device
487
- )
488
- latent = self.style(latent_in).mean(0, keepdim=True)
489
-
490
- return latent
491
-
492
- def get_latent(self, input):
493
- return self.style(input)
494
-
495
- def forward(
496
- self,
497
- styles,
498
- return_latents=False,
499
- return_features=False,
500
- inject_index=None,
501
- truncation=1,
502
- truncation_latent=None,
503
- input_is_latent=False,
504
- noise=None,
505
- randomize_noise=True,
506
- ):
507
- if not input_is_latent:
508
- # print("haha")
509
- styles = [self.style(s) for s in styles]
510
- if noise is None:
511
- if randomize_noise:
512
- noise = [None] * self.num_layers
513
- else:
514
- noise = [
515
- getattr(self.noises, "noise_{}".format(i))
516
- for i in range(self.num_layers)
517
- ]
518
-
519
- if truncation < 1:
520
- style_t = []
521
-
522
- for style in styles:
523
- style_t.append(
524
- truncation_latent + truncation * (style - truncation_latent)
525
- )
526
-
527
- styles = style_t
528
- # print(styles)
529
- if len(styles) < 2:
530
- inject_index = self.n_latent
531
-
532
- if styles[0].ndim < 3:
533
- latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
534
- # print("a")
535
- else:
536
- # print(len(styles))
537
- latent = styles[0]
538
- # print("b", latent.shape)
539
-
540
- else:
541
- # print("c")
542
- if inject_index is None:
543
- inject_index = 4
544
-
545
- latent = styles[0].unsqueeze(0)
546
- if latent.shape[1] == 1:
547
- latent = latent.repeat(1, inject_index, 1)
548
- else:
549
- latent = latent[:, :inject_index, :]
550
- latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
551
-
552
- latent = torch.cat([latent, latent2], 1)
553
-
554
- features = {}
555
- out = self.input(latent)
556
- features["out_0"] = out
557
- out = self.conv1(out, latent[:, 0], noise=noise[0])
558
- features["conv1_0"] = out
559
-
560
- skip = self.to_rgb1(out, latent[:, 1])
561
- features["skip_0"] = skip
562
- i = 1
563
- for conv1, conv2, noise1, noise2, to_rgb in zip(
564
- self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
565
- ):
566
- out = conv1(out, latent[:, i], noise=noise1)
567
- features["conv1_{}".format(i)] = out
568
- out = conv2(out, latent[:, i + 1], noise=noise2)
569
- features["conv2_{}".format(i)] = out
570
- skip = to_rgb(out, latent[:, i + 2], skip)
571
- features["skip_{}".format(i)] = skip
572
-
573
- i += 2
574
-
575
- image = skip
576
-
577
- if return_latents:
578
- return image, latent
579
- elif return_features:
580
- return image, features
581
- else:
582
- return image, None
583
-
584
-
585
- class ConvLayer(nn.Sequential):
586
- def __init__(
587
- self,
588
- in_channel,
589
- out_channel,
590
- kernel_size,
591
- downsample=False,
592
- blur_kernel=[1, 3, 3, 1],
593
- bias=True,
594
- activate=True,
595
- ):
596
- layers = []
597
-
598
- if downsample:
599
- factor = 2
600
- p = (len(blur_kernel) - factor) + (kernel_size - 1)
601
- pad0 = (p + 1) // 2
602
- pad1 = p // 2
603
-
604
- layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
605
-
606
- stride = 2
607
- self.padding = 0
608
-
609
- else:
610
- stride = 1
611
- self.padding = kernel_size // 2
612
-
613
- layers.append(
614
- EqualConv2d(
615
- in_channel,
616
- out_channel,
617
- kernel_size,
618
- padding=self.padding,
619
- stride=stride,
620
- bias=bias and not activate,
621
- )
622
- )
623
-
624
- if activate:
625
- if bias:
626
- layers.append(FusedLeakyReLU(out_channel))
627
-
628
- else:
629
- layers.append(ScaledLeakyReLU(0.2))
630
-
631
- super().__init__(*layers)
632
-
633
-
634
- class ResBlock(nn.Module):
635
- def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
636
- super().__init__()
637
-
638
- self.conv1 = ConvLayer(in_channel, in_channel, 3)
639
- self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
640
-
641
- self.skip = ConvLayer(
642
- in_channel, out_channel, 1, downsample=True, activate=False, bias=False
643
- )
644
-
645
- def forward(self, input):
646
- out = self.conv1(input)
647
- out = self.conv2(out)
648
-
649
- skip = self.skip(input)
650
- out = (out + skip) / math.sqrt(2)
651
-
652
- return out
653
-
654
-
655
- class StyleDiscriminator(nn.Module):
656
- def __init__(
657
- self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1], small=False
658
- ):
659
- super().__init__()
660
-
661
- if small:
662
- channels = {4: 64, 8: 64, 16: 64, 32: 64, 64: 64}
663
-
664
- else:
665
- channels = {
666
- 4: 512,
667
- 8: 512,
668
- 16: 512,
669
- 32: 512,
670
- 64: 256 * channel_multiplier,
671
- 128: 128 * channel_multiplier,
672
- 256: 64 * channel_multiplier,
673
- 512: 32 * channel_multiplier,
674
- 1024: 16 * channel_multiplier,
675
- }
676
-
677
- convs = [ConvLayer(3, channels[size], 1)]
678
-
679
- log_size = int(math.log(size, 2))
680
-
681
- in_channel = channels[size]
682
-
683
- for i in range(log_size, 2, -1):
684
- out_channel = channels[2 ** (i - 1)]
685
-
686
- convs.append(ResBlock(in_channel, out_channel, blur_kernel))
687
-
688
- in_channel = out_channel
689
-
690
- self.convs = nn.Sequential(*convs)
691
-
692
- self.stddev_group = 4
693
- self.stddev_feat = 1
694
-
695
- self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
696
- self.final_linear = nn.Sequential(
697
- EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
698
- EqualLinear(channels[4], 1),
699
- )
700
-
701
- # def forward(self, input):
702
- # out = self.convs(input)
703
-
704
- # batch, channel, height, width = out.shape
705
- # group = min(batch, self.stddev_group)
706
- # stddev = out.view(
707
- # group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
708
- # )
709
- # stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
710
- # stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
711
- # stddev = stddev.repeat(group, 1, height, width)
712
- # out = torch.cat([out, stddev], 1)
713
-
714
- # out = self.final_conv(out)
715
-
716
- # out = out.view(batch, -1)
717
- # out = self.final_linear(out)
718
-
719
- # return out
720
-
721
- def forward(self, input):
722
- h = input
723
- h_list = []
724
-
725
- for index, blocklist in enumerate(self.convs):
726
- h = blocklist(h)
727
- h_list.append(h)
728
-
729
- out = h
730
- batch, channel, height, width = out.shape
731
- group = min(batch, self.stddev_group)
732
- stddev = out.view(
733
- group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
734
- )
735
- stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
736
- stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
737
- stddev = stddev.repeat(group, 1, height, width)
738
- out = torch.cat([out, stddev], 1)
739
-
740
- out = self.final_conv(out)
741
- h_list.append(out)
742
-
743
- out = out.view(batch, -1)
744
- out = self.final_linear(out)
745
-
746
- return out, h_list
747
-
748
-
749
- class StyleEncoder(nn.Module):
750
- def __init__(self, size, w_dim=512):
751
- super().__init__()
752
-
753
- channels = {
754
- 4: 512,
755
- 8: 512,
756
- 16: 512,
757
- 32: 512,
758
- 64: 256,
759
- 128: 128,
760
- 256: 64,
761
- 512: 32,
762
- 1024: 16
763
- }
764
-
765
- self.w_dim = w_dim
766
- log_size = int(math.log(size, 2))
767
-
768
- # self.n_latents = log_size*2 - 2
769
-
770
- convs = [ConvLayer(3, channels[size], 1)]
771
-
772
- in_channel = channels[size]
773
- for i in range(log_size, 2, -1):
774
- out_channel = channels[2 ** (i - 1)]
775
- convs.append(ResBlock(in_channel, out_channel))
776
- in_channel = out_channel
777
-
778
- # convs.append(EqualConv2d(in_channel, self.n_latents*self.w_dim, 4, padding=0, bias=False))
779
- convs.append(EqualConv2d(in_channel,2*self.w_dim, 4, padding=0, bias=False))
780
-
781
-
782
- self.convs = nn.Sequential(*convs)
783
-
784
- def forward(self, input):
785
- out = self.convs(input)
786
- # return out.view(len(input), self.n_latents, self.w_dim)
787
- reshaped = out.view(len(input), 2*self.w_dim)
788
- return reshaped[:,:self.w_dim], reshaped[:,self.w_dim:]
789
-
790
- def kaiming_init(m):
791
- if isinstance(m, (nn.Linear, nn.Conv2d)):
792
- init.kaiming_normal_(m.weight)
793
- if m.bias is not None:
794
- m.bias.data.fill_(0)
795
- elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
796
- m.weight.data.fill_(1)
797
- if m.bias is not None:
798
- m.bias.data.fill_(0)
799
-
800
-
801
- def normal_init(m):
802
- if isinstance(m, (nn.Linear, nn.Conv2d)):
803
- init.normal_(m.weight, 0, 0.02)
804
- if m.bias is not None:
805
- m.bias.data.fill_(0)
806
- elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
807
- m.weight.data.fill_(1)
808
- if m.bias is not None:
809
- m.bias.data.fill_(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/op_edit/__init__.py DELETED
@@ -1,4 +0,0 @@
1
- # Copyright (c) SenseTime Research. All rights reserved.
2
-
3
- from .fused_act import FusedLeakyReLU, fused_leaky_relu
4
- from .upfirdn2d import upfirdn2d
 
 
 
 
 
torch_utils/op_edit/fused_act.py DELETED
@@ -1,99 +0,0 @@
1
- # Copyright (c) SenseTime Research. All rights reserved.
2
-
3
- import os
4
-
5
- import torch
6
- from torch import nn
7
- from torch.nn import functional as F
8
- from torch.autograd import Function
9
- from torch.utils.cpp_extension import load
10
-
11
-
12
- module_path = os.path.dirname(__file__)
13
- fused = load(
14
- "fused",
15
- sources=[
16
- os.path.join(module_path, "fused_bias_act.cpp"),
17
- os.path.join(module_path, "fused_bias_act_kernel.cu"),
18
- ],
19
- )
20
-
21
-
22
- class FusedLeakyReLUFunctionBackward(Function):
23
- @staticmethod
24
- def forward(ctx, grad_output, out, negative_slope, scale):
25
- ctx.save_for_backward(out)
26
- ctx.negative_slope = negative_slope
27
- ctx.scale = scale
28
-
29
- empty = grad_output.new_empty(0)
30
-
31
- grad_input = fused.fused_bias_act(
32
- grad_output, empty, out, 3, 1, negative_slope, scale
33
- )
34
-
35
- dim = [0]
36
-
37
- if grad_input.ndim > 2:
38
- dim += list(range(2, grad_input.ndim))
39
-
40
- grad_bias = grad_input.sum(dim).detach()
41
-
42
- return grad_input, grad_bias
43
-
44
- @staticmethod
45
- def backward(ctx, gradgrad_input, gradgrad_bias):
46
- (out,) = ctx.saved_tensors
47
- gradgrad_out = fused.fused_bias_act(
48
- gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
49
- )
50
-
51
- return gradgrad_out, None, None, None
52
-
53
-
54
- class FusedLeakyReLUFunction(Function):
55
- @staticmethod
56
- def forward(ctx, input, bias, negative_slope, scale):
57
- empty = input.new_empty(0)
58
- out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
59
- ctx.save_for_backward(out)
60
- ctx.negative_slope = negative_slope
61
- ctx.scale = scale
62
-
63
- return out
64
-
65
- @staticmethod
66
- def backward(ctx, grad_output):
67
- (out,) = ctx.saved_tensors
68
-
69
- grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
70
- grad_output, out, ctx.negative_slope, ctx.scale
71
- )
72
-
73
- return grad_input, grad_bias, None, None
74
-
75
-
76
- class FusedLeakyReLU(nn.Module):
77
- def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
78
- super().__init__()
79
-
80
- self.bias = nn.Parameter(torch.zeros(channel))
81
- self.negative_slope = negative_slope
82
- self.scale = scale
83
-
84
- def forward(self, input):
85
- return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
86
-
87
-
88
- def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
89
- if input.device.type == "cpu":
90
- rest_dim = [1] * (input.ndim - bias.ndim - 1)
91
- return (
92
- F.leaky_relu(
93
- input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
94
- )
95
- * scale
96
- )
97
-
98
- else:
99
- return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/op_edit/fused_bias_act.cpp DELETED
@@ -1,23 +0,0 @@
1
- // Copyright (c) SenseTime Research. All rights reserved.
2
-
3
- #include <torch/extension.h>
4
-
5
-
6
- torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
7
- int act, int grad, float alpha, float scale);
8
-
9
- #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
10
- #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
11
- #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
12
-
13
- torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
14
- int act, int grad, float alpha, float scale) {
15
- CHECK_CUDA(input);
16
- CHECK_CUDA(bias);
17
-
18
- return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
19
- }
20
-
21
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
22
- m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/op_edit/fused_bias_act_kernel.cu DELETED
@@ -1,101 +0,0 @@
1
- // Copyright (c) SenseTime Research. All rights reserved.
2
-
3
- // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
4
- //
5
- // This work is made available under the Nvidia Source Code License-NC.
6
- // To view a copy of this license, visit
7
- // https://nvlabs.github.io/stylegan2/license.html
8
-
9
- #include <torch/types.h>
10
-
11
- #include <ATen/ATen.h>
12
- #include <ATen/AccumulateType.h>
13
- #include <ATen/cuda/CUDAContext.h>
14
- #include <ATen/cuda/CUDAApplyUtils.cuh>
15
-
16
- #include <cuda.h>
17
- #include <cuda_runtime.h>
18
-
19
-
20
- template <typename scalar_t>
21
- static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
22
- int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
23
- int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
24
-
25
- scalar_t zero = 0.0;
26
-
27
- for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
28
- scalar_t x = p_x[xi];
29
-
30
- if (use_bias) {
31
- x += p_b[(xi / step_b) % size_b];
32
- }
33
-
34
- scalar_t ref = use_ref ? p_ref[xi] : zero;
35
-
36
- scalar_t y;
37
-
38
- switch (act * 10 + grad) {
39
- default:
40
- case 10: y = x; break;
41
- case 11: y = x; break;
42
- case 12: y = 0.0; break;
43
-
44
- case 30: y = (x > 0.0) ? x : x * alpha; break;
45
- case 31: y = (ref > 0.0) ? x : x * alpha; break;
46
- case 32: y = 0.0; break;
47
- }
48
-
49
- out[xi] = y * scale;
50
- }
51
- }
52
-
53
-
54
- torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
55
- int act, int grad, float alpha, float scale) {
56
- int curDevice = -1;
57
- cudaGetDevice(&curDevice);
58
- cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
59
-
60
- auto x = input.contiguous();
61
- auto b = bias.contiguous();
62
- auto ref = refer.contiguous();
63
-
64
- int use_bias = b.numel() ? 1 : 0;
65
- int use_ref = ref.numel() ? 1 : 0;
66
-
67
- int size_x = x.numel();
68
- int size_b = b.numel();
69
- int step_b = 1;
70
-
71
- for (int i = 1 + 1; i < x.dim(); i++) {
72
- step_b *= x.size(i);
73
- }
74
-
75
- int loop_x = 4;
76
- int block_size = 4 * 32;
77
- int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
78
-
79
- auto y = torch::empty_like(x);
80
-
81
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
82
- fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
83
- y.data_ptr<scalar_t>(),
84
- x.data_ptr<scalar_t>(),
85
- b.data_ptr<scalar_t>(),
86
- ref.data_ptr<scalar_t>(),
87
- act,
88
- grad,
89
- alpha,
90
- scale,
91
- loop_x,
92
- size_x,
93
- step_b,
94
- size_b,
95
- use_bias,
96
- use_ref
97
- );
98
- });
99
-
100
- return y;
101
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/op_edit/upfirdn2d.cpp DELETED
@@ -1,25 +0,0 @@
1
- // Copyright (c) SenseTime Research. All rights reserved.
2
-
3
- #include <torch/extension.h>
4
-
5
-
6
- torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
7
- int up_x, int up_y, int down_x, int down_y,
8
- int pad_x0, int pad_x1, int pad_y0, int pad_y1);
9
-
10
- #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
11
- #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
12
- #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
13
-
14
- torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
15
- int up_x, int up_y, int down_x, int down_y,
16
- int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
17
- CHECK_CUDA(input);
18
- CHECK_CUDA(kernel);
19
-
20
- return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
21
- }
22
-
23
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
24
- m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
25
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/op_edit/upfirdn2d.py DELETED
@@ -1,202 +0,0 @@
1
- # Copyright (c) SenseTime Research. All rights reserved.
2
-
3
- import os
4
-
5
- import torch
6
- from torch.nn import functional as F
7
- from torch.autograd import Function
8
- from torch.utils.cpp_extension import load
9
-
10
-
11
- module_path = os.path.dirname(__file__)
12
- upfirdn2d_op = load(
13
- "upfirdn2d",
14
- sources=[
15
- os.path.join(module_path, "upfirdn2d.cpp"),
16
- os.path.join(module_path, "upfirdn2d_kernel.cu"),
17
- ],
18
- )
19
-
20
-
21
- class UpFirDn2dBackward(Function):
22
- @staticmethod
23
- def forward(
24
- ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
25
- ):
26
-
27
- up_x, up_y = up
28
- down_x, down_y = down
29
- g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
30
-
31
- grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
32
-
33
- grad_input = upfirdn2d_op.upfirdn2d(
34
- grad_output,
35
- grad_kernel,
36
- down_x,
37
- down_y,
38
- up_x,
39
- up_y,
40
- g_pad_x0,
41
- g_pad_x1,
42
- g_pad_y0,
43
- g_pad_y1,
44
- )
45
- grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
46
-
47
- ctx.save_for_backward(kernel)
48
-
49
- pad_x0, pad_x1, pad_y0, pad_y1 = pad
50
-
51
- ctx.up_x = up_x
52
- ctx.up_y = up_y
53
- ctx.down_x = down_x
54
- ctx.down_y = down_y
55
- ctx.pad_x0 = pad_x0
56
- ctx.pad_x1 = pad_x1
57
- ctx.pad_y0 = pad_y0
58
- ctx.pad_y1 = pad_y1
59
- ctx.in_size = in_size
60
- ctx.out_size = out_size
61
-
62
- return grad_input
63
-
64
- @staticmethod
65
- def backward(ctx, gradgrad_input):
66
- (kernel,) = ctx.saved_tensors
67
-
68
- gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
69
-
70
- gradgrad_out = upfirdn2d_op.upfirdn2d(
71
- gradgrad_input,
72
- kernel,
73
- ctx.up_x,
74
- ctx.up_y,
75
- ctx.down_x,
76
- ctx.down_y,
77
- ctx.pad_x0,
78
- ctx.pad_x1,
79
- ctx.pad_y0,
80
- ctx.pad_y1,
81
- )
82
- # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
83
- gradgrad_out = gradgrad_out.view(
84
- ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
85
- )
86
-
87
- return gradgrad_out, None, None, None, None, None, None, None, None
88
-
89
-
90
- class UpFirDn2d(Function):
91
- @staticmethod
92
- def forward(ctx, input, kernel, up, down, pad):
93
- up_x, up_y = up
94
- down_x, down_y = down
95
- pad_x0, pad_x1, pad_y0, pad_y1 = pad
96
-
97
- kernel_h, kernel_w = kernel.shape
98
- batch, channel, in_h, in_w = input.shape
99
- ctx.in_size = input.shape
100
-
101
- input = input.reshape(-1, in_h, in_w, 1)
102
-
103
- ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
104
-
105
- out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
106
- out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
107
- ctx.out_size = (out_h, out_w)
108
-
109
- ctx.up = (up_x, up_y)
110
- ctx.down = (down_x, down_y)
111
- ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
112
-
113
- g_pad_x0 = kernel_w - pad_x0 - 1
114
- g_pad_y0 = kernel_h - pad_y0 - 1
115
- g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
116
- g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
117
-
118
- ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
119
-
120
- out = upfirdn2d_op.upfirdn2d(
121
- input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
122
- )
123
- # out = out.view(major, out_h, out_w, minor)
124
- out = out.view(-1, channel, out_h, out_w)
125
-
126
- return out
127
-
128
- @staticmethod
129
- def backward(ctx, grad_output):
130
- kernel, grad_kernel = ctx.saved_tensors
131
-
132
- grad_input = UpFirDn2dBackward.apply(
133
- grad_output,
134
- kernel,
135
- grad_kernel,
136
- ctx.up,
137
- ctx.down,
138
- ctx.pad,
139
- ctx.g_pad,
140
- ctx.in_size,
141
- ctx.out_size,
142
- )
143
-
144
- return grad_input, None, None, None, None
145
-
146
-
147
- def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
148
- if input.device.type == "cpu":
149
- out = upfirdn2d_native(
150
- input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
151
- )
152
-
153
- else:
154
- out = UpFirDn2d.apply(
155
- input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
156
- )
157
-
158
- return out
159
-
160
-
161
- def upfirdn2d_native(
162
- input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
163
- ):
164
- _, channel, in_h, in_w = input.shape
165
- input = input.reshape(-1, in_h, in_w, 1)
166
-
167
- _, in_h, in_w, minor = input.shape
168
- kernel_h, kernel_w = kernel.shape
169
-
170
- out = input.view(-1, in_h, 1, in_w, 1, minor)
171
- out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
172
- out = out.view(-1, in_h * up_y, in_w * up_x, minor)
173
-
174
- out = F.pad(
175
- out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
176
- )
177
- out = out[
178
- :,
179
- max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
180
- max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
181
- :,
182
- ]
183
-
184
- out = out.permute(0, 3, 1, 2)
185
- out = out.reshape(
186
- [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
187
- )
188
- w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
189
- out = F.conv2d(out, w)
190
- out = out.reshape(
191
- -1,
192
- minor,
193
- in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
194
- in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
195
- )
196
- out = out.permute(0, 2, 3, 1)
197
- out = out[:, ::down_y, ::down_x, :]
198
-
199
- out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
200
- out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
201
-
202
- return out.view(-1, channel, out_h, out_w)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/op_edit/upfirdn2d_kernel.cu DELETED
@@ -1,371 +0,0 @@
1
- // Copyright (c) SenseTime Research. All rights reserved.
2
-
3
- // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
4
- //
5
- // This work is made available under the Nvidia Source Code License-NC.
6
- // To view a copy of this license, visit
7
- // https://nvlabs.github.io/stylegan2/license.html
8
-
9
- #include <torch/types.h>
10
-
11
- #include <ATen/ATen.h>
12
- #include <ATen/AccumulateType.h>
13
- #include <ATen/cuda/CUDAApplyUtils.cuh>
14
- #include <ATen/cuda/CUDAContext.h>
15
-
16
- #include <cuda.h>
17
- #include <cuda_runtime.h>
18
-
19
- static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
20
- int c = a / b;
21
-
22
- if (c * b > a) {
23
- c--;
24
- }
25
-
26
- return c;
27
- }
28
-
29
- struct UpFirDn2DKernelParams {
30
- int up_x;
31
- int up_y;
32
- int down_x;
33
- int down_y;
34
- int pad_x0;
35
- int pad_x1;
36
- int pad_y0;
37
- int pad_y1;
38
-
39
- int major_dim;
40
- int in_h;
41
- int in_w;
42
- int minor_dim;
43
- int kernel_h;
44
- int kernel_w;
45
- int out_h;
46
- int out_w;
47
- int loop_major;
48
- int loop_x;
49
- };
50
-
51
- template <typename scalar_t>
52
- __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
53
- const scalar_t *kernel,
54
- const UpFirDn2DKernelParams p) {
55
- int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
56
- int out_y = minor_idx / p.minor_dim;
57
- minor_idx -= out_y * p.minor_dim;
58
- int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
59
- int major_idx_base = blockIdx.z * p.loop_major;
60
-
61
- if (out_x_base >= p.out_w || out_y >= p.out_h ||
62
- major_idx_base >= p.major_dim) {
63
- return;
64
- }
65
-
66
- int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
67
- int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
68
- int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
69
- int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
70
-
71
- for (int loop_major = 0, major_idx = major_idx_base;
72
- loop_major < p.loop_major && major_idx < p.major_dim;
73
- loop_major++, major_idx++) {
74
- for (int loop_x = 0, out_x = out_x_base;
75
- loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
76
- int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
77
- int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
78
- int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
79
- int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
80
-
81
- const scalar_t *x_p =
82
- &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
83
- minor_idx];
84
- const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
85
- int x_px = p.minor_dim;
86
- int k_px = -p.up_x;
87
- int x_py = p.in_w * p.minor_dim;
88
- int k_py = -p.up_y * p.kernel_w;
89
-
90
- scalar_t v = 0.0f;
91
-
92
- for (int y = 0; y < h; y++) {
93
- for (int x = 0; x < w; x++) {
94
- v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);
95
- x_p += x_px;
96
- k_p += k_px;
97
- }
98
-
99
- x_p += x_py - w * x_px;
100
- k_p += k_py - w * k_px;
101
- }
102
-
103
- out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
104
- minor_idx] = v;
105
- }
106
- }
107
- }
108
-
109
- template <typename scalar_t, int up_x, int up_y, int down_x, int down_y,
110
- int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
111
- __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
112
- const scalar_t *kernel,
113
- const UpFirDn2DKernelParams p) {
114
- const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
115
- const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
116
-
117
- __shared__ volatile float sk[kernel_h][kernel_w];
118
- __shared__ volatile float sx[tile_in_h][tile_in_w];
119
-
120
- int minor_idx = blockIdx.x;
121
- int tile_out_y = minor_idx / p.minor_dim;
122
- minor_idx -= tile_out_y * p.minor_dim;
123
- tile_out_y *= tile_out_h;
124
- int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
125
- int major_idx_base = blockIdx.z * p.loop_major;
126
-
127
- if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
128
- major_idx_base >= p.major_dim) {
129
- return;
130
- }
131
-
132
- for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
133
- tap_idx += blockDim.x) {
134
- int ky = tap_idx / kernel_w;
135
- int kx = tap_idx - ky * kernel_w;
136
- scalar_t v = 0.0;
137
-
138
- if (kx < p.kernel_w & ky < p.kernel_h) {
139
- v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
140
- }
141
-
142
- sk[ky][kx] = v;
143
- }
144
-
145
- for (int loop_major = 0, major_idx = major_idx_base;
146
- loop_major < p.loop_major & major_idx < p.major_dim;
147
- loop_major++, major_idx++) {
148
- for (int loop_x = 0, tile_out_x = tile_out_x_base;
149
- loop_x < p.loop_x & tile_out_x < p.out_w;
150
- loop_x++, tile_out_x += tile_out_w) {
151
- int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
152
- int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
153
- int tile_in_x = floor_div(tile_mid_x, up_x);
154
- int tile_in_y = floor_div(tile_mid_y, up_y);
155
-
156
- __syncthreads();
157
-
158
- for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
159
- in_idx += blockDim.x) {
160
- int rel_in_y = in_idx / tile_in_w;
161
- int rel_in_x = in_idx - rel_in_y * tile_in_w;
162
- int in_x = rel_in_x + tile_in_x;
163
- int in_y = rel_in_y + tile_in_y;
164
-
165
- scalar_t v = 0.0;
166
-
167
- if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
168
- v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
169
- p.minor_dim +
170
- minor_idx];
171
- }
172
-
173
- sx[rel_in_y][rel_in_x] = v;
174
- }
175
-
176
- __syncthreads();
177
- for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
178
- out_idx += blockDim.x) {
179
- int rel_out_y = out_idx / tile_out_w;
180
- int rel_out_x = out_idx - rel_out_y * tile_out_w;
181
- int out_x = rel_out_x + tile_out_x;
182
- int out_y = rel_out_y + tile_out_y;
183
-
184
- int mid_x = tile_mid_x + rel_out_x * down_x;
185
- int mid_y = tile_mid_y + rel_out_y * down_y;
186
- int in_x = floor_div(mid_x, up_x);
187
- int in_y = floor_div(mid_y, up_y);
188
- int rel_in_x = in_x - tile_in_x;
189
- int rel_in_y = in_y - tile_in_y;
190
- int kernel_x = (in_x + 1) * up_x - mid_x - 1;
191
- int kernel_y = (in_y + 1) * up_y - mid_y - 1;
192
-
193
- scalar_t v = 0.0;
194
-
195
- #pragma unroll
196
- for (int y = 0; y < kernel_h / up_y; y++)
197
- #pragma unroll
198
- for (int x = 0; x < kernel_w / up_x; x++)
199
- v += sx[rel_in_y + y][rel_in_x + x] *
200
- sk[kernel_y + y * up_y][kernel_x + x * up_x];
201
-
202
- if (out_x < p.out_w & out_y < p.out_h) {
203
- out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
204
- minor_idx] = v;
205
- }
206
- }
207
- }
208
- }
209
- }
210
-
211
- torch::Tensor upfirdn2d_op(const torch::Tensor &input,
212
- const torch::Tensor &kernel, int up_x, int up_y,
213
- int down_x, int down_y, int pad_x0, int pad_x1,
214
- int pad_y0, int pad_y1) {
215
- int curDevice = -1;
216
- cudaGetDevice(&curDevice);
217
- cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
218
-
219
- UpFirDn2DKernelParams p;
220
-
221
- auto x = input.contiguous();
222
- auto k = kernel.contiguous();
223
-
224
- p.major_dim = x.size(0);
225
- p.in_h = x.size(1);
226
- p.in_w = x.size(2);
227
- p.minor_dim = x.size(3);
228
- p.kernel_h = k.size(0);
229
- p.kernel_w = k.size(1);
230
- p.up_x = up_x;
231
- p.up_y = up_y;
232
- p.down_x = down_x;
233
- p.down_y = down_y;
234
- p.pad_x0 = pad_x0;
235
- p.pad_x1 = pad_x1;
236
- p.pad_y0 = pad_y0;
237
- p.pad_y1 = pad_y1;
238
-
239
- p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
240
- p.down_y;
241
- p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
242
- p.down_x;
243
-
244
- auto out =
245
- at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
246
-
247
- int mode = -1;
248
-
249
- int tile_out_h = -1;
250
- int tile_out_w = -1;
251
-
252
- if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
253
- p.kernel_h <= 4 && p.kernel_w <= 4) {
254
- mode = 1;
255
- tile_out_h = 16;
256
- tile_out_w = 64;
257
- }
258
-
259
- if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
260
- p.kernel_h <= 3 && p.kernel_w <= 3) {
261
- mode = 2;
262
- tile_out_h = 16;
263
- tile_out_w = 64;
264
- }
265
-
266
- if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
267
- p.kernel_h <= 4 && p.kernel_w <= 4) {
268
- mode = 3;
269
- tile_out_h = 16;
270
- tile_out_w = 64;
271
- }
272
-
273
- if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
274
- p.kernel_h <= 2 && p.kernel_w <= 2) {
275
- mode = 4;
276
- tile_out_h = 16;
277
- tile_out_w = 64;
278
- }
279
-
280
- if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
281
- p.kernel_h <= 4 && p.kernel_w <= 4) {
282
- mode = 5;
283
- tile_out_h = 8;
284
- tile_out_w = 32;
285
- }
286
-
287
- if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
288
- p.kernel_h <= 2 && p.kernel_w <= 2) {
289
- mode = 6;
290
- tile_out_h = 8;
291
- tile_out_w = 32;
292
- }
293
-
294
- dim3 block_size;
295
- dim3 grid_size;
296
-
297
- if (tile_out_h > 0 && tile_out_w > 0) {
298
- p.loop_major = (p.major_dim - 1) / 16384 + 1;
299
- p.loop_x = 1;
300
- block_size = dim3(32 * 8, 1, 1);
301
- grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
302
- (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
303
- (p.major_dim - 1) / p.loop_major + 1);
304
- } else {
305
- p.loop_major = (p.major_dim - 1) / 16384 + 1;
306
- p.loop_x = 4;
307
- block_size = dim3(4, 32, 1);
308
- grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
309
- (p.out_w - 1) / (p.loop_x * block_size.y) + 1,
310
- (p.major_dim - 1) / p.loop_major + 1);
311
- }
312
-
313
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
314
- switch (mode) {
315
- case 1:
316
- upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>
317
- <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
318
- x.data_ptr<scalar_t>(),
319
- k.data_ptr<scalar_t>(), p);
320
-
321
- break;
322
-
323
- case 2:
324
- upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>
325
- <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
326
- x.data_ptr<scalar_t>(),
327
- k.data_ptr<scalar_t>(), p);
328
-
329
- break;
330
-
331
- case 3:
332
- upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>
333
- <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
334
- x.data_ptr<scalar_t>(),
335
- k.data_ptr<scalar_t>(), p);
336
-
337
- break;
338
-
339
- case 4:
340
- upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>
341
- <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
342
- x.data_ptr<scalar_t>(),
343
- k.data_ptr<scalar_t>(), p);
344
-
345
- break;
346
-
347
- case 5:
348
- upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
349
- <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
350
- x.data_ptr<scalar_t>(),
351
- k.data_ptr<scalar_t>(), p);
352
-
353
- break;
354
-
355
- case 6:
356
- upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
357
- <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
358
- x.data_ptr<scalar_t>(),
359
- k.data_ptr<scalar_t>(), p);
360
-
361
- break;
362
-
363
- default:
364
- upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(
365
- out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
366
- k.data_ptr<scalar_t>(), p);
367
- }
368
- });
369
-
370
- return out;
371
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/ops/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- # Copyright (c) SenseTime Research. All rights reserved.
2
-
3
- #empty
 
 
 
 
torch_utils/ops/bias_act.cpp DELETED
@@ -1,101 +0,0 @@
1
- // Copyright (c) SenseTime Research. All rights reserved.
2
-
3
- // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
- //
5
- // NVIDIA CORPORATION and its licensors retain all intellectual property
6
- // and proprietary rights in and to this software, related documentation
7
- // and any modifications thereto. Any use, reproduction, disclosure or
8
- // distribution of this software and related documentation without an express
9
- // license agreement from NVIDIA CORPORATION is strictly prohibited.
10
-
11
- #include <torch/extension.h>
12
- #include <ATen/cuda/CUDAContext.h>
13
- #include <c10/cuda/CUDAGuard.h>
14
- #include "bias_act.h"
15
-
16
- //------------------------------------------------------------------------
17
-
18
- static bool has_same_layout(torch::Tensor x, torch::Tensor y)
19
- {
20
- if (x.dim() != y.dim())
21
- return false;
22
- for (int64_t i = 0; i < x.dim(); i++)
23
- {
24
- if (x.size(i) != y.size(i))
25
- return false;
26
- if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
27
- return false;
28
- }
29
- return true;
30
- }
31
-
32
- //------------------------------------------------------------------------
33
-
34
- static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
35
- {
36
- // Validate arguments.
37
- TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
38
- TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
39
- TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
40
- TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
41
- TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
42
- TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
43
- TORCH_CHECK(b.dim() == 1, "b must have rank 1");
44
- TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
45
- TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
46
- TORCH_CHECK(grad >= 0, "grad must be non-negative");
47
-
48
- // Validate layout.
49
- TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
50
- TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
51
- TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
52
- TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
53
- TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
54
-
55
- // Create output tensor.
56
- const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
57
- torch::Tensor y = torch::empty_like(x);
58
- TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
59
-
60
- // Initialize CUDA kernel parameters.
61
- bias_act_kernel_params p;
62
- p.x = x.data_ptr();
63
- p.b = (b.numel()) ? b.data_ptr() : NULL;
64
- p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
65
- p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
66
- p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
67
- p.y = y.data_ptr();
68
- p.grad = grad;
69
- p.act = act;
70
- p.alpha = alpha;
71
- p.gain = gain;
72
- p.clamp = clamp;
73
- p.sizeX = (int)x.numel();
74
- p.sizeB = (int)b.numel();
75
- p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
76
-
77
- // Choose CUDA kernel.
78
- void* kernel;
79
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
80
- {
81
- kernel = choose_bias_act_kernel<scalar_t>(p);
82
- });
83
- TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
84
-
85
- // Launch CUDA kernel.
86
- p.loopX = 4;
87
- int blockSize = 4 * 32;
88
- int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
89
- void* args[] = {&p};
90
- AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
91
- return y;
92
- }
93
-
94
- //------------------------------------------------------------------------
95
-
96
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
97
- {
98
- m.def("bias_act", &bias_act);
99
- }
100
-
101
- //------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/ops/bias_act.cu DELETED
@@ -1,175 +0,0 @@
1
- // Copyright (c) SenseTime Research. All rights reserved.
2
-
3
- // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
- //
5
- // NVIDIA CORPORATION and its licensors retain all intellectual property
6
- // and proprietary rights in and to this software, related documentation
7
- // and any modifications thereto. Any use, reproduction, disclosure or
8
- // distribution of this software and related documentation without an express
9
- // license agreement from NVIDIA CORPORATION is strictly prohibited.
10
-
11
- #include <c10/util/Half.h>
12
- #include "bias_act.h"
13
-
14
- //------------------------------------------------------------------------
15
- // Helpers.
16
-
17
- template <class T> struct InternalType;
18
- template <> struct InternalType<double> { typedef double scalar_t; };
19
- template <> struct InternalType<float> { typedef float scalar_t; };
20
- template <> struct InternalType<c10::Half> { typedef float scalar_t; };
21
-
22
- //------------------------------------------------------------------------
23
- // CUDA kernel.
24
-
25
- template <class T, int A>
26
- __global__ void bias_act_kernel(bias_act_kernel_params p)
27
- {
28
- typedef typename InternalType<T>::scalar_t scalar_t;
29
- int G = p.grad;
30
- scalar_t alpha = (scalar_t)p.alpha;
31
- scalar_t gain = (scalar_t)p.gain;
32
- scalar_t clamp = (scalar_t)p.clamp;
33
- scalar_t one = (scalar_t)1;
34
- scalar_t two = (scalar_t)2;
35
- scalar_t expRange = (scalar_t)80;
36
- scalar_t halfExpRange = (scalar_t)40;
37
- scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
38
- scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
39
-
40
- // Loop over elements.
41
- int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
42
- for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
43
- {
44
- // Load.
45
- scalar_t x = (scalar_t)((const T*)p.x)[xi];
46
- scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
47
- scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
48
- scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
49
- scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
50
- scalar_t yy = (gain != 0) ? yref / gain : 0;
51
- scalar_t y = 0;
52
-
53
- // Apply bias.
54
- ((G == 0) ? x : xref) += b;
55
-
56
- // linear
57
- if (A == 1)
58
- {
59
- if (G == 0) y = x;
60
- if (G == 1) y = x;
61
- }
62
-
63
- // relu
64
- if (A == 2)
65
- {
66
- if (G == 0) y = (x > 0) ? x : 0;
67
- if (G == 1) y = (yy > 0) ? x : 0;
68
- }
69
-
70
- // lrelu
71
- if (A == 3)
72
- {
73
- if (G == 0) y = (x > 0) ? x : x * alpha;
74
- if (G == 1) y = (yy > 0) ? x : x * alpha;
75
- }
76
-
77
- // tanh
78
- if (A == 4)
79
- {
80
- if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
81
- if (G == 1) y = x * (one - yy * yy);
82
- if (G == 2) y = x * (one - yy * yy) * (-two * yy);
83
- }
84
-
85
- // sigmoid
86
- if (A == 5)
87
- {
88
- if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
89
- if (G == 1) y = x * yy * (one - yy);
90
- if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
91
- }
92
-
93
- // elu
94
- if (A == 6)
95
- {
96
- if (G == 0) y = (x >= 0) ? x : exp(x) - one;
97
- if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
98
- if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
99
- }
100
-
101
- // selu
102
- if (A == 7)
103
- {
104
- if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
105
- if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
106
- if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
107
- }
108
-
109
- // softplus
110
- if (A == 8)
111
- {
112
- if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
113
- if (G == 1) y = x * (one - exp(-yy));
114
- if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
115
- }
116
-
117
- // swish
118
- if (A == 9)
119
- {
120
- if (G == 0)
121
- y = (x < -expRange) ? 0 : x / (exp(-x) + one);
122
- else
123
- {
124
- scalar_t c = exp(xref);
125
- scalar_t d = c + one;
126
- if (G == 1)
127
- y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
128
- else
129
- y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
130
- yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
131
- }
132
- }
133
-
134
- // Apply gain.
135
- y *= gain * dy;
136
-
137
- // Clamp.
138
- if (clamp >= 0)
139
- {
140
- if (G == 0)
141
- y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
142
- else
143
- y = (yref > -clamp & yref < clamp) ? y : 0;
144
- }
145
-
146
- // Store.
147
- ((T*)p.y)[xi] = (T)y;
148
- }
149
- }
150
-
151
- //------------------------------------------------------------------------
152
- // CUDA kernel selection.
153
-
154
- template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p)
155
- {
156
- if (p.act == 1) return (void*)bias_act_kernel<T, 1>;
157
- if (p.act == 2) return (void*)bias_act_kernel<T, 2>;
158
- if (p.act == 3) return (void*)bias_act_kernel<T, 3>;
159
- if (p.act == 4) return (void*)bias_act_kernel<T, 4>;
160
- if (p.act == 5) return (void*)bias_act_kernel<T, 5>;
161
- if (p.act == 6) return (void*)bias_act_kernel<T, 6>;
162
- if (p.act == 7) return (void*)bias_act_kernel<T, 7>;
163
- if (p.act == 8) return (void*)bias_act_kernel<T, 8>;
164
- if (p.act == 9) return (void*)bias_act_kernel<T, 9>;
165
- return NULL;
166
- }
167
-
168
- //------------------------------------------------------------------------
169
- // Template specializations.
170
-
171
- template void* choose_bias_act_kernel<double> (const bias_act_kernel_params& p);
172
- template void* choose_bias_act_kernel<float> (const bias_act_kernel_params& p);
173
- template void* choose_bias_act_kernel<c10::Half> (const bias_act_kernel_params& p);
174
-
175
- //------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/ops/bias_act.h DELETED
@@ -1,40 +0,0 @@
1
- // Copyright (c) SenseTime Research. All rights reserved.
2
-
3
- // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
- //
5
- // NVIDIA CORPORATION and its licensors retain all intellectual property
6
- // and proprietary rights in and to this software, related documentation
7
- // and any modifications thereto. Any use, reproduction, disclosure or
8
- // distribution of this software and related documentation without an express
9
- // license agreement from NVIDIA CORPORATION is strictly prohibited.
10
-
11
- //------------------------------------------------------------------------
12
- // CUDA kernel parameters.
13
-
14
- struct bias_act_kernel_params
15
- {
16
- const void* x; // [sizeX]
17
- const void* b; // [sizeB] or NULL
18
- const void* xref; // [sizeX] or NULL
19
- const void* yref; // [sizeX] or NULL
20
- const void* dy; // [sizeX] or NULL
21
- void* y; // [sizeX]
22
-
23
- int grad;
24
- int act;
25
- float alpha;
26
- float gain;
27
- float clamp;
28
-
29
- int sizeX;
30
- int sizeB;
31
- int stepB;
32
- int loopX;
33
- };
34
-
35
- //------------------------------------------------------------------------
36
- // CUDA kernel selection.
37
-
38
- template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p);
39
-
40
- //------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/ops/bias_act.py DELETED
@@ -1,214 +0,0 @@
1
- # Copyright (c) SenseTime Research. All rights reserved.
2
-
3
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
- #
5
- # NVIDIA CORPORATION and its licensors retain all intellectual property
6
- # and proprietary rights in and to this software, related documentation
7
- # and any modifications thereto. Any use, reproduction, disclosure or
8
- # distribution of this software and related documentation without an express
9
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
-
11
- """Custom PyTorch ops for efficient bias and activation."""
12
-
13
- import os
14
- import warnings
15
- import numpy as np
16
- import torch
17
- import dnnlib
18
- import traceback
19
-
20
- from .. import custom_ops
21
- from .. import misc
22
-
23
- #----------------------------------------------------------------------------
24
-
25
- activation_funcs = {
26
- 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
27
- 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
28
- 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
29
- 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
30
- 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
31
- 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
32
- 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
33
- 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
34
- 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
35
- }
36
-
37
- #----------------------------------------------------------------------------
38
-
39
- _inited = False
40
- _plugin = None
41
- _null_tensor = torch.empty([0])
42
-
43
- def _init():
44
- global _inited, _plugin
45
- if not _inited:
46
- _inited = True
47
- sources = ['bias_act.cpp', 'bias_act.cu']
48
- sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
49
- try:
50
- _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
51
- except:
52
- warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
53
- return _plugin is not None
54
-
55
- #----------------------------------------------------------------------------
56
-
57
- def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
58
- r"""Fused bias and activation function.
59
-
60
- Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
61
- and scales the result by `gain`. Each of the steps is optional. In most cases,
62
- the fused op is considerably more efficient than performing the same calculation
63
- using standard PyTorch ops. It supports first and second order gradients,
64
- but not third order gradients.
65
-
66
- Args:
67
- x: Input activation tensor. Can be of any shape.
68
- b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
69
- as `x`. The shape must be known, and it must match the dimension of `x`
70
- corresponding to `dim`.
71
- dim: The dimension in `x` corresponding to the elements of `b`.
72
- The value of `dim` is ignored if `b` is not specified.
73
- act: Name of the activation function to evaluate, or `"linear"` to disable.
74
- Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
75
- See `activation_funcs` for a full list. `None` is not allowed.
76
- alpha: Shape parameter for the activation function, or `None` to use the default.
77
- gain: Scaling factor for the output tensor, or `None` to use default.
78
- See `activation_funcs` for the default scaling of each activation function.
79
- If unsure, consider specifying 1.
80
- clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
81
- the clamping (default).
82
- impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
83
-
84
- Returns:
85
- Tensor of the same shape and datatype as `x`.
86
- """
87
- assert isinstance(x, torch.Tensor)
88
- assert impl in ['ref', 'cuda']
89
- if impl == 'cuda' and x.device.type == 'cuda' and _init():
90
- return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
91
- return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
92
-
93
- #----------------------------------------------------------------------------
94
-
95
- @misc.profiled_function
96
- def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
97
- """Slow reference implementation of `bias_act()` using standard TensorFlow ops.
98
- """
99
- assert isinstance(x, torch.Tensor)
100
- assert clamp is None or clamp >= 0
101
- spec = activation_funcs[act]
102
- alpha = float(alpha if alpha is not None else spec.def_alpha)
103
- gain = float(gain if gain is not None else spec.def_gain)
104
- clamp = float(clamp if clamp is not None else -1)
105
-
106
- # Add bias.
107
- if b is not None:
108
- assert isinstance(b, torch.Tensor) and b.ndim == 1
109
- assert 0 <= dim < x.ndim
110
- assert b.shape[0] == x.shape[dim]
111
- x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
112
-
113
- # Evaluate activation function.
114
- alpha = float(alpha)
115
- x = spec.func(x, alpha=alpha)
116
-
117
- # Scale by gain.
118
- gain = float(gain)
119
- if gain != 1:
120
- x = x * gain
121
-
122
- # Clamp.
123
- if clamp >= 0:
124
- x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
125
- return x
126
-
127
- #----------------------------------------------------------------------------
128
-
129
- _bias_act_cuda_cache = dict()
130
-
131
- def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
132
- """Fast CUDA implementation of `bias_act()` using custom ops.
133
- """
134
- # Parse arguments.
135
- assert clamp is None or clamp >= 0
136
- spec = activation_funcs[act]
137
- alpha = float(alpha if alpha is not None else spec.def_alpha)
138
- gain = float(gain if gain is not None else spec.def_gain)
139
- clamp = float(clamp if clamp is not None else -1)
140
-
141
- # Lookup from cache.
142
- key = (dim, act, alpha, gain, clamp)
143
- if key in _bias_act_cuda_cache:
144
- return _bias_act_cuda_cache[key]
145
-
146
- # Forward op.
147
- class BiasActCuda(torch.autograd.Function):
148
- @staticmethod
149
- def forward(ctx, x, b): # pylint: disable=arguments-differ
150
- ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format
151
- x = x.contiguous(memory_format=ctx.memory_format)
152
- b = b.contiguous() if b is not None else _null_tensor
153
- y = x
154
- if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
155
- y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
156
- ctx.save_for_backward(
157
- x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
158
- b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
159
- y if 'y' in spec.ref else _null_tensor)
160
- return y
161
-
162
- @staticmethod
163
- def backward(ctx, dy): # pylint: disable=arguments-differ
164
- dy = dy.contiguous(memory_format=ctx.memory_format)
165
- x, b, y = ctx.saved_tensors
166
- dx = None
167
- db = None
168
-
169
- if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
170
- dx = dy
171
- if act != 'linear' or gain != 1 or clamp >= 0:
172
- dx = BiasActCudaGrad.apply(dy, x, b, y)
173
-
174
- if ctx.needs_input_grad[1]:
175
- db = dx.sum([i for i in range(dx.ndim) if i != dim])
176
-
177
- return dx, db
178
-
179
- # Backward op.
180
- class BiasActCudaGrad(torch.autograd.Function):
181
- @staticmethod
182
- def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
183
- ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format
184
- dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
185
- ctx.save_for_backward(
186
- dy if spec.has_2nd_grad else _null_tensor,
187
- x, b, y)
188
- return dx
189
-
190
- @staticmethod
191
- def backward(ctx, d_dx): # pylint: disable=arguments-differ
192
- d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
193
- dy, x, b, y = ctx.saved_tensors
194
- d_dy = None
195
- d_x = None
196
- d_b = None
197
- d_y = None
198
-
199
- if ctx.needs_input_grad[0]:
200
- d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
201
-
202
- if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
203
- d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
204
-
205
- if spec.has_2nd_grad and ctx.needs_input_grad[2]:
206
- d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
207
-
208
- return d_dy, d_x, d_b, d_y
209
-
210
- # Add to cache.
211
- _bias_act_cuda_cache[key] = BiasActCuda
212
- return BiasActCuda
213
-
214
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/ops/conv2d_gradfix.py DELETED
@@ -1,172 +0,0 @@
1
- # Copyright (c) SenseTime Research. All rights reserved.
2
-
3
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
- #
5
- # NVIDIA CORPORATION and its licensors retain all intellectual property
6
- # and proprietary rights in and to this software, related documentation
7
- # and any modifications thereto. Any use, reproduction, disclosure or
8
- # distribution of this software and related documentation without an express
9
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
-
11
- """Custom replacement for `torch.nn.functional.conv2d` that supports
12
- arbitrarily high order gradients with zero performance penalty."""
13
-
14
- import warnings
15
- import contextlib
16
- import torch
17
-
18
- # pylint: disable=redefined-builtin
19
- # pylint: disable=arguments-differ
20
- # pylint: disable=protected-access
21
-
22
- #----------------------------------------------------------------------------
23
-
24
- enabled = False # Enable the custom op by setting this to true.
25
- weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
26
-
27
- @contextlib.contextmanager
28
- def no_weight_gradients():
29
- global weight_gradients_disabled
30
- old = weight_gradients_disabled
31
- weight_gradients_disabled = True
32
- yield
33
- weight_gradients_disabled = old
34
-
35
- #----------------------------------------------------------------------------
36
-
37
- def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
38
- if _should_use_custom_op(input):
39
- return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
40
- return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
41
-
42
- def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
43
- if _should_use_custom_op(input):
44
- return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
45
- return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
46
-
47
- #----------------------------------------------------------------------------
48
-
49
- def _should_use_custom_op(input):
50
- assert isinstance(input, torch.Tensor)
51
- if (not enabled) or (not torch.backends.cudnn.enabled):
52
- return False
53
- if input.device.type != 'cuda':
54
- return False
55
- if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
56
- return True
57
- warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().')
58
- return False
59
-
60
- def _tuple_of_ints(xs, ndim):
61
- xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
62
- assert len(xs) == ndim
63
- assert all(isinstance(x, int) for x in xs)
64
- return xs
65
-
66
- #----------------------------------------------------------------------------
67
-
68
- _conv2d_gradfix_cache = dict()
69
-
70
- def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
71
- # Parse arguments.
72
- ndim = 2
73
- weight_shape = tuple(weight_shape)
74
- stride = _tuple_of_ints(stride, ndim)
75
- padding = _tuple_of_ints(padding, ndim)
76
- output_padding = _tuple_of_ints(output_padding, ndim)
77
- dilation = _tuple_of_ints(dilation, ndim)
78
-
79
- # Lookup from cache.
80
- key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
81
- if key in _conv2d_gradfix_cache:
82
- return _conv2d_gradfix_cache[key]
83
-
84
- # Validate arguments.
85
- assert groups >= 1
86
- assert len(weight_shape) == ndim + 2
87
- assert all(stride[i] >= 1 for i in range(ndim))
88
- assert all(padding[i] >= 0 for i in range(ndim))
89
- assert all(dilation[i] >= 0 for i in range(ndim))
90
- if not transpose:
91
- assert all(output_padding[i] == 0 for i in range(ndim))
92
- else: # transpose
93
- assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
94
-
95
- # Helpers.
96
- common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
97
- def calc_output_padding(input_shape, output_shape):
98
- if transpose:
99
- return [0, 0]
100
- return [
101
- input_shape[i + 2]
102
- - (output_shape[i + 2] - 1) * stride[i]
103
- - (1 - 2 * padding[i])
104
- - dilation[i] * (weight_shape[i + 2] - 1)
105
- for i in range(ndim)
106
- ]
107
-
108
- # Forward & backward.
109
- class Conv2d(torch.autograd.Function):
110
- @staticmethod
111
- def forward(ctx, input, weight, bias):
112
- assert weight.shape == weight_shape
113
- if not transpose:
114
- output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
115
- else: # transpose
116
- output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
117
- ctx.save_for_backward(input, weight)
118
- return output
119
-
120
- @staticmethod
121
- def backward(ctx, grad_output):
122
- input, weight = ctx.saved_tensors
123
- grad_input = None
124
- grad_weight = None
125
- grad_bias = None
126
-
127
- if ctx.needs_input_grad[0]:
128
- p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
129
- grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None)
130
- assert grad_input.shape == input.shape
131
-
132
- if ctx.needs_input_grad[1] and not weight_gradients_disabled:
133
- grad_weight = Conv2dGradWeight.apply(grad_output, input)
134
- assert grad_weight.shape == weight_shape
135
-
136
- if ctx.needs_input_grad[2]:
137
- grad_bias = grad_output.sum([0, 2, 3])
138
-
139
- return grad_input, grad_weight, grad_bias
140
-
141
- # Gradient with respect to the weights.
142
- class Conv2dGradWeight(torch.autograd.Function):
143
- @staticmethod
144
- def forward(ctx, grad_output, input):
145
- op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight')
146
- flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
147
- grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
148
- assert grad_weight.shape == weight_shape
149
- ctx.save_for_backward(grad_output, input)
150
- return grad_weight
151
-
152
- @staticmethod
153
- def backward(ctx, grad2_grad_weight):
154
- grad_output, input = ctx.saved_tensors
155
- grad2_grad_output = None
156
- grad2_input = None
157
-
158
- if ctx.needs_input_grad[0]:
159
- grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
160
- assert grad2_grad_output.shape == grad_output.shape
161
-
162
- if ctx.needs_input_grad[1]:
163
- p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
164
- grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None)
165
- assert grad2_input.shape == input.shape
166
-
167
- return grad2_grad_output, grad2_input
168
-
169
- _conv2d_gradfix_cache[key] = Conv2d
170
- return Conv2d
171
-
172
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/ops/conv2d_resample.py DELETED
@@ -1,158 +0,0 @@
1
- # Copyright (c) SenseTime Research. All rights reserved.
2
-
3
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
- #
5
- # NVIDIA CORPORATION and its licensors retain all intellectual property
6
- # and proprietary rights in and to this software, related documentation
7
- # and any modifications thereto. Any use, reproduction, disclosure or
8
- # distribution of this software and related documentation without an express
9
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
-
11
- """2D convolution with optional up/downsampling."""
12
-
13
- import torch
14
-
15
- from .. import misc
16
- from . import conv2d_gradfix
17
- from . import upfirdn2d
18
- from .upfirdn2d import _parse_padding
19
- from .upfirdn2d import _get_filter_size
20
-
21
- #----------------------------------------------------------------------------
22
-
23
- def _get_weight_shape(w):
24
- with misc.suppress_tracer_warnings(): # this value will be treated as a constant
25
- shape = [int(sz) for sz in w.shape]
26
- misc.assert_shape(w, shape)
27
- return shape
28
-
29
- #----------------------------------------------------------------------------
30
-
31
- def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
32
- """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
33
- """
34
- out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
35
-
36
- # Flip weight if requested.
37
- if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
38
- w = w.flip([2, 3])
39
-
40
- # Workaround performance pitfall in cuDNN 8.0.5, triggered when using
41
- # 1x1 kernel + memory_format=channels_last + less than 64 channels.
42
- if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose:
43
- if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64:
44
- if out_channels <= 4 and groups == 1:
45
- in_shape = x.shape
46
- x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1])
47
- x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]])
48
- else:
49
- x = x.to(memory_format=torch.contiguous_format)
50
- w = w.to(memory_format=torch.contiguous_format)
51
- x = conv2d_gradfix.conv2d(x, w, groups=groups)
52
- return x.to(memory_format=torch.channels_last)
53
-
54
- # Otherwise => execute using conv2d_gradfix.
55
- op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
56
- return op(x, w, stride=stride, padding=padding, groups=groups)
57
-
58
- #----------------------------------------------------------------------------
59
-
60
- @misc.profiled_function
61
- def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
62
- r"""2D convolution with optional up/downsampling.
63
-
64
- Padding is performed only once at the beginning, not between the operations.
65
-
66
- Args:
67
- x: Input tensor of shape
68
- `[batch_size, in_channels, in_height, in_width]`.
69
- w: Weight tensor of shape
70
- `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
71
- f: Low-pass filter for up/downsampling. Must be prepared beforehand by
72
- calling upfirdn2d.setup_filter(). None = identity (default).
73
- up: Integer upsampling factor (default: 1).
74
- down: Integer downsampling factor (default: 1).
75
- padding: Padding with respect to the upsampled image. Can be a single number
76
- or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
77
- (default: 0).
78
- groups: Split input channels into N groups (default: 1).
79
- flip_weight: False = convolution, True = correlation (default: True).
80
- flip_filter: False = convolution, True = correlation (default: False).
81
-
82
- Returns:
83
- Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
84
- """
85
- # Validate arguments.
86
- assert isinstance(x, torch.Tensor) and (x.ndim == 4)
87
- assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
88
- assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
89
- assert isinstance(up, int) and (up >= 1)
90
- assert isinstance(down, int) and (down >= 1)
91
- assert isinstance(groups, int) and (groups >= 1)
92
- out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
93
- fw, fh = _get_filter_size(f)
94
- px0, px1, py0, py1 = _parse_padding(padding)
95
-
96
- # Adjust padding to account for up/downsampling.
97
- if up > 1:
98
- px0 += (fw + up - 1) // 2
99
- px1 += (fw - up) // 2
100
- py0 += (fh + up - 1) // 2
101
- py1 += (fh - up) // 2
102
- if down > 1:
103
- px0 += (fw - down + 1) // 2
104
- px1 += (fw - down) // 2
105
- py0 += (fh - down + 1) // 2
106
- py1 += (fh - down) // 2
107
-
108
- # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
109
- if kw == 1 and kh == 1 and (down > 1 and up == 1):
110
- x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
111
- x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
112
- return x
113
-
114
- # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
115
- if kw == 1 and kh == 1 and (up > 1 and down == 1):
116
- x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
117
- x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
118
- return x
119
-
120
- # Fast path: downsampling only => use strided convolution.
121
- if down > 1 and up == 1:
122
- x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
123
- x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
124
- return x
125
-
126
- # Fast path: upsampling with optional downsampling => use transpose strided convolution.
127
- if up > 1:
128
- if groups == 1:
129
- w = w.transpose(0, 1)
130
- else:
131
- w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
132
- w = w.transpose(1, 2)
133
- w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
134
- px0 -= kw - 1
135
- px1 -= kw - up
136
- py0 -= kh - 1
137
- py1 -= kh - up
138
- pxt = max(min(-px0, -px1), 0)
139
- pyt = max(min(-py0, -py1), 0)
140
- x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
141
- x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
142
- if down > 1:
143
- x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
144
- return x
145
-
146
- # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
147
- if up == 1 and down == 1:
148
- if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
149
- return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
150
-
151
- # Fallback: Generic reference implementation.
152
- x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
153
- x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
154
- if down > 1:
155
- x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
156
- return x
157
-
158
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/ops/filtered_lrelu.cpp DELETED
@@ -1,300 +0,0 @@
1
- // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- //
3
- // NVIDIA CORPORATION and its licensors retain all intellectual property
4
- // and proprietary rights in and to this software, related documentation
5
- // and any modifications thereto. Any use, reproduction, disclosure or
6
- // distribution of this software and related documentation without an express
7
- // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- #include <torch/extension.h>
10
- #include <ATen/cuda/CUDAContext.h>
11
- #include <c10/cuda/CUDAGuard.h>
12
- #include "filtered_lrelu.h"
13
-
14
- //------------------------------------------------------------------------
15
-
16
- static std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu(
17
- torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b, torch::Tensor si,
18
- int up, int down, int px0, int px1, int py0, int py1, int sx, int sy, float gain, float slope, float clamp, bool flip_filters, bool writeSigns)
19
- {
20
- // Set CUDA device.
21
- TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
22
- const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
23
-
24
- // Validate arguments.
25
- TORCH_CHECK(fu.device() == x.device() && fd.device() == x.device() && b.device() == x.device(), "all input tensors must reside on the same device");
26
- TORCH_CHECK(fu.dtype() == torch::kFloat && fd.dtype() == torch::kFloat, "fu and fd must be float32");
27
- TORCH_CHECK(b.dtype() == x.dtype(), "x and b must have the same dtype");
28
- TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat, "x and b must be float16 or float32");
29
- TORCH_CHECK(x.dim() == 4, "x must be rank 4");
30
- TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large");
31
- TORCH_CHECK(x.numel() > 0, "x is empty");
32
- TORCH_CHECK((fu.dim() == 1 || fu.dim() == 2) && (fd.dim() == 1 || fd.dim() == 2), "fu and fd must be rank 1 or 2");
33
- TORCH_CHECK(fu.size(0) <= INT_MAX && fu.size(-1) <= INT_MAX, "fu is too large");
34
- TORCH_CHECK(fd.size(0) <= INT_MAX && fd.size(-1) <= INT_MAX, "fd is too large");
35
- TORCH_CHECK(fu.numel() > 0, "fu is empty");
36
- TORCH_CHECK(fd.numel() > 0, "fd is empty");
37
- TORCH_CHECK(b.dim() == 1 && b.size(0) == x.size(1), "b must be a vector with the same number of channels as x");
38
- TORCH_CHECK(up >= 1 && down >= 1, "up and down must be at least 1");
39
-
40
- // Figure out how much shared memory is available on the device.
41
- int maxSharedBytes = 0;
42
- AT_CUDA_CHECK(cudaDeviceGetAttribute(&maxSharedBytes, cudaDevAttrMaxSharedMemoryPerBlockOptin, x.device().index()));
43
- int sharedKB = maxSharedBytes >> 10;
44
-
45
- // Populate enough launch parameters to check if a CUDA kernel exists.
46
- filtered_lrelu_kernel_params p;
47
- p.up = up;
48
- p.down = down;
49
- p.fuShape = make_int2((int)fu.size(-1), fu.dim() == 2 ? (int)fu.size(0) : 0); // shape [n, 0] indicates separable filter.
50
- p.fdShape = make_int2((int)fd.size(-1), fd.dim() == 2 ? (int)fd.size(0) : 0);
51
- filtered_lrelu_kernel_spec test_spec = choose_filtered_lrelu_kernel<float, int32_t, false, false>(p, sharedKB);
52
- if (!test_spec.exec)
53
- {
54
- // No kernel found - return empty tensors and indicate missing kernel with return code of -1.
55
- return std::make_tuple(torch::Tensor(), torch::Tensor(), -1);
56
- }
57
-
58
- // Input/output element size.
59
- int64_t sz = (x.dtype() == torch::kHalf) ? 2 : 4;
60
-
61
- // Input sizes.
62
- int64_t xw = (int)x.size(3);
63
- int64_t xh = (int)x.size(2);
64
- int64_t fut_w = (int)fu.size(-1) - 1;
65
- int64_t fut_h = (int)fu.size(0) - 1;
66
- int64_t fdt_w = (int)fd.size(-1) - 1;
67
- int64_t fdt_h = (int)fd.size(0) - 1;
68
-
69
- // Logical size of upsampled buffer.
70
- int64_t cw = xw * up + (px0 + px1) - fut_w;
71
- int64_t ch = xh * up + (py0 + py1) - fut_h;
72
- TORCH_CHECK(cw > fdt_w && ch > fdt_h, "upsampled buffer must be at least the size of downsampling filter");
73
- TORCH_CHECK(cw <= INT_MAX && ch <= INT_MAX, "upsampled buffer is too large");
74
-
75
- // Compute output size and allocate.
76
- int64_t yw = (cw - fdt_w + (down - 1)) / down;
77
- int64_t yh = (ch - fdt_h + (down - 1)) / down;
78
- TORCH_CHECK(yw > 0 && yh > 0, "output must be at least 1x1");
79
- TORCH_CHECK(yw <= INT_MAX && yh <= INT_MAX, "output is too large");
80
- torch::Tensor y = torch::empty({x.size(0), x.size(1), yh, yw}, x.options(), x.suggest_memory_format());
81
-
82
- // Allocate sign tensor.
83
- torch::Tensor so;
84
- torch::Tensor s = si;
85
- bool readSigns = !!s.numel();
86
- int64_t sw_active = 0; // Active width of sign tensor.
87
- if (writeSigns)
88
- {
89
- sw_active = yw * down - (down - 1) + fdt_w; // Active width in elements.
90
- int64_t sh = yh * down - (down - 1) + fdt_h; // Height = active height.
91
- int64_t sw = (sw_active + 15) & ~15; // Width = active width in elements, rounded up to multiple of 16.
92
- TORCH_CHECK(sh <= INT_MAX && (sw >> 2) <= INT_MAX, "signs is too large");
93
- s = so = torch::empty({x.size(0), x.size(1), sh, sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous);
94
- }
95
- else if (readSigns)
96
- sw_active = s.size(3) << 2;
97
-
98
- // Validate sign tensor if in use.
99
- if (readSigns || writeSigns)
100
- {
101
- TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
102
- TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
103
- TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x");
104
- TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
105
- TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x");
106
- TORCH_CHECK(s.size(2) <= INT_MAX && s.size(3) <= INT_MAX, "signs is too large");
107
- }
108
-
109
- // Populate rest of CUDA kernel parameters.
110
- p.x = x.data_ptr();
111
- p.y = y.data_ptr();
112
- p.b = b.data_ptr();
113
- p.s = (readSigns || writeSigns) ? s.data_ptr<unsigned char>() : 0;
114
- p.fu = fu.data_ptr<float>();
115
- p.fd = fd.data_ptr<float>();
116
- p.pad0 = make_int2(px0, py0);
117
- p.gain = gain;
118
- p.slope = slope;
119
- p.clamp = clamp;
120
- p.flip = (flip_filters) ? 1 : 0;
121
- p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
122
- p.yShape = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
123
- p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3), (int)s.size(2)) : make_int2(0, 0); // Width is in bytes. Contiguous.
124
- p.sOfs = make_int2(sx, sy);
125
- p.swLimit = (sw_active + 3) >> 2; // Rounded up to bytes.
126
-
127
- // x, y, b strides are in bytes.
128
- p.xStride = make_longlong4(sz * x.stride(3), sz * x.stride(2), sz * x.stride(1), sz * x.stride(0));
129
- p.yStride = make_longlong4(sz * y.stride(3), sz * y.stride(2), sz * y.stride(1), sz * y.stride(0));
130
- p.bStride = sz * b.stride(0);
131
-
132
- // fu, fd strides are in elements.
133
- p.fuStride = make_longlong3(fu.stride(-1), fu.dim() == 2 ? fu.stride(0) : 0, 0);
134
- p.fdStride = make_longlong3(fd.stride(-1), fd.dim() == 2 ? fd.stride(0) : 0, 0);
135
-
136
- // Determine if indices don't fit in int32. Support negative strides although Torch currently never produces those.
137
- bool index64b = false;
138
- if (std::abs(p.bStride * x.size(1)) > INT_MAX) index64b = true;
139
- if (std::min(x.size(0) * p.xStride.w, 0ll) + std::min(x.size(1) * p.xStride.z, 0ll) + std::min(x.size(2) * p.xStride.y, 0ll) + std::min(x.size(3) * p.xStride.x, 0ll) < -INT_MAX) index64b = true;
140
- if (std::max(x.size(0) * p.xStride.w, 0ll) + std::max(x.size(1) * p.xStride.z, 0ll) + std::max(x.size(2) * p.xStride.y, 0ll) + std::max(x.size(3) * p.xStride.x, 0ll) > INT_MAX) index64b = true;
141
- if (std::min(y.size(0) * p.yStride.w, 0ll) + std::min(y.size(1) * p.yStride.z, 0ll) + std::min(y.size(2) * p.yStride.y, 0ll) + std::min(y.size(3) * p.yStride.x, 0ll) < -INT_MAX) index64b = true;
142
- if (std::max(y.size(0) * p.yStride.w, 0ll) + std::max(y.size(1) * p.yStride.z, 0ll) + std::max(y.size(2) * p.yStride.y, 0ll) + std::max(y.size(3) * p.yStride.x, 0ll) > INT_MAX) index64b = true;
143
- if (s.numel() > INT_MAX) index64b = true;
144
-
145
- // Choose CUDA kernel.
146
- filtered_lrelu_kernel_spec spec = { 0 };
147
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_cuda", [&]
148
- {
149
- if constexpr (sizeof(scalar_t) <= 4) // Exclude doubles. constexpr prevents template instantiation.
150
- {
151
- // Choose kernel based on index type, datatype and sign read/write modes.
152
- if (!index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, true, false>(p, sharedKB);
153
- else if (!index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, false, true >(p, sharedKB);
154
- else if (!index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, false, false>(p, sharedKB);
155
- else if ( index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, true, false>(p, sharedKB);
156
- else if ( index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, false, true >(p, sharedKB);
157
- else if ( index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, false, false>(p, sharedKB);
158
- }
159
- });
160
- TORCH_CHECK(spec.exec, "internal error - CUDA kernel not found") // This should not happen because we tested earlier that kernel exists.
161
-
162
- // Launch CUDA kernel.
163
- void* args[] = {&p};
164
- int bx = spec.numWarps * 32;
165
- int gx = (p.yShape.x - 1) / spec.tileOut.x + 1;
166
- int gy = (p.yShape.y - 1) / spec.tileOut.y + 1;
167
- int gz = p.yShape.z * p.yShape.w;
168
-
169
- // Repeat multiple horizontal tiles in a CTA?
170
- if (spec.xrep)
171
- {
172
- p.tilesXrep = spec.xrep;
173
- p.tilesXdim = gx;
174
-
175
- gx = (gx + p.tilesXrep - 1) / p.tilesXrep;
176
- std::swap(gx, gy);
177
- }
178
- else
179
- {
180
- p.tilesXrep = 0;
181
- p.tilesXdim = 0;
182
- }
183
-
184
- // Launch filter setup kernel.
185
- AT_CUDA_CHECK(cudaLaunchKernel(spec.setup, 1, 1024, args, 0, at::cuda::getCurrentCUDAStream()));
186
-
187
- // Copy kernels to constant memory.
188
- if ( writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters<true, false>(at::cuda::getCurrentCUDAStream())));
189
- else if (!writeSigns && readSigns) AT_CUDA_CHECK((copy_filters<false, true >(at::cuda::getCurrentCUDAStream())));
190
- else if (!writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters<false, false>(at::cuda::getCurrentCUDAStream())));
191
-
192
- // Set cache and shared memory configurations for main kernel.
193
- AT_CUDA_CHECK(cudaFuncSetCacheConfig(spec.exec, cudaFuncCachePreferShared));
194
- if (spec.dynamicSharedKB) // Need dynamically allocated shared memory?
195
- AT_CUDA_CHECK(cudaFuncSetAttribute(spec.exec, cudaFuncAttributeMaxDynamicSharedMemorySize, spec.dynamicSharedKB << 10));
196
- AT_CUDA_CHECK(cudaFuncSetSharedMemConfig(spec.exec, cudaSharedMemBankSizeFourByte));
197
-
198
- // Launch main kernel.
199
- const int maxSubGz = 65535; // CUDA maximum for block z dimension.
200
- for (int zofs=0; zofs < gz; zofs += maxSubGz) // Do multiple launches if gz is too big.
201
- {
202
- p.blockZofs = zofs;
203
- int subGz = std::min(maxSubGz, gz - zofs);
204
- AT_CUDA_CHECK(cudaLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args, spec.dynamicSharedKB << 10, at::cuda::getCurrentCUDAStream()));
205
- }
206
-
207
- // Done.
208
- return std::make_tuple(y, so, 0);
209
- }
210
-
211
- //------------------------------------------------------------------------
212
-
213
- static torch::Tensor filtered_lrelu_act(torch::Tensor x, torch::Tensor si, int sx, int sy, float gain, float slope, float clamp, bool writeSigns)
214
- {
215
- // Set CUDA device.
216
- TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
217
- const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
218
-
219
- // Validate arguments.
220
- TORCH_CHECK(x.dim() == 4, "x must be rank 4");
221
- TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large");
222
- TORCH_CHECK(x.numel() > 0, "x is empty");
223
- TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat || x.dtype() == torch::kDouble, "x must be float16, float32 or float64");
224
-
225
- // Output signs if we don't have sign input.
226
- torch::Tensor so;
227
- torch::Tensor s = si;
228
- bool readSigns = !!s.numel();
229
- if (writeSigns)
230
- {
231
- int64_t sw = x.size(3);
232
- sw = (sw + 15) & ~15; // Round to a multiple of 16 for coalescing.
233
- s = so = torch::empty({x.size(0), x.size(1), x.size(2), sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous);
234
- }
235
-
236
- // Validate sign tensor if in use.
237
- if (readSigns || writeSigns)
238
- {
239
- TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
240
- TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
241
- TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x");
242
- TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
243
- TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x");
244
- TORCH_CHECK(s.size(2) <= INT_MAX && (s.size(3) << 2) <= INT_MAX, "signs tensor is too large");
245
- }
246
-
247
- // Initialize CUDA kernel parameters.
248
- filtered_lrelu_act_kernel_params p;
249
- p.x = x.data_ptr();
250
- p.s = (readSigns || writeSigns) ? s.data_ptr<unsigned char>() : 0;
251
- p.gain = gain;
252
- p.slope = slope;
253
- p.clamp = clamp;
254
- p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
255
- p.xStride = make_longlong4(x.stride(3), x.stride(2), x.stride(1), x.stride(0));
256
- p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3) << 2, (int)s.size(2)) : make_int2(0, 0); // Width is in elements. Contiguous.
257
- p.sOfs = make_int2(sx, sy);
258
-
259
- // Choose CUDA kernel.
260
- void* func = 0;
261
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_act_cuda", [&]
262
- {
263
- if (writeSigns)
264
- func = choose_filtered_lrelu_act_kernel<scalar_t, true, false>();
265
- else if (readSigns)
266
- func = choose_filtered_lrelu_act_kernel<scalar_t, false, true>();
267
- else
268
- func = choose_filtered_lrelu_act_kernel<scalar_t, false, false>();
269
- });
270
- TORCH_CHECK(func, "internal error - CUDA kernel not found");
271
-
272
- // Launch CUDA kernel.
273
- void* args[] = {&p};
274
- int bx = 128; // 4 warps per block.
275
-
276
- // Logical size of launch = writeSigns ? p.s : p.x
277
- uint32_t gx = writeSigns ? p.sShape.x : p.xShape.x;
278
- uint32_t gy = writeSigns ? p.sShape.y : p.xShape.y;
279
- uint32_t gz = p.xShape.z * p.xShape.w; // Same as in p.sShape if signs are in use.
280
- gx = (gx - 1) / bx + 1;
281
-
282
- // Make sure grid y and z dimensions are within CUDA launch limits. Kernel loops internally to do the rest.
283
- const uint32_t gmax = 65535;
284
- gy = std::min(gy, gmax);
285
- gz = std::min(gz, gmax);
286
-
287
- // Launch.
288
- AT_CUDA_CHECK(cudaLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0, at::cuda::getCurrentCUDAStream()));
289
- return so;
290
- }
291
-
292
- //------------------------------------------------------------------------
293
-
294
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
295
- {
296
- m.def("filtered_lrelu", &filtered_lrelu); // The whole thing.
297
- m.def("filtered_lrelu_act_", &filtered_lrelu_act); // Activation and sign tensor handling only. Modifies data tensor in-place.
298
- }
299
-
300
- //------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/ops/filtered_lrelu.cu DELETED
@@ -1,1284 +0,0 @@
1
- // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- //
3
- // NVIDIA CORPORATION and its licensors retain all intellectual property
4
- // and proprietary rights in and to this software, related documentation
5
- // and any modifications thereto. Any use, reproduction, disclosure or
6
- // distribution of this software and related documentation without an express
7
- // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- #include <c10/util/Half.h>
10
- #include "filtered_lrelu.h"
11
- #include <cstdint>
12
-
13
- //------------------------------------------------------------------------
14
- // Helpers.
15
-
16
- enum // Filter modes.
17
- {
18
- MODE_SUSD = 0, // Separable upsampling, separable downsampling.
19
- MODE_FUSD = 1, // Full upsampling, separable downsampling.
20
- MODE_SUFD = 2, // Separable upsampling, full downsampling.
21
- MODE_FUFD = 3, // Full upsampling, full downsampling.
22
- };
23
-
24
- template <class T> struct InternalType;
25
- template <> struct InternalType<double>
26
- {
27
- typedef double scalar_t; typedef double2 vec2_t; typedef double4 vec4_t;
28
- __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_double2(0, 0); }
29
- __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_double4(0, 0, 0, 0); }
30
- __device__ __forceinline__ static double clamp(double x, double c) { return fmin(fmax(x, -c), c); }
31
- };
32
- template <> struct InternalType<float>
33
- {
34
- typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t;
35
- __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); }
36
- __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); }
37
- __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); }
38
- };
39
- template <> struct InternalType<c10::Half>
40
- {
41
- typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t;
42
- __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); }
43
- __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); }
44
- __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); }
45
- };
46
-
47
- #define MIN(A, B) ((A) < (B) ? (A) : (B))
48
- #define MAX(A, B) ((A) > (B) ? (A) : (B))
49
- #define CEIL_DIV(A, B) (((B)==1) ? (A) : \
50
- ((B)==2) ? ((int)((A)+1) >> 1) : \
51
- ((B)==4) ? ((int)((A)+3) >> 2) : \
52
- (((A) + ((A) > 0 ? (B) - 1 : 0)) / (B)))
53
-
54
- // This works only up to blocks of size 256 x 256 and for all N that are powers of two.
55
- template <int N> __device__ __forceinline__ void fast_div_mod(int& x, int& y, unsigned int i)
56
- {
57
- if ((N & (N-1)) && N <= 256)
58
- y = (i * ((1<<24)/N + 1)) >> 24; // Assumes N <= 256, i < N*256.
59
- else
60
- y = i/N;
61
-
62
- x = i - y*N;
63
- }
64
-
65
- // Type cast stride before reading it.
66
- template <class T> __device__ __forceinline__ T get_stride(const int64_t& x)
67
- {
68
- return *reinterpret_cast<const T*>(&x);
69
- }
70
-
71
- //------------------------------------------------------------------------
72
- // Filters, setup kernel, copying function.
73
-
74
- #define MAX_FILTER_SIZE 32
75
-
76
- // Combined up/down filter buffers so that transfer can be done with one copy.
77
- __device__ float g_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in global memory, written by setup kernel.
78
- __device__ __constant__ float c_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in constant memory, read by main kernel.
79
-
80
- // Accessors to combined buffers to index up/down filters individually.
81
- #define c_fu (c_fbuf)
82
- #define c_fd (c_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE)
83
- #define g_fu (g_fbuf)
84
- #define g_fd (g_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE)
85
-
86
- // Set up filters into global memory buffer.
87
- static __global__ void setup_filters_kernel(filtered_lrelu_kernel_params p)
88
- {
89
- for (int idx = threadIdx.x; idx < MAX_FILTER_SIZE * MAX_FILTER_SIZE; idx += blockDim.x)
90
- {
91
- int x, y;
92
- fast_div_mod<MAX_FILTER_SIZE>(x, y, idx);
93
-
94
- int fu_x = p.flip ? x : (p.fuShape.x - 1 - x);
95
- int fu_y = p.flip ? y : (p.fuShape.y - 1 - y);
96
- if (p.fuShape.y > 0)
97
- g_fu[idx] = (x >= p.fuShape.x || y >= p.fuShape.y) ? 0.0f : p.fu[fu_x * p.fuStride.x + fu_y * p.fuStride.y];
98
- else
99
- g_fu[idx] = (x >= p.fuShape.x || y > 0) ? 0.0f : p.fu[fu_x * p.fuStride.x];
100
-
101
- int fd_x = p.flip ? x : (p.fdShape.x - 1 - x);
102
- int fd_y = p.flip ? y : (p.fdShape.y - 1 - y);
103
- if (p.fdShape.y > 0)
104
- g_fd[idx] = (x >= p.fdShape.x || y >= p.fdShape.y) ? 0.0f : p.fd[fd_x * p.fdStride.x + fd_y * p.fdStride.y];
105
- else
106
- g_fd[idx] = (x >= p.fdShape.x || y > 0) ? 0.0f : p.fd[fd_x * p.fdStride.x];
107
- }
108
- }
109
-
110
- // Host function to copy filters written by setup kernel into constant buffer for main kernel.
111
- template <bool, bool> static cudaError_t copy_filters(cudaStream_t stream)
112
- {
113
- void* src = 0;
114
- cudaError_t err = cudaGetSymbolAddress(&src, g_fbuf);
115
- if (err) return err;
116
- return cudaMemcpyToSymbolAsync(c_fbuf, src, 2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE * sizeof(float), 0, cudaMemcpyDeviceToDevice, stream);
117
- }
118
-
119
- //------------------------------------------------------------------------
120
- // Coordinate spaces:
121
- // - Relative to input tensor: inX, inY, tileInX, tileInY
122
- // - Relative to input tile: relInX, relInY, tileInW, tileInH
123
- // - Relative to upsampled tile: relUpX, relUpY, tileUpW, tileUpH
124
- // - Relative to output tile: relOutX, relOutY, tileOutW, tileOutH
125
- // - Relative to output tensor: outX, outY, tileOutX, tileOutY
126
- //
127
- // Relationships between coordinate spaces:
128
- // - inX = tileInX + relInX
129
- // - inY = tileInY + relInY
130
- // - relUpX = relInX * up + phaseInX
131
- // - relUpY = relInY * up + phaseInY
132
- // - relUpX = relOutX * down
133
- // - relUpY = relOutY * down
134
- // - outX = tileOutX + relOutX
135
- // - outY = tileOutY + relOutY
136
-
137
- extern __shared__ char s_buf_raw[]; // When sharedKB <= 48, allocate shared memory statically inside the kernel, otherwise use the externally allocated shared memory buffer.
138
-
139
- template <class T, class index_t, int sharedKB, bool signWrite, bool signRead, int filterMode, int up, int fuSize, int down, int fdSize, int tileOutW, int tileOutH, int threadsPerBlock, bool enableXrep, bool enableWriteSkip>
140
- static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p)
141
- {
142
- // Check that we don't try to support non-existing filter modes.
143
- static_assert(up == 1 || up == 2 || up == 4, "only up=1, up=2, up=4 scales supported");
144
- static_assert(down == 1 || down == 2 || down == 4, "only down=1, down=2, down=4 scales supported");
145
- static_assert(fuSize >= up, "upsampling filter size must be at least upsampling factor");
146
- static_assert(fdSize >= down, "downsampling filter size must be at least downsampling factor");
147
- static_assert(fuSize % up == 0, "upsampling filter size must be divisible with upsampling factor");
148
- static_assert(fdSize % down == 0, "downsampling filter size must be divisible with downsampling factor");
149
- static_assert(fuSize <= MAX_FILTER_SIZE && fdSize <= MAX_FILTER_SIZE, "filter size greater than MAX_FILTER_SIZE");
150
- static_assert(up != 1 || (fuSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "up=1 supported only for 1x1 full filters");
151
- static_assert(down != 1 || (fdSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "down=1 supported only for 1x1 full filters");
152
- static_assert(!(up == 4 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "full filters not supported for up=4");
153
- static_assert(!(down == 4 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "full filters not supported for down=4");
154
-
155
- // Static definitions.
156
- typedef typename InternalType<T>::scalar_t scalar_t;
157
- typedef typename InternalType<T>::vec2_t vec2_t;
158
- typedef typename InternalType<T>::vec4_t vec4_t;
159
- const int tileUpW = (tileOutW * down + (fdSize - 1) - (down - 1) + 3) & ~3; // Upsampled tile width, rounded up to multiple of 4.
160
- const int tileUpH = tileOutH * down + (fdSize - 1) - (down - 1); // Upsampled tile height.
161
- const int tileInW = CEIL_DIV(tileUpW + (fuSize - 1), up); // Input tile width.
162
- const int tileInH = CEIL_DIV(tileUpH + (fuSize - 1), up); // Input tile height.
163
- const int tileUpH_up = CEIL_DIV(tileUpH, up) * up; // Upsampled tile height rounded up to a multiple of up.
164
- const int tileInH_up = CEIL_DIV(tileUpH_up + (fuSize - 1), up); // For allocations only, to avoid shared memory read overruns with up=2 and up=4.
165
-
166
- // Merge 1x1 downsampling into last upsampling step for upf1 and ups2.
167
- const bool downInline = (down == 1) && ((up == 1 && filterMode == MODE_FUFD) || (up == 2 && filterMode == MODE_SUFD));
168
-
169
- // Sizes of logical buffers.
170
- const int szIn = tileInH_up * tileInW;
171
- const int szUpX = tileInH_up * tileUpW;
172
- const int szUpXY = downInline ? 0 : (tileUpH * tileUpW);
173
- const int szDownX = tileUpH * tileOutW;
174
-
175
- // Sizes for shared memory arrays.
176
- const int s_buf0_size_base =
177
- (filterMode == MODE_SUSD) ? MAX(szIn, szUpXY) :
178
- (filterMode == MODE_FUSD) ? MAX(szIn, szDownX) :
179
- (filterMode == MODE_SUFD) ? MAX(szIn, szUpXY) :
180
- (filterMode == MODE_FUFD) ? szIn :
181
- -1;
182
- const int s_buf1_size_base =
183
- (filterMode == MODE_SUSD) ? MAX(szUpX, szDownX) :
184
- (filterMode == MODE_FUSD) ? szUpXY :
185
- (filterMode == MODE_SUFD) ? szUpX :
186
- (filterMode == MODE_FUFD) ? szUpXY :
187
- -1;
188
-
189
- // Ensure U128 alignment.
190
- const int s_buf0_size = (s_buf0_size_base + 3) & ~3;
191
- const int s_buf1_size = (s_buf1_size_base + 3) & ~3;
192
-
193
- // Check at compile time that we don't use too much shared memory.
194
- static_assert((s_buf0_size + s_buf1_size) * sizeof(scalar_t) <= (sharedKB << 10), "shared memory overflow");
195
-
196
- // Declare shared memory arrays.
197
- scalar_t* s_buf0;
198
- scalar_t* s_buf1;
199
- if (sharedKB <= 48)
200
- {
201
- // Allocate shared memory arrays here.
202
- __shared__ scalar_t s_buf0_st[(sharedKB > 48) ? (1<<24) : (s_buf0_size + s_buf1_size)]; // Prevent launching if this isn't optimized away when unused.
203
- s_buf0 = s_buf0_st;
204
- s_buf1 = s_buf0 + s_buf0_size;
205
- }
206
- else
207
- {
208
- // Use the dynamically allocated shared memory array.
209
- s_buf0 = (scalar_t*)s_buf_raw;
210
- s_buf1 = s_buf0 + s_buf0_size;
211
- }
212
-
213
- // Pointers to the buffers.
214
- scalar_t* s_tileIn; // Input tile: [relInX * tileInH + relInY]
215
- scalar_t* s_tileUpX; // After horizontal upsampling: [relInY * tileUpW + relUpX]
216
- scalar_t* s_tileUpXY; // After upsampling: [relUpY * tileUpW + relUpX]
217
- scalar_t* s_tileDownX; // After horizontal downsampling: [relUpY * tileOutW + relOutX]
218
- if (filterMode == MODE_SUSD)
219
- {
220
- s_tileIn = s_buf0;
221
- s_tileUpX = s_buf1;
222
- s_tileUpXY = s_buf0;
223
- s_tileDownX = s_buf1;
224
- }
225
- else if (filterMode == MODE_FUSD)
226
- {
227
- s_tileIn = s_buf0;
228
- s_tileUpXY = s_buf1;
229
- s_tileDownX = s_buf0;
230
- }
231
- else if (filterMode == MODE_SUFD)
232
- {
233
- s_tileIn = s_buf0;
234
- s_tileUpX = s_buf1;
235
- s_tileUpXY = s_buf0;
236
- }
237
- else if (filterMode == MODE_FUFD)
238
- {
239
- s_tileIn = s_buf0;
240
- s_tileUpXY = s_buf1;
241
- }
242
-
243
- // Allow large grids in z direction via per-launch offset.
244
- int channelIdx = blockIdx.z + p.blockZofs;
245
- int batchIdx = channelIdx / p.yShape.z;
246
- channelIdx -= batchIdx * p.yShape.z;
247
-
248
- // Offset to output feature map. In bytes.
249
- index_t mapOfsOut = channelIdx * get_stride<index_t>(p.yStride.z) + batchIdx * get_stride<index_t>(p.yStride.w);
250
-
251
- // Sign shift amount.
252
- uint32_t signXo = ((threadIdx.x + p.sOfs.x) << 1) & 6;
253
-
254
- // Inner tile loop.
255
- #pragma unroll 1
256
- for (int tileIdx = 0; !enableXrep || (tileIdx < MIN(p.tilesXrep, p.tilesXdim - p.tilesXrep * blockIdx.y)); tileIdx++)
257
- {
258
- // Locate output tile.
259
- int tileX = enableXrep ? blockIdx.y * p.tilesXrep + tileIdx : blockIdx.x;
260
- int tileOutX = tileX * tileOutW;
261
- int tileOutY = (enableXrep ? blockIdx.x : blockIdx.y) * tileOutH;
262
-
263
- // Locate input tile.
264
- int tmpX = tileOutX * down - p.pad0.x;
265
- int tmpY = tileOutY * down - p.pad0.y;
266
- int tileInX = CEIL_DIV(tmpX, up);
267
- int tileInY = CEIL_DIV(tmpY, up);
268
- const int phaseInX = tileInX * up - tmpX;
269
- const int phaseInY = tileInY * up - tmpY;
270
-
271
- // Extra sync if input and output buffers are the same and we are not on first tile.
272
- if (enableXrep && tileIdx > 0 && (filterMode == MODE_FUSD || (filterMode == MODE_SUFD && !downInline) || (filterMode == MODE_FUFD && downInline)))
273
- __syncthreads();
274
-
275
- // Load input tile & apply bias. Unrolled.
276
- scalar_t b = (scalar_t)*(const T*)((const char*)p.b + (channelIdx * get_stride<index_t>(p.bStride)));
277
- index_t mapOfsIn = channelIdx * get_stride<index_t>(p.xStride.z) + batchIdx * get_stride<index_t>(p.xStride.w);
278
- int idx = threadIdx.x;
279
- const int loopCountIN = CEIL_DIV(tileInW * tileInH, threadsPerBlock);
280
- #pragma unroll
281
- for (int loop = 0; loop < loopCountIN; loop++)
282
- {
283
- int relInX, relInY;
284
- fast_div_mod<tileInW>(relInX, relInY, idx);
285
- int inX = tileInX + relInX;
286
- int inY = tileInY + relInY;
287
- scalar_t v = 0;
288
-
289
- if ((uint32_t)inX < p.xShape.x && (uint32_t)inY < p.xShape.y)
290
- v = (scalar_t)*((const T*)((const char*)p.x + (inX * get_stride<index_t>(p.xStride.x) + inY * get_stride<index_t>(p.xStride.y) + mapOfsIn))) + b;
291
-
292
- bool skip = (loop == loopCountIN-1) && (idx >= tileInW * tileInH);
293
- if (!skip)
294
- s_tileIn[idx] = v;
295
-
296
- idx += threadsPerBlock;
297
- }
298
-
299
- if (filterMode == MODE_SUSD || filterMode == MODE_SUFD) // Separable upsampling filter.
300
- {
301
- // Horizontal upsampling.
302
- __syncthreads();
303
- if (up == 4)
304
- {
305
- for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up)
306
- {
307
- int relUpX0, relInY;
308
- fast_div_mod<tileUpW>(relUpX0, relInY, idx);
309
- int relInX0 = relUpX0 / up;
310
- int src0 = relInX0 + tileInW * relInY;
311
- int dst = relInY * tileUpW + relUpX0;
312
- vec4_t v = InternalType<T>::zero_vec4();
313
- scalar_t a = s_tileIn[src0];
314
- if (phaseInX == 0)
315
- {
316
- #pragma unroll
317
- for (int step = 0; step < fuSize / up; step++)
318
- {
319
- v.x += a * (scalar_t)c_fu[step * up + 0];
320
- a = s_tileIn[src0 + step + 1];
321
- v.y += a * (scalar_t)c_fu[step * up + 3];
322
- v.z += a * (scalar_t)c_fu[step * up + 2];
323
- v.w += a * (scalar_t)c_fu[step * up + 1];
324
- }
325
- }
326
- else if (phaseInX == 1)
327
- {
328
- #pragma unroll
329
- for (int step = 0; step < fuSize / up; step++)
330
- {
331
- v.x += a * (scalar_t)c_fu[step * up + 1];
332
- v.y += a * (scalar_t)c_fu[step * up + 0];
333
- a = s_tileIn[src0 + step + 1];
334
- v.z += a * (scalar_t)c_fu[step * up + 3];
335
- v.w += a * (scalar_t)c_fu[step * up + 2];
336
- }
337
- }
338
- else if (phaseInX == 2)
339
- {
340
- #pragma unroll
341
- for (int step = 0; step < fuSize / up; step++)
342
- {
343
- v.x += a * (scalar_t)c_fu[step * up + 2];
344
- v.y += a * (scalar_t)c_fu[step * up + 1];
345
- v.z += a * (scalar_t)c_fu[step * up + 0];
346
- a = s_tileIn[src0 + step + 1];
347
- v.w += a * (scalar_t)c_fu[step * up + 3];
348
- }
349
- }
350
- else // (phaseInX == 3)
351
- {
352
- #pragma unroll
353
- for (int step = 0; step < fuSize / up; step++)
354
- {
355
- v.x += a * (scalar_t)c_fu[step * up + 3];
356
- v.y += a * (scalar_t)c_fu[step * up + 2];
357
- v.z += a * (scalar_t)c_fu[step * up + 1];
358
- v.w += a * (scalar_t)c_fu[step * up + 0];
359
- a = s_tileIn[src0 + step + 1];
360
- }
361
- }
362
- s_tileUpX[dst+0] = v.x;
363
- s_tileUpX[dst+1] = v.y;
364
- s_tileUpX[dst+2] = v.z;
365
- s_tileUpX[dst+3] = v.w;
366
- }
367
- }
368
- else if (up == 2)
369
- {
370
- bool p0 = (phaseInX == 0);
371
- for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up)
372
- {
373
- int relUpX0, relInY;
374
- fast_div_mod<tileUpW>(relUpX0, relInY, idx);
375
- int relInX0 = relUpX0 / up;
376
- int src0 = relInX0 + tileInW * relInY;
377
- int dst = relInY * tileUpW + relUpX0;
378
- vec2_t v = InternalType<T>::zero_vec2();
379
- scalar_t a = s_tileIn[src0];
380
- if (p0) // (phaseInX == 0)
381
- {
382
- #pragma unroll
383
- for (int step = 0; step < fuSize / up; step++)
384
- {
385
- v.x += a * (scalar_t)c_fu[step * up + 0];
386
- a = s_tileIn[src0 + step + 1];
387
- v.y += a * (scalar_t)c_fu[step * up + 1];
388
- }
389
- }
390
- else // (phaseInX == 1)
391
- {
392
- #pragma unroll
393
- for (int step = 0; step < fuSize / up; step++)
394
- {
395
- v.x += a * (scalar_t)c_fu[step * up + 1];
396
- v.y += a * (scalar_t)c_fu[step * up + 0];
397
- a = s_tileIn[src0 + step + 1];
398
- }
399
- }
400
- s_tileUpX[dst+0] = v.x;
401
- s_tileUpX[dst+1] = v.y;
402
- }
403
- }
404
-
405
- // Vertical upsampling & nonlinearity.
406
-
407
- __syncthreads();
408
- int groupMask = 15 << ((threadIdx.x & 31) & ~3);
409
- int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs.
410
- int sShapeMaxY = MIN(p.sShape.y, tileOutY * down + tileUpH); // Avoid out-of-tile sign writes.
411
- if (up == 4)
412
- {
413
- minY -= 3; // Adjust according to block height.
414
- for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x)
415
- {
416
- int relUpX, relInY0;
417
- fast_div_mod<tileUpW>(relUpX, relInY0, idx);
418
- int relUpY0 = relInY0 * up;
419
- int src0 = relInY0 * tileUpW + relUpX;
420
- int dst = relUpY0 * tileUpW + relUpX;
421
- vec4_t v = InternalType<T>::zero_vec4();
422
-
423
- scalar_t a = s_tileUpX[src0];
424
- if (phaseInY == 0)
425
- {
426
- #pragma unroll
427
- for (int step = 0; step < fuSize / up; step++)
428
- {
429
- v.x += a * (scalar_t)c_fu[step * up + 0];
430
- a = s_tileUpX[src0 + (step + 1) * tileUpW];
431
- v.y += a * (scalar_t)c_fu[step * up + 3];
432
- v.z += a * (scalar_t)c_fu[step * up + 2];
433
- v.w += a * (scalar_t)c_fu[step * up + 1];
434
- }
435
- }
436
- else if (phaseInY == 1)
437
- {
438
- #pragma unroll
439
- for (int step = 0; step < fuSize / up; step++)
440
- {
441
- v.x += a * (scalar_t)c_fu[step * up + 1];
442
- v.y += a * (scalar_t)c_fu[step * up + 0];
443
- a = s_tileUpX[src0 + (step + 1) * tileUpW];
444
- v.z += a * (scalar_t)c_fu[step * up + 3];
445
- v.w += a * (scalar_t)c_fu[step * up + 2];
446
- }
447
- }
448
- else if (phaseInY == 2)
449
- {
450
- #pragma unroll
451
- for (int step = 0; step < fuSize / up; step++)
452
- {
453
- v.x += a * (scalar_t)c_fu[step * up + 2];
454
- v.y += a * (scalar_t)c_fu[step * up + 1];
455
- v.z += a * (scalar_t)c_fu[step * up + 0];
456
- a = s_tileUpX[src0 + (step + 1) * tileUpW];
457
- v.w += a * (scalar_t)c_fu[step * up + 3];
458
- }
459
- }
460
- else // (phaseInY == 3)
461
- {
462
- #pragma unroll
463
- for (int step = 0; step < fuSize / up; step++)
464
- {
465
- v.x += a * (scalar_t)c_fu[step * up + 3];
466
- v.y += a * (scalar_t)c_fu[step * up + 2];
467
- v.z += a * (scalar_t)c_fu[step * up + 1];
468
- v.w += a * (scalar_t)c_fu[step * up + 0];
469
- a = s_tileUpX[src0 + (step + 1) * tileUpW];
470
- }
471
- }
472
-
473
- int x = tileOutX * down + relUpX;
474
- int y = tileOutY * down + relUpY0;
475
- int signX = x + p.sOfs.x;
476
- int signY = y + p.sOfs.y;
477
- int signZ = blockIdx.z + p.blockZofs;
478
- int signXb = signX >> 2;
479
- index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
480
- index_t si1 = si0 + p.sShape.x;
481
- index_t si2 = si0 + p.sShape.x * 2;
482
- index_t si3 = si0 + p.sShape.x * 3;
483
-
484
- v.x *= (scalar_t)((float)up * (float)up * p.gain);
485
- v.y *= (scalar_t)((float)up * (float)up * p.gain);
486
- v.z *= (scalar_t)((float)up * (float)up * p.gain);
487
- v.w *= (scalar_t)((float)up * (float)up * p.gain);
488
-
489
- if (signWrite)
490
- {
491
- if (!enableWriteSkip)
492
- {
493
- // Determine and write signs.
494
- int sx = __float_as_uint(v.x) >> 31 << 0;
495
- int sy = __float_as_uint(v.y) >> 31 << 8;
496
- int sz = __float_as_uint(v.z) >> 31 << 16;
497
- int sw = __float_as_uint(v.w) >> 31 << 24;
498
- if (sx) v.x *= p.slope;
499
- if (sy) v.y *= p.slope;
500
- if (sz) v.z *= p.slope;
501
- if (sw) v.w *= p.slope;
502
- if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }
503
- if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }
504
- if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType<T>::clamp(v.z, p.clamp); }
505
- if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType<T>::clamp(v.w, p.clamp); }
506
-
507
- if ((uint32_t)signXb < p.swLimit && signY >= minY)
508
- {
509
- // Combine signs.
510
- uint32_t s = sx + sy + sw + sz;
511
- s <<= (signX & 3) << 1;
512
- s |= __shfl_xor_sync(groupMask, s, 1);
513
- s |= __shfl_xor_sync(groupMask, s, 2);
514
-
515
- // Write signs.
516
- if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
517
- if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
518
- if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); }
519
- if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); }
520
- }
521
- }
522
- else
523
- {
524
- // Determine and write signs.
525
- if ((uint32_t)signXb < p.swLimit && signY >= minY)
526
- {
527
- int sx = __float_as_uint(v.x) >> 31 << 0;
528
- int sy = __float_as_uint(v.y) >> 31 << 8;
529
- int sz = __float_as_uint(v.z) >> 31 << 16;
530
- int sw = __float_as_uint(v.w) >> 31 << 24;
531
- if (sx) v.x *= p.slope;
532
- if (sy) v.y *= p.slope;
533
- if (sz) v.z *= p.slope;
534
- if (sw) v.w *= p.slope;
535
- if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }
536
- if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }
537
- if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType<T>::clamp(v.z, p.clamp); }
538
- if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType<T>::clamp(v.w, p.clamp); }
539
-
540
- // Combine signs.
541
- uint32_t s = sx + sy + sw + sz;
542
- s <<= (signX & 3) << 1;
543
- s |= __shfl_xor_sync(groupMask, s, 1);
544
- s |= __shfl_xor_sync(groupMask, s, 2);
545
-
546
- // Write signs.
547
- if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
548
- if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
549
- if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); }
550
- if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); }
551
- }
552
- else
553
- {
554
- // Just compute the values.
555
- if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
556
- if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
557
- if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);
558
- if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);
559
- }
560
- }
561
- }
562
- else if (signRead) // Read signs and apply.
563
- {
564
- if ((uint32_t)signXb < p.swLimit)
565
- {
566
- int ss = (signX & 3) << 1;
567
- if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> ss; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; }
568
- if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> ss; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; }
569
- if ((uint32_t)(signY + 2) < p.sShape.y) { int s = p.s[si2] >> ss; if (s & 1) v.z *= p.slope; if (s & 2) v.z = 0.f; }
570
- if ((uint32_t)(signY + 3) < p.sShape.y) { int s = p.s[si3] >> ss; if (s & 1) v.w *= p.slope; if (s & 2) v.w = 0.f; }
571
- }
572
- }
573
- else // Forward pass with no sign write.
574
- {
575
- if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
576
- if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
577
- if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);
578
- if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);
579
- }
580
-
581
- s_tileUpXY[dst + 0 * tileUpW] = v.x;
582
- if (relUpY0 + 1 < tileUpH) s_tileUpXY[dst + 1 * tileUpW] = v.y;
583
- if (relUpY0 + 2 < tileUpH) s_tileUpXY[dst + 2 * tileUpW] = v.z;
584
- if (relUpY0 + 3 < tileUpH) s_tileUpXY[dst + 3 * tileUpW] = v.w;
585
- }
586
- }
587
- else if (up == 2)
588
- {
589
- minY -= 1; // Adjust according to block height.
590
- for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x)
591
- {
592
- int relUpX, relInY0;
593
- fast_div_mod<tileUpW>(relUpX, relInY0, idx);
594
- int relUpY0 = relInY0 * up;
595
- int src0 = relInY0 * tileUpW + relUpX;
596
- int dst = relUpY0 * tileUpW + relUpX;
597
- vec2_t v = InternalType<T>::zero_vec2();
598
-
599
- scalar_t a = s_tileUpX[src0];
600
- if (phaseInY == 0)
601
- {
602
- #pragma unroll
603
- for (int step = 0; step < fuSize / up; step++)
604
- {
605
- v.x += a * (scalar_t)c_fu[step * up + 0];
606
- a = s_tileUpX[src0 + (step + 1) * tileUpW];
607
- v.y += a * (scalar_t)c_fu[step * up + 1];
608
- }
609
- }
610
- else // (phaseInY == 1)
611
- {
612
- #pragma unroll
613
- for (int step = 0; step < fuSize / up; step++)
614
- {
615
- v.x += a * (scalar_t)c_fu[step * up + 1];
616
- v.y += a * (scalar_t)c_fu[step * up + 0];
617
- a = s_tileUpX[src0 + (step + 1) * tileUpW];
618
- }
619
- }
620
-
621
- int x = tileOutX * down + relUpX;
622
- int y = tileOutY * down + relUpY0;
623
- int signX = x + p.sOfs.x;
624
- int signY = y + p.sOfs.y;
625
- int signZ = blockIdx.z + p.blockZofs;
626
- int signXb = signX >> 2;
627
- index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
628
- index_t si1 = si0 + p.sShape.x;
629
-
630
- v.x *= (scalar_t)((float)up * (float)up * p.gain);
631
- v.y *= (scalar_t)((float)up * (float)up * p.gain);
632
-
633
- if (signWrite)
634
- {
635
- if (!enableWriteSkip)
636
- {
637
- // Determine and write signs.
638
- int sx = __float_as_uint(v.x) >> 31 << 0;
639
- int sy = __float_as_uint(v.y) >> 31 << 8;
640
- if (sx) v.x *= p.slope;
641
- if (sy) v.y *= p.slope;
642
- if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }
643
- if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }
644
-
645
- if ((uint32_t)signXb < p.swLimit && signY >= minY)
646
- {
647
- // Combine signs.
648
- int s = sx + sy;
649
- s <<= signXo;
650
- s |= __shfl_xor_sync(groupMask, s, 1);
651
- s |= __shfl_xor_sync(groupMask, s, 2);
652
-
653
- // Write signs.
654
- if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
655
- if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
656
- }
657
- }
658
- else
659
- {
660
- // Determine and write signs.
661
- if ((uint32_t)signXb < p.swLimit && signY >= minY)
662
- {
663
- int sx = __float_as_uint(v.x) >> 31 << 0;
664
- int sy = __float_as_uint(v.y) >> 31 << 8;
665
- if (sx) v.x *= p.slope;
666
- if (sy) v.y *= p.slope;
667
- if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }
668
- if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }
669
-
670
- // Combine signs.
671
- int s = sx + sy;
672
- s <<= signXo;
673
- s |= __shfl_xor_sync(groupMask, s, 1);
674
- s |= __shfl_xor_sync(groupMask, s, 2);
675
-
676
- // Write signs.
677
- if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
678
- if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
679
- }
680
- else
681
- {
682
- // Just compute the values.
683
- if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
684
- if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
685
- }
686
- }
687
- }
688
- else if (signRead) // Read signs and apply.
689
- {
690
- if ((uint32_t)signXb < p.swLimit)
691
- {
692
- if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> signXo; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; }
693
- if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> signXo; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; }
694
- }
695
- }
696
- else // Forward pass with no sign write.
697
- {
698
- if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
699
- if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
700
- }
701
-
702
- if (!downInline)
703
- {
704
- // Write into temporary buffer.
705
- s_tileUpXY[dst] = v.x;
706
- if (relUpY0 < tileUpH - 1)
707
- s_tileUpXY[dst + tileUpW] = v.y;
708
- }
709
- else
710
- {
711
- // Write directly into output buffer.
712
- if ((uint32_t)x < p.yShape.x)
713
- {
714
- int ymax = MIN(p.yShape.y, tileUpH + tileOutY * down);
715
- index_t ofs = x * get_stride<index_t>(p.yStride.x) + y * get_stride<index_t>(p.yStride.y) + mapOfsOut;
716
- if ((uint32_t)y + 0 < p.yShape.y) *((T*)((char*)p.y + ofs)) = (T)(v.x * (scalar_t)c_fd[0]);
717
- if ((uint32_t)y + 1 < ymax) *((T*)((char*)p.y + ofs + get_stride<index_t>(p.yStride.y))) = (T)(v.y * (scalar_t)c_fd[0]);
718
- }
719
- }
720
- }
721
- }
722
- }
723
- else if (filterMode == MODE_FUSD || filterMode == MODE_FUFD)
724
- {
725
- // Full upsampling filter.
726
-
727
- if (up == 2)
728
- {
729
- // 2 x 2-wide.
730
- __syncthreads();
731
- int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH + p.sOfs.y : 0; // Skip already written signs.
732
- for (int idx = threadIdx.x * 4; idx < tileUpW * tileUpH; idx += blockDim.x * 4)
733
- {
734
- int relUpX0, relUpY0;
735
- fast_div_mod<tileUpW>(relUpX0, relUpY0, idx);
736
- int relInX0 = CEIL_DIV(relUpX0 - phaseInX, up);
737
- int relInY0 = CEIL_DIV(relUpY0 - phaseInY, up);
738
- int src0 = relInX0 + tileInW * relInY0;
739
- int tap0y = (relInY0 * up + phaseInY - relUpY0);
740
-
741
- #define X_LOOP(TAPY, PX) \
742
- for (int sx = 0; sx < fuSize / up; sx++) \
743
- { \
744
- v.x += a * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \
745
- v.z += b * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 0) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \
746
- v.y += a * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \
747
- v.w += b * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 1) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \
748
- }
749
-
750
- vec4_t v = InternalType<T>::zero_vec4();
751
- if (tap0y == 0 && phaseInX == 0)
752
- #pragma unroll
753
- for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
754
- #pragma unroll
755
- X_LOOP(0, 0) }
756
- if (tap0y == 0 && phaseInX == 1)
757
- #pragma unroll
758
- for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
759
- #pragma unroll
760
- X_LOOP(0, 1) }
761
- if (tap0y == 1 && phaseInX == 0)
762
- #pragma unroll
763
- for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
764
- #pragma unroll
765
- X_LOOP(1, 0) }
766
- if (tap0y == 1 && phaseInX == 1)
767
- #pragma unroll
768
- for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
769
- #pragma unroll
770
- X_LOOP(1, 1) }
771
-
772
- #undef X_LOOP
773
-
774
- int x = tileOutX * down + relUpX0;
775
- int y = tileOutY * down + relUpY0;
776
- int signX = x + p.sOfs.x;
777
- int signY = y + p.sOfs.y;
778
- int signZ = blockIdx.z + p.blockZofs;
779
- int signXb = signX >> 2;
780
- index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
781
-
782
- v.x *= (scalar_t)((float)up * (float)up * p.gain);
783
- v.y *= (scalar_t)((float)up * (float)up * p.gain);
784
- v.z *= (scalar_t)((float)up * (float)up * p.gain);
785
- v.w *= (scalar_t)((float)up * (float)up * p.gain);
786
-
787
- if (signWrite)
788
- {
789
- if (!enableWriteSkip)
790
- {
791
- // Determine and write signs.
792
- int sx = __float_as_uint(v.x) >> 31;
793
- int sy = __float_as_uint(v.y) >> 31;
794
- int sz = __float_as_uint(v.z) >> 31;
795
- int sw = __float_as_uint(v.w) >> 31;
796
- if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType<T>::clamp(v.x, p.clamp); }
797
- if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType<T>::clamp(v.y, p.clamp); }
798
- if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType<T>::clamp(v.z, p.clamp); }
799
- if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType<T>::clamp(v.w, p.clamp); }
800
-
801
- if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
802
- {
803
- p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6);
804
- }
805
- }
806
- else
807
- {
808
- // Determine and write signs.
809
- if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
810
- {
811
- int sx = __float_as_uint(v.x) >> 31;
812
- int sy = __float_as_uint(v.y) >> 31;
813
- int sz = __float_as_uint(v.z) >> 31;
814
- int sw = __float_as_uint(v.w) >> 31;
815
- if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType<T>::clamp(v.x, p.clamp); }
816
- if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType<T>::clamp(v.y, p.clamp); }
817
- if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType<T>::clamp(v.z, p.clamp); }
818
- if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType<T>::clamp(v.w, p.clamp); }
819
-
820
- p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6);
821
- }
822
- else
823
- {
824
- // Just compute the values.
825
- if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
826
- if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
827
- if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);
828
- if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);
829
- }
830
- }
831
- }
832
- else if (signRead) // Read sign and apply.
833
- {
834
- if ((uint32_t)signY < p.sShape.y)
835
- {
836
- int s = 0;
837
- if ((uint32_t)signXb < p.swLimit) s = p.s[si];
838
- if ((uint32_t)signXb + 1 < p.swLimit) s |= p.s[si + 1] << 8;
839
- s >>= (signX & 3) << 1;
840
- if (s & 0x01) v.x *= p.slope; if (s & 0x02) v.x = 0.f;
841
- if (s & 0x04) v.y *= p.slope; if (s & 0x08) v.y = 0.f;
842
- if (s & 0x10) v.z *= p.slope; if (s & 0x20) v.z = 0.f;
843
- if (s & 0x40) v.w *= p.slope; if (s & 0x80) v.w = 0.f;
844
- }
845
- }
846
- else // Forward pass with no sign write.
847
- {
848
- if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
849
- if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
850
- if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);
851
- if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);
852
- }
853
-
854
- s_tileUpXY[idx + 0] = v.x;
855
- s_tileUpXY[idx + 1] = v.y;
856
- s_tileUpXY[idx + 2] = v.z;
857
- s_tileUpXY[idx + 3] = v.w;
858
- }
859
- }
860
- else if (up == 1)
861
- {
862
- __syncthreads();
863
- uint32_t groupMask = 15 << ((threadIdx.x & 31) & ~3);
864
- int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs.
865
- for (int idx = threadIdx.x; idx < tileUpW * tileUpH; idx += blockDim.x)
866
- {
867
- int relUpX0, relUpY0;
868
- fast_div_mod<tileUpW>(relUpX0, relUpY0, idx);
869
- scalar_t v = s_tileIn[idx] * (scalar_t)c_fu[0]; // 1x1 filter.
870
-
871
- int x = tileOutX * down + relUpX0;
872
- int y = tileOutY * down + relUpY0;
873
- int signX = x + p.sOfs.x;
874
- int signY = y + p.sOfs.y;
875
- int signZ = blockIdx.z + p.blockZofs;
876
- int signXb = signX >> 2;
877
- index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
878
- v *= (scalar_t)((float)up * (float)up * p.gain);
879
-
880
- if (signWrite)
881
- {
882
- if (!enableWriteSkip)
883
- {
884
- // Determine and write sign.
885
- uint32_t s = 0;
886
- uint32_t signXbit = (1u << signXo);
887
- if (v < 0.f)
888
- {
889
- s = signXbit;
890
- v *= p.slope;
891
- }
892
- if (fabsf(v) > p.clamp)
893
- {
894
- s = signXbit * 2;
895
- v = InternalType<T>::clamp(v, p.clamp);
896
- }
897
- if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
898
- {
899
- s += __shfl_xor_sync(groupMask, s, 1); // Coalesce.
900
- s += __shfl_xor_sync(groupMask, s, 2); // Coalesce.
901
- p.s[si] = s; // Write.
902
- }
903
- }
904
- else
905
- {
906
- // Determine and write sign.
907
- if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
908
- {
909
- uint32_t s = 0;
910
- uint32_t signXbit = (1u << signXo);
911
- if (v < 0.f)
912
- {
913
- s = signXbit;
914
- v *= p.slope;
915
- }
916
- if (fabsf(v) > p.clamp)
917
- {
918
- s = signXbit * 2;
919
- v = InternalType<T>::clamp(v, p.clamp);
920
- }
921
- s += __shfl_xor_sync(groupMask, s, 1); // Coalesce.
922
- s += __shfl_xor_sync(groupMask, s, 2); // Coalesce.
923
- p.s[si] = s; // Write.
924
- }
925
- else
926
- {
927
- // Just compute the value.
928
- if (v < 0.f) v *= p.slope;
929
- v = InternalType<T>::clamp(v, p.clamp);
930
- }
931
- }
932
- }
933
- else if (signRead)
934
- {
935
- // Read sign and apply if within sign tensor bounds.
936
- if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y)
937
- {
938
- int s = p.s[si];
939
- s >>= signXo;
940
- if (s & 1) v *= p.slope;
941
- if (s & 2) v = 0.f;
942
- }
943
- }
944
- else // Forward pass with no sign write.
945
- {
946
- if (v < 0.f) v *= p.slope;
947
- v = InternalType<T>::clamp(v, p.clamp);
948
- }
949
-
950
- if (!downInline) // Write into temporary buffer.
951
- s_tileUpXY[idx] = v;
952
- else if ((uint32_t)x < p.yShape.x && (uint32_t)y < p.yShape.y) // Write directly into output buffer
953
- *((T*)((char*)p.y + (x * get_stride<index_t>(p.yStride.x) + y * get_stride<index_t>(p.yStride.y) + mapOfsOut))) = (T)(v * (scalar_t)c_fd[0]);
954
- }
955
- }
956
- }
957
-
958
- // Downsampling.
959
- if (filterMode == MODE_SUSD || filterMode == MODE_FUSD)
960
- {
961
- // Horizontal downsampling.
962
- __syncthreads();
963
- if (down == 4 && tileOutW % 4 == 0)
964
- {
965
- // Calculate 4 pixels at a time.
966
- for (int idx = threadIdx.x * 4; idx < tileOutW * tileUpH; idx += blockDim.x * 4)
967
- {
968
- int relOutX0, relUpY;
969
- fast_div_mod<tileOutW>(relOutX0, relUpY, idx);
970
- int relUpX0 = relOutX0 * down;
971
- int src0 = relUpY * tileUpW + relUpX0;
972
- vec4_t v = InternalType<T>::zero_vec4();
973
- #pragma unroll
974
- for (int step = 0; step < fdSize; step++)
975
- {
976
- v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step];
977
- v.y += s_tileUpXY[src0 + 4 + step] * (scalar_t)c_fd[step];
978
- v.z += s_tileUpXY[src0 + 8 + step] * (scalar_t)c_fd[step];
979
- v.w += s_tileUpXY[src0 + 12 + step] * (scalar_t)c_fd[step];
980
- }
981
- s_tileDownX[idx+0] = v.x;
982
- s_tileDownX[idx+1] = v.y;
983
- s_tileDownX[idx+2] = v.z;
984
- s_tileDownX[idx+3] = v.w;
985
- }
986
- }
987
- else if ((down == 2 || down == 4) && (tileOutW % 2 == 0))
988
- {
989
- // Calculate 2 pixels at a time.
990
- for (int idx = threadIdx.x * 2; idx < tileOutW * tileUpH; idx += blockDim.x * 2)
991
- {
992
- int relOutX0, relUpY;
993
- fast_div_mod<tileOutW>(relOutX0, relUpY, idx);
994
- int relUpX0 = relOutX0 * down;
995
- int src0 = relUpY * tileUpW + relUpX0;
996
- vec2_t v = InternalType<T>::zero_vec2();
997
- #pragma unroll
998
- for (int step = 0; step < fdSize; step++)
999
- {
1000
- v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step];
1001
- v.y += s_tileUpXY[src0 + down + step] * (scalar_t)c_fd[step];
1002
- }
1003
- s_tileDownX[idx+0] = v.x;
1004
- s_tileDownX[idx+1] = v.y;
1005
- }
1006
- }
1007
- else
1008
- {
1009
- // Calculate 1 pixel at a time.
1010
- for (int idx = threadIdx.x; idx < tileOutW * tileUpH; idx += blockDim.x)
1011
- {
1012
- int relOutX0, relUpY;
1013
- fast_div_mod<tileOutW>(relOutX0, relUpY, idx);
1014
- int relUpX0 = relOutX0 * down;
1015
- int src = relUpY * tileUpW + relUpX0;
1016
- scalar_t v = 0.f;
1017
- #pragma unroll
1018
- for (int step = 0; step < fdSize; step++)
1019
- v += s_tileUpXY[src + step] * (scalar_t)c_fd[step];
1020
- s_tileDownX[idx] = v;
1021
- }
1022
- }
1023
-
1024
- // Vertical downsampling & store output tile.
1025
- __syncthreads();
1026
- for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x)
1027
- {
1028
- int relOutX, relOutY0;
1029
- fast_div_mod<tileOutW>(relOutX, relOutY0, idx);
1030
- int relUpY0 = relOutY0 * down;
1031
- int src0 = relUpY0 * tileOutW + relOutX;
1032
- scalar_t v = 0;
1033
- #pragma unroll
1034
- for (int step = 0; step < fdSize; step++)
1035
- v += s_tileDownX[src0 + step * tileOutW] * (scalar_t)c_fd[step];
1036
-
1037
- int outX = tileOutX + relOutX;
1038
- int outY = tileOutY + relOutY0;
1039
-
1040
- if (outX < p.yShape.x & outY < p.yShape.y)
1041
- *((T*)((char*)p.y + (outX * get_stride<index_t>(p.yStride.x) + outY * get_stride<index_t>(p.yStride.y) + mapOfsOut))) = (T)v;
1042
- }
1043
- }
1044
- else if (filterMode == MODE_SUFD || filterMode == MODE_FUFD)
1045
- {
1046
- // Full downsampling filter.
1047
- if (down == 2)
1048
- {
1049
- // 2-wide.
1050
- __syncthreads();
1051
- for (int idx = threadIdx.x * 2; idx < tileOutW * tileOutH; idx += blockDim.x * 2)
1052
- {
1053
- int relOutX0, relOutY0;
1054
- fast_div_mod<tileOutW>(relOutX0, relOutY0, idx);
1055
- int relUpX0 = relOutX0 * down;
1056
- int relUpY0 = relOutY0 * down;
1057
- int src0 = relUpY0 * tileUpW + relUpX0;
1058
- vec2_t v = InternalType<T>::zero_vec2();
1059
- #pragma unroll
1060
- for (int sy = 0; sy < fdSize; sy++)
1061
- #pragma unroll
1062
- for (int sx = 0; sx < fdSize; sx++)
1063
- {
1064
- v.x += s_tileUpXY[src0 + 0 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE];
1065
- v.y += s_tileUpXY[src0 + 2 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE];
1066
- }
1067
-
1068
- int outX = tileOutX + relOutX0;
1069
- int outY = tileOutY + relOutY0;
1070
- if ((uint32_t)outY < p.yShape.y)
1071
- {
1072
- index_t ofs = outX * get_stride<index_t>(p.yStride.x) + outY * get_stride<index_t>(p.yStride.y) + mapOfsOut;
1073
- if (outX + 0 < p.yShape.x) *((T*)((char*)p.y + ofs)) = (T)v.x;
1074
- if (outX + 1 < p.yShape.x) *((T*)((char*)p.y + ofs + get_stride<index_t>(p.yStride.x))) = (T)v.y;
1075
- }
1076
- }
1077
- }
1078
- else if (down == 1 && !downInline)
1079
- {
1080
- // Thread per pixel.
1081
- __syncthreads();
1082
- for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x)
1083
- {
1084
- int relOutX0, relOutY0;
1085
- fast_div_mod<tileOutW>(relOutX0, relOutY0, idx);
1086
- scalar_t v = s_tileUpXY[idx] * (scalar_t)c_fd[0]; // 1x1 filter.
1087
-
1088
- int outX = tileOutX + relOutX0;
1089
- int outY = tileOutY + relOutY0;
1090
- if ((uint32_t)outX < p.yShape.x && (uint32_t)outY < p.yShape.y)
1091
- *((T*)((char*)p.y + (outX * get_stride<index_t>(p.yStride.x) + outY * get_stride<index_t>(p.yStride.y) + mapOfsOut))) = (T)v;
1092
- }
1093
- }
1094
- }
1095
-
1096
- if (!enableXrep)
1097
- break;
1098
- }
1099
- }
1100
-
1101
- //------------------------------------------------------------------------
1102
- // Compute activation function and signs for upsampled data tensor, modifying data tensor in-place. Used for accelerating the generic variant.
1103
- // Sign tensor is known to be contiguous, and p.x and p.s have the same z, w dimensions. 64-bit indexing is always used.
1104
-
1105
- template <class T, bool signWrite, bool signRead>
1106
- static __global__ void filtered_lrelu_act_kernel(filtered_lrelu_act_kernel_params p)
1107
- {
1108
- typedef typename InternalType<T>::scalar_t scalar_t;
1109
-
1110
- // Indexing.
1111
- int32_t x = threadIdx.x + blockIdx.x * blockDim.x;
1112
- int32_t ymax = signWrite ? p.sShape.y : p.xShape.y;
1113
- int32_t qmax = p.xShape.z * p.xShape.w; // Combined minibatch*channel maximum index.
1114
-
1115
- // Loop to accommodate oversized tensors.
1116
- for (int32_t q = blockIdx.z; q < qmax; q += gridDim.z)
1117
- for (int32_t y = blockIdx.y; y < ymax; y += gridDim.y)
1118
- {
1119
- // Extract z and w (channel, minibatch index).
1120
- int32_t w = q / p.xShape.z;
1121
- int32_t z = q - w * p.xShape.z;
1122
-
1123
- // Choose behavior based on sign read/write mode.
1124
- if (signWrite)
1125
- {
1126
- // Process value if in p.x.
1127
- uint32_t s = 0;
1128
- if (x < p.xShape.x && y < p.xShape.y)
1129
- {
1130
- int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
1131
- T* pv = ((T*)p.x) + ix;
1132
- scalar_t v = (scalar_t)(*pv);
1133
-
1134
- // Gain, LReLU, clamp.
1135
- v *= p.gain;
1136
- if (v < 0.f)
1137
- {
1138
- v *= p.slope;
1139
- s = 1; // Sign.
1140
- }
1141
- if (fabsf(v) > p.clamp)
1142
- {
1143
- v = InternalType<T>::clamp(v, p.clamp);
1144
- s = 2; // Clamp.
1145
- }
1146
-
1147
- *pv = (T)v; // Write value.
1148
- }
1149
-
1150
- // Coalesce into threads 0 and 16 of warp.
1151
- uint32_t m = (threadIdx.x & 16) ? 0xffff0000u : 0x0000ffffu;
1152
- s <<= ((threadIdx.x & 15) << 1); // Shift into place.
1153
- s |= __shfl_xor_sync(m, s, 1); // Distribute.
1154
- s |= __shfl_xor_sync(m, s, 2);
1155
- s |= __shfl_xor_sync(m, s, 4);
1156
- s |= __shfl_xor_sync(m, s, 8);
1157
-
1158
- // Write signs if leader and in p.s.
1159
- if (!(threadIdx.x & 15) && x < p.sShape.x) // y is always in.
1160
- {
1161
- uint64_t is = x + p.sShape.x * (y + (int64_t)p.sShape.y * q); // Contiguous.
1162
- ((uint32_t*)p.s)[is >> 4] = s;
1163
- }
1164
- }
1165
- else if (signRead)
1166
- {
1167
- // Process value if in p.x.
1168
- if (x < p.xShape.x) // y is always in.
1169
- {
1170
- int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
1171
- T* pv = ((T*)p.x) + ix;
1172
- scalar_t v = (scalar_t)(*pv);
1173
- v *= p.gain;
1174
-
1175
- // Apply sign buffer offset.
1176
- uint32_t sx = x + p.sOfs.x;
1177
- uint32_t sy = y + p.sOfs.y;
1178
-
1179
- // Read and apply signs if we land inside valid region of sign buffer.
1180
- if (sx < p.sShape.x && sy < p.sShape.y)
1181
- {
1182
- uint64_t is = (sx >> 2) + (p.sShape.x >> 2) * (sy + (uint64_t)p.sShape.y * q); // Contiguous.
1183
- unsigned char s = p.s[is];
1184
- s >>= (sx & 3) << 1; // Shift into place.
1185
- if (s & 1) // Sign?
1186
- v *= p.slope;
1187
- if (s & 2) // Clamp?
1188
- v = 0.f;
1189
- }
1190
-
1191
- *pv = (T)v; // Write value.
1192
- }
1193
- }
1194
- else
1195
- {
1196
- // Forward pass with no sign write. Process value if in p.x.
1197
- if (x < p.xShape.x) // y is always in.
1198
- {
1199
- int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
1200
- T* pv = ((T*)p.x) + ix;
1201
- scalar_t v = (scalar_t)(*pv);
1202
- v *= p.gain;
1203
- if (v < 0.f)
1204
- v *= p.slope;
1205
- if (fabsf(v) > p.clamp)
1206
- v = InternalType<T>::clamp(v, p.clamp);
1207
- *pv = (T)v; // Write value.
1208
- }
1209
- }
1210
- }
1211
- }
1212
-
1213
- template <class T, bool signWrite, bool signRead> void* choose_filtered_lrelu_act_kernel(void)
1214
- {
1215
- return (void*)filtered_lrelu_act_kernel<T, signWrite, signRead>;
1216
- }
1217
-
1218
- //------------------------------------------------------------------------
1219
- // CUDA kernel selection.
1220
-
1221
- template <class T, class index_t, bool signWrite, bool signRead> filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB)
1222
- {
1223
- filtered_lrelu_kernel_spec s = { 0 };
1224
-
1225
- // Return the first matching kernel.
1226
- #define CASE(SH, U, FU, D, FD, MODE, TW, TH, W, XR, WS) \
1227
- if (sharedKB >= SH) \
1228
- if ((p.fuShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_SUFD)) || (p.fuShape.y > 0 && (MODE == MODE_FUSD || MODE == MODE_FUFD))) \
1229
- if ((p.fdShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_FUSD)) || (p.fdShape.y > 0 && (MODE == MODE_SUFD || MODE == MODE_FUFD))) \
1230
- if (p.up == U && p.fuShape.x <= FU && p.fuShape.y <= FU && p.down == D && p.fdShape.x <= FD && p.fdShape.y <= FD) \
1231
- { \
1232
- static_assert((D*TW % 4) == 0, "down * tileWidth must be divisible by 4"); \
1233
- static_assert(FU % U == 0, "upscaling filter size must be multiple of upscaling factor"); \
1234
- static_assert(FD % D == 0, "downscaling filter size must be multiple of downscaling factor"); \
1235
- s.setup = (void*)setup_filters_kernel; \
1236
- s.exec = (void*)filtered_lrelu_kernel<T, index_t, SH, signWrite, signRead, MODE, U, FU, D, FD, TW, TH, W*32, !!XR, !!WS>; \
1237
- s.tileOut = make_int2(TW, TH); \
1238
- s.numWarps = W; \
1239
- s.xrep = XR; \
1240
- s.dynamicSharedKB = (SH == 48) ? 0 : SH; \
1241
- return s; \
1242
- }
1243
-
1244
- // Launch parameters for various kernel specializations.
1245
- // Small filters must be listed before large filters, otherwise the kernel for larger filter will always match first.
1246
- // Kernels that use more shared memory must be listed before those that use less, for the same reason.
1247
-
1248
- CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/1,1, /*mode*/MODE_FUFD, /*tw,th,warps,xrep,wskip*/64, 178, 32, 0, 0) // 1t-upf1-downf1
1249
- CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/152, 95, 16, 0, 0) // 4t-ups2-downf1
1250
- CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 22, 16, 0, 0) // 4t-upf1-downs2
1251
- CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 29, 16, 11, 0) // 4t-ups2-downs2
1252
- CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/60, 28, 16, 0, 0) // 4t-upf2-downs2
1253
- CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 28, 16, 0, 0) // 4t-ups2-downf2
1254
- CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 31, 16, 11, 0) // 4t-ups4-downs2
1255
- CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 36, 16, 0, 0) // 4t-ups4-downf2
1256
- CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 22, 16, 12, 0) // 4t-ups2-downs4
1257
- CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/29, 15, 16, 0, 0) // 4t-upf2-downs4
1258
- CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/96, 150, 28, 0, 0) // 6t-ups2-downf1
1259
- CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 35, 24, 0, 0) // 6t-upf1-downs2
1260
- CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 16, 10, 0) // 6t-ups2-downs2
1261
- CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/58, 28, 24, 8, 0) // 6t-upf2-downs2
1262
- CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/52, 28, 16, 0, 0) // 6t-ups2-downf2
1263
- CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 51, 16, 5, 0) // 6t-ups4-downs2
1264
- CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 56, 16, 6, 0) // 6t-ups4-downf2
1265
- CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 18, 16, 12, 0) // 6t-ups2-downs4
1266
- CASE(/*sharedKB*/96, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 31, 32, 6, 0) // 6t-upf2-downs4 96kB
1267
- CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 13, 24, 0, 0) // 6t-upf2-downs4
1268
- CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/148, 89, 24, 0, 0) // 8t-ups2-downf1
1269
- CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 31, 16, 5, 0) // 8t-upf1-downs2
1270
- CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 41, 16, 9, 0) // 8t-ups2-downs2
1271
- CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 26, 24, 0, 0) // 8t-upf2-downs2
1272
- CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 40, 16, 0, 0) // 8t-ups2-downf2
1273
- CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 24, 5, 0) // 8t-ups4-downs2
1274
- CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 50, 16, 0, 0) // 8t-ups4-downf2
1275
- CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/24, 24, 32, 12, 1) // 8t-ups2-downs4 96kB
1276
- CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 13, 16, 10, 1) // 8t-ups2-downs4
1277
- CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 28, 28, 4, 0) // 8t-upf2-downs4 96kB
1278
- CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 10, 24, 0, 0) // 8t-upf2-downs4
1279
-
1280
- #undef CASE
1281
- return s; // No kernel found.
1282
- }
1283
-
1284
- //------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/ops/filtered_lrelu.h DELETED
@@ -1,90 +0,0 @@
1
- // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- //
3
- // NVIDIA CORPORATION and its licensors retain all intellectual property
4
- // and proprietary rights in and to this software, related documentation
5
- // and any modifications thereto. Any use, reproduction, disclosure or
6
- // distribution of this software and related documentation without an express
7
- // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- #include <cuda_runtime.h>
10
-
11
- //------------------------------------------------------------------------
12
- // CUDA kernel parameters.
13
-
14
- struct filtered_lrelu_kernel_params
15
- {
16
- // These parameters decide which kernel to use.
17
- int up; // upsampling ratio (1, 2, 4)
18
- int down; // downsampling ratio (1, 2, 4)
19
- int2 fuShape; // [size, 1] | [size, size]
20
- int2 fdShape; // [size, 1] | [size, size]
21
-
22
- int _dummy; // Alignment.
23
-
24
- // Rest of the parameters.
25
- const void* x; // Input tensor.
26
- void* y; // Output tensor.
27
- const void* b; // Bias tensor.
28
- unsigned char* s; // Sign tensor in/out. NULL if unused.
29
- const float* fu; // Upsampling filter.
30
- const float* fd; // Downsampling filter.
31
-
32
- int2 pad0; // Left/top padding.
33
- float gain; // Additional gain factor.
34
- float slope; // Leaky ReLU slope on negative side.
35
- float clamp; // Clamp after nonlinearity.
36
- int flip; // Filter kernel flip for gradient computation.
37
-
38
- int tilesXdim; // Original number of horizontal output tiles.
39
- int tilesXrep; // Number of horizontal tiles per CTA.
40
- int blockZofs; // Block z offset to support large minibatch, channel dimensions.
41
-
42
- int4 xShape; // [width, height, channel, batch]
43
- int4 yShape; // [width, height, channel, batch]
44
- int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused.
45
- int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
46
- int swLimit; // Active width of sign tensor in bytes.
47
-
48
- longlong4 xStride; // Strides of all tensors except signs, same component order as shapes.
49
- longlong4 yStride; //
50
- int64_t bStride; //
51
- longlong3 fuStride; //
52
- longlong3 fdStride; //
53
- };
54
-
55
- struct filtered_lrelu_act_kernel_params
56
- {
57
- void* x; // Input/output, modified in-place.
58
- unsigned char* s; // Sign tensor in/out. NULL if unused.
59
-
60
- float gain; // Additional gain factor.
61
- float slope; // Leaky ReLU slope on negative side.
62
- float clamp; // Clamp after nonlinearity.
63
-
64
- int4 xShape; // [width, height, channel, batch]
65
- longlong4 xStride; // Input/output tensor strides, same order as in shape.
66
- int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused.
67
- int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
68
- };
69
-
70
- //------------------------------------------------------------------------
71
- // CUDA kernel specialization.
72
-
73
- struct filtered_lrelu_kernel_spec
74
- {
75
- void* setup; // Function for filter kernel setup.
76
- void* exec; // Function for main operation.
77
- int2 tileOut; // Width/height of launch tile.
78
- int numWarps; // Number of warps per thread block, determines launch block size.
79
- int xrep; // For processing multiple horizontal tiles per thread block.
80
- int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants.
81
- };
82
-
83
- //------------------------------------------------------------------------
84
- // CUDA kernel selection.
85
-
86
- template <class T, class index_t, bool signWrite, bool signRead> filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
87
- template <class T, bool signWrite, bool signRead> void* choose_filtered_lrelu_act_kernel(void);
88
- template <bool signWrite, bool signRead> cudaError_t copy_filters(cudaStream_t stream);
89
-
90
- //------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/ops/filtered_lrelu.py DELETED
@@ -1,282 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- import os
10
- import numpy as np
11
- import torch
12
- import warnings
13
-
14
- from .. import custom_ops
15
- from .. import misc
16
- from . import upfirdn2d
17
- from . import bias_act
18
-
19
- #----------------------------------------------------------------------------
20
-
21
- _plugin = None
22
-
23
- def _init():
24
- global _plugin
25
- if _plugin is None:
26
-
27
- # sources=['filtered_lrelu.h', 'filtered_lrelu.cu', 'filtered_lrelu.cpp', 'filtered_lrelu_wr.cu', 'filtered_lrelu_rd.cu', 'filtered_lrelu_ns.cu']
28
- # sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
29
- # try:
30
- # _plugin = custom_ops.get_plugin('filtered_lrelu_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math', '--allow-unsupported-compiler'])
31
- # except:
32
- # warnings.warn('Failed to build CUDA kernels for filtered_lrelu_plugin. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
33
-
34
- _plugin = custom_ops.get_plugin_v3(
35
- module_name='filtered_lrelu_plugin',
36
- sources=['filtered_lrelu.cpp', 'filtered_lrelu_wr.cu', 'filtered_lrelu_rd.cu', 'filtered_lrelu_ns.cu'],
37
- headers=['filtered_lrelu.h', 'filtered_lrelu.cu'],
38
- source_dir=os.path.dirname(__file__),
39
- extra_cuda_cflags=['--use_fast_math', '--allow-unsupported-compiler'],
40
- )
41
- return True
42
-
43
- def _get_filter_size(f):
44
- if f is None:
45
- return 1, 1
46
- assert isinstance(f, torch.Tensor)
47
- assert 1 <= f.ndim <= 2
48
- return f.shape[-1], f.shape[0] # width, height
49
-
50
- def _parse_padding(padding):
51
- if isinstance(padding, int):
52
- padding = [padding, padding]
53
- assert isinstance(padding, (list, tuple))
54
- assert all(isinstance(x, (int, np.integer)) for x in padding)
55
- padding = [int(x) for x in padding]
56
- if len(padding) == 2:
57
- px, py = padding
58
- padding = [px, px, py, py]
59
- px0, px1, py0, py1 = padding
60
- return px0, px1, py0, py1
61
-
62
- #----------------------------------------------------------------------------
63
-
64
- def filtered_lrelu(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False, impl='cuda'):
65
- r"""Filtered leaky ReLU for a batch of 2D images.
66
-
67
- Performs the following sequence of operations for each channel:
68
-
69
- 1. Add channel-specific bias if provided (`b`).
70
-
71
- 2. Upsample the image by inserting N-1 zeros after each pixel (`up`).
72
-
73
- 3. Pad the image with the specified number of zeros on each side (`padding`).
74
- Negative padding corresponds to cropping the image.
75
-
76
- 4. Convolve the image with the specified upsampling FIR filter (`fu`), shrinking it
77
- so that the footprint of all output pixels lies within the input image.
78
-
79
- 5. Multiply each value by the provided gain factor (`gain`).
80
-
81
- 6. Apply leaky ReLU activation function to each value.
82
-
83
- 7. Clamp each value between -clamp and +clamp, if `clamp` parameter is provided.
84
-
85
- 8. Convolve the image with the specified downsampling FIR filter (`fd`), shrinking
86
- it so that the footprint of all output pixels lies within the input image.
87
-
88
- 9. Downsample the image by keeping every Nth pixel (`down`).
89
-
90
- The fused op is considerably more efficient than performing the same calculation
91
- using standard PyTorch ops. It supports gradients of arbitrary order.
92
-
93
- Args:
94
- x: Float32/float16/float64 input tensor of the shape
95
- `[batch_size, num_channels, in_height, in_width]`.
96
- fu: Float32 upsampling FIR filter of the shape
97
- `[filter_height, filter_width]` (non-separable),
98
- `[filter_taps]` (separable), or
99
- `None` (identity).
100
- fd: Float32 downsampling FIR filter of the shape
101
- `[filter_height, filter_width]` (non-separable),
102
- `[filter_taps]` (separable), or
103
- `None` (identity).
104
- b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
105
- as `x`. The length of vector must must match the channel dimension of `x`.
106
- up: Integer upsampling factor (default: 1).
107
- down: Integer downsampling factor. (default: 1).
108
- padding: Padding with respect to the upsampled image. Can be a single number
109
- or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
110
- (default: 0).
111
- gain: Overall scaling factor for signal magnitude (default: sqrt(2)).
112
- slope: Slope on the negative side of leaky ReLU (default: 0.2).
113
- clamp: Maximum magnitude for leaky ReLU output (default: None).
114
- flip_filter: False = convolution, True = correlation (default: False).
115
- impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
116
-
117
- Returns:
118
- Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
119
- """
120
- assert isinstance(x, torch.Tensor)
121
- assert impl in ['ref', 'cuda']
122
- if impl == 'cuda' and x.device.type == 'cuda' and _init():
123
- return _filtered_lrelu_cuda(up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter).apply(x, fu, fd, b, None, 0, 0)
124
- return _filtered_lrelu_ref(x, fu=fu, fd=fd, b=b, up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter)
125
-
126
- #----------------------------------------------------------------------------
127
-
128
- @misc.profiled_function
129
- def _filtered_lrelu_ref(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False):
130
- """Slow and memory-inefficient reference implementation of `filtered_lrelu()` using
131
- existing `upfirdn2n()` and `bias_act()` ops.
132
- """
133
- assert isinstance(x, torch.Tensor) and x.ndim == 4
134
- fu_w, fu_h = _get_filter_size(fu)
135
- fd_w, fd_h = _get_filter_size(fd)
136
- if b is not None:
137
- assert isinstance(b, torch.Tensor) and b.dtype == x.dtype
138
- misc.assert_shape(b, [x.shape[1]])
139
- assert isinstance(up, int) and up >= 1
140
- assert isinstance(down, int) and down >= 1
141
- px0, px1, py0, py1 = _parse_padding(padding)
142
- assert gain == float(gain) and gain > 0
143
- assert slope == float(slope) and slope >= 0
144
- assert clamp is None or (clamp == float(clamp) and clamp >= 0)
145
-
146
- # Calculate output size.
147
- batch_size, channels, in_h, in_w = x.shape
148
- in_dtype = x.dtype
149
- out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down
150
- out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down
151
-
152
- # Compute using existing ops.
153
- x = bias_act.bias_act(x=x, b=b) # Apply bias.
154
- x = upfirdn2d.upfirdn2d(x=x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample.
155
- x = bias_act.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) # Bias, leaky ReLU, clamp.
156
- x = upfirdn2d.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) # Downsample.
157
-
158
- # Check output shape & dtype.
159
- misc.assert_shape(x, [batch_size, channels, out_h, out_w])
160
- assert x.dtype == in_dtype
161
- return x
162
-
163
- #----------------------------------------------------------------------------
164
-
165
- _filtered_lrelu_cuda_cache = dict()
166
-
167
- def _filtered_lrelu_cuda(up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False):
168
- """Fast CUDA implementation of `filtered_lrelu()` using custom ops.
169
- """
170
- assert isinstance(up, int) and up >= 1
171
- assert isinstance(down, int) and down >= 1
172
- px0, px1, py0, py1 = _parse_padding(padding)
173
- assert gain == float(gain) and gain > 0
174
- gain = float(gain)
175
- assert slope == float(slope) and slope >= 0
176
- slope = float(slope)
177
- assert clamp is None or (clamp == float(clamp) and clamp >= 0)
178
- clamp = float(clamp if clamp is not None else 'inf')
179
-
180
- # Lookup from cache.
181
- key = (up, down, px0, px1, py0, py1, gain, slope, clamp, flip_filter)
182
- if key in _filtered_lrelu_cuda_cache:
183
- return _filtered_lrelu_cuda_cache[key]
184
-
185
- # Forward op.
186
- class FilteredLReluCuda(torch.autograd.Function):
187
- @staticmethod
188
- def forward(ctx, x, fu, fd, b, si, sx, sy): # pylint: disable=arguments-differ
189
- assert isinstance(x, torch.Tensor) and x.ndim == 4
190
-
191
- # Replace empty up/downsample kernels with full 1x1 kernels (faster than separable).
192
- if fu is None:
193
- fu = torch.ones([1, 1], dtype=torch.float32, device=x.device)
194
- if fd is None:
195
- fd = torch.ones([1, 1], dtype=torch.float32, device=x.device)
196
- assert 1 <= fu.ndim <= 2
197
- assert 1 <= fd.ndim <= 2
198
-
199
- # Replace separable 1x1 kernels with full 1x1 kernels when scale factor is 1.
200
- if up == 1 and fu.ndim == 1 and fu.shape[0] == 1:
201
- fu = fu.square()[None]
202
- if down == 1 and fd.ndim == 1 and fd.shape[0] == 1:
203
- fd = fd.square()[None]
204
-
205
- # Missing sign input tensor.
206
- if si is None:
207
- si = torch.empty([0])
208
-
209
- # Missing bias tensor.
210
- if b is None:
211
- b = torch.zeros([x.shape[1]], dtype=x.dtype, device=x.device)
212
-
213
- # Construct internal sign tensor only if gradients are needed.
214
- write_signs = (si.numel() == 0) and (x.requires_grad or b.requires_grad)
215
-
216
- # Warn if input storage strides are not in decreasing order due to e.g. channels-last layout.
217
- strides = [x.stride(i) for i in range(x.ndim) if x.size(i) > 1]
218
- if any(a < b for a, b in zip(strides[:-1], strides[1:])):
219
- warnings.warn("low-performance memory layout detected in filtered_lrelu input", RuntimeWarning)
220
-
221
- # Call C++/Cuda plugin if datatype is supported.
222
- if x.dtype in [torch.float16, torch.float32]:
223
- if torch.cuda.current_stream(x.device) != torch.cuda.default_stream(x.device):
224
- warnings.warn("filtered_lrelu called with non-default cuda stream but concurrent execution is not supported", RuntimeWarning)
225
- y, so, return_code = _plugin.filtered_lrelu(x, fu, fd, b, si, up, down, px0, px1, py0, py1, sx, sy, gain, slope, clamp, flip_filter, write_signs)
226
- else:
227
- return_code = -1
228
-
229
- # No Cuda kernel found? Fall back to generic implementation. Still more memory efficient than the reference implementation because
230
- # only the bit-packed sign tensor is retained for gradient computation.
231
- if return_code < 0:
232
- warnings.warn("filtered_lrelu called with parameters that have no optimized CUDA kernel, using generic fallback", RuntimeWarning)
233
-
234
- y = x.add(b.unsqueeze(-1).unsqueeze(-1)) # Add bias.
235
- y = upfirdn2d.upfirdn2d(x=y, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample.
236
- so = _plugin.filtered_lrelu_act_(y, si, sx, sy, gain, slope, clamp, write_signs) # Activation function and sign handling. Modifies y in-place.
237
- y = upfirdn2d.upfirdn2d(x=y, f=fd, down=down, flip_filter=flip_filter) # Downsample.
238
-
239
- # Prepare for gradient computation.
240
- ctx.save_for_backward(fu, fd, (si if si.numel() else so))
241
- ctx.x_shape = x.shape
242
- ctx.y_shape = y.shape
243
- ctx.s_ofs = sx, sy
244
- return y
245
-
246
- @staticmethod
247
- def backward(ctx, dy): # pylint: disable=arguments-differ
248
- fu, fd, si = ctx.saved_tensors
249
- _, _, xh, xw = ctx.x_shape
250
- _, _, yh, yw = ctx.y_shape
251
- sx, sy = ctx.s_ofs
252
- dx = None # 0
253
- dfu = None; assert not ctx.needs_input_grad[1]
254
- dfd = None; assert not ctx.needs_input_grad[2]
255
- db = None # 3
256
- dsi = None; assert not ctx.needs_input_grad[4]
257
- dsx = None; assert not ctx.needs_input_grad[5]
258
- dsy = None; assert not ctx.needs_input_grad[6]
259
-
260
- if ctx.needs_input_grad[0] or ctx.needs_input_grad[3]:
261
- pp = [
262
- (fu.shape[-1] - 1) + (fd.shape[-1] - 1) - px0,
263
- xw * up - yw * down + px0 - (up - 1),
264
- (fu.shape[0] - 1) + (fd.shape[0] - 1) - py0,
265
- xh * up - yh * down + py0 - (up - 1),
266
- ]
267
- gg = gain * (up ** 2) / (down ** 2)
268
- ff = (not flip_filter)
269
- sx = sx - (fu.shape[-1] - 1) + px0
270
- sy = sy - (fu.shape[0] - 1) + py0
271
- dx = _filtered_lrelu_cuda(up=down, down=up, padding=pp, gain=gg, slope=slope, clamp=None, flip_filter=ff).apply(dy, fd, fu, None, si, sx, sy)
272
-
273
- if ctx.needs_input_grad[3]:
274
- db = dx.sum([0, 2, 3])
275
-
276
- return dx, dfu, dfd, db, dsi, dsx, dsy
277
-
278
- # Add to cache.
279
- _filtered_lrelu_cuda_cache[key] = FilteredLReluCuda
280
- return FilteredLReluCuda
281
-
282
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/ops/filtered_lrelu_ns.cu DELETED
@@ -1,27 +0,0 @@
1
- // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- //
3
- // NVIDIA CORPORATION and its licensors retain all intellectual property
4
- // and proprietary rights in and to this software, related documentation
5
- // and any modifications thereto. Any use, reproduction, disclosure or
6
- // distribution of this software and related documentation without an express
7
- // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- #include "filtered_lrelu.cu"
10
-
11
- // Template/kernel specializations for no signs mode (no gradients required).
12
-
13
- // Full op, 32-bit indexing.
14
- template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
15
- template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int32_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
16
-
17
- // Full op, 64-bit indexing.
18
- template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
19
- template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int64_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
20
-
21
- // Activation/signs only for generic variant. 64-bit indexing.
22
- template void* choose_filtered_lrelu_act_kernel<c10::Half, false, false>(void);
23
- template void* choose_filtered_lrelu_act_kernel<float, false, false>(void);
24
- template void* choose_filtered_lrelu_act_kernel<double, false, false>(void);
25
-
26
- // Copy filters to constant memory.
27
- template cudaError_t copy_filters<false, false>(cudaStream_t stream);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/ops/filtered_lrelu_rd.cu DELETED
@@ -1,27 +0,0 @@
1
- // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- //
3
- // NVIDIA CORPORATION and its licensors retain all intellectual property
4
- // and proprietary rights in and to this software, related documentation
5
- // and any modifications thereto. Any use, reproduction, disclosure or
6
- // distribution of this software and related documentation without an express
7
- // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- #include "filtered_lrelu.cu"
10
-
11
- // Template/kernel specializations for sign read mode.
12
-
13
- // Full op, 32-bit indexing.
14
- template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
15
- template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int32_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
16
-
17
- // Full op, 64-bit indexing.
18
- template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
19
- template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int64_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
20
-
21
- // Activation/signs only for generic variant. 64-bit indexing.
22
- template void* choose_filtered_lrelu_act_kernel<c10::Half, false, true>(void);
23
- template void* choose_filtered_lrelu_act_kernel<float, false, true>(void);
24
- template void* choose_filtered_lrelu_act_kernel<double, false, true>(void);
25
-
26
- // Copy filters to constant memory.
27
- template cudaError_t copy_filters<false, true>(cudaStream_t stream);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/ops/filtered_lrelu_wr.cu DELETED
@@ -1,27 +0,0 @@
1
- // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- //
3
- // NVIDIA CORPORATION and its licensors retain all intellectual property
4
- // and proprietary rights in and to this software, related documentation
5
- // and any modifications thereto. Any use, reproduction, disclosure or
6
- // distribution of this software and related documentation without an express
7
- // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- #include "filtered_lrelu.cu"
10
-
11
- // Template/kernel specializations for sign write mode.
12
-
13
- // Full op, 32-bit indexing.
14
- template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
15
- template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int32_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
16
-
17
- // Full op, 64-bit indexing.
18
- template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
19
- template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int64_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
20
-
21
- // Activation/signs only for generic variant. 64-bit indexing.
22
- template void* choose_filtered_lrelu_act_kernel<c10::Half, true, false>(void);
23
- template void* choose_filtered_lrelu_act_kernel<float, true, false>(void);
24
- template void* choose_filtered_lrelu_act_kernel<double, true, false>(void);
25
-
26
- // Copy filters to constant memory.
27
- template cudaError_t copy_filters<true, false>(cudaStream_t stream);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/ops/fma.py DELETED
@@ -1,62 +0,0 @@
1
- # Copyright (c) SenseTime Research. All rights reserved.
2
-
3
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
- #
5
- # NVIDIA CORPORATION and its licensors retain all intellectual property
6
- # and proprietary rights in and to this software, related documentation
7
- # and any modifications thereto. Any use, reproduction, disclosure or
8
- # distribution of this software and related documentation without an express
9
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
-
11
- """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
12
-
13
- import torch
14
-
15
- #----------------------------------------------------------------------------
16
-
17
- def fma(a, b, c): # => a * b + c
18
- return _FusedMultiplyAdd.apply(a, b, c)
19
-
20
- #----------------------------------------------------------------------------
21
-
22
- class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
23
- @staticmethod
24
- def forward(ctx, a, b, c): # pylint: disable=arguments-differ
25
- out = torch.addcmul(c, a, b)
26
- ctx.save_for_backward(a, b)
27
- ctx.c_shape = c.shape
28
- return out
29
-
30
- @staticmethod
31
- def backward(ctx, dout): # pylint: disable=arguments-differ
32
- a, b = ctx.saved_tensors
33
- c_shape = ctx.c_shape
34
- da = None
35
- db = None
36
- dc = None
37
-
38
- if ctx.needs_input_grad[0]:
39
- da = _unbroadcast(dout * b, a.shape)
40
-
41
- if ctx.needs_input_grad[1]:
42
- db = _unbroadcast(dout * a, b.shape)
43
-
44
- if ctx.needs_input_grad[2]:
45
- dc = _unbroadcast(dout, c_shape)
46
-
47
- return da, db, dc
48
-
49
- #----------------------------------------------------------------------------
50
-
51
- def _unbroadcast(x, shape):
52
- extra_dims = x.ndim - len(shape)
53
- assert extra_dims >= 0
54
- dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
55
- if len(dim):
56
- x = x.sum(dim=dim, keepdim=True)
57
- if extra_dims:
58
- x = x.reshape(-1, *x.shape[extra_dims+1:])
59
- assert x.shape == shape
60
- return x
61
-
62
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/ops/grid_sample_gradfix.py DELETED
@@ -1,85 +0,0 @@
1
- # Copyright (c) SenseTime Research. All rights reserved.
2
-
3
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
- #
5
- # NVIDIA CORPORATION and its licensors retain all intellectual property
6
- # and proprietary rights in and to this software, related documentation
7
- # and any modifications thereto. Any use, reproduction, disclosure or
8
- # distribution of this software and related documentation without an express
9
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
-
11
- """Custom replacement for `torch.nn.functional.grid_sample` that
12
- supports arbitrarily high order gradients between the input and output.
13
- Only works on 2D images and assumes
14
- `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
15
-
16
- import warnings
17
- import torch
18
-
19
- # pylint: disable=redefined-builtin
20
- # pylint: disable=arguments-differ
21
- # pylint: disable=protected-access
22
-
23
- #----------------------------------------------------------------------------
24
-
25
- enabled = False # Enable the custom op by setting this to true.
26
-
27
- #----------------------------------------------------------------------------
28
-
29
- def grid_sample(input, grid):
30
- if _should_use_custom_op():
31
- return _GridSample2dForward.apply(input, grid)
32
- return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
33
-
34
- #----------------------------------------------------------------------------
35
-
36
- def _should_use_custom_op():
37
- if not enabled:
38
- return False
39
- if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
40
- return True
41
- warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().')
42
- return False
43
-
44
- #----------------------------------------------------------------------------
45
-
46
- class _GridSample2dForward(torch.autograd.Function):
47
- @staticmethod
48
- def forward(ctx, input, grid):
49
- assert input.ndim == 4
50
- assert grid.ndim == 4
51
- output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
52
- ctx.save_for_backward(input, grid)
53
- return output
54
-
55
- @staticmethod
56
- def backward(ctx, grad_output):
57
- input, grid = ctx.saved_tensors
58
- grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
59
- return grad_input, grad_grid
60
-
61
- #----------------------------------------------------------------------------
62
-
63
- class _GridSample2dBackward(torch.autograd.Function):
64
- @staticmethod
65
- def forward(ctx, grad_output, input, grid):
66
- op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
67
- grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
68
- ctx.save_for_backward(grid)
69
- return grad_input, grad_grid
70
-
71
- @staticmethod
72
- def backward(ctx, grad2_grad_input, grad2_grad_grid):
73
- _ = grad2_grad_grid # unused
74
- grid, = ctx.saved_tensors
75
- grad2_grad_output = None
76
- grad2_input = None
77
- grad2_grid = None
78
-
79
- if ctx.needs_input_grad[0]:
80
- grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
81
-
82
- assert not ctx.needs_input_grad[2]
83
- return grad2_grad_output, grad2_input, grad2_grid
84
-
85
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/ops/upfirdn2d.cpp DELETED
@@ -1,105 +0,0 @@
1
- // Copyright (c) SenseTime Research. All rights reserved.
2
-
3
- // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
- //
5
- // NVIDIA CORPORATION and its licensors retain all intellectual property
6
- // and proprietary rights in and to this software, related documentation
7
- // and any modifications thereto. Any use, reproduction, disclosure or
8
- // distribution of this software and related documentation without an express
9
- // license agreement from NVIDIA CORPORATION is strictly prohibited.
10
-
11
- #include <torch/extension.h>
12
- #include <ATen/cuda/CUDAContext.h>
13
- #include <c10/cuda/CUDAGuard.h>
14
- #include "upfirdn2d.h"
15
-
16
- //------------------------------------------------------------------------
17
-
18
- static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
19
- {
20
- // Validate arguments.
21
- TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
22
- TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
23
- TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
24
- TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
25
- TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
26
- TORCH_CHECK(x.dim() == 4, "x must be rank 4");
27
- TORCH_CHECK(f.dim() == 2, "f must be rank 2");
28
- TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
29
- TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
30
- TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
31
-
32
- // Create output tensor.
33
- const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
34
- int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
35
- int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
36
- TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
37
- torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
38
- TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
39
-
40
- // Initialize CUDA kernel parameters.
41
- upfirdn2d_kernel_params p;
42
- p.x = x.data_ptr();
43
- p.f = f.data_ptr<float>();
44
- p.y = y.data_ptr();
45
- p.up = make_int2(upx, upy);
46
- p.down = make_int2(downx, downy);
47
- p.pad0 = make_int2(padx0, pady0);
48
- p.flip = (flip) ? 1 : 0;
49
- p.gain = gain;
50
- p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
51
- p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
52
- p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
53
- p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
54
- p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
55
- p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
56
- p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
57
- p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
58
-
59
- // Choose CUDA kernel.
60
- upfirdn2d_kernel_spec spec;
61
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
62
- {
63
- spec = choose_upfirdn2d_kernel<scalar_t>(p);
64
- });
65
-
66
- // Set looping options.
67
- p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
68
- p.loopMinor = spec.loopMinor;
69
- p.loopX = spec.loopX;
70
- p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
71
- p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
72
-
73
- // Compute grid size.
74
- dim3 blockSize, gridSize;
75
- if (spec.tileOutW < 0) // large
76
- {
77
- blockSize = dim3(4, 32, 1);
78
- gridSize = dim3(
79
- ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
80
- (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
81
- p.launchMajor);
82
- }
83
- else // small
84
- {
85
- blockSize = dim3(256, 1, 1);
86
- gridSize = dim3(
87
- ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
88
- (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
89
- p.launchMajor);
90
- }
91
-
92
- // Launch CUDA kernel.
93
- void* args[] = {&p};
94
- AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
95
- return y;
96
- }
97
-
98
- //------------------------------------------------------------------------
99
-
100
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
101
- {
102
- m.def("upfirdn2d", &upfirdn2d);
103
- }
104
-
105
- //------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/ops/upfirdn2d.cu DELETED
@@ -1,352 +0,0 @@
1
- // Copyright (c) SenseTime Research. All rights reserved.
2
-
3
- // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
- //
5
- // NVIDIA CORPORATION and its licensors retain all intellectual property
6
- // and proprietary rights in and to this software, related documentation
7
- // and any modifications thereto. Any use, reproduction, disclosure or
8
- // distribution of this software and related documentation without an express
9
- // license agreement from NVIDIA CORPORATION is strictly prohibited.
10
-
11
- #include <c10/util/Half.h>
12
- #include "upfirdn2d.h"
13
-
14
- //------------------------------------------------------------------------
15
- // Helpers.
16
-
17
- template <class T> struct InternalType;
18
- template <> struct InternalType<double> { typedef double scalar_t; };
19
- template <> struct InternalType<float> { typedef float scalar_t; };
20
- template <> struct InternalType<c10::Half> { typedef float scalar_t; };
21
-
22
- static __device__ __forceinline__ int floor_div(int a, int b)
23
- {
24
- int t = 1 - a / b;
25
- return (a + t * b) / b - t;
26
- }
27
-
28
- //------------------------------------------------------------------------
29
- // Generic CUDA implementation for large filters.
30
-
31
- template <class T> static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p)
32
- {
33
- typedef typename InternalType<T>::scalar_t scalar_t;
34
-
35
- // Calculate thread index.
36
- int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
37
- int outY = minorBase / p.launchMinor;
38
- minorBase -= outY * p.launchMinor;
39
- int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
40
- int majorBase = blockIdx.z * p.loopMajor;
41
- if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
42
- return;
43
-
44
- // Setup Y receptive field.
45
- int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
46
- int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
47
- int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
48
- int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
49
- if (p.flip)
50
- filterY = p.filterSize.y - 1 - filterY;
51
-
52
- // Loop over major, minor, and X.
53
- for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
54
- for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor)
55
- {
56
- int nc = major * p.sizeMinor + minor;
57
- int n = nc / p.inSize.z;
58
- int c = nc - n * p.inSize.z;
59
- for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y)
60
- {
61
- // Setup X receptive field.
62
- int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
63
- int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
64
- int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX;
65
- int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
66
- if (p.flip)
67
- filterX = p.filterSize.x - 1 - filterX;
68
-
69
- // Initialize pointers.
70
- const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
71
- const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
72
- int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
73
- int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
74
-
75
- // Inner loop.
76
- scalar_t v = 0;
77
- for (int y = 0; y < h; y++)
78
- {
79
- for (int x = 0; x < w; x++)
80
- {
81
- v += (scalar_t)(*xp) * (scalar_t)(*fp);
82
- xp += p.inStride.x;
83
- fp += filterStepX;
84
- }
85
- xp += p.inStride.y - w * p.inStride.x;
86
- fp += filterStepY - w * filterStepX;
87
- }
88
-
89
- // Store result.
90
- v *= p.gain;
91
- ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
92
- }
93
- }
94
- }
95
-
96
- //------------------------------------------------------------------------
97
- // Specialized CUDA implementation for small filters.
98
-
99
- template <class T, int upx, int upy, int downx, int downy, int filterW, int filterH, int tileOutW, int tileOutH, int loopMinor>
100
- static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p)
101
- {
102
- typedef typename InternalType<T>::scalar_t scalar_t;
103
- const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
104
- const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
105
- __shared__ volatile scalar_t sf[filterH][filterW];
106
- __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
107
-
108
- // Calculate tile index.
109
- int minorBase = blockIdx.x;
110
- int tileOutY = minorBase / p.launchMinor;
111
- minorBase -= tileOutY * p.launchMinor;
112
- minorBase *= loopMinor;
113
- tileOutY *= tileOutH;
114
- int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
115
- int majorBase = blockIdx.z * p.loopMajor;
116
- if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor)
117
- return;
118
-
119
- // Load filter (flipped).
120
- for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x)
121
- {
122
- int fy = tapIdx / filterW;
123
- int fx = tapIdx - fy * filterW;
124
- scalar_t v = 0;
125
- if (fx < p.filterSize.x & fy < p.filterSize.y)
126
- {
127
- int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
128
- int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
129
- v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
130
- }
131
- sf[fy][fx] = v;
132
- }
133
-
134
- // Loop over major and X.
135
- for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
136
- {
137
- int baseNC = major * p.sizeMinor + minorBase;
138
- int n = baseNC / p.inSize.z;
139
- int baseC = baseNC - n * p.inSize.z;
140
- for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW)
141
- {
142
- // Load input pixels.
143
- int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
144
- int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
145
- int tileInX = floor_div(tileMidX, upx);
146
- int tileInY = floor_div(tileMidY, upy);
147
- __syncthreads();
148
- for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x)
149
- {
150
- int relC = inIdx;
151
- int relInX = relC / loopMinor;
152
- int relInY = relInX / tileInW;
153
- relC -= relInX * loopMinor;
154
- relInX -= relInY * tileInW;
155
- int c = baseC + relC;
156
- int inX = tileInX + relInX;
157
- int inY = tileInY + relInY;
158
- scalar_t v = 0;
159
- if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z)
160
- v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
161
- sx[relInY][relInX][relC] = v;
162
- }
163
-
164
- // Loop over output pixels.
165
- __syncthreads();
166
- for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x)
167
- {
168
- int relC = outIdx;
169
- int relOutX = relC / loopMinor;
170
- int relOutY = relOutX / tileOutW;
171
- relC -= relOutX * loopMinor;
172
- relOutX -= relOutY * tileOutW;
173
- int c = baseC + relC;
174
- int outX = tileOutX + relOutX;
175
- int outY = tileOutY + relOutY;
176
-
177
- // Setup receptive field.
178
- int midX = tileMidX + relOutX * downx;
179
- int midY = tileMidY + relOutY * downy;
180
- int inX = floor_div(midX, upx);
181
- int inY = floor_div(midY, upy);
182
- int relInX = inX - tileInX;
183
- int relInY = inY - tileInY;
184
- int filterX = (inX + 1) * upx - midX - 1; // flipped
185
- int filterY = (inY + 1) * upy - midY - 1; // flipped
186
-
187
- // Inner loop.
188
- if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z)
189
- {
190
- scalar_t v = 0;
191
- #pragma unroll
192
- for (int y = 0; y < filterH / upy; y++)
193
- #pragma unroll
194
- for (int x = 0; x < filterW / upx; x++)
195
- v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx];
196
- v *= p.gain;
197
- ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
198
- }
199
- }
200
- }
201
- }
202
- }
203
-
204
- //------------------------------------------------------------------------
205
- // CUDA kernel selection.
206
-
207
- template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p)
208
- {
209
- int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
210
-
211
- upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,1, 4}; // contiguous
212
- if (s == 1) spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,4, 1}; // channels_last
213
-
214
- if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
215
- {
216
- if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 64,16,1>, 64,16,1, 1};
217
- if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
218
- if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5, 64,16,1>, 64,16,1, 1};
219
- if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
220
- if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3, 64,16,1>, 64,16,1, 1};
221
- if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
222
- if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 20,1, 128,8,1>, 128,8,1, 1};
223
- if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
224
- if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 12,1, 128,8,1>, 128,8,1, 1};
225
- if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
226
- if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
227
- if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,20, 32,32,1>, 32,32,1, 1};
228
- if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
229
- if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,12, 32,32,1>, 32,32,1, 1};
230
- if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
231
- }
232
- if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
233
- {
234
- if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 16,16,8>, 16,16,8, 1};
235
- if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
236
- if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
237
- if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
238
- if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
239
- if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
240
- if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 20,1, 128,1,16>, 128,1,16, 1};
241
- if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
242
- if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 12,1, 128,1,16>, 128,1,16, 1};
243
- if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
244
- if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
245
- if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,20, 1,128,16>, 1,128,16, 1};
246
- if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
247
- if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,12, 1,128,16>, 1,128,16, 1};
248
- if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
249
- }
250
- if (s != 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
251
- {
252
- if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 64,16,1>, 64,16,1, 1};
253
- if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
254
- if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
255
- if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 64,16,1>, 64,16,1, 1};
256
- }
257
- if (s == 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
258
- {
259
- if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 16,16,8>, 16,16,8, 1};
260
- if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 16,16,8>, 16,16,8, 1};
261
- if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
262
- if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 16,16,8>, 16,16,8, 1};
263
- }
264
- if (s != 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
265
- {
266
- if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
267
- if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 20,1, 128,8,1>, 128,8,1, 1};
268
- if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
269
- if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 12,1, 128,8,1>, 128,8,1, 1};
270
- if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
271
- }
272
- if (s == 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
273
- {
274
- if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
275
- if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 20,1, 128,1,16>, 128,1,16, 1};
276
- if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
277
- if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 12,1, 128,1,16>, 128,1,16, 1};
278
- if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
279
- }
280
- if (s != 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
281
- {
282
- if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
283
- if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,20, 32,32,1>, 32,32,1, 1};
284
- if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
285
- if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,12, 32,32,1>, 32,32,1, 1};
286
- if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
287
- }
288
- if (s == 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
289
- {
290
- if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
291
- if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,20, 1,128,16>, 1,128,16, 1};
292
- if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
293
- if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,12, 1,128,16>, 1,128,16, 1};
294
- if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
295
- }
296
- if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // contiguous
297
- {
298
- if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 32,8,1>, 32,8,1, 1};
299
- if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 32,8,1>, 32,8,1, 1};
300
- if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 32,8,1>, 32,8,1, 1};
301
- if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 32,8,1>, 32,8,1, 1};
302
- }
303
- if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // channels_last
304
- {
305
- if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 8,8,8>, 8,8,8, 1};
306
- if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 8,8,8>, 8,8,8, 1};
307
- if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 8,8,8>, 8,8,8, 1};
308
- if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 8,8,8>, 8,8,8, 1};
309
- }
310
- if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // contiguous
311
- {
312
- if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,8,1>, 64,8,1, 1};
313
- if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 20,1, 64,8,1>, 64,8,1, 1};
314
- if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,8,1>, 64,8,1, 1};
315
- if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 12,1, 64,8,1>, 64,8,1, 1};
316
- if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,8,1>, 64,8,1, 1};
317
- }
318
- if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // channels_last
319
- {
320
- if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,1,8>, 64,1,8, 1};
321
- if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 20,1, 64,1,8>, 64,1,8, 1};
322
- if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,1,8>, 64,1,8, 1};
323
- if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 12,1, 64,1,8>, 64,1,8, 1};
324
- if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,1,8>, 64,1,8, 1};
325
- }
326
- if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // contiguous
327
- {
328
- if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 32,16,1>, 32,16,1, 1};
329
- if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,20, 32,16,1>, 32,16,1, 1};
330
- if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 32,16,1>, 32,16,1, 1};
331
- if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,12, 32,16,1>, 32,16,1, 1};
332
- if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 32,16,1>, 32,16,1, 1};
333
- }
334
- if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // channels_last
335
- {
336
- if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 1,64,8>, 1,64,8, 1};
337
- if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,20, 1,64,8>, 1,64,8, 1};
338
- if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 1,64,8>, 1,64,8, 1};
339
- if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,12, 1,64,8>, 1,64,8, 1};
340
- if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 1,64,8>, 1,64,8, 1};
341
- }
342
- return spec;
343
- }
344
-
345
- //------------------------------------------------------------------------
346
- // Template specializations.
347
-
348
- template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<double> (const upfirdn2d_kernel_params& p);
349
- template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<float> (const upfirdn2d_kernel_params& p);
350
- template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<c10::Half>(const upfirdn2d_kernel_params& p);
351
-
352
- //------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/ops/upfirdn2d.h DELETED
@@ -1,61 +0,0 @@
1
- // Copyright (c) SenseTime Research. All rights reserved.
2
-
3
- // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
- //
5
- // NVIDIA CORPORATION and its licensors retain all intellectual property
6
- // and proprietary rights in and to this software, related documentation
7
- // and any modifications thereto. Any use, reproduction, disclosure or
8
- // distribution of this software and related documentation without an express
9
- // license agreement from NVIDIA CORPORATION is strictly prohibited.
10
-
11
- #include <cuda_runtime.h>
12
-
13
- //------------------------------------------------------------------------
14
- // CUDA kernel parameters.
15
-
16
- struct upfirdn2d_kernel_params
17
- {
18
- const void* x;
19
- const float* f;
20
- void* y;
21
-
22
- int2 up;
23
- int2 down;
24
- int2 pad0;
25
- int flip;
26
- float gain;
27
-
28
- int4 inSize; // [width, height, channel, batch]
29
- int4 inStride;
30
- int2 filterSize; // [width, height]
31
- int2 filterStride;
32
- int4 outSize; // [width, height, channel, batch]
33
- int4 outStride;
34
- int sizeMinor;
35
- int sizeMajor;
36
-
37
- int loopMinor;
38
- int loopMajor;
39
- int loopX;
40
- int launchMinor;
41
- int launchMajor;
42
- };
43
-
44
- //------------------------------------------------------------------------
45
- // CUDA kernel specialization.
46
-
47
- struct upfirdn2d_kernel_spec
48
- {
49
- void* kernel;
50
- int tileOutW;
51
- int tileOutH;
52
- int loopMinor;
53
- int loopX;
54
- };
55
-
56
- //------------------------------------------------------------------------
57
- // CUDA kernel selection.
58
-
59
- template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
60
-
61
- //------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/ops/upfirdn2d.py DELETED
@@ -1,386 +0,0 @@
1
- # Copyright (c) SenseTime Research. All rights reserved.
2
-
3
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
- #
5
- # NVIDIA CORPORATION and its licensors retain all intellectual property
6
- # and proprietary rights in and to this software, related documentation
7
- # and any modifications thereto. Any use, reproduction, disclosure or
8
- # distribution of this software and related documentation without an express
9
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
-
11
- """Custom PyTorch ops for efficient resampling of 2D images."""
12
-
13
- import os
14
- import warnings
15
- import numpy as np
16
- import torch
17
- import traceback
18
-
19
- from .. import custom_ops
20
- from .. import misc
21
- from . import conv2d_gradfix
22
-
23
- #----------------------------------------------------------------------------
24
-
25
- _inited = False
26
- _plugin = None
27
-
28
- def _init():
29
- global _inited, _plugin
30
- if not _inited:
31
- sources = ['upfirdn2d.cpp', 'upfirdn2d.cu']
32
- sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
33
- try:
34
- _plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
35
- except:
36
- warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
37
- return _plugin is not None
38
-
39
- def _parse_scaling(scaling):
40
- if isinstance(scaling, int):
41
- scaling = [scaling, scaling]
42
- assert isinstance(scaling, (list, tuple))
43
- assert all(isinstance(x, int) for x in scaling)
44
- sx, sy = scaling
45
- assert sx >= 1 and sy >= 1
46
- return sx, sy
47
-
48
- def _parse_padding(padding):
49
- if isinstance(padding, int):
50
- padding = [padding, padding]
51
- assert isinstance(padding, (list, tuple))
52
- assert all(isinstance(x, int) for x in padding)
53
- if len(padding) == 2:
54
- padx, pady = padding
55
- padding = [padx, padx, pady, pady]
56
- padx0, padx1, pady0, pady1 = padding
57
- return padx0, padx1, pady0, pady1
58
-
59
- def _get_filter_size(f):
60
- if f is None:
61
- return 1, 1
62
- assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
63
- fw = f.shape[-1]
64
- fh = f.shape[0]
65
- with misc.suppress_tracer_warnings():
66
- fw = int(fw)
67
- fh = int(fh)
68
- misc.assert_shape(f, [fh, fw][:f.ndim])
69
- assert fw >= 1 and fh >= 1
70
- return fw, fh
71
-
72
- #----------------------------------------------------------------------------
73
-
74
- def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None):
75
- r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
76
-
77
- Args:
78
- f: Torch tensor, numpy array, or python list of the shape
79
- `[filter_height, filter_width]` (non-separable),
80
- `[filter_taps]` (separable),
81
- `[]` (impulse), or
82
- `None` (identity).
83
- device: Result device (default: cpu).
84
- normalize: Normalize the filter so that it retains the magnitude
85
- for constant input signal (DC)? (default: True).
86
- flip_filter: Flip the filter? (default: False).
87
- gain: Overall scaling factor for signal magnitude (default: 1).
88
- separable: Return a separable filter? (default: select automatically).
89
-
90
- Returns:
91
- Float32 tensor of the shape
92
- `[filter_height, filter_width]` (non-separable) or
93
- `[filter_taps]` (separable).
94
- """
95
- # Validate.
96
- if f is None:
97
- f = 1
98
- f = torch.as_tensor(f, dtype=torch.float32)
99
- assert f.ndim in [0, 1, 2]
100
- assert f.numel() > 0
101
- if f.ndim == 0:
102
- f = f[np.newaxis]
103
-
104
- # Separable?
105
- if separable is None:
106
- separable = (f.ndim == 1 and f.numel() >= 8)
107
- if f.ndim == 1 and not separable:
108
- f = f.ger(f)
109
- assert f.ndim == (1 if separable else 2)
110
-
111
- # Apply normalize, flip, gain, and device.
112
- if normalize:
113
- f /= f.sum()
114
- if flip_filter:
115
- f = f.flip(list(range(f.ndim)))
116
- f = f * (gain ** (f.ndim / 2))
117
- f = f.to(device=device)
118
- return f
119
-
120
- #----------------------------------------------------------------------------
121
-
122
- def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
123
- r"""Pad, upsample, filter, and downsample a batch of 2D images.
124
-
125
- Performs the following sequence of operations for each channel:
126
-
127
- 1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
128
-
129
- 2. Pad the image with the specified number of zeros on each side (`padding`).
130
- Negative padding corresponds to cropping the image.
131
-
132
- 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
133
- so that the footprint of all output pixels lies within the input image.
134
-
135
- 4. Downsample the image by keeping every Nth pixel (`down`).
136
-
137
- This sequence of operations bears close resemblance to scipy.signal.upfirdn().
138
- The fused op is considerably more efficient than performing the same calculation
139
- using standard PyTorch ops. It supports gradients of arbitrary order.
140
-
141
- Args:
142
- x: Float32/float64/float16 input tensor of the shape
143
- `[batch_size, num_channels, in_height, in_width]`.
144
- f: Float32 FIR filter of the shape
145
- `[filter_height, filter_width]` (non-separable),
146
- `[filter_taps]` (separable), or
147
- `None` (identity).
148
- up: Integer upsampling factor. Can be a single int or a list/tuple
149
- `[x, y]` (default: 1).
150
- down: Integer downsampling factor. Can be a single int or a list/tuple
151
- `[x, y]` (default: 1).
152
- padding: Padding with respect to the upsampled image. Can be a single number
153
- or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
154
- (default: 0).
155
- flip_filter: False = convolution, True = correlation (default: False).
156
- gain: Overall scaling factor for signal magnitude (default: 1).
157
- impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
158
-
159
- Returns:
160
- Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
161
- """
162
- assert isinstance(x, torch.Tensor)
163
- assert impl in ['ref', 'cuda']
164
- if impl == 'cuda' and x.device.type == 'cuda' and _init():
165
- return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f)
166
- return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
167
-
168
- #----------------------------------------------------------------------------
169
-
170
- @misc.profiled_function
171
- def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
172
- """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
173
- """
174
- # Validate arguments.
175
- assert isinstance(x, torch.Tensor) and x.ndim == 4
176
- if f is None:
177
- f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
178
- assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
179
- assert f.dtype == torch.float32 and not f.requires_grad
180
- batch_size, num_channels, in_height, in_width = x.shape
181
- upx, upy = _parse_scaling(up)
182
- downx, downy = _parse_scaling(down)
183
- padx0, padx1, pady0, pady1 = _parse_padding(padding)
184
-
185
- # Upsample by inserting zeros.
186
- x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
187
- x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
188
- x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
189
-
190
- # Pad or crop.
191
- x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
192
- x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)]
193
-
194
- # Setup filter.
195
- f = f * (gain ** (f.ndim / 2))
196
- f = f.to(x.dtype)
197
- if not flip_filter:
198
- f = f.flip(list(range(f.ndim)))
199
-
200
- # Convolve with the filter.
201
- f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
202
- if f.ndim == 4:
203
- x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels)
204
- else:
205
- x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
206
- x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
207
-
208
- # Downsample by throwing away pixels.
209
- x = x[:, :, ::downy, ::downx]
210
- return x
211
-
212
- #----------------------------------------------------------------------------
213
-
214
- _upfirdn2d_cuda_cache = dict()
215
-
216
- def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1):
217
- """Fast CUDA implementation of `upfirdn2d()` using custom ops.
218
- """
219
- # Parse arguments.
220
- upx, upy = _parse_scaling(up)
221
- downx, downy = _parse_scaling(down)
222
- padx0, padx1, pady0, pady1 = _parse_padding(padding)
223
-
224
- # Lookup from cache.
225
- key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
226
- if key in _upfirdn2d_cuda_cache:
227
- return _upfirdn2d_cuda_cache[key]
228
-
229
- # Forward op.
230
- class Upfirdn2dCuda(torch.autograd.Function):
231
- @staticmethod
232
- def forward(ctx, x, f): # pylint: disable=arguments-differ
233
- assert isinstance(x, torch.Tensor) and x.ndim == 4
234
- if f is None:
235
- f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
236
- assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
237
- y = x
238
- if f.ndim == 2:
239
- y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
240
- else:
241
- y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain))
242
- y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain))
243
- ctx.save_for_backward(f)
244
- ctx.x_shape = x.shape
245
- return y
246
-
247
- @staticmethod
248
- def backward(ctx, dy): # pylint: disable=arguments-differ
249
- f, = ctx.saved_tensors
250
- _, _, ih, iw = ctx.x_shape
251
- _, _, oh, ow = dy.shape
252
- fw, fh = _get_filter_size(f)
253
- p = [
254
- fw - padx0 - 1,
255
- iw * upx - ow * downx + padx0 - upx + 1,
256
- fh - pady0 - 1,
257
- ih * upy - oh * downy + pady0 - upy + 1,
258
- ]
259
- dx = None
260
- df = None
261
-
262
- if ctx.needs_input_grad[0]:
263
- dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f)
264
-
265
- assert not ctx.needs_input_grad[1]
266
- return dx, df
267
-
268
- # Add to cache.
269
- _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda
270
- return Upfirdn2dCuda
271
-
272
- #----------------------------------------------------------------------------
273
-
274
- def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'):
275
- r"""Filter a batch of 2D images using the given 2D FIR filter.
276
-
277
- By default, the result is padded so that its shape matches the input.
278
- User-specified padding is applied on top of that, with negative values
279
- indicating cropping. Pixels outside the image are assumed to be zero.
280
-
281
- Args:
282
- x: Float32/float64/float16 input tensor of the shape
283
- `[batch_size, num_channels, in_height, in_width]`.
284
- f: Float32 FIR filter of the shape
285
- `[filter_height, filter_width]` (non-separable),
286
- `[filter_taps]` (separable), or
287
- `None` (identity).
288
- padding: Padding with respect to the output. Can be a single number or a
289
- list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
290
- (default: 0).
291
- flip_filter: False = convolution, True = correlation (default: False).
292
- gain: Overall scaling factor for signal magnitude (default: 1).
293
- impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
294
-
295
- Returns:
296
- Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
297
- """
298
- padx0, padx1, pady0, pady1 = _parse_padding(padding)
299
- fw, fh = _get_filter_size(f)
300
- p = [
301
- padx0 + fw // 2,
302
- padx1 + (fw - 1) // 2,
303
- pady0 + fh // 2,
304
- pady1 + (fh - 1) // 2,
305
- ]
306
- return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
307
-
308
- #----------------------------------------------------------------------------
309
-
310
- def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
311
- r"""Upsample a batch of 2D images using the given 2D FIR filter.
312
-
313
- By default, the result is padded so that its shape is a multiple of the input.
314
- User-specified padding is applied on top of that, with negative values
315
- indicating cropping. Pixels outside the image are assumed to be zero.
316
-
317
- Args:
318
- x: Float32/float64/float16 input tensor of the shape
319
- `[batch_size, num_channels, in_height, in_width]`.
320
- f: Float32 FIR filter of the shape
321
- `[filter_height, filter_width]` (non-separable),
322
- `[filter_taps]` (separable), or
323
- `None` (identity).
324
- up: Integer upsampling factor. Can be a single int or a list/tuple
325
- `[x, y]` (default: 1).
326
- padding: Padding with respect to the output. Can be a single number or a
327
- list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
328
- (default: 0).
329
- flip_filter: False = convolution, True = correlation (default: False).
330
- gain: Overall scaling factor for signal magnitude (default: 1).
331
- impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
332
-
333
- Returns:
334
- Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
335
- """
336
- upx, upy = _parse_scaling(up)
337
- padx0, padx1, pady0, pady1 = _parse_padding(padding)
338
- fw, fh = _get_filter_size(f)
339
- p = [
340
- padx0 + (fw + upx - 1) // 2,
341
- padx1 + (fw - upx) // 2,
342
- pady0 + (fh + upy - 1) // 2,
343
- pady1 + (fh - upy) // 2,
344
- ]
345
- return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl)
346
-
347
- #----------------------------------------------------------------------------
348
-
349
- def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
350
- r"""Downsample a batch of 2D images using the given 2D FIR filter.
351
-
352
- By default, the result is padded so that its shape is a fraction of the input.
353
- User-specified padding is applied on top of that, with negative values
354
- indicating cropping. Pixels outside the image are assumed to be zero.
355
-
356
- Args:
357
- x: Float32/float64/float16 input tensor of the shape
358
- `[batch_size, num_channels, in_height, in_width]`.
359
- f: Float32 FIR filter of the shape
360
- `[filter_height, filter_width]` (non-separable),
361
- `[filter_taps]` (separable), or
362
- `None` (identity).
363
- down: Integer downsampling factor. Can be a single int or a list/tuple
364
- `[x, y]` (default: 1).
365
- padding: Padding with respect to the input. Can be a single number or a
366
- list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
367
- (default: 0).
368
- flip_filter: False = convolution, True = correlation (default: False).
369
- gain: Overall scaling factor for signal magnitude (default: 1).
370
- impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
371
-
372
- Returns:
373
- Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
374
- """
375
- downx, downy = _parse_scaling(down)
376
- padx0, padx1, pady0, pady1 = _parse_padding(padding)
377
- fw, fh = _get_filter_size(f)
378
- p = [
379
- padx0 + (fw - downx + 1) // 2,
380
- padx1 + (fw - downx) // 2,
381
- pady0 + (fh - downy + 1) // 2,
382
- pady1 + (fh - downy) // 2,
383
- ]
384
- return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
385
-
386
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/persistence.py DELETED
@@ -1,253 +0,0 @@
1
- # Copyright (c) SenseTime Research. All rights reserved.
2
-
3
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
- #
5
- # NVIDIA CORPORATION and its licensors retain all intellectual property
6
- # and proprietary rights in and to this software, related documentation
7
- # and any modifications thereto. Any use, reproduction, disclosure or
8
- # distribution of this software and related documentation without an express
9
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
-
11
- """Facilities for pickling Python code alongside other data.
12
-
13
- The pickled code is automatically imported into a separate Python module
14
- during unpickling. This way, any previously exported pickles will remain
15
- usable even if the original code is no longer available, or if the current
16
- version of the code is not consistent with what was originally pickled."""
17
-
18
- import sys
19
- import pickle
20
- import io
21
- import inspect
22
- import copy
23
- import uuid
24
- import types
25
- import dnnlib
26
-
27
- #----------------------------------------------------------------------------
28
-
29
- _version = 6 # internal version number
30
- _decorators = set() # {decorator_class, ...}
31
- _import_hooks = [] # [hook_function, ...]
32
- _module_to_src_dict = dict() # {module: src, ...}
33
- _src_to_module_dict = dict() # {src: module, ...}
34
-
35
- #----------------------------------------------------------------------------
36
-
37
- def persistent_class(orig_class):
38
- r"""Class decorator that extends a given class to save its source code
39
- when pickled.
40
-
41
- Example:
42
-
43
- from torch_utils import persistence
44
-
45
- @persistence.persistent_class
46
- class MyNetwork(torch.nn.Module):
47
- def __init__(self, num_inputs, num_outputs):
48
- super().__init__()
49
- self.fc = MyLayer(num_inputs, num_outputs)
50
- ...
51
-
52
- @persistence.persistent_class
53
- class MyLayer(torch.nn.Module):
54
- ...
55
-
56
- When pickled, any instance of `MyNetwork` and `MyLayer` will save its
57
- source code alongside other internal state (e.g., parameters, buffers,
58
- and submodules). This way, any previously exported pickle will remain
59
- usable even if the class definitions have been modified or are no
60
- longer available.
61
-
62
- The decorator saves the source code of the entire Python module
63
- containing the decorated class. It does *not* save the source code of
64
- any imported modules. Thus, the imported modules must be available
65
- during unpickling, also including `torch_utils.persistence` itself.
66
-
67
- It is ok to call functions defined in the same module from the
68
- decorated class. However, if the decorated class depends on other
69
- classes defined in the same module, they must be decorated as well.
70
- This is illustrated in the above example in the case of `MyLayer`.
71
-
72
- It is also possible to employ the decorator just-in-time before
73
- calling the constructor. For example:
74
-
75
- cls = MyLayer
76
- if want_to_make_it_persistent:
77
- cls = persistence.persistent_class(cls)
78
- layer = cls(num_inputs, num_outputs)
79
-
80
- As an additional feature, the decorator also keeps track of the
81
- arguments that were used to construct each instance of the decorated
82
- class. The arguments can be queried via `obj.init_args` and
83
- `obj.init_kwargs`, and they are automatically pickled alongside other
84
- object state. A typical use case is to first unpickle a previous
85
- instance of a persistent class, and then upgrade it to use the latest
86
- version of the source code:
87
-
88
- with open('old_pickle.pkl', 'rb') as f:
89
- old_net = pickle.load(f)
90
- new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
91
- misc.copy_params_and_buffers(old_net, new_net, require_all=True)
92
- """
93
- assert isinstance(orig_class, type)
94
- if is_persistent(orig_class):
95
- return orig_class
96
-
97
- assert orig_class.__module__ in sys.modules
98
- orig_module = sys.modules[orig_class.__module__]
99
- orig_module_src = _module_to_src(orig_module)
100
-
101
- class Decorator(orig_class):
102
- _orig_module_src = orig_module_src
103
- _orig_class_name = orig_class.__name__
104
-
105
- def __init__(self, *args, **kwargs):
106
- super().__init__(*args, **kwargs)
107
- self._init_args = copy.deepcopy(args)
108
- self._init_kwargs = copy.deepcopy(kwargs)
109
- assert orig_class.__name__ in orig_module.__dict__
110
- _check_pickleable(self.__reduce__())
111
-
112
- @property
113
- def init_args(self):
114
- return copy.deepcopy(self._init_args)
115
-
116
- @property
117
- def init_kwargs(self):
118
- return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs))
119
-
120
- def __reduce__(self):
121
- fields = list(super().__reduce__())
122
- fields += [None] * max(3 - len(fields), 0)
123
- if fields[0] is not _reconstruct_persistent_obj:
124
- meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
125
- fields[0] = _reconstruct_persistent_obj # reconstruct func
126
- fields[1] = (meta,) # reconstruct args
127
- fields[2] = None # state dict
128
- return tuple(fields)
129
-
130
- Decorator.__name__ = orig_class.__name__
131
- _decorators.add(Decorator)
132
- return Decorator
133
-
134
- #----------------------------------------------------------------------------
135
-
136
- def is_persistent(obj):
137
- r"""Test whether the given object or class is persistent, i.e.,
138
- whether it will save its source code when pickled.
139
- """
140
- try:
141
- if obj in _decorators:
142
- return True
143
- except TypeError:
144
- pass
145
- return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
146
-
147
- #----------------------------------------------------------------------------
148
-
149
- def import_hook(hook):
150
- r"""Register an import hook that is called whenever a persistent object
151
- is being unpickled. A typical use case is to patch the pickled source
152
- code to avoid errors and inconsistencies when the API of some imported
153
- module has changed.
154
-
155
- The hook should have the following signature:
156
-
157
- hook(meta) -> modified meta
158
-
159
- `meta` is an instance of `dnnlib.EasyDict` with the following fields:
160
-
161
- type: Type of the persistent object, e.g. `'class'`.
162
- version: Internal version number of `torch_utils.persistence`.
163
- module_src Original source code of the Python module.
164
- class_name: Class name in the original Python module.
165
- state: Internal state of the object.
166
-
167
- Example:
168
-
169
- @persistence.import_hook
170
- def wreck_my_network(meta):
171
- if meta.class_name == 'MyNetwork':
172
- print('MyNetwork is being imported. I will wreck it!')
173
- meta.module_src = meta.module_src.replace("True", "False")
174
- return meta
175
- """
176
- assert callable(hook)
177
- _import_hooks.append(hook)
178
-
179
- #----------------------------------------------------------------------------
180
-
181
- def _reconstruct_persistent_obj(meta):
182
- r"""Hook that is called internally by the `pickle` module to unpickle
183
- a persistent object.
184
- """
185
- meta = dnnlib.EasyDict(meta)
186
- meta.state = dnnlib.EasyDict(meta.state)
187
- for hook in _import_hooks:
188
- meta = hook(meta)
189
- assert meta is not None
190
-
191
- assert meta.version == _version
192
- module = _src_to_module(meta.module_src)
193
-
194
- assert meta.type == 'class'
195
- orig_class = module.__dict__[meta.class_name]
196
- decorator_class = persistent_class(orig_class)
197
- obj = decorator_class.__new__(decorator_class)
198
-
199
- setstate = getattr(obj, '__setstate__', None)
200
- if callable(setstate):
201
- setstate(meta.state) # pylint: disable=not-callable
202
- else:
203
- obj.__dict__.update(meta.state)
204
- return obj
205
-
206
- #----------------------------------------------------------------------------
207
-
208
- def _module_to_src(module):
209
- r"""Query the source code of a given Python module.
210
- """
211
- src = _module_to_src_dict.get(module, None)
212
- if src is None:
213
- src = inspect.getsource(module)
214
- _module_to_src_dict[module] = src
215
- _src_to_module_dict[src] = module
216
- return src
217
-
218
- def _src_to_module(src):
219
- r"""Get or create a Python module for the given source code.
220
- """
221
- module = _src_to_module_dict.get(src, None)
222
- if module is None:
223
- module_name = "_imported_module_" + uuid.uuid4().hex
224
- module = types.ModuleType(module_name)
225
- sys.modules[module_name] = module
226
- _module_to_src_dict[module] = src
227
- _src_to_module_dict[src] = module
228
- exec(src, module.__dict__) # pylint: disable=exec-used
229
- return module
230
-
231
- #----------------------------------------------------------------------------
232
-
233
- def _check_pickleable(obj):
234
- r"""Check that the given object is pickleable, raising an exception if
235
- it is not. This function is expected to be considerably more efficient
236
- than actually pickling the object.
237
- """
238
- def recurse(obj):
239
- if isinstance(obj, (list, tuple, set)):
240
- return [recurse(x) for x in obj]
241
- if isinstance(obj, dict):
242
- return [[recurse(x), recurse(y)] for x, y in obj.items()]
243
- if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
244
- return None # Python primitive types are pickleable.
245
- if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']:
246
- return None # NumPy arrays and PyTorch tensors are pickleable.
247
- if is_persistent(obj):
248
- return None # Persistent objects are pickleable, by virtue of the constructor check.
249
- return obj
250
- with io.BytesIO() as f:
251
- pickle.dump(recurse(obj), f)
252
-
253
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/training_stats.py DELETED
@@ -1,270 +0,0 @@
1
- # Copyright (c) SenseTime Research. All rights reserved.
2
-
3
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
- #
5
- # NVIDIA CORPORATION and its licensors retain all intellectual property
6
- # and proprietary rights in and to this software, related documentation
7
- # and any modifications thereto. Any use, reproduction, disclosure or
8
- # distribution of this software and related documentation without an express
9
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
-
11
- """Facilities for reporting and collecting training statistics across
12
- multiple processes and devices. The interface is designed to minimize
13
- synchronization overhead as well as the amount of boilerplate in user
14
- code."""
15
-
16
- import re
17
- import numpy as np
18
- import torch
19
- import dnnlib
20
-
21
- from . import misc
22
-
23
- #----------------------------------------------------------------------------
24
-
25
- _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
26
- _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
27
- _counter_dtype = torch.float64 # Data type to use for the internal counters.
28
- _rank = 0 # Rank of the current process.
29
- _sync_device = None # Device to use for multiprocess communication. None = single-process.
30
- _sync_called = False # Has _sync() been called yet?
31
- _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor
32
- _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
33
-
34
- #----------------------------------------------------------------------------
35
-
36
- def init_multiprocessing(rank, sync_device):
37
- r"""Initializes `torch_utils.training_stats` for collecting statistics
38
- across multiple processes.
39
-
40
- This function must be called after
41
- `torch.distributed.init_process_group()` and before `Collector.update()`.
42
- The call is not necessary if multi-process collection is not needed.
43
-
44
- Args:
45
- rank: Rank of the current process.
46
- sync_device: PyTorch device to use for inter-process
47
- communication, or None to disable multi-process
48
- collection. Typically `torch.device('cuda', rank)`.
49
- """
50
- global _rank, _sync_device
51
- assert not _sync_called
52
- _rank = rank
53
- _sync_device = sync_device
54
-
55
- #----------------------------------------------------------------------------
56
-
57
- @misc.profiled_function
58
- def report(name, value):
59
- r"""Broadcasts the given set of scalars to all interested instances of
60
- `Collector`, across device and process boundaries.
61
-
62
- This function is expected to be extremely cheap and can be safely
63
- called from anywhere in the training loop, loss function, or inside a
64
- `torch.nn.Module`.
65
-
66
- Warning: The current implementation expects the set of unique names to
67
- be consistent across processes. Please make sure that `report()` is
68
- called at least once for each unique name by each process, and in the
69
- same order. If a given process has no scalars to broadcast, it can do
70
- `report(name, [])` (empty list).
71
-
72
- Args:
73
- name: Arbitrary string specifying the name of the statistic.
74
- Averages are accumulated separately for each unique name.
75
- value: Arbitrary set of scalars. Can be a list, tuple,
76
- NumPy array, PyTorch tensor, or Python scalar.
77
-
78
- Returns:
79
- The same `value` that was passed in.
80
- """
81
- if name not in _counters:
82
- _counters[name] = dict()
83
-
84
- elems = torch.as_tensor(value)
85
- if elems.numel() == 0:
86
- return value
87
-
88
- elems = elems.detach().flatten().to(_reduce_dtype)
89
- moments = torch.stack([
90
- torch.ones_like(elems).sum(),
91
- elems.sum(),
92
- elems.square().sum(),
93
- ])
94
- assert moments.ndim == 1 and moments.shape[0] == _num_moments
95
- moments = moments.to(_counter_dtype)
96
-
97
- device = moments.device
98
- if device not in _counters[name]:
99
- _counters[name][device] = torch.zeros_like(moments)
100
- _counters[name][device].add_(moments)
101
- return value
102
-
103
- #----------------------------------------------------------------------------
104
-
105
- def report0(name, value):
106
- r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
107
- but ignores any scalars provided by the other processes.
108
- See `report()` for further details.
109
- """
110
- report(name, value if _rank == 0 else [])
111
- return value
112
-
113
- #----------------------------------------------------------------------------
114
-
115
- class Collector:
116
- r"""Collects the scalars broadcasted by `report()` and `report0()` and
117
- computes their long-term averages (mean and standard deviation) over
118
- user-defined periods of time.
119
-
120
- The averages are first collected into internal counters that are not
121
- directly visible to the user. They are then copied to the user-visible
122
- state as a result of calling `update()` and can then be queried using
123
- `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the
124
- internal counters for the next round, so that the user-visible state
125
- effectively reflects averages collected between the last two calls to
126
- `update()`.
127
-
128
- Args:
129
- regex: Regular expression defining which statistics to
130
- collect. The default is to collect everything.
131
- keep_previous: Whether to retain the previous averages if no
132
- scalars were collected on a given round
133
- (default: True).
134
- """
135
- def __init__(self, regex='.*', keep_previous=True):
136
- self._regex = re.compile(regex)
137
- self._keep_previous = keep_previous
138
- self._cumulative = dict()
139
- self._moments = dict()
140
- self.update()
141
- self._moments.clear()
142
-
143
- def names(self):
144
- r"""Returns the names of all statistics broadcasted so far that
145
- match the regular expression specified at construction time.
146
- """
147
- return [name for name in _counters if self._regex.fullmatch(name)]
148
-
149
- def update(self):
150
- r"""Copies current values of the internal counters to the
151
- user-visible state and resets them for the next round.
152
-
153
- If `keep_previous=True` was specified at construction time, the
154
- operation is skipped for statistics that have received no scalars
155
- since the last update, retaining their previous averages.
156
-
157
- This method performs a number of GPU-to-CPU transfers and one
158
- `torch.distributed.all_reduce()`. It is intended to be called
159
- periodically in the main training loop, typically once every
160
- N training steps.
161
- """
162
- if not self._keep_previous:
163
- self._moments.clear()
164
- for name, cumulative in _sync(self.names()):
165
- if name not in self._cumulative:
166
- self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
167
- delta = cumulative - self._cumulative[name]
168
- self._cumulative[name].copy_(cumulative)
169
- if float(delta[0]) != 0:
170
- self._moments[name] = delta
171
-
172
- def _get_delta(self, name):
173
- r"""Returns the raw moments that were accumulated for the given
174
- statistic between the last two calls to `update()`, or zero if
175
- no scalars were collected.
176
- """
177
- assert self._regex.fullmatch(name)
178
- if name not in self._moments:
179
- self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
180
- return self._moments[name]
181
-
182
- def num(self, name):
183
- r"""Returns the number of scalars that were accumulated for the given
184
- statistic between the last two calls to `update()`, or zero if
185
- no scalars were collected.
186
- """
187
- delta = self._get_delta(name)
188
- return int(delta[0])
189
-
190
- def mean(self, name):
191
- r"""Returns the mean of the scalars that were accumulated for the
192
- given statistic between the last two calls to `update()`, or NaN if
193
- no scalars were collected.
194
- """
195
- delta = self._get_delta(name)
196
- if int(delta[0]) == 0:
197
- return float('nan')
198
- return float(delta[1] / delta[0])
199
-
200
- def std(self, name):
201
- r"""Returns the standard deviation of the scalars that were
202
- accumulated for the given statistic between the last two calls to
203
- `update()`, or NaN if no scalars were collected.
204
- """
205
- delta = self._get_delta(name)
206
- if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
207
- return float('nan')
208
- if int(delta[0]) == 1:
209
- return float(0)
210
- mean = float(delta[1] / delta[0])
211
- raw_var = float(delta[2] / delta[0])
212
- return np.sqrt(max(raw_var - np.square(mean), 0))
213
-
214
- def as_dict(self):
215
- r"""Returns the averages accumulated between the last two calls to
216
- `update()` as an `dnnlib.EasyDict`. The contents are as follows:
217
-
218
- dnnlib.EasyDict(
219
- NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT),
220
- ...
221
- )
222
- """
223
- stats = dnnlib.EasyDict()
224
- for name in self.names():
225
- stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
226
- return stats
227
-
228
- def __getitem__(self, name):
229
- r"""Convenience getter.
230
- `collector[name]` is a synonym for `collector.mean(name)`.
231
- """
232
- return self.mean(name)
233
-
234
- #----------------------------------------------------------------------------
235
-
236
- def _sync(names):
237
- r"""Synchronize the global cumulative counters across devices and
238
- processes. Called internally by `Collector.update()`.
239
- """
240
- if len(names) == 0:
241
- return []
242
- global _sync_called
243
- _sync_called = True
244
-
245
- # Collect deltas within current rank.
246
- deltas = []
247
- device = _sync_device if _sync_device is not None else torch.device('cpu')
248
- for name in names:
249
- delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
250
- for counter in _counters[name].values():
251
- delta.add_(counter.to(device))
252
- counter.copy_(torch.zeros_like(counter))
253
- deltas.append(delta)
254
- deltas = torch.stack(deltas)
255
-
256
- # Sum deltas across ranks.
257
- if _sync_device is not None:
258
- torch.distributed.all_reduce(deltas)
259
-
260
- # Update cumulative values.
261
- deltas = deltas.cpu()
262
- for idx, name in enumerate(names):
263
- if name not in _cumulative:
264
- _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
265
- _cumulative[name].add_(deltas[idx])
266
-
267
- # Return name-value pairs.
268
- return [(name, _cumulative[name]) for name in names]
269
-
270
- #----------------------------------------------------------------------------