- torch_utils/__init__.py +0 -11
- torch_utils/custom_ops.py +0 -238
- torch_utils/misc.py +0 -264
- torch_utils/models.py +0 -756
- torch_utils/models_face.py +0 -809
- torch_utils/op_edit/__init__.py +0 -4
- torch_utils/op_edit/fused_act.py +0 -99
- torch_utils/op_edit/fused_bias_act.cpp +0 -23
- torch_utils/op_edit/fused_bias_act_kernel.cu +0 -101
- torch_utils/op_edit/upfirdn2d.cpp +0 -25
- torch_utils/op_edit/upfirdn2d.py +0 -202
- torch_utils/op_edit/upfirdn2d_kernel.cu +0 -371
- torch_utils/ops/__init__.py +0 -3
- torch_utils/ops/bias_act.cpp +0 -101
- torch_utils/ops/bias_act.cu +0 -175
- torch_utils/ops/bias_act.h +0 -40
- torch_utils/ops/bias_act.py +0 -214
- torch_utils/ops/conv2d_gradfix.py +0 -172
- torch_utils/ops/conv2d_resample.py +0 -158
- torch_utils/ops/filtered_lrelu.cpp +0 -300
- torch_utils/ops/filtered_lrelu.cu +0 -1284
- torch_utils/ops/filtered_lrelu.h +0 -90
- torch_utils/ops/filtered_lrelu.py +0 -282
- torch_utils/ops/filtered_lrelu_ns.cu +0 -27
- torch_utils/ops/filtered_lrelu_rd.cu +0 -27
- torch_utils/ops/filtered_lrelu_wr.cu +0 -27
- torch_utils/ops/fma.py +0 -62
- torch_utils/ops/grid_sample_gradfix.py +0 -85
- torch_utils/ops/upfirdn2d.cpp +0 -105
- torch_utils/ops/upfirdn2d.cu +0 -352
- torch_utils/ops/upfirdn2d.h +0 -61
- torch_utils/ops/upfirdn2d.py +0 -386
- torch_utils/persistence.py +0 -253
- torch_utils/training_stats.py +0 -270
torch_utils/__init__.py
DELETED
@@ -1,11 +0,0 @@
|
|
1 |
-
# Copyright (c) SenseTime Research. All rights reserved.
|
2 |
-
|
3 |
-
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
4 |
-
#
|
5 |
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
6 |
-
# and proprietary rights in and to this software, related documentation
|
7 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
8 |
-
# distribution of this software and related documentation without an express
|
9 |
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
-
|
11 |
-
# empty
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch_utils/custom_ops.py
DELETED
@@ -1,238 +0,0 @@
|
|
1 |
-
# Copyright (c) SenseTime Research. All rights reserved.
|
2 |
-
|
3 |
-
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
4 |
-
#
|
5 |
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
6 |
-
# and proprietary rights in and to this software, related documentation
|
7 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
8 |
-
# distribution of this software and related documentation without an express
|
9 |
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
-
|
11 |
-
import os
|
12 |
-
import glob
|
13 |
-
import torch
|
14 |
-
import torch.utils.cpp_extension
|
15 |
-
import importlib
|
16 |
-
import hashlib
|
17 |
-
import shutil
|
18 |
-
from pathlib import Path
|
19 |
-
import re
|
20 |
-
import uuid
|
21 |
-
|
22 |
-
from torch.utils.file_baton import FileBaton
|
23 |
-
|
24 |
-
#----------------------------------------------------------------------------
|
25 |
-
# Global options.
|
26 |
-
|
27 |
-
verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
|
28 |
-
|
29 |
-
#----------------------------------------------------------------------------
|
30 |
-
# Internal helper funcs.
|
31 |
-
|
32 |
-
def _find_compiler_bindir():
|
33 |
-
patterns = [
|
34 |
-
'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
35 |
-
'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
36 |
-
'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
37 |
-
'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
|
38 |
-
]
|
39 |
-
for pattern in patterns:
|
40 |
-
matches = sorted(glob.glob(pattern))
|
41 |
-
if len(matches):
|
42 |
-
return matches[-1]
|
43 |
-
return None
|
44 |
-
|
45 |
-
def _get_mangled_gpu_name():
|
46 |
-
name = torch.cuda.get_device_name().lower()
|
47 |
-
out = []
|
48 |
-
for c in name:
|
49 |
-
if re.match('[a-z0-9_-]+', c):
|
50 |
-
out.append(c)
|
51 |
-
else:
|
52 |
-
out.append('-')
|
53 |
-
return ''.join(out)
|
54 |
-
|
55 |
-
|
56 |
-
#----------------------------------------------------------------------------
|
57 |
-
# Main entry point for compiling and loading C++/CUDA plugins.
|
58 |
-
|
59 |
-
_cached_plugins = dict()
|
60 |
-
|
61 |
-
def get_plugin(module_name, sources, **build_kwargs):
|
62 |
-
assert verbosity in ['none', 'brief', 'full']
|
63 |
-
|
64 |
-
# Already cached?
|
65 |
-
if module_name in _cached_plugins:
|
66 |
-
return _cached_plugins[module_name]
|
67 |
-
|
68 |
-
# Print status.
|
69 |
-
if verbosity == 'full':
|
70 |
-
print(f'Setting up PyTorch plugin "{module_name}"...')
|
71 |
-
elif verbosity == 'brief':
|
72 |
-
print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
|
73 |
-
|
74 |
-
try: # pylint: disable=too-many-nested-blocks
|
75 |
-
# Make sure we can find the necessary compiler binaries.
|
76 |
-
if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
|
77 |
-
compiler_bindir = _find_compiler_bindir()
|
78 |
-
if compiler_bindir is None:
|
79 |
-
raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
|
80 |
-
os.environ['PATH'] += ';' + compiler_bindir
|
81 |
-
|
82 |
-
# Compile and load.
|
83 |
-
verbose_build = (verbosity == 'full')
|
84 |
-
|
85 |
-
# Incremental build md5sum trickery. Copies all the input source files
|
86 |
-
# into a cached build directory under a combined md5 digest of the input
|
87 |
-
# source files. Copying is done only if the combined digest has changed.
|
88 |
-
# This keeps input file timestamps and filenames the same as in previous
|
89 |
-
# extension builds, allowing for fast incremental rebuilds.
|
90 |
-
#
|
91 |
-
# This optimization is done only in case all the source files reside in
|
92 |
-
# a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
|
93 |
-
# environment variable is set (we take this as a signal that the user
|
94 |
-
# actually cares about this.)
|
95 |
-
source_dirs_set = set(os.path.dirname(source) for source in sources)
|
96 |
-
if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ):
|
97 |
-
all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file()))
|
98 |
-
|
99 |
-
# Compute a combined hash digest for all source files in the same
|
100 |
-
# custom op directory (usually .cu, .cpp, .py and .h files).
|
101 |
-
hash_md5 = hashlib.md5()
|
102 |
-
for src in all_source_files:
|
103 |
-
with open(src, 'rb') as f:
|
104 |
-
hash_md5.update(f.read())
|
105 |
-
build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
|
106 |
-
digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest())
|
107 |
-
|
108 |
-
if not os.path.isdir(digest_build_dir):
|
109 |
-
os.makedirs(digest_build_dir, exist_ok=True)
|
110 |
-
baton = FileBaton(os.path.join(digest_build_dir, 'lock'))
|
111 |
-
if baton.try_acquire():
|
112 |
-
try:
|
113 |
-
for src in all_source_files:
|
114 |
-
shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src)))
|
115 |
-
finally:
|
116 |
-
baton.release()
|
117 |
-
else:
|
118 |
-
# Someone else is copying source files under the digest dir,
|
119 |
-
# wait until done and continue.
|
120 |
-
baton.wait()
|
121 |
-
digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources]
|
122 |
-
torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir,
|
123 |
-
verbose=verbose_build, sources=digest_sources, **build_kwargs)
|
124 |
-
else:
|
125 |
-
torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
|
126 |
-
module = importlib.import_module(module_name)
|
127 |
-
|
128 |
-
except:
|
129 |
-
if verbosity == 'brief':
|
130 |
-
print('Failed!')
|
131 |
-
raise
|
132 |
-
|
133 |
-
# Print status and add to cache.
|
134 |
-
if verbosity == 'full':
|
135 |
-
print(f'Done setting up PyTorch plugin "{module_name}".')
|
136 |
-
elif verbosity == 'brief':
|
137 |
-
print('Done.')
|
138 |
-
_cached_plugins[module_name] = module
|
139 |
-
return module
|
140 |
-
|
141 |
-
#----------------------------------------------------------------------------
|
142 |
-
def get_plugin_v3(module_name, sources, headers=None, source_dir=None, **build_kwargs):
|
143 |
-
assert verbosity in ['none', 'brief', 'full']
|
144 |
-
if headers is None:
|
145 |
-
headers = []
|
146 |
-
if source_dir is not None:
|
147 |
-
sources = [os.path.join(source_dir, fname) for fname in sources]
|
148 |
-
headers = [os.path.join(source_dir, fname) for fname in headers]
|
149 |
-
|
150 |
-
# Already cached?
|
151 |
-
if module_name in _cached_plugins:
|
152 |
-
return _cached_plugins[module_name]
|
153 |
-
|
154 |
-
# Print status.
|
155 |
-
if verbosity == 'full':
|
156 |
-
print(f'Setting up PyTorch plugin "{module_name}"...')
|
157 |
-
elif verbosity == 'brief':
|
158 |
-
print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
|
159 |
-
verbose_build = (verbosity == 'full')
|
160 |
-
|
161 |
-
# Compile and load.
|
162 |
-
try: # pylint: disable=too-many-nested-blocks
|
163 |
-
# Make sure we can find the necessary compiler binaries.
|
164 |
-
if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
|
165 |
-
compiler_bindir = _find_compiler_bindir()
|
166 |
-
if compiler_bindir is None:
|
167 |
-
raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
|
168 |
-
os.environ['PATH'] += ';' + compiler_bindir
|
169 |
-
|
170 |
-
# Some containers set TORCH_CUDA_ARCH_LIST to a list that can either
|
171 |
-
# break the build or unnecessarily restrict what's available to nvcc.
|
172 |
-
# Unset it to let nvcc decide based on what's available on the
|
173 |
-
# machine.
|
174 |
-
os.environ['TORCH_CUDA_ARCH_LIST'] = ''
|
175 |
-
|
176 |
-
# Incremental build md5sum trickery. Copies all the input source files
|
177 |
-
# into a cached build directory under a combined md5 digest of the input
|
178 |
-
# source files. Copying is done only if the combined digest has changed.
|
179 |
-
# This keeps input file timestamps and filenames the same as in previous
|
180 |
-
# extension builds, allowing for fast incremental rebuilds.
|
181 |
-
#
|
182 |
-
# This optimization is done only in case all the source files reside in
|
183 |
-
# a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
|
184 |
-
# environment variable is set (we take this as a signal that the user
|
185 |
-
# actually cares about this.)
|
186 |
-
#
|
187 |
-
# EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work
|
188 |
-
# around the *.cu dependency bug in ninja config.
|
189 |
-
#
|
190 |
-
all_source_files = sorted(sources + headers)
|
191 |
-
all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files)
|
192 |
-
if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ):
|
193 |
-
|
194 |
-
# Compute combined hash digest for all source files.
|
195 |
-
hash_md5 = hashlib.md5()
|
196 |
-
for src in all_source_files:
|
197 |
-
with open(src, 'rb') as f:
|
198 |
-
hash_md5.update(f.read())
|
199 |
-
|
200 |
-
# Select cached build directory name.
|
201 |
-
source_digest = hash_md5.hexdigest()
|
202 |
-
build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
|
203 |
-
cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}')
|
204 |
-
|
205 |
-
if not os.path.isdir(cached_build_dir):
|
206 |
-
tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}'
|
207 |
-
os.makedirs(tmpdir)
|
208 |
-
for src in all_source_files:
|
209 |
-
shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src)))
|
210 |
-
try:
|
211 |
-
os.replace(tmpdir, cached_build_dir) # atomic
|
212 |
-
except OSError:
|
213 |
-
# source directory already exists, delete tmpdir and its contents.
|
214 |
-
shutil.rmtree(tmpdir)
|
215 |
-
if not os.path.isdir(cached_build_dir): raise
|
216 |
-
|
217 |
-
# Compile.
|
218 |
-
cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources]
|
219 |
-
torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir,
|
220 |
-
verbose=verbose_build, sources=cached_sources, **build_kwargs)
|
221 |
-
else:
|
222 |
-
torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
|
223 |
-
|
224 |
-
# Load.
|
225 |
-
module = importlib.import_module(module_name)
|
226 |
-
|
227 |
-
except:
|
228 |
-
if verbosity == 'brief':
|
229 |
-
print('Failed!')
|
230 |
-
raise
|
231 |
-
|
232 |
-
# Print status and add to cache dict.
|
233 |
-
if verbosity == 'full':
|
234 |
-
print(f'Done setting up PyTorch plugin "{module_name}".')
|
235 |
-
elif verbosity == 'brief':
|
236 |
-
print('Done.')
|
237 |
-
_cached_plugins[module_name] = module
|
238 |
-
return module
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch_utils/misc.py
DELETED
@@ -1,264 +0,0 @@
|
|
1 |
-
# Copyright (c) SenseTime Research. All rights reserved.
|
2 |
-
|
3 |
-
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
4 |
-
#
|
5 |
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
6 |
-
# and proprietary rights in and to this software, related documentation
|
7 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
8 |
-
# distribution of this software and related documentation without an express
|
9 |
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
-
|
11 |
-
import re
|
12 |
-
import contextlib
|
13 |
-
import numpy as np
|
14 |
-
import torch
|
15 |
-
import warnings
|
16 |
-
import dnnlib
|
17 |
-
|
18 |
-
#----------------------------------------------------------------------------
|
19 |
-
# Cached construction of constant tensors. Avoids CPU=>GPU copy when the
|
20 |
-
# same constant is used multiple times.
|
21 |
-
|
22 |
-
_constant_cache = dict()
|
23 |
-
|
24 |
-
def constant(value, shape=None, dtype=None, device=None, memory_format=None):
|
25 |
-
value = np.asarray(value)
|
26 |
-
if shape is not None:
|
27 |
-
shape = tuple(shape)
|
28 |
-
if dtype is None:
|
29 |
-
dtype = torch.get_default_dtype()
|
30 |
-
if device is None:
|
31 |
-
device = torch.device('cpu')
|
32 |
-
if memory_format is None:
|
33 |
-
memory_format = torch.contiguous_format
|
34 |
-
|
35 |
-
key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
|
36 |
-
tensor = _constant_cache.get(key, None)
|
37 |
-
if tensor is None:
|
38 |
-
tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
|
39 |
-
if shape is not None:
|
40 |
-
tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
|
41 |
-
tensor = tensor.contiguous(memory_format=memory_format)
|
42 |
-
_constant_cache[key] = tensor
|
43 |
-
return tensor
|
44 |
-
|
45 |
-
#----------------------------------------------------------------------------
|
46 |
-
# Replace NaN/Inf with specified numerical values.
|
47 |
-
|
48 |
-
try:
|
49 |
-
nan_to_num = torch.nan_to_num # 1.8.0a0
|
50 |
-
except AttributeError:
|
51 |
-
def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
|
52 |
-
assert isinstance(input, torch.Tensor)
|
53 |
-
if posinf is None:
|
54 |
-
posinf = torch.finfo(input.dtype).max
|
55 |
-
if neginf is None:
|
56 |
-
neginf = torch.finfo(input.dtype).min
|
57 |
-
assert nan == 0
|
58 |
-
return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
|
59 |
-
|
60 |
-
#----------------------------------------------------------------------------
|
61 |
-
# Symbolic assert.
|
62 |
-
|
63 |
-
try:
|
64 |
-
symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
|
65 |
-
except AttributeError:
|
66 |
-
symbolic_assert = torch.Assert # 1.7.0
|
67 |
-
|
68 |
-
#----------------------------------------------------------------------------
|
69 |
-
# Context manager to suppress known warnings in torch.jit.trace().
|
70 |
-
|
71 |
-
class suppress_tracer_warnings(warnings.catch_warnings):
|
72 |
-
def __enter__(self):
|
73 |
-
super().__enter__()
|
74 |
-
warnings.simplefilter('ignore', category=torch.jit.TracerWarning)
|
75 |
-
return self
|
76 |
-
|
77 |
-
#----------------------------------------------------------------------------
|
78 |
-
# Assert that the shape of a tensor matches the given list of integers.
|
79 |
-
# None indicates that the size of a dimension is allowed to vary.
|
80 |
-
# Performs symbolic assertion when used in torch.jit.trace().
|
81 |
-
|
82 |
-
def assert_shape(tensor, ref_shape):
|
83 |
-
if tensor.ndim != len(ref_shape):
|
84 |
-
raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
|
85 |
-
for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
|
86 |
-
if ref_size is None:
|
87 |
-
pass
|
88 |
-
elif isinstance(ref_size, torch.Tensor):
|
89 |
-
with suppress_tracer_warnings(): # as_tensor results are registered as constants
|
90 |
-
symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
|
91 |
-
elif isinstance(size, torch.Tensor):
|
92 |
-
with suppress_tracer_warnings(): # as_tensor results are registered as constants
|
93 |
-
symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
|
94 |
-
elif size != ref_size:
|
95 |
-
raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
|
96 |
-
|
97 |
-
#----------------------------------------------------------------------------
|
98 |
-
# Function decorator that calls torch.autograd.profiler.record_function().
|
99 |
-
|
100 |
-
def profiled_function(fn):
|
101 |
-
def decorator(*args, **kwargs):
|
102 |
-
with torch.autograd.profiler.record_function(fn.__name__):
|
103 |
-
return fn(*args, **kwargs)
|
104 |
-
decorator.__name__ = fn.__name__
|
105 |
-
return decorator
|
106 |
-
|
107 |
-
#----------------------------------------------------------------------------
|
108 |
-
# Sampler for torch.utils.data.DataLoader that loops over the dataset
|
109 |
-
# indefinitely, shuffling items as it goes.
|
110 |
-
|
111 |
-
class InfiniteSampler(torch.utils.data.Sampler):
|
112 |
-
def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
|
113 |
-
assert len(dataset) > 0
|
114 |
-
assert num_replicas > 0
|
115 |
-
assert 0 <= rank < num_replicas
|
116 |
-
assert 0 <= window_size <= 1
|
117 |
-
super().__init__(dataset)
|
118 |
-
self.dataset = dataset
|
119 |
-
self.rank = rank
|
120 |
-
self.num_replicas = num_replicas
|
121 |
-
self.shuffle = shuffle
|
122 |
-
self.seed = seed
|
123 |
-
self.window_size = window_size
|
124 |
-
|
125 |
-
def __iter__(self):
|
126 |
-
order = np.arange(len(self.dataset))
|
127 |
-
rnd = None
|
128 |
-
window = 0
|
129 |
-
if self.shuffle:
|
130 |
-
rnd = np.random.RandomState(self.seed)
|
131 |
-
rnd.shuffle(order)
|
132 |
-
window = int(np.rint(order.size * self.window_size))
|
133 |
-
|
134 |
-
idx = 0
|
135 |
-
while True:
|
136 |
-
i = idx % order.size
|
137 |
-
if idx % self.num_replicas == self.rank:
|
138 |
-
yield order[i]
|
139 |
-
if window >= 2:
|
140 |
-
j = (i - rnd.randint(window)) % order.size
|
141 |
-
order[i], order[j] = order[j], order[i]
|
142 |
-
idx += 1
|
143 |
-
|
144 |
-
#----------------------------------------------------------------------------
|
145 |
-
# Utilities for operating with torch.nn.Module parameters and buffers.
|
146 |
-
|
147 |
-
def params_and_buffers(module):
|
148 |
-
assert isinstance(module, torch.nn.Module)
|
149 |
-
return list(module.parameters()) + list(module.buffers())
|
150 |
-
|
151 |
-
def named_params_and_buffers(module):
|
152 |
-
assert isinstance(module, torch.nn.Module)
|
153 |
-
return list(module.named_parameters()) + list(module.named_buffers())
|
154 |
-
|
155 |
-
def copy_params_and_buffers(src_module, dst_module, require_all=False):
|
156 |
-
assert isinstance(src_module, torch.nn.Module)
|
157 |
-
assert isinstance(dst_module, torch.nn.Module)
|
158 |
-
src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)}
|
159 |
-
for name, tensor in named_params_and_buffers(dst_module):
|
160 |
-
assert (name in src_tensors) or (not require_all)
|
161 |
-
if name in src_tensors:
|
162 |
-
tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
|
163 |
-
|
164 |
-
#----------------------------------------------------------------------------
|
165 |
-
# Context manager for easily enabling/disabling DistributedDataParallel
|
166 |
-
# synchronization.
|
167 |
-
|
168 |
-
@contextlib.contextmanager
|
169 |
-
def ddp_sync(module, sync):
|
170 |
-
assert isinstance(module, torch.nn.Module)
|
171 |
-
if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
|
172 |
-
yield
|
173 |
-
else:
|
174 |
-
with module.no_sync():
|
175 |
-
yield
|
176 |
-
|
177 |
-
#----------------------------------------------------------------------------
|
178 |
-
# Check DistributedDataParallel consistency across processes.
|
179 |
-
|
180 |
-
def check_ddp_consistency(module, ignore_regex=None):
|
181 |
-
assert isinstance(module, torch.nn.Module)
|
182 |
-
for name, tensor in named_params_and_buffers(module):
|
183 |
-
fullname = type(module).__name__ + '.' + name
|
184 |
-
if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
|
185 |
-
continue
|
186 |
-
tensor = tensor.detach()
|
187 |
-
other = tensor.clone()
|
188 |
-
torch.distributed.broadcast(tensor=other, src=0)
|
189 |
-
assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname
|
190 |
-
|
191 |
-
#----------------------------------------------------------------------------
|
192 |
-
# Print summary table of module hierarchy.
|
193 |
-
|
194 |
-
def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
|
195 |
-
assert isinstance(module, torch.nn.Module)
|
196 |
-
assert not isinstance(module, torch.jit.ScriptModule)
|
197 |
-
assert isinstance(inputs, (tuple, list))
|
198 |
-
|
199 |
-
# Register hooks.
|
200 |
-
entries = []
|
201 |
-
nesting = [0]
|
202 |
-
def pre_hook(_mod, _inputs):
|
203 |
-
nesting[0] += 1
|
204 |
-
def post_hook(mod, _inputs, outputs):
|
205 |
-
nesting[0] -= 1
|
206 |
-
if nesting[0] <= max_nesting:
|
207 |
-
outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
|
208 |
-
outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
|
209 |
-
entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
|
210 |
-
hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
|
211 |
-
hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
|
212 |
-
|
213 |
-
# Run module.
|
214 |
-
outputs = module(*inputs)
|
215 |
-
for hook in hooks:
|
216 |
-
hook.remove()
|
217 |
-
|
218 |
-
# Identify unique outputs, parameters, and buffers.
|
219 |
-
tensors_seen = set()
|
220 |
-
for e in entries:
|
221 |
-
e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
|
222 |
-
e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
|
223 |
-
e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
|
224 |
-
tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
|
225 |
-
|
226 |
-
# Filter out redundant entries.
|
227 |
-
if skip_redundant:
|
228 |
-
entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
|
229 |
-
|
230 |
-
# Construct table.
|
231 |
-
rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
|
232 |
-
rows += [['---'] * len(rows[0])]
|
233 |
-
param_total = 0
|
234 |
-
buffer_total = 0
|
235 |
-
submodule_names = {mod: name for name, mod in module.named_modules()}
|
236 |
-
for e in entries:
|
237 |
-
name = '<top-level>' if e.mod is module else submodule_names[e.mod]
|
238 |
-
param_size = sum(t.numel() for t in e.unique_params)
|
239 |
-
buffer_size = sum(t.numel() for t in e.unique_buffers)
|
240 |
-
output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs]
|
241 |
-
output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
|
242 |
-
rows += [[
|
243 |
-
name + (':0' if len(e.outputs) >= 2 else ''),
|
244 |
-
str(param_size) if param_size else '-',
|
245 |
-
str(buffer_size) if buffer_size else '-',
|
246 |
-
(output_shapes + ['-'])[0],
|
247 |
-
(output_dtypes + ['-'])[0],
|
248 |
-
]]
|
249 |
-
for idx in range(1, len(e.outputs)):
|
250 |
-
rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
|
251 |
-
param_total += param_size
|
252 |
-
buffer_total += buffer_size
|
253 |
-
rows += [['---'] * len(rows[0])]
|
254 |
-
rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
|
255 |
-
|
256 |
-
# Print table.
|
257 |
-
widths = [max(len(cell) for cell in column) for column in zip(*rows)]
|
258 |
-
print()
|
259 |
-
for row in rows:
|
260 |
-
print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
|
261 |
-
print()
|
262 |
-
return outputs
|
263 |
-
|
264 |
-
#----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch_utils/models.py
DELETED
@@ -1,756 +0,0 @@
|
|
1 |
-
# Copyright (c) SenseTime Research. All rights reserved.
|
2 |
-
|
3 |
-
# https://github.com/rosinality/stylegan2-pytorch/blob/master/model.py
|
4 |
-
|
5 |
-
import math
|
6 |
-
import random
|
7 |
-
import functools
|
8 |
-
import operator
|
9 |
-
|
10 |
-
import torch
|
11 |
-
from torch import nn
|
12 |
-
from torch.nn import functional as F
|
13 |
-
import torch.nn.init as init
|
14 |
-
from torch.autograd import Function
|
15 |
-
|
16 |
-
from .op_edit import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
|
17 |
-
|
18 |
-
|
19 |
-
class PixelNorm(nn.Module):
|
20 |
-
def __init__(self):
|
21 |
-
super().__init__()
|
22 |
-
|
23 |
-
def forward(self, input):
|
24 |
-
return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
|
25 |
-
|
26 |
-
|
27 |
-
def make_kernel(k):
|
28 |
-
k = torch.tensor(k, dtype=torch.float32)
|
29 |
-
if k.ndim == 1:
|
30 |
-
k = k[None, :] * k[:, None]
|
31 |
-
k /= k.sum()
|
32 |
-
return k
|
33 |
-
|
34 |
-
|
35 |
-
class Upsample(nn.Module):
|
36 |
-
def __init__(self, kernel, factor=2):
|
37 |
-
super().__init__()
|
38 |
-
|
39 |
-
self.factor = factor
|
40 |
-
kernel = make_kernel(kernel) * (factor ** 2)
|
41 |
-
self.register_buffer("kernel", kernel)
|
42 |
-
|
43 |
-
p = kernel.shape[0] - factor
|
44 |
-
|
45 |
-
pad0 = (p + 1) // 2 + factor - 1
|
46 |
-
pad1 = p // 2
|
47 |
-
|
48 |
-
self.pad = (pad0, pad1)
|
49 |
-
|
50 |
-
def forward(self, input):
|
51 |
-
out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
|
52 |
-
return out
|
53 |
-
|
54 |
-
|
55 |
-
class Downsample(nn.Module):
|
56 |
-
def __init__(self, kernel, factor=2):
|
57 |
-
super().__init__()
|
58 |
-
|
59 |
-
self.factor = factor
|
60 |
-
kernel = make_kernel(kernel)
|
61 |
-
self.register_buffer("kernel", kernel)
|
62 |
-
|
63 |
-
p = kernel.shape[0] - factor
|
64 |
-
|
65 |
-
pad0 = (p + 1) // 2
|
66 |
-
pad1 = p // 2
|
67 |
-
|
68 |
-
self.pad = (pad0, pad1)
|
69 |
-
|
70 |
-
def forward(self, input):
|
71 |
-
out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
|
72 |
-
return out
|
73 |
-
|
74 |
-
|
75 |
-
class Blur(nn.Module):
|
76 |
-
def __init__(self, kernel, pad, upsample_factor=1):
|
77 |
-
super().__init__()
|
78 |
-
|
79 |
-
kernel = make_kernel(kernel)
|
80 |
-
|
81 |
-
if upsample_factor > 1:
|
82 |
-
kernel = kernel * (upsample_factor ** 2)
|
83 |
-
|
84 |
-
self.register_buffer("kernel", kernel)
|
85 |
-
|
86 |
-
self.pad = pad
|
87 |
-
|
88 |
-
def forward(self, input):
|
89 |
-
out = upfirdn2d(input, self.kernel, pad=self.pad)
|
90 |
-
return out
|
91 |
-
|
92 |
-
|
93 |
-
class EqualConv2d(nn.Module):
|
94 |
-
def __init__(
|
95 |
-
self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
|
96 |
-
):
|
97 |
-
super().__init__()
|
98 |
-
|
99 |
-
self.weight = nn.Parameter(
|
100 |
-
torch.randn(out_channel, in_channel, kernel_size, kernel_size)
|
101 |
-
)
|
102 |
-
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
|
103 |
-
|
104 |
-
self.stride = stride
|
105 |
-
self.padding = padding
|
106 |
-
|
107 |
-
if bias:
|
108 |
-
self.bias = nn.Parameter(torch.zeros(out_channel))
|
109 |
-
|
110 |
-
else:
|
111 |
-
self.bias = None
|
112 |
-
|
113 |
-
def forward(self, input):
|
114 |
-
out = F.conv2d(
|
115 |
-
input,
|
116 |
-
self.weight * self.scale,
|
117 |
-
bias=self.bias,
|
118 |
-
stride=self.stride,
|
119 |
-
padding=self.padding,
|
120 |
-
)
|
121 |
-
return out
|
122 |
-
|
123 |
-
def __repr__(self):
|
124 |
-
return (
|
125 |
-
f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
|
126 |
-
f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
|
127 |
-
)
|
128 |
-
|
129 |
-
|
130 |
-
class EqualLinear(nn.Module):
|
131 |
-
def __init__(
|
132 |
-
self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
|
133 |
-
):
|
134 |
-
super().__init__()
|
135 |
-
|
136 |
-
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
|
137 |
-
|
138 |
-
if bias:
|
139 |
-
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
|
140 |
-
else:
|
141 |
-
self.bias = None
|
142 |
-
|
143 |
-
self.activation = activation
|
144 |
-
|
145 |
-
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
|
146 |
-
self.lr_mul = lr_mul
|
147 |
-
|
148 |
-
def forward(self, input):
|
149 |
-
if self.activation:
|
150 |
-
out = F.linear(input, self.weight * self.scale)
|
151 |
-
out = fused_leaky_relu(out, self.bias * self.lr_mul)
|
152 |
-
else:
|
153 |
-
out = F.linear(
|
154 |
-
input, self.weight * self.scale, bias=self.bias * self.lr_mul
|
155 |
-
)
|
156 |
-
return out
|
157 |
-
|
158 |
-
def __repr__(self):
|
159 |
-
return (
|
160 |
-
f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
|
161 |
-
)
|
162 |
-
|
163 |
-
|
164 |
-
class ScaledLeakyReLU(nn.Module):
|
165 |
-
def __init__(self, negative_slope=0.2):
|
166 |
-
super().__init__()
|
167 |
-
self.negative_slope = negative_slope
|
168 |
-
|
169 |
-
def forward(self, input):
|
170 |
-
out = F.leaky_relu(input, negative_slope=self.negative_slope)
|
171 |
-
return out * math.sqrt(2)
|
172 |
-
|
173 |
-
|
174 |
-
class ModulatedConv2d(nn.Module):
|
175 |
-
def __init__(
|
176 |
-
self,
|
177 |
-
in_channel,
|
178 |
-
out_channel,
|
179 |
-
kernel_size,
|
180 |
-
style_dim,
|
181 |
-
demodulate=True,
|
182 |
-
upsample=False,
|
183 |
-
downsample=False,
|
184 |
-
blur_kernel=[1, 3, 3, 1],
|
185 |
-
):
|
186 |
-
super().__init__()
|
187 |
-
|
188 |
-
self.eps = 1e-8
|
189 |
-
self.kernel_size = kernel_size
|
190 |
-
self.in_channel = in_channel
|
191 |
-
self.out_channel = out_channel
|
192 |
-
self.upsample = upsample
|
193 |
-
self.downsample = downsample
|
194 |
-
|
195 |
-
if upsample:
|
196 |
-
factor = 2
|
197 |
-
p = (len(blur_kernel) - factor) - (kernel_size - 1)
|
198 |
-
pad0 = (p + 1) // 2 + factor - 1
|
199 |
-
pad1 = p // 2 + 1
|
200 |
-
self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
|
201 |
-
|
202 |
-
if downsample:
|
203 |
-
factor = 2
|
204 |
-
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
205 |
-
pad0 = (p + 1) // 2
|
206 |
-
pad1 = p // 2
|
207 |
-
self.blur = Blur(blur_kernel, pad=(pad0, pad1))
|
208 |
-
|
209 |
-
fan_in = in_channel * kernel_size ** 2
|
210 |
-
self.scale = 1 / math.sqrt(fan_in)
|
211 |
-
self.padding = kernel_size // 2
|
212 |
-
self.weight = nn.Parameter(
|
213 |
-
torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
|
214 |
-
)
|
215 |
-
self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
|
216 |
-
self.demodulate = demodulate
|
217 |
-
|
218 |
-
def __repr__(self):
|
219 |
-
return (
|
220 |
-
f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
|
221 |
-
f"upsample={self.upsample}, downsample={self.downsample})"
|
222 |
-
)
|
223 |
-
|
224 |
-
def forward(self, input, style):
|
225 |
-
batch, in_channel, height, width = input.shape
|
226 |
-
|
227 |
-
style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
|
228 |
-
weight = self.scale * self.weight * style
|
229 |
-
|
230 |
-
if self.demodulate:
|
231 |
-
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
|
232 |
-
weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
|
233 |
-
|
234 |
-
weight = weight.view(
|
235 |
-
batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
|
236 |
-
)
|
237 |
-
|
238 |
-
if self.upsample:
|
239 |
-
input = input.view(1, batch * in_channel, height, width)
|
240 |
-
weight = weight.view(
|
241 |
-
batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
|
242 |
-
)
|
243 |
-
weight = weight.transpose(1, 2).reshape(
|
244 |
-
batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
|
245 |
-
)
|
246 |
-
out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
|
247 |
-
_, _, height, width = out.shape
|
248 |
-
out = out.view(batch, self.out_channel, height, width)
|
249 |
-
out = self.blur(out)
|
250 |
-
|
251 |
-
elif self.downsample:
|
252 |
-
input = self.blur(input)
|
253 |
-
_, _, height, width = input.shape
|
254 |
-
input = input.view(1, batch * in_channel, height, width)
|
255 |
-
out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
|
256 |
-
_, _, height, width = out.shape
|
257 |
-
out = out.view(batch, self.out_channel, height, width)
|
258 |
-
|
259 |
-
else:
|
260 |
-
input = input.view(1, batch * in_channel, height, width)
|
261 |
-
out = F.conv2d(input, weight, padding=self.padding, groups=batch)
|
262 |
-
_, _, height, width = out.shape
|
263 |
-
out = out.view(batch, self.out_channel, height, width)
|
264 |
-
|
265 |
-
return out
|
266 |
-
|
267 |
-
|
268 |
-
class NoiseInjection(nn.Module):
|
269 |
-
def __init__(self):
|
270 |
-
super().__init__()
|
271 |
-
self.weight = nn.Parameter(torch.zeros(1))
|
272 |
-
|
273 |
-
def forward(self, image, noise=None):
|
274 |
-
if noise is None:
|
275 |
-
batch, _, height, width = image.shape
|
276 |
-
noise = image.new_empty(batch, 1, height, width).normal_()
|
277 |
-
return image + self.weight * noise
|
278 |
-
|
279 |
-
|
280 |
-
class ConstantInput(nn.Module):
|
281 |
-
def __init__(self, channel, size=4):
|
282 |
-
super().__init__()
|
283 |
-
self.input = nn.Parameter(torch.randn(1, channel, size, size // 2))
|
284 |
-
|
285 |
-
def forward(self, input):
|
286 |
-
batch = input.shape[0]
|
287 |
-
out = self.input.repeat(batch, 1, 1, 1)
|
288 |
-
return out
|
289 |
-
|
290 |
-
|
291 |
-
class StyledConv(nn.Module):
|
292 |
-
def __init__(
|
293 |
-
self,
|
294 |
-
in_channel,
|
295 |
-
out_channel,
|
296 |
-
kernel_size,
|
297 |
-
style_dim,
|
298 |
-
upsample=False,
|
299 |
-
blur_kernel=[1, 3, 3, 1],
|
300 |
-
demodulate=True,
|
301 |
-
):
|
302 |
-
super().__init__()
|
303 |
-
self.conv = ModulatedConv2d(
|
304 |
-
in_channel,
|
305 |
-
out_channel,
|
306 |
-
kernel_size,
|
307 |
-
style_dim,
|
308 |
-
upsample=upsample,
|
309 |
-
blur_kernel=blur_kernel,
|
310 |
-
demodulate=demodulate,
|
311 |
-
)
|
312 |
-
self.noise = NoiseInjection()
|
313 |
-
self.activate = FusedLeakyReLU(out_channel)
|
314 |
-
|
315 |
-
def forward(self, input, style, noise=None):
|
316 |
-
out = self.conv(input, style)
|
317 |
-
out = self.noise(out, noise=noise)
|
318 |
-
out = self.activate(out)
|
319 |
-
return out
|
320 |
-
|
321 |
-
|
322 |
-
class ToRGB(nn.Module):
|
323 |
-
def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
|
324 |
-
super().__init__()
|
325 |
-
if upsample:
|
326 |
-
self.upsample = Upsample(blur_kernel)
|
327 |
-
|
328 |
-
self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
|
329 |
-
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
|
330 |
-
|
331 |
-
def forward(self, input, style, skip=None):
|
332 |
-
out = self.conv(input, style)
|
333 |
-
out = out + self.bias
|
334 |
-
|
335 |
-
if skip is not None:
|
336 |
-
skip = self.upsample(skip)
|
337 |
-
out = out + skip
|
338 |
-
|
339 |
-
return out
|
340 |
-
|
341 |
-
|
342 |
-
class Generator(nn.Module):
|
343 |
-
def __init__(
|
344 |
-
self,
|
345 |
-
size,
|
346 |
-
style_dim,
|
347 |
-
n_mlp,
|
348 |
-
channel_multiplier=1,
|
349 |
-
blur_kernel=[1, 3, 3, 1],
|
350 |
-
lr_mlp=0.01,
|
351 |
-
small=False,
|
352 |
-
small_isaac=False,
|
353 |
-
):
|
354 |
-
super().__init__()
|
355 |
-
|
356 |
-
self.size = size
|
357 |
-
|
358 |
-
if small and size > 64:
|
359 |
-
raise ValueError("small only works for sizes <= 64")
|
360 |
-
|
361 |
-
self.style_dim = style_dim
|
362 |
-
layers = [PixelNorm()]
|
363 |
-
|
364 |
-
for i in range(n_mlp):
|
365 |
-
layers.append(
|
366 |
-
EqualLinear(
|
367 |
-
style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
|
368 |
-
)
|
369 |
-
)
|
370 |
-
|
371 |
-
self.style = nn.Sequential(*layers)
|
372 |
-
|
373 |
-
if small:
|
374 |
-
self.channels = {
|
375 |
-
4: 64 * channel_multiplier,
|
376 |
-
8: 64 * channel_multiplier,
|
377 |
-
16: 64 * channel_multiplier,
|
378 |
-
32: 64 * channel_multiplier,
|
379 |
-
64: 64 * channel_multiplier,
|
380 |
-
}
|
381 |
-
elif small_isaac:
|
382 |
-
self.channels = {4: 256, 8: 256, 16: 256, 32: 256, 64: 128, 128: 128}
|
383 |
-
else:
|
384 |
-
self.channels = {
|
385 |
-
4: 512,
|
386 |
-
8: 512,
|
387 |
-
16: 512,
|
388 |
-
32: 512,
|
389 |
-
64: 256 * channel_multiplier,
|
390 |
-
128: 128 * channel_multiplier,
|
391 |
-
256: 64 * channel_multiplier,
|
392 |
-
512: 32 * channel_multiplier,
|
393 |
-
1024: 16 * channel_multiplier,
|
394 |
-
}
|
395 |
-
|
396 |
-
self.input = ConstantInput(self.channels[4])
|
397 |
-
self.conv1 = StyledConv(
|
398 |
-
self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
|
399 |
-
)
|
400 |
-
self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
|
401 |
-
|
402 |
-
self.log_size = int(math.log(size, 2))
|
403 |
-
self.num_layers = (self.log_size - 2) * 2 + 1
|
404 |
-
|
405 |
-
self.convs = nn.ModuleList()
|
406 |
-
self.upsamples = nn.ModuleList()
|
407 |
-
self.to_rgbs = nn.ModuleList()
|
408 |
-
self.noises = nn.Module()
|
409 |
-
|
410 |
-
in_channel = self.channels[4]
|
411 |
-
|
412 |
-
for layer_idx in range(self.num_layers):
|
413 |
-
res = (layer_idx + 5) // 2
|
414 |
-
shape = [1, 1, 2 ** res, 2 ** res // 2]
|
415 |
-
self.noises.register_buffer(
|
416 |
-
"noise_{}".format(layer_idx), torch.randn(*shape)
|
417 |
-
)
|
418 |
-
|
419 |
-
for i in range(3, self.log_size + 1):
|
420 |
-
out_channel = self.channels[2 ** i]
|
421 |
-
|
422 |
-
self.convs.append(
|
423 |
-
StyledConv(
|
424 |
-
in_channel,
|
425 |
-
out_channel,
|
426 |
-
3,
|
427 |
-
style_dim,
|
428 |
-
upsample=True,
|
429 |
-
blur_kernel=blur_kernel,
|
430 |
-
)
|
431 |
-
)
|
432 |
-
|
433 |
-
self.convs.append(
|
434 |
-
StyledConv(
|
435 |
-
out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
|
436 |
-
)
|
437 |
-
)
|
438 |
-
|
439 |
-
self.to_rgbs.append(ToRGB(out_channel, style_dim))
|
440 |
-
in_channel = out_channel
|
441 |
-
|
442 |
-
self.n_latent = self.log_size * 2 - 2
|
443 |
-
|
444 |
-
def make_noise(self):
|
445 |
-
device = self.input.input.device
|
446 |
-
|
447 |
-
noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2 // 2, device=device)]
|
448 |
-
|
449 |
-
for i in range(3, self.log_size + 1):
|
450 |
-
for _ in range(2):
|
451 |
-
noises.append(torch.randn(1, 1, 2 ** i, 2 ** i // 2, device=device))
|
452 |
-
|
453 |
-
return noises
|
454 |
-
|
455 |
-
def mean_latent(self, n_latent):
|
456 |
-
latent_in = torch.randn(
|
457 |
-
n_latent, self.style_dim, device=self.input.input.device
|
458 |
-
)
|
459 |
-
latent = self.style(latent_in).mean(0, keepdim=True)
|
460 |
-
|
461 |
-
return latent
|
462 |
-
|
463 |
-
def get_latent(self, input):
|
464 |
-
return self.style(input)
|
465 |
-
|
466 |
-
def forward(
|
467 |
-
self,
|
468 |
-
styles,
|
469 |
-
return_latents=False,
|
470 |
-
return_features=False,
|
471 |
-
inject_index=None,
|
472 |
-
truncation=1,
|
473 |
-
truncation_latent=None,
|
474 |
-
input_is_latent=False,
|
475 |
-
noise=None,
|
476 |
-
randomize_noise=True,
|
477 |
-
real=False,
|
478 |
-
):
|
479 |
-
if not input_is_latent:
|
480 |
-
styles = [self.style(s) for s in styles]
|
481 |
-
if noise is None:
|
482 |
-
if randomize_noise:
|
483 |
-
noise = [None] * self.num_layers
|
484 |
-
else:
|
485 |
-
noise = [
|
486 |
-
getattr(self.noises, "noise_{}".format(i))
|
487 |
-
for i in range(self.num_layers)
|
488 |
-
]
|
489 |
-
|
490 |
-
if truncation < 1:
|
491 |
-
# print('truncation_latent: ', truncation_latent.shape)
|
492 |
-
if not real: #if type(styles) == list:
|
493 |
-
style_t = []
|
494 |
-
for style in styles:
|
495 |
-
style_t.append(
|
496 |
-
truncation_latent + truncation * (style - truncation_latent)
|
497 |
-
) # (-1.1162e-03-(-1.0914e-01))*0.8+(-1.0914e-01)
|
498 |
-
styles = style_t
|
499 |
-
else: # styles are latent (tensor: 1,18,512), for real PTI output
|
500 |
-
truncation_latent = truncation_latent.repeat(18,1).unsqueeze(0) # (1,512) --> (1,18,512)
|
501 |
-
styles = torch.add(truncation_latent,torch.mul(torch.sub(styles,truncation_latent),truncation))
|
502 |
-
# print('now styles after truncation : ', styles)
|
503 |
-
#if type(styles) == list and len(styles) < 2: # this if for input as list of [(1,512)]
|
504 |
-
if not real:
|
505 |
-
if len(styles) < 2:
|
506 |
-
inject_index = self.n_latent
|
507 |
-
if styles[0].ndim < 3:
|
508 |
-
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
509 |
-
else:
|
510 |
-
latent = styles[0]
|
511 |
-
elif type(styles) == list:
|
512 |
-
if inject_index is None:
|
513 |
-
inject_index = 4
|
514 |
-
|
515 |
-
latent = styles[0].unsqueeze(0)
|
516 |
-
if latent.shape[1] == 1:
|
517 |
-
latent = latent.repeat(1, inject_index, 1)
|
518 |
-
else:
|
519 |
-
latent = latent[:, :inject_index, :]
|
520 |
-
latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
|
521 |
-
latent = torch.cat([latent, latent2], 1)
|
522 |
-
else: # input is tensor of size with torch.Size([1, 18, 512]), for real PTI output
|
523 |
-
latent = styles
|
524 |
-
|
525 |
-
# print(f'processed latent: {latent.shape}')
|
526 |
-
|
527 |
-
features = {}
|
528 |
-
out = self.input(latent)
|
529 |
-
features["out_0"] = out
|
530 |
-
out = self.conv1(out, latent[:, 0], noise=noise[0])
|
531 |
-
features["conv1_0"] = out
|
532 |
-
|
533 |
-
skip = self.to_rgb1(out, latent[:, 1])
|
534 |
-
features["skip_0"] = skip
|
535 |
-
i = 1
|
536 |
-
for conv1, conv2, noise1, noise2, to_rgb in zip(
|
537 |
-
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
|
538 |
-
):
|
539 |
-
out = conv1(out, latent[:, i], noise=noise1)
|
540 |
-
features["conv1_{}".format(i)] = out
|
541 |
-
out = conv2(out, latent[:, i + 1], noise=noise2)
|
542 |
-
features["conv2_{}".format(i)] = out
|
543 |
-
skip = to_rgb(out, latent[:, i + 2], skip)
|
544 |
-
features["skip_{}".format(i)] = skip
|
545 |
-
|
546 |
-
i += 2
|
547 |
-
|
548 |
-
image = skip
|
549 |
-
|
550 |
-
if return_latents:
|
551 |
-
return image, latent
|
552 |
-
elif return_features:
|
553 |
-
return image, features
|
554 |
-
else:
|
555 |
-
return image, None
|
556 |
-
|
557 |
-
|
558 |
-
class ConvLayer(nn.Sequential):
|
559 |
-
def __init__(
|
560 |
-
self,
|
561 |
-
in_channel,
|
562 |
-
out_channel,
|
563 |
-
kernel_size,
|
564 |
-
downsample=False,
|
565 |
-
blur_kernel=[1, 3, 3, 1],
|
566 |
-
bias=True,
|
567 |
-
activate=True,
|
568 |
-
):
|
569 |
-
layers = []
|
570 |
-
|
571 |
-
if downsample:
|
572 |
-
factor = 2
|
573 |
-
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
574 |
-
pad0 = (p + 1) // 2
|
575 |
-
pad1 = p // 2
|
576 |
-
|
577 |
-
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
|
578 |
-
|
579 |
-
stride = 2
|
580 |
-
self.padding = 0
|
581 |
-
|
582 |
-
else:
|
583 |
-
stride = 1
|
584 |
-
self.padding = kernel_size // 2
|
585 |
-
|
586 |
-
layers.append(
|
587 |
-
EqualConv2d(
|
588 |
-
in_channel,
|
589 |
-
out_channel,
|
590 |
-
kernel_size,
|
591 |
-
padding=self.padding,
|
592 |
-
stride=stride,
|
593 |
-
bias=bias and not activate,
|
594 |
-
)
|
595 |
-
)
|
596 |
-
|
597 |
-
if activate:
|
598 |
-
if bias:
|
599 |
-
layers.append(FusedLeakyReLU(out_channel))
|
600 |
-
else:
|
601 |
-
layers.append(ScaledLeakyReLU(0.2))
|
602 |
-
|
603 |
-
super().__init__(*layers)
|
604 |
-
|
605 |
-
|
606 |
-
class ResBlock(nn.Module):
|
607 |
-
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
|
608 |
-
super().__init__()
|
609 |
-
|
610 |
-
self.conv1 = ConvLayer(in_channel, in_channel, 3)
|
611 |
-
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
|
612 |
-
|
613 |
-
self.skip = ConvLayer(
|
614 |
-
in_channel, out_channel, 1, downsample=True, activate=False, bias=False
|
615 |
-
)
|
616 |
-
|
617 |
-
def forward(self, input):
|
618 |
-
out = self.conv1(input)
|
619 |
-
out = self.conv2(out)
|
620 |
-
|
621 |
-
skip = self.skip(input)
|
622 |
-
out = (out + skip) / math.sqrt(2)
|
623 |
-
|
624 |
-
return out
|
625 |
-
|
626 |
-
|
627 |
-
class StyleDiscriminator(nn.Module):
|
628 |
-
def __init__(
|
629 |
-
self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1], small=False
|
630 |
-
):
|
631 |
-
super().__init__()
|
632 |
-
|
633 |
-
if small:
|
634 |
-
channels = {4: 64, 8: 64, 16: 64, 32: 64, 64: 64}
|
635 |
-
|
636 |
-
else:
|
637 |
-
channels = {
|
638 |
-
4: 512,
|
639 |
-
8: 512,
|
640 |
-
16: 512,
|
641 |
-
32: 512,
|
642 |
-
64: 256 * channel_multiplier,
|
643 |
-
128: 128 * channel_multiplier,
|
644 |
-
256: 64 * channel_multiplier,
|
645 |
-
512: 32 * channel_multiplier,
|
646 |
-
1024: 16 * channel_multiplier,
|
647 |
-
}
|
648 |
-
|
649 |
-
convs = [ConvLayer(3, channels[size], 1)]
|
650 |
-
|
651 |
-
log_size = int(math.log(size, 2))
|
652 |
-
in_channel = channels[size]
|
653 |
-
|
654 |
-
for i in range(log_size, 2, -1):
|
655 |
-
out_channel = channels[2 ** (i - 1)]
|
656 |
-
|
657 |
-
convs.append(ResBlock(in_channel, out_channel, blur_kernel))
|
658 |
-
|
659 |
-
in_channel = out_channel
|
660 |
-
|
661 |
-
self.convs = nn.Sequential(*convs)
|
662 |
-
|
663 |
-
self.stddev_group = 4
|
664 |
-
self.stddev_feat = 1
|
665 |
-
|
666 |
-
self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
|
667 |
-
self.final_linear = nn.Sequential(
|
668 |
-
EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
|
669 |
-
EqualLinear(channels[4], 1),
|
670 |
-
)
|
671 |
-
|
672 |
-
|
673 |
-
def forward(self, input):
|
674 |
-
h = input
|
675 |
-
h_list = []
|
676 |
-
|
677 |
-
for index, blocklist in enumerate(self.convs):
|
678 |
-
h = blocklist(h)
|
679 |
-
h_list.append(h)
|
680 |
-
|
681 |
-
out = h
|
682 |
-
batch, channel, height, width = out.shape
|
683 |
-
group = min(batch, self.stddev_group)
|
684 |
-
stddev = out.view(
|
685 |
-
group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
|
686 |
-
)
|
687 |
-
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
|
688 |
-
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
|
689 |
-
stddev = stddev.repeat(group, 1, height, width)
|
690 |
-
out = torch.cat([out, stddev], 1)
|
691 |
-
|
692 |
-
out = self.final_conv(out)
|
693 |
-
h_list.append(out)
|
694 |
-
|
695 |
-
out = out.view(batch, -1)
|
696 |
-
out = self.final_linear(out)
|
697 |
-
|
698 |
-
return out, h_list
|
699 |
-
|
700 |
-
|
701 |
-
class StyleEncoder(nn.Module):
|
702 |
-
def __init__(self, size, w_dim=512):
|
703 |
-
super().__init__()
|
704 |
-
|
705 |
-
channels = {
|
706 |
-
4: 512,
|
707 |
-
8: 512,
|
708 |
-
16: 512,
|
709 |
-
32: 512,
|
710 |
-
64: 256,
|
711 |
-
128: 128,
|
712 |
-
256: 64,
|
713 |
-
512: 32,
|
714 |
-
1024: 16
|
715 |
-
}
|
716 |
-
|
717 |
-
self.w_dim = w_dim
|
718 |
-
log_size = int(math.log(size, 2))
|
719 |
-
convs = [ConvLayer(3, channels[size], 1)]
|
720 |
-
|
721 |
-
in_channel = channels[size]
|
722 |
-
for i in range(log_size, 2, -1):
|
723 |
-
out_channel = channels[2 ** (i - 1)]
|
724 |
-
convs.append(ResBlock(in_channel, out_channel))
|
725 |
-
in_channel = out_channel
|
726 |
-
|
727 |
-
convs.append(EqualConv2d(in_channel,2*self.w_dim, 4, padding=0, bias=False))
|
728 |
-
|
729 |
-
self.convs = nn.Sequential(*convs)
|
730 |
-
|
731 |
-
def forward(self, input):
|
732 |
-
out = self.convs(input)
|
733 |
-
# return out.view(len(input), self.n_latents, self.w_dim)
|
734 |
-
reshaped = out.view(len(input), 2*self.w_dim)
|
735 |
-
return reshaped[:,:self.w_dim], reshaped[:,self.w_dim:]
|
736 |
-
|
737 |
-
def kaiming_init(m):
|
738 |
-
if isinstance(m, (nn.Linear, nn.Conv2d)):
|
739 |
-
init.kaiming_normal_(m.weight)
|
740 |
-
if m.bias is not None:
|
741 |
-
m.bias.data.fill_(0)
|
742 |
-
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
|
743 |
-
m.weight.data.fill_(1)
|
744 |
-
if m.bias is not None:
|
745 |
-
m.bias.data.fill_(0)
|
746 |
-
|
747 |
-
|
748 |
-
def normal_init(m):
|
749 |
-
if isinstance(m, (nn.Linear, nn.Conv2d)):
|
750 |
-
init.normal_(m.weight, 0, 0.02)
|
751 |
-
if m.bias is not None:
|
752 |
-
m.bias.data.fill_(0)
|
753 |
-
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
|
754 |
-
m.weight.data.fill_(1)
|
755 |
-
if m.bias is not None:
|
756 |
-
m.bias.data.fill_(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch_utils/models_face.py
DELETED
@@ -1,809 +0,0 @@
|
|
1 |
-
# Copyright (c) SenseTime Research. All rights reserved.
|
2 |
-
|
3 |
-
import math
|
4 |
-
import random
|
5 |
-
import functools
|
6 |
-
import operator
|
7 |
-
|
8 |
-
import torch
|
9 |
-
from torch import nn
|
10 |
-
from torch.nn import functional as F
|
11 |
-
import torch.nn.init as init
|
12 |
-
from torch.autograd import Function
|
13 |
-
|
14 |
-
from .op_edit import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
|
15 |
-
|
16 |
-
|
17 |
-
class PixelNorm(nn.Module):
|
18 |
-
def __init__(self):
|
19 |
-
super().__init__()
|
20 |
-
|
21 |
-
def forward(self, input):
|
22 |
-
return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
|
23 |
-
|
24 |
-
|
25 |
-
def make_kernel(k):
|
26 |
-
k = torch.tensor(k, dtype=torch.float32)
|
27 |
-
|
28 |
-
if k.ndim == 1:
|
29 |
-
k = k[None, :] * k[:, None]
|
30 |
-
|
31 |
-
k /= k.sum()
|
32 |
-
|
33 |
-
return k
|
34 |
-
|
35 |
-
|
36 |
-
class Upsample(nn.Module):
|
37 |
-
def __init__(self, kernel, factor=2):
|
38 |
-
super().__init__()
|
39 |
-
|
40 |
-
self.factor = factor
|
41 |
-
kernel = make_kernel(kernel) * (factor ** 2)
|
42 |
-
self.register_buffer("kernel", kernel)
|
43 |
-
|
44 |
-
p = kernel.shape[0] - factor
|
45 |
-
|
46 |
-
pad0 = (p + 1) // 2 + factor - 1
|
47 |
-
pad1 = p // 2
|
48 |
-
|
49 |
-
self.pad = (pad0, pad1)
|
50 |
-
|
51 |
-
def forward(self, input):
|
52 |
-
out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
|
53 |
-
|
54 |
-
return out
|
55 |
-
|
56 |
-
|
57 |
-
class Downsample(nn.Module):
|
58 |
-
def __init__(self, kernel, factor=2):
|
59 |
-
super().__init__()
|
60 |
-
|
61 |
-
self.factor = factor
|
62 |
-
kernel = make_kernel(kernel)
|
63 |
-
self.register_buffer("kernel", kernel)
|
64 |
-
|
65 |
-
p = kernel.shape[0] - factor
|
66 |
-
|
67 |
-
pad0 = (p + 1) // 2
|
68 |
-
pad1 = p // 2
|
69 |
-
|
70 |
-
self.pad = (pad0, pad1)
|
71 |
-
|
72 |
-
def forward(self, input):
|
73 |
-
out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
|
74 |
-
|
75 |
-
return out
|
76 |
-
|
77 |
-
|
78 |
-
class Blur(nn.Module):
|
79 |
-
def __init__(self, kernel, pad, upsample_factor=1):
|
80 |
-
super().__init__()
|
81 |
-
|
82 |
-
kernel = make_kernel(kernel)
|
83 |
-
|
84 |
-
if upsample_factor > 1:
|
85 |
-
kernel = kernel * (upsample_factor ** 2)
|
86 |
-
|
87 |
-
self.register_buffer("kernel", kernel)
|
88 |
-
|
89 |
-
self.pad = pad
|
90 |
-
|
91 |
-
def forward(self, input):
|
92 |
-
out = upfirdn2d(input, self.kernel, pad=self.pad)
|
93 |
-
|
94 |
-
return out
|
95 |
-
|
96 |
-
|
97 |
-
class EqualConv2d(nn.Module):
|
98 |
-
def __init__(
|
99 |
-
self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
|
100 |
-
):
|
101 |
-
super().__init__()
|
102 |
-
|
103 |
-
self.weight = nn.Parameter(
|
104 |
-
torch.randn(out_channel, in_channel, kernel_size, kernel_size)
|
105 |
-
)
|
106 |
-
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
|
107 |
-
|
108 |
-
self.stride = stride
|
109 |
-
self.padding = padding
|
110 |
-
|
111 |
-
if bias:
|
112 |
-
self.bias = nn.Parameter(torch.zeros(out_channel))
|
113 |
-
|
114 |
-
else:
|
115 |
-
self.bias = None
|
116 |
-
|
117 |
-
def forward(self, input):
|
118 |
-
out = F.conv2d(
|
119 |
-
input,
|
120 |
-
self.weight * self.scale,
|
121 |
-
bias=self.bias,
|
122 |
-
stride=self.stride,
|
123 |
-
padding=self.padding,
|
124 |
-
)
|
125 |
-
|
126 |
-
return out
|
127 |
-
|
128 |
-
def __repr__(self):
|
129 |
-
return (
|
130 |
-
f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
|
131 |
-
f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
|
132 |
-
)
|
133 |
-
|
134 |
-
|
135 |
-
class EqualLinear(nn.Module):
|
136 |
-
def __init__(
|
137 |
-
self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
|
138 |
-
):
|
139 |
-
super().__init__()
|
140 |
-
|
141 |
-
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
|
142 |
-
|
143 |
-
if bias:
|
144 |
-
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
|
145 |
-
|
146 |
-
else:
|
147 |
-
self.bias = None
|
148 |
-
|
149 |
-
self.activation = activation
|
150 |
-
|
151 |
-
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
|
152 |
-
self.lr_mul = lr_mul
|
153 |
-
|
154 |
-
def forward(self, input):
|
155 |
-
if self.activation:
|
156 |
-
out = F.linear(input, self.weight * self.scale)
|
157 |
-
out = fused_leaky_relu(out, self.bias * self.lr_mul)
|
158 |
-
|
159 |
-
else:
|
160 |
-
out = F.linear(
|
161 |
-
input, self.weight * self.scale, bias=self.bias * self.lr_mul
|
162 |
-
)
|
163 |
-
|
164 |
-
return out
|
165 |
-
|
166 |
-
def __repr__(self):
|
167 |
-
return (
|
168 |
-
f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
|
169 |
-
)
|
170 |
-
|
171 |
-
|
172 |
-
class ScaledLeakyReLU(nn.Module):
|
173 |
-
def __init__(self, negative_slope=0.2):
|
174 |
-
super().__init__()
|
175 |
-
|
176 |
-
self.negative_slope = negative_slope
|
177 |
-
|
178 |
-
def forward(self, input):
|
179 |
-
out = F.leaky_relu(input, negative_slope=self.negative_slope)
|
180 |
-
|
181 |
-
return out * math.sqrt(2)
|
182 |
-
|
183 |
-
|
184 |
-
class ModulatedConv2d(nn.Module):
|
185 |
-
def __init__(
|
186 |
-
self,
|
187 |
-
in_channel,
|
188 |
-
out_channel,
|
189 |
-
kernel_size,
|
190 |
-
style_dim,
|
191 |
-
demodulate=True,
|
192 |
-
upsample=False,
|
193 |
-
downsample=False,
|
194 |
-
blur_kernel=[1, 3, 3, 1],
|
195 |
-
):
|
196 |
-
super().__init__()
|
197 |
-
|
198 |
-
self.eps = 1e-8
|
199 |
-
self.kernel_size = kernel_size
|
200 |
-
self.in_channel = in_channel
|
201 |
-
self.out_channel = out_channel
|
202 |
-
self.upsample = upsample
|
203 |
-
self.downsample = downsample
|
204 |
-
|
205 |
-
if upsample:
|
206 |
-
factor = 2
|
207 |
-
p = (len(blur_kernel) - factor) - (kernel_size - 1)
|
208 |
-
pad0 = (p + 1) // 2 + factor - 1
|
209 |
-
pad1 = p // 2 + 1
|
210 |
-
|
211 |
-
self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
|
212 |
-
|
213 |
-
if downsample:
|
214 |
-
factor = 2
|
215 |
-
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
216 |
-
pad0 = (p + 1) // 2
|
217 |
-
pad1 = p // 2
|
218 |
-
|
219 |
-
self.blur = Blur(blur_kernel, pad=(pad0, pad1))
|
220 |
-
|
221 |
-
fan_in = in_channel * kernel_size ** 2
|
222 |
-
self.scale = 1 / math.sqrt(fan_in)
|
223 |
-
self.padding = kernel_size // 2
|
224 |
-
|
225 |
-
self.weight = nn.Parameter(
|
226 |
-
torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
|
227 |
-
)
|
228 |
-
|
229 |
-
self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
|
230 |
-
|
231 |
-
self.demodulate = demodulate
|
232 |
-
|
233 |
-
def __repr__(self):
|
234 |
-
return (
|
235 |
-
f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
|
236 |
-
f"upsample={self.upsample}, downsample={self.downsample})"
|
237 |
-
)
|
238 |
-
|
239 |
-
def forward(self, input, style):
|
240 |
-
batch, in_channel, height, width = input.shape
|
241 |
-
|
242 |
-
style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
|
243 |
-
weight = self.scale * self.weight * style
|
244 |
-
|
245 |
-
if self.demodulate:
|
246 |
-
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
|
247 |
-
weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
|
248 |
-
|
249 |
-
weight = weight.view(
|
250 |
-
batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
|
251 |
-
)
|
252 |
-
|
253 |
-
if self.upsample:
|
254 |
-
input = input.view(1, batch * in_channel, height, width)
|
255 |
-
weight = weight.view(
|
256 |
-
batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
|
257 |
-
)
|
258 |
-
weight = weight.transpose(1, 2).reshape(
|
259 |
-
batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
|
260 |
-
)
|
261 |
-
out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
|
262 |
-
_, _, height, width = out.shape
|
263 |
-
out = out.view(batch, self.out_channel, height, width)
|
264 |
-
out = self.blur(out)
|
265 |
-
|
266 |
-
elif self.downsample:
|
267 |
-
input = self.blur(input)
|
268 |
-
_, _, height, width = input.shape
|
269 |
-
input = input.view(1, batch * in_channel, height, width)
|
270 |
-
out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
|
271 |
-
_, _, height, width = out.shape
|
272 |
-
out = out.view(batch, self.out_channel, height, width)
|
273 |
-
|
274 |
-
else:
|
275 |
-
input = input.view(1, batch * in_channel, height, width)
|
276 |
-
out = F.conv2d(input, weight, padding=self.padding, groups=batch)
|
277 |
-
_, _, height, width = out.shape
|
278 |
-
out = out.view(batch, self.out_channel, height, width)
|
279 |
-
|
280 |
-
return out
|
281 |
-
|
282 |
-
|
283 |
-
class NoiseInjection(nn.Module):
|
284 |
-
def __init__(self):
|
285 |
-
super().__init__()
|
286 |
-
|
287 |
-
self.weight = nn.Parameter(torch.zeros(1))
|
288 |
-
|
289 |
-
def forward(self, image, noise=None):
|
290 |
-
if noise is None:
|
291 |
-
batch, _, height, width = image.shape
|
292 |
-
noise = image.new_empty(batch, 1, height, width).normal_()
|
293 |
-
|
294 |
-
return image + self.weight * noise
|
295 |
-
|
296 |
-
|
297 |
-
class ConstantInput(nn.Module):
|
298 |
-
def __init__(self, channel, size=4):
|
299 |
-
super().__init__()
|
300 |
-
|
301 |
-
self.input = nn.Parameter(torch.randn(1, channel, size, size))
|
302 |
-
|
303 |
-
def forward(self, input):
|
304 |
-
batch = input.shape[0]
|
305 |
-
out = self.input.repeat(batch, 1, 1, 1)
|
306 |
-
|
307 |
-
return out
|
308 |
-
|
309 |
-
|
310 |
-
class StyledConv(nn.Module):
|
311 |
-
def __init__(
|
312 |
-
self,
|
313 |
-
in_channel,
|
314 |
-
out_channel,
|
315 |
-
kernel_size,
|
316 |
-
style_dim,
|
317 |
-
upsample=False,
|
318 |
-
blur_kernel=[1, 3, 3, 1],
|
319 |
-
demodulate=True,
|
320 |
-
):
|
321 |
-
super().__init__()
|
322 |
-
|
323 |
-
self.conv = ModulatedConv2d(
|
324 |
-
in_channel,
|
325 |
-
out_channel,
|
326 |
-
kernel_size,
|
327 |
-
style_dim,
|
328 |
-
upsample=upsample,
|
329 |
-
blur_kernel=blur_kernel,
|
330 |
-
demodulate=demodulate,
|
331 |
-
)
|
332 |
-
|
333 |
-
self.noise = NoiseInjection()
|
334 |
-
# self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
|
335 |
-
# self.activate = ScaledLeakyReLU(0.2)
|
336 |
-
self.activate = FusedLeakyReLU(out_channel)
|
337 |
-
|
338 |
-
def forward(self, input, style, noise=None):
|
339 |
-
out = self.conv(input, style)
|
340 |
-
out = self.noise(out, noise=noise)
|
341 |
-
# out = out + self.bias
|
342 |
-
out = self.activate(out)
|
343 |
-
|
344 |
-
return out
|
345 |
-
|
346 |
-
|
347 |
-
class ToRGB(nn.Module):
|
348 |
-
def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
|
349 |
-
super().__init__()
|
350 |
-
|
351 |
-
if upsample:
|
352 |
-
self.upsample = Upsample(blur_kernel)
|
353 |
-
|
354 |
-
self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
|
355 |
-
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
|
356 |
-
|
357 |
-
def forward(self, input, style, skip=None):
|
358 |
-
out = self.conv(input, style)
|
359 |
-
out = out + self.bias
|
360 |
-
|
361 |
-
if skip is not None:
|
362 |
-
skip = self.upsample(skip)
|
363 |
-
|
364 |
-
out = out + skip
|
365 |
-
|
366 |
-
return out
|
367 |
-
|
368 |
-
|
369 |
-
class Generator(nn.Module):
|
370 |
-
def __init__(
|
371 |
-
self,
|
372 |
-
size,
|
373 |
-
style_dim,
|
374 |
-
n_mlp,
|
375 |
-
channel_multiplier=1,
|
376 |
-
blur_kernel=[1, 3, 3, 1],
|
377 |
-
lr_mlp=0.01,
|
378 |
-
small=False,
|
379 |
-
small_isaac=False,
|
380 |
-
):
|
381 |
-
super().__init__()
|
382 |
-
|
383 |
-
self.size = size
|
384 |
-
|
385 |
-
if small and size > 64:
|
386 |
-
raise ValueError("small only works for sizes <= 64")
|
387 |
-
|
388 |
-
self.style_dim = style_dim
|
389 |
-
|
390 |
-
layers = [PixelNorm()]
|
391 |
-
|
392 |
-
for i in range(n_mlp):
|
393 |
-
layers.append(
|
394 |
-
EqualLinear(
|
395 |
-
style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
|
396 |
-
)
|
397 |
-
)
|
398 |
-
|
399 |
-
self.style = nn.Sequential(*layers)
|
400 |
-
|
401 |
-
if small:
|
402 |
-
self.channels = {
|
403 |
-
4: 64 * channel_multiplier,
|
404 |
-
8: 64 * channel_multiplier,
|
405 |
-
16: 64 * channel_multiplier,
|
406 |
-
32: 64 * channel_multiplier,
|
407 |
-
64: 64 * channel_multiplier,
|
408 |
-
}
|
409 |
-
elif small_isaac:
|
410 |
-
self.channels = {4: 256, 8: 256, 16: 256, 32: 256, 64: 128, 128: 128}
|
411 |
-
else:
|
412 |
-
self.channels = {
|
413 |
-
4: 512,
|
414 |
-
8: 512,
|
415 |
-
16: 512,
|
416 |
-
32: 512,
|
417 |
-
64: 256 * channel_multiplier,
|
418 |
-
128: 128 * channel_multiplier,
|
419 |
-
256: 64 * channel_multiplier,
|
420 |
-
512: 32 * channel_multiplier,
|
421 |
-
1024: 16 * channel_multiplier,
|
422 |
-
}
|
423 |
-
|
424 |
-
self.input = ConstantInput(self.channels[4])
|
425 |
-
self.conv1 = StyledConv(
|
426 |
-
self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
|
427 |
-
)
|
428 |
-
self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
|
429 |
-
|
430 |
-
self.log_size = int(math.log(size, 2))
|
431 |
-
self.num_layers = (self.log_size - 2) * 2 + 1
|
432 |
-
|
433 |
-
self.convs = nn.ModuleList()
|
434 |
-
self.upsamples = nn.ModuleList()
|
435 |
-
self.to_rgbs = nn.ModuleList()
|
436 |
-
self.noises = nn.Module()
|
437 |
-
|
438 |
-
in_channel = self.channels[4]
|
439 |
-
|
440 |
-
for layer_idx in range(self.num_layers):
|
441 |
-
res = (layer_idx + 5) // 2
|
442 |
-
shape = [1, 1, 2 ** res, 2 ** res]
|
443 |
-
self.noises.register_buffer(
|
444 |
-
"noise_{}".format(layer_idx), torch.randn(*shape)
|
445 |
-
)
|
446 |
-
|
447 |
-
for i in range(3, self.log_size + 1):
|
448 |
-
out_channel = self.channels[2 ** i]
|
449 |
-
|
450 |
-
self.convs.append(
|
451 |
-
StyledConv(
|
452 |
-
in_channel,
|
453 |
-
out_channel,
|
454 |
-
3,
|
455 |
-
style_dim,
|
456 |
-
upsample=True,
|
457 |
-
blur_kernel=blur_kernel,
|
458 |
-
)
|
459 |
-
)
|
460 |
-
|
461 |
-
self.convs.append(
|
462 |
-
StyledConv(
|
463 |
-
out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
|
464 |
-
)
|
465 |
-
)
|
466 |
-
|
467 |
-
self.to_rgbs.append(ToRGB(out_channel, style_dim))
|
468 |
-
|
469 |
-
in_channel = out_channel
|
470 |
-
|
471 |
-
self.n_latent = self.log_size * 2 - 2
|
472 |
-
|
473 |
-
def make_noise(self):
|
474 |
-
device = self.input.input.device
|
475 |
-
|
476 |
-
noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
|
477 |
-
|
478 |
-
for i in range(3, self.log_size + 1):
|
479 |
-
for _ in range(2):
|
480 |
-
noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
|
481 |
-
|
482 |
-
return noises
|
483 |
-
|
484 |
-
def mean_latent(self, n_latent):
|
485 |
-
latent_in = torch.randn(
|
486 |
-
n_latent, self.style_dim, device=self.input.input.device
|
487 |
-
)
|
488 |
-
latent = self.style(latent_in).mean(0, keepdim=True)
|
489 |
-
|
490 |
-
return latent
|
491 |
-
|
492 |
-
def get_latent(self, input):
|
493 |
-
return self.style(input)
|
494 |
-
|
495 |
-
def forward(
|
496 |
-
self,
|
497 |
-
styles,
|
498 |
-
return_latents=False,
|
499 |
-
return_features=False,
|
500 |
-
inject_index=None,
|
501 |
-
truncation=1,
|
502 |
-
truncation_latent=None,
|
503 |
-
input_is_latent=False,
|
504 |
-
noise=None,
|
505 |
-
randomize_noise=True,
|
506 |
-
):
|
507 |
-
if not input_is_latent:
|
508 |
-
# print("haha")
|
509 |
-
styles = [self.style(s) for s in styles]
|
510 |
-
if noise is None:
|
511 |
-
if randomize_noise:
|
512 |
-
noise = [None] * self.num_layers
|
513 |
-
else:
|
514 |
-
noise = [
|
515 |
-
getattr(self.noises, "noise_{}".format(i))
|
516 |
-
for i in range(self.num_layers)
|
517 |
-
]
|
518 |
-
|
519 |
-
if truncation < 1:
|
520 |
-
style_t = []
|
521 |
-
|
522 |
-
for style in styles:
|
523 |
-
style_t.append(
|
524 |
-
truncation_latent + truncation * (style - truncation_latent)
|
525 |
-
)
|
526 |
-
|
527 |
-
styles = style_t
|
528 |
-
# print(styles)
|
529 |
-
if len(styles) < 2:
|
530 |
-
inject_index = self.n_latent
|
531 |
-
|
532 |
-
if styles[0].ndim < 3:
|
533 |
-
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
534 |
-
# print("a")
|
535 |
-
else:
|
536 |
-
# print(len(styles))
|
537 |
-
latent = styles[0]
|
538 |
-
# print("b", latent.shape)
|
539 |
-
|
540 |
-
else:
|
541 |
-
# print("c")
|
542 |
-
if inject_index is None:
|
543 |
-
inject_index = 4
|
544 |
-
|
545 |
-
latent = styles[0].unsqueeze(0)
|
546 |
-
if latent.shape[1] == 1:
|
547 |
-
latent = latent.repeat(1, inject_index, 1)
|
548 |
-
else:
|
549 |
-
latent = latent[:, :inject_index, :]
|
550 |
-
latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
|
551 |
-
|
552 |
-
latent = torch.cat([latent, latent2], 1)
|
553 |
-
|
554 |
-
features = {}
|
555 |
-
out = self.input(latent)
|
556 |
-
features["out_0"] = out
|
557 |
-
out = self.conv1(out, latent[:, 0], noise=noise[0])
|
558 |
-
features["conv1_0"] = out
|
559 |
-
|
560 |
-
skip = self.to_rgb1(out, latent[:, 1])
|
561 |
-
features["skip_0"] = skip
|
562 |
-
i = 1
|
563 |
-
for conv1, conv2, noise1, noise2, to_rgb in zip(
|
564 |
-
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
|
565 |
-
):
|
566 |
-
out = conv1(out, latent[:, i], noise=noise1)
|
567 |
-
features["conv1_{}".format(i)] = out
|
568 |
-
out = conv2(out, latent[:, i + 1], noise=noise2)
|
569 |
-
features["conv2_{}".format(i)] = out
|
570 |
-
skip = to_rgb(out, latent[:, i + 2], skip)
|
571 |
-
features["skip_{}".format(i)] = skip
|
572 |
-
|
573 |
-
i += 2
|
574 |
-
|
575 |
-
image = skip
|
576 |
-
|
577 |
-
if return_latents:
|
578 |
-
return image, latent
|
579 |
-
elif return_features:
|
580 |
-
return image, features
|
581 |
-
else:
|
582 |
-
return image, None
|
583 |
-
|
584 |
-
|
585 |
-
class ConvLayer(nn.Sequential):
|
586 |
-
def __init__(
|
587 |
-
self,
|
588 |
-
in_channel,
|
589 |
-
out_channel,
|
590 |
-
kernel_size,
|
591 |
-
downsample=False,
|
592 |
-
blur_kernel=[1, 3, 3, 1],
|
593 |
-
bias=True,
|
594 |
-
activate=True,
|
595 |
-
):
|
596 |
-
layers = []
|
597 |
-
|
598 |
-
if downsample:
|
599 |
-
factor = 2
|
600 |
-
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
601 |
-
pad0 = (p + 1) // 2
|
602 |
-
pad1 = p // 2
|
603 |
-
|
604 |
-
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
|
605 |
-
|
606 |
-
stride = 2
|
607 |
-
self.padding = 0
|
608 |
-
|
609 |
-
else:
|
610 |
-
stride = 1
|
611 |
-
self.padding = kernel_size // 2
|
612 |
-
|
613 |
-
layers.append(
|
614 |
-
EqualConv2d(
|
615 |
-
in_channel,
|
616 |
-
out_channel,
|
617 |
-
kernel_size,
|
618 |
-
padding=self.padding,
|
619 |
-
stride=stride,
|
620 |
-
bias=bias and not activate,
|
621 |
-
)
|
622 |
-
)
|
623 |
-
|
624 |
-
if activate:
|
625 |
-
if bias:
|
626 |
-
layers.append(FusedLeakyReLU(out_channel))
|
627 |
-
|
628 |
-
else:
|
629 |
-
layers.append(ScaledLeakyReLU(0.2))
|
630 |
-
|
631 |
-
super().__init__(*layers)
|
632 |
-
|
633 |
-
|
634 |
-
class ResBlock(nn.Module):
|
635 |
-
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
|
636 |
-
super().__init__()
|
637 |
-
|
638 |
-
self.conv1 = ConvLayer(in_channel, in_channel, 3)
|
639 |
-
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
|
640 |
-
|
641 |
-
self.skip = ConvLayer(
|
642 |
-
in_channel, out_channel, 1, downsample=True, activate=False, bias=False
|
643 |
-
)
|
644 |
-
|
645 |
-
def forward(self, input):
|
646 |
-
out = self.conv1(input)
|
647 |
-
out = self.conv2(out)
|
648 |
-
|
649 |
-
skip = self.skip(input)
|
650 |
-
out = (out + skip) / math.sqrt(2)
|
651 |
-
|
652 |
-
return out
|
653 |
-
|
654 |
-
|
655 |
-
class StyleDiscriminator(nn.Module):
|
656 |
-
def __init__(
|
657 |
-
self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1], small=False
|
658 |
-
):
|
659 |
-
super().__init__()
|
660 |
-
|
661 |
-
if small:
|
662 |
-
channels = {4: 64, 8: 64, 16: 64, 32: 64, 64: 64}
|
663 |
-
|
664 |
-
else:
|
665 |
-
channels = {
|
666 |
-
4: 512,
|
667 |
-
8: 512,
|
668 |
-
16: 512,
|
669 |
-
32: 512,
|
670 |
-
64: 256 * channel_multiplier,
|
671 |
-
128: 128 * channel_multiplier,
|
672 |
-
256: 64 * channel_multiplier,
|
673 |
-
512: 32 * channel_multiplier,
|
674 |
-
1024: 16 * channel_multiplier,
|
675 |
-
}
|
676 |
-
|
677 |
-
convs = [ConvLayer(3, channels[size], 1)]
|
678 |
-
|
679 |
-
log_size = int(math.log(size, 2))
|
680 |
-
|
681 |
-
in_channel = channels[size]
|
682 |
-
|
683 |
-
for i in range(log_size, 2, -1):
|
684 |
-
out_channel = channels[2 ** (i - 1)]
|
685 |
-
|
686 |
-
convs.append(ResBlock(in_channel, out_channel, blur_kernel))
|
687 |
-
|
688 |
-
in_channel = out_channel
|
689 |
-
|
690 |
-
self.convs = nn.Sequential(*convs)
|
691 |
-
|
692 |
-
self.stddev_group = 4
|
693 |
-
self.stddev_feat = 1
|
694 |
-
|
695 |
-
self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
|
696 |
-
self.final_linear = nn.Sequential(
|
697 |
-
EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
|
698 |
-
EqualLinear(channels[4], 1),
|
699 |
-
)
|
700 |
-
|
701 |
-
# def forward(self, input):
|
702 |
-
# out = self.convs(input)
|
703 |
-
|
704 |
-
# batch, channel, height, width = out.shape
|
705 |
-
# group = min(batch, self.stddev_group)
|
706 |
-
# stddev = out.view(
|
707 |
-
# group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
|
708 |
-
# )
|
709 |
-
# stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
|
710 |
-
# stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
|
711 |
-
# stddev = stddev.repeat(group, 1, height, width)
|
712 |
-
# out = torch.cat([out, stddev], 1)
|
713 |
-
|
714 |
-
# out = self.final_conv(out)
|
715 |
-
|
716 |
-
# out = out.view(batch, -1)
|
717 |
-
# out = self.final_linear(out)
|
718 |
-
|
719 |
-
# return out
|
720 |
-
|
721 |
-
def forward(self, input):
|
722 |
-
h = input
|
723 |
-
h_list = []
|
724 |
-
|
725 |
-
for index, blocklist in enumerate(self.convs):
|
726 |
-
h = blocklist(h)
|
727 |
-
h_list.append(h)
|
728 |
-
|
729 |
-
out = h
|
730 |
-
batch, channel, height, width = out.shape
|
731 |
-
group = min(batch, self.stddev_group)
|
732 |
-
stddev = out.view(
|
733 |
-
group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
|
734 |
-
)
|
735 |
-
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
|
736 |
-
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
|
737 |
-
stddev = stddev.repeat(group, 1, height, width)
|
738 |
-
out = torch.cat([out, stddev], 1)
|
739 |
-
|
740 |
-
out = self.final_conv(out)
|
741 |
-
h_list.append(out)
|
742 |
-
|
743 |
-
out = out.view(batch, -1)
|
744 |
-
out = self.final_linear(out)
|
745 |
-
|
746 |
-
return out, h_list
|
747 |
-
|
748 |
-
|
749 |
-
class StyleEncoder(nn.Module):
|
750 |
-
def __init__(self, size, w_dim=512):
|
751 |
-
super().__init__()
|
752 |
-
|
753 |
-
channels = {
|
754 |
-
4: 512,
|
755 |
-
8: 512,
|
756 |
-
16: 512,
|
757 |
-
32: 512,
|
758 |
-
64: 256,
|
759 |
-
128: 128,
|
760 |
-
256: 64,
|
761 |
-
512: 32,
|
762 |
-
1024: 16
|
763 |
-
}
|
764 |
-
|
765 |
-
self.w_dim = w_dim
|
766 |
-
log_size = int(math.log(size, 2))
|
767 |
-
|
768 |
-
# self.n_latents = log_size*2 - 2
|
769 |
-
|
770 |
-
convs = [ConvLayer(3, channels[size], 1)]
|
771 |
-
|
772 |
-
in_channel = channels[size]
|
773 |
-
for i in range(log_size, 2, -1):
|
774 |
-
out_channel = channels[2 ** (i - 1)]
|
775 |
-
convs.append(ResBlock(in_channel, out_channel))
|
776 |
-
in_channel = out_channel
|
777 |
-
|
778 |
-
# convs.append(EqualConv2d(in_channel, self.n_latents*self.w_dim, 4, padding=0, bias=False))
|
779 |
-
convs.append(EqualConv2d(in_channel,2*self.w_dim, 4, padding=0, bias=False))
|
780 |
-
|
781 |
-
|
782 |
-
self.convs = nn.Sequential(*convs)
|
783 |
-
|
784 |
-
def forward(self, input):
|
785 |
-
out = self.convs(input)
|
786 |
-
# return out.view(len(input), self.n_latents, self.w_dim)
|
787 |
-
reshaped = out.view(len(input), 2*self.w_dim)
|
788 |
-
return reshaped[:,:self.w_dim], reshaped[:,self.w_dim:]
|
789 |
-
|
790 |
-
def kaiming_init(m):
|
791 |
-
if isinstance(m, (nn.Linear, nn.Conv2d)):
|
792 |
-
init.kaiming_normal_(m.weight)
|
793 |
-
if m.bias is not None:
|
794 |
-
m.bias.data.fill_(0)
|
795 |
-
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
|
796 |
-
m.weight.data.fill_(1)
|
797 |
-
if m.bias is not None:
|
798 |
-
m.bias.data.fill_(0)
|
799 |
-
|
800 |
-
|
801 |
-
def normal_init(m):
|
802 |
-
if isinstance(m, (nn.Linear, nn.Conv2d)):
|
803 |
-
init.normal_(m.weight, 0, 0.02)
|
804 |
-
if m.bias is not None:
|
805 |
-
m.bias.data.fill_(0)
|
806 |
-
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
|
807 |
-
m.weight.data.fill_(1)
|
808 |
-
if m.bias is not None:
|
809 |
-
m.bias.data.fill_(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch_utils/op_edit/__init__.py
DELETED
@@ -1,4 +0,0 @@
|
|
1 |
-
# Copyright (c) SenseTime Research. All rights reserved.
|
2 |
-
|
3 |
-
from .fused_act import FusedLeakyReLU, fused_leaky_relu
|
4 |
-
from .upfirdn2d import upfirdn2d
|
|
|
|
|
|
|
|
|
|
torch_utils/op_edit/fused_act.py
DELETED
@@ -1,99 +0,0 @@
|
|
1 |
-
# Copyright (c) SenseTime Research. All rights reserved.
|
2 |
-
|
3 |
-
import os
|
4 |
-
|
5 |
-
import torch
|
6 |
-
from torch import nn
|
7 |
-
from torch.nn import functional as F
|
8 |
-
from torch.autograd import Function
|
9 |
-
from torch.utils.cpp_extension import load
|
10 |
-
|
11 |
-
|
12 |
-
module_path = os.path.dirname(__file__)
|
13 |
-
fused = load(
|
14 |
-
"fused",
|
15 |
-
sources=[
|
16 |
-
os.path.join(module_path, "fused_bias_act.cpp"),
|
17 |
-
os.path.join(module_path, "fused_bias_act_kernel.cu"),
|
18 |
-
],
|
19 |
-
)
|
20 |
-
|
21 |
-
|
22 |
-
class FusedLeakyReLUFunctionBackward(Function):
|
23 |
-
@staticmethod
|
24 |
-
def forward(ctx, grad_output, out, negative_slope, scale):
|
25 |
-
ctx.save_for_backward(out)
|
26 |
-
ctx.negative_slope = negative_slope
|
27 |
-
ctx.scale = scale
|
28 |
-
|
29 |
-
empty = grad_output.new_empty(0)
|
30 |
-
|
31 |
-
grad_input = fused.fused_bias_act(
|
32 |
-
grad_output, empty, out, 3, 1, negative_slope, scale
|
33 |
-
)
|
34 |
-
|
35 |
-
dim = [0]
|
36 |
-
|
37 |
-
if grad_input.ndim > 2:
|
38 |
-
dim += list(range(2, grad_input.ndim))
|
39 |
-
|
40 |
-
grad_bias = grad_input.sum(dim).detach()
|
41 |
-
|
42 |
-
return grad_input, grad_bias
|
43 |
-
|
44 |
-
@staticmethod
|
45 |
-
def backward(ctx, gradgrad_input, gradgrad_bias):
|
46 |
-
(out,) = ctx.saved_tensors
|
47 |
-
gradgrad_out = fused.fused_bias_act(
|
48 |
-
gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
|
49 |
-
)
|
50 |
-
|
51 |
-
return gradgrad_out, None, None, None
|
52 |
-
|
53 |
-
|
54 |
-
class FusedLeakyReLUFunction(Function):
|
55 |
-
@staticmethod
|
56 |
-
def forward(ctx, input, bias, negative_slope, scale):
|
57 |
-
empty = input.new_empty(0)
|
58 |
-
out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
|
59 |
-
ctx.save_for_backward(out)
|
60 |
-
ctx.negative_slope = negative_slope
|
61 |
-
ctx.scale = scale
|
62 |
-
|
63 |
-
return out
|
64 |
-
|
65 |
-
@staticmethod
|
66 |
-
def backward(ctx, grad_output):
|
67 |
-
(out,) = ctx.saved_tensors
|
68 |
-
|
69 |
-
grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
|
70 |
-
grad_output, out, ctx.negative_slope, ctx.scale
|
71 |
-
)
|
72 |
-
|
73 |
-
return grad_input, grad_bias, None, None
|
74 |
-
|
75 |
-
|
76 |
-
class FusedLeakyReLU(nn.Module):
|
77 |
-
def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
|
78 |
-
super().__init__()
|
79 |
-
|
80 |
-
self.bias = nn.Parameter(torch.zeros(channel))
|
81 |
-
self.negative_slope = negative_slope
|
82 |
-
self.scale = scale
|
83 |
-
|
84 |
-
def forward(self, input):
|
85 |
-
return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
|
86 |
-
|
87 |
-
|
88 |
-
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
|
89 |
-
if input.device.type == "cpu":
|
90 |
-
rest_dim = [1] * (input.ndim - bias.ndim - 1)
|
91 |
-
return (
|
92 |
-
F.leaky_relu(
|
93 |
-
input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
|
94 |
-
)
|
95 |
-
* scale
|
96 |
-
)
|
97 |
-
|
98 |
-
else:
|
99 |
-
return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch_utils/op_edit/fused_bias_act.cpp
DELETED
@@ -1,23 +0,0 @@
|
|
1 |
-
// Copyright (c) SenseTime Research. All rights reserved.
|
2 |
-
|
3 |
-
#include <torch/extension.h>
|
4 |
-
|
5 |
-
|
6 |
-
torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
|
7 |
-
int act, int grad, float alpha, float scale);
|
8 |
-
|
9 |
-
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
10 |
-
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
11 |
-
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
12 |
-
|
13 |
-
torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
|
14 |
-
int act, int grad, float alpha, float scale) {
|
15 |
-
CHECK_CUDA(input);
|
16 |
-
CHECK_CUDA(bias);
|
17 |
-
|
18 |
-
return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
|
19 |
-
}
|
20 |
-
|
21 |
-
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
22 |
-
m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
|
23 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch_utils/op_edit/fused_bias_act_kernel.cu
DELETED
@@ -1,101 +0,0 @@
|
|
1 |
-
// Copyright (c) SenseTime Research. All rights reserved.
|
2 |
-
|
3 |
-
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
4 |
-
//
|
5 |
-
// This work is made available under the Nvidia Source Code License-NC.
|
6 |
-
// To view a copy of this license, visit
|
7 |
-
// https://nvlabs.github.io/stylegan2/license.html
|
8 |
-
|
9 |
-
#include <torch/types.h>
|
10 |
-
|
11 |
-
#include <ATen/ATen.h>
|
12 |
-
#include <ATen/AccumulateType.h>
|
13 |
-
#include <ATen/cuda/CUDAContext.h>
|
14 |
-
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
15 |
-
|
16 |
-
#include <cuda.h>
|
17 |
-
#include <cuda_runtime.h>
|
18 |
-
|
19 |
-
|
20 |
-
template <typename scalar_t>
|
21 |
-
static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
|
22 |
-
int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
|
23 |
-
int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
|
24 |
-
|
25 |
-
scalar_t zero = 0.0;
|
26 |
-
|
27 |
-
for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
|
28 |
-
scalar_t x = p_x[xi];
|
29 |
-
|
30 |
-
if (use_bias) {
|
31 |
-
x += p_b[(xi / step_b) % size_b];
|
32 |
-
}
|
33 |
-
|
34 |
-
scalar_t ref = use_ref ? p_ref[xi] : zero;
|
35 |
-
|
36 |
-
scalar_t y;
|
37 |
-
|
38 |
-
switch (act * 10 + grad) {
|
39 |
-
default:
|
40 |
-
case 10: y = x; break;
|
41 |
-
case 11: y = x; break;
|
42 |
-
case 12: y = 0.0; break;
|
43 |
-
|
44 |
-
case 30: y = (x > 0.0) ? x : x * alpha; break;
|
45 |
-
case 31: y = (ref > 0.0) ? x : x * alpha; break;
|
46 |
-
case 32: y = 0.0; break;
|
47 |
-
}
|
48 |
-
|
49 |
-
out[xi] = y * scale;
|
50 |
-
}
|
51 |
-
}
|
52 |
-
|
53 |
-
|
54 |
-
torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
|
55 |
-
int act, int grad, float alpha, float scale) {
|
56 |
-
int curDevice = -1;
|
57 |
-
cudaGetDevice(&curDevice);
|
58 |
-
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
|
59 |
-
|
60 |
-
auto x = input.contiguous();
|
61 |
-
auto b = bias.contiguous();
|
62 |
-
auto ref = refer.contiguous();
|
63 |
-
|
64 |
-
int use_bias = b.numel() ? 1 : 0;
|
65 |
-
int use_ref = ref.numel() ? 1 : 0;
|
66 |
-
|
67 |
-
int size_x = x.numel();
|
68 |
-
int size_b = b.numel();
|
69 |
-
int step_b = 1;
|
70 |
-
|
71 |
-
for (int i = 1 + 1; i < x.dim(); i++) {
|
72 |
-
step_b *= x.size(i);
|
73 |
-
}
|
74 |
-
|
75 |
-
int loop_x = 4;
|
76 |
-
int block_size = 4 * 32;
|
77 |
-
int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
|
78 |
-
|
79 |
-
auto y = torch::empty_like(x);
|
80 |
-
|
81 |
-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
|
82 |
-
fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
|
83 |
-
y.data_ptr<scalar_t>(),
|
84 |
-
x.data_ptr<scalar_t>(),
|
85 |
-
b.data_ptr<scalar_t>(),
|
86 |
-
ref.data_ptr<scalar_t>(),
|
87 |
-
act,
|
88 |
-
grad,
|
89 |
-
alpha,
|
90 |
-
scale,
|
91 |
-
loop_x,
|
92 |
-
size_x,
|
93 |
-
step_b,
|
94 |
-
size_b,
|
95 |
-
use_bias,
|
96 |
-
use_ref
|
97 |
-
);
|
98 |
-
});
|
99 |
-
|
100 |
-
return y;
|
101 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch_utils/op_edit/upfirdn2d.cpp
DELETED
@@ -1,25 +0,0 @@
|
|
1 |
-
// Copyright (c) SenseTime Research. All rights reserved.
|
2 |
-
|
3 |
-
#include <torch/extension.h>
|
4 |
-
|
5 |
-
|
6 |
-
torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
|
7 |
-
int up_x, int up_y, int down_x, int down_y,
|
8 |
-
int pad_x0, int pad_x1, int pad_y0, int pad_y1);
|
9 |
-
|
10 |
-
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
11 |
-
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
12 |
-
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
13 |
-
|
14 |
-
torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
|
15 |
-
int up_x, int up_y, int down_x, int down_y,
|
16 |
-
int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
|
17 |
-
CHECK_CUDA(input);
|
18 |
-
CHECK_CUDA(kernel);
|
19 |
-
|
20 |
-
return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
|
21 |
-
}
|
22 |
-
|
23 |
-
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
24 |
-
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
|
25 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch_utils/op_edit/upfirdn2d.py
DELETED
@@ -1,202 +0,0 @@
|
|
1 |
-
# Copyright (c) SenseTime Research. All rights reserved.
|
2 |
-
|
3 |
-
import os
|
4 |
-
|
5 |
-
import torch
|
6 |
-
from torch.nn import functional as F
|
7 |
-
from torch.autograd import Function
|
8 |
-
from torch.utils.cpp_extension import load
|
9 |
-
|
10 |
-
|
11 |
-
module_path = os.path.dirname(__file__)
|
12 |
-
upfirdn2d_op = load(
|
13 |
-
"upfirdn2d",
|
14 |
-
sources=[
|
15 |
-
os.path.join(module_path, "upfirdn2d.cpp"),
|
16 |
-
os.path.join(module_path, "upfirdn2d_kernel.cu"),
|
17 |
-
],
|
18 |
-
)
|
19 |
-
|
20 |
-
|
21 |
-
class UpFirDn2dBackward(Function):
|
22 |
-
@staticmethod
|
23 |
-
def forward(
|
24 |
-
ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
|
25 |
-
):
|
26 |
-
|
27 |
-
up_x, up_y = up
|
28 |
-
down_x, down_y = down
|
29 |
-
g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
|
30 |
-
|
31 |
-
grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
|
32 |
-
|
33 |
-
grad_input = upfirdn2d_op.upfirdn2d(
|
34 |
-
grad_output,
|
35 |
-
grad_kernel,
|
36 |
-
down_x,
|
37 |
-
down_y,
|
38 |
-
up_x,
|
39 |
-
up_y,
|
40 |
-
g_pad_x0,
|
41 |
-
g_pad_x1,
|
42 |
-
g_pad_y0,
|
43 |
-
g_pad_y1,
|
44 |
-
)
|
45 |
-
grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
|
46 |
-
|
47 |
-
ctx.save_for_backward(kernel)
|
48 |
-
|
49 |
-
pad_x0, pad_x1, pad_y0, pad_y1 = pad
|
50 |
-
|
51 |
-
ctx.up_x = up_x
|
52 |
-
ctx.up_y = up_y
|
53 |
-
ctx.down_x = down_x
|
54 |
-
ctx.down_y = down_y
|
55 |
-
ctx.pad_x0 = pad_x0
|
56 |
-
ctx.pad_x1 = pad_x1
|
57 |
-
ctx.pad_y0 = pad_y0
|
58 |
-
ctx.pad_y1 = pad_y1
|
59 |
-
ctx.in_size = in_size
|
60 |
-
ctx.out_size = out_size
|
61 |
-
|
62 |
-
return grad_input
|
63 |
-
|
64 |
-
@staticmethod
|
65 |
-
def backward(ctx, gradgrad_input):
|
66 |
-
(kernel,) = ctx.saved_tensors
|
67 |
-
|
68 |
-
gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
|
69 |
-
|
70 |
-
gradgrad_out = upfirdn2d_op.upfirdn2d(
|
71 |
-
gradgrad_input,
|
72 |
-
kernel,
|
73 |
-
ctx.up_x,
|
74 |
-
ctx.up_y,
|
75 |
-
ctx.down_x,
|
76 |
-
ctx.down_y,
|
77 |
-
ctx.pad_x0,
|
78 |
-
ctx.pad_x1,
|
79 |
-
ctx.pad_y0,
|
80 |
-
ctx.pad_y1,
|
81 |
-
)
|
82 |
-
# gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
|
83 |
-
gradgrad_out = gradgrad_out.view(
|
84 |
-
ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
|
85 |
-
)
|
86 |
-
|
87 |
-
return gradgrad_out, None, None, None, None, None, None, None, None
|
88 |
-
|
89 |
-
|
90 |
-
class UpFirDn2d(Function):
|
91 |
-
@staticmethod
|
92 |
-
def forward(ctx, input, kernel, up, down, pad):
|
93 |
-
up_x, up_y = up
|
94 |
-
down_x, down_y = down
|
95 |
-
pad_x0, pad_x1, pad_y0, pad_y1 = pad
|
96 |
-
|
97 |
-
kernel_h, kernel_w = kernel.shape
|
98 |
-
batch, channel, in_h, in_w = input.shape
|
99 |
-
ctx.in_size = input.shape
|
100 |
-
|
101 |
-
input = input.reshape(-1, in_h, in_w, 1)
|
102 |
-
|
103 |
-
ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
|
104 |
-
|
105 |
-
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
|
106 |
-
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
|
107 |
-
ctx.out_size = (out_h, out_w)
|
108 |
-
|
109 |
-
ctx.up = (up_x, up_y)
|
110 |
-
ctx.down = (down_x, down_y)
|
111 |
-
ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
|
112 |
-
|
113 |
-
g_pad_x0 = kernel_w - pad_x0 - 1
|
114 |
-
g_pad_y0 = kernel_h - pad_y0 - 1
|
115 |
-
g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
|
116 |
-
g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
|
117 |
-
|
118 |
-
ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
|
119 |
-
|
120 |
-
out = upfirdn2d_op.upfirdn2d(
|
121 |
-
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
|
122 |
-
)
|
123 |
-
# out = out.view(major, out_h, out_w, minor)
|
124 |
-
out = out.view(-1, channel, out_h, out_w)
|
125 |
-
|
126 |
-
return out
|
127 |
-
|
128 |
-
@staticmethod
|
129 |
-
def backward(ctx, grad_output):
|
130 |
-
kernel, grad_kernel = ctx.saved_tensors
|
131 |
-
|
132 |
-
grad_input = UpFirDn2dBackward.apply(
|
133 |
-
grad_output,
|
134 |
-
kernel,
|
135 |
-
grad_kernel,
|
136 |
-
ctx.up,
|
137 |
-
ctx.down,
|
138 |
-
ctx.pad,
|
139 |
-
ctx.g_pad,
|
140 |
-
ctx.in_size,
|
141 |
-
ctx.out_size,
|
142 |
-
)
|
143 |
-
|
144 |
-
return grad_input, None, None, None, None
|
145 |
-
|
146 |
-
|
147 |
-
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
148 |
-
if input.device.type == "cpu":
|
149 |
-
out = upfirdn2d_native(
|
150 |
-
input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
|
151 |
-
)
|
152 |
-
|
153 |
-
else:
|
154 |
-
out = UpFirDn2d.apply(
|
155 |
-
input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
|
156 |
-
)
|
157 |
-
|
158 |
-
return out
|
159 |
-
|
160 |
-
|
161 |
-
def upfirdn2d_native(
|
162 |
-
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
|
163 |
-
):
|
164 |
-
_, channel, in_h, in_w = input.shape
|
165 |
-
input = input.reshape(-1, in_h, in_w, 1)
|
166 |
-
|
167 |
-
_, in_h, in_w, minor = input.shape
|
168 |
-
kernel_h, kernel_w = kernel.shape
|
169 |
-
|
170 |
-
out = input.view(-1, in_h, 1, in_w, 1, minor)
|
171 |
-
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
|
172 |
-
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
|
173 |
-
|
174 |
-
out = F.pad(
|
175 |
-
out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
|
176 |
-
)
|
177 |
-
out = out[
|
178 |
-
:,
|
179 |
-
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
|
180 |
-
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
|
181 |
-
:,
|
182 |
-
]
|
183 |
-
|
184 |
-
out = out.permute(0, 3, 1, 2)
|
185 |
-
out = out.reshape(
|
186 |
-
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
|
187 |
-
)
|
188 |
-
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
189 |
-
out = F.conv2d(out, w)
|
190 |
-
out = out.reshape(
|
191 |
-
-1,
|
192 |
-
minor,
|
193 |
-
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
194 |
-
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
|
195 |
-
)
|
196 |
-
out = out.permute(0, 2, 3, 1)
|
197 |
-
out = out[:, ::down_y, ::down_x, :]
|
198 |
-
|
199 |
-
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
|
200 |
-
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
|
201 |
-
|
202 |
-
return out.view(-1, channel, out_h, out_w)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch_utils/op_edit/upfirdn2d_kernel.cu
DELETED
@@ -1,371 +0,0 @@
|
|
1 |
-
// Copyright (c) SenseTime Research. All rights reserved.
|
2 |
-
|
3 |
-
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
4 |
-
//
|
5 |
-
// This work is made available under the Nvidia Source Code License-NC.
|
6 |
-
// To view a copy of this license, visit
|
7 |
-
// https://nvlabs.github.io/stylegan2/license.html
|
8 |
-
|
9 |
-
#include <torch/types.h>
|
10 |
-
|
11 |
-
#include <ATen/ATen.h>
|
12 |
-
#include <ATen/AccumulateType.h>
|
13 |
-
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
14 |
-
#include <ATen/cuda/CUDAContext.h>
|
15 |
-
|
16 |
-
#include <cuda.h>
|
17 |
-
#include <cuda_runtime.h>
|
18 |
-
|
19 |
-
static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
|
20 |
-
int c = a / b;
|
21 |
-
|
22 |
-
if (c * b > a) {
|
23 |
-
c--;
|
24 |
-
}
|
25 |
-
|
26 |
-
return c;
|
27 |
-
}
|
28 |
-
|
29 |
-
struct UpFirDn2DKernelParams {
|
30 |
-
int up_x;
|
31 |
-
int up_y;
|
32 |
-
int down_x;
|
33 |
-
int down_y;
|
34 |
-
int pad_x0;
|
35 |
-
int pad_x1;
|
36 |
-
int pad_y0;
|
37 |
-
int pad_y1;
|
38 |
-
|
39 |
-
int major_dim;
|
40 |
-
int in_h;
|
41 |
-
int in_w;
|
42 |
-
int minor_dim;
|
43 |
-
int kernel_h;
|
44 |
-
int kernel_w;
|
45 |
-
int out_h;
|
46 |
-
int out_w;
|
47 |
-
int loop_major;
|
48 |
-
int loop_x;
|
49 |
-
};
|
50 |
-
|
51 |
-
template <typename scalar_t>
|
52 |
-
__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
|
53 |
-
const scalar_t *kernel,
|
54 |
-
const UpFirDn2DKernelParams p) {
|
55 |
-
int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
56 |
-
int out_y = minor_idx / p.minor_dim;
|
57 |
-
minor_idx -= out_y * p.minor_dim;
|
58 |
-
int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
|
59 |
-
int major_idx_base = blockIdx.z * p.loop_major;
|
60 |
-
|
61 |
-
if (out_x_base >= p.out_w || out_y >= p.out_h ||
|
62 |
-
major_idx_base >= p.major_dim) {
|
63 |
-
return;
|
64 |
-
}
|
65 |
-
|
66 |
-
int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
|
67 |
-
int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
|
68 |
-
int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
|
69 |
-
int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
|
70 |
-
|
71 |
-
for (int loop_major = 0, major_idx = major_idx_base;
|
72 |
-
loop_major < p.loop_major && major_idx < p.major_dim;
|
73 |
-
loop_major++, major_idx++) {
|
74 |
-
for (int loop_x = 0, out_x = out_x_base;
|
75 |
-
loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
|
76 |
-
int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
|
77 |
-
int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
|
78 |
-
int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
|
79 |
-
int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
|
80 |
-
|
81 |
-
const scalar_t *x_p =
|
82 |
-
&input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
|
83 |
-
minor_idx];
|
84 |
-
const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
|
85 |
-
int x_px = p.minor_dim;
|
86 |
-
int k_px = -p.up_x;
|
87 |
-
int x_py = p.in_w * p.minor_dim;
|
88 |
-
int k_py = -p.up_y * p.kernel_w;
|
89 |
-
|
90 |
-
scalar_t v = 0.0f;
|
91 |
-
|
92 |
-
for (int y = 0; y < h; y++) {
|
93 |
-
for (int x = 0; x < w; x++) {
|
94 |
-
v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);
|
95 |
-
x_p += x_px;
|
96 |
-
k_p += k_px;
|
97 |
-
}
|
98 |
-
|
99 |
-
x_p += x_py - w * x_px;
|
100 |
-
k_p += k_py - w * k_px;
|
101 |
-
}
|
102 |
-
|
103 |
-
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
|
104 |
-
minor_idx] = v;
|
105 |
-
}
|
106 |
-
}
|
107 |
-
}
|
108 |
-
|
109 |
-
template <typename scalar_t, int up_x, int up_y, int down_x, int down_y,
|
110 |
-
int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
|
111 |
-
__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
|
112 |
-
const scalar_t *kernel,
|
113 |
-
const UpFirDn2DKernelParams p) {
|
114 |
-
const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
|
115 |
-
const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
|
116 |
-
|
117 |
-
__shared__ volatile float sk[kernel_h][kernel_w];
|
118 |
-
__shared__ volatile float sx[tile_in_h][tile_in_w];
|
119 |
-
|
120 |
-
int minor_idx = blockIdx.x;
|
121 |
-
int tile_out_y = minor_idx / p.minor_dim;
|
122 |
-
minor_idx -= tile_out_y * p.minor_dim;
|
123 |
-
tile_out_y *= tile_out_h;
|
124 |
-
int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
|
125 |
-
int major_idx_base = blockIdx.z * p.loop_major;
|
126 |
-
|
127 |
-
if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
|
128 |
-
major_idx_base >= p.major_dim) {
|
129 |
-
return;
|
130 |
-
}
|
131 |
-
|
132 |
-
for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
|
133 |
-
tap_idx += blockDim.x) {
|
134 |
-
int ky = tap_idx / kernel_w;
|
135 |
-
int kx = tap_idx - ky * kernel_w;
|
136 |
-
scalar_t v = 0.0;
|
137 |
-
|
138 |
-
if (kx < p.kernel_w & ky < p.kernel_h) {
|
139 |
-
v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
|
140 |
-
}
|
141 |
-
|
142 |
-
sk[ky][kx] = v;
|
143 |
-
}
|
144 |
-
|
145 |
-
for (int loop_major = 0, major_idx = major_idx_base;
|
146 |
-
loop_major < p.loop_major & major_idx < p.major_dim;
|
147 |
-
loop_major++, major_idx++) {
|
148 |
-
for (int loop_x = 0, tile_out_x = tile_out_x_base;
|
149 |
-
loop_x < p.loop_x & tile_out_x < p.out_w;
|
150 |
-
loop_x++, tile_out_x += tile_out_w) {
|
151 |
-
int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
|
152 |
-
int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
|
153 |
-
int tile_in_x = floor_div(tile_mid_x, up_x);
|
154 |
-
int tile_in_y = floor_div(tile_mid_y, up_y);
|
155 |
-
|
156 |
-
__syncthreads();
|
157 |
-
|
158 |
-
for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
|
159 |
-
in_idx += blockDim.x) {
|
160 |
-
int rel_in_y = in_idx / tile_in_w;
|
161 |
-
int rel_in_x = in_idx - rel_in_y * tile_in_w;
|
162 |
-
int in_x = rel_in_x + tile_in_x;
|
163 |
-
int in_y = rel_in_y + tile_in_y;
|
164 |
-
|
165 |
-
scalar_t v = 0.0;
|
166 |
-
|
167 |
-
if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
|
168 |
-
v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
|
169 |
-
p.minor_dim +
|
170 |
-
minor_idx];
|
171 |
-
}
|
172 |
-
|
173 |
-
sx[rel_in_y][rel_in_x] = v;
|
174 |
-
}
|
175 |
-
|
176 |
-
__syncthreads();
|
177 |
-
for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
|
178 |
-
out_idx += blockDim.x) {
|
179 |
-
int rel_out_y = out_idx / tile_out_w;
|
180 |
-
int rel_out_x = out_idx - rel_out_y * tile_out_w;
|
181 |
-
int out_x = rel_out_x + tile_out_x;
|
182 |
-
int out_y = rel_out_y + tile_out_y;
|
183 |
-
|
184 |
-
int mid_x = tile_mid_x + rel_out_x * down_x;
|
185 |
-
int mid_y = tile_mid_y + rel_out_y * down_y;
|
186 |
-
int in_x = floor_div(mid_x, up_x);
|
187 |
-
int in_y = floor_div(mid_y, up_y);
|
188 |
-
int rel_in_x = in_x - tile_in_x;
|
189 |
-
int rel_in_y = in_y - tile_in_y;
|
190 |
-
int kernel_x = (in_x + 1) * up_x - mid_x - 1;
|
191 |
-
int kernel_y = (in_y + 1) * up_y - mid_y - 1;
|
192 |
-
|
193 |
-
scalar_t v = 0.0;
|
194 |
-
|
195 |
-
#pragma unroll
|
196 |
-
for (int y = 0; y < kernel_h / up_y; y++)
|
197 |
-
#pragma unroll
|
198 |
-
for (int x = 0; x < kernel_w / up_x; x++)
|
199 |
-
v += sx[rel_in_y + y][rel_in_x + x] *
|
200 |
-
sk[kernel_y + y * up_y][kernel_x + x * up_x];
|
201 |
-
|
202 |
-
if (out_x < p.out_w & out_y < p.out_h) {
|
203 |
-
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
|
204 |
-
minor_idx] = v;
|
205 |
-
}
|
206 |
-
}
|
207 |
-
}
|
208 |
-
}
|
209 |
-
}
|
210 |
-
|
211 |
-
torch::Tensor upfirdn2d_op(const torch::Tensor &input,
|
212 |
-
const torch::Tensor &kernel, int up_x, int up_y,
|
213 |
-
int down_x, int down_y, int pad_x0, int pad_x1,
|
214 |
-
int pad_y0, int pad_y1) {
|
215 |
-
int curDevice = -1;
|
216 |
-
cudaGetDevice(&curDevice);
|
217 |
-
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
|
218 |
-
|
219 |
-
UpFirDn2DKernelParams p;
|
220 |
-
|
221 |
-
auto x = input.contiguous();
|
222 |
-
auto k = kernel.contiguous();
|
223 |
-
|
224 |
-
p.major_dim = x.size(0);
|
225 |
-
p.in_h = x.size(1);
|
226 |
-
p.in_w = x.size(2);
|
227 |
-
p.minor_dim = x.size(3);
|
228 |
-
p.kernel_h = k.size(0);
|
229 |
-
p.kernel_w = k.size(1);
|
230 |
-
p.up_x = up_x;
|
231 |
-
p.up_y = up_y;
|
232 |
-
p.down_x = down_x;
|
233 |
-
p.down_y = down_y;
|
234 |
-
p.pad_x0 = pad_x0;
|
235 |
-
p.pad_x1 = pad_x1;
|
236 |
-
p.pad_y0 = pad_y0;
|
237 |
-
p.pad_y1 = pad_y1;
|
238 |
-
|
239 |
-
p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
|
240 |
-
p.down_y;
|
241 |
-
p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
|
242 |
-
p.down_x;
|
243 |
-
|
244 |
-
auto out =
|
245 |
-
at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
|
246 |
-
|
247 |
-
int mode = -1;
|
248 |
-
|
249 |
-
int tile_out_h = -1;
|
250 |
-
int tile_out_w = -1;
|
251 |
-
|
252 |
-
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
|
253 |
-
p.kernel_h <= 4 && p.kernel_w <= 4) {
|
254 |
-
mode = 1;
|
255 |
-
tile_out_h = 16;
|
256 |
-
tile_out_w = 64;
|
257 |
-
}
|
258 |
-
|
259 |
-
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
|
260 |
-
p.kernel_h <= 3 && p.kernel_w <= 3) {
|
261 |
-
mode = 2;
|
262 |
-
tile_out_h = 16;
|
263 |
-
tile_out_w = 64;
|
264 |
-
}
|
265 |
-
|
266 |
-
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
|
267 |
-
p.kernel_h <= 4 && p.kernel_w <= 4) {
|
268 |
-
mode = 3;
|
269 |
-
tile_out_h = 16;
|
270 |
-
tile_out_w = 64;
|
271 |
-
}
|
272 |
-
|
273 |
-
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
|
274 |
-
p.kernel_h <= 2 && p.kernel_w <= 2) {
|
275 |
-
mode = 4;
|
276 |
-
tile_out_h = 16;
|
277 |
-
tile_out_w = 64;
|
278 |
-
}
|
279 |
-
|
280 |
-
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
|
281 |
-
p.kernel_h <= 4 && p.kernel_w <= 4) {
|
282 |
-
mode = 5;
|
283 |
-
tile_out_h = 8;
|
284 |
-
tile_out_w = 32;
|
285 |
-
}
|
286 |
-
|
287 |
-
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
|
288 |
-
p.kernel_h <= 2 && p.kernel_w <= 2) {
|
289 |
-
mode = 6;
|
290 |
-
tile_out_h = 8;
|
291 |
-
tile_out_w = 32;
|
292 |
-
}
|
293 |
-
|
294 |
-
dim3 block_size;
|
295 |
-
dim3 grid_size;
|
296 |
-
|
297 |
-
if (tile_out_h > 0 && tile_out_w > 0) {
|
298 |
-
p.loop_major = (p.major_dim - 1) / 16384 + 1;
|
299 |
-
p.loop_x = 1;
|
300 |
-
block_size = dim3(32 * 8, 1, 1);
|
301 |
-
grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
|
302 |
-
(p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
|
303 |
-
(p.major_dim - 1) / p.loop_major + 1);
|
304 |
-
} else {
|
305 |
-
p.loop_major = (p.major_dim - 1) / 16384 + 1;
|
306 |
-
p.loop_x = 4;
|
307 |
-
block_size = dim3(4, 32, 1);
|
308 |
-
grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
|
309 |
-
(p.out_w - 1) / (p.loop_x * block_size.y) + 1,
|
310 |
-
(p.major_dim - 1) / p.loop_major + 1);
|
311 |
-
}
|
312 |
-
|
313 |
-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
|
314 |
-
switch (mode) {
|
315 |
-
case 1:
|
316 |
-
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>
|
317 |
-
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
318 |
-
x.data_ptr<scalar_t>(),
|
319 |
-
k.data_ptr<scalar_t>(), p);
|
320 |
-
|
321 |
-
break;
|
322 |
-
|
323 |
-
case 2:
|
324 |
-
upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>
|
325 |
-
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
326 |
-
x.data_ptr<scalar_t>(),
|
327 |
-
k.data_ptr<scalar_t>(), p);
|
328 |
-
|
329 |
-
break;
|
330 |
-
|
331 |
-
case 3:
|
332 |
-
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>
|
333 |
-
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
334 |
-
x.data_ptr<scalar_t>(),
|
335 |
-
k.data_ptr<scalar_t>(), p);
|
336 |
-
|
337 |
-
break;
|
338 |
-
|
339 |
-
case 4:
|
340 |
-
upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>
|
341 |
-
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
342 |
-
x.data_ptr<scalar_t>(),
|
343 |
-
k.data_ptr<scalar_t>(), p);
|
344 |
-
|
345 |
-
break;
|
346 |
-
|
347 |
-
case 5:
|
348 |
-
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
|
349 |
-
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
350 |
-
x.data_ptr<scalar_t>(),
|
351 |
-
k.data_ptr<scalar_t>(), p);
|
352 |
-
|
353 |
-
break;
|
354 |
-
|
355 |
-
case 6:
|
356 |
-
upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
|
357 |
-
<<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
|
358 |
-
x.data_ptr<scalar_t>(),
|
359 |
-
k.data_ptr<scalar_t>(), p);
|
360 |
-
|
361 |
-
break;
|
362 |
-
|
363 |
-
default:
|
364 |
-
upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(
|
365 |
-
out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
|
366 |
-
k.data_ptr<scalar_t>(), p);
|
367 |
-
}
|
368 |
-
});
|
369 |
-
|
370 |
-
return out;
|
371 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch_utils/ops/__init__.py
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
# Copyright (c) SenseTime Research. All rights reserved.
|
2 |
-
|
3 |
-
#empty
|
|
|
|
|
|
|
|
torch_utils/ops/bias_act.cpp
DELETED
@@ -1,101 +0,0 @@
|
|
1 |
-
// Copyright (c) SenseTime Research. All rights reserved.
|
2 |
-
|
3 |
-
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
4 |
-
//
|
5 |
-
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
6 |
-
// and proprietary rights in and to this software, related documentation
|
7 |
-
// and any modifications thereto. Any use, reproduction, disclosure or
|
8 |
-
// distribution of this software and related documentation without an express
|
9 |
-
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
-
|
11 |
-
#include <torch/extension.h>
|
12 |
-
#include <ATen/cuda/CUDAContext.h>
|
13 |
-
#include <c10/cuda/CUDAGuard.h>
|
14 |
-
#include "bias_act.h"
|
15 |
-
|
16 |
-
//------------------------------------------------------------------------
|
17 |
-
|
18 |
-
static bool has_same_layout(torch::Tensor x, torch::Tensor y)
|
19 |
-
{
|
20 |
-
if (x.dim() != y.dim())
|
21 |
-
return false;
|
22 |
-
for (int64_t i = 0; i < x.dim(); i++)
|
23 |
-
{
|
24 |
-
if (x.size(i) != y.size(i))
|
25 |
-
return false;
|
26 |
-
if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
|
27 |
-
return false;
|
28 |
-
}
|
29 |
-
return true;
|
30 |
-
}
|
31 |
-
|
32 |
-
//------------------------------------------------------------------------
|
33 |
-
|
34 |
-
static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
|
35 |
-
{
|
36 |
-
// Validate arguments.
|
37 |
-
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
|
38 |
-
TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
|
39 |
-
TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
|
40 |
-
TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
|
41 |
-
TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
|
42 |
-
TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
|
43 |
-
TORCH_CHECK(b.dim() == 1, "b must have rank 1");
|
44 |
-
TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
|
45 |
-
TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
|
46 |
-
TORCH_CHECK(grad >= 0, "grad must be non-negative");
|
47 |
-
|
48 |
-
// Validate layout.
|
49 |
-
TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
|
50 |
-
TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
|
51 |
-
TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
|
52 |
-
TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
|
53 |
-
TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
|
54 |
-
|
55 |
-
// Create output tensor.
|
56 |
-
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
57 |
-
torch::Tensor y = torch::empty_like(x);
|
58 |
-
TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
|
59 |
-
|
60 |
-
// Initialize CUDA kernel parameters.
|
61 |
-
bias_act_kernel_params p;
|
62 |
-
p.x = x.data_ptr();
|
63 |
-
p.b = (b.numel()) ? b.data_ptr() : NULL;
|
64 |
-
p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
|
65 |
-
p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
|
66 |
-
p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
|
67 |
-
p.y = y.data_ptr();
|
68 |
-
p.grad = grad;
|
69 |
-
p.act = act;
|
70 |
-
p.alpha = alpha;
|
71 |
-
p.gain = gain;
|
72 |
-
p.clamp = clamp;
|
73 |
-
p.sizeX = (int)x.numel();
|
74 |
-
p.sizeB = (int)b.numel();
|
75 |
-
p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
|
76 |
-
|
77 |
-
// Choose CUDA kernel.
|
78 |
-
void* kernel;
|
79 |
-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
|
80 |
-
{
|
81 |
-
kernel = choose_bias_act_kernel<scalar_t>(p);
|
82 |
-
});
|
83 |
-
TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
|
84 |
-
|
85 |
-
// Launch CUDA kernel.
|
86 |
-
p.loopX = 4;
|
87 |
-
int blockSize = 4 * 32;
|
88 |
-
int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
|
89 |
-
void* args[] = {&p};
|
90 |
-
AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
|
91 |
-
return y;
|
92 |
-
}
|
93 |
-
|
94 |
-
//------------------------------------------------------------------------
|
95 |
-
|
96 |
-
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
97 |
-
{
|
98 |
-
m.def("bias_act", &bias_act);
|
99 |
-
}
|
100 |
-
|
101 |
-
//------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch_utils/ops/bias_act.cu
DELETED
@@ -1,175 +0,0 @@
|
|
1 |
-
// Copyright (c) SenseTime Research. All rights reserved.
|
2 |
-
|
3 |
-
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
4 |
-
//
|
5 |
-
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
6 |
-
// and proprietary rights in and to this software, related documentation
|
7 |
-
// and any modifications thereto. Any use, reproduction, disclosure or
|
8 |
-
// distribution of this software and related documentation without an express
|
9 |
-
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
-
|
11 |
-
#include <c10/util/Half.h>
|
12 |
-
#include "bias_act.h"
|
13 |
-
|
14 |
-
//------------------------------------------------------------------------
|
15 |
-
// Helpers.
|
16 |
-
|
17 |
-
template <class T> struct InternalType;
|
18 |
-
template <> struct InternalType<double> { typedef double scalar_t; };
|
19 |
-
template <> struct InternalType<float> { typedef float scalar_t; };
|
20 |
-
template <> struct InternalType<c10::Half> { typedef float scalar_t; };
|
21 |
-
|
22 |
-
//------------------------------------------------------------------------
|
23 |
-
// CUDA kernel.
|
24 |
-
|
25 |
-
template <class T, int A>
|
26 |
-
__global__ void bias_act_kernel(bias_act_kernel_params p)
|
27 |
-
{
|
28 |
-
typedef typename InternalType<T>::scalar_t scalar_t;
|
29 |
-
int G = p.grad;
|
30 |
-
scalar_t alpha = (scalar_t)p.alpha;
|
31 |
-
scalar_t gain = (scalar_t)p.gain;
|
32 |
-
scalar_t clamp = (scalar_t)p.clamp;
|
33 |
-
scalar_t one = (scalar_t)1;
|
34 |
-
scalar_t two = (scalar_t)2;
|
35 |
-
scalar_t expRange = (scalar_t)80;
|
36 |
-
scalar_t halfExpRange = (scalar_t)40;
|
37 |
-
scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
|
38 |
-
scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
|
39 |
-
|
40 |
-
// Loop over elements.
|
41 |
-
int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
|
42 |
-
for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
|
43 |
-
{
|
44 |
-
// Load.
|
45 |
-
scalar_t x = (scalar_t)((const T*)p.x)[xi];
|
46 |
-
scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
|
47 |
-
scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
|
48 |
-
scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
|
49 |
-
scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
|
50 |
-
scalar_t yy = (gain != 0) ? yref / gain : 0;
|
51 |
-
scalar_t y = 0;
|
52 |
-
|
53 |
-
// Apply bias.
|
54 |
-
((G == 0) ? x : xref) += b;
|
55 |
-
|
56 |
-
// linear
|
57 |
-
if (A == 1)
|
58 |
-
{
|
59 |
-
if (G == 0) y = x;
|
60 |
-
if (G == 1) y = x;
|
61 |
-
}
|
62 |
-
|
63 |
-
// relu
|
64 |
-
if (A == 2)
|
65 |
-
{
|
66 |
-
if (G == 0) y = (x > 0) ? x : 0;
|
67 |
-
if (G == 1) y = (yy > 0) ? x : 0;
|
68 |
-
}
|
69 |
-
|
70 |
-
// lrelu
|
71 |
-
if (A == 3)
|
72 |
-
{
|
73 |
-
if (G == 0) y = (x > 0) ? x : x * alpha;
|
74 |
-
if (G == 1) y = (yy > 0) ? x : x * alpha;
|
75 |
-
}
|
76 |
-
|
77 |
-
// tanh
|
78 |
-
if (A == 4)
|
79 |
-
{
|
80 |
-
if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
|
81 |
-
if (G == 1) y = x * (one - yy * yy);
|
82 |
-
if (G == 2) y = x * (one - yy * yy) * (-two * yy);
|
83 |
-
}
|
84 |
-
|
85 |
-
// sigmoid
|
86 |
-
if (A == 5)
|
87 |
-
{
|
88 |
-
if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
|
89 |
-
if (G == 1) y = x * yy * (one - yy);
|
90 |
-
if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
|
91 |
-
}
|
92 |
-
|
93 |
-
// elu
|
94 |
-
if (A == 6)
|
95 |
-
{
|
96 |
-
if (G == 0) y = (x >= 0) ? x : exp(x) - one;
|
97 |
-
if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
|
98 |
-
if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
|
99 |
-
}
|
100 |
-
|
101 |
-
// selu
|
102 |
-
if (A == 7)
|
103 |
-
{
|
104 |
-
if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
|
105 |
-
if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
|
106 |
-
if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
|
107 |
-
}
|
108 |
-
|
109 |
-
// softplus
|
110 |
-
if (A == 8)
|
111 |
-
{
|
112 |
-
if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
|
113 |
-
if (G == 1) y = x * (one - exp(-yy));
|
114 |
-
if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
|
115 |
-
}
|
116 |
-
|
117 |
-
// swish
|
118 |
-
if (A == 9)
|
119 |
-
{
|
120 |
-
if (G == 0)
|
121 |
-
y = (x < -expRange) ? 0 : x / (exp(-x) + one);
|
122 |
-
else
|
123 |
-
{
|
124 |
-
scalar_t c = exp(xref);
|
125 |
-
scalar_t d = c + one;
|
126 |
-
if (G == 1)
|
127 |
-
y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
|
128 |
-
else
|
129 |
-
y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
|
130 |
-
yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
|
131 |
-
}
|
132 |
-
}
|
133 |
-
|
134 |
-
// Apply gain.
|
135 |
-
y *= gain * dy;
|
136 |
-
|
137 |
-
// Clamp.
|
138 |
-
if (clamp >= 0)
|
139 |
-
{
|
140 |
-
if (G == 0)
|
141 |
-
y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
|
142 |
-
else
|
143 |
-
y = (yref > -clamp & yref < clamp) ? y : 0;
|
144 |
-
}
|
145 |
-
|
146 |
-
// Store.
|
147 |
-
((T*)p.y)[xi] = (T)y;
|
148 |
-
}
|
149 |
-
}
|
150 |
-
|
151 |
-
//------------------------------------------------------------------------
|
152 |
-
// CUDA kernel selection.
|
153 |
-
|
154 |
-
template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p)
|
155 |
-
{
|
156 |
-
if (p.act == 1) return (void*)bias_act_kernel<T, 1>;
|
157 |
-
if (p.act == 2) return (void*)bias_act_kernel<T, 2>;
|
158 |
-
if (p.act == 3) return (void*)bias_act_kernel<T, 3>;
|
159 |
-
if (p.act == 4) return (void*)bias_act_kernel<T, 4>;
|
160 |
-
if (p.act == 5) return (void*)bias_act_kernel<T, 5>;
|
161 |
-
if (p.act == 6) return (void*)bias_act_kernel<T, 6>;
|
162 |
-
if (p.act == 7) return (void*)bias_act_kernel<T, 7>;
|
163 |
-
if (p.act == 8) return (void*)bias_act_kernel<T, 8>;
|
164 |
-
if (p.act == 9) return (void*)bias_act_kernel<T, 9>;
|
165 |
-
return NULL;
|
166 |
-
}
|
167 |
-
|
168 |
-
//------------------------------------------------------------------------
|
169 |
-
// Template specializations.
|
170 |
-
|
171 |
-
template void* choose_bias_act_kernel<double> (const bias_act_kernel_params& p);
|
172 |
-
template void* choose_bias_act_kernel<float> (const bias_act_kernel_params& p);
|
173 |
-
template void* choose_bias_act_kernel<c10::Half> (const bias_act_kernel_params& p);
|
174 |
-
|
175 |
-
//------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch_utils/ops/bias_act.h
DELETED
@@ -1,40 +0,0 @@
|
|
1 |
-
// Copyright (c) SenseTime Research. All rights reserved.
|
2 |
-
|
3 |
-
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
4 |
-
//
|
5 |
-
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
6 |
-
// and proprietary rights in and to this software, related documentation
|
7 |
-
// and any modifications thereto. Any use, reproduction, disclosure or
|
8 |
-
// distribution of this software and related documentation without an express
|
9 |
-
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
-
|
11 |
-
//------------------------------------------------------------------------
|
12 |
-
// CUDA kernel parameters.
|
13 |
-
|
14 |
-
struct bias_act_kernel_params
|
15 |
-
{
|
16 |
-
const void* x; // [sizeX]
|
17 |
-
const void* b; // [sizeB] or NULL
|
18 |
-
const void* xref; // [sizeX] or NULL
|
19 |
-
const void* yref; // [sizeX] or NULL
|
20 |
-
const void* dy; // [sizeX] or NULL
|
21 |
-
void* y; // [sizeX]
|
22 |
-
|
23 |
-
int grad;
|
24 |
-
int act;
|
25 |
-
float alpha;
|
26 |
-
float gain;
|
27 |
-
float clamp;
|
28 |
-
|
29 |
-
int sizeX;
|
30 |
-
int sizeB;
|
31 |
-
int stepB;
|
32 |
-
int loopX;
|
33 |
-
};
|
34 |
-
|
35 |
-
//------------------------------------------------------------------------
|
36 |
-
// CUDA kernel selection.
|
37 |
-
|
38 |
-
template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p);
|
39 |
-
|
40 |
-
//------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch_utils/ops/bias_act.py
DELETED
@@ -1,214 +0,0 @@
|
|
1 |
-
# Copyright (c) SenseTime Research. All rights reserved.
|
2 |
-
|
3 |
-
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
4 |
-
#
|
5 |
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
6 |
-
# and proprietary rights in and to this software, related documentation
|
7 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
8 |
-
# distribution of this software and related documentation without an express
|
9 |
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
-
|
11 |
-
"""Custom PyTorch ops for efficient bias and activation."""
|
12 |
-
|
13 |
-
import os
|
14 |
-
import warnings
|
15 |
-
import numpy as np
|
16 |
-
import torch
|
17 |
-
import dnnlib
|
18 |
-
import traceback
|
19 |
-
|
20 |
-
from .. import custom_ops
|
21 |
-
from .. import misc
|
22 |
-
|
23 |
-
#----------------------------------------------------------------------------
|
24 |
-
|
25 |
-
activation_funcs = {
|
26 |
-
'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
|
27 |
-
'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
|
28 |
-
'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
|
29 |
-
'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
|
30 |
-
'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
|
31 |
-
'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
|
32 |
-
'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
|
33 |
-
'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
|
34 |
-
'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
|
35 |
-
}
|
36 |
-
|
37 |
-
#----------------------------------------------------------------------------
|
38 |
-
|
39 |
-
_inited = False
|
40 |
-
_plugin = None
|
41 |
-
_null_tensor = torch.empty([0])
|
42 |
-
|
43 |
-
def _init():
|
44 |
-
global _inited, _plugin
|
45 |
-
if not _inited:
|
46 |
-
_inited = True
|
47 |
-
sources = ['bias_act.cpp', 'bias_act.cu']
|
48 |
-
sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
|
49 |
-
try:
|
50 |
-
_plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
|
51 |
-
except:
|
52 |
-
warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
|
53 |
-
return _plugin is not None
|
54 |
-
|
55 |
-
#----------------------------------------------------------------------------
|
56 |
-
|
57 |
-
def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
|
58 |
-
r"""Fused bias and activation function.
|
59 |
-
|
60 |
-
Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
|
61 |
-
and scales the result by `gain`. Each of the steps is optional. In most cases,
|
62 |
-
the fused op is considerably more efficient than performing the same calculation
|
63 |
-
using standard PyTorch ops. It supports first and second order gradients,
|
64 |
-
but not third order gradients.
|
65 |
-
|
66 |
-
Args:
|
67 |
-
x: Input activation tensor. Can be of any shape.
|
68 |
-
b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
|
69 |
-
as `x`. The shape must be known, and it must match the dimension of `x`
|
70 |
-
corresponding to `dim`.
|
71 |
-
dim: The dimension in `x` corresponding to the elements of `b`.
|
72 |
-
The value of `dim` is ignored if `b` is not specified.
|
73 |
-
act: Name of the activation function to evaluate, or `"linear"` to disable.
|
74 |
-
Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
|
75 |
-
See `activation_funcs` for a full list. `None` is not allowed.
|
76 |
-
alpha: Shape parameter for the activation function, or `None` to use the default.
|
77 |
-
gain: Scaling factor for the output tensor, or `None` to use default.
|
78 |
-
See `activation_funcs` for the default scaling of each activation function.
|
79 |
-
If unsure, consider specifying 1.
|
80 |
-
clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
|
81 |
-
the clamping (default).
|
82 |
-
impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
|
83 |
-
|
84 |
-
Returns:
|
85 |
-
Tensor of the same shape and datatype as `x`.
|
86 |
-
"""
|
87 |
-
assert isinstance(x, torch.Tensor)
|
88 |
-
assert impl in ['ref', 'cuda']
|
89 |
-
if impl == 'cuda' and x.device.type == 'cuda' and _init():
|
90 |
-
return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
|
91 |
-
return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
|
92 |
-
|
93 |
-
#----------------------------------------------------------------------------
|
94 |
-
|
95 |
-
@misc.profiled_function
|
96 |
-
def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
|
97 |
-
"""Slow reference implementation of `bias_act()` using standard TensorFlow ops.
|
98 |
-
"""
|
99 |
-
assert isinstance(x, torch.Tensor)
|
100 |
-
assert clamp is None or clamp >= 0
|
101 |
-
spec = activation_funcs[act]
|
102 |
-
alpha = float(alpha if alpha is not None else spec.def_alpha)
|
103 |
-
gain = float(gain if gain is not None else spec.def_gain)
|
104 |
-
clamp = float(clamp if clamp is not None else -1)
|
105 |
-
|
106 |
-
# Add bias.
|
107 |
-
if b is not None:
|
108 |
-
assert isinstance(b, torch.Tensor) and b.ndim == 1
|
109 |
-
assert 0 <= dim < x.ndim
|
110 |
-
assert b.shape[0] == x.shape[dim]
|
111 |
-
x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
|
112 |
-
|
113 |
-
# Evaluate activation function.
|
114 |
-
alpha = float(alpha)
|
115 |
-
x = spec.func(x, alpha=alpha)
|
116 |
-
|
117 |
-
# Scale by gain.
|
118 |
-
gain = float(gain)
|
119 |
-
if gain != 1:
|
120 |
-
x = x * gain
|
121 |
-
|
122 |
-
# Clamp.
|
123 |
-
if clamp >= 0:
|
124 |
-
x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
|
125 |
-
return x
|
126 |
-
|
127 |
-
#----------------------------------------------------------------------------
|
128 |
-
|
129 |
-
_bias_act_cuda_cache = dict()
|
130 |
-
|
131 |
-
def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
|
132 |
-
"""Fast CUDA implementation of `bias_act()` using custom ops.
|
133 |
-
"""
|
134 |
-
# Parse arguments.
|
135 |
-
assert clamp is None or clamp >= 0
|
136 |
-
spec = activation_funcs[act]
|
137 |
-
alpha = float(alpha if alpha is not None else spec.def_alpha)
|
138 |
-
gain = float(gain if gain is not None else spec.def_gain)
|
139 |
-
clamp = float(clamp if clamp is not None else -1)
|
140 |
-
|
141 |
-
# Lookup from cache.
|
142 |
-
key = (dim, act, alpha, gain, clamp)
|
143 |
-
if key in _bias_act_cuda_cache:
|
144 |
-
return _bias_act_cuda_cache[key]
|
145 |
-
|
146 |
-
# Forward op.
|
147 |
-
class BiasActCuda(torch.autograd.Function):
|
148 |
-
@staticmethod
|
149 |
-
def forward(ctx, x, b): # pylint: disable=arguments-differ
|
150 |
-
ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format
|
151 |
-
x = x.contiguous(memory_format=ctx.memory_format)
|
152 |
-
b = b.contiguous() if b is not None else _null_tensor
|
153 |
-
y = x
|
154 |
-
if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
|
155 |
-
y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
|
156 |
-
ctx.save_for_backward(
|
157 |
-
x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
|
158 |
-
b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
|
159 |
-
y if 'y' in spec.ref else _null_tensor)
|
160 |
-
return y
|
161 |
-
|
162 |
-
@staticmethod
|
163 |
-
def backward(ctx, dy): # pylint: disable=arguments-differ
|
164 |
-
dy = dy.contiguous(memory_format=ctx.memory_format)
|
165 |
-
x, b, y = ctx.saved_tensors
|
166 |
-
dx = None
|
167 |
-
db = None
|
168 |
-
|
169 |
-
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
|
170 |
-
dx = dy
|
171 |
-
if act != 'linear' or gain != 1 or clamp >= 0:
|
172 |
-
dx = BiasActCudaGrad.apply(dy, x, b, y)
|
173 |
-
|
174 |
-
if ctx.needs_input_grad[1]:
|
175 |
-
db = dx.sum([i for i in range(dx.ndim) if i != dim])
|
176 |
-
|
177 |
-
return dx, db
|
178 |
-
|
179 |
-
# Backward op.
|
180 |
-
class BiasActCudaGrad(torch.autograd.Function):
|
181 |
-
@staticmethod
|
182 |
-
def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
|
183 |
-
ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format
|
184 |
-
dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
|
185 |
-
ctx.save_for_backward(
|
186 |
-
dy if spec.has_2nd_grad else _null_tensor,
|
187 |
-
x, b, y)
|
188 |
-
return dx
|
189 |
-
|
190 |
-
@staticmethod
|
191 |
-
def backward(ctx, d_dx): # pylint: disable=arguments-differ
|
192 |
-
d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
|
193 |
-
dy, x, b, y = ctx.saved_tensors
|
194 |
-
d_dy = None
|
195 |
-
d_x = None
|
196 |
-
d_b = None
|
197 |
-
d_y = None
|
198 |
-
|
199 |
-
if ctx.needs_input_grad[0]:
|
200 |
-
d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
|
201 |
-
|
202 |
-
if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
|
203 |
-
d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
|
204 |
-
|
205 |
-
if spec.has_2nd_grad and ctx.needs_input_grad[2]:
|
206 |
-
d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
|
207 |
-
|
208 |
-
return d_dy, d_x, d_b, d_y
|
209 |
-
|
210 |
-
# Add to cache.
|
211 |
-
_bias_act_cuda_cache[key] = BiasActCuda
|
212 |
-
return BiasActCuda
|
213 |
-
|
214 |
-
#----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch_utils/ops/conv2d_gradfix.py
DELETED
@@ -1,172 +0,0 @@
|
|
1 |
-
# Copyright (c) SenseTime Research. All rights reserved.
|
2 |
-
|
3 |
-
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
4 |
-
#
|
5 |
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
6 |
-
# and proprietary rights in and to this software, related documentation
|
7 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
8 |
-
# distribution of this software and related documentation without an express
|
9 |
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
-
|
11 |
-
"""Custom replacement for `torch.nn.functional.conv2d` that supports
|
12 |
-
arbitrarily high order gradients with zero performance penalty."""
|
13 |
-
|
14 |
-
import warnings
|
15 |
-
import contextlib
|
16 |
-
import torch
|
17 |
-
|
18 |
-
# pylint: disable=redefined-builtin
|
19 |
-
# pylint: disable=arguments-differ
|
20 |
-
# pylint: disable=protected-access
|
21 |
-
|
22 |
-
#----------------------------------------------------------------------------
|
23 |
-
|
24 |
-
enabled = False # Enable the custom op by setting this to true.
|
25 |
-
weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
|
26 |
-
|
27 |
-
@contextlib.contextmanager
|
28 |
-
def no_weight_gradients():
|
29 |
-
global weight_gradients_disabled
|
30 |
-
old = weight_gradients_disabled
|
31 |
-
weight_gradients_disabled = True
|
32 |
-
yield
|
33 |
-
weight_gradients_disabled = old
|
34 |
-
|
35 |
-
#----------------------------------------------------------------------------
|
36 |
-
|
37 |
-
def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
38 |
-
if _should_use_custom_op(input):
|
39 |
-
return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
|
40 |
-
return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
41 |
-
|
42 |
-
def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
|
43 |
-
if _should_use_custom_op(input):
|
44 |
-
return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
|
45 |
-
return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
|
46 |
-
|
47 |
-
#----------------------------------------------------------------------------
|
48 |
-
|
49 |
-
def _should_use_custom_op(input):
|
50 |
-
assert isinstance(input, torch.Tensor)
|
51 |
-
if (not enabled) or (not torch.backends.cudnn.enabled):
|
52 |
-
return False
|
53 |
-
if input.device.type != 'cuda':
|
54 |
-
return False
|
55 |
-
if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
|
56 |
-
return True
|
57 |
-
warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().')
|
58 |
-
return False
|
59 |
-
|
60 |
-
def _tuple_of_ints(xs, ndim):
|
61 |
-
xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
|
62 |
-
assert len(xs) == ndim
|
63 |
-
assert all(isinstance(x, int) for x in xs)
|
64 |
-
return xs
|
65 |
-
|
66 |
-
#----------------------------------------------------------------------------
|
67 |
-
|
68 |
-
_conv2d_gradfix_cache = dict()
|
69 |
-
|
70 |
-
def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
|
71 |
-
# Parse arguments.
|
72 |
-
ndim = 2
|
73 |
-
weight_shape = tuple(weight_shape)
|
74 |
-
stride = _tuple_of_ints(stride, ndim)
|
75 |
-
padding = _tuple_of_ints(padding, ndim)
|
76 |
-
output_padding = _tuple_of_ints(output_padding, ndim)
|
77 |
-
dilation = _tuple_of_ints(dilation, ndim)
|
78 |
-
|
79 |
-
# Lookup from cache.
|
80 |
-
key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
|
81 |
-
if key in _conv2d_gradfix_cache:
|
82 |
-
return _conv2d_gradfix_cache[key]
|
83 |
-
|
84 |
-
# Validate arguments.
|
85 |
-
assert groups >= 1
|
86 |
-
assert len(weight_shape) == ndim + 2
|
87 |
-
assert all(stride[i] >= 1 for i in range(ndim))
|
88 |
-
assert all(padding[i] >= 0 for i in range(ndim))
|
89 |
-
assert all(dilation[i] >= 0 for i in range(ndim))
|
90 |
-
if not transpose:
|
91 |
-
assert all(output_padding[i] == 0 for i in range(ndim))
|
92 |
-
else: # transpose
|
93 |
-
assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
|
94 |
-
|
95 |
-
# Helpers.
|
96 |
-
common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
|
97 |
-
def calc_output_padding(input_shape, output_shape):
|
98 |
-
if transpose:
|
99 |
-
return [0, 0]
|
100 |
-
return [
|
101 |
-
input_shape[i + 2]
|
102 |
-
- (output_shape[i + 2] - 1) * stride[i]
|
103 |
-
- (1 - 2 * padding[i])
|
104 |
-
- dilation[i] * (weight_shape[i + 2] - 1)
|
105 |
-
for i in range(ndim)
|
106 |
-
]
|
107 |
-
|
108 |
-
# Forward & backward.
|
109 |
-
class Conv2d(torch.autograd.Function):
|
110 |
-
@staticmethod
|
111 |
-
def forward(ctx, input, weight, bias):
|
112 |
-
assert weight.shape == weight_shape
|
113 |
-
if not transpose:
|
114 |
-
output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
|
115 |
-
else: # transpose
|
116 |
-
output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
|
117 |
-
ctx.save_for_backward(input, weight)
|
118 |
-
return output
|
119 |
-
|
120 |
-
@staticmethod
|
121 |
-
def backward(ctx, grad_output):
|
122 |
-
input, weight = ctx.saved_tensors
|
123 |
-
grad_input = None
|
124 |
-
grad_weight = None
|
125 |
-
grad_bias = None
|
126 |
-
|
127 |
-
if ctx.needs_input_grad[0]:
|
128 |
-
p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
|
129 |
-
grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None)
|
130 |
-
assert grad_input.shape == input.shape
|
131 |
-
|
132 |
-
if ctx.needs_input_grad[1] and not weight_gradients_disabled:
|
133 |
-
grad_weight = Conv2dGradWeight.apply(grad_output, input)
|
134 |
-
assert grad_weight.shape == weight_shape
|
135 |
-
|
136 |
-
if ctx.needs_input_grad[2]:
|
137 |
-
grad_bias = grad_output.sum([0, 2, 3])
|
138 |
-
|
139 |
-
return grad_input, grad_weight, grad_bias
|
140 |
-
|
141 |
-
# Gradient with respect to the weights.
|
142 |
-
class Conv2dGradWeight(torch.autograd.Function):
|
143 |
-
@staticmethod
|
144 |
-
def forward(ctx, grad_output, input):
|
145 |
-
op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight')
|
146 |
-
flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
|
147 |
-
grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
|
148 |
-
assert grad_weight.shape == weight_shape
|
149 |
-
ctx.save_for_backward(grad_output, input)
|
150 |
-
return grad_weight
|
151 |
-
|
152 |
-
@staticmethod
|
153 |
-
def backward(ctx, grad2_grad_weight):
|
154 |
-
grad_output, input = ctx.saved_tensors
|
155 |
-
grad2_grad_output = None
|
156 |
-
grad2_input = None
|
157 |
-
|
158 |
-
if ctx.needs_input_grad[0]:
|
159 |
-
grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
|
160 |
-
assert grad2_grad_output.shape == grad_output.shape
|
161 |
-
|
162 |
-
if ctx.needs_input_grad[1]:
|
163 |
-
p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
|
164 |
-
grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None)
|
165 |
-
assert grad2_input.shape == input.shape
|
166 |
-
|
167 |
-
return grad2_grad_output, grad2_input
|
168 |
-
|
169 |
-
_conv2d_gradfix_cache[key] = Conv2d
|
170 |
-
return Conv2d
|
171 |
-
|
172 |
-
#----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch_utils/ops/conv2d_resample.py
DELETED
@@ -1,158 +0,0 @@
|
|
1 |
-
# Copyright (c) SenseTime Research. All rights reserved.
|
2 |
-
|
3 |
-
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
4 |
-
#
|
5 |
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
6 |
-
# and proprietary rights in and to this software, related documentation
|
7 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
8 |
-
# distribution of this software and related documentation without an express
|
9 |
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
-
|
11 |
-
"""2D convolution with optional up/downsampling."""
|
12 |
-
|
13 |
-
import torch
|
14 |
-
|
15 |
-
from .. import misc
|
16 |
-
from . import conv2d_gradfix
|
17 |
-
from . import upfirdn2d
|
18 |
-
from .upfirdn2d import _parse_padding
|
19 |
-
from .upfirdn2d import _get_filter_size
|
20 |
-
|
21 |
-
#----------------------------------------------------------------------------
|
22 |
-
|
23 |
-
def _get_weight_shape(w):
|
24 |
-
with misc.suppress_tracer_warnings(): # this value will be treated as a constant
|
25 |
-
shape = [int(sz) for sz in w.shape]
|
26 |
-
misc.assert_shape(w, shape)
|
27 |
-
return shape
|
28 |
-
|
29 |
-
#----------------------------------------------------------------------------
|
30 |
-
|
31 |
-
def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
|
32 |
-
"""Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
|
33 |
-
"""
|
34 |
-
out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
|
35 |
-
|
36 |
-
# Flip weight if requested.
|
37 |
-
if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
|
38 |
-
w = w.flip([2, 3])
|
39 |
-
|
40 |
-
# Workaround performance pitfall in cuDNN 8.0.5, triggered when using
|
41 |
-
# 1x1 kernel + memory_format=channels_last + less than 64 channels.
|
42 |
-
if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose:
|
43 |
-
if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64:
|
44 |
-
if out_channels <= 4 and groups == 1:
|
45 |
-
in_shape = x.shape
|
46 |
-
x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1])
|
47 |
-
x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]])
|
48 |
-
else:
|
49 |
-
x = x.to(memory_format=torch.contiguous_format)
|
50 |
-
w = w.to(memory_format=torch.contiguous_format)
|
51 |
-
x = conv2d_gradfix.conv2d(x, w, groups=groups)
|
52 |
-
return x.to(memory_format=torch.channels_last)
|
53 |
-
|
54 |
-
# Otherwise => execute using conv2d_gradfix.
|
55 |
-
op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
|
56 |
-
return op(x, w, stride=stride, padding=padding, groups=groups)
|
57 |
-
|
58 |
-
#----------------------------------------------------------------------------
|
59 |
-
|
60 |
-
@misc.profiled_function
|
61 |
-
def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
|
62 |
-
r"""2D convolution with optional up/downsampling.
|
63 |
-
|
64 |
-
Padding is performed only once at the beginning, not between the operations.
|
65 |
-
|
66 |
-
Args:
|
67 |
-
x: Input tensor of shape
|
68 |
-
`[batch_size, in_channels, in_height, in_width]`.
|
69 |
-
w: Weight tensor of shape
|
70 |
-
`[out_channels, in_channels//groups, kernel_height, kernel_width]`.
|
71 |
-
f: Low-pass filter for up/downsampling. Must be prepared beforehand by
|
72 |
-
calling upfirdn2d.setup_filter(). None = identity (default).
|
73 |
-
up: Integer upsampling factor (default: 1).
|
74 |
-
down: Integer downsampling factor (default: 1).
|
75 |
-
padding: Padding with respect to the upsampled image. Can be a single number
|
76 |
-
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
77 |
-
(default: 0).
|
78 |
-
groups: Split input channels into N groups (default: 1).
|
79 |
-
flip_weight: False = convolution, True = correlation (default: True).
|
80 |
-
flip_filter: False = convolution, True = correlation (default: False).
|
81 |
-
|
82 |
-
Returns:
|
83 |
-
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
84 |
-
"""
|
85 |
-
# Validate arguments.
|
86 |
-
assert isinstance(x, torch.Tensor) and (x.ndim == 4)
|
87 |
-
assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
|
88 |
-
assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
|
89 |
-
assert isinstance(up, int) and (up >= 1)
|
90 |
-
assert isinstance(down, int) and (down >= 1)
|
91 |
-
assert isinstance(groups, int) and (groups >= 1)
|
92 |
-
out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
|
93 |
-
fw, fh = _get_filter_size(f)
|
94 |
-
px0, px1, py0, py1 = _parse_padding(padding)
|
95 |
-
|
96 |
-
# Adjust padding to account for up/downsampling.
|
97 |
-
if up > 1:
|
98 |
-
px0 += (fw + up - 1) // 2
|
99 |
-
px1 += (fw - up) // 2
|
100 |
-
py0 += (fh + up - 1) // 2
|
101 |
-
py1 += (fh - up) // 2
|
102 |
-
if down > 1:
|
103 |
-
px0 += (fw - down + 1) // 2
|
104 |
-
px1 += (fw - down) // 2
|
105 |
-
py0 += (fh - down + 1) // 2
|
106 |
-
py1 += (fh - down) // 2
|
107 |
-
|
108 |
-
# Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
|
109 |
-
if kw == 1 and kh == 1 and (down > 1 and up == 1):
|
110 |
-
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
|
111 |
-
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
112 |
-
return x
|
113 |
-
|
114 |
-
# Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
|
115 |
-
if kw == 1 and kh == 1 and (up > 1 and down == 1):
|
116 |
-
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
117 |
-
x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
|
118 |
-
return x
|
119 |
-
|
120 |
-
# Fast path: downsampling only => use strided convolution.
|
121 |
-
if down > 1 and up == 1:
|
122 |
-
x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
|
123 |
-
x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
|
124 |
-
return x
|
125 |
-
|
126 |
-
# Fast path: upsampling with optional downsampling => use transpose strided convolution.
|
127 |
-
if up > 1:
|
128 |
-
if groups == 1:
|
129 |
-
w = w.transpose(0, 1)
|
130 |
-
else:
|
131 |
-
w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
|
132 |
-
w = w.transpose(1, 2)
|
133 |
-
w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
|
134 |
-
px0 -= kw - 1
|
135 |
-
px1 -= kw - up
|
136 |
-
py0 -= kh - 1
|
137 |
-
py1 -= kh - up
|
138 |
-
pxt = max(min(-px0, -px1), 0)
|
139 |
-
pyt = max(min(-py0, -py1), 0)
|
140 |
-
x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
|
141 |
-
x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
|
142 |
-
if down > 1:
|
143 |
-
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
|
144 |
-
return x
|
145 |
-
|
146 |
-
# Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
|
147 |
-
if up == 1 and down == 1:
|
148 |
-
if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
|
149 |
-
return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
|
150 |
-
|
151 |
-
# Fallback: Generic reference implementation.
|
152 |
-
x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
|
153 |
-
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
154 |
-
if down > 1:
|
155 |
-
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
|
156 |
-
return x
|
157 |
-
|
158 |
-
#----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch_utils/ops/filtered_lrelu.cpp
DELETED
@@ -1,300 +0,0 @@
|
|
1 |
-
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
//
|
3 |
-
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
-
// and proprietary rights in and to this software, related documentation
|
5 |
-
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
// distribution of this software and related documentation without an express
|
7 |
-
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
-
|
9 |
-
#include <torch/extension.h>
|
10 |
-
#include <ATen/cuda/CUDAContext.h>
|
11 |
-
#include <c10/cuda/CUDAGuard.h>
|
12 |
-
#include "filtered_lrelu.h"
|
13 |
-
|
14 |
-
//------------------------------------------------------------------------
|
15 |
-
|
16 |
-
static std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu(
|
17 |
-
torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b, torch::Tensor si,
|
18 |
-
int up, int down, int px0, int px1, int py0, int py1, int sx, int sy, float gain, float slope, float clamp, bool flip_filters, bool writeSigns)
|
19 |
-
{
|
20 |
-
// Set CUDA device.
|
21 |
-
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
|
22 |
-
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
23 |
-
|
24 |
-
// Validate arguments.
|
25 |
-
TORCH_CHECK(fu.device() == x.device() && fd.device() == x.device() && b.device() == x.device(), "all input tensors must reside on the same device");
|
26 |
-
TORCH_CHECK(fu.dtype() == torch::kFloat && fd.dtype() == torch::kFloat, "fu and fd must be float32");
|
27 |
-
TORCH_CHECK(b.dtype() == x.dtype(), "x and b must have the same dtype");
|
28 |
-
TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat, "x and b must be float16 or float32");
|
29 |
-
TORCH_CHECK(x.dim() == 4, "x must be rank 4");
|
30 |
-
TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large");
|
31 |
-
TORCH_CHECK(x.numel() > 0, "x is empty");
|
32 |
-
TORCH_CHECK((fu.dim() == 1 || fu.dim() == 2) && (fd.dim() == 1 || fd.dim() == 2), "fu and fd must be rank 1 or 2");
|
33 |
-
TORCH_CHECK(fu.size(0) <= INT_MAX && fu.size(-1) <= INT_MAX, "fu is too large");
|
34 |
-
TORCH_CHECK(fd.size(0) <= INT_MAX && fd.size(-1) <= INT_MAX, "fd is too large");
|
35 |
-
TORCH_CHECK(fu.numel() > 0, "fu is empty");
|
36 |
-
TORCH_CHECK(fd.numel() > 0, "fd is empty");
|
37 |
-
TORCH_CHECK(b.dim() == 1 && b.size(0) == x.size(1), "b must be a vector with the same number of channels as x");
|
38 |
-
TORCH_CHECK(up >= 1 && down >= 1, "up and down must be at least 1");
|
39 |
-
|
40 |
-
// Figure out how much shared memory is available on the device.
|
41 |
-
int maxSharedBytes = 0;
|
42 |
-
AT_CUDA_CHECK(cudaDeviceGetAttribute(&maxSharedBytes, cudaDevAttrMaxSharedMemoryPerBlockOptin, x.device().index()));
|
43 |
-
int sharedKB = maxSharedBytes >> 10;
|
44 |
-
|
45 |
-
// Populate enough launch parameters to check if a CUDA kernel exists.
|
46 |
-
filtered_lrelu_kernel_params p;
|
47 |
-
p.up = up;
|
48 |
-
p.down = down;
|
49 |
-
p.fuShape = make_int2((int)fu.size(-1), fu.dim() == 2 ? (int)fu.size(0) : 0); // shape [n, 0] indicates separable filter.
|
50 |
-
p.fdShape = make_int2((int)fd.size(-1), fd.dim() == 2 ? (int)fd.size(0) : 0);
|
51 |
-
filtered_lrelu_kernel_spec test_spec = choose_filtered_lrelu_kernel<float, int32_t, false, false>(p, sharedKB);
|
52 |
-
if (!test_spec.exec)
|
53 |
-
{
|
54 |
-
// No kernel found - return empty tensors and indicate missing kernel with return code of -1.
|
55 |
-
return std::make_tuple(torch::Tensor(), torch::Tensor(), -1);
|
56 |
-
}
|
57 |
-
|
58 |
-
// Input/output element size.
|
59 |
-
int64_t sz = (x.dtype() == torch::kHalf) ? 2 : 4;
|
60 |
-
|
61 |
-
// Input sizes.
|
62 |
-
int64_t xw = (int)x.size(3);
|
63 |
-
int64_t xh = (int)x.size(2);
|
64 |
-
int64_t fut_w = (int)fu.size(-1) - 1;
|
65 |
-
int64_t fut_h = (int)fu.size(0) - 1;
|
66 |
-
int64_t fdt_w = (int)fd.size(-1) - 1;
|
67 |
-
int64_t fdt_h = (int)fd.size(0) - 1;
|
68 |
-
|
69 |
-
// Logical size of upsampled buffer.
|
70 |
-
int64_t cw = xw * up + (px0 + px1) - fut_w;
|
71 |
-
int64_t ch = xh * up + (py0 + py1) - fut_h;
|
72 |
-
TORCH_CHECK(cw > fdt_w && ch > fdt_h, "upsampled buffer must be at least the size of downsampling filter");
|
73 |
-
TORCH_CHECK(cw <= INT_MAX && ch <= INT_MAX, "upsampled buffer is too large");
|
74 |
-
|
75 |
-
// Compute output size and allocate.
|
76 |
-
int64_t yw = (cw - fdt_w + (down - 1)) / down;
|
77 |
-
int64_t yh = (ch - fdt_h + (down - 1)) / down;
|
78 |
-
TORCH_CHECK(yw > 0 && yh > 0, "output must be at least 1x1");
|
79 |
-
TORCH_CHECK(yw <= INT_MAX && yh <= INT_MAX, "output is too large");
|
80 |
-
torch::Tensor y = torch::empty({x.size(0), x.size(1), yh, yw}, x.options(), x.suggest_memory_format());
|
81 |
-
|
82 |
-
// Allocate sign tensor.
|
83 |
-
torch::Tensor so;
|
84 |
-
torch::Tensor s = si;
|
85 |
-
bool readSigns = !!s.numel();
|
86 |
-
int64_t sw_active = 0; // Active width of sign tensor.
|
87 |
-
if (writeSigns)
|
88 |
-
{
|
89 |
-
sw_active = yw * down - (down - 1) + fdt_w; // Active width in elements.
|
90 |
-
int64_t sh = yh * down - (down - 1) + fdt_h; // Height = active height.
|
91 |
-
int64_t sw = (sw_active + 15) & ~15; // Width = active width in elements, rounded up to multiple of 16.
|
92 |
-
TORCH_CHECK(sh <= INT_MAX && (sw >> 2) <= INT_MAX, "signs is too large");
|
93 |
-
s = so = torch::empty({x.size(0), x.size(1), sh, sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous);
|
94 |
-
}
|
95 |
-
else if (readSigns)
|
96 |
-
sw_active = s.size(3) << 2;
|
97 |
-
|
98 |
-
// Validate sign tensor if in use.
|
99 |
-
if (readSigns || writeSigns)
|
100 |
-
{
|
101 |
-
TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
|
102 |
-
TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
|
103 |
-
TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x");
|
104 |
-
TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
|
105 |
-
TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x");
|
106 |
-
TORCH_CHECK(s.size(2) <= INT_MAX && s.size(3) <= INT_MAX, "signs is too large");
|
107 |
-
}
|
108 |
-
|
109 |
-
// Populate rest of CUDA kernel parameters.
|
110 |
-
p.x = x.data_ptr();
|
111 |
-
p.y = y.data_ptr();
|
112 |
-
p.b = b.data_ptr();
|
113 |
-
p.s = (readSigns || writeSigns) ? s.data_ptr<unsigned char>() : 0;
|
114 |
-
p.fu = fu.data_ptr<float>();
|
115 |
-
p.fd = fd.data_ptr<float>();
|
116 |
-
p.pad0 = make_int2(px0, py0);
|
117 |
-
p.gain = gain;
|
118 |
-
p.slope = slope;
|
119 |
-
p.clamp = clamp;
|
120 |
-
p.flip = (flip_filters) ? 1 : 0;
|
121 |
-
p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
|
122 |
-
p.yShape = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
|
123 |
-
p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3), (int)s.size(2)) : make_int2(0, 0); // Width is in bytes. Contiguous.
|
124 |
-
p.sOfs = make_int2(sx, sy);
|
125 |
-
p.swLimit = (sw_active + 3) >> 2; // Rounded up to bytes.
|
126 |
-
|
127 |
-
// x, y, b strides are in bytes.
|
128 |
-
p.xStride = make_longlong4(sz * x.stride(3), sz * x.stride(2), sz * x.stride(1), sz * x.stride(0));
|
129 |
-
p.yStride = make_longlong4(sz * y.stride(3), sz * y.stride(2), sz * y.stride(1), sz * y.stride(0));
|
130 |
-
p.bStride = sz * b.stride(0);
|
131 |
-
|
132 |
-
// fu, fd strides are in elements.
|
133 |
-
p.fuStride = make_longlong3(fu.stride(-1), fu.dim() == 2 ? fu.stride(0) : 0, 0);
|
134 |
-
p.fdStride = make_longlong3(fd.stride(-1), fd.dim() == 2 ? fd.stride(0) : 0, 0);
|
135 |
-
|
136 |
-
// Determine if indices don't fit in int32. Support negative strides although Torch currently never produces those.
|
137 |
-
bool index64b = false;
|
138 |
-
if (std::abs(p.bStride * x.size(1)) > INT_MAX) index64b = true;
|
139 |
-
if (std::min(x.size(0) * p.xStride.w, 0ll) + std::min(x.size(1) * p.xStride.z, 0ll) + std::min(x.size(2) * p.xStride.y, 0ll) + std::min(x.size(3) * p.xStride.x, 0ll) < -INT_MAX) index64b = true;
|
140 |
-
if (std::max(x.size(0) * p.xStride.w, 0ll) + std::max(x.size(1) * p.xStride.z, 0ll) + std::max(x.size(2) * p.xStride.y, 0ll) + std::max(x.size(3) * p.xStride.x, 0ll) > INT_MAX) index64b = true;
|
141 |
-
if (std::min(y.size(0) * p.yStride.w, 0ll) + std::min(y.size(1) * p.yStride.z, 0ll) + std::min(y.size(2) * p.yStride.y, 0ll) + std::min(y.size(3) * p.yStride.x, 0ll) < -INT_MAX) index64b = true;
|
142 |
-
if (std::max(y.size(0) * p.yStride.w, 0ll) + std::max(y.size(1) * p.yStride.z, 0ll) + std::max(y.size(2) * p.yStride.y, 0ll) + std::max(y.size(3) * p.yStride.x, 0ll) > INT_MAX) index64b = true;
|
143 |
-
if (s.numel() > INT_MAX) index64b = true;
|
144 |
-
|
145 |
-
// Choose CUDA kernel.
|
146 |
-
filtered_lrelu_kernel_spec spec = { 0 };
|
147 |
-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_cuda", [&]
|
148 |
-
{
|
149 |
-
if constexpr (sizeof(scalar_t) <= 4) // Exclude doubles. constexpr prevents template instantiation.
|
150 |
-
{
|
151 |
-
// Choose kernel based on index type, datatype and sign read/write modes.
|
152 |
-
if (!index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, true, false>(p, sharedKB);
|
153 |
-
else if (!index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, false, true >(p, sharedKB);
|
154 |
-
else if (!index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, false, false>(p, sharedKB);
|
155 |
-
else if ( index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, true, false>(p, sharedKB);
|
156 |
-
else if ( index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, false, true >(p, sharedKB);
|
157 |
-
else if ( index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, false, false>(p, sharedKB);
|
158 |
-
}
|
159 |
-
});
|
160 |
-
TORCH_CHECK(spec.exec, "internal error - CUDA kernel not found") // This should not happen because we tested earlier that kernel exists.
|
161 |
-
|
162 |
-
// Launch CUDA kernel.
|
163 |
-
void* args[] = {&p};
|
164 |
-
int bx = spec.numWarps * 32;
|
165 |
-
int gx = (p.yShape.x - 1) / spec.tileOut.x + 1;
|
166 |
-
int gy = (p.yShape.y - 1) / spec.tileOut.y + 1;
|
167 |
-
int gz = p.yShape.z * p.yShape.w;
|
168 |
-
|
169 |
-
// Repeat multiple horizontal tiles in a CTA?
|
170 |
-
if (spec.xrep)
|
171 |
-
{
|
172 |
-
p.tilesXrep = spec.xrep;
|
173 |
-
p.tilesXdim = gx;
|
174 |
-
|
175 |
-
gx = (gx + p.tilesXrep - 1) / p.tilesXrep;
|
176 |
-
std::swap(gx, gy);
|
177 |
-
}
|
178 |
-
else
|
179 |
-
{
|
180 |
-
p.tilesXrep = 0;
|
181 |
-
p.tilesXdim = 0;
|
182 |
-
}
|
183 |
-
|
184 |
-
// Launch filter setup kernel.
|
185 |
-
AT_CUDA_CHECK(cudaLaunchKernel(spec.setup, 1, 1024, args, 0, at::cuda::getCurrentCUDAStream()));
|
186 |
-
|
187 |
-
// Copy kernels to constant memory.
|
188 |
-
if ( writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters<true, false>(at::cuda::getCurrentCUDAStream())));
|
189 |
-
else if (!writeSigns && readSigns) AT_CUDA_CHECK((copy_filters<false, true >(at::cuda::getCurrentCUDAStream())));
|
190 |
-
else if (!writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters<false, false>(at::cuda::getCurrentCUDAStream())));
|
191 |
-
|
192 |
-
// Set cache and shared memory configurations for main kernel.
|
193 |
-
AT_CUDA_CHECK(cudaFuncSetCacheConfig(spec.exec, cudaFuncCachePreferShared));
|
194 |
-
if (spec.dynamicSharedKB) // Need dynamically allocated shared memory?
|
195 |
-
AT_CUDA_CHECK(cudaFuncSetAttribute(spec.exec, cudaFuncAttributeMaxDynamicSharedMemorySize, spec.dynamicSharedKB << 10));
|
196 |
-
AT_CUDA_CHECK(cudaFuncSetSharedMemConfig(spec.exec, cudaSharedMemBankSizeFourByte));
|
197 |
-
|
198 |
-
// Launch main kernel.
|
199 |
-
const int maxSubGz = 65535; // CUDA maximum for block z dimension.
|
200 |
-
for (int zofs=0; zofs < gz; zofs += maxSubGz) // Do multiple launches if gz is too big.
|
201 |
-
{
|
202 |
-
p.blockZofs = zofs;
|
203 |
-
int subGz = std::min(maxSubGz, gz - zofs);
|
204 |
-
AT_CUDA_CHECK(cudaLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args, spec.dynamicSharedKB << 10, at::cuda::getCurrentCUDAStream()));
|
205 |
-
}
|
206 |
-
|
207 |
-
// Done.
|
208 |
-
return std::make_tuple(y, so, 0);
|
209 |
-
}
|
210 |
-
|
211 |
-
//------------------------------------------------------------------------
|
212 |
-
|
213 |
-
static torch::Tensor filtered_lrelu_act(torch::Tensor x, torch::Tensor si, int sx, int sy, float gain, float slope, float clamp, bool writeSigns)
|
214 |
-
{
|
215 |
-
// Set CUDA device.
|
216 |
-
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
|
217 |
-
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
218 |
-
|
219 |
-
// Validate arguments.
|
220 |
-
TORCH_CHECK(x.dim() == 4, "x must be rank 4");
|
221 |
-
TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large");
|
222 |
-
TORCH_CHECK(x.numel() > 0, "x is empty");
|
223 |
-
TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat || x.dtype() == torch::kDouble, "x must be float16, float32 or float64");
|
224 |
-
|
225 |
-
// Output signs if we don't have sign input.
|
226 |
-
torch::Tensor so;
|
227 |
-
torch::Tensor s = si;
|
228 |
-
bool readSigns = !!s.numel();
|
229 |
-
if (writeSigns)
|
230 |
-
{
|
231 |
-
int64_t sw = x.size(3);
|
232 |
-
sw = (sw + 15) & ~15; // Round to a multiple of 16 for coalescing.
|
233 |
-
s = so = torch::empty({x.size(0), x.size(1), x.size(2), sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous);
|
234 |
-
}
|
235 |
-
|
236 |
-
// Validate sign tensor if in use.
|
237 |
-
if (readSigns || writeSigns)
|
238 |
-
{
|
239 |
-
TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
|
240 |
-
TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
|
241 |
-
TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x");
|
242 |
-
TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
|
243 |
-
TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x");
|
244 |
-
TORCH_CHECK(s.size(2) <= INT_MAX && (s.size(3) << 2) <= INT_MAX, "signs tensor is too large");
|
245 |
-
}
|
246 |
-
|
247 |
-
// Initialize CUDA kernel parameters.
|
248 |
-
filtered_lrelu_act_kernel_params p;
|
249 |
-
p.x = x.data_ptr();
|
250 |
-
p.s = (readSigns || writeSigns) ? s.data_ptr<unsigned char>() : 0;
|
251 |
-
p.gain = gain;
|
252 |
-
p.slope = slope;
|
253 |
-
p.clamp = clamp;
|
254 |
-
p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
|
255 |
-
p.xStride = make_longlong4(x.stride(3), x.stride(2), x.stride(1), x.stride(0));
|
256 |
-
p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3) << 2, (int)s.size(2)) : make_int2(0, 0); // Width is in elements. Contiguous.
|
257 |
-
p.sOfs = make_int2(sx, sy);
|
258 |
-
|
259 |
-
// Choose CUDA kernel.
|
260 |
-
void* func = 0;
|
261 |
-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_act_cuda", [&]
|
262 |
-
{
|
263 |
-
if (writeSigns)
|
264 |
-
func = choose_filtered_lrelu_act_kernel<scalar_t, true, false>();
|
265 |
-
else if (readSigns)
|
266 |
-
func = choose_filtered_lrelu_act_kernel<scalar_t, false, true>();
|
267 |
-
else
|
268 |
-
func = choose_filtered_lrelu_act_kernel<scalar_t, false, false>();
|
269 |
-
});
|
270 |
-
TORCH_CHECK(func, "internal error - CUDA kernel not found");
|
271 |
-
|
272 |
-
// Launch CUDA kernel.
|
273 |
-
void* args[] = {&p};
|
274 |
-
int bx = 128; // 4 warps per block.
|
275 |
-
|
276 |
-
// Logical size of launch = writeSigns ? p.s : p.x
|
277 |
-
uint32_t gx = writeSigns ? p.sShape.x : p.xShape.x;
|
278 |
-
uint32_t gy = writeSigns ? p.sShape.y : p.xShape.y;
|
279 |
-
uint32_t gz = p.xShape.z * p.xShape.w; // Same as in p.sShape if signs are in use.
|
280 |
-
gx = (gx - 1) / bx + 1;
|
281 |
-
|
282 |
-
// Make sure grid y and z dimensions are within CUDA launch limits. Kernel loops internally to do the rest.
|
283 |
-
const uint32_t gmax = 65535;
|
284 |
-
gy = std::min(gy, gmax);
|
285 |
-
gz = std::min(gz, gmax);
|
286 |
-
|
287 |
-
// Launch.
|
288 |
-
AT_CUDA_CHECK(cudaLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0, at::cuda::getCurrentCUDAStream()));
|
289 |
-
return so;
|
290 |
-
}
|
291 |
-
|
292 |
-
//------------------------------------------------------------------------
|
293 |
-
|
294 |
-
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
295 |
-
{
|
296 |
-
m.def("filtered_lrelu", &filtered_lrelu); // The whole thing.
|
297 |
-
m.def("filtered_lrelu_act_", &filtered_lrelu_act); // Activation and sign tensor handling only. Modifies data tensor in-place.
|
298 |
-
}
|
299 |
-
|
300 |
-
//------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch_utils/ops/filtered_lrelu.cu
DELETED
@@ -1,1284 +0,0 @@
|
|
1 |
-
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
//
|
3 |
-
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
-
// and proprietary rights in and to this software, related documentation
|
5 |
-
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
// distribution of this software and related documentation without an express
|
7 |
-
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
-
|
9 |
-
#include <c10/util/Half.h>
|
10 |
-
#include "filtered_lrelu.h"
|
11 |
-
#include <cstdint>
|
12 |
-
|
13 |
-
//------------------------------------------------------------------------
|
14 |
-
// Helpers.
|
15 |
-
|
16 |
-
enum // Filter modes.
|
17 |
-
{
|
18 |
-
MODE_SUSD = 0, // Separable upsampling, separable downsampling.
|
19 |
-
MODE_FUSD = 1, // Full upsampling, separable downsampling.
|
20 |
-
MODE_SUFD = 2, // Separable upsampling, full downsampling.
|
21 |
-
MODE_FUFD = 3, // Full upsampling, full downsampling.
|
22 |
-
};
|
23 |
-
|
24 |
-
template <class T> struct InternalType;
|
25 |
-
template <> struct InternalType<double>
|
26 |
-
{
|
27 |
-
typedef double scalar_t; typedef double2 vec2_t; typedef double4 vec4_t;
|
28 |
-
__device__ __forceinline__ static vec2_t zero_vec2(void) { return make_double2(0, 0); }
|
29 |
-
__device__ __forceinline__ static vec4_t zero_vec4(void) { return make_double4(0, 0, 0, 0); }
|
30 |
-
__device__ __forceinline__ static double clamp(double x, double c) { return fmin(fmax(x, -c), c); }
|
31 |
-
};
|
32 |
-
template <> struct InternalType<float>
|
33 |
-
{
|
34 |
-
typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t;
|
35 |
-
__device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); }
|
36 |
-
__device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); }
|
37 |
-
__device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); }
|
38 |
-
};
|
39 |
-
template <> struct InternalType<c10::Half>
|
40 |
-
{
|
41 |
-
typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t;
|
42 |
-
__device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); }
|
43 |
-
__device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); }
|
44 |
-
__device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); }
|
45 |
-
};
|
46 |
-
|
47 |
-
#define MIN(A, B) ((A) < (B) ? (A) : (B))
|
48 |
-
#define MAX(A, B) ((A) > (B) ? (A) : (B))
|
49 |
-
#define CEIL_DIV(A, B) (((B)==1) ? (A) : \
|
50 |
-
((B)==2) ? ((int)((A)+1) >> 1) : \
|
51 |
-
((B)==4) ? ((int)((A)+3) >> 2) : \
|
52 |
-
(((A) + ((A) > 0 ? (B) - 1 : 0)) / (B)))
|
53 |
-
|
54 |
-
// This works only up to blocks of size 256 x 256 and for all N that are powers of two.
|
55 |
-
template <int N> __device__ __forceinline__ void fast_div_mod(int& x, int& y, unsigned int i)
|
56 |
-
{
|
57 |
-
if ((N & (N-1)) && N <= 256)
|
58 |
-
y = (i * ((1<<24)/N + 1)) >> 24; // Assumes N <= 256, i < N*256.
|
59 |
-
else
|
60 |
-
y = i/N;
|
61 |
-
|
62 |
-
x = i - y*N;
|
63 |
-
}
|
64 |
-
|
65 |
-
// Type cast stride before reading it.
|
66 |
-
template <class T> __device__ __forceinline__ T get_stride(const int64_t& x)
|
67 |
-
{
|
68 |
-
return *reinterpret_cast<const T*>(&x);
|
69 |
-
}
|
70 |
-
|
71 |
-
//------------------------------------------------------------------------
|
72 |
-
// Filters, setup kernel, copying function.
|
73 |
-
|
74 |
-
#define MAX_FILTER_SIZE 32
|
75 |
-
|
76 |
-
// Combined up/down filter buffers so that transfer can be done with one copy.
|
77 |
-
__device__ float g_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in global memory, written by setup kernel.
|
78 |
-
__device__ __constant__ float c_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in constant memory, read by main kernel.
|
79 |
-
|
80 |
-
// Accessors to combined buffers to index up/down filters individually.
|
81 |
-
#define c_fu (c_fbuf)
|
82 |
-
#define c_fd (c_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE)
|
83 |
-
#define g_fu (g_fbuf)
|
84 |
-
#define g_fd (g_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE)
|
85 |
-
|
86 |
-
// Set up filters into global memory buffer.
|
87 |
-
static __global__ void setup_filters_kernel(filtered_lrelu_kernel_params p)
|
88 |
-
{
|
89 |
-
for (int idx = threadIdx.x; idx < MAX_FILTER_SIZE * MAX_FILTER_SIZE; idx += blockDim.x)
|
90 |
-
{
|
91 |
-
int x, y;
|
92 |
-
fast_div_mod<MAX_FILTER_SIZE>(x, y, idx);
|
93 |
-
|
94 |
-
int fu_x = p.flip ? x : (p.fuShape.x - 1 - x);
|
95 |
-
int fu_y = p.flip ? y : (p.fuShape.y - 1 - y);
|
96 |
-
if (p.fuShape.y > 0)
|
97 |
-
g_fu[idx] = (x >= p.fuShape.x || y >= p.fuShape.y) ? 0.0f : p.fu[fu_x * p.fuStride.x + fu_y * p.fuStride.y];
|
98 |
-
else
|
99 |
-
g_fu[idx] = (x >= p.fuShape.x || y > 0) ? 0.0f : p.fu[fu_x * p.fuStride.x];
|
100 |
-
|
101 |
-
int fd_x = p.flip ? x : (p.fdShape.x - 1 - x);
|
102 |
-
int fd_y = p.flip ? y : (p.fdShape.y - 1 - y);
|
103 |
-
if (p.fdShape.y > 0)
|
104 |
-
g_fd[idx] = (x >= p.fdShape.x || y >= p.fdShape.y) ? 0.0f : p.fd[fd_x * p.fdStride.x + fd_y * p.fdStride.y];
|
105 |
-
else
|
106 |
-
g_fd[idx] = (x >= p.fdShape.x || y > 0) ? 0.0f : p.fd[fd_x * p.fdStride.x];
|
107 |
-
}
|
108 |
-
}
|
109 |
-
|
110 |
-
// Host function to copy filters written by setup kernel into constant buffer for main kernel.
|
111 |
-
template <bool, bool> static cudaError_t copy_filters(cudaStream_t stream)
|
112 |
-
{
|
113 |
-
void* src = 0;
|
114 |
-
cudaError_t err = cudaGetSymbolAddress(&src, g_fbuf);
|
115 |
-
if (err) return err;
|
116 |
-
return cudaMemcpyToSymbolAsync(c_fbuf, src, 2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE * sizeof(float), 0, cudaMemcpyDeviceToDevice, stream);
|
117 |
-
}
|
118 |
-
|
119 |
-
//------------------------------------------------------------------------
|
120 |
-
// Coordinate spaces:
|
121 |
-
// - Relative to input tensor: inX, inY, tileInX, tileInY
|
122 |
-
// - Relative to input tile: relInX, relInY, tileInW, tileInH
|
123 |
-
// - Relative to upsampled tile: relUpX, relUpY, tileUpW, tileUpH
|
124 |
-
// - Relative to output tile: relOutX, relOutY, tileOutW, tileOutH
|
125 |
-
// - Relative to output tensor: outX, outY, tileOutX, tileOutY
|
126 |
-
//
|
127 |
-
// Relationships between coordinate spaces:
|
128 |
-
// - inX = tileInX + relInX
|
129 |
-
// - inY = tileInY + relInY
|
130 |
-
// - relUpX = relInX * up + phaseInX
|
131 |
-
// - relUpY = relInY * up + phaseInY
|
132 |
-
// - relUpX = relOutX * down
|
133 |
-
// - relUpY = relOutY * down
|
134 |
-
// - outX = tileOutX + relOutX
|
135 |
-
// - outY = tileOutY + relOutY
|
136 |
-
|
137 |
-
extern __shared__ char s_buf_raw[]; // When sharedKB <= 48, allocate shared memory statically inside the kernel, otherwise use the externally allocated shared memory buffer.
|
138 |
-
|
139 |
-
template <class T, class index_t, int sharedKB, bool signWrite, bool signRead, int filterMode, int up, int fuSize, int down, int fdSize, int tileOutW, int tileOutH, int threadsPerBlock, bool enableXrep, bool enableWriteSkip>
|
140 |
-
static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p)
|
141 |
-
{
|
142 |
-
// Check that we don't try to support non-existing filter modes.
|
143 |
-
static_assert(up == 1 || up == 2 || up == 4, "only up=1, up=2, up=4 scales supported");
|
144 |
-
static_assert(down == 1 || down == 2 || down == 4, "only down=1, down=2, down=4 scales supported");
|
145 |
-
static_assert(fuSize >= up, "upsampling filter size must be at least upsampling factor");
|
146 |
-
static_assert(fdSize >= down, "downsampling filter size must be at least downsampling factor");
|
147 |
-
static_assert(fuSize % up == 0, "upsampling filter size must be divisible with upsampling factor");
|
148 |
-
static_assert(fdSize % down == 0, "downsampling filter size must be divisible with downsampling factor");
|
149 |
-
static_assert(fuSize <= MAX_FILTER_SIZE && fdSize <= MAX_FILTER_SIZE, "filter size greater than MAX_FILTER_SIZE");
|
150 |
-
static_assert(up != 1 || (fuSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "up=1 supported only for 1x1 full filters");
|
151 |
-
static_assert(down != 1 || (fdSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "down=1 supported only for 1x1 full filters");
|
152 |
-
static_assert(!(up == 4 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "full filters not supported for up=4");
|
153 |
-
static_assert(!(down == 4 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "full filters not supported for down=4");
|
154 |
-
|
155 |
-
// Static definitions.
|
156 |
-
typedef typename InternalType<T>::scalar_t scalar_t;
|
157 |
-
typedef typename InternalType<T>::vec2_t vec2_t;
|
158 |
-
typedef typename InternalType<T>::vec4_t vec4_t;
|
159 |
-
const int tileUpW = (tileOutW * down + (fdSize - 1) - (down - 1) + 3) & ~3; // Upsampled tile width, rounded up to multiple of 4.
|
160 |
-
const int tileUpH = tileOutH * down + (fdSize - 1) - (down - 1); // Upsampled tile height.
|
161 |
-
const int tileInW = CEIL_DIV(tileUpW + (fuSize - 1), up); // Input tile width.
|
162 |
-
const int tileInH = CEIL_DIV(tileUpH + (fuSize - 1), up); // Input tile height.
|
163 |
-
const int tileUpH_up = CEIL_DIV(tileUpH, up) * up; // Upsampled tile height rounded up to a multiple of up.
|
164 |
-
const int tileInH_up = CEIL_DIV(tileUpH_up + (fuSize - 1), up); // For allocations only, to avoid shared memory read overruns with up=2 and up=4.
|
165 |
-
|
166 |
-
// Merge 1x1 downsampling into last upsampling step for upf1 and ups2.
|
167 |
-
const bool downInline = (down == 1) && ((up == 1 && filterMode == MODE_FUFD) || (up == 2 && filterMode == MODE_SUFD));
|
168 |
-
|
169 |
-
// Sizes of logical buffers.
|
170 |
-
const int szIn = tileInH_up * tileInW;
|
171 |
-
const int szUpX = tileInH_up * tileUpW;
|
172 |
-
const int szUpXY = downInline ? 0 : (tileUpH * tileUpW);
|
173 |
-
const int szDownX = tileUpH * tileOutW;
|
174 |
-
|
175 |
-
// Sizes for shared memory arrays.
|
176 |
-
const int s_buf0_size_base =
|
177 |
-
(filterMode == MODE_SUSD) ? MAX(szIn, szUpXY) :
|
178 |
-
(filterMode == MODE_FUSD) ? MAX(szIn, szDownX) :
|
179 |
-
(filterMode == MODE_SUFD) ? MAX(szIn, szUpXY) :
|
180 |
-
(filterMode == MODE_FUFD) ? szIn :
|
181 |
-
-1;
|
182 |
-
const int s_buf1_size_base =
|
183 |
-
(filterMode == MODE_SUSD) ? MAX(szUpX, szDownX) :
|
184 |
-
(filterMode == MODE_FUSD) ? szUpXY :
|
185 |
-
(filterMode == MODE_SUFD) ? szUpX :
|
186 |
-
(filterMode == MODE_FUFD) ? szUpXY :
|
187 |
-
-1;
|
188 |
-
|
189 |
-
// Ensure U128 alignment.
|
190 |
-
const int s_buf0_size = (s_buf0_size_base + 3) & ~3;
|
191 |
-
const int s_buf1_size = (s_buf1_size_base + 3) & ~3;
|
192 |
-
|
193 |
-
// Check at compile time that we don't use too much shared memory.
|
194 |
-
static_assert((s_buf0_size + s_buf1_size) * sizeof(scalar_t) <= (sharedKB << 10), "shared memory overflow");
|
195 |
-
|
196 |
-
// Declare shared memory arrays.
|
197 |
-
scalar_t* s_buf0;
|
198 |
-
scalar_t* s_buf1;
|
199 |
-
if (sharedKB <= 48)
|
200 |
-
{
|
201 |
-
// Allocate shared memory arrays here.
|
202 |
-
__shared__ scalar_t s_buf0_st[(sharedKB > 48) ? (1<<24) : (s_buf0_size + s_buf1_size)]; // Prevent launching if this isn't optimized away when unused.
|
203 |
-
s_buf0 = s_buf0_st;
|
204 |
-
s_buf1 = s_buf0 + s_buf0_size;
|
205 |
-
}
|
206 |
-
else
|
207 |
-
{
|
208 |
-
// Use the dynamically allocated shared memory array.
|
209 |
-
s_buf0 = (scalar_t*)s_buf_raw;
|
210 |
-
s_buf1 = s_buf0 + s_buf0_size;
|
211 |
-
}
|
212 |
-
|
213 |
-
// Pointers to the buffers.
|
214 |
-
scalar_t* s_tileIn; // Input tile: [relInX * tileInH + relInY]
|
215 |
-
scalar_t* s_tileUpX; // After horizontal upsampling: [relInY * tileUpW + relUpX]
|
216 |
-
scalar_t* s_tileUpXY; // After upsampling: [relUpY * tileUpW + relUpX]
|
217 |
-
scalar_t* s_tileDownX; // After horizontal downsampling: [relUpY * tileOutW + relOutX]
|
218 |
-
if (filterMode == MODE_SUSD)
|
219 |
-
{
|
220 |
-
s_tileIn = s_buf0;
|
221 |
-
s_tileUpX = s_buf1;
|
222 |
-
s_tileUpXY = s_buf0;
|
223 |
-
s_tileDownX = s_buf1;
|
224 |
-
}
|
225 |
-
else if (filterMode == MODE_FUSD)
|
226 |
-
{
|
227 |
-
s_tileIn = s_buf0;
|
228 |
-
s_tileUpXY = s_buf1;
|
229 |
-
s_tileDownX = s_buf0;
|
230 |
-
}
|
231 |
-
else if (filterMode == MODE_SUFD)
|
232 |
-
{
|
233 |
-
s_tileIn = s_buf0;
|
234 |
-
s_tileUpX = s_buf1;
|
235 |
-
s_tileUpXY = s_buf0;
|
236 |
-
}
|
237 |
-
else if (filterMode == MODE_FUFD)
|
238 |
-
{
|
239 |
-
s_tileIn = s_buf0;
|
240 |
-
s_tileUpXY = s_buf1;
|
241 |
-
}
|
242 |
-
|
243 |
-
// Allow large grids in z direction via per-launch offset.
|
244 |
-
int channelIdx = blockIdx.z + p.blockZofs;
|
245 |
-
int batchIdx = channelIdx / p.yShape.z;
|
246 |
-
channelIdx -= batchIdx * p.yShape.z;
|
247 |
-
|
248 |
-
// Offset to output feature map. In bytes.
|
249 |
-
index_t mapOfsOut = channelIdx * get_stride<index_t>(p.yStride.z) + batchIdx * get_stride<index_t>(p.yStride.w);
|
250 |
-
|
251 |
-
// Sign shift amount.
|
252 |
-
uint32_t signXo = ((threadIdx.x + p.sOfs.x) << 1) & 6;
|
253 |
-
|
254 |
-
// Inner tile loop.
|
255 |
-
#pragma unroll 1
|
256 |
-
for (int tileIdx = 0; !enableXrep || (tileIdx < MIN(p.tilesXrep, p.tilesXdim - p.tilesXrep * blockIdx.y)); tileIdx++)
|
257 |
-
{
|
258 |
-
// Locate output tile.
|
259 |
-
int tileX = enableXrep ? blockIdx.y * p.tilesXrep + tileIdx : blockIdx.x;
|
260 |
-
int tileOutX = tileX * tileOutW;
|
261 |
-
int tileOutY = (enableXrep ? blockIdx.x : blockIdx.y) * tileOutH;
|
262 |
-
|
263 |
-
// Locate input tile.
|
264 |
-
int tmpX = tileOutX * down - p.pad0.x;
|
265 |
-
int tmpY = tileOutY * down - p.pad0.y;
|
266 |
-
int tileInX = CEIL_DIV(tmpX, up);
|
267 |
-
int tileInY = CEIL_DIV(tmpY, up);
|
268 |
-
const int phaseInX = tileInX * up - tmpX;
|
269 |
-
const int phaseInY = tileInY * up - tmpY;
|
270 |
-
|
271 |
-
// Extra sync if input and output buffers are the same and we are not on first tile.
|
272 |
-
if (enableXrep && tileIdx > 0 && (filterMode == MODE_FUSD || (filterMode == MODE_SUFD && !downInline) || (filterMode == MODE_FUFD && downInline)))
|
273 |
-
__syncthreads();
|
274 |
-
|
275 |
-
// Load input tile & apply bias. Unrolled.
|
276 |
-
scalar_t b = (scalar_t)*(const T*)((const char*)p.b + (channelIdx * get_stride<index_t>(p.bStride)));
|
277 |
-
index_t mapOfsIn = channelIdx * get_stride<index_t>(p.xStride.z) + batchIdx * get_stride<index_t>(p.xStride.w);
|
278 |
-
int idx = threadIdx.x;
|
279 |
-
const int loopCountIN = CEIL_DIV(tileInW * tileInH, threadsPerBlock);
|
280 |
-
#pragma unroll
|
281 |
-
for (int loop = 0; loop < loopCountIN; loop++)
|
282 |
-
{
|
283 |
-
int relInX, relInY;
|
284 |
-
fast_div_mod<tileInW>(relInX, relInY, idx);
|
285 |
-
int inX = tileInX + relInX;
|
286 |
-
int inY = tileInY + relInY;
|
287 |
-
scalar_t v = 0;
|
288 |
-
|
289 |
-
if ((uint32_t)inX < p.xShape.x && (uint32_t)inY < p.xShape.y)
|
290 |
-
v = (scalar_t)*((const T*)((const char*)p.x + (inX * get_stride<index_t>(p.xStride.x) + inY * get_stride<index_t>(p.xStride.y) + mapOfsIn))) + b;
|
291 |
-
|
292 |
-
bool skip = (loop == loopCountIN-1) && (idx >= tileInW * tileInH);
|
293 |
-
if (!skip)
|
294 |
-
s_tileIn[idx] = v;
|
295 |
-
|
296 |
-
idx += threadsPerBlock;
|
297 |
-
}
|
298 |
-
|
299 |
-
if (filterMode == MODE_SUSD || filterMode == MODE_SUFD) // Separable upsampling filter.
|
300 |
-
{
|
301 |
-
// Horizontal upsampling.
|
302 |
-
__syncthreads();
|
303 |
-
if (up == 4)
|
304 |
-
{
|
305 |
-
for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up)
|
306 |
-
{
|
307 |
-
int relUpX0, relInY;
|
308 |
-
fast_div_mod<tileUpW>(relUpX0, relInY, idx);
|
309 |
-
int relInX0 = relUpX0 / up;
|
310 |
-
int src0 = relInX0 + tileInW * relInY;
|
311 |
-
int dst = relInY * tileUpW + relUpX0;
|
312 |
-
vec4_t v = InternalType<T>::zero_vec4();
|
313 |
-
scalar_t a = s_tileIn[src0];
|
314 |
-
if (phaseInX == 0)
|
315 |
-
{
|
316 |
-
#pragma unroll
|
317 |
-
for (int step = 0; step < fuSize / up; step++)
|
318 |
-
{
|
319 |
-
v.x += a * (scalar_t)c_fu[step * up + 0];
|
320 |
-
a = s_tileIn[src0 + step + 1];
|
321 |
-
v.y += a * (scalar_t)c_fu[step * up + 3];
|
322 |
-
v.z += a * (scalar_t)c_fu[step * up + 2];
|
323 |
-
v.w += a * (scalar_t)c_fu[step * up + 1];
|
324 |
-
}
|
325 |
-
}
|
326 |
-
else if (phaseInX == 1)
|
327 |
-
{
|
328 |
-
#pragma unroll
|
329 |
-
for (int step = 0; step < fuSize / up; step++)
|
330 |
-
{
|
331 |
-
v.x += a * (scalar_t)c_fu[step * up + 1];
|
332 |
-
v.y += a * (scalar_t)c_fu[step * up + 0];
|
333 |
-
a = s_tileIn[src0 + step + 1];
|
334 |
-
v.z += a * (scalar_t)c_fu[step * up + 3];
|
335 |
-
v.w += a * (scalar_t)c_fu[step * up + 2];
|
336 |
-
}
|
337 |
-
}
|
338 |
-
else if (phaseInX == 2)
|
339 |
-
{
|
340 |
-
#pragma unroll
|
341 |
-
for (int step = 0; step < fuSize / up; step++)
|
342 |
-
{
|
343 |
-
v.x += a * (scalar_t)c_fu[step * up + 2];
|
344 |
-
v.y += a * (scalar_t)c_fu[step * up + 1];
|
345 |
-
v.z += a * (scalar_t)c_fu[step * up + 0];
|
346 |
-
a = s_tileIn[src0 + step + 1];
|
347 |
-
v.w += a * (scalar_t)c_fu[step * up + 3];
|
348 |
-
}
|
349 |
-
}
|
350 |
-
else // (phaseInX == 3)
|
351 |
-
{
|
352 |
-
#pragma unroll
|
353 |
-
for (int step = 0; step < fuSize / up; step++)
|
354 |
-
{
|
355 |
-
v.x += a * (scalar_t)c_fu[step * up + 3];
|
356 |
-
v.y += a * (scalar_t)c_fu[step * up + 2];
|
357 |
-
v.z += a * (scalar_t)c_fu[step * up + 1];
|
358 |
-
v.w += a * (scalar_t)c_fu[step * up + 0];
|
359 |
-
a = s_tileIn[src0 + step + 1];
|
360 |
-
}
|
361 |
-
}
|
362 |
-
s_tileUpX[dst+0] = v.x;
|
363 |
-
s_tileUpX[dst+1] = v.y;
|
364 |
-
s_tileUpX[dst+2] = v.z;
|
365 |
-
s_tileUpX[dst+3] = v.w;
|
366 |
-
}
|
367 |
-
}
|
368 |
-
else if (up == 2)
|
369 |
-
{
|
370 |
-
bool p0 = (phaseInX == 0);
|
371 |
-
for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up)
|
372 |
-
{
|
373 |
-
int relUpX0, relInY;
|
374 |
-
fast_div_mod<tileUpW>(relUpX0, relInY, idx);
|
375 |
-
int relInX0 = relUpX0 / up;
|
376 |
-
int src0 = relInX0 + tileInW * relInY;
|
377 |
-
int dst = relInY * tileUpW + relUpX0;
|
378 |
-
vec2_t v = InternalType<T>::zero_vec2();
|
379 |
-
scalar_t a = s_tileIn[src0];
|
380 |
-
if (p0) // (phaseInX == 0)
|
381 |
-
{
|
382 |
-
#pragma unroll
|
383 |
-
for (int step = 0; step < fuSize / up; step++)
|
384 |
-
{
|
385 |
-
v.x += a * (scalar_t)c_fu[step * up + 0];
|
386 |
-
a = s_tileIn[src0 + step + 1];
|
387 |
-
v.y += a * (scalar_t)c_fu[step * up + 1];
|
388 |
-
}
|
389 |
-
}
|
390 |
-
else // (phaseInX == 1)
|
391 |
-
{
|
392 |
-
#pragma unroll
|
393 |
-
for (int step = 0; step < fuSize / up; step++)
|
394 |
-
{
|
395 |
-
v.x += a * (scalar_t)c_fu[step * up + 1];
|
396 |
-
v.y += a * (scalar_t)c_fu[step * up + 0];
|
397 |
-
a = s_tileIn[src0 + step + 1];
|
398 |
-
}
|
399 |
-
}
|
400 |
-
s_tileUpX[dst+0] = v.x;
|
401 |
-
s_tileUpX[dst+1] = v.y;
|
402 |
-
}
|
403 |
-
}
|
404 |
-
|
405 |
-
// Vertical upsampling & nonlinearity.
|
406 |
-
|
407 |
-
__syncthreads();
|
408 |
-
int groupMask = 15 << ((threadIdx.x & 31) & ~3);
|
409 |
-
int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs.
|
410 |
-
int sShapeMaxY = MIN(p.sShape.y, tileOutY * down + tileUpH); // Avoid out-of-tile sign writes.
|
411 |
-
if (up == 4)
|
412 |
-
{
|
413 |
-
minY -= 3; // Adjust according to block height.
|
414 |
-
for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x)
|
415 |
-
{
|
416 |
-
int relUpX, relInY0;
|
417 |
-
fast_div_mod<tileUpW>(relUpX, relInY0, idx);
|
418 |
-
int relUpY0 = relInY0 * up;
|
419 |
-
int src0 = relInY0 * tileUpW + relUpX;
|
420 |
-
int dst = relUpY0 * tileUpW + relUpX;
|
421 |
-
vec4_t v = InternalType<T>::zero_vec4();
|
422 |
-
|
423 |
-
scalar_t a = s_tileUpX[src0];
|
424 |
-
if (phaseInY == 0)
|
425 |
-
{
|
426 |
-
#pragma unroll
|
427 |
-
for (int step = 0; step < fuSize / up; step++)
|
428 |
-
{
|
429 |
-
v.x += a * (scalar_t)c_fu[step * up + 0];
|
430 |
-
a = s_tileUpX[src0 + (step + 1) * tileUpW];
|
431 |
-
v.y += a * (scalar_t)c_fu[step * up + 3];
|
432 |
-
v.z += a * (scalar_t)c_fu[step * up + 2];
|
433 |
-
v.w += a * (scalar_t)c_fu[step * up + 1];
|
434 |
-
}
|
435 |
-
}
|
436 |
-
else if (phaseInY == 1)
|
437 |
-
{
|
438 |
-
#pragma unroll
|
439 |
-
for (int step = 0; step < fuSize / up; step++)
|
440 |
-
{
|
441 |
-
v.x += a * (scalar_t)c_fu[step * up + 1];
|
442 |
-
v.y += a * (scalar_t)c_fu[step * up + 0];
|
443 |
-
a = s_tileUpX[src0 + (step + 1) * tileUpW];
|
444 |
-
v.z += a * (scalar_t)c_fu[step * up + 3];
|
445 |
-
v.w += a * (scalar_t)c_fu[step * up + 2];
|
446 |
-
}
|
447 |
-
}
|
448 |
-
else if (phaseInY == 2)
|
449 |
-
{
|
450 |
-
#pragma unroll
|
451 |
-
for (int step = 0; step < fuSize / up; step++)
|
452 |
-
{
|
453 |
-
v.x += a * (scalar_t)c_fu[step * up + 2];
|
454 |
-
v.y += a * (scalar_t)c_fu[step * up + 1];
|
455 |
-
v.z += a * (scalar_t)c_fu[step * up + 0];
|
456 |
-
a = s_tileUpX[src0 + (step + 1) * tileUpW];
|
457 |
-
v.w += a * (scalar_t)c_fu[step * up + 3];
|
458 |
-
}
|
459 |
-
}
|
460 |
-
else // (phaseInY == 3)
|
461 |
-
{
|
462 |
-
#pragma unroll
|
463 |
-
for (int step = 0; step < fuSize / up; step++)
|
464 |
-
{
|
465 |
-
v.x += a * (scalar_t)c_fu[step * up + 3];
|
466 |
-
v.y += a * (scalar_t)c_fu[step * up + 2];
|
467 |
-
v.z += a * (scalar_t)c_fu[step * up + 1];
|
468 |
-
v.w += a * (scalar_t)c_fu[step * up + 0];
|
469 |
-
a = s_tileUpX[src0 + (step + 1) * tileUpW];
|
470 |
-
}
|
471 |
-
}
|
472 |
-
|
473 |
-
int x = tileOutX * down + relUpX;
|
474 |
-
int y = tileOutY * down + relUpY0;
|
475 |
-
int signX = x + p.sOfs.x;
|
476 |
-
int signY = y + p.sOfs.y;
|
477 |
-
int signZ = blockIdx.z + p.blockZofs;
|
478 |
-
int signXb = signX >> 2;
|
479 |
-
index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
|
480 |
-
index_t si1 = si0 + p.sShape.x;
|
481 |
-
index_t si2 = si0 + p.sShape.x * 2;
|
482 |
-
index_t si3 = si0 + p.sShape.x * 3;
|
483 |
-
|
484 |
-
v.x *= (scalar_t)((float)up * (float)up * p.gain);
|
485 |
-
v.y *= (scalar_t)((float)up * (float)up * p.gain);
|
486 |
-
v.z *= (scalar_t)((float)up * (float)up * p.gain);
|
487 |
-
v.w *= (scalar_t)((float)up * (float)up * p.gain);
|
488 |
-
|
489 |
-
if (signWrite)
|
490 |
-
{
|
491 |
-
if (!enableWriteSkip)
|
492 |
-
{
|
493 |
-
// Determine and write signs.
|
494 |
-
int sx = __float_as_uint(v.x) >> 31 << 0;
|
495 |
-
int sy = __float_as_uint(v.y) >> 31 << 8;
|
496 |
-
int sz = __float_as_uint(v.z) >> 31 << 16;
|
497 |
-
int sw = __float_as_uint(v.w) >> 31 << 24;
|
498 |
-
if (sx) v.x *= p.slope;
|
499 |
-
if (sy) v.y *= p.slope;
|
500 |
-
if (sz) v.z *= p.slope;
|
501 |
-
if (sw) v.w *= p.slope;
|
502 |
-
if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }
|
503 |
-
if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }
|
504 |
-
if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType<T>::clamp(v.z, p.clamp); }
|
505 |
-
if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType<T>::clamp(v.w, p.clamp); }
|
506 |
-
|
507 |
-
if ((uint32_t)signXb < p.swLimit && signY >= minY)
|
508 |
-
{
|
509 |
-
// Combine signs.
|
510 |
-
uint32_t s = sx + sy + sw + sz;
|
511 |
-
s <<= (signX & 3) << 1;
|
512 |
-
s |= __shfl_xor_sync(groupMask, s, 1);
|
513 |
-
s |= __shfl_xor_sync(groupMask, s, 2);
|
514 |
-
|
515 |
-
// Write signs.
|
516 |
-
if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
|
517 |
-
if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
|
518 |
-
if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); }
|
519 |
-
if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); }
|
520 |
-
}
|
521 |
-
}
|
522 |
-
else
|
523 |
-
{
|
524 |
-
// Determine and write signs.
|
525 |
-
if ((uint32_t)signXb < p.swLimit && signY >= minY)
|
526 |
-
{
|
527 |
-
int sx = __float_as_uint(v.x) >> 31 << 0;
|
528 |
-
int sy = __float_as_uint(v.y) >> 31 << 8;
|
529 |
-
int sz = __float_as_uint(v.z) >> 31 << 16;
|
530 |
-
int sw = __float_as_uint(v.w) >> 31 << 24;
|
531 |
-
if (sx) v.x *= p.slope;
|
532 |
-
if (sy) v.y *= p.slope;
|
533 |
-
if (sz) v.z *= p.slope;
|
534 |
-
if (sw) v.w *= p.slope;
|
535 |
-
if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }
|
536 |
-
if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }
|
537 |
-
if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType<T>::clamp(v.z, p.clamp); }
|
538 |
-
if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType<T>::clamp(v.w, p.clamp); }
|
539 |
-
|
540 |
-
// Combine signs.
|
541 |
-
uint32_t s = sx + sy + sw + sz;
|
542 |
-
s <<= (signX & 3) << 1;
|
543 |
-
s |= __shfl_xor_sync(groupMask, s, 1);
|
544 |
-
s |= __shfl_xor_sync(groupMask, s, 2);
|
545 |
-
|
546 |
-
// Write signs.
|
547 |
-
if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
|
548 |
-
if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
|
549 |
-
if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); }
|
550 |
-
if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); }
|
551 |
-
}
|
552 |
-
else
|
553 |
-
{
|
554 |
-
// Just compute the values.
|
555 |
-
if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
|
556 |
-
if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
|
557 |
-
if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);
|
558 |
-
if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);
|
559 |
-
}
|
560 |
-
}
|
561 |
-
}
|
562 |
-
else if (signRead) // Read signs and apply.
|
563 |
-
{
|
564 |
-
if ((uint32_t)signXb < p.swLimit)
|
565 |
-
{
|
566 |
-
int ss = (signX & 3) << 1;
|
567 |
-
if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> ss; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; }
|
568 |
-
if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> ss; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; }
|
569 |
-
if ((uint32_t)(signY + 2) < p.sShape.y) { int s = p.s[si2] >> ss; if (s & 1) v.z *= p.slope; if (s & 2) v.z = 0.f; }
|
570 |
-
if ((uint32_t)(signY + 3) < p.sShape.y) { int s = p.s[si3] >> ss; if (s & 1) v.w *= p.slope; if (s & 2) v.w = 0.f; }
|
571 |
-
}
|
572 |
-
}
|
573 |
-
else // Forward pass with no sign write.
|
574 |
-
{
|
575 |
-
if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
|
576 |
-
if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
|
577 |
-
if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);
|
578 |
-
if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);
|
579 |
-
}
|
580 |
-
|
581 |
-
s_tileUpXY[dst + 0 * tileUpW] = v.x;
|
582 |
-
if (relUpY0 + 1 < tileUpH) s_tileUpXY[dst + 1 * tileUpW] = v.y;
|
583 |
-
if (relUpY0 + 2 < tileUpH) s_tileUpXY[dst + 2 * tileUpW] = v.z;
|
584 |
-
if (relUpY0 + 3 < tileUpH) s_tileUpXY[dst + 3 * tileUpW] = v.w;
|
585 |
-
}
|
586 |
-
}
|
587 |
-
else if (up == 2)
|
588 |
-
{
|
589 |
-
minY -= 1; // Adjust according to block height.
|
590 |
-
for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x)
|
591 |
-
{
|
592 |
-
int relUpX, relInY0;
|
593 |
-
fast_div_mod<tileUpW>(relUpX, relInY0, idx);
|
594 |
-
int relUpY0 = relInY0 * up;
|
595 |
-
int src0 = relInY0 * tileUpW + relUpX;
|
596 |
-
int dst = relUpY0 * tileUpW + relUpX;
|
597 |
-
vec2_t v = InternalType<T>::zero_vec2();
|
598 |
-
|
599 |
-
scalar_t a = s_tileUpX[src0];
|
600 |
-
if (phaseInY == 0)
|
601 |
-
{
|
602 |
-
#pragma unroll
|
603 |
-
for (int step = 0; step < fuSize / up; step++)
|
604 |
-
{
|
605 |
-
v.x += a * (scalar_t)c_fu[step * up + 0];
|
606 |
-
a = s_tileUpX[src0 + (step + 1) * tileUpW];
|
607 |
-
v.y += a * (scalar_t)c_fu[step * up + 1];
|
608 |
-
}
|
609 |
-
}
|
610 |
-
else // (phaseInY == 1)
|
611 |
-
{
|
612 |
-
#pragma unroll
|
613 |
-
for (int step = 0; step < fuSize / up; step++)
|
614 |
-
{
|
615 |
-
v.x += a * (scalar_t)c_fu[step * up + 1];
|
616 |
-
v.y += a * (scalar_t)c_fu[step * up + 0];
|
617 |
-
a = s_tileUpX[src0 + (step + 1) * tileUpW];
|
618 |
-
}
|
619 |
-
}
|
620 |
-
|
621 |
-
int x = tileOutX * down + relUpX;
|
622 |
-
int y = tileOutY * down + relUpY0;
|
623 |
-
int signX = x + p.sOfs.x;
|
624 |
-
int signY = y + p.sOfs.y;
|
625 |
-
int signZ = blockIdx.z + p.blockZofs;
|
626 |
-
int signXb = signX >> 2;
|
627 |
-
index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
|
628 |
-
index_t si1 = si0 + p.sShape.x;
|
629 |
-
|
630 |
-
v.x *= (scalar_t)((float)up * (float)up * p.gain);
|
631 |
-
v.y *= (scalar_t)((float)up * (float)up * p.gain);
|
632 |
-
|
633 |
-
if (signWrite)
|
634 |
-
{
|
635 |
-
if (!enableWriteSkip)
|
636 |
-
{
|
637 |
-
// Determine and write signs.
|
638 |
-
int sx = __float_as_uint(v.x) >> 31 << 0;
|
639 |
-
int sy = __float_as_uint(v.y) >> 31 << 8;
|
640 |
-
if (sx) v.x *= p.slope;
|
641 |
-
if (sy) v.y *= p.slope;
|
642 |
-
if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }
|
643 |
-
if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }
|
644 |
-
|
645 |
-
if ((uint32_t)signXb < p.swLimit && signY >= minY)
|
646 |
-
{
|
647 |
-
// Combine signs.
|
648 |
-
int s = sx + sy;
|
649 |
-
s <<= signXo;
|
650 |
-
s |= __shfl_xor_sync(groupMask, s, 1);
|
651 |
-
s |= __shfl_xor_sync(groupMask, s, 2);
|
652 |
-
|
653 |
-
// Write signs.
|
654 |
-
if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
|
655 |
-
if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
|
656 |
-
}
|
657 |
-
}
|
658 |
-
else
|
659 |
-
{
|
660 |
-
// Determine and write signs.
|
661 |
-
if ((uint32_t)signXb < p.swLimit && signY >= minY)
|
662 |
-
{
|
663 |
-
int sx = __float_as_uint(v.x) >> 31 << 0;
|
664 |
-
int sy = __float_as_uint(v.y) >> 31 << 8;
|
665 |
-
if (sx) v.x *= p.slope;
|
666 |
-
if (sy) v.y *= p.slope;
|
667 |
-
if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }
|
668 |
-
if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }
|
669 |
-
|
670 |
-
// Combine signs.
|
671 |
-
int s = sx + sy;
|
672 |
-
s <<= signXo;
|
673 |
-
s |= __shfl_xor_sync(groupMask, s, 1);
|
674 |
-
s |= __shfl_xor_sync(groupMask, s, 2);
|
675 |
-
|
676 |
-
// Write signs.
|
677 |
-
if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
|
678 |
-
if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
|
679 |
-
}
|
680 |
-
else
|
681 |
-
{
|
682 |
-
// Just compute the values.
|
683 |
-
if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
|
684 |
-
if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
|
685 |
-
}
|
686 |
-
}
|
687 |
-
}
|
688 |
-
else if (signRead) // Read signs and apply.
|
689 |
-
{
|
690 |
-
if ((uint32_t)signXb < p.swLimit)
|
691 |
-
{
|
692 |
-
if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> signXo; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; }
|
693 |
-
if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> signXo; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; }
|
694 |
-
}
|
695 |
-
}
|
696 |
-
else // Forward pass with no sign write.
|
697 |
-
{
|
698 |
-
if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
|
699 |
-
if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
|
700 |
-
}
|
701 |
-
|
702 |
-
if (!downInline)
|
703 |
-
{
|
704 |
-
// Write into temporary buffer.
|
705 |
-
s_tileUpXY[dst] = v.x;
|
706 |
-
if (relUpY0 < tileUpH - 1)
|
707 |
-
s_tileUpXY[dst + tileUpW] = v.y;
|
708 |
-
}
|
709 |
-
else
|
710 |
-
{
|
711 |
-
// Write directly into output buffer.
|
712 |
-
if ((uint32_t)x < p.yShape.x)
|
713 |
-
{
|
714 |
-
int ymax = MIN(p.yShape.y, tileUpH + tileOutY * down);
|
715 |
-
index_t ofs = x * get_stride<index_t>(p.yStride.x) + y * get_stride<index_t>(p.yStride.y) + mapOfsOut;
|
716 |
-
if ((uint32_t)y + 0 < p.yShape.y) *((T*)((char*)p.y + ofs)) = (T)(v.x * (scalar_t)c_fd[0]);
|
717 |
-
if ((uint32_t)y + 1 < ymax) *((T*)((char*)p.y + ofs + get_stride<index_t>(p.yStride.y))) = (T)(v.y * (scalar_t)c_fd[0]);
|
718 |
-
}
|
719 |
-
}
|
720 |
-
}
|
721 |
-
}
|
722 |
-
}
|
723 |
-
else if (filterMode == MODE_FUSD || filterMode == MODE_FUFD)
|
724 |
-
{
|
725 |
-
// Full upsampling filter.
|
726 |
-
|
727 |
-
if (up == 2)
|
728 |
-
{
|
729 |
-
// 2 x 2-wide.
|
730 |
-
__syncthreads();
|
731 |
-
int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH + p.sOfs.y : 0; // Skip already written signs.
|
732 |
-
for (int idx = threadIdx.x * 4; idx < tileUpW * tileUpH; idx += blockDim.x * 4)
|
733 |
-
{
|
734 |
-
int relUpX0, relUpY0;
|
735 |
-
fast_div_mod<tileUpW>(relUpX0, relUpY0, idx);
|
736 |
-
int relInX0 = CEIL_DIV(relUpX0 - phaseInX, up);
|
737 |
-
int relInY0 = CEIL_DIV(relUpY0 - phaseInY, up);
|
738 |
-
int src0 = relInX0 + tileInW * relInY0;
|
739 |
-
int tap0y = (relInY0 * up + phaseInY - relUpY0);
|
740 |
-
|
741 |
-
#define X_LOOP(TAPY, PX) \
|
742 |
-
for (int sx = 0; sx < fuSize / up; sx++) \
|
743 |
-
{ \
|
744 |
-
v.x += a * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \
|
745 |
-
v.z += b * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 0) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \
|
746 |
-
v.y += a * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \
|
747 |
-
v.w += b * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 1) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \
|
748 |
-
}
|
749 |
-
|
750 |
-
vec4_t v = InternalType<T>::zero_vec4();
|
751 |
-
if (tap0y == 0 && phaseInX == 0)
|
752 |
-
#pragma unroll
|
753 |
-
for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
|
754 |
-
#pragma unroll
|
755 |
-
X_LOOP(0, 0) }
|
756 |
-
if (tap0y == 0 && phaseInX == 1)
|
757 |
-
#pragma unroll
|
758 |
-
for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
|
759 |
-
#pragma unroll
|
760 |
-
X_LOOP(0, 1) }
|
761 |
-
if (tap0y == 1 && phaseInX == 0)
|
762 |
-
#pragma unroll
|
763 |
-
for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
|
764 |
-
#pragma unroll
|
765 |
-
X_LOOP(1, 0) }
|
766 |
-
if (tap0y == 1 && phaseInX == 1)
|
767 |
-
#pragma unroll
|
768 |
-
for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
|
769 |
-
#pragma unroll
|
770 |
-
X_LOOP(1, 1) }
|
771 |
-
|
772 |
-
#undef X_LOOP
|
773 |
-
|
774 |
-
int x = tileOutX * down + relUpX0;
|
775 |
-
int y = tileOutY * down + relUpY0;
|
776 |
-
int signX = x + p.sOfs.x;
|
777 |
-
int signY = y + p.sOfs.y;
|
778 |
-
int signZ = blockIdx.z + p.blockZofs;
|
779 |
-
int signXb = signX >> 2;
|
780 |
-
index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
|
781 |
-
|
782 |
-
v.x *= (scalar_t)((float)up * (float)up * p.gain);
|
783 |
-
v.y *= (scalar_t)((float)up * (float)up * p.gain);
|
784 |
-
v.z *= (scalar_t)((float)up * (float)up * p.gain);
|
785 |
-
v.w *= (scalar_t)((float)up * (float)up * p.gain);
|
786 |
-
|
787 |
-
if (signWrite)
|
788 |
-
{
|
789 |
-
if (!enableWriteSkip)
|
790 |
-
{
|
791 |
-
// Determine and write signs.
|
792 |
-
int sx = __float_as_uint(v.x) >> 31;
|
793 |
-
int sy = __float_as_uint(v.y) >> 31;
|
794 |
-
int sz = __float_as_uint(v.z) >> 31;
|
795 |
-
int sw = __float_as_uint(v.w) >> 31;
|
796 |
-
if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType<T>::clamp(v.x, p.clamp); }
|
797 |
-
if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType<T>::clamp(v.y, p.clamp); }
|
798 |
-
if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType<T>::clamp(v.z, p.clamp); }
|
799 |
-
if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType<T>::clamp(v.w, p.clamp); }
|
800 |
-
|
801 |
-
if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
|
802 |
-
{
|
803 |
-
p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6);
|
804 |
-
}
|
805 |
-
}
|
806 |
-
else
|
807 |
-
{
|
808 |
-
// Determine and write signs.
|
809 |
-
if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
|
810 |
-
{
|
811 |
-
int sx = __float_as_uint(v.x) >> 31;
|
812 |
-
int sy = __float_as_uint(v.y) >> 31;
|
813 |
-
int sz = __float_as_uint(v.z) >> 31;
|
814 |
-
int sw = __float_as_uint(v.w) >> 31;
|
815 |
-
if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType<T>::clamp(v.x, p.clamp); }
|
816 |
-
if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType<T>::clamp(v.y, p.clamp); }
|
817 |
-
if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType<T>::clamp(v.z, p.clamp); }
|
818 |
-
if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType<T>::clamp(v.w, p.clamp); }
|
819 |
-
|
820 |
-
p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6);
|
821 |
-
}
|
822 |
-
else
|
823 |
-
{
|
824 |
-
// Just compute the values.
|
825 |
-
if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
|
826 |
-
if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
|
827 |
-
if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);
|
828 |
-
if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);
|
829 |
-
}
|
830 |
-
}
|
831 |
-
}
|
832 |
-
else if (signRead) // Read sign and apply.
|
833 |
-
{
|
834 |
-
if ((uint32_t)signY < p.sShape.y)
|
835 |
-
{
|
836 |
-
int s = 0;
|
837 |
-
if ((uint32_t)signXb < p.swLimit) s = p.s[si];
|
838 |
-
if ((uint32_t)signXb + 1 < p.swLimit) s |= p.s[si + 1] << 8;
|
839 |
-
s >>= (signX & 3) << 1;
|
840 |
-
if (s & 0x01) v.x *= p.slope; if (s & 0x02) v.x = 0.f;
|
841 |
-
if (s & 0x04) v.y *= p.slope; if (s & 0x08) v.y = 0.f;
|
842 |
-
if (s & 0x10) v.z *= p.slope; if (s & 0x20) v.z = 0.f;
|
843 |
-
if (s & 0x40) v.w *= p.slope; if (s & 0x80) v.w = 0.f;
|
844 |
-
}
|
845 |
-
}
|
846 |
-
else // Forward pass with no sign write.
|
847 |
-
{
|
848 |
-
if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
|
849 |
-
if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
|
850 |
-
if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);
|
851 |
-
if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);
|
852 |
-
}
|
853 |
-
|
854 |
-
s_tileUpXY[idx + 0] = v.x;
|
855 |
-
s_tileUpXY[idx + 1] = v.y;
|
856 |
-
s_tileUpXY[idx + 2] = v.z;
|
857 |
-
s_tileUpXY[idx + 3] = v.w;
|
858 |
-
}
|
859 |
-
}
|
860 |
-
else if (up == 1)
|
861 |
-
{
|
862 |
-
__syncthreads();
|
863 |
-
uint32_t groupMask = 15 << ((threadIdx.x & 31) & ~3);
|
864 |
-
int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs.
|
865 |
-
for (int idx = threadIdx.x; idx < tileUpW * tileUpH; idx += blockDim.x)
|
866 |
-
{
|
867 |
-
int relUpX0, relUpY0;
|
868 |
-
fast_div_mod<tileUpW>(relUpX0, relUpY0, idx);
|
869 |
-
scalar_t v = s_tileIn[idx] * (scalar_t)c_fu[0]; // 1x1 filter.
|
870 |
-
|
871 |
-
int x = tileOutX * down + relUpX0;
|
872 |
-
int y = tileOutY * down + relUpY0;
|
873 |
-
int signX = x + p.sOfs.x;
|
874 |
-
int signY = y + p.sOfs.y;
|
875 |
-
int signZ = blockIdx.z + p.blockZofs;
|
876 |
-
int signXb = signX >> 2;
|
877 |
-
index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
|
878 |
-
v *= (scalar_t)((float)up * (float)up * p.gain);
|
879 |
-
|
880 |
-
if (signWrite)
|
881 |
-
{
|
882 |
-
if (!enableWriteSkip)
|
883 |
-
{
|
884 |
-
// Determine and write sign.
|
885 |
-
uint32_t s = 0;
|
886 |
-
uint32_t signXbit = (1u << signXo);
|
887 |
-
if (v < 0.f)
|
888 |
-
{
|
889 |
-
s = signXbit;
|
890 |
-
v *= p.slope;
|
891 |
-
}
|
892 |
-
if (fabsf(v) > p.clamp)
|
893 |
-
{
|
894 |
-
s = signXbit * 2;
|
895 |
-
v = InternalType<T>::clamp(v, p.clamp);
|
896 |
-
}
|
897 |
-
if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
|
898 |
-
{
|
899 |
-
s += __shfl_xor_sync(groupMask, s, 1); // Coalesce.
|
900 |
-
s += __shfl_xor_sync(groupMask, s, 2); // Coalesce.
|
901 |
-
p.s[si] = s; // Write.
|
902 |
-
}
|
903 |
-
}
|
904 |
-
else
|
905 |
-
{
|
906 |
-
// Determine and write sign.
|
907 |
-
if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
|
908 |
-
{
|
909 |
-
uint32_t s = 0;
|
910 |
-
uint32_t signXbit = (1u << signXo);
|
911 |
-
if (v < 0.f)
|
912 |
-
{
|
913 |
-
s = signXbit;
|
914 |
-
v *= p.slope;
|
915 |
-
}
|
916 |
-
if (fabsf(v) > p.clamp)
|
917 |
-
{
|
918 |
-
s = signXbit * 2;
|
919 |
-
v = InternalType<T>::clamp(v, p.clamp);
|
920 |
-
}
|
921 |
-
s += __shfl_xor_sync(groupMask, s, 1); // Coalesce.
|
922 |
-
s += __shfl_xor_sync(groupMask, s, 2); // Coalesce.
|
923 |
-
p.s[si] = s; // Write.
|
924 |
-
}
|
925 |
-
else
|
926 |
-
{
|
927 |
-
// Just compute the value.
|
928 |
-
if (v < 0.f) v *= p.slope;
|
929 |
-
v = InternalType<T>::clamp(v, p.clamp);
|
930 |
-
}
|
931 |
-
}
|
932 |
-
}
|
933 |
-
else if (signRead)
|
934 |
-
{
|
935 |
-
// Read sign and apply if within sign tensor bounds.
|
936 |
-
if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y)
|
937 |
-
{
|
938 |
-
int s = p.s[si];
|
939 |
-
s >>= signXo;
|
940 |
-
if (s & 1) v *= p.slope;
|
941 |
-
if (s & 2) v = 0.f;
|
942 |
-
}
|
943 |
-
}
|
944 |
-
else // Forward pass with no sign write.
|
945 |
-
{
|
946 |
-
if (v < 0.f) v *= p.slope;
|
947 |
-
v = InternalType<T>::clamp(v, p.clamp);
|
948 |
-
}
|
949 |
-
|
950 |
-
if (!downInline) // Write into temporary buffer.
|
951 |
-
s_tileUpXY[idx] = v;
|
952 |
-
else if ((uint32_t)x < p.yShape.x && (uint32_t)y < p.yShape.y) // Write directly into output buffer
|
953 |
-
*((T*)((char*)p.y + (x * get_stride<index_t>(p.yStride.x) + y * get_stride<index_t>(p.yStride.y) + mapOfsOut))) = (T)(v * (scalar_t)c_fd[0]);
|
954 |
-
}
|
955 |
-
}
|
956 |
-
}
|
957 |
-
|
958 |
-
// Downsampling.
|
959 |
-
if (filterMode == MODE_SUSD || filterMode == MODE_FUSD)
|
960 |
-
{
|
961 |
-
// Horizontal downsampling.
|
962 |
-
__syncthreads();
|
963 |
-
if (down == 4 && tileOutW % 4 == 0)
|
964 |
-
{
|
965 |
-
// Calculate 4 pixels at a time.
|
966 |
-
for (int idx = threadIdx.x * 4; idx < tileOutW * tileUpH; idx += blockDim.x * 4)
|
967 |
-
{
|
968 |
-
int relOutX0, relUpY;
|
969 |
-
fast_div_mod<tileOutW>(relOutX0, relUpY, idx);
|
970 |
-
int relUpX0 = relOutX0 * down;
|
971 |
-
int src0 = relUpY * tileUpW + relUpX0;
|
972 |
-
vec4_t v = InternalType<T>::zero_vec4();
|
973 |
-
#pragma unroll
|
974 |
-
for (int step = 0; step < fdSize; step++)
|
975 |
-
{
|
976 |
-
v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step];
|
977 |
-
v.y += s_tileUpXY[src0 + 4 + step] * (scalar_t)c_fd[step];
|
978 |
-
v.z += s_tileUpXY[src0 + 8 + step] * (scalar_t)c_fd[step];
|
979 |
-
v.w += s_tileUpXY[src0 + 12 + step] * (scalar_t)c_fd[step];
|
980 |
-
}
|
981 |
-
s_tileDownX[idx+0] = v.x;
|
982 |
-
s_tileDownX[idx+1] = v.y;
|
983 |
-
s_tileDownX[idx+2] = v.z;
|
984 |
-
s_tileDownX[idx+3] = v.w;
|
985 |
-
}
|
986 |
-
}
|
987 |
-
else if ((down == 2 || down == 4) && (tileOutW % 2 == 0))
|
988 |
-
{
|
989 |
-
// Calculate 2 pixels at a time.
|
990 |
-
for (int idx = threadIdx.x * 2; idx < tileOutW * tileUpH; idx += blockDim.x * 2)
|
991 |
-
{
|
992 |
-
int relOutX0, relUpY;
|
993 |
-
fast_div_mod<tileOutW>(relOutX0, relUpY, idx);
|
994 |
-
int relUpX0 = relOutX0 * down;
|
995 |
-
int src0 = relUpY * tileUpW + relUpX0;
|
996 |
-
vec2_t v = InternalType<T>::zero_vec2();
|
997 |
-
#pragma unroll
|
998 |
-
for (int step = 0; step < fdSize; step++)
|
999 |
-
{
|
1000 |
-
v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step];
|
1001 |
-
v.y += s_tileUpXY[src0 + down + step] * (scalar_t)c_fd[step];
|
1002 |
-
}
|
1003 |
-
s_tileDownX[idx+0] = v.x;
|
1004 |
-
s_tileDownX[idx+1] = v.y;
|
1005 |
-
}
|
1006 |
-
}
|
1007 |
-
else
|
1008 |
-
{
|
1009 |
-
// Calculate 1 pixel at a time.
|
1010 |
-
for (int idx = threadIdx.x; idx < tileOutW * tileUpH; idx += blockDim.x)
|
1011 |
-
{
|
1012 |
-
int relOutX0, relUpY;
|
1013 |
-
fast_div_mod<tileOutW>(relOutX0, relUpY, idx);
|
1014 |
-
int relUpX0 = relOutX0 * down;
|
1015 |
-
int src = relUpY * tileUpW + relUpX0;
|
1016 |
-
scalar_t v = 0.f;
|
1017 |
-
#pragma unroll
|
1018 |
-
for (int step = 0; step < fdSize; step++)
|
1019 |
-
v += s_tileUpXY[src + step] * (scalar_t)c_fd[step];
|
1020 |
-
s_tileDownX[idx] = v;
|
1021 |
-
}
|
1022 |
-
}
|
1023 |
-
|
1024 |
-
// Vertical downsampling & store output tile.
|
1025 |
-
__syncthreads();
|
1026 |
-
for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x)
|
1027 |
-
{
|
1028 |
-
int relOutX, relOutY0;
|
1029 |
-
fast_div_mod<tileOutW>(relOutX, relOutY0, idx);
|
1030 |
-
int relUpY0 = relOutY0 * down;
|
1031 |
-
int src0 = relUpY0 * tileOutW + relOutX;
|
1032 |
-
scalar_t v = 0;
|
1033 |
-
#pragma unroll
|
1034 |
-
for (int step = 0; step < fdSize; step++)
|
1035 |
-
v += s_tileDownX[src0 + step * tileOutW] * (scalar_t)c_fd[step];
|
1036 |
-
|
1037 |
-
int outX = tileOutX + relOutX;
|
1038 |
-
int outY = tileOutY + relOutY0;
|
1039 |
-
|
1040 |
-
if (outX < p.yShape.x & outY < p.yShape.y)
|
1041 |
-
*((T*)((char*)p.y + (outX * get_stride<index_t>(p.yStride.x) + outY * get_stride<index_t>(p.yStride.y) + mapOfsOut))) = (T)v;
|
1042 |
-
}
|
1043 |
-
}
|
1044 |
-
else if (filterMode == MODE_SUFD || filterMode == MODE_FUFD)
|
1045 |
-
{
|
1046 |
-
// Full downsampling filter.
|
1047 |
-
if (down == 2)
|
1048 |
-
{
|
1049 |
-
// 2-wide.
|
1050 |
-
__syncthreads();
|
1051 |
-
for (int idx = threadIdx.x * 2; idx < tileOutW * tileOutH; idx += blockDim.x * 2)
|
1052 |
-
{
|
1053 |
-
int relOutX0, relOutY0;
|
1054 |
-
fast_div_mod<tileOutW>(relOutX0, relOutY0, idx);
|
1055 |
-
int relUpX0 = relOutX0 * down;
|
1056 |
-
int relUpY0 = relOutY0 * down;
|
1057 |
-
int src0 = relUpY0 * tileUpW + relUpX0;
|
1058 |
-
vec2_t v = InternalType<T>::zero_vec2();
|
1059 |
-
#pragma unroll
|
1060 |
-
for (int sy = 0; sy < fdSize; sy++)
|
1061 |
-
#pragma unroll
|
1062 |
-
for (int sx = 0; sx < fdSize; sx++)
|
1063 |
-
{
|
1064 |
-
v.x += s_tileUpXY[src0 + 0 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE];
|
1065 |
-
v.y += s_tileUpXY[src0 + 2 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE];
|
1066 |
-
}
|
1067 |
-
|
1068 |
-
int outX = tileOutX + relOutX0;
|
1069 |
-
int outY = tileOutY + relOutY0;
|
1070 |
-
if ((uint32_t)outY < p.yShape.y)
|
1071 |
-
{
|
1072 |
-
index_t ofs = outX * get_stride<index_t>(p.yStride.x) + outY * get_stride<index_t>(p.yStride.y) + mapOfsOut;
|
1073 |
-
if (outX + 0 < p.yShape.x) *((T*)((char*)p.y + ofs)) = (T)v.x;
|
1074 |
-
if (outX + 1 < p.yShape.x) *((T*)((char*)p.y + ofs + get_stride<index_t>(p.yStride.x))) = (T)v.y;
|
1075 |
-
}
|
1076 |
-
}
|
1077 |
-
}
|
1078 |
-
else if (down == 1 && !downInline)
|
1079 |
-
{
|
1080 |
-
// Thread per pixel.
|
1081 |
-
__syncthreads();
|
1082 |
-
for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x)
|
1083 |
-
{
|
1084 |
-
int relOutX0, relOutY0;
|
1085 |
-
fast_div_mod<tileOutW>(relOutX0, relOutY0, idx);
|
1086 |
-
scalar_t v = s_tileUpXY[idx] * (scalar_t)c_fd[0]; // 1x1 filter.
|
1087 |
-
|
1088 |
-
int outX = tileOutX + relOutX0;
|
1089 |
-
int outY = tileOutY + relOutY0;
|
1090 |
-
if ((uint32_t)outX < p.yShape.x && (uint32_t)outY < p.yShape.y)
|
1091 |
-
*((T*)((char*)p.y + (outX * get_stride<index_t>(p.yStride.x) + outY * get_stride<index_t>(p.yStride.y) + mapOfsOut))) = (T)v;
|
1092 |
-
}
|
1093 |
-
}
|
1094 |
-
}
|
1095 |
-
|
1096 |
-
if (!enableXrep)
|
1097 |
-
break;
|
1098 |
-
}
|
1099 |
-
}
|
1100 |
-
|
1101 |
-
//------------------------------------------------------------------------
|
1102 |
-
// Compute activation function and signs for upsampled data tensor, modifying data tensor in-place. Used for accelerating the generic variant.
|
1103 |
-
// Sign tensor is known to be contiguous, and p.x and p.s have the same z, w dimensions. 64-bit indexing is always used.
|
1104 |
-
|
1105 |
-
template <class T, bool signWrite, bool signRead>
|
1106 |
-
static __global__ void filtered_lrelu_act_kernel(filtered_lrelu_act_kernel_params p)
|
1107 |
-
{
|
1108 |
-
typedef typename InternalType<T>::scalar_t scalar_t;
|
1109 |
-
|
1110 |
-
// Indexing.
|
1111 |
-
int32_t x = threadIdx.x + blockIdx.x * blockDim.x;
|
1112 |
-
int32_t ymax = signWrite ? p.sShape.y : p.xShape.y;
|
1113 |
-
int32_t qmax = p.xShape.z * p.xShape.w; // Combined minibatch*channel maximum index.
|
1114 |
-
|
1115 |
-
// Loop to accommodate oversized tensors.
|
1116 |
-
for (int32_t q = blockIdx.z; q < qmax; q += gridDim.z)
|
1117 |
-
for (int32_t y = blockIdx.y; y < ymax; y += gridDim.y)
|
1118 |
-
{
|
1119 |
-
// Extract z and w (channel, minibatch index).
|
1120 |
-
int32_t w = q / p.xShape.z;
|
1121 |
-
int32_t z = q - w * p.xShape.z;
|
1122 |
-
|
1123 |
-
// Choose behavior based on sign read/write mode.
|
1124 |
-
if (signWrite)
|
1125 |
-
{
|
1126 |
-
// Process value if in p.x.
|
1127 |
-
uint32_t s = 0;
|
1128 |
-
if (x < p.xShape.x && y < p.xShape.y)
|
1129 |
-
{
|
1130 |
-
int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
|
1131 |
-
T* pv = ((T*)p.x) + ix;
|
1132 |
-
scalar_t v = (scalar_t)(*pv);
|
1133 |
-
|
1134 |
-
// Gain, LReLU, clamp.
|
1135 |
-
v *= p.gain;
|
1136 |
-
if (v < 0.f)
|
1137 |
-
{
|
1138 |
-
v *= p.slope;
|
1139 |
-
s = 1; // Sign.
|
1140 |
-
}
|
1141 |
-
if (fabsf(v) > p.clamp)
|
1142 |
-
{
|
1143 |
-
v = InternalType<T>::clamp(v, p.clamp);
|
1144 |
-
s = 2; // Clamp.
|
1145 |
-
}
|
1146 |
-
|
1147 |
-
*pv = (T)v; // Write value.
|
1148 |
-
}
|
1149 |
-
|
1150 |
-
// Coalesce into threads 0 and 16 of warp.
|
1151 |
-
uint32_t m = (threadIdx.x & 16) ? 0xffff0000u : 0x0000ffffu;
|
1152 |
-
s <<= ((threadIdx.x & 15) << 1); // Shift into place.
|
1153 |
-
s |= __shfl_xor_sync(m, s, 1); // Distribute.
|
1154 |
-
s |= __shfl_xor_sync(m, s, 2);
|
1155 |
-
s |= __shfl_xor_sync(m, s, 4);
|
1156 |
-
s |= __shfl_xor_sync(m, s, 8);
|
1157 |
-
|
1158 |
-
// Write signs if leader and in p.s.
|
1159 |
-
if (!(threadIdx.x & 15) && x < p.sShape.x) // y is always in.
|
1160 |
-
{
|
1161 |
-
uint64_t is = x + p.sShape.x * (y + (int64_t)p.sShape.y * q); // Contiguous.
|
1162 |
-
((uint32_t*)p.s)[is >> 4] = s;
|
1163 |
-
}
|
1164 |
-
}
|
1165 |
-
else if (signRead)
|
1166 |
-
{
|
1167 |
-
// Process value if in p.x.
|
1168 |
-
if (x < p.xShape.x) // y is always in.
|
1169 |
-
{
|
1170 |
-
int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
|
1171 |
-
T* pv = ((T*)p.x) + ix;
|
1172 |
-
scalar_t v = (scalar_t)(*pv);
|
1173 |
-
v *= p.gain;
|
1174 |
-
|
1175 |
-
// Apply sign buffer offset.
|
1176 |
-
uint32_t sx = x + p.sOfs.x;
|
1177 |
-
uint32_t sy = y + p.sOfs.y;
|
1178 |
-
|
1179 |
-
// Read and apply signs if we land inside valid region of sign buffer.
|
1180 |
-
if (sx < p.sShape.x && sy < p.sShape.y)
|
1181 |
-
{
|
1182 |
-
uint64_t is = (sx >> 2) + (p.sShape.x >> 2) * (sy + (uint64_t)p.sShape.y * q); // Contiguous.
|
1183 |
-
unsigned char s = p.s[is];
|
1184 |
-
s >>= (sx & 3) << 1; // Shift into place.
|
1185 |
-
if (s & 1) // Sign?
|
1186 |
-
v *= p.slope;
|
1187 |
-
if (s & 2) // Clamp?
|
1188 |
-
v = 0.f;
|
1189 |
-
}
|
1190 |
-
|
1191 |
-
*pv = (T)v; // Write value.
|
1192 |
-
}
|
1193 |
-
}
|
1194 |
-
else
|
1195 |
-
{
|
1196 |
-
// Forward pass with no sign write. Process value if in p.x.
|
1197 |
-
if (x < p.xShape.x) // y is always in.
|
1198 |
-
{
|
1199 |
-
int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
|
1200 |
-
T* pv = ((T*)p.x) + ix;
|
1201 |
-
scalar_t v = (scalar_t)(*pv);
|
1202 |
-
v *= p.gain;
|
1203 |
-
if (v < 0.f)
|
1204 |
-
v *= p.slope;
|
1205 |
-
if (fabsf(v) > p.clamp)
|
1206 |
-
v = InternalType<T>::clamp(v, p.clamp);
|
1207 |
-
*pv = (T)v; // Write value.
|
1208 |
-
}
|
1209 |
-
}
|
1210 |
-
}
|
1211 |
-
}
|
1212 |
-
|
1213 |
-
template <class T, bool signWrite, bool signRead> void* choose_filtered_lrelu_act_kernel(void)
|
1214 |
-
{
|
1215 |
-
return (void*)filtered_lrelu_act_kernel<T, signWrite, signRead>;
|
1216 |
-
}
|
1217 |
-
|
1218 |
-
//------------------------------------------------------------------------
|
1219 |
-
// CUDA kernel selection.
|
1220 |
-
|
1221 |
-
template <class T, class index_t, bool signWrite, bool signRead> filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB)
|
1222 |
-
{
|
1223 |
-
filtered_lrelu_kernel_spec s = { 0 };
|
1224 |
-
|
1225 |
-
// Return the first matching kernel.
|
1226 |
-
#define CASE(SH, U, FU, D, FD, MODE, TW, TH, W, XR, WS) \
|
1227 |
-
if (sharedKB >= SH) \
|
1228 |
-
if ((p.fuShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_SUFD)) || (p.fuShape.y > 0 && (MODE == MODE_FUSD || MODE == MODE_FUFD))) \
|
1229 |
-
if ((p.fdShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_FUSD)) || (p.fdShape.y > 0 && (MODE == MODE_SUFD || MODE == MODE_FUFD))) \
|
1230 |
-
if (p.up == U && p.fuShape.x <= FU && p.fuShape.y <= FU && p.down == D && p.fdShape.x <= FD && p.fdShape.y <= FD) \
|
1231 |
-
{ \
|
1232 |
-
static_assert((D*TW % 4) == 0, "down * tileWidth must be divisible by 4"); \
|
1233 |
-
static_assert(FU % U == 0, "upscaling filter size must be multiple of upscaling factor"); \
|
1234 |
-
static_assert(FD % D == 0, "downscaling filter size must be multiple of downscaling factor"); \
|
1235 |
-
s.setup = (void*)setup_filters_kernel; \
|
1236 |
-
s.exec = (void*)filtered_lrelu_kernel<T, index_t, SH, signWrite, signRead, MODE, U, FU, D, FD, TW, TH, W*32, !!XR, !!WS>; \
|
1237 |
-
s.tileOut = make_int2(TW, TH); \
|
1238 |
-
s.numWarps = W; \
|
1239 |
-
s.xrep = XR; \
|
1240 |
-
s.dynamicSharedKB = (SH == 48) ? 0 : SH; \
|
1241 |
-
return s; \
|
1242 |
-
}
|
1243 |
-
|
1244 |
-
// Launch parameters for various kernel specializations.
|
1245 |
-
// Small filters must be listed before large filters, otherwise the kernel for larger filter will always match first.
|
1246 |
-
// Kernels that use more shared memory must be listed before those that use less, for the same reason.
|
1247 |
-
|
1248 |
-
CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/1,1, /*mode*/MODE_FUFD, /*tw,th,warps,xrep,wskip*/64, 178, 32, 0, 0) // 1t-upf1-downf1
|
1249 |
-
CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/152, 95, 16, 0, 0) // 4t-ups2-downf1
|
1250 |
-
CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 22, 16, 0, 0) // 4t-upf1-downs2
|
1251 |
-
CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 29, 16, 11, 0) // 4t-ups2-downs2
|
1252 |
-
CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/60, 28, 16, 0, 0) // 4t-upf2-downs2
|
1253 |
-
CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 28, 16, 0, 0) // 4t-ups2-downf2
|
1254 |
-
CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 31, 16, 11, 0) // 4t-ups4-downs2
|
1255 |
-
CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 36, 16, 0, 0) // 4t-ups4-downf2
|
1256 |
-
CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 22, 16, 12, 0) // 4t-ups2-downs4
|
1257 |
-
CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/29, 15, 16, 0, 0) // 4t-upf2-downs4
|
1258 |
-
CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/96, 150, 28, 0, 0) // 6t-ups2-downf1
|
1259 |
-
CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 35, 24, 0, 0) // 6t-upf1-downs2
|
1260 |
-
CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 16, 10, 0) // 6t-ups2-downs2
|
1261 |
-
CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/58, 28, 24, 8, 0) // 6t-upf2-downs2
|
1262 |
-
CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/52, 28, 16, 0, 0) // 6t-ups2-downf2
|
1263 |
-
CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 51, 16, 5, 0) // 6t-ups4-downs2
|
1264 |
-
CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 56, 16, 6, 0) // 6t-ups4-downf2
|
1265 |
-
CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 18, 16, 12, 0) // 6t-ups2-downs4
|
1266 |
-
CASE(/*sharedKB*/96, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 31, 32, 6, 0) // 6t-upf2-downs4 96kB
|
1267 |
-
CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 13, 24, 0, 0) // 6t-upf2-downs4
|
1268 |
-
CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/148, 89, 24, 0, 0) // 8t-ups2-downf1
|
1269 |
-
CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 31, 16, 5, 0) // 8t-upf1-downs2
|
1270 |
-
CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 41, 16, 9, 0) // 8t-ups2-downs2
|
1271 |
-
CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 26, 24, 0, 0) // 8t-upf2-downs2
|
1272 |
-
CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 40, 16, 0, 0) // 8t-ups2-downf2
|
1273 |
-
CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 24, 5, 0) // 8t-ups4-downs2
|
1274 |
-
CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 50, 16, 0, 0) // 8t-ups4-downf2
|
1275 |
-
CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/24, 24, 32, 12, 1) // 8t-ups2-downs4 96kB
|
1276 |
-
CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 13, 16, 10, 1) // 8t-ups2-downs4
|
1277 |
-
CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 28, 28, 4, 0) // 8t-upf2-downs4 96kB
|
1278 |
-
CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 10, 24, 0, 0) // 8t-upf2-downs4
|
1279 |
-
|
1280 |
-
#undef CASE
|
1281 |
-
return s; // No kernel found.
|
1282 |
-
}
|
1283 |
-
|
1284 |
-
//------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch_utils/ops/filtered_lrelu.h
DELETED
@@ -1,90 +0,0 @@
|
|
1 |
-
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
//
|
3 |
-
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
-
// and proprietary rights in and to this software, related documentation
|
5 |
-
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
// distribution of this software and related documentation without an express
|
7 |
-
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
-
|
9 |
-
#include <cuda_runtime.h>
|
10 |
-
|
11 |
-
//------------------------------------------------------------------------
|
12 |
-
// CUDA kernel parameters.
|
13 |
-
|
14 |
-
struct filtered_lrelu_kernel_params
|
15 |
-
{
|
16 |
-
// These parameters decide which kernel to use.
|
17 |
-
int up; // upsampling ratio (1, 2, 4)
|
18 |
-
int down; // downsampling ratio (1, 2, 4)
|
19 |
-
int2 fuShape; // [size, 1] | [size, size]
|
20 |
-
int2 fdShape; // [size, 1] | [size, size]
|
21 |
-
|
22 |
-
int _dummy; // Alignment.
|
23 |
-
|
24 |
-
// Rest of the parameters.
|
25 |
-
const void* x; // Input tensor.
|
26 |
-
void* y; // Output tensor.
|
27 |
-
const void* b; // Bias tensor.
|
28 |
-
unsigned char* s; // Sign tensor in/out. NULL if unused.
|
29 |
-
const float* fu; // Upsampling filter.
|
30 |
-
const float* fd; // Downsampling filter.
|
31 |
-
|
32 |
-
int2 pad0; // Left/top padding.
|
33 |
-
float gain; // Additional gain factor.
|
34 |
-
float slope; // Leaky ReLU slope on negative side.
|
35 |
-
float clamp; // Clamp after nonlinearity.
|
36 |
-
int flip; // Filter kernel flip for gradient computation.
|
37 |
-
|
38 |
-
int tilesXdim; // Original number of horizontal output tiles.
|
39 |
-
int tilesXrep; // Number of horizontal tiles per CTA.
|
40 |
-
int blockZofs; // Block z offset to support large minibatch, channel dimensions.
|
41 |
-
|
42 |
-
int4 xShape; // [width, height, channel, batch]
|
43 |
-
int4 yShape; // [width, height, channel, batch]
|
44 |
-
int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused.
|
45 |
-
int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
|
46 |
-
int swLimit; // Active width of sign tensor in bytes.
|
47 |
-
|
48 |
-
longlong4 xStride; // Strides of all tensors except signs, same component order as shapes.
|
49 |
-
longlong4 yStride; //
|
50 |
-
int64_t bStride; //
|
51 |
-
longlong3 fuStride; //
|
52 |
-
longlong3 fdStride; //
|
53 |
-
};
|
54 |
-
|
55 |
-
struct filtered_lrelu_act_kernel_params
|
56 |
-
{
|
57 |
-
void* x; // Input/output, modified in-place.
|
58 |
-
unsigned char* s; // Sign tensor in/out. NULL if unused.
|
59 |
-
|
60 |
-
float gain; // Additional gain factor.
|
61 |
-
float slope; // Leaky ReLU slope on negative side.
|
62 |
-
float clamp; // Clamp after nonlinearity.
|
63 |
-
|
64 |
-
int4 xShape; // [width, height, channel, batch]
|
65 |
-
longlong4 xStride; // Input/output tensor strides, same order as in shape.
|
66 |
-
int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused.
|
67 |
-
int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
|
68 |
-
};
|
69 |
-
|
70 |
-
//------------------------------------------------------------------------
|
71 |
-
// CUDA kernel specialization.
|
72 |
-
|
73 |
-
struct filtered_lrelu_kernel_spec
|
74 |
-
{
|
75 |
-
void* setup; // Function for filter kernel setup.
|
76 |
-
void* exec; // Function for main operation.
|
77 |
-
int2 tileOut; // Width/height of launch tile.
|
78 |
-
int numWarps; // Number of warps per thread block, determines launch block size.
|
79 |
-
int xrep; // For processing multiple horizontal tiles per thread block.
|
80 |
-
int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants.
|
81 |
-
};
|
82 |
-
|
83 |
-
//------------------------------------------------------------------------
|
84 |
-
// CUDA kernel selection.
|
85 |
-
|
86 |
-
template <class T, class index_t, bool signWrite, bool signRead> filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
|
87 |
-
template <class T, bool signWrite, bool signRead> void* choose_filtered_lrelu_act_kernel(void);
|
88 |
-
template <bool signWrite, bool signRead> cudaError_t copy_filters(cudaStream_t stream);
|
89 |
-
|
90 |
-
//------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch_utils/ops/filtered_lrelu.py
DELETED
@@ -1,282 +0,0 @@
|
|
1 |
-
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
#
|
3 |
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
-
# and proprietary rights in and to this software, related documentation
|
5 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
# distribution of this software and related documentation without an express
|
7 |
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
-
|
9 |
-
import os
|
10 |
-
import numpy as np
|
11 |
-
import torch
|
12 |
-
import warnings
|
13 |
-
|
14 |
-
from .. import custom_ops
|
15 |
-
from .. import misc
|
16 |
-
from . import upfirdn2d
|
17 |
-
from . import bias_act
|
18 |
-
|
19 |
-
#----------------------------------------------------------------------------
|
20 |
-
|
21 |
-
_plugin = None
|
22 |
-
|
23 |
-
def _init():
|
24 |
-
global _plugin
|
25 |
-
if _plugin is None:
|
26 |
-
|
27 |
-
# sources=['filtered_lrelu.h', 'filtered_lrelu.cu', 'filtered_lrelu.cpp', 'filtered_lrelu_wr.cu', 'filtered_lrelu_rd.cu', 'filtered_lrelu_ns.cu']
|
28 |
-
# sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
|
29 |
-
# try:
|
30 |
-
# _plugin = custom_ops.get_plugin('filtered_lrelu_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math', '--allow-unsupported-compiler'])
|
31 |
-
# except:
|
32 |
-
# warnings.warn('Failed to build CUDA kernels for filtered_lrelu_plugin. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
|
33 |
-
|
34 |
-
_plugin = custom_ops.get_plugin_v3(
|
35 |
-
module_name='filtered_lrelu_plugin',
|
36 |
-
sources=['filtered_lrelu.cpp', 'filtered_lrelu_wr.cu', 'filtered_lrelu_rd.cu', 'filtered_lrelu_ns.cu'],
|
37 |
-
headers=['filtered_lrelu.h', 'filtered_lrelu.cu'],
|
38 |
-
source_dir=os.path.dirname(__file__),
|
39 |
-
extra_cuda_cflags=['--use_fast_math', '--allow-unsupported-compiler'],
|
40 |
-
)
|
41 |
-
return True
|
42 |
-
|
43 |
-
def _get_filter_size(f):
|
44 |
-
if f is None:
|
45 |
-
return 1, 1
|
46 |
-
assert isinstance(f, torch.Tensor)
|
47 |
-
assert 1 <= f.ndim <= 2
|
48 |
-
return f.shape[-1], f.shape[0] # width, height
|
49 |
-
|
50 |
-
def _parse_padding(padding):
|
51 |
-
if isinstance(padding, int):
|
52 |
-
padding = [padding, padding]
|
53 |
-
assert isinstance(padding, (list, tuple))
|
54 |
-
assert all(isinstance(x, (int, np.integer)) for x in padding)
|
55 |
-
padding = [int(x) for x in padding]
|
56 |
-
if len(padding) == 2:
|
57 |
-
px, py = padding
|
58 |
-
padding = [px, px, py, py]
|
59 |
-
px0, px1, py0, py1 = padding
|
60 |
-
return px0, px1, py0, py1
|
61 |
-
|
62 |
-
#----------------------------------------------------------------------------
|
63 |
-
|
64 |
-
def filtered_lrelu(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False, impl='cuda'):
|
65 |
-
r"""Filtered leaky ReLU for a batch of 2D images.
|
66 |
-
|
67 |
-
Performs the following sequence of operations for each channel:
|
68 |
-
|
69 |
-
1. Add channel-specific bias if provided (`b`).
|
70 |
-
|
71 |
-
2. Upsample the image by inserting N-1 zeros after each pixel (`up`).
|
72 |
-
|
73 |
-
3. Pad the image with the specified number of zeros on each side (`padding`).
|
74 |
-
Negative padding corresponds to cropping the image.
|
75 |
-
|
76 |
-
4. Convolve the image with the specified upsampling FIR filter (`fu`), shrinking it
|
77 |
-
so that the footprint of all output pixels lies within the input image.
|
78 |
-
|
79 |
-
5. Multiply each value by the provided gain factor (`gain`).
|
80 |
-
|
81 |
-
6. Apply leaky ReLU activation function to each value.
|
82 |
-
|
83 |
-
7. Clamp each value between -clamp and +clamp, if `clamp` parameter is provided.
|
84 |
-
|
85 |
-
8. Convolve the image with the specified downsampling FIR filter (`fd`), shrinking
|
86 |
-
it so that the footprint of all output pixels lies within the input image.
|
87 |
-
|
88 |
-
9. Downsample the image by keeping every Nth pixel (`down`).
|
89 |
-
|
90 |
-
The fused op is considerably more efficient than performing the same calculation
|
91 |
-
using standard PyTorch ops. It supports gradients of arbitrary order.
|
92 |
-
|
93 |
-
Args:
|
94 |
-
x: Float32/float16/float64 input tensor of the shape
|
95 |
-
`[batch_size, num_channels, in_height, in_width]`.
|
96 |
-
fu: Float32 upsampling FIR filter of the shape
|
97 |
-
`[filter_height, filter_width]` (non-separable),
|
98 |
-
`[filter_taps]` (separable), or
|
99 |
-
`None` (identity).
|
100 |
-
fd: Float32 downsampling FIR filter of the shape
|
101 |
-
`[filter_height, filter_width]` (non-separable),
|
102 |
-
`[filter_taps]` (separable), or
|
103 |
-
`None` (identity).
|
104 |
-
b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
|
105 |
-
as `x`. The length of vector must must match the channel dimension of `x`.
|
106 |
-
up: Integer upsampling factor (default: 1).
|
107 |
-
down: Integer downsampling factor. (default: 1).
|
108 |
-
padding: Padding with respect to the upsampled image. Can be a single number
|
109 |
-
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
110 |
-
(default: 0).
|
111 |
-
gain: Overall scaling factor for signal magnitude (default: sqrt(2)).
|
112 |
-
slope: Slope on the negative side of leaky ReLU (default: 0.2).
|
113 |
-
clamp: Maximum magnitude for leaky ReLU output (default: None).
|
114 |
-
flip_filter: False = convolution, True = correlation (default: False).
|
115 |
-
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
116 |
-
|
117 |
-
Returns:
|
118 |
-
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
119 |
-
"""
|
120 |
-
assert isinstance(x, torch.Tensor)
|
121 |
-
assert impl in ['ref', 'cuda']
|
122 |
-
if impl == 'cuda' and x.device.type == 'cuda' and _init():
|
123 |
-
return _filtered_lrelu_cuda(up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter).apply(x, fu, fd, b, None, 0, 0)
|
124 |
-
return _filtered_lrelu_ref(x, fu=fu, fd=fd, b=b, up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter)
|
125 |
-
|
126 |
-
#----------------------------------------------------------------------------
|
127 |
-
|
128 |
-
@misc.profiled_function
|
129 |
-
def _filtered_lrelu_ref(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False):
|
130 |
-
"""Slow and memory-inefficient reference implementation of `filtered_lrelu()` using
|
131 |
-
existing `upfirdn2n()` and `bias_act()` ops.
|
132 |
-
"""
|
133 |
-
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
134 |
-
fu_w, fu_h = _get_filter_size(fu)
|
135 |
-
fd_w, fd_h = _get_filter_size(fd)
|
136 |
-
if b is not None:
|
137 |
-
assert isinstance(b, torch.Tensor) and b.dtype == x.dtype
|
138 |
-
misc.assert_shape(b, [x.shape[1]])
|
139 |
-
assert isinstance(up, int) and up >= 1
|
140 |
-
assert isinstance(down, int) and down >= 1
|
141 |
-
px0, px1, py0, py1 = _parse_padding(padding)
|
142 |
-
assert gain == float(gain) and gain > 0
|
143 |
-
assert slope == float(slope) and slope >= 0
|
144 |
-
assert clamp is None or (clamp == float(clamp) and clamp >= 0)
|
145 |
-
|
146 |
-
# Calculate output size.
|
147 |
-
batch_size, channels, in_h, in_w = x.shape
|
148 |
-
in_dtype = x.dtype
|
149 |
-
out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down
|
150 |
-
out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down
|
151 |
-
|
152 |
-
# Compute using existing ops.
|
153 |
-
x = bias_act.bias_act(x=x, b=b) # Apply bias.
|
154 |
-
x = upfirdn2d.upfirdn2d(x=x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample.
|
155 |
-
x = bias_act.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) # Bias, leaky ReLU, clamp.
|
156 |
-
x = upfirdn2d.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) # Downsample.
|
157 |
-
|
158 |
-
# Check output shape & dtype.
|
159 |
-
misc.assert_shape(x, [batch_size, channels, out_h, out_w])
|
160 |
-
assert x.dtype == in_dtype
|
161 |
-
return x
|
162 |
-
|
163 |
-
#----------------------------------------------------------------------------
|
164 |
-
|
165 |
-
_filtered_lrelu_cuda_cache = dict()
|
166 |
-
|
167 |
-
def _filtered_lrelu_cuda(up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False):
|
168 |
-
"""Fast CUDA implementation of `filtered_lrelu()` using custom ops.
|
169 |
-
"""
|
170 |
-
assert isinstance(up, int) and up >= 1
|
171 |
-
assert isinstance(down, int) and down >= 1
|
172 |
-
px0, px1, py0, py1 = _parse_padding(padding)
|
173 |
-
assert gain == float(gain) and gain > 0
|
174 |
-
gain = float(gain)
|
175 |
-
assert slope == float(slope) and slope >= 0
|
176 |
-
slope = float(slope)
|
177 |
-
assert clamp is None or (clamp == float(clamp) and clamp >= 0)
|
178 |
-
clamp = float(clamp if clamp is not None else 'inf')
|
179 |
-
|
180 |
-
# Lookup from cache.
|
181 |
-
key = (up, down, px0, px1, py0, py1, gain, slope, clamp, flip_filter)
|
182 |
-
if key in _filtered_lrelu_cuda_cache:
|
183 |
-
return _filtered_lrelu_cuda_cache[key]
|
184 |
-
|
185 |
-
# Forward op.
|
186 |
-
class FilteredLReluCuda(torch.autograd.Function):
|
187 |
-
@staticmethod
|
188 |
-
def forward(ctx, x, fu, fd, b, si, sx, sy): # pylint: disable=arguments-differ
|
189 |
-
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
190 |
-
|
191 |
-
# Replace empty up/downsample kernels with full 1x1 kernels (faster than separable).
|
192 |
-
if fu is None:
|
193 |
-
fu = torch.ones([1, 1], dtype=torch.float32, device=x.device)
|
194 |
-
if fd is None:
|
195 |
-
fd = torch.ones([1, 1], dtype=torch.float32, device=x.device)
|
196 |
-
assert 1 <= fu.ndim <= 2
|
197 |
-
assert 1 <= fd.ndim <= 2
|
198 |
-
|
199 |
-
# Replace separable 1x1 kernels with full 1x1 kernels when scale factor is 1.
|
200 |
-
if up == 1 and fu.ndim == 1 and fu.shape[0] == 1:
|
201 |
-
fu = fu.square()[None]
|
202 |
-
if down == 1 and fd.ndim == 1 and fd.shape[0] == 1:
|
203 |
-
fd = fd.square()[None]
|
204 |
-
|
205 |
-
# Missing sign input tensor.
|
206 |
-
if si is None:
|
207 |
-
si = torch.empty([0])
|
208 |
-
|
209 |
-
# Missing bias tensor.
|
210 |
-
if b is None:
|
211 |
-
b = torch.zeros([x.shape[1]], dtype=x.dtype, device=x.device)
|
212 |
-
|
213 |
-
# Construct internal sign tensor only if gradients are needed.
|
214 |
-
write_signs = (si.numel() == 0) and (x.requires_grad or b.requires_grad)
|
215 |
-
|
216 |
-
# Warn if input storage strides are not in decreasing order due to e.g. channels-last layout.
|
217 |
-
strides = [x.stride(i) for i in range(x.ndim) if x.size(i) > 1]
|
218 |
-
if any(a < b for a, b in zip(strides[:-1], strides[1:])):
|
219 |
-
warnings.warn("low-performance memory layout detected in filtered_lrelu input", RuntimeWarning)
|
220 |
-
|
221 |
-
# Call C++/Cuda plugin if datatype is supported.
|
222 |
-
if x.dtype in [torch.float16, torch.float32]:
|
223 |
-
if torch.cuda.current_stream(x.device) != torch.cuda.default_stream(x.device):
|
224 |
-
warnings.warn("filtered_lrelu called with non-default cuda stream but concurrent execution is not supported", RuntimeWarning)
|
225 |
-
y, so, return_code = _plugin.filtered_lrelu(x, fu, fd, b, si, up, down, px0, px1, py0, py1, sx, sy, gain, slope, clamp, flip_filter, write_signs)
|
226 |
-
else:
|
227 |
-
return_code = -1
|
228 |
-
|
229 |
-
# No Cuda kernel found? Fall back to generic implementation. Still more memory efficient than the reference implementation because
|
230 |
-
# only the bit-packed sign tensor is retained for gradient computation.
|
231 |
-
if return_code < 0:
|
232 |
-
warnings.warn("filtered_lrelu called with parameters that have no optimized CUDA kernel, using generic fallback", RuntimeWarning)
|
233 |
-
|
234 |
-
y = x.add(b.unsqueeze(-1).unsqueeze(-1)) # Add bias.
|
235 |
-
y = upfirdn2d.upfirdn2d(x=y, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample.
|
236 |
-
so = _plugin.filtered_lrelu_act_(y, si, sx, sy, gain, slope, clamp, write_signs) # Activation function and sign handling. Modifies y in-place.
|
237 |
-
y = upfirdn2d.upfirdn2d(x=y, f=fd, down=down, flip_filter=flip_filter) # Downsample.
|
238 |
-
|
239 |
-
# Prepare for gradient computation.
|
240 |
-
ctx.save_for_backward(fu, fd, (si if si.numel() else so))
|
241 |
-
ctx.x_shape = x.shape
|
242 |
-
ctx.y_shape = y.shape
|
243 |
-
ctx.s_ofs = sx, sy
|
244 |
-
return y
|
245 |
-
|
246 |
-
@staticmethod
|
247 |
-
def backward(ctx, dy): # pylint: disable=arguments-differ
|
248 |
-
fu, fd, si = ctx.saved_tensors
|
249 |
-
_, _, xh, xw = ctx.x_shape
|
250 |
-
_, _, yh, yw = ctx.y_shape
|
251 |
-
sx, sy = ctx.s_ofs
|
252 |
-
dx = None # 0
|
253 |
-
dfu = None; assert not ctx.needs_input_grad[1]
|
254 |
-
dfd = None; assert not ctx.needs_input_grad[2]
|
255 |
-
db = None # 3
|
256 |
-
dsi = None; assert not ctx.needs_input_grad[4]
|
257 |
-
dsx = None; assert not ctx.needs_input_grad[5]
|
258 |
-
dsy = None; assert not ctx.needs_input_grad[6]
|
259 |
-
|
260 |
-
if ctx.needs_input_grad[0] or ctx.needs_input_grad[3]:
|
261 |
-
pp = [
|
262 |
-
(fu.shape[-1] - 1) + (fd.shape[-1] - 1) - px0,
|
263 |
-
xw * up - yw * down + px0 - (up - 1),
|
264 |
-
(fu.shape[0] - 1) + (fd.shape[0] - 1) - py0,
|
265 |
-
xh * up - yh * down + py0 - (up - 1),
|
266 |
-
]
|
267 |
-
gg = gain * (up ** 2) / (down ** 2)
|
268 |
-
ff = (not flip_filter)
|
269 |
-
sx = sx - (fu.shape[-1] - 1) + px0
|
270 |
-
sy = sy - (fu.shape[0] - 1) + py0
|
271 |
-
dx = _filtered_lrelu_cuda(up=down, down=up, padding=pp, gain=gg, slope=slope, clamp=None, flip_filter=ff).apply(dy, fd, fu, None, si, sx, sy)
|
272 |
-
|
273 |
-
if ctx.needs_input_grad[3]:
|
274 |
-
db = dx.sum([0, 2, 3])
|
275 |
-
|
276 |
-
return dx, dfu, dfd, db, dsi, dsx, dsy
|
277 |
-
|
278 |
-
# Add to cache.
|
279 |
-
_filtered_lrelu_cuda_cache[key] = FilteredLReluCuda
|
280 |
-
return FilteredLReluCuda
|
281 |
-
|
282 |
-
#----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch_utils/ops/filtered_lrelu_ns.cu
DELETED
@@ -1,27 +0,0 @@
|
|
1 |
-
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
//
|
3 |
-
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
-
// and proprietary rights in and to this software, related documentation
|
5 |
-
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
// distribution of this software and related documentation without an express
|
7 |
-
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
-
|
9 |
-
#include "filtered_lrelu.cu"
|
10 |
-
|
11 |
-
// Template/kernel specializations for no signs mode (no gradients required).
|
12 |
-
|
13 |
-
// Full op, 32-bit indexing.
|
14 |
-
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
15 |
-
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int32_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
16 |
-
|
17 |
-
// Full op, 64-bit indexing.
|
18 |
-
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
19 |
-
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int64_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
20 |
-
|
21 |
-
// Activation/signs only for generic variant. 64-bit indexing.
|
22 |
-
template void* choose_filtered_lrelu_act_kernel<c10::Half, false, false>(void);
|
23 |
-
template void* choose_filtered_lrelu_act_kernel<float, false, false>(void);
|
24 |
-
template void* choose_filtered_lrelu_act_kernel<double, false, false>(void);
|
25 |
-
|
26 |
-
// Copy filters to constant memory.
|
27 |
-
template cudaError_t copy_filters<false, false>(cudaStream_t stream);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch_utils/ops/filtered_lrelu_rd.cu
DELETED
@@ -1,27 +0,0 @@
|
|
1 |
-
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
//
|
3 |
-
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
-
// and proprietary rights in and to this software, related documentation
|
5 |
-
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
// distribution of this software and related documentation without an express
|
7 |
-
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
-
|
9 |
-
#include "filtered_lrelu.cu"
|
10 |
-
|
11 |
-
// Template/kernel specializations for sign read mode.
|
12 |
-
|
13 |
-
// Full op, 32-bit indexing.
|
14 |
-
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
15 |
-
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int32_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
16 |
-
|
17 |
-
// Full op, 64-bit indexing.
|
18 |
-
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
19 |
-
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int64_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
20 |
-
|
21 |
-
// Activation/signs only for generic variant. 64-bit indexing.
|
22 |
-
template void* choose_filtered_lrelu_act_kernel<c10::Half, false, true>(void);
|
23 |
-
template void* choose_filtered_lrelu_act_kernel<float, false, true>(void);
|
24 |
-
template void* choose_filtered_lrelu_act_kernel<double, false, true>(void);
|
25 |
-
|
26 |
-
// Copy filters to constant memory.
|
27 |
-
template cudaError_t copy_filters<false, true>(cudaStream_t stream);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch_utils/ops/filtered_lrelu_wr.cu
DELETED
@@ -1,27 +0,0 @@
|
|
1 |
-
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
-
//
|
3 |
-
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
-
// and proprietary rights in and to this software, related documentation
|
5 |
-
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
-
// distribution of this software and related documentation without an express
|
7 |
-
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
-
|
9 |
-
#include "filtered_lrelu.cu"
|
10 |
-
|
11 |
-
// Template/kernel specializations for sign write mode.
|
12 |
-
|
13 |
-
// Full op, 32-bit indexing.
|
14 |
-
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
15 |
-
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int32_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
16 |
-
|
17 |
-
// Full op, 64-bit indexing.
|
18 |
-
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
19 |
-
template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int64_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
|
20 |
-
|
21 |
-
// Activation/signs only for generic variant. 64-bit indexing.
|
22 |
-
template void* choose_filtered_lrelu_act_kernel<c10::Half, true, false>(void);
|
23 |
-
template void* choose_filtered_lrelu_act_kernel<float, true, false>(void);
|
24 |
-
template void* choose_filtered_lrelu_act_kernel<double, true, false>(void);
|
25 |
-
|
26 |
-
// Copy filters to constant memory.
|
27 |
-
template cudaError_t copy_filters<true, false>(cudaStream_t stream);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch_utils/ops/fma.py
DELETED
@@ -1,62 +0,0 @@
|
|
1 |
-
# Copyright (c) SenseTime Research. All rights reserved.
|
2 |
-
|
3 |
-
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
4 |
-
#
|
5 |
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
6 |
-
# and proprietary rights in and to this software, related documentation
|
7 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
8 |
-
# distribution of this software and related documentation without an express
|
9 |
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
-
|
11 |
-
"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
|
12 |
-
|
13 |
-
import torch
|
14 |
-
|
15 |
-
#----------------------------------------------------------------------------
|
16 |
-
|
17 |
-
def fma(a, b, c): # => a * b + c
|
18 |
-
return _FusedMultiplyAdd.apply(a, b, c)
|
19 |
-
|
20 |
-
#----------------------------------------------------------------------------
|
21 |
-
|
22 |
-
class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
|
23 |
-
@staticmethod
|
24 |
-
def forward(ctx, a, b, c): # pylint: disable=arguments-differ
|
25 |
-
out = torch.addcmul(c, a, b)
|
26 |
-
ctx.save_for_backward(a, b)
|
27 |
-
ctx.c_shape = c.shape
|
28 |
-
return out
|
29 |
-
|
30 |
-
@staticmethod
|
31 |
-
def backward(ctx, dout): # pylint: disable=arguments-differ
|
32 |
-
a, b = ctx.saved_tensors
|
33 |
-
c_shape = ctx.c_shape
|
34 |
-
da = None
|
35 |
-
db = None
|
36 |
-
dc = None
|
37 |
-
|
38 |
-
if ctx.needs_input_grad[0]:
|
39 |
-
da = _unbroadcast(dout * b, a.shape)
|
40 |
-
|
41 |
-
if ctx.needs_input_grad[1]:
|
42 |
-
db = _unbroadcast(dout * a, b.shape)
|
43 |
-
|
44 |
-
if ctx.needs_input_grad[2]:
|
45 |
-
dc = _unbroadcast(dout, c_shape)
|
46 |
-
|
47 |
-
return da, db, dc
|
48 |
-
|
49 |
-
#----------------------------------------------------------------------------
|
50 |
-
|
51 |
-
def _unbroadcast(x, shape):
|
52 |
-
extra_dims = x.ndim - len(shape)
|
53 |
-
assert extra_dims >= 0
|
54 |
-
dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
|
55 |
-
if len(dim):
|
56 |
-
x = x.sum(dim=dim, keepdim=True)
|
57 |
-
if extra_dims:
|
58 |
-
x = x.reshape(-1, *x.shape[extra_dims+1:])
|
59 |
-
assert x.shape == shape
|
60 |
-
return x
|
61 |
-
|
62 |
-
#----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch_utils/ops/grid_sample_gradfix.py
DELETED
@@ -1,85 +0,0 @@
|
|
1 |
-
# Copyright (c) SenseTime Research. All rights reserved.
|
2 |
-
|
3 |
-
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
4 |
-
#
|
5 |
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
6 |
-
# and proprietary rights in and to this software, related documentation
|
7 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
8 |
-
# distribution of this software and related documentation without an express
|
9 |
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
-
|
11 |
-
"""Custom replacement for `torch.nn.functional.grid_sample` that
|
12 |
-
supports arbitrarily high order gradients between the input and output.
|
13 |
-
Only works on 2D images and assumes
|
14 |
-
`mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
|
15 |
-
|
16 |
-
import warnings
|
17 |
-
import torch
|
18 |
-
|
19 |
-
# pylint: disable=redefined-builtin
|
20 |
-
# pylint: disable=arguments-differ
|
21 |
-
# pylint: disable=protected-access
|
22 |
-
|
23 |
-
#----------------------------------------------------------------------------
|
24 |
-
|
25 |
-
enabled = False # Enable the custom op by setting this to true.
|
26 |
-
|
27 |
-
#----------------------------------------------------------------------------
|
28 |
-
|
29 |
-
def grid_sample(input, grid):
|
30 |
-
if _should_use_custom_op():
|
31 |
-
return _GridSample2dForward.apply(input, grid)
|
32 |
-
return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
|
33 |
-
|
34 |
-
#----------------------------------------------------------------------------
|
35 |
-
|
36 |
-
def _should_use_custom_op():
|
37 |
-
if not enabled:
|
38 |
-
return False
|
39 |
-
if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
|
40 |
-
return True
|
41 |
-
warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().')
|
42 |
-
return False
|
43 |
-
|
44 |
-
#----------------------------------------------------------------------------
|
45 |
-
|
46 |
-
class _GridSample2dForward(torch.autograd.Function):
|
47 |
-
@staticmethod
|
48 |
-
def forward(ctx, input, grid):
|
49 |
-
assert input.ndim == 4
|
50 |
-
assert grid.ndim == 4
|
51 |
-
output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
|
52 |
-
ctx.save_for_backward(input, grid)
|
53 |
-
return output
|
54 |
-
|
55 |
-
@staticmethod
|
56 |
-
def backward(ctx, grad_output):
|
57 |
-
input, grid = ctx.saved_tensors
|
58 |
-
grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
|
59 |
-
return grad_input, grad_grid
|
60 |
-
|
61 |
-
#----------------------------------------------------------------------------
|
62 |
-
|
63 |
-
class _GridSample2dBackward(torch.autograd.Function):
|
64 |
-
@staticmethod
|
65 |
-
def forward(ctx, grad_output, input, grid):
|
66 |
-
op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
|
67 |
-
grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
|
68 |
-
ctx.save_for_backward(grid)
|
69 |
-
return grad_input, grad_grid
|
70 |
-
|
71 |
-
@staticmethod
|
72 |
-
def backward(ctx, grad2_grad_input, grad2_grad_grid):
|
73 |
-
_ = grad2_grad_grid # unused
|
74 |
-
grid, = ctx.saved_tensors
|
75 |
-
grad2_grad_output = None
|
76 |
-
grad2_input = None
|
77 |
-
grad2_grid = None
|
78 |
-
|
79 |
-
if ctx.needs_input_grad[0]:
|
80 |
-
grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
|
81 |
-
|
82 |
-
assert not ctx.needs_input_grad[2]
|
83 |
-
return grad2_grad_output, grad2_input, grad2_grid
|
84 |
-
|
85 |
-
#----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch_utils/ops/upfirdn2d.cpp
DELETED
@@ -1,105 +0,0 @@
|
|
1 |
-
// Copyright (c) SenseTime Research. All rights reserved.
|
2 |
-
|
3 |
-
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
4 |
-
//
|
5 |
-
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
6 |
-
// and proprietary rights in and to this software, related documentation
|
7 |
-
// and any modifications thereto. Any use, reproduction, disclosure or
|
8 |
-
// distribution of this software and related documentation without an express
|
9 |
-
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
-
|
11 |
-
#include <torch/extension.h>
|
12 |
-
#include <ATen/cuda/CUDAContext.h>
|
13 |
-
#include <c10/cuda/CUDAGuard.h>
|
14 |
-
#include "upfirdn2d.h"
|
15 |
-
|
16 |
-
//------------------------------------------------------------------------
|
17 |
-
|
18 |
-
static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
|
19 |
-
{
|
20 |
-
// Validate arguments.
|
21 |
-
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
|
22 |
-
TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
|
23 |
-
TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
|
24 |
-
TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
|
25 |
-
TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
|
26 |
-
TORCH_CHECK(x.dim() == 4, "x must be rank 4");
|
27 |
-
TORCH_CHECK(f.dim() == 2, "f must be rank 2");
|
28 |
-
TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
|
29 |
-
TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
|
30 |
-
TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
|
31 |
-
|
32 |
-
// Create output tensor.
|
33 |
-
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
34 |
-
int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
|
35 |
-
int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
|
36 |
-
TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
|
37 |
-
torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
|
38 |
-
TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
|
39 |
-
|
40 |
-
// Initialize CUDA kernel parameters.
|
41 |
-
upfirdn2d_kernel_params p;
|
42 |
-
p.x = x.data_ptr();
|
43 |
-
p.f = f.data_ptr<float>();
|
44 |
-
p.y = y.data_ptr();
|
45 |
-
p.up = make_int2(upx, upy);
|
46 |
-
p.down = make_int2(downx, downy);
|
47 |
-
p.pad0 = make_int2(padx0, pady0);
|
48 |
-
p.flip = (flip) ? 1 : 0;
|
49 |
-
p.gain = gain;
|
50 |
-
p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
|
51 |
-
p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
|
52 |
-
p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
|
53 |
-
p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
|
54 |
-
p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
|
55 |
-
p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
|
56 |
-
p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
|
57 |
-
p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
|
58 |
-
|
59 |
-
// Choose CUDA kernel.
|
60 |
-
upfirdn2d_kernel_spec spec;
|
61 |
-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
|
62 |
-
{
|
63 |
-
spec = choose_upfirdn2d_kernel<scalar_t>(p);
|
64 |
-
});
|
65 |
-
|
66 |
-
// Set looping options.
|
67 |
-
p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
|
68 |
-
p.loopMinor = spec.loopMinor;
|
69 |
-
p.loopX = spec.loopX;
|
70 |
-
p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
|
71 |
-
p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
|
72 |
-
|
73 |
-
// Compute grid size.
|
74 |
-
dim3 blockSize, gridSize;
|
75 |
-
if (spec.tileOutW < 0) // large
|
76 |
-
{
|
77 |
-
blockSize = dim3(4, 32, 1);
|
78 |
-
gridSize = dim3(
|
79 |
-
((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
|
80 |
-
(p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
|
81 |
-
p.launchMajor);
|
82 |
-
}
|
83 |
-
else // small
|
84 |
-
{
|
85 |
-
blockSize = dim3(256, 1, 1);
|
86 |
-
gridSize = dim3(
|
87 |
-
((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
|
88 |
-
(p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
|
89 |
-
p.launchMajor);
|
90 |
-
}
|
91 |
-
|
92 |
-
// Launch CUDA kernel.
|
93 |
-
void* args[] = {&p};
|
94 |
-
AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
|
95 |
-
return y;
|
96 |
-
}
|
97 |
-
|
98 |
-
//------------------------------------------------------------------------
|
99 |
-
|
100 |
-
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
101 |
-
{
|
102 |
-
m.def("upfirdn2d", &upfirdn2d);
|
103 |
-
}
|
104 |
-
|
105 |
-
//------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch_utils/ops/upfirdn2d.cu
DELETED
@@ -1,352 +0,0 @@
|
|
1 |
-
// Copyright (c) SenseTime Research. All rights reserved.
|
2 |
-
|
3 |
-
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
4 |
-
//
|
5 |
-
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
6 |
-
// and proprietary rights in and to this software, related documentation
|
7 |
-
// and any modifications thereto. Any use, reproduction, disclosure or
|
8 |
-
// distribution of this software and related documentation without an express
|
9 |
-
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
-
|
11 |
-
#include <c10/util/Half.h>
|
12 |
-
#include "upfirdn2d.h"
|
13 |
-
|
14 |
-
//------------------------------------------------------------------------
|
15 |
-
// Helpers.
|
16 |
-
|
17 |
-
template <class T> struct InternalType;
|
18 |
-
template <> struct InternalType<double> { typedef double scalar_t; };
|
19 |
-
template <> struct InternalType<float> { typedef float scalar_t; };
|
20 |
-
template <> struct InternalType<c10::Half> { typedef float scalar_t; };
|
21 |
-
|
22 |
-
static __device__ __forceinline__ int floor_div(int a, int b)
|
23 |
-
{
|
24 |
-
int t = 1 - a / b;
|
25 |
-
return (a + t * b) / b - t;
|
26 |
-
}
|
27 |
-
|
28 |
-
//------------------------------------------------------------------------
|
29 |
-
// Generic CUDA implementation for large filters.
|
30 |
-
|
31 |
-
template <class T> static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p)
|
32 |
-
{
|
33 |
-
typedef typename InternalType<T>::scalar_t scalar_t;
|
34 |
-
|
35 |
-
// Calculate thread index.
|
36 |
-
int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
|
37 |
-
int outY = minorBase / p.launchMinor;
|
38 |
-
minorBase -= outY * p.launchMinor;
|
39 |
-
int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
|
40 |
-
int majorBase = blockIdx.z * p.loopMajor;
|
41 |
-
if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
|
42 |
-
return;
|
43 |
-
|
44 |
-
// Setup Y receptive field.
|
45 |
-
int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
|
46 |
-
int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
|
47 |
-
int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
|
48 |
-
int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
|
49 |
-
if (p.flip)
|
50 |
-
filterY = p.filterSize.y - 1 - filterY;
|
51 |
-
|
52 |
-
// Loop over major, minor, and X.
|
53 |
-
for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
|
54 |
-
for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor)
|
55 |
-
{
|
56 |
-
int nc = major * p.sizeMinor + minor;
|
57 |
-
int n = nc / p.inSize.z;
|
58 |
-
int c = nc - n * p.inSize.z;
|
59 |
-
for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y)
|
60 |
-
{
|
61 |
-
// Setup X receptive field.
|
62 |
-
int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
|
63 |
-
int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
|
64 |
-
int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX;
|
65 |
-
int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
|
66 |
-
if (p.flip)
|
67 |
-
filterX = p.filterSize.x - 1 - filterX;
|
68 |
-
|
69 |
-
// Initialize pointers.
|
70 |
-
const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
|
71 |
-
const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
|
72 |
-
int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
|
73 |
-
int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
|
74 |
-
|
75 |
-
// Inner loop.
|
76 |
-
scalar_t v = 0;
|
77 |
-
for (int y = 0; y < h; y++)
|
78 |
-
{
|
79 |
-
for (int x = 0; x < w; x++)
|
80 |
-
{
|
81 |
-
v += (scalar_t)(*xp) * (scalar_t)(*fp);
|
82 |
-
xp += p.inStride.x;
|
83 |
-
fp += filterStepX;
|
84 |
-
}
|
85 |
-
xp += p.inStride.y - w * p.inStride.x;
|
86 |
-
fp += filterStepY - w * filterStepX;
|
87 |
-
}
|
88 |
-
|
89 |
-
// Store result.
|
90 |
-
v *= p.gain;
|
91 |
-
((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
|
92 |
-
}
|
93 |
-
}
|
94 |
-
}
|
95 |
-
|
96 |
-
//------------------------------------------------------------------------
|
97 |
-
// Specialized CUDA implementation for small filters.
|
98 |
-
|
99 |
-
template <class T, int upx, int upy, int downx, int downy, int filterW, int filterH, int tileOutW, int tileOutH, int loopMinor>
|
100 |
-
static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p)
|
101 |
-
{
|
102 |
-
typedef typename InternalType<T>::scalar_t scalar_t;
|
103 |
-
const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
|
104 |
-
const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
|
105 |
-
__shared__ volatile scalar_t sf[filterH][filterW];
|
106 |
-
__shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
|
107 |
-
|
108 |
-
// Calculate tile index.
|
109 |
-
int minorBase = blockIdx.x;
|
110 |
-
int tileOutY = minorBase / p.launchMinor;
|
111 |
-
minorBase -= tileOutY * p.launchMinor;
|
112 |
-
minorBase *= loopMinor;
|
113 |
-
tileOutY *= tileOutH;
|
114 |
-
int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
|
115 |
-
int majorBase = blockIdx.z * p.loopMajor;
|
116 |
-
if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor)
|
117 |
-
return;
|
118 |
-
|
119 |
-
// Load filter (flipped).
|
120 |
-
for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x)
|
121 |
-
{
|
122 |
-
int fy = tapIdx / filterW;
|
123 |
-
int fx = tapIdx - fy * filterW;
|
124 |
-
scalar_t v = 0;
|
125 |
-
if (fx < p.filterSize.x & fy < p.filterSize.y)
|
126 |
-
{
|
127 |
-
int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
|
128 |
-
int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
|
129 |
-
v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
|
130 |
-
}
|
131 |
-
sf[fy][fx] = v;
|
132 |
-
}
|
133 |
-
|
134 |
-
// Loop over major and X.
|
135 |
-
for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
|
136 |
-
{
|
137 |
-
int baseNC = major * p.sizeMinor + minorBase;
|
138 |
-
int n = baseNC / p.inSize.z;
|
139 |
-
int baseC = baseNC - n * p.inSize.z;
|
140 |
-
for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW)
|
141 |
-
{
|
142 |
-
// Load input pixels.
|
143 |
-
int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
|
144 |
-
int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
|
145 |
-
int tileInX = floor_div(tileMidX, upx);
|
146 |
-
int tileInY = floor_div(tileMidY, upy);
|
147 |
-
__syncthreads();
|
148 |
-
for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x)
|
149 |
-
{
|
150 |
-
int relC = inIdx;
|
151 |
-
int relInX = relC / loopMinor;
|
152 |
-
int relInY = relInX / tileInW;
|
153 |
-
relC -= relInX * loopMinor;
|
154 |
-
relInX -= relInY * tileInW;
|
155 |
-
int c = baseC + relC;
|
156 |
-
int inX = tileInX + relInX;
|
157 |
-
int inY = tileInY + relInY;
|
158 |
-
scalar_t v = 0;
|
159 |
-
if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z)
|
160 |
-
v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
|
161 |
-
sx[relInY][relInX][relC] = v;
|
162 |
-
}
|
163 |
-
|
164 |
-
// Loop over output pixels.
|
165 |
-
__syncthreads();
|
166 |
-
for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x)
|
167 |
-
{
|
168 |
-
int relC = outIdx;
|
169 |
-
int relOutX = relC / loopMinor;
|
170 |
-
int relOutY = relOutX / tileOutW;
|
171 |
-
relC -= relOutX * loopMinor;
|
172 |
-
relOutX -= relOutY * tileOutW;
|
173 |
-
int c = baseC + relC;
|
174 |
-
int outX = tileOutX + relOutX;
|
175 |
-
int outY = tileOutY + relOutY;
|
176 |
-
|
177 |
-
// Setup receptive field.
|
178 |
-
int midX = tileMidX + relOutX * downx;
|
179 |
-
int midY = tileMidY + relOutY * downy;
|
180 |
-
int inX = floor_div(midX, upx);
|
181 |
-
int inY = floor_div(midY, upy);
|
182 |
-
int relInX = inX - tileInX;
|
183 |
-
int relInY = inY - tileInY;
|
184 |
-
int filterX = (inX + 1) * upx - midX - 1; // flipped
|
185 |
-
int filterY = (inY + 1) * upy - midY - 1; // flipped
|
186 |
-
|
187 |
-
// Inner loop.
|
188 |
-
if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z)
|
189 |
-
{
|
190 |
-
scalar_t v = 0;
|
191 |
-
#pragma unroll
|
192 |
-
for (int y = 0; y < filterH / upy; y++)
|
193 |
-
#pragma unroll
|
194 |
-
for (int x = 0; x < filterW / upx; x++)
|
195 |
-
v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx];
|
196 |
-
v *= p.gain;
|
197 |
-
((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
|
198 |
-
}
|
199 |
-
}
|
200 |
-
}
|
201 |
-
}
|
202 |
-
}
|
203 |
-
|
204 |
-
//------------------------------------------------------------------------
|
205 |
-
// CUDA kernel selection.
|
206 |
-
|
207 |
-
template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p)
|
208 |
-
{
|
209 |
-
int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
|
210 |
-
|
211 |
-
upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,1, 4}; // contiguous
|
212 |
-
if (s == 1) spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,4, 1}; // channels_last
|
213 |
-
|
214 |
-
if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
|
215 |
-
{
|
216 |
-
if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 64,16,1>, 64,16,1, 1};
|
217 |
-
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
|
218 |
-
if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5, 64,16,1>, 64,16,1, 1};
|
219 |
-
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
|
220 |
-
if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3, 64,16,1>, 64,16,1, 1};
|
221 |
-
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
|
222 |
-
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 20,1, 128,8,1>, 128,8,1, 1};
|
223 |
-
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
|
224 |
-
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 12,1, 128,8,1>, 128,8,1, 1};
|
225 |
-
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
|
226 |
-
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
|
227 |
-
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,20, 32,32,1>, 32,32,1, 1};
|
228 |
-
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
|
229 |
-
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,12, 32,32,1>, 32,32,1, 1};
|
230 |
-
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
|
231 |
-
}
|
232 |
-
if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
|
233 |
-
{
|
234 |
-
if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 16,16,8>, 16,16,8, 1};
|
235 |
-
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
236 |
-
if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
237 |
-
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
238 |
-
if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
239 |
-
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
|
240 |
-
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 20,1, 128,1,16>, 128,1,16, 1};
|
241 |
-
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
|
242 |
-
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 12,1, 128,1,16>, 128,1,16, 1};
|
243 |
-
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
|
244 |
-
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
|
245 |
-
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,20, 1,128,16>, 1,128,16, 1};
|
246 |
-
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
|
247 |
-
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,12, 1,128,16>, 1,128,16, 1};
|
248 |
-
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
|
249 |
-
}
|
250 |
-
if (s != 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
|
251 |
-
{
|
252 |
-
if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 64,16,1>, 64,16,1, 1};
|
253 |
-
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
|
254 |
-
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
|
255 |
-
if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 64,16,1>, 64,16,1, 1};
|
256 |
-
}
|
257 |
-
if (s == 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
|
258 |
-
{
|
259 |
-
if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 16,16,8>, 16,16,8, 1};
|
260 |
-
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 16,16,8>, 16,16,8, 1};
|
261 |
-
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
262 |
-
if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 16,16,8>, 16,16,8, 1};
|
263 |
-
}
|
264 |
-
if (s != 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
|
265 |
-
{
|
266 |
-
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
|
267 |
-
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 20,1, 128,8,1>, 128,8,1, 1};
|
268 |
-
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
|
269 |
-
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 12,1, 128,8,1>, 128,8,1, 1};
|
270 |
-
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
|
271 |
-
}
|
272 |
-
if (s == 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
|
273 |
-
{
|
274 |
-
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
|
275 |
-
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 20,1, 128,1,16>, 128,1,16, 1};
|
276 |
-
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
|
277 |
-
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 12,1, 128,1,16>, 128,1,16, 1};
|
278 |
-
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
|
279 |
-
}
|
280 |
-
if (s != 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
|
281 |
-
{
|
282 |
-
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
|
283 |
-
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,20, 32,32,1>, 32,32,1, 1};
|
284 |
-
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
|
285 |
-
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,12, 32,32,1>, 32,32,1, 1};
|
286 |
-
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
|
287 |
-
}
|
288 |
-
if (s == 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
|
289 |
-
{
|
290 |
-
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
|
291 |
-
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,20, 1,128,16>, 1,128,16, 1};
|
292 |
-
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
|
293 |
-
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,12, 1,128,16>, 1,128,16, 1};
|
294 |
-
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
|
295 |
-
}
|
296 |
-
if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // contiguous
|
297 |
-
{
|
298 |
-
if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 32,8,1>, 32,8,1, 1};
|
299 |
-
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 32,8,1>, 32,8,1, 1};
|
300 |
-
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 32,8,1>, 32,8,1, 1};
|
301 |
-
if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 32,8,1>, 32,8,1, 1};
|
302 |
-
}
|
303 |
-
if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // channels_last
|
304 |
-
{
|
305 |
-
if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 8,8,8>, 8,8,8, 1};
|
306 |
-
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 8,8,8>, 8,8,8, 1};
|
307 |
-
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 8,8,8>, 8,8,8, 1};
|
308 |
-
if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 8,8,8>, 8,8,8, 1};
|
309 |
-
}
|
310 |
-
if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // contiguous
|
311 |
-
{
|
312 |
-
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,8,1>, 64,8,1, 1};
|
313 |
-
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 20,1, 64,8,1>, 64,8,1, 1};
|
314 |
-
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,8,1>, 64,8,1, 1};
|
315 |
-
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 12,1, 64,8,1>, 64,8,1, 1};
|
316 |
-
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,8,1>, 64,8,1, 1};
|
317 |
-
}
|
318 |
-
if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // channels_last
|
319 |
-
{
|
320 |
-
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,1,8>, 64,1,8, 1};
|
321 |
-
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 20,1, 64,1,8>, 64,1,8, 1};
|
322 |
-
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,1,8>, 64,1,8, 1};
|
323 |
-
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 12,1, 64,1,8>, 64,1,8, 1};
|
324 |
-
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,1,8>, 64,1,8, 1};
|
325 |
-
}
|
326 |
-
if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // contiguous
|
327 |
-
{
|
328 |
-
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 32,16,1>, 32,16,1, 1};
|
329 |
-
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,20, 32,16,1>, 32,16,1, 1};
|
330 |
-
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 32,16,1>, 32,16,1, 1};
|
331 |
-
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,12, 32,16,1>, 32,16,1, 1};
|
332 |
-
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 32,16,1>, 32,16,1, 1};
|
333 |
-
}
|
334 |
-
if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // channels_last
|
335 |
-
{
|
336 |
-
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 1,64,8>, 1,64,8, 1};
|
337 |
-
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,20, 1,64,8>, 1,64,8, 1};
|
338 |
-
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 1,64,8>, 1,64,8, 1};
|
339 |
-
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,12, 1,64,8>, 1,64,8, 1};
|
340 |
-
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 1,64,8>, 1,64,8, 1};
|
341 |
-
}
|
342 |
-
return spec;
|
343 |
-
}
|
344 |
-
|
345 |
-
//------------------------------------------------------------------------
|
346 |
-
// Template specializations.
|
347 |
-
|
348 |
-
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<double> (const upfirdn2d_kernel_params& p);
|
349 |
-
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<float> (const upfirdn2d_kernel_params& p);
|
350 |
-
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<c10::Half>(const upfirdn2d_kernel_params& p);
|
351 |
-
|
352 |
-
//------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch_utils/ops/upfirdn2d.h
DELETED
@@ -1,61 +0,0 @@
|
|
1 |
-
// Copyright (c) SenseTime Research. All rights reserved.
|
2 |
-
|
3 |
-
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
4 |
-
//
|
5 |
-
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
6 |
-
// and proprietary rights in and to this software, related documentation
|
7 |
-
// and any modifications thereto. Any use, reproduction, disclosure or
|
8 |
-
// distribution of this software and related documentation without an express
|
9 |
-
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
-
|
11 |
-
#include <cuda_runtime.h>
|
12 |
-
|
13 |
-
//------------------------------------------------------------------------
|
14 |
-
// CUDA kernel parameters.
|
15 |
-
|
16 |
-
struct upfirdn2d_kernel_params
|
17 |
-
{
|
18 |
-
const void* x;
|
19 |
-
const float* f;
|
20 |
-
void* y;
|
21 |
-
|
22 |
-
int2 up;
|
23 |
-
int2 down;
|
24 |
-
int2 pad0;
|
25 |
-
int flip;
|
26 |
-
float gain;
|
27 |
-
|
28 |
-
int4 inSize; // [width, height, channel, batch]
|
29 |
-
int4 inStride;
|
30 |
-
int2 filterSize; // [width, height]
|
31 |
-
int2 filterStride;
|
32 |
-
int4 outSize; // [width, height, channel, batch]
|
33 |
-
int4 outStride;
|
34 |
-
int sizeMinor;
|
35 |
-
int sizeMajor;
|
36 |
-
|
37 |
-
int loopMinor;
|
38 |
-
int loopMajor;
|
39 |
-
int loopX;
|
40 |
-
int launchMinor;
|
41 |
-
int launchMajor;
|
42 |
-
};
|
43 |
-
|
44 |
-
//------------------------------------------------------------------------
|
45 |
-
// CUDA kernel specialization.
|
46 |
-
|
47 |
-
struct upfirdn2d_kernel_spec
|
48 |
-
{
|
49 |
-
void* kernel;
|
50 |
-
int tileOutW;
|
51 |
-
int tileOutH;
|
52 |
-
int loopMinor;
|
53 |
-
int loopX;
|
54 |
-
};
|
55 |
-
|
56 |
-
//------------------------------------------------------------------------
|
57 |
-
// CUDA kernel selection.
|
58 |
-
|
59 |
-
template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
|
60 |
-
|
61 |
-
//------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch_utils/ops/upfirdn2d.py
DELETED
@@ -1,386 +0,0 @@
|
|
1 |
-
# Copyright (c) SenseTime Research. All rights reserved.
|
2 |
-
|
3 |
-
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
4 |
-
#
|
5 |
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
6 |
-
# and proprietary rights in and to this software, related documentation
|
7 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
8 |
-
# distribution of this software and related documentation without an express
|
9 |
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
-
|
11 |
-
"""Custom PyTorch ops for efficient resampling of 2D images."""
|
12 |
-
|
13 |
-
import os
|
14 |
-
import warnings
|
15 |
-
import numpy as np
|
16 |
-
import torch
|
17 |
-
import traceback
|
18 |
-
|
19 |
-
from .. import custom_ops
|
20 |
-
from .. import misc
|
21 |
-
from . import conv2d_gradfix
|
22 |
-
|
23 |
-
#----------------------------------------------------------------------------
|
24 |
-
|
25 |
-
_inited = False
|
26 |
-
_plugin = None
|
27 |
-
|
28 |
-
def _init():
|
29 |
-
global _inited, _plugin
|
30 |
-
if not _inited:
|
31 |
-
sources = ['upfirdn2d.cpp', 'upfirdn2d.cu']
|
32 |
-
sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
|
33 |
-
try:
|
34 |
-
_plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
|
35 |
-
except:
|
36 |
-
warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
|
37 |
-
return _plugin is not None
|
38 |
-
|
39 |
-
def _parse_scaling(scaling):
|
40 |
-
if isinstance(scaling, int):
|
41 |
-
scaling = [scaling, scaling]
|
42 |
-
assert isinstance(scaling, (list, tuple))
|
43 |
-
assert all(isinstance(x, int) for x in scaling)
|
44 |
-
sx, sy = scaling
|
45 |
-
assert sx >= 1 and sy >= 1
|
46 |
-
return sx, sy
|
47 |
-
|
48 |
-
def _parse_padding(padding):
|
49 |
-
if isinstance(padding, int):
|
50 |
-
padding = [padding, padding]
|
51 |
-
assert isinstance(padding, (list, tuple))
|
52 |
-
assert all(isinstance(x, int) for x in padding)
|
53 |
-
if len(padding) == 2:
|
54 |
-
padx, pady = padding
|
55 |
-
padding = [padx, padx, pady, pady]
|
56 |
-
padx0, padx1, pady0, pady1 = padding
|
57 |
-
return padx0, padx1, pady0, pady1
|
58 |
-
|
59 |
-
def _get_filter_size(f):
|
60 |
-
if f is None:
|
61 |
-
return 1, 1
|
62 |
-
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
63 |
-
fw = f.shape[-1]
|
64 |
-
fh = f.shape[0]
|
65 |
-
with misc.suppress_tracer_warnings():
|
66 |
-
fw = int(fw)
|
67 |
-
fh = int(fh)
|
68 |
-
misc.assert_shape(f, [fh, fw][:f.ndim])
|
69 |
-
assert fw >= 1 and fh >= 1
|
70 |
-
return fw, fh
|
71 |
-
|
72 |
-
#----------------------------------------------------------------------------
|
73 |
-
|
74 |
-
def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None):
|
75 |
-
r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
|
76 |
-
|
77 |
-
Args:
|
78 |
-
f: Torch tensor, numpy array, or python list of the shape
|
79 |
-
`[filter_height, filter_width]` (non-separable),
|
80 |
-
`[filter_taps]` (separable),
|
81 |
-
`[]` (impulse), or
|
82 |
-
`None` (identity).
|
83 |
-
device: Result device (default: cpu).
|
84 |
-
normalize: Normalize the filter so that it retains the magnitude
|
85 |
-
for constant input signal (DC)? (default: True).
|
86 |
-
flip_filter: Flip the filter? (default: False).
|
87 |
-
gain: Overall scaling factor for signal magnitude (default: 1).
|
88 |
-
separable: Return a separable filter? (default: select automatically).
|
89 |
-
|
90 |
-
Returns:
|
91 |
-
Float32 tensor of the shape
|
92 |
-
`[filter_height, filter_width]` (non-separable) or
|
93 |
-
`[filter_taps]` (separable).
|
94 |
-
"""
|
95 |
-
# Validate.
|
96 |
-
if f is None:
|
97 |
-
f = 1
|
98 |
-
f = torch.as_tensor(f, dtype=torch.float32)
|
99 |
-
assert f.ndim in [0, 1, 2]
|
100 |
-
assert f.numel() > 0
|
101 |
-
if f.ndim == 0:
|
102 |
-
f = f[np.newaxis]
|
103 |
-
|
104 |
-
# Separable?
|
105 |
-
if separable is None:
|
106 |
-
separable = (f.ndim == 1 and f.numel() >= 8)
|
107 |
-
if f.ndim == 1 and not separable:
|
108 |
-
f = f.ger(f)
|
109 |
-
assert f.ndim == (1 if separable else 2)
|
110 |
-
|
111 |
-
# Apply normalize, flip, gain, and device.
|
112 |
-
if normalize:
|
113 |
-
f /= f.sum()
|
114 |
-
if flip_filter:
|
115 |
-
f = f.flip(list(range(f.ndim)))
|
116 |
-
f = f * (gain ** (f.ndim / 2))
|
117 |
-
f = f.to(device=device)
|
118 |
-
return f
|
119 |
-
|
120 |
-
#----------------------------------------------------------------------------
|
121 |
-
|
122 |
-
def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
123 |
-
r"""Pad, upsample, filter, and downsample a batch of 2D images.
|
124 |
-
|
125 |
-
Performs the following sequence of operations for each channel:
|
126 |
-
|
127 |
-
1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
|
128 |
-
|
129 |
-
2. Pad the image with the specified number of zeros on each side (`padding`).
|
130 |
-
Negative padding corresponds to cropping the image.
|
131 |
-
|
132 |
-
3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
|
133 |
-
so that the footprint of all output pixels lies within the input image.
|
134 |
-
|
135 |
-
4. Downsample the image by keeping every Nth pixel (`down`).
|
136 |
-
|
137 |
-
This sequence of operations bears close resemblance to scipy.signal.upfirdn().
|
138 |
-
The fused op is considerably more efficient than performing the same calculation
|
139 |
-
using standard PyTorch ops. It supports gradients of arbitrary order.
|
140 |
-
|
141 |
-
Args:
|
142 |
-
x: Float32/float64/float16 input tensor of the shape
|
143 |
-
`[batch_size, num_channels, in_height, in_width]`.
|
144 |
-
f: Float32 FIR filter of the shape
|
145 |
-
`[filter_height, filter_width]` (non-separable),
|
146 |
-
`[filter_taps]` (separable), or
|
147 |
-
`None` (identity).
|
148 |
-
up: Integer upsampling factor. Can be a single int or a list/tuple
|
149 |
-
`[x, y]` (default: 1).
|
150 |
-
down: Integer downsampling factor. Can be a single int or a list/tuple
|
151 |
-
`[x, y]` (default: 1).
|
152 |
-
padding: Padding with respect to the upsampled image. Can be a single number
|
153 |
-
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
154 |
-
(default: 0).
|
155 |
-
flip_filter: False = convolution, True = correlation (default: False).
|
156 |
-
gain: Overall scaling factor for signal magnitude (default: 1).
|
157 |
-
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
158 |
-
|
159 |
-
Returns:
|
160 |
-
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
161 |
-
"""
|
162 |
-
assert isinstance(x, torch.Tensor)
|
163 |
-
assert impl in ['ref', 'cuda']
|
164 |
-
if impl == 'cuda' and x.device.type == 'cuda' and _init():
|
165 |
-
return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f)
|
166 |
-
return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
|
167 |
-
|
168 |
-
#----------------------------------------------------------------------------
|
169 |
-
|
170 |
-
@misc.profiled_function
|
171 |
-
def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
|
172 |
-
"""Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
|
173 |
-
"""
|
174 |
-
# Validate arguments.
|
175 |
-
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
176 |
-
if f is None:
|
177 |
-
f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
|
178 |
-
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
179 |
-
assert f.dtype == torch.float32 and not f.requires_grad
|
180 |
-
batch_size, num_channels, in_height, in_width = x.shape
|
181 |
-
upx, upy = _parse_scaling(up)
|
182 |
-
downx, downy = _parse_scaling(down)
|
183 |
-
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
184 |
-
|
185 |
-
# Upsample by inserting zeros.
|
186 |
-
x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
|
187 |
-
x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
|
188 |
-
x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
|
189 |
-
|
190 |
-
# Pad or crop.
|
191 |
-
x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
|
192 |
-
x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)]
|
193 |
-
|
194 |
-
# Setup filter.
|
195 |
-
f = f * (gain ** (f.ndim / 2))
|
196 |
-
f = f.to(x.dtype)
|
197 |
-
if not flip_filter:
|
198 |
-
f = f.flip(list(range(f.ndim)))
|
199 |
-
|
200 |
-
# Convolve with the filter.
|
201 |
-
f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
|
202 |
-
if f.ndim == 4:
|
203 |
-
x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels)
|
204 |
-
else:
|
205 |
-
x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
|
206 |
-
x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
|
207 |
-
|
208 |
-
# Downsample by throwing away pixels.
|
209 |
-
x = x[:, :, ::downy, ::downx]
|
210 |
-
return x
|
211 |
-
|
212 |
-
#----------------------------------------------------------------------------
|
213 |
-
|
214 |
-
_upfirdn2d_cuda_cache = dict()
|
215 |
-
|
216 |
-
def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1):
|
217 |
-
"""Fast CUDA implementation of `upfirdn2d()` using custom ops.
|
218 |
-
"""
|
219 |
-
# Parse arguments.
|
220 |
-
upx, upy = _parse_scaling(up)
|
221 |
-
downx, downy = _parse_scaling(down)
|
222 |
-
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
223 |
-
|
224 |
-
# Lookup from cache.
|
225 |
-
key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
|
226 |
-
if key in _upfirdn2d_cuda_cache:
|
227 |
-
return _upfirdn2d_cuda_cache[key]
|
228 |
-
|
229 |
-
# Forward op.
|
230 |
-
class Upfirdn2dCuda(torch.autograd.Function):
|
231 |
-
@staticmethod
|
232 |
-
def forward(ctx, x, f): # pylint: disable=arguments-differ
|
233 |
-
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
234 |
-
if f is None:
|
235 |
-
f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
|
236 |
-
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
237 |
-
y = x
|
238 |
-
if f.ndim == 2:
|
239 |
-
y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
|
240 |
-
else:
|
241 |
-
y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain))
|
242 |
-
y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain))
|
243 |
-
ctx.save_for_backward(f)
|
244 |
-
ctx.x_shape = x.shape
|
245 |
-
return y
|
246 |
-
|
247 |
-
@staticmethod
|
248 |
-
def backward(ctx, dy): # pylint: disable=arguments-differ
|
249 |
-
f, = ctx.saved_tensors
|
250 |
-
_, _, ih, iw = ctx.x_shape
|
251 |
-
_, _, oh, ow = dy.shape
|
252 |
-
fw, fh = _get_filter_size(f)
|
253 |
-
p = [
|
254 |
-
fw - padx0 - 1,
|
255 |
-
iw * upx - ow * downx + padx0 - upx + 1,
|
256 |
-
fh - pady0 - 1,
|
257 |
-
ih * upy - oh * downy + pady0 - upy + 1,
|
258 |
-
]
|
259 |
-
dx = None
|
260 |
-
df = None
|
261 |
-
|
262 |
-
if ctx.needs_input_grad[0]:
|
263 |
-
dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f)
|
264 |
-
|
265 |
-
assert not ctx.needs_input_grad[1]
|
266 |
-
return dx, df
|
267 |
-
|
268 |
-
# Add to cache.
|
269 |
-
_upfirdn2d_cuda_cache[key] = Upfirdn2dCuda
|
270 |
-
return Upfirdn2dCuda
|
271 |
-
|
272 |
-
#----------------------------------------------------------------------------
|
273 |
-
|
274 |
-
def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
275 |
-
r"""Filter a batch of 2D images using the given 2D FIR filter.
|
276 |
-
|
277 |
-
By default, the result is padded so that its shape matches the input.
|
278 |
-
User-specified padding is applied on top of that, with negative values
|
279 |
-
indicating cropping. Pixels outside the image are assumed to be zero.
|
280 |
-
|
281 |
-
Args:
|
282 |
-
x: Float32/float64/float16 input tensor of the shape
|
283 |
-
`[batch_size, num_channels, in_height, in_width]`.
|
284 |
-
f: Float32 FIR filter of the shape
|
285 |
-
`[filter_height, filter_width]` (non-separable),
|
286 |
-
`[filter_taps]` (separable), or
|
287 |
-
`None` (identity).
|
288 |
-
padding: Padding with respect to the output. Can be a single number or a
|
289 |
-
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
290 |
-
(default: 0).
|
291 |
-
flip_filter: False = convolution, True = correlation (default: False).
|
292 |
-
gain: Overall scaling factor for signal magnitude (default: 1).
|
293 |
-
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
294 |
-
|
295 |
-
Returns:
|
296 |
-
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
297 |
-
"""
|
298 |
-
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
299 |
-
fw, fh = _get_filter_size(f)
|
300 |
-
p = [
|
301 |
-
padx0 + fw // 2,
|
302 |
-
padx1 + (fw - 1) // 2,
|
303 |
-
pady0 + fh // 2,
|
304 |
-
pady1 + (fh - 1) // 2,
|
305 |
-
]
|
306 |
-
return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
|
307 |
-
|
308 |
-
#----------------------------------------------------------------------------
|
309 |
-
|
310 |
-
def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
311 |
-
r"""Upsample a batch of 2D images using the given 2D FIR filter.
|
312 |
-
|
313 |
-
By default, the result is padded so that its shape is a multiple of the input.
|
314 |
-
User-specified padding is applied on top of that, with negative values
|
315 |
-
indicating cropping. Pixels outside the image are assumed to be zero.
|
316 |
-
|
317 |
-
Args:
|
318 |
-
x: Float32/float64/float16 input tensor of the shape
|
319 |
-
`[batch_size, num_channels, in_height, in_width]`.
|
320 |
-
f: Float32 FIR filter of the shape
|
321 |
-
`[filter_height, filter_width]` (non-separable),
|
322 |
-
`[filter_taps]` (separable), or
|
323 |
-
`None` (identity).
|
324 |
-
up: Integer upsampling factor. Can be a single int or a list/tuple
|
325 |
-
`[x, y]` (default: 1).
|
326 |
-
padding: Padding with respect to the output. Can be a single number or a
|
327 |
-
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
328 |
-
(default: 0).
|
329 |
-
flip_filter: False = convolution, True = correlation (default: False).
|
330 |
-
gain: Overall scaling factor for signal magnitude (default: 1).
|
331 |
-
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
332 |
-
|
333 |
-
Returns:
|
334 |
-
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
335 |
-
"""
|
336 |
-
upx, upy = _parse_scaling(up)
|
337 |
-
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
338 |
-
fw, fh = _get_filter_size(f)
|
339 |
-
p = [
|
340 |
-
padx0 + (fw + upx - 1) // 2,
|
341 |
-
padx1 + (fw - upx) // 2,
|
342 |
-
pady0 + (fh + upy - 1) // 2,
|
343 |
-
pady1 + (fh - upy) // 2,
|
344 |
-
]
|
345 |
-
return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl)
|
346 |
-
|
347 |
-
#----------------------------------------------------------------------------
|
348 |
-
|
349 |
-
def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
350 |
-
r"""Downsample a batch of 2D images using the given 2D FIR filter.
|
351 |
-
|
352 |
-
By default, the result is padded so that its shape is a fraction of the input.
|
353 |
-
User-specified padding is applied on top of that, with negative values
|
354 |
-
indicating cropping. Pixels outside the image are assumed to be zero.
|
355 |
-
|
356 |
-
Args:
|
357 |
-
x: Float32/float64/float16 input tensor of the shape
|
358 |
-
`[batch_size, num_channels, in_height, in_width]`.
|
359 |
-
f: Float32 FIR filter of the shape
|
360 |
-
`[filter_height, filter_width]` (non-separable),
|
361 |
-
`[filter_taps]` (separable), or
|
362 |
-
`None` (identity).
|
363 |
-
down: Integer downsampling factor. Can be a single int or a list/tuple
|
364 |
-
`[x, y]` (default: 1).
|
365 |
-
padding: Padding with respect to the input. Can be a single number or a
|
366 |
-
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
367 |
-
(default: 0).
|
368 |
-
flip_filter: False = convolution, True = correlation (default: False).
|
369 |
-
gain: Overall scaling factor for signal magnitude (default: 1).
|
370 |
-
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
371 |
-
|
372 |
-
Returns:
|
373 |
-
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
374 |
-
"""
|
375 |
-
downx, downy = _parse_scaling(down)
|
376 |
-
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
377 |
-
fw, fh = _get_filter_size(f)
|
378 |
-
p = [
|
379 |
-
padx0 + (fw - downx + 1) // 2,
|
380 |
-
padx1 + (fw - downx) // 2,
|
381 |
-
pady0 + (fh - downy + 1) // 2,
|
382 |
-
pady1 + (fh - downy) // 2,
|
383 |
-
]
|
384 |
-
return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
|
385 |
-
|
386 |
-
#----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch_utils/persistence.py
DELETED
@@ -1,253 +0,0 @@
|
|
1 |
-
# Copyright (c) SenseTime Research. All rights reserved.
|
2 |
-
|
3 |
-
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
4 |
-
#
|
5 |
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
6 |
-
# and proprietary rights in and to this software, related documentation
|
7 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
8 |
-
# distribution of this software and related documentation without an express
|
9 |
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
-
|
11 |
-
"""Facilities for pickling Python code alongside other data.
|
12 |
-
|
13 |
-
The pickled code is automatically imported into a separate Python module
|
14 |
-
during unpickling. This way, any previously exported pickles will remain
|
15 |
-
usable even if the original code is no longer available, or if the current
|
16 |
-
version of the code is not consistent with what was originally pickled."""
|
17 |
-
|
18 |
-
import sys
|
19 |
-
import pickle
|
20 |
-
import io
|
21 |
-
import inspect
|
22 |
-
import copy
|
23 |
-
import uuid
|
24 |
-
import types
|
25 |
-
import dnnlib
|
26 |
-
|
27 |
-
#----------------------------------------------------------------------------
|
28 |
-
|
29 |
-
_version = 6 # internal version number
|
30 |
-
_decorators = set() # {decorator_class, ...}
|
31 |
-
_import_hooks = [] # [hook_function, ...]
|
32 |
-
_module_to_src_dict = dict() # {module: src, ...}
|
33 |
-
_src_to_module_dict = dict() # {src: module, ...}
|
34 |
-
|
35 |
-
#----------------------------------------------------------------------------
|
36 |
-
|
37 |
-
def persistent_class(orig_class):
|
38 |
-
r"""Class decorator that extends a given class to save its source code
|
39 |
-
when pickled.
|
40 |
-
|
41 |
-
Example:
|
42 |
-
|
43 |
-
from torch_utils import persistence
|
44 |
-
|
45 |
-
@persistence.persistent_class
|
46 |
-
class MyNetwork(torch.nn.Module):
|
47 |
-
def __init__(self, num_inputs, num_outputs):
|
48 |
-
super().__init__()
|
49 |
-
self.fc = MyLayer(num_inputs, num_outputs)
|
50 |
-
...
|
51 |
-
|
52 |
-
@persistence.persistent_class
|
53 |
-
class MyLayer(torch.nn.Module):
|
54 |
-
...
|
55 |
-
|
56 |
-
When pickled, any instance of `MyNetwork` and `MyLayer` will save its
|
57 |
-
source code alongside other internal state (e.g., parameters, buffers,
|
58 |
-
and submodules). This way, any previously exported pickle will remain
|
59 |
-
usable even if the class definitions have been modified or are no
|
60 |
-
longer available.
|
61 |
-
|
62 |
-
The decorator saves the source code of the entire Python module
|
63 |
-
containing the decorated class. It does *not* save the source code of
|
64 |
-
any imported modules. Thus, the imported modules must be available
|
65 |
-
during unpickling, also including `torch_utils.persistence` itself.
|
66 |
-
|
67 |
-
It is ok to call functions defined in the same module from the
|
68 |
-
decorated class. However, if the decorated class depends on other
|
69 |
-
classes defined in the same module, they must be decorated as well.
|
70 |
-
This is illustrated in the above example in the case of `MyLayer`.
|
71 |
-
|
72 |
-
It is also possible to employ the decorator just-in-time before
|
73 |
-
calling the constructor. For example:
|
74 |
-
|
75 |
-
cls = MyLayer
|
76 |
-
if want_to_make_it_persistent:
|
77 |
-
cls = persistence.persistent_class(cls)
|
78 |
-
layer = cls(num_inputs, num_outputs)
|
79 |
-
|
80 |
-
As an additional feature, the decorator also keeps track of the
|
81 |
-
arguments that were used to construct each instance of the decorated
|
82 |
-
class. The arguments can be queried via `obj.init_args` and
|
83 |
-
`obj.init_kwargs`, and they are automatically pickled alongside other
|
84 |
-
object state. A typical use case is to first unpickle a previous
|
85 |
-
instance of a persistent class, and then upgrade it to use the latest
|
86 |
-
version of the source code:
|
87 |
-
|
88 |
-
with open('old_pickle.pkl', 'rb') as f:
|
89 |
-
old_net = pickle.load(f)
|
90 |
-
new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
|
91 |
-
misc.copy_params_and_buffers(old_net, new_net, require_all=True)
|
92 |
-
"""
|
93 |
-
assert isinstance(orig_class, type)
|
94 |
-
if is_persistent(orig_class):
|
95 |
-
return orig_class
|
96 |
-
|
97 |
-
assert orig_class.__module__ in sys.modules
|
98 |
-
orig_module = sys.modules[orig_class.__module__]
|
99 |
-
orig_module_src = _module_to_src(orig_module)
|
100 |
-
|
101 |
-
class Decorator(orig_class):
|
102 |
-
_orig_module_src = orig_module_src
|
103 |
-
_orig_class_name = orig_class.__name__
|
104 |
-
|
105 |
-
def __init__(self, *args, **kwargs):
|
106 |
-
super().__init__(*args, **kwargs)
|
107 |
-
self._init_args = copy.deepcopy(args)
|
108 |
-
self._init_kwargs = copy.deepcopy(kwargs)
|
109 |
-
assert orig_class.__name__ in orig_module.__dict__
|
110 |
-
_check_pickleable(self.__reduce__())
|
111 |
-
|
112 |
-
@property
|
113 |
-
def init_args(self):
|
114 |
-
return copy.deepcopy(self._init_args)
|
115 |
-
|
116 |
-
@property
|
117 |
-
def init_kwargs(self):
|
118 |
-
return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs))
|
119 |
-
|
120 |
-
def __reduce__(self):
|
121 |
-
fields = list(super().__reduce__())
|
122 |
-
fields += [None] * max(3 - len(fields), 0)
|
123 |
-
if fields[0] is not _reconstruct_persistent_obj:
|
124 |
-
meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
|
125 |
-
fields[0] = _reconstruct_persistent_obj # reconstruct func
|
126 |
-
fields[1] = (meta,) # reconstruct args
|
127 |
-
fields[2] = None # state dict
|
128 |
-
return tuple(fields)
|
129 |
-
|
130 |
-
Decorator.__name__ = orig_class.__name__
|
131 |
-
_decorators.add(Decorator)
|
132 |
-
return Decorator
|
133 |
-
|
134 |
-
#----------------------------------------------------------------------------
|
135 |
-
|
136 |
-
def is_persistent(obj):
|
137 |
-
r"""Test whether the given object or class is persistent, i.e.,
|
138 |
-
whether it will save its source code when pickled.
|
139 |
-
"""
|
140 |
-
try:
|
141 |
-
if obj in _decorators:
|
142 |
-
return True
|
143 |
-
except TypeError:
|
144 |
-
pass
|
145 |
-
return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
|
146 |
-
|
147 |
-
#----------------------------------------------------------------------------
|
148 |
-
|
149 |
-
def import_hook(hook):
|
150 |
-
r"""Register an import hook that is called whenever a persistent object
|
151 |
-
is being unpickled. A typical use case is to patch the pickled source
|
152 |
-
code to avoid errors and inconsistencies when the API of some imported
|
153 |
-
module has changed.
|
154 |
-
|
155 |
-
The hook should have the following signature:
|
156 |
-
|
157 |
-
hook(meta) -> modified meta
|
158 |
-
|
159 |
-
`meta` is an instance of `dnnlib.EasyDict` with the following fields:
|
160 |
-
|
161 |
-
type: Type of the persistent object, e.g. `'class'`.
|
162 |
-
version: Internal version number of `torch_utils.persistence`.
|
163 |
-
module_src Original source code of the Python module.
|
164 |
-
class_name: Class name in the original Python module.
|
165 |
-
state: Internal state of the object.
|
166 |
-
|
167 |
-
Example:
|
168 |
-
|
169 |
-
@persistence.import_hook
|
170 |
-
def wreck_my_network(meta):
|
171 |
-
if meta.class_name == 'MyNetwork':
|
172 |
-
print('MyNetwork is being imported. I will wreck it!')
|
173 |
-
meta.module_src = meta.module_src.replace("True", "False")
|
174 |
-
return meta
|
175 |
-
"""
|
176 |
-
assert callable(hook)
|
177 |
-
_import_hooks.append(hook)
|
178 |
-
|
179 |
-
#----------------------------------------------------------------------------
|
180 |
-
|
181 |
-
def _reconstruct_persistent_obj(meta):
|
182 |
-
r"""Hook that is called internally by the `pickle` module to unpickle
|
183 |
-
a persistent object.
|
184 |
-
"""
|
185 |
-
meta = dnnlib.EasyDict(meta)
|
186 |
-
meta.state = dnnlib.EasyDict(meta.state)
|
187 |
-
for hook in _import_hooks:
|
188 |
-
meta = hook(meta)
|
189 |
-
assert meta is not None
|
190 |
-
|
191 |
-
assert meta.version == _version
|
192 |
-
module = _src_to_module(meta.module_src)
|
193 |
-
|
194 |
-
assert meta.type == 'class'
|
195 |
-
orig_class = module.__dict__[meta.class_name]
|
196 |
-
decorator_class = persistent_class(orig_class)
|
197 |
-
obj = decorator_class.__new__(decorator_class)
|
198 |
-
|
199 |
-
setstate = getattr(obj, '__setstate__', None)
|
200 |
-
if callable(setstate):
|
201 |
-
setstate(meta.state) # pylint: disable=not-callable
|
202 |
-
else:
|
203 |
-
obj.__dict__.update(meta.state)
|
204 |
-
return obj
|
205 |
-
|
206 |
-
#----------------------------------------------------------------------------
|
207 |
-
|
208 |
-
def _module_to_src(module):
|
209 |
-
r"""Query the source code of a given Python module.
|
210 |
-
"""
|
211 |
-
src = _module_to_src_dict.get(module, None)
|
212 |
-
if src is None:
|
213 |
-
src = inspect.getsource(module)
|
214 |
-
_module_to_src_dict[module] = src
|
215 |
-
_src_to_module_dict[src] = module
|
216 |
-
return src
|
217 |
-
|
218 |
-
def _src_to_module(src):
|
219 |
-
r"""Get or create a Python module for the given source code.
|
220 |
-
"""
|
221 |
-
module = _src_to_module_dict.get(src, None)
|
222 |
-
if module is None:
|
223 |
-
module_name = "_imported_module_" + uuid.uuid4().hex
|
224 |
-
module = types.ModuleType(module_name)
|
225 |
-
sys.modules[module_name] = module
|
226 |
-
_module_to_src_dict[module] = src
|
227 |
-
_src_to_module_dict[src] = module
|
228 |
-
exec(src, module.__dict__) # pylint: disable=exec-used
|
229 |
-
return module
|
230 |
-
|
231 |
-
#----------------------------------------------------------------------------
|
232 |
-
|
233 |
-
def _check_pickleable(obj):
|
234 |
-
r"""Check that the given object is pickleable, raising an exception if
|
235 |
-
it is not. This function is expected to be considerably more efficient
|
236 |
-
than actually pickling the object.
|
237 |
-
"""
|
238 |
-
def recurse(obj):
|
239 |
-
if isinstance(obj, (list, tuple, set)):
|
240 |
-
return [recurse(x) for x in obj]
|
241 |
-
if isinstance(obj, dict):
|
242 |
-
return [[recurse(x), recurse(y)] for x, y in obj.items()]
|
243 |
-
if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
|
244 |
-
return None # Python primitive types are pickleable.
|
245 |
-
if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']:
|
246 |
-
return None # NumPy arrays and PyTorch tensors are pickleable.
|
247 |
-
if is_persistent(obj):
|
248 |
-
return None # Persistent objects are pickleable, by virtue of the constructor check.
|
249 |
-
return obj
|
250 |
-
with io.BytesIO() as f:
|
251 |
-
pickle.dump(recurse(obj), f)
|
252 |
-
|
253 |
-
#----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch_utils/training_stats.py
DELETED
@@ -1,270 +0,0 @@
|
|
1 |
-
# Copyright (c) SenseTime Research. All rights reserved.
|
2 |
-
|
3 |
-
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
4 |
-
#
|
5 |
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
6 |
-
# and proprietary rights in and to this software, related documentation
|
7 |
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
8 |
-
# distribution of this software and related documentation without an express
|
9 |
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
10 |
-
|
11 |
-
"""Facilities for reporting and collecting training statistics across
|
12 |
-
multiple processes and devices. The interface is designed to minimize
|
13 |
-
synchronization overhead as well as the amount of boilerplate in user
|
14 |
-
code."""
|
15 |
-
|
16 |
-
import re
|
17 |
-
import numpy as np
|
18 |
-
import torch
|
19 |
-
import dnnlib
|
20 |
-
|
21 |
-
from . import misc
|
22 |
-
|
23 |
-
#----------------------------------------------------------------------------
|
24 |
-
|
25 |
-
_num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
|
26 |
-
_reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
|
27 |
-
_counter_dtype = torch.float64 # Data type to use for the internal counters.
|
28 |
-
_rank = 0 # Rank of the current process.
|
29 |
-
_sync_device = None # Device to use for multiprocess communication. None = single-process.
|
30 |
-
_sync_called = False # Has _sync() been called yet?
|
31 |
-
_counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor
|
32 |
-
_cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
|
33 |
-
|
34 |
-
#----------------------------------------------------------------------------
|
35 |
-
|
36 |
-
def init_multiprocessing(rank, sync_device):
|
37 |
-
r"""Initializes `torch_utils.training_stats` for collecting statistics
|
38 |
-
across multiple processes.
|
39 |
-
|
40 |
-
This function must be called after
|
41 |
-
`torch.distributed.init_process_group()` and before `Collector.update()`.
|
42 |
-
The call is not necessary if multi-process collection is not needed.
|
43 |
-
|
44 |
-
Args:
|
45 |
-
rank: Rank of the current process.
|
46 |
-
sync_device: PyTorch device to use for inter-process
|
47 |
-
communication, or None to disable multi-process
|
48 |
-
collection. Typically `torch.device('cuda', rank)`.
|
49 |
-
"""
|
50 |
-
global _rank, _sync_device
|
51 |
-
assert not _sync_called
|
52 |
-
_rank = rank
|
53 |
-
_sync_device = sync_device
|
54 |
-
|
55 |
-
#----------------------------------------------------------------------------
|
56 |
-
|
57 |
-
@misc.profiled_function
|
58 |
-
def report(name, value):
|
59 |
-
r"""Broadcasts the given set of scalars to all interested instances of
|
60 |
-
`Collector`, across device and process boundaries.
|
61 |
-
|
62 |
-
This function is expected to be extremely cheap and can be safely
|
63 |
-
called from anywhere in the training loop, loss function, or inside a
|
64 |
-
`torch.nn.Module`.
|
65 |
-
|
66 |
-
Warning: The current implementation expects the set of unique names to
|
67 |
-
be consistent across processes. Please make sure that `report()` is
|
68 |
-
called at least once for each unique name by each process, and in the
|
69 |
-
same order. If a given process has no scalars to broadcast, it can do
|
70 |
-
`report(name, [])` (empty list).
|
71 |
-
|
72 |
-
Args:
|
73 |
-
name: Arbitrary string specifying the name of the statistic.
|
74 |
-
Averages are accumulated separately for each unique name.
|
75 |
-
value: Arbitrary set of scalars. Can be a list, tuple,
|
76 |
-
NumPy array, PyTorch tensor, or Python scalar.
|
77 |
-
|
78 |
-
Returns:
|
79 |
-
The same `value` that was passed in.
|
80 |
-
"""
|
81 |
-
if name not in _counters:
|
82 |
-
_counters[name] = dict()
|
83 |
-
|
84 |
-
elems = torch.as_tensor(value)
|
85 |
-
if elems.numel() == 0:
|
86 |
-
return value
|
87 |
-
|
88 |
-
elems = elems.detach().flatten().to(_reduce_dtype)
|
89 |
-
moments = torch.stack([
|
90 |
-
torch.ones_like(elems).sum(),
|
91 |
-
elems.sum(),
|
92 |
-
elems.square().sum(),
|
93 |
-
])
|
94 |
-
assert moments.ndim == 1 and moments.shape[0] == _num_moments
|
95 |
-
moments = moments.to(_counter_dtype)
|
96 |
-
|
97 |
-
device = moments.device
|
98 |
-
if device not in _counters[name]:
|
99 |
-
_counters[name][device] = torch.zeros_like(moments)
|
100 |
-
_counters[name][device].add_(moments)
|
101 |
-
return value
|
102 |
-
|
103 |
-
#----------------------------------------------------------------------------
|
104 |
-
|
105 |
-
def report0(name, value):
|
106 |
-
r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
|
107 |
-
but ignores any scalars provided by the other processes.
|
108 |
-
See `report()` for further details.
|
109 |
-
"""
|
110 |
-
report(name, value if _rank == 0 else [])
|
111 |
-
return value
|
112 |
-
|
113 |
-
#----------------------------------------------------------------------------
|
114 |
-
|
115 |
-
class Collector:
|
116 |
-
r"""Collects the scalars broadcasted by `report()` and `report0()` and
|
117 |
-
computes their long-term averages (mean and standard deviation) over
|
118 |
-
user-defined periods of time.
|
119 |
-
|
120 |
-
The averages are first collected into internal counters that are not
|
121 |
-
directly visible to the user. They are then copied to the user-visible
|
122 |
-
state as a result of calling `update()` and can then be queried using
|
123 |
-
`mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the
|
124 |
-
internal counters for the next round, so that the user-visible state
|
125 |
-
effectively reflects averages collected between the last two calls to
|
126 |
-
`update()`.
|
127 |
-
|
128 |
-
Args:
|
129 |
-
regex: Regular expression defining which statistics to
|
130 |
-
collect. The default is to collect everything.
|
131 |
-
keep_previous: Whether to retain the previous averages if no
|
132 |
-
scalars were collected on a given round
|
133 |
-
(default: True).
|
134 |
-
"""
|
135 |
-
def __init__(self, regex='.*', keep_previous=True):
|
136 |
-
self._regex = re.compile(regex)
|
137 |
-
self._keep_previous = keep_previous
|
138 |
-
self._cumulative = dict()
|
139 |
-
self._moments = dict()
|
140 |
-
self.update()
|
141 |
-
self._moments.clear()
|
142 |
-
|
143 |
-
def names(self):
|
144 |
-
r"""Returns the names of all statistics broadcasted so far that
|
145 |
-
match the regular expression specified at construction time.
|
146 |
-
"""
|
147 |
-
return [name for name in _counters if self._regex.fullmatch(name)]
|
148 |
-
|
149 |
-
def update(self):
|
150 |
-
r"""Copies current values of the internal counters to the
|
151 |
-
user-visible state and resets them for the next round.
|
152 |
-
|
153 |
-
If `keep_previous=True` was specified at construction time, the
|
154 |
-
operation is skipped for statistics that have received no scalars
|
155 |
-
since the last update, retaining their previous averages.
|
156 |
-
|
157 |
-
This method performs a number of GPU-to-CPU transfers and one
|
158 |
-
`torch.distributed.all_reduce()`. It is intended to be called
|
159 |
-
periodically in the main training loop, typically once every
|
160 |
-
N training steps.
|
161 |
-
"""
|
162 |
-
if not self._keep_previous:
|
163 |
-
self._moments.clear()
|
164 |
-
for name, cumulative in _sync(self.names()):
|
165 |
-
if name not in self._cumulative:
|
166 |
-
self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
|
167 |
-
delta = cumulative - self._cumulative[name]
|
168 |
-
self._cumulative[name].copy_(cumulative)
|
169 |
-
if float(delta[0]) != 0:
|
170 |
-
self._moments[name] = delta
|
171 |
-
|
172 |
-
def _get_delta(self, name):
|
173 |
-
r"""Returns the raw moments that were accumulated for the given
|
174 |
-
statistic between the last two calls to `update()`, or zero if
|
175 |
-
no scalars were collected.
|
176 |
-
"""
|
177 |
-
assert self._regex.fullmatch(name)
|
178 |
-
if name not in self._moments:
|
179 |
-
self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
|
180 |
-
return self._moments[name]
|
181 |
-
|
182 |
-
def num(self, name):
|
183 |
-
r"""Returns the number of scalars that were accumulated for the given
|
184 |
-
statistic between the last two calls to `update()`, or zero if
|
185 |
-
no scalars were collected.
|
186 |
-
"""
|
187 |
-
delta = self._get_delta(name)
|
188 |
-
return int(delta[0])
|
189 |
-
|
190 |
-
def mean(self, name):
|
191 |
-
r"""Returns the mean of the scalars that were accumulated for the
|
192 |
-
given statistic between the last two calls to `update()`, or NaN if
|
193 |
-
no scalars were collected.
|
194 |
-
"""
|
195 |
-
delta = self._get_delta(name)
|
196 |
-
if int(delta[0]) == 0:
|
197 |
-
return float('nan')
|
198 |
-
return float(delta[1] / delta[0])
|
199 |
-
|
200 |
-
def std(self, name):
|
201 |
-
r"""Returns the standard deviation of the scalars that were
|
202 |
-
accumulated for the given statistic between the last two calls to
|
203 |
-
`update()`, or NaN if no scalars were collected.
|
204 |
-
"""
|
205 |
-
delta = self._get_delta(name)
|
206 |
-
if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
|
207 |
-
return float('nan')
|
208 |
-
if int(delta[0]) == 1:
|
209 |
-
return float(0)
|
210 |
-
mean = float(delta[1] / delta[0])
|
211 |
-
raw_var = float(delta[2] / delta[0])
|
212 |
-
return np.sqrt(max(raw_var - np.square(mean), 0))
|
213 |
-
|
214 |
-
def as_dict(self):
|
215 |
-
r"""Returns the averages accumulated between the last two calls to
|
216 |
-
`update()` as an `dnnlib.EasyDict`. The contents are as follows:
|
217 |
-
|
218 |
-
dnnlib.EasyDict(
|
219 |
-
NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT),
|
220 |
-
...
|
221 |
-
)
|
222 |
-
"""
|
223 |
-
stats = dnnlib.EasyDict()
|
224 |
-
for name in self.names():
|
225 |
-
stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
|
226 |
-
return stats
|
227 |
-
|
228 |
-
def __getitem__(self, name):
|
229 |
-
r"""Convenience getter.
|
230 |
-
`collector[name]` is a synonym for `collector.mean(name)`.
|
231 |
-
"""
|
232 |
-
return self.mean(name)
|
233 |
-
|
234 |
-
#----------------------------------------------------------------------------
|
235 |
-
|
236 |
-
def _sync(names):
|
237 |
-
r"""Synchronize the global cumulative counters across devices and
|
238 |
-
processes. Called internally by `Collector.update()`.
|
239 |
-
"""
|
240 |
-
if len(names) == 0:
|
241 |
-
return []
|
242 |
-
global _sync_called
|
243 |
-
_sync_called = True
|
244 |
-
|
245 |
-
# Collect deltas within current rank.
|
246 |
-
deltas = []
|
247 |
-
device = _sync_device if _sync_device is not None else torch.device('cpu')
|
248 |
-
for name in names:
|
249 |
-
delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
|
250 |
-
for counter in _counters[name].values():
|
251 |
-
delta.add_(counter.to(device))
|
252 |
-
counter.copy_(torch.zeros_like(counter))
|
253 |
-
deltas.append(delta)
|
254 |
-
deltas = torch.stack(deltas)
|
255 |
-
|
256 |
-
# Sum deltas across ranks.
|
257 |
-
if _sync_device is not None:
|
258 |
-
torch.distributed.all_reduce(deltas)
|
259 |
-
|
260 |
-
# Update cumulative values.
|
261 |
-
deltas = deltas.cpu()
|
262 |
-
for idx, name in enumerate(names):
|
263 |
-
if name not in _cumulative:
|
264 |
-
_cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
|
265 |
-
_cumulative[name].add_(deltas[idx])
|
266 |
-
|
267 |
-
# Return name-value pairs.
|
268 |
-
return [(name, _cumulative[name]) for name in names]
|
269 |
-
|
270 |
-
#----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|