feng2022 commited on
Commit
664952a
1 Parent(s): 1fa48f0
Files changed (47) hide show
  1. dnnlib/__init__.py +11 -0
  2. dnnlib/tflib/__init__.py +20 -0
  3. dnnlib/tflib/autosummary.py +193 -0
  4. dnnlib/tflib/custom_ops.py +171 -0
  5. dnnlib/tflib/network.py +592 -0
  6. dnnlib/tflib/ops/__init__.py +9 -0
  7. dnnlib/tflib/ops/fused_bias_act.cu +190 -0
  8. dnnlib/tflib/ops/fused_bias_act.py +198 -0
  9. dnnlib/tflib/ops/upfirdn_2d.cu +328 -0
  10. dnnlib/tflib/ops/upfirdn_2d.py +366 -0
  11. dnnlib/tflib/optimizer.py +338 -0
  12. dnnlib/tflib/tfutil.py +254 -0
  13. dnnlib/util.py +479 -0
  14. torch_utils/__init__.py +11 -0
  15. torch_utils/custom_ops.py +238 -0
  16. torch_utils/misc.py +264 -0
  17. torch_utils/models.py +756 -0
  18. torch_utils/models_face.py +809 -0
  19. torch_utils/op_edit/__init__.py +4 -0
  20. torch_utils/op_edit/fused_act.py +99 -0
  21. torch_utils/op_edit/fused_bias_act.cpp +23 -0
  22. torch_utils/op_edit/fused_bias_act_kernel.cu +101 -0
  23. torch_utils/op_edit/upfirdn2d.cpp +25 -0
  24. torch_utils/op_edit/upfirdn2d.py +202 -0
  25. torch_utils/op_edit/upfirdn2d_kernel.cu +371 -0
  26. torch_utils/ops/__init__.py +3 -0
  27. torch_utils/ops/bias_act.cpp +101 -0
  28. torch_utils/ops/bias_act.cu +175 -0
  29. torch_utils/ops/bias_act.h +40 -0
  30. torch_utils/ops/bias_act.py +214 -0
  31. torch_utils/ops/conv2d_gradfix.py +172 -0
  32. torch_utils/ops/conv2d_resample.py +158 -0
  33. torch_utils/ops/filtered_lrelu.cpp +300 -0
  34. torch_utils/ops/filtered_lrelu.cu +1284 -0
  35. torch_utils/ops/filtered_lrelu.h +90 -0
  36. torch_utils/ops/filtered_lrelu.py +282 -0
  37. torch_utils/ops/filtered_lrelu_ns.cu +27 -0
  38. torch_utils/ops/filtered_lrelu_rd.cu +27 -0
  39. torch_utils/ops/filtered_lrelu_wr.cu +27 -0
  40. torch_utils/ops/fma.py +62 -0
  41. torch_utils/ops/grid_sample_gradfix.py +85 -0
  42. torch_utils/ops/upfirdn2d.cpp +105 -0
  43. torch_utils/ops/upfirdn2d.cu +352 -0
  44. torch_utils/ops/upfirdn2d.h +61 -0
  45. torch_utils/ops/upfirdn2d.py +386 -0
  46. torch_utils/persistence.py +253 -0
  47. torch_utils/training_stats.py +270 -0
dnnlib/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ # and proprietary rights in and to this software, related documentation
7
+ # and any modifications thereto. Any use, reproduction, disclosure or
8
+ # distribution of this software and related documentation without an express
9
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+
11
+ from .util import EasyDict, make_cache_dir_path
dnnlib/tflib/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
4
+ #
5
+ # This work is made available under the Nvidia Source Code License-NC.
6
+ # To view a copy of this license, visit
7
+ # https://nvlabs.github.io/stylegan2/license.html
8
+
9
+ from . import autosummary
10
+ from . import network
11
+ from . import optimizer
12
+ from . import tfutil
13
+ from . import custom_ops
14
+
15
+ from .tfutil import *
16
+ from .network import Network
17
+
18
+ from .optimizer import Optimizer
19
+
20
+ from .custom_ops import get_plugin
dnnlib/tflib/autosummary.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
4
+ #
5
+ # This work is made available under the Nvidia Source Code License-NC.
6
+ # To view a copy of this license, visit
7
+ # https://nvlabs.github.io/stylegan2/license.html
8
+
9
+ """Helper for adding automatically tracked values to Tensorboard.
10
+
11
+ Autosummary creates an identity op that internally keeps track of the input
12
+ values and automatically shows up in TensorBoard. The reported value
13
+ represents an average over input components. The average is accumulated
14
+ constantly over time and flushed when save_summaries() is called.
15
+
16
+ Notes:
17
+ - The output tensor must be used as an input for something else in the
18
+ graph. Otherwise, the autosummary op will not get executed, and the average
19
+ value will not get accumulated.
20
+ - It is perfectly fine to include autosummaries with the same name in
21
+ several places throughout the graph, even if they are executed concurrently.
22
+ - It is ok to also pass in a python scalar or numpy array. In this case, it
23
+ is added to the average immediately.
24
+ """
25
+
26
+ from collections import OrderedDict
27
+ import numpy as np
28
+ import tensorflow as tf
29
+ from tensorboard import summary as summary_lib
30
+ from tensorboard.plugins.custom_scalar import layout_pb2
31
+
32
+ from . import tfutil
33
+ from .tfutil import TfExpression
34
+ from .tfutil import TfExpressionEx
35
+
36
+ # Enable "Custom scalars" tab in TensorBoard for advanced formatting.
37
+ # Disabled by default to reduce tfevents file size.
38
+ enable_custom_scalars = False
39
+
40
+ _dtype = tf.float64
41
+ _vars = OrderedDict() # name => [var, ...]
42
+ _immediate = OrderedDict() # name => update_op, update_value
43
+ _finalized = False
44
+ _merge_op = None
45
+
46
+
47
+ def _create_var(name: str, value_expr: TfExpression) -> TfExpression:
48
+ """Internal helper for creating autosummary accumulators."""
49
+ assert not _finalized
50
+ name_id = name.replace("/", "_")
51
+ v = tf.cast(value_expr, _dtype)
52
+
53
+ if v.shape.is_fully_defined():
54
+ size = np.prod(v.shape.as_list())
55
+ size_expr = tf.constant(size, dtype=_dtype)
56
+ else:
57
+ size = None
58
+ size_expr = tf.reduce_prod(tf.cast(tf.shape(v), _dtype))
59
+
60
+ if size == 1:
61
+ if v.shape.ndims != 0:
62
+ v = tf.reshape(v, [])
63
+ v = [size_expr, v, tf.square(v)]
64
+ else:
65
+ v = [size_expr, tf.reduce_sum(v), tf.reduce_sum(tf.square(v))]
66
+ v = tf.cond(tf.is_finite(v[1]), lambda: tf.stack(v), lambda: tf.zeros(3, dtype=_dtype))
67
+
68
+ with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.control_dependencies(None):
69
+ var = tf.Variable(tf.zeros(3, dtype=_dtype), trainable=False) # [sum(1), sum(x), sum(x**2)]
70
+ update_op = tf.cond(tf.is_variable_initialized(var), lambda: tf.assign_add(var, v), lambda: tf.assign(var, v))
71
+
72
+ if name in _vars:
73
+ _vars[name].append(var)
74
+ else:
75
+ _vars[name] = [var]
76
+ return update_op
77
+
78
+
79
+ def autosummary(name: str, value: TfExpressionEx, passthru: TfExpressionEx = None, condition: TfExpressionEx = True) -> TfExpressionEx:
80
+ """Create a new autosummary.
81
+
82
+ Args:
83
+ name: Name to use in TensorBoard
84
+ value: TensorFlow expression or python value to track
85
+ passthru: Optionally return this TF node without modifications but tack an autosummary update side-effect to this node.
86
+
87
+ Example use of the passthru mechanism:
88
+
89
+ n = autosummary('l2loss', loss, passthru=n)
90
+
91
+ This is a shorthand for the following code:
92
+
93
+ with tf.control_dependencies([autosummary('l2loss', loss)]):
94
+ n = tf.identity(n)
95
+ """
96
+ tfutil.assert_tf_initialized()
97
+ name_id = name.replace("/", "_")
98
+
99
+ if tfutil.is_tf_expression(value):
100
+ with tf.name_scope("summary_" + name_id), tf.device(value.device):
101
+ condition = tf.convert_to_tensor(condition, name='condition')
102
+ update_op = tf.cond(condition, lambda: tf.group(_create_var(name, value)), tf.no_op)
103
+ with tf.control_dependencies([update_op]):
104
+ return tf.identity(value if passthru is None else passthru)
105
+
106
+ else: # python scalar or numpy array
107
+ assert not tfutil.is_tf_expression(passthru)
108
+ assert not tfutil.is_tf_expression(condition)
109
+ if condition:
110
+ if name not in _immediate:
111
+ with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.device(None), tf.control_dependencies(None):
112
+ update_value = tf.placeholder(_dtype)
113
+ update_op = _create_var(name, update_value)
114
+ _immediate[name] = update_op, update_value
115
+ update_op, update_value = _immediate[name]
116
+ tfutil.run(update_op, {update_value: value})
117
+ return value if passthru is None else passthru
118
+
119
+
120
+ def finalize_autosummaries() -> None:
121
+ """Create the necessary ops to include autosummaries in TensorBoard report.
122
+ Note: This should be done only once per graph.
123
+ """
124
+ global _finalized
125
+ tfutil.assert_tf_initialized()
126
+
127
+ if _finalized:
128
+ return None
129
+
130
+ _finalized = True
131
+ tfutil.init_uninitialized_vars([var for vars_list in _vars.values() for var in vars_list])
132
+
133
+ # Create summary ops.
134
+ with tf.device(None), tf.control_dependencies(None):
135
+ for name, vars_list in _vars.items():
136
+ name_id = name.replace("/", "_")
137
+ with tfutil.absolute_name_scope("Autosummary/" + name_id):
138
+ moments = tf.add_n(vars_list)
139
+ moments /= moments[0]
140
+ with tf.control_dependencies([moments]): # read before resetting
141
+ reset_ops = [tf.assign(var, tf.zeros(3, dtype=_dtype)) for var in vars_list]
142
+ with tf.name_scope(None), tf.control_dependencies(reset_ops): # reset before reporting
143
+ mean = moments[1]
144
+ std = tf.sqrt(moments[2] - tf.square(moments[1]))
145
+ tf.summary.scalar(name, mean)
146
+ if enable_custom_scalars:
147
+ tf.summary.scalar("xCustomScalars/" + name + "/margin_lo", mean - std)
148
+ tf.summary.scalar("xCustomScalars/" + name + "/margin_hi", mean + std)
149
+
150
+ # Setup layout for custom scalars.
151
+ layout = None
152
+ if enable_custom_scalars:
153
+ cat_dict = OrderedDict()
154
+ for series_name in sorted(_vars.keys()):
155
+ p = series_name.split("/")
156
+ cat = p[0] if len(p) >= 2 else ""
157
+ chart = "/".join(p[1:-1]) if len(p) >= 3 else p[-1]
158
+ if cat not in cat_dict:
159
+ cat_dict[cat] = OrderedDict()
160
+ if chart not in cat_dict[cat]:
161
+ cat_dict[cat][chart] = []
162
+ cat_dict[cat][chart].append(series_name)
163
+ categories = []
164
+ for cat_name, chart_dict in cat_dict.items():
165
+ charts = []
166
+ for chart_name, series_names in chart_dict.items():
167
+ series = []
168
+ for series_name in series_names:
169
+ series.append(layout_pb2.MarginChartContent.Series(
170
+ value=series_name,
171
+ lower="xCustomScalars/" + series_name + "/margin_lo",
172
+ upper="xCustomScalars/" + series_name + "/margin_hi"))
173
+ margin = layout_pb2.MarginChartContent(series=series)
174
+ charts.append(layout_pb2.Chart(title=chart_name, margin=margin))
175
+ categories.append(layout_pb2.Category(title=cat_name, chart=charts))
176
+ layout = summary_lib.custom_scalar_pb(layout_pb2.Layout(category=categories))
177
+ return layout
178
+
179
+ def save_summaries(file_writer, global_step=None):
180
+ """Call FileWriter.add_summary() with all summaries in the default graph,
181
+ automatically finalizing and merging them on the first call.
182
+ """
183
+ global _merge_op
184
+ tfutil.assert_tf_initialized()
185
+
186
+ if _merge_op is None:
187
+ layout = finalize_autosummaries()
188
+ if layout is not None:
189
+ file_writer.add_summary(layout)
190
+ with tf.device(None), tf.control_dependencies(None):
191
+ _merge_op = tf.summary.merge_all()
192
+
193
+ file_writer.add_summary(_merge_op.eval(), global_step)
dnnlib/tflib/custom_ops.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
4
+ #
5
+ # This work is made available under the Nvidia Source Code License-NC.
6
+ # To view a copy of this license, visit
7
+ # https://nvlabs.github.io/stylegan2/license.html
8
+
9
+ """TensorFlow custom ops builder.
10
+ """
11
+
12
+ import os
13
+ import re
14
+ import uuid
15
+ import hashlib
16
+ import tempfile
17
+ import shutil
18
+ import tensorflow as tf
19
+ from tensorflow.python.client import device_lib # pylint: disable=no-name-in-module
20
+
21
+ #----------------------------------------------------------------------------
22
+ # Global options.
23
+
24
+ cuda_cache_path = os.path.join(os.path.dirname(__file__), '_cudacache')
25
+ cuda_cache_version_tag = 'v1'
26
+ do_not_hash_included_headers = False # Speed up compilation by assuming that headers included by the CUDA code never change. Unsafe!
27
+ verbose = True # Print status messages to stdout.
28
+
29
+ compiler_bindir_search_path = [
30
+ 'C:/Program Files (x86)/Microsoft Visual Studio/2017/Community/VC/Tools/MSVC/14.14.26428/bin/Hostx64/x64',
31
+ 'C:/Program Files (x86)/Microsoft Visual Studio/2019/Community/VC/Tools/MSVC/14.23.28105/bin/Hostx64/x64',
32
+ 'C:/Program Files (x86)/Microsoft Visual Studio 14.0/vc/bin',
33
+ ]
34
+
35
+ #----------------------------------------------------------------------------
36
+ # Internal helper funcs.
37
+
38
+ def _find_compiler_bindir():
39
+ for compiler_path in compiler_bindir_search_path:
40
+ if os.path.isdir(compiler_path):
41
+ return compiler_path
42
+ return None
43
+
44
+ def _get_compute_cap(device):
45
+ caps_str = device.physical_device_desc
46
+ m = re.search('compute capability: (\\d+).(\\d+)', caps_str)
47
+ major = m.group(1)
48
+ minor = m.group(2)
49
+ return (major, minor)
50
+
51
+ def _get_cuda_gpu_arch_string():
52
+ gpus = [x for x in device_lib.list_local_devices() if x.device_type == 'GPU']
53
+ if len(gpus) == 0:
54
+ raise RuntimeError('No GPU devices found')
55
+ (major, minor) = _get_compute_cap(gpus[0])
56
+ return 'sm_%s%s' % (major, minor)
57
+
58
+ def _run_cmd(cmd):
59
+ with os.popen(cmd) as pipe:
60
+ output = pipe.read()
61
+ status = pipe.close()
62
+ if status is not None:
63
+ raise RuntimeError('NVCC returned an error. See below for full command line and output log:\n\n%s\n\n%s' % (cmd, output))
64
+
65
+ def _prepare_nvcc_cli(opts):
66
+ cmd = 'nvcc ' + opts.strip()
67
+ cmd += ' --disable-warnings'
68
+ cmd += ' --include-path "%s"' % tf.sysconfig.get_include()
69
+ cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'protobuf_archive', 'src')
70
+ cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'com_google_absl')
71
+ cmd += ' --include-path "%s"' % os.path.join(tf.sysconfig.get_include(), 'external', 'eigen_archive')
72
+
73
+ compiler_bindir = _find_compiler_bindir()
74
+ if compiler_bindir is None:
75
+ # Require that _find_compiler_bindir succeeds on Windows. Allow
76
+ # nvcc to use whatever is the default on Linux.
77
+ if os.name == 'nt':
78
+ raise RuntimeError('Could not find MSVC/GCC/CLANG installation on this computer. Check compiler_bindir_search_path list in "%s".' % __file__)
79
+ else:
80
+ cmd += ' --compiler-bindir "%s"' % compiler_bindir
81
+ cmd += ' 2>&1'
82
+ return cmd
83
+
84
+ #----------------------------------------------------------------------------
85
+ # Main entry point.
86
+
87
+ _plugin_cache = dict()
88
+
89
+ def get_plugin(cuda_file):
90
+ cuda_file_base = os.path.basename(cuda_file)
91
+ cuda_file_name, cuda_file_ext = os.path.splitext(cuda_file_base)
92
+
93
+ # Already in cache?
94
+ if cuda_file in _plugin_cache:
95
+ return _plugin_cache[cuda_file]
96
+
97
+ # Setup plugin.
98
+ if verbose:
99
+ print('Setting up TensorFlow plugin "%s": ' % cuda_file_base, end='', flush=True)
100
+ try:
101
+ # Hash CUDA source.
102
+ md5 = hashlib.md5()
103
+ with open(cuda_file, 'rb') as f:
104
+ md5.update(f.read())
105
+ md5.update(b'\n')
106
+
107
+ # Hash headers included by the CUDA code by running it through the preprocessor.
108
+ if not do_not_hash_included_headers:
109
+ if verbose:
110
+ print('Preprocessing... ', end='', flush=True)
111
+ with tempfile.TemporaryDirectory() as tmp_dir:
112
+ tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + cuda_file_ext)
113
+ _run_cmd(_prepare_nvcc_cli('"%s" --preprocess -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir)))
114
+ with open(tmp_file, 'rb') as f:
115
+ bad_file_str = ('"' + cuda_file.replace('\\', '/') + '"').encode('utf-8') # __FILE__ in error check macros
116
+ good_file_str = ('"' + cuda_file_base + '"').encode('utf-8')
117
+ for ln in f:
118
+ if not ln.startswith(b'# ') and not ln.startswith(b'#line '): # ignore line number pragmas
119
+ ln = ln.replace(bad_file_str, good_file_str)
120
+ md5.update(ln)
121
+ md5.update(b'\n')
122
+
123
+ # Select compiler options.
124
+ compile_opts = ''
125
+ if os.name == 'nt':
126
+ compile_opts += '"%s"' % os.path.join(tf.sysconfig.get_lib(), 'python', '_pywrap_tensorflow_internal.lib')
127
+ elif os.name == 'posix':
128
+ compile_opts += '"%s"' % os.path.join(tf.sysconfig.get_lib(), 'python', '_pywrap_tensorflow_internal.so')
129
+ compile_opts += ' --compiler-options \'-fPIC -D_GLIBCXX_USE_CXX11_ABI=0\''
130
+ else:
131
+ assert False # not Windows or Linux, w00t?
132
+ compile_opts += ' --gpu-architecture=%s' % _get_cuda_gpu_arch_string()
133
+ compile_opts += ' --use_fast_math'
134
+ nvcc_cmd = _prepare_nvcc_cli(compile_opts)
135
+
136
+ # Hash build configuration.
137
+ md5.update(('nvcc_cmd: ' + nvcc_cmd).encode('utf-8') + b'\n')
138
+ md5.update(('tf.VERSION: ' + tf.VERSION).encode('utf-8') + b'\n')
139
+ md5.update(('cuda_cache_version_tag: ' + cuda_cache_version_tag).encode('utf-8') + b'\n')
140
+
141
+ # Compile if not already compiled.
142
+ bin_file_ext = '.dll' if os.name == 'nt' else '.so'
143
+ bin_file = os.path.join(cuda_cache_path, cuda_file_name + '_' + md5.hexdigest() + bin_file_ext)
144
+ if not os.path.isfile(bin_file):
145
+ if verbose:
146
+ print('Compiling... ', end='', flush=True)
147
+ with tempfile.TemporaryDirectory() as tmp_dir:
148
+ tmp_file = os.path.join(tmp_dir, cuda_file_name + '_tmp' + bin_file_ext)
149
+ _run_cmd(nvcc_cmd + ' "%s" --shared -o "%s" --keep --keep-dir "%s"' % (cuda_file, tmp_file, tmp_dir))
150
+ os.makedirs(cuda_cache_path, exist_ok=True)
151
+ intermediate_file = os.path.join(cuda_cache_path, cuda_file_name + '_' + uuid.uuid4().hex + '_tmp' + bin_file_ext)
152
+ shutil.copyfile(tmp_file, intermediate_file)
153
+ os.rename(intermediate_file, bin_file) # atomic
154
+
155
+ # Load.
156
+ if verbose:
157
+ print('Loading... ', end='', flush=True)
158
+ plugin = tf.load_op_library(bin_file)
159
+
160
+ # Add to cache.
161
+ _plugin_cache[cuda_file] = plugin
162
+ if verbose:
163
+ print('Done.', flush=True)
164
+ return plugin
165
+
166
+ except:
167
+ if verbose:
168
+ print('Failed!', flush=True)
169
+ raise
170
+
171
+ #----------------------------------------------------------------------------
dnnlib/tflib/network.py ADDED
@@ -0,0 +1,592 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
4
+ #
5
+ # This work is made available under the Nvidia Source Code License-NC.
6
+ # To view a copy of this license, visit
7
+ # https://nvlabs.github.io/stylegan2/license.html
8
+
9
+ """Helper for managing networks."""
10
+
11
+ import types
12
+ import inspect
13
+ import re
14
+ import uuid
15
+ import sys
16
+ import numpy as np
17
+ import tensorflow as tf
18
+
19
+ from collections import OrderedDict
20
+ from typing import Any, List, Tuple, Union
21
+
22
+ from . import tfutil
23
+ from .. import util
24
+
25
+ from .tfutil import TfExpression, TfExpressionEx
26
+
27
+ _import_handlers = [] # Custom import handlers for dealing with legacy data in pickle import.
28
+ _import_module_src = dict() # Source code for temporary modules created during pickle import.
29
+
30
+
31
+ def import_handler(handler_func):
32
+ """Function decorator for declaring custom import handlers."""
33
+ _import_handlers.append(handler_func)
34
+ return handler_func
35
+
36
+
37
+ class Network:
38
+ """Generic network abstraction.
39
+
40
+ Acts as a convenience wrapper for a parameterized network construction
41
+ function, providing several utility methods and convenient access to
42
+ the inputs/outputs/weights.
43
+
44
+ Network objects can be safely pickled and unpickled for long-term
45
+ archival purposes. The pickling works reliably as long as the underlying
46
+ network construction function is defined in a standalone Python module
47
+ that has no side effects or application-specific imports.
48
+
49
+ Args:
50
+ name: Network name. Used to select TensorFlow name and variable scopes.
51
+ func_name: Fully qualified name of the underlying network construction function, or a top-level function object.
52
+ static_kwargs: Keyword arguments to be passed in to the network construction function.
53
+
54
+ Attributes:
55
+ name: User-specified name, defaults to build func name if None.
56
+ scope: Unique TensorFlow scope containing template graph and variables, derived from the user-specified name.
57
+ static_kwargs: Arguments passed to the user-supplied build func.
58
+ components: Container for sub-networks. Passed to the build func, and retained between calls.
59
+ num_inputs: Number of input tensors.
60
+ num_outputs: Number of output tensors.
61
+ input_shapes: Input tensor shapes (NC or NCHW), including minibatch dimension.
62
+ output_shapes: Output tensor shapes (NC or NCHW), including minibatch dimension.
63
+ input_shape: Short-hand for input_shapes[0].
64
+ output_shape: Short-hand for output_shapes[0].
65
+ input_templates: Input placeholders in the template graph.
66
+ output_templates: Output tensors in the template graph.
67
+ input_names: Name string for each input.
68
+ output_names: Name string for each output.
69
+ own_vars: Variables defined by this network (local_name => var), excluding sub-networks.
70
+ vars: All variables (local_name => var).
71
+ trainables: All trainable variables (local_name => var).
72
+ var_global_to_local: Mapping from variable global names to local names.
73
+ """
74
+
75
+ def __init__(self, name: str = None, func_name: Any = None, **static_kwargs):
76
+ tfutil.assert_tf_initialized()
77
+ assert isinstance(name, str) or name is None
78
+ assert func_name is not None
79
+ assert isinstance(func_name, str) or util.is_top_level_function(func_name)
80
+ assert util.is_pickleable(static_kwargs)
81
+
82
+ self._init_fields()
83
+ self.name = name
84
+ self.static_kwargs = util.EasyDict(static_kwargs)
85
+
86
+ # Locate the user-specified network build function.
87
+ if util.is_top_level_function(func_name):
88
+ func_name = util.get_top_level_function_name(func_name)
89
+ module, self._build_func_name = util.get_module_from_obj_name(func_name)
90
+ self._build_func = util.get_obj_from_module(module, self._build_func_name)
91
+ assert callable(self._build_func)
92
+
93
+ # Dig up source code for the module containing the build function.
94
+ self._build_module_src = _import_module_src.get(module, None)
95
+ if self._build_module_src is None:
96
+ self._build_module_src = inspect.getsource(module)
97
+
98
+ # Init TensorFlow graph.
99
+ self._init_graph()
100
+ self.reset_own_vars()
101
+
102
+ def _init_fields(self) -> None:
103
+ self.name = None
104
+ self.scope = None
105
+ self.static_kwargs = util.EasyDict()
106
+ self.components = util.EasyDict()
107
+ self.num_inputs = 0
108
+ self.num_outputs = 0
109
+ self.input_shapes = [[]]
110
+ self.output_shapes = [[]]
111
+ self.input_shape = []
112
+ self.output_shape = []
113
+ self.input_templates = []
114
+ self.output_templates = []
115
+ self.input_names = []
116
+ self.output_names = []
117
+ self.own_vars = OrderedDict()
118
+ self.vars = OrderedDict()
119
+ self.trainables = OrderedDict()
120
+ self.var_global_to_local = OrderedDict()
121
+
122
+ self._build_func = None # User-supplied build function that constructs the network.
123
+ self._build_func_name = None # Name of the build function.
124
+ self._build_module_src = None # Full source code of the module containing the build function.
125
+ self._run_cache = dict() # Cached graph data for Network.run().
126
+
127
+ def _init_graph(self) -> None:
128
+ # Collect inputs.
129
+ self.input_names = []
130
+
131
+ for param in inspect.signature(self._build_func).parameters.values():
132
+ if param.kind == param.POSITIONAL_OR_KEYWORD and param.default is param.empty:
133
+ self.input_names.append(param.name)
134
+
135
+ self.num_inputs = len(self.input_names)
136
+ assert self.num_inputs >= 1
137
+
138
+ # Choose name and scope.
139
+ if self.name is None:
140
+ self.name = self._build_func_name
141
+ assert re.match("^[A-Za-z0-9_.\\-]*$", self.name)
142
+ with tf.name_scope(None):
143
+ self.scope = tf.get_default_graph().unique_name(self.name, mark_as_used=True)
144
+
145
+ # Finalize build func kwargs.
146
+ build_kwargs = dict(self.static_kwargs)
147
+ build_kwargs["is_template_graph"] = True
148
+ build_kwargs["components"] = self.components
149
+
150
+ # Build template graph.
151
+ with tfutil.absolute_variable_scope(self.scope, reuse=False), tfutil.absolute_name_scope(self.scope): # ignore surrounding scopes
152
+ assert tf.get_variable_scope().name == self.scope
153
+ assert tf.get_default_graph().get_name_scope() == self.scope
154
+ with tf.control_dependencies(None): # ignore surrounding control dependencies
155
+ self.input_templates = [tf.placeholder(tf.float32, name=name) for name in self.input_names]
156
+ out_expr = self._build_func(*self.input_templates, **build_kwargs)
157
+
158
+ # Collect outputs.
159
+ assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)
160
+ self.output_templates = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr)
161
+ self.num_outputs = len(self.output_templates)
162
+ assert self.num_outputs >= 1
163
+ assert all(tfutil.is_tf_expression(t) for t in self.output_templates)
164
+
165
+ # Perform sanity checks.
166
+ if any(t.shape.ndims is None for t in self.input_templates):
167
+ raise ValueError("Network input shapes not defined. Please call x.set_shape() for each input.")
168
+ if any(t.shape.ndims is None for t in self.output_templates):
169
+ raise ValueError("Network output shapes not defined. Please call x.set_shape() where applicable.")
170
+ if any(not isinstance(comp, Network) for comp in self.components.values()):
171
+ raise ValueError("Components of a Network must be Networks themselves.")
172
+ if len(self.components) != len(set(comp.name for comp in self.components.values())):
173
+ raise ValueError("Components of a Network must have unique names.")
174
+
175
+ # List inputs and outputs.
176
+ self.input_shapes = [t.shape.as_list() for t in self.input_templates]
177
+ self.output_shapes = [t.shape.as_list() for t in self.output_templates]
178
+ self.input_shape = self.input_shapes[0]
179
+ self.output_shape = self.output_shapes[0]
180
+ self.output_names = [t.name.split("/")[-1].split(":")[0] for t in self.output_templates]
181
+
182
+ # List variables.
183
+ self.own_vars = OrderedDict((var.name[len(self.scope) + 1:].split(":")[0], var) for var in tf.global_variables(self.scope + "/"))
184
+ self.vars = OrderedDict(self.own_vars)
185
+ self.vars.update((comp.name + "/" + name, var) for comp in self.components.values() for name, var in comp.vars.items())
186
+ self.trainables = OrderedDict((name, var) for name, var in self.vars.items() if var.trainable)
187
+ self.var_global_to_local = OrderedDict((var.name.split(":")[0], name) for name, var in self.vars.items())
188
+
189
+ def reset_own_vars(self) -> None:
190
+ """Re-initialize all variables of this network, excluding sub-networks."""
191
+ tfutil.run([var.initializer for var in self.own_vars.values()])
192
+
193
+ def reset_vars(self) -> None:
194
+ """Re-initialize all variables of this network, including sub-networks."""
195
+ tfutil.run([var.initializer for var in self.vars.values()])
196
+
197
+ def reset_trainables(self) -> None:
198
+ """Re-initialize all trainable variables of this network, including sub-networks."""
199
+ tfutil.run([var.initializer for var in self.trainables.values()])
200
+
201
+ def get_output_for(self, *in_expr: TfExpression, return_as_list: bool = False, **dynamic_kwargs) -> Union[TfExpression, List[TfExpression]]:
202
+ """Construct TensorFlow expression(s) for the output(s) of this network, given the input expression(s)."""
203
+ assert len(in_expr) == self.num_inputs
204
+ assert not all(expr is None for expr in in_expr)
205
+
206
+ # Finalize build func kwargs.
207
+ build_kwargs = dict(self.static_kwargs)
208
+ build_kwargs.update(dynamic_kwargs)
209
+ build_kwargs["is_template_graph"] = False
210
+ build_kwargs["components"] = self.components
211
+
212
+ # Build TensorFlow graph to evaluate the network.
213
+ with tfutil.absolute_variable_scope(self.scope, reuse=True), tf.name_scope(self.name):
214
+ assert tf.get_variable_scope().name == self.scope
215
+ valid_inputs = [expr for expr in in_expr if expr is not None]
216
+ final_inputs = []
217
+ for expr, name, shape in zip(in_expr, self.input_names, self.input_shapes):
218
+ if expr is not None:
219
+ expr = tf.identity(expr, name=name)
220
+ else:
221
+ expr = tf.zeros([tf.shape(valid_inputs[0])[0]] + shape[1:], name=name)
222
+ final_inputs.append(expr)
223
+ out_expr = self._build_func(*final_inputs, **build_kwargs)
224
+
225
+ # Propagate input shapes back to the user-specified expressions.
226
+ for expr, final in zip(in_expr, final_inputs):
227
+ if isinstance(expr, tf.Tensor):
228
+ expr.set_shape(final.shape)
229
+
230
+ # Express outputs in the desired format.
231
+ assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)
232
+ if return_as_list:
233
+ out_expr = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr)
234
+ return out_expr
235
+
236
+ def get_var_local_name(self, var_or_global_name: Union[TfExpression, str]) -> str:
237
+ """Get the local name of a given variable, without any surrounding name scopes."""
238
+ assert tfutil.is_tf_expression(var_or_global_name) or isinstance(var_or_global_name, str)
239
+ global_name = var_or_global_name if isinstance(var_or_global_name, str) else var_or_global_name.name
240
+ return self.var_global_to_local[global_name]
241
+
242
+ def find_var(self, var_or_local_name: Union[TfExpression, str]) -> TfExpression:
243
+ """Find variable by local or global name."""
244
+ assert tfutil.is_tf_expression(var_or_local_name) or isinstance(var_or_local_name, str)
245
+ return self.vars[var_or_local_name] if isinstance(var_or_local_name, str) else var_or_local_name
246
+
247
+ def get_var(self, var_or_local_name: Union[TfExpression, str]) -> np.ndarray:
248
+ """Get the value of a given variable as NumPy array.
249
+ Note: This method is very inefficient -- prefer to use tflib.run(list_of_vars) whenever possible."""
250
+ return self.find_var(var_or_local_name).eval()
251
+
252
+ def set_var(self, var_or_local_name: Union[TfExpression, str], new_value: Union[int, float, np.ndarray]) -> None:
253
+ """Set the value of a given variable based on the given NumPy array.
254
+ Note: This method is very inefficient -- prefer to use tflib.set_vars() whenever possible."""
255
+ tfutil.set_vars({self.find_var(var_or_local_name): new_value})
256
+
257
+ def __getstate__(self) -> dict:
258
+ """Pickle export."""
259
+ state = dict()
260
+ state["version"] = 4
261
+ state["name"] = self.name
262
+ state["static_kwargs"] = dict(self.static_kwargs)
263
+ state["components"] = dict(self.components)
264
+ state["build_module_src"] = self._build_module_src
265
+ state["build_func_name"] = self._build_func_name
266
+ state["variables"] = list(zip(self.own_vars.keys(), tfutil.run(list(self.own_vars.values()))))
267
+ return state
268
+
269
+ def __setstate__(self, state: dict) -> None:
270
+ """Pickle import."""
271
+ # pylint: disable=attribute-defined-outside-init
272
+ tfutil.assert_tf_initialized()
273
+ self._init_fields()
274
+
275
+ # Execute custom import handlers.
276
+ for handler in _import_handlers:
277
+ state = handler(state)
278
+
279
+ # Set basic fields.
280
+ assert state["version"] in [2, 3, 4]
281
+ self.name = state["name"]
282
+ self.static_kwargs = util.EasyDict(state["static_kwargs"])
283
+ self.components = util.EasyDict(state.get("components", {}))
284
+ self._build_module_src = state["build_module_src"]
285
+ self._build_func_name = state["build_func_name"]
286
+
287
+ # Create temporary module from the imported source code.
288
+ module_name = "_tflib_network_import_" + uuid.uuid4().hex
289
+ module = types.ModuleType(module_name)
290
+ sys.modules[module_name] = module
291
+ _import_module_src[module] = self._build_module_src
292
+ exec(self._build_module_src, module.__dict__) # pylint: disable=exec-used
293
+
294
+ # Locate network build function in the temporary module.
295
+ self._build_func = util.get_obj_from_module(module, self._build_func_name)
296
+ assert callable(self._build_func)
297
+
298
+ # Init TensorFlow graph.
299
+ self._init_graph()
300
+ self.reset_own_vars()
301
+ tfutil.set_vars({self.find_var(name): value for name, value in state["variables"]})
302
+
303
+ def clone(self, name: str = None, **new_static_kwargs) -> "Network":
304
+ """Create a clone of this network with its own copy of the variables."""
305
+ # pylint: disable=protected-access
306
+ net = object.__new__(Network)
307
+ net._init_fields()
308
+ net.name = name if name is not None else self.name
309
+ net.static_kwargs = util.EasyDict(self.static_kwargs)
310
+ net.static_kwargs.update(new_static_kwargs)
311
+ net._build_module_src = self._build_module_src
312
+ net._build_func_name = self._build_func_name
313
+ net._build_func = self._build_func
314
+ net._init_graph()
315
+ net.copy_vars_from(self)
316
+ return net
317
+
318
+ def copy_own_vars_from(self, src_net: "Network") -> None:
319
+ """Copy the values of all variables from the given network, excluding sub-networks."""
320
+ names = [name for name in self.own_vars.keys() if name in src_net.own_vars]
321
+ tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
322
+
323
+ def copy_vars_from(self, src_net: "Network") -> None:
324
+ """Copy the values of all variables from the given network, including sub-networks."""
325
+ names = [name for name in self.vars.keys() if name in src_net.vars]
326
+ tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
327
+
328
+ def copy_trainables_from(self, src_net: "Network") -> None:
329
+ """Copy the values of all trainable variables from the given network, including sub-networks."""
330
+ names = [name for name in self.trainables.keys() if name in src_net.trainables]
331
+ tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
332
+
333
+ def convert(self, new_func_name: str, new_name: str = None, **new_static_kwargs) -> "Network":
334
+ """Create new network with the given parameters, and copy all variables from this network."""
335
+ if new_name is None:
336
+ new_name = self.name
337
+ static_kwargs = dict(self.static_kwargs)
338
+ static_kwargs.update(new_static_kwargs)
339
+ net = Network(name=new_name, func_name=new_func_name, **static_kwargs)
340
+ net.copy_vars_from(self)
341
+ return net
342
+
343
+ def setup_as_moving_average_of(self, src_net: "Network", beta: TfExpressionEx = 0.99, beta_nontrainable: TfExpressionEx = 0.0) -> tf.Operation:
344
+ """Construct a TensorFlow op that updates the variables of this network
345
+ to be slightly closer to those of the given network."""
346
+ with tfutil.absolute_name_scope(self.scope + "/_MovingAvg"):
347
+ ops = []
348
+ for name, var in self.vars.items():
349
+ if name in src_net.vars:
350
+ cur_beta = beta if name in self.trainables else beta_nontrainable
351
+ new_value = tfutil.lerp(src_net.vars[name], var, cur_beta)
352
+ ops.append(var.assign(new_value))
353
+ return tf.group(*ops)
354
+
355
+ def run(self,
356
+ *in_arrays: Tuple[Union[np.ndarray, None], ...],
357
+ input_transform: dict = None,
358
+ output_transform: dict = None,
359
+ return_as_list: bool = False,
360
+ print_progress: bool = False,
361
+ minibatch_size: int = None,
362
+ num_gpus: int = 1,
363
+ assume_frozen: bool = False,
364
+ **dynamic_kwargs) -> Union[np.ndarray, Tuple[np.ndarray, ...], List[np.ndarray]]:
365
+ """Run this network for the given NumPy array(s), and return the output(s) as NumPy array(s).
366
+
367
+ Args:
368
+ input_transform: A dict specifying a custom transformation to be applied to the input tensor(s) before evaluating the network.
369
+ The dict must contain a 'func' field that points to a top-level function. The function is called with the input
370
+ TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.
371
+ output_transform: A dict specifying a custom transformation to be applied to the output tensor(s) after evaluating the network.
372
+ The dict must contain a 'func' field that points to a top-level function. The function is called with the output
373
+ TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.
374
+ return_as_list: True = return a list of NumPy arrays, False = return a single NumPy array, or a tuple if there are multiple outputs.
375
+ print_progress: Print progress to the console? Useful for very large input arrays.
376
+ minibatch_size: Maximum minibatch size to use, None = disable batching.
377
+ num_gpus: Number of GPUs to use.
378
+ assume_frozen: Improve multi-GPU performance by assuming that the trainable parameters will remain changed between calls.
379
+ dynamic_kwargs: Additional keyword arguments to be passed into the network build function.
380
+ """
381
+ assert len(in_arrays) == self.num_inputs
382
+ assert not all(arr is None for arr in in_arrays)
383
+ assert input_transform is None or util.is_top_level_function(input_transform["func"])
384
+ assert output_transform is None or util.is_top_level_function(output_transform["func"])
385
+ output_transform, dynamic_kwargs = _handle_legacy_output_transforms(output_transform, dynamic_kwargs)
386
+ num_items = in_arrays[0].shape[0]
387
+ if minibatch_size is None:
388
+ minibatch_size = num_items
389
+
390
+ # Construct unique hash key from all arguments that affect the TensorFlow graph.
391
+ key = dict(input_transform=input_transform, output_transform=output_transform, num_gpus=num_gpus, assume_frozen=assume_frozen, dynamic_kwargs=dynamic_kwargs)
392
+ def unwind_key(obj):
393
+ if isinstance(obj, dict):
394
+ return [(key, unwind_key(value)) for key, value in sorted(obj.items())]
395
+ if callable(obj):
396
+ return util.get_top_level_function_name(obj)
397
+ return obj
398
+ key = repr(unwind_key(key))
399
+
400
+ # Build graph.
401
+ if key not in self._run_cache:
402
+ with tfutil.absolute_name_scope(self.scope + "/_Run"), tf.control_dependencies(None):
403
+ with tf.device("/cpu:0"):
404
+ in_expr = [tf.placeholder(tf.float32, name=name) for name in self.input_names]
405
+ in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr]))
406
+
407
+ out_split = []
408
+ for gpu in range(num_gpus):
409
+ with tf.device("/gpu:%d" % gpu):
410
+ net_gpu = self.clone() if assume_frozen else self
411
+ in_gpu = in_split[gpu]
412
+
413
+ if input_transform is not None:
414
+ in_kwargs = dict(input_transform)
415
+ in_gpu = in_kwargs.pop("func")(*in_gpu, **in_kwargs)
416
+ in_gpu = [in_gpu] if tfutil.is_tf_expression(in_gpu) else list(in_gpu)
417
+
418
+ assert len(in_gpu) == self.num_inputs
419
+ out_gpu = net_gpu.get_output_for(*in_gpu, return_as_list=True, **dynamic_kwargs)
420
+
421
+ if output_transform is not None:
422
+ out_kwargs = dict(output_transform)
423
+ out_gpu = out_kwargs.pop("func")(*out_gpu, **out_kwargs)
424
+ out_gpu = [out_gpu] if tfutil.is_tf_expression(out_gpu) else list(out_gpu)
425
+
426
+ assert len(out_gpu) == self.num_outputs
427
+ out_split.append(out_gpu)
428
+
429
+ with tf.device("/cpu:0"):
430
+ out_expr = [tf.concat(outputs, axis=0) for outputs in zip(*out_split)]
431
+ self._run_cache[key] = in_expr, out_expr
432
+
433
+ # Run minibatches.
434
+ in_expr, out_expr = self._run_cache[key]
435
+ out_arrays = [np.empty([num_items] + expr.shape.as_list()[1:], expr.dtype.name) for expr in out_expr]
436
+
437
+ for mb_begin in range(0, num_items, minibatch_size):
438
+ if print_progress:
439
+ print("\r%d / %d" % (mb_begin, num_items), end="")
440
+
441
+ mb_end = min(mb_begin + minibatch_size, num_items)
442
+ mb_num = mb_end - mb_begin
443
+ mb_in = [src[mb_begin : mb_end] if src is not None else np.zeros([mb_num] + shape[1:]) for src, shape in zip(in_arrays, self.input_shapes)]
444
+ mb_out = tf.get_default_session().run(out_expr, dict(zip(in_expr, mb_in)))
445
+
446
+ for dst, src in zip(out_arrays, mb_out):
447
+ dst[mb_begin: mb_end] = src
448
+
449
+ # Done.
450
+ if print_progress:
451
+ print("\r%d / %d" % (num_items, num_items))
452
+
453
+ if not return_as_list:
454
+ out_arrays = out_arrays[0] if len(out_arrays) == 1 else tuple(out_arrays)
455
+ return out_arrays
456
+
457
+ def list_ops(self) -> List[TfExpression]:
458
+ include_prefix = self.scope + "/"
459
+ exclude_prefix = include_prefix + "_"
460
+ ops = tf.get_default_graph().get_operations()
461
+ ops = [op for op in ops if op.name.startswith(include_prefix)]
462
+ ops = [op for op in ops if not op.name.startswith(exclude_prefix)]
463
+ return ops
464
+
465
+ def list_layers(self) -> List[Tuple[str, TfExpression, List[TfExpression]]]:
466
+ """Returns a list of (layer_name, output_expr, trainable_vars) tuples corresponding to
467
+ individual layers of the network. Mainly intended to be used for reporting."""
468
+ layers = []
469
+
470
+ def recurse(scope, parent_ops, parent_vars, level):
471
+ # Ignore specific patterns.
472
+ if any(p in scope for p in ["/Shape", "/strided_slice", "/Cast", "/concat", "/Assign"]):
473
+ return
474
+
475
+ # Filter ops and vars by scope.
476
+ global_prefix = scope + "/"
477
+ local_prefix = global_prefix[len(self.scope) + 1:]
478
+ cur_ops = [op for op in parent_ops if op.name.startswith(global_prefix) or op.name == global_prefix[:-1]]
479
+ cur_vars = [(name, var) for name, var in parent_vars if name.startswith(local_prefix) or name == local_prefix[:-1]]
480
+ if not cur_ops and not cur_vars:
481
+ return
482
+
483
+ # Filter out all ops related to variables.
484
+ for var in [op for op in cur_ops if op.type.startswith("Variable")]:
485
+ var_prefix = var.name + "/"
486
+ cur_ops = [op for op in cur_ops if not op.name.startswith(var_prefix)]
487
+
488
+ # Scope does not contain ops as immediate children => recurse deeper.
489
+ contains_direct_ops = any("/" not in op.name[len(global_prefix):] and op.type not in ["Identity", "Cast", "Transpose"] for op in cur_ops)
490
+ if (level == 0 or not contains_direct_ops) and (len(cur_ops) + len(cur_vars)) > 1:
491
+ visited = set()
492
+ for rel_name in [op.name[len(global_prefix):] for op in cur_ops] + [name[len(local_prefix):] for name, _var in cur_vars]:
493
+ token = rel_name.split("/")[0]
494
+ if token not in visited:
495
+ recurse(global_prefix + token, cur_ops, cur_vars, level + 1)
496
+ visited.add(token)
497
+ return
498
+
499
+ # Report layer.
500
+ layer_name = scope[len(self.scope) + 1:]
501
+ layer_output = cur_ops[-1].outputs[0] if cur_ops else cur_vars[-1][1]
502
+ layer_trainables = [var for _name, var in cur_vars if var.trainable]
503
+ layers.append((layer_name, layer_output, layer_trainables))
504
+
505
+ recurse(self.scope, self.list_ops(), list(self.vars.items()), 0)
506
+ return layers
507
+
508
+ def print_layers(self, title: str = None, hide_layers_with_no_params: bool = False) -> None:
509
+ """Print a summary table of the network structure."""
510
+ rows = [[title if title is not None else self.name, "Params", "OutputShape", "WeightShape"]]
511
+ rows += [["---"] * 4]
512
+ total_params = 0
513
+
514
+ for layer_name, layer_output, layer_trainables in self.list_layers():
515
+ num_params = sum(int(np.prod(var.shape.as_list())) for var in layer_trainables)
516
+ weights = [var for var in layer_trainables if var.name.endswith("/weight:0")]
517
+ weights.sort(key=lambda x: len(x.name))
518
+ if len(weights) == 0 and len(layer_trainables) == 1:
519
+ weights = layer_trainables
520
+ total_params += num_params
521
+
522
+ if not hide_layers_with_no_params or num_params != 0:
523
+ num_params_str = str(num_params) if num_params > 0 else "-"
524
+ output_shape_str = str(layer_output.shape)
525
+ weight_shape_str = str(weights[0].shape) if len(weights) >= 1 else "-"
526
+ rows += [[layer_name, num_params_str, output_shape_str, weight_shape_str]]
527
+
528
+ rows += [["---"] * 4]
529
+ rows += [["Total", str(total_params), "", ""]]
530
+
531
+ widths = [max(len(cell) for cell in column) for column in zip(*rows)]
532
+ print()
533
+ for row in rows:
534
+ print(" ".join(cell + " " * (width - len(cell)) for cell, width in zip(row, widths)))
535
+ print()
536
+
537
+ def setup_weight_histograms(self, title: str = None) -> None:
538
+ """Construct summary ops to include histograms of all trainable parameters in TensorBoard."""
539
+ if title is None:
540
+ title = self.name
541
+
542
+ with tf.name_scope(None), tf.device(None), tf.control_dependencies(None):
543
+ for local_name, var in self.trainables.items():
544
+ if "/" in local_name:
545
+ p = local_name.split("/")
546
+ name = title + "_" + p[-1] + "/" + "_".join(p[:-1])
547
+ else:
548
+ name = title + "_toplevel/" + local_name
549
+
550
+ tf.summary.histogram(name, var)
551
+
552
+ #----------------------------------------------------------------------------
553
+ # Backwards-compatible emulation of legacy output transformation in Network.run().
554
+
555
+ _print_legacy_warning = True
556
+
557
+ def _handle_legacy_output_transforms(output_transform, dynamic_kwargs):
558
+ global _print_legacy_warning
559
+ legacy_kwargs = ["out_mul", "out_add", "out_shrink", "out_dtype"]
560
+ if not any(kwarg in dynamic_kwargs for kwarg in legacy_kwargs):
561
+ return output_transform, dynamic_kwargs
562
+
563
+ if _print_legacy_warning:
564
+ _print_legacy_warning = False
565
+ print()
566
+ print("WARNING: Old-style output transformations in Network.run() are deprecated.")
567
+ print("Consider using 'output_transform=dict(func=tflib.convert_images_to_uint8)'")
568
+ print("instead of 'out_mul=127.5, out_add=127.5, out_dtype=np.uint8'.")
569
+ print()
570
+ assert output_transform is None
571
+
572
+ new_kwargs = dict(dynamic_kwargs)
573
+ new_transform = {kwarg: new_kwargs.pop(kwarg) for kwarg in legacy_kwargs if kwarg in dynamic_kwargs}
574
+ new_transform["func"] = _legacy_output_transform_func
575
+ return new_transform, new_kwargs
576
+
577
+ def _legacy_output_transform_func(*expr, out_mul=1.0, out_add=0.0, out_shrink=1, out_dtype=None):
578
+ if out_mul != 1.0:
579
+ expr = [x * out_mul for x in expr]
580
+
581
+ if out_add != 0.0:
582
+ expr = [x + out_add for x in expr]
583
+
584
+ if out_shrink > 1:
585
+ ksize = [1, 1, out_shrink, out_shrink]
586
+ expr = [tf.nn.avg_pool(x, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW") for x in expr]
587
+
588
+ if out_dtype is not None:
589
+ if tf.as_dtype(out_dtype).is_integer:
590
+ expr = [tf.round(x) for x in expr]
591
+ expr = [tf.saturate_cast(x, out_dtype) for x in expr]
592
+ return expr
dnnlib/tflib/ops/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
4
+ #
5
+ # This work is made available under the Nvidia Source Code License-NC.
6
+ # To view a copy of this license, visit
7
+ # https://nvlabs.github.io/stylegan2/license.html
8
+
9
+ # empty
dnnlib/tflib/ops/fused_bias_act.cu ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
4
+ //
5
+ // This work is made available under the Nvidia Source Code License-NC.
6
+ // To view a copy of this license, visit
7
+ // https://nvlabs.github.io/stylegan2/license.html
8
+
9
+ #define EIGEN_USE_GPU
10
+ #define __CUDA_INCLUDE_COMPILER_INTERNAL_HEADERS__
11
+ #include "tensorflow/core/framework/op.h"
12
+ #include "tensorflow/core/framework/op_kernel.h"
13
+ #include "tensorflow/core/framework/shape_inference.h"
14
+ #include <stdio.h>
15
+
16
+ using namespace tensorflow;
17
+ using namespace tensorflow::shape_inference;
18
+
19
+ #define OP_CHECK_CUDA_ERROR(CTX, CUDA_CALL) do { cudaError_t err = CUDA_CALL; OP_REQUIRES(CTX, err == cudaSuccess, errors::Internal(cudaGetErrorName(err))); } while (false)
20
+
21
+ //------------------------------------------------------------------------
22
+ // CUDA kernel.
23
+
24
+ template <class T>
25
+ struct FusedBiasActKernelParams
26
+ {
27
+ const T* x; // [sizeX]
28
+ const T* b; // [sizeB] or NULL
29
+ const T* ref; // [sizeX] or NULL
30
+ T* y; // [sizeX]
31
+
32
+ int grad;
33
+ int axis;
34
+ int act;
35
+ float alpha;
36
+ float gain;
37
+
38
+ int sizeX;
39
+ int sizeB;
40
+ int stepB;
41
+ int loopX;
42
+ };
43
+
44
+ template <class T>
45
+ static __global__ void FusedBiasActKernel(const FusedBiasActKernelParams<T> p)
46
+ {
47
+ const float expRange = 80.0f;
48
+ const float halfExpRange = 40.0f;
49
+ const float seluScale = 1.0507009873554804934193349852946f;
50
+ const float seluAlpha = 1.6732632423543772848170429916717f;
51
+
52
+ // Loop over elements.
53
+ int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
54
+ for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
55
+ {
56
+ // Load and apply bias.
57
+ float x = (float)p.x[xi];
58
+ if (p.b)
59
+ x += (float)p.b[(xi / p.stepB) % p.sizeB];
60
+ float ref = (p.ref) ? (float)p.ref[xi] : 0.0f;
61
+ if (p.gain != 0.0f & p.act != 9)
62
+ ref /= p.gain;
63
+
64
+ // Evaluate activation func.
65
+ float y;
66
+ switch (p.act * 10 + p.grad)
67
+ {
68
+ // linear
69
+ default:
70
+ case 10: y = x; break;
71
+ case 11: y = x; break;
72
+ case 12: y = 0.0f; break;
73
+
74
+ // relu
75
+ case 20: y = (x > 0.0f) ? x : 0.0f; break;
76
+ case 21: y = (ref > 0.0f) ? x : 0.0f; break;
77
+ case 22: y = 0.0f; break;
78
+
79
+ // lrelu
80
+ case 30: y = (x > 0.0f) ? x : x * p.alpha; break;
81
+ case 31: y = (ref > 0.0f) ? x : x * p.alpha; break;
82
+ case 32: y = 0.0f; break;
83
+
84
+ // tanh
85
+ case 40: { float c = expf(x); float d = 1.0f / c; y = (x < -expRange) ? -1.0f : (x > expRange) ? 1.0f : (c - d) / (c + d); } break;
86
+ case 41: y = x * (1.0f - ref * ref); break;
87
+ case 42: y = x * (1.0f - ref * ref) * (-2.0f * ref); break;
88
+
89
+ // sigmoid
90
+ case 50: y = (x < -expRange) ? 0.0f : 1.0f / (expf(-x) + 1.0f); break;
91
+ case 51: y = x * ref * (1.0f - ref); break;
92
+ case 52: y = x * ref * (1.0f - ref) * (1.0f - 2.0f * ref); break;
93
+
94
+ // elu
95
+ case 60: y = (x >= 0.0f) ? x : expf(x) - 1.0f; break;
96
+ case 61: y = (ref >= 0.0f) ? x : x * (ref + 1.0f); break;
97
+ case 62: y = (ref >= 0.0f) ? 0.0f : x * (ref + 1.0f); break;
98
+
99
+ // selu
100
+ case 70: y = (x >= 0.0f) ? seluScale * x : (seluScale * seluAlpha) * (expf(x) - 1.0f); break;
101
+ case 71: y = (ref >= 0.0f) ? x * seluScale : x * (ref + seluScale * seluAlpha); break;
102
+ case 72: y = (ref >= 0.0f) ? 0.0f : x * (ref + seluScale * seluAlpha); break;
103
+
104
+ // softplus
105
+ case 80: y = (x > expRange) ? x : logf(expf(x) + 1.0f); break;
106
+ case 81: y = x * (1.0f - expf(-ref)); break;
107
+ case 82: { float c = expf(-ref); y = x * c * (1.0f - c); } break;
108
+
109
+ // swish
110
+ case 90: y = (x < -expRange) ? 0.0f : x / (expf(-x) + 1.0f); break;
111
+ case 91: { float c = expf(ref); float d = c + 1.0f; y = (ref > halfExpRange) ? x : x * c * (ref + d) / (d * d); } break;
112
+ case 92: { float c = expf(ref); float d = c + 1.0f; y = (ref > halfExpRange) ? 0.0f : x * c * (ref * (2.0f - d) + 2.0f * d) / (d * d * d); } break;
113
+ }
114
+
115
+ // Apply gain and store.
116
+ p.y[xi] = (T)(y * p.gain);
117
+ }
118
+ }
119
+
120
+ //------------------------------------------------------------------------
121
+ // TensorFlow op.
122
+
123
+ template <class T>
124
+ struct FusedBiasActOp : public OpKernel
125
+ {
126
+ FusedBiasActKernelParams<T> m_attribs;
127
+
128
+ FusedBiasActOp(OpKernelConstruction* ctx) : OpKernel(ctx)
129
+ {
130
+ memset(&m_attribs, 0, sizeof(m_attribs));
131
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("grad", &m_attribs.grad));
132
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &m_attribs.axis));
133
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("act", &m_attribs.act));
134
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &m_attribs.alpha));
135
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("gain", &m_attribs.gain));
136
+ OP_REQUIRES(ctx, m_attribs.grad >= 0, errors::InvalidArgument("grad must be non-negative"));
137
+ OP_REQUIRES(ctx, m_attribs.axis >= 0, errors::InvalidArgument("axis must be non-negative"));
138
+ OP_REQUIRES(ctx, m_attribs.act >= 0, errors::InvalidArgument("act must be non-negative"));
139
+ }
140
+
141
+ void Compute(OpKernelContext* ctx)
142
+ {
143
+ FusedBiasActKernelParams<T> p = m_attribs;
144
+ cudaStream_t stream = ctx->eigen_device<Eigen::GpuDevice>().stream();
145
+
146
+ const Tensor& x = ctx->input(0); // [...]
147
+ const Tensor& b = ctx->input(1); // [sizeB] or [0]
148
+ const Tensor& ref = ctx->input(2); // x.shape or [0]
149
+ p.x = x.flat<T>().data();
150
+ p.b = (b.NumElements()) ? b.flat<T>().data() : NULL;
151
+ p.ref = (ref.NumElements()) ? ref.flat<T>().data() : NULL;
152
+ OP_REQUIRES(ctx, b.NumElements() == 0 || m_attribs.axis < x.dims(), errors::InvalidArgument("axis out of bounds"));
153
+ OP_REQUIRES(ctx, b.dims() == 1, errors::InvalidArgument("b must have rank 1"));
154
+ OP_REQUIRES(ctx, b.NumElements() == 0 || b.NumElements() == x.dim_size(m_attribs.axis), errors::InvalidArgument("b has wrong number of elements"));
155
+ OP_REQUIRES(ctx, ref.NumElements() == ((p.grad == 0) ? 0 : x.NumElements()), errors::InvalidArgument("ref has wrong number of elements"));
156
+ OP_REQUIRES(ctx, x.NumElements() <= kint32max, errors::InvalidArgument("x is too large"));
157
+
158
+ p.sizeX = (int)x.NumElements();
159
+ p.sizeB = (int)b.NumElements();
160
+ p.stepB = 1;
161
+ for (int i = m_attribs.axis + 1; i < x.dims(); i++)
162
+ p.stepB *= (int)x.dim_size(i);
163
+
164
+ Tensor* y = NULL; // x.shape
165
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, x.shape(), &y));
166
+ p.y = y->flat<T>().data();
167
+
168
+ p.loopX = 4;
169
+ int blockSize = 4 * 32;
170
+ int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
171
+ void* args[] = {&p};
172
+ OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel((void*)FusedBiasActKernel<T>, gridSize, blockSize, args, 0, stream));
173
+ }
174
+ };
175
+
176
+ REGISTER_OP("FusedBiasAct")
177
+ .Input ("x: T")
178
+ .Input ("b: T")
179
+ .Input ("ref: T")
180
+ .Output ("y: T")
181
+ .Attr ("T: {float, half}")
182
+ .Attr ("grad: int = 0")
183
+ .Attr ("axis: int = 1")
184
+ .Attr ("act: int = 0")
185
+ .Attr ("alpha: float = 0.0")
186
+ .Attr ("gain: float = 1.0");
187
+ REGISTER_KERNEL_BUILDER(Name("FusedBiasAct").Device(DEVICE_GPU).TypeConstraint<float>("T"), FusedBiasActOp<float>);
188
+ REGISTER_KERNEL_BUILDER(Name("FusedBiasAct").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"), FusedBiasActOp<Eigen::half>);
189
+
190
+ //------------------------------------------------------------------------
dnnlib/tflib/ops/fused_bias_act.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
4
+ #
5
+ # This work is made available under the Nvidia Source Code License-NC.
6
+ # To view a copy of this license, visit
7
+ # https://nvlabs.github.io/stylegan2/license.html
8
+
9
+ """Custom TensorFlow ops for efficient bias and activation."""
10
+
11
+ import os
12
+ import numpy as np
13
+ import tensorflow as tf
14
+ from .. import custom_ops
15
+ from ...util import EasyDict
16
+
17
+ def _get_plugin():
18
+ return custom_ops.get_plugin(os.path.splitext(__file__)[0] + '.cu')
19
+
20
+ #----------------------------------------------------------------------------
21
+
22
+ activation_funcs = {
23
+ 'linear': EasyDict(func=lambda x, **_: x, def_alpha=None, def_gain=1.0, cuda_idx=1, ref='y', zero_2nd_grad=True),
24
+ 'relu': EasyDict(func=lambda x, **_: tf.nn.relu(x), def_alpha=None, def_gain=np.sqrt(2), cuda_idx=2, ref='y', zero_2nd_grad=True),
25
+ 'lrelu': EasyDict(func=lambda x, alpha, **_: tf.nn.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', zero_2nd_grad=True),
26
+ 'tanh': EasyDict(func=lambda x, **_: tf.nn.tanh(x), def_alpha=None, def_gain=1.0, cuda_idx=4, ref='y', zero_2nd_grad=False),
27
+ 'sigmoid': EasyDict(func=lambda x, **_: tf.nn.sigmoid(x), def_alpha=None, def_gain=1.0, cuda_idx=5, ref='y', zero_2nd_grad=False),
28
+ 'elu': EasyDict(func=lambda x, **_: tf.nn.elu(x), def_alpha=None, def_gain=1.0, cuda_idx=6, ref='y', zero_2nd_grad=False),
29
+ 'selu': EasyDict(func=lambda x, **_: tf.nn.selu(x), def_alpha=None, def_gain=1.0, cuda_idx=7, ref='y', zero_2nd_grad=False),
30
+ 'softplus': EasyDict(func=lambda x, **_: tf.nn.softplus(x), def_alpha=None, def_gain=1.0, cuda_idx=8, ref='y', zero_2nd_grad=False),
31
+ 'swish': EasyDict(func=lambda x, **_: tf.nn.sigmoid(x) * x, def_alpha=None, def_gain=np.sqrt(2), cuda_idx=9, ref='x', zero_2nd_grad=False),
32
+ }
33
+
34
+ #----------------------------------------------------------------------------
35
+
36
+ def fused_bias_act(x, b=None, axis=1, act='linear', alpha=None, gain=None, impl='cuda'):
37
+ r"""Fused bias and activation function.
38
+
39
+ Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
40
+ and scales the result by `gain`. Each of the steps is optional. In most cases,
41
+ the fused op is considerably more efficient than performing the same calculation
42
+ using standard TensorFlow ops. It supports first and second order gradients,
43
+ but not third order gradients.
44
+
45
+ Args:
46
+ x: Input activation tensor. Can have any shape, but if `b` is defined, the
47
+ dimension corresponding to `axis`, as well as the rank, must be known.
48
+ b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
49
+ as `x`. The shape must be known, and it must match the dimension of `x`
50
+ corresponding to `axis`.
51
+ axis: The dimension in `x` corresponding to the elements of `b`.
52
+ The value of `axis` is ignored if `b` is not specified.
53
+ act: Name of the activation function to evaluate, or `"linear"` to disable.
54
+ Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
55
+ See `activation_funcs` for a full list. `None` is not allowed.
56
+ alpha: Shape parameter for the activation function, or `None` to use the default.
57
+ gain: Scaling factor for the output tensor, or `None` to use default.
58
+ See `activation_funcs` for the default scaling of each activation function.
59
+ If unsure, consider specifying `1.0`.
60
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
61
+
62
+ Returns:
63
+ Tensor of the same shape and datatype as `x`.
64
+ """
65
+
66
+ impl_dict = {
67
+ 'ref': _fused_bias_act_ref,
68
+ 'cuda': _fused_bias_act_cuda,
69
+ }
70
+ return impl_dict[impl](x=x, b=b, axis=axis, act=act, alpha=alpha, gain=gain)
71
+
72
+ #----------------------------------------------------------------------------
73
+
74
+ def _fused_bias_act_ref(x, b, axis, act, alpha, gain):
75
+ """Slow reference implementation of `fused_bias_act()` using standard TensorFlow ops."""
76
+
77
+ # Validate arguments.
78
+ x = tf.convert_to_tensor(x)
79
+ b = tf.convert_to_tensor(b) if b is not None else tf.constant([], dtype=x.dtype)
80
+ act_spec = activation_funcs[act]
81
+ assert b.shape.rank == 1 and (b.shape[0] == 0 or b.shape[0] == x.shape[axis])
82
+ assert b.shape[0] == 0 or 0 <= axis < x.shape.rank
83
+ if alpha is None:
84
+ alpha = act_spec.def_alpha
85
+ if gain is None:
86
+ gain = act_spec.def_gain
87
+
88
+ # Add bias.
89
+ if b.shape[0] != 0:
90
+ x += tf.reshape(b, [-1 if i == axis else 1 for i in range(x.shape.rank)])
91
+
92
+ # Evaluate activation function.
93
+ x = act_spec.func(x, alpha=alpha)
94
+
95
+ # Scale by gain.
96
+ if gain != 1:
97
+ x *= gain
98
+ return x
99
+
100
+ #----------------------------------------------------------------------------
101
+
102
+ def _fused_bias_act_cuda(x, b, axis, act, alpha, gain):
103
+ """Fast CUDA implementation of `fused_bias_act()` using custom ops."""
104
+
105
+ # Validate arguments.
106
+ x = tf.convert_to_tensor(x)
107
+ empty_tensor = tf.constant([], dtype=x.dtype)
108
+ b = tf.convert_to_tensor(b) if b is not None else empty_tensor
109
+ act_spec = activation_funcs[act]
110
+ assert b.shape.rank == 1 and (b.shape[0] == 0 or b.shape[0] == x.shape[axis])
111
+ assert b.shape[0] == 0 or 0 <= axis < x.shape.rank
112
+ if alpha is None:
113
+ alpha = act_spec.def_alpha
114
+ if gain is None:
115
+ gain = act_spec.def_gain
116
+
117
+ # Special cases.
118
+ if act == 'linear' and b is None and gain == 1.0:
119
+ return x
120
+ if act_spec.cuda_idx is None:
121
+ return _fused_bias_act_ref(x=x, b=b, axis=axis, act=act, alpha=alpha, gain=gain)
122
+
123
+ # CUDA kernel.
124
+ cuda_kernel = _get_plugin().fused_bias_act
125
+ cuda_kwargs = dict(axis=axis, act=act_spec.cuda_idx, alpha=alpha, gain=gain)
126
+
127
+ # Forward pass: y = func(x, b).
128
+ def func_y(x, b):
129
+ y = cuda_kernel(x=x, b=b, ref=empty_tensor, grad=0, **cuda_kwargs)
130
+ y.set_shape(x.shape)
131
+ return y
132
+
133
+ # Backward pass: dx, db = grad(dy, x, y)
134
+ def grad_dx(dy, x, y):
135
+ ref = {'x': x, 'y': y}[act_spec.ref]
136
+ dx = cuda_kernel(x=dy, b=empty_tensor, ref=ref, grad=1, **cuda_kwargs)
137
+ dx.set_shape(x.shape)
138
+ return dx
139
+ def grad_db(dx):
140
+ if b.shape[0] == 0:
141
+ return empty_tensor
142
+ db = dx
143
+ if axis < x.shape.rank - 1:
144
+ db = tf.reduce_sum(db, list(range(axis + 1, x.shape.rank)))
145
+ if axis > 0:
146
+ db = tf.reduce_sum(db, list(range(axis)))
147
+ db.set_shape(b.shape)
148
+ return db
149
+
150
+ # Second order gradients: d_dy, d_x = grad2(d_dx, d_db, x, y)
151
+ def grad2_d_dy(d_dx, d_db, x, y):
152
+ ref = {'x': x, 'y': y}[act_spec.ref]
153
+ d_dy = cuda_kernel(x=d_dx, b=d_db, ref=ref, grad=1, **cuda_kwargs)
154
+ d_dy.set_shape(x.shape)
155
+ return d_dy
156
+ def grad2_d_x(d_dx, d_db, x, y):
157
+ ref = {'x': x, 'y': y}[act_spec.ref]
158
+ d_x = cuda_kernel(x=d_dx, b=d_db, ref=ref, grad=2, **cuda_kwargs)
159
+ d_x.set_shape(x.shape)
160
+ return d_x
161
+
162
+ # Fast version for piecewise-linear activation funcs.
163
+ @tf.custom_gradient
164
+ def func_zero_2nd_grad(x, b):
165
+ y = func_y(x, b)
166
+ @tf.custom_gradient
167
+ def grad(dy):
168
+ dx = grad_dx(dy, x, y)
169
+ db = grad_db(dx)
170
+ def grad2(d_dx, d_db):
171
+ d_dy = grad2_d_dy(d_dx, d_db, x, y)
172
+ return d_dy
173
+ return (dx, db), grad2
174
+ return y, grad
175
+
176
+ # Slow version for general activation funcs.
177
+ @tf.custom_gradient
178
+ def func_nonzero_2nd_grad(x, b):
179
+ y = func_y(x, b)
180
+ def grad_wrap(dy):
181
+ @tf.custom_gradient
182
+ def grad_impl(dy, x):
183
+ dx = grad_dx(dy, x, y)
184
+ db = grad_db(dx)
185
+ def grad2(d_dx, d_db):
186
+ d_dy = grad2_d_dy(d_dx, d_db, x, y)
187
+ d_x = grad2_d_x(d_dx, d_db, x, y)
188
+ return d_dy, d_x
189
+ return (dx, db), grad2
190
+ return grad_impl(dy, x)
191
+ return y, grad_wrap
192
+
193
+ # Which version to use?
194
+ if act_spec.zero_2nd_grad:
195
+ return func_zero_2nd_grad(x, b)
196
+ return func_nonzero_2nd_grad(x, b)
197
+
198
+ #----------------------------------------------------------------------------
dnnlib/tflib/ops/upfirdn_2d.cu ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
4
+ //
5
+ // This work is made available under the Nvidia Source Code License-NC.
6
+ // To view a copy of this license, visit
7
+ // https://nvlabs.github.io/stylegan2/license.html
8
+
9
+ #define EIGEN_USE_GPU
10
+ #define __CUDA_INCLUDE_COMPILER_INTERNAL_HEADERS__
11
+ #include "tensorflow/core/framework/op.h"
12
+ #include "tensorflow/core/framework/op_kernel.h"
13
+ #include "tensorflow/core/framework/shape_inference.h"
14
+ #include <stdio.h>
15
+
16
+ using namespace tensorflow;
17
+ using namespace tensorflow::shape_inference;
18
+
19
+ //------------------------------------------------------------------------
20
+ // Helpers.
21
+
22
+ #define OP_CHECK_CUDA_ERROR(CTX, CUDA_CALL) do { cudaError_t err = CUDA_CALL; OP_REQUIRES(CTX, err == cudaSuccess, errors::Internal(cudaGetErrorName(err))); } while (false)
23
+
24
+ static __host__ __device__ __forceinline__ int floorDiv(int a, int b)
25
+ {
26
+ int c = a / b;
27
+ if (c * b > a)
28
+ c--;
29
+ return c;
30
+ }
31
+
32
+ //------------------------------------------------------------------------
33
+ // CUDA kernel params.
34
+
35
+ template <class T>
36
+ struct UpFirDn2DKernelParams
37
+ {
38
+ const T* x; // [majorDim, inH, inW, minorDim]
39
+ const T* k; // [kernelH, kernelW]
40
+ T* y; // [majorDim, outH, outW, minorDim]
41
+
42
+ int upx;
43
+ int upy;
44
+ int downx;
45
+ int downy;
46
+ int padx0;
47
+ int padx1;
48
+ int pady0;
49
+ int pady1;
50
+
51
+ int majorDim;
52
+ int inH;
53
+ int inW;
54
+ int minorDim;
55
+ int kernelH;
56
+ int kernelW;
57
+ int outH;
58
+ int outW;
59
+ int loopMajor;
60
+ int loopX;
61
+ };
62
+
63
+ //------------------------------------------------------------------------
64
+ // General CUDA implementation for large filter kernels.
65
+
66
+ template <class T>
67
+ static __global__ void UpFirDn2DKernel_large(const UpFirDn2DKernelParams<T> p)
68
+ {
69
+ // Calculate thread index.
70
+ int minorIdx = blockIdx.x * blockDim.x + threadIdx.x;
71
+ int outY = minorIdx / p.minorDim;
72
+ minorIdx -= outY * p.minorDim;
73
+ int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
74
+ int majorIdxBase = blockIdx.z * p.loopMajor;
75
+ if (outXBase >= p.outW || outY >= p.outH || majorIdxBase >= p.majorDim)
76
+ return;
77
+
78
+ // Setup Y receptive field.
79
+ int midY = outY * p.downy + p.upy - 1 - p.pady0;
80
+ int inY = min(max(floorDiv(midY, p.upy), 0), p.inH);
81
+ int h = min(max(floorDiv(midY + p.kernelH, p.upy), 0), p.inH) - inY;
82
+ int kernelY = midY + p.kernelH - (inY + 1) * p.upy;
83
+
84
+ // Loop over majorDim and outX.
85
+ for (int loopMajor = 0, majorIdx = majorIdxBase; loopMajor < p.loopMajor && majorIdx < p.majorDim; loopMajor++, majorIdx++)
86
+ for (int loopX = 0, outX = outXBase; loopX < p.loopX && outX < p.outW; loopX++, outX += blockDim.y)
87
+ {
88
+ // Setup X receptive field.
89
+ int midX = outX * p.downx + p.upx - 1 - p.padx0;
90
+ int inX = min(max(floorDiv(midX, p.upx), 0), p.inW);
91
+ int w = min(max(floorDiv(midX + p.kernelW, p.upx), 0), p.inW) - inX;
92
+ int kernelX = midX + p.kernelW - (inX + 1) * p.upx;
93
+
94
+ // Initialize pointers.
95
+ const T* xp = &p.x[((majorIdx * p.inH + inY) * p.inW + inX) * p.minorDim + minorIdx];
96
+ const T* kp = &p.k[kernelY * p.kernelW + kernelX];
97
+ int xpx = p.minorDim;
98
+ int kpx = -p.upx;
99
+ int xpy = p.inW * p.minorDim;
100
+ int kpy = -p.upy * p.kernelW;
101
+
102
+ // Inner loop.
103
+ float v = 0.0f;
104
+ for (int y = 0; y < h; y++)
105
+ {
106
+ for (int x = 0; x < w; x++)
107
+ {
108
+ v += (float)(*xp) * (float)(*kp);
109
+ xp += xpx;
110
+ kp += kpx;
111
+ }
112
+ xp += xpy - w * xpx;
113
+ kp += kpy - w * kpx;
114
+ }
115
+
116
+ // Store result.
117
+ p.y[((majorIdx * p.outH + outY) * p.outW + outX) * p.minorDim + minorIdx] = (T)v;
118
+ }
119
+ }
120
+
121
+ //------------------------------------------------------------------------
122
+ // Specialized CUDA implementation for small filter kernels.
123
+
124
+ template <class T, int upx, int upy, int downx, int downy, int kernelW, int kernelH, int tileOutW, int tileOutH>
125
+ static __global__ void UpFirDn2DKernel_small(const UpFirDn2DKernelParams<T> p)
126
+ {
127
+ //assert(kernelW % upx == 0);
128
+ //assert(kernelH % upy == 0);
129
+ const int tileInW = ((tileOutW - 1) * downx + kernelW - 1) / upx + 1;
130
+ const int tileInH = ((tileOutH - 1) * downy + kernelH - 1) / upy + 1;
131
+ __shared__ volatile float sk[kernelH][kernelW];
132
+ __shared__ volatile float sx[tileInH][tileInW];
133
+
134
+ // Calculate tile index.
135
+ int minorIdx = blockIdx.x;
136
+ int tileOutY = minorIdx / p.minorDim;
137
+ minorIdx -= tileOutY * p.minorDim;
138
+ tileOutY *= tileOutH;
139
+ int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
140
+ int majorIdxBase = blockIdx.z * p.loopMajor;
141
+ if (tileOutXBase >= p.outW | tileOutY >= p.outH | majorIdxBase >= p.majorDim)
142
+ return;
143
+
144
+ // Load filter kernel (flipped).
145
+ for (int tapIdx = threadIdx.x; tapIdx < kernelH * kernelW; tapIdx += blockDim.x)
146
+ {
147
+ int ky = tapIdx / kernelW;
148
+ int kx = tapIdx - ky * kernelW;
149
+ float v = 0.0f;
150
+ if (kx < p.kernelW & ky < p.kernelH)
151
+ v = (float)p.k[(p.kernelH - 1 - ky) * p.kernelW + (p.kernelW - 1 - kx)];
152
+ sk[ky][kx] = v;
153
+ }
154
+
155
+ // Loop over majorDim and outX.
156
+ for (int loopMajor = 0, majorIdx = majorIdxBase; loopMajor < p.loopMajor & majorIdx < p.majorDim; loopMajor++, majorIdx++)
157
+ for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outW; loopX++, tileOutX += tileOutW)
158
+ {
159
+ // Load input pixels.
160
+ int tileMidX = tileOutX * downx + upx - 1 - p.padx0;
161
+ int tileMidY = tileOutY * downy + upy - 1 - p.pady0;
162
+ int tileInX = floorDiv(tileMidX, upx);
163
+ int tileInY = floorDiv(tileMidY, upy);
164
+ __syncthreads();
165
+ for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW; inIdx += blockDim.x)
166
+ {
167
+ int relInY = inIdx / tileInW;
168
+ int relInX = inIdx - relInY * tileInW;
169
+ int inX = relInX + tileInX;
170
+ int inY = relInY + tileInY;
171
+ float v = 0.0f;
172
+ if (inX >= 0 & inY >= 0 & inX < p.inW & inY < p.inH)
173
+ v = (float)p.x[((majorIdx * p.inH + inY) * p.inW + inX) * p.minorDim + minorIdx];
174
+ sx[relInY][relInX] = v;
175
+ }
176
+
177
+ // Loop over output pixels.
178
+ __syncthreads();
179
+ for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW; outIdx += blockDim.x)
180
+ {
181
+ int relOutY = outIdx / tileOutW;
182
+ int relOutX = outIdx - relOutY * tileOutW;
183
+ int outX = relOutX + tileOutX;
184
+ int outY = relOutY + tileOutY;
185
+
186
+ // Setup receptive field.
187
+ int midX = tileMidX + relOutX * downx;
188
+ int midY = tileMidY + relOutY * downy;
189
+ int inX = floorDiv(midX, upx);
190
+ int inY = floorDiv(midY, upy);
191
+ int relInX = inX - tileInX;
192
+ int relInY = inY - tileInY;
193
+ int kernelX = (inX + 1) * upx - midX - 1; // flipped
194
+ int kernelY = (inY + 1) * upy - midY - 1; // flipped
195
+
196
+ // Inner loop.
197
+ float v = 0.0f;
198
+ #pragma unroll
199
+ for (int y = 0; y < kernelH / upy; y++)
200
+ #pragma unroll
201
+ for (int x = 0; x < kernelW / upx; x++)
202
+ v += sx[relInY + y][relInX + x] * sk[kernelY + y * upy][kernelX + x * upx];
203
+
204
+ // Store result.
205
+ if (outX < p.outW & outY < p.outH)
206
+ p.y[((majorIdx * p.outH + outY) * p.outW + outX) * p.minorDim + minorIdx] = (T)v;
207
+ }
208
+ }
209
+ }
210
+
211
+ //------------------------------------------------------------------------
212
+ // TensorFlow op.
213
+
214
+ template <class T>
215
+ struct UpFirDn2DOp : public OpKernel
216
+ {
217
+ UpFirDn2DKernelParams<T> m_attribs;
218
+
219
+ UpFirDn2DOp(OpKernelConstruction* ctx) : OpKernel(ctx)
220
+ {
221
+ memset(&m_attribs, 0, sizeof(m_attribs));
222
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("upx", &m_attribs.upx));
223
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("upy", &m_attribs.upy));
224
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("downx", &m_attribs.downx));
225
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("downy", &m_attribs.downy));
226
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("padx0", &m_attribs.padx0));
227
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("padx1", &m_attribs.padx1));
228
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("pady0", &m_attribs.pady0));
229
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("pady1", &m_attribs.pady1));
230
+ OP_REQUIRES(ctx, m_attribs.upx >= 1 && m_attribs.upy >= 1, errors::InvalidArgument("upx and upy must be at least 1x1"));
231
+ OP_REQUIRES(ctx, m_attribs.downx >= 1 && m_attribs.downy >= 1, errors::InvalidArgument("downx and downy must be at least 1x1"));
232
+ }
233
+
234
+ void Compute(OpKernelContext* ctx)
235
+ {
236
+ UpFirDn2DKernelParams<T> p = m_attribs;
237
+ cudaStream_t stream = ctx->eigen_device<Eigen::GpuDevice>().stream();
238
+
239
+ const Tensor& x = ctx->input(0); // [majorDim, inH, inW, minorDim]
240
+ const Tensor& k = ctx->input(1); // [kernelH, kernelW]
241
+ p.x = x.flat<T>().data();
242
+ p.k = k.flat<T>().data();
243
+ OP_REQUIRES(ctx, x.dims() == 4, errors::InvalidArgument("input must have rank 4"));
244
+ OP_REQUIRES(ctx, k.dims() == 2, errors::InvalidArgument("kernel must have rank 2"));
245
+ OP_REQUIRES(ctx, x.NumElements() <= kint32max, errors::InvalidArgument("input too large"));
246
+ OP_REQUIRES(ctx, k.NumElements() <= kint32max, errors::InvalidArgument("kernel too large"));
247
+
248
+ p.majorDim = (int)x.dim_size(0);
249
+ p.inH = (int)x.dim_size(1);
250
+ p.inW = (int)x.dim_size(2);
251
+ p.minorDim = (int)x.dim_size(3);
252
+ p.kernelH = (int)k.dim_size(0);
253
+ p.kernelW = (int)k.dim_size(1);
254
+ OP_REQUIRES(ctx, p.kernelW >= 1 && p.kernelH >= 1, errors::InvalidArgument("kernel must be at least 1x1"));
255
+
256
+ p.outW = (p.inW * p.upx + p.padx0 + p.padx1 - p.kernelW + p.downx) / p.downx;
257
+ p.outH = (p.inH * p.upy + p.pady0 + p.pady1 - p.kernelH + p.downy) / p.downy;
258
+ OP_REQUIRES(ctx, p.outW >= 1 && p.outH >= 1, errors::InvalidArgument("output must be at least 1x1"));
259
+
260
+ Tensor* y = NULL; // [majorDim, outH, outW, minorDim]
261
+ TensorShape ys;
262
+ ys.AddDim(p.majorDim);
263
+ ys.AddDim(p.outH);
264
+ ys.AddDim(p.outW);
265
+ ys.AddDim(p.minorDim);
266
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, ys, &y));
267
+ p.y = y->flat<T>().data();
268
+ OP_REQUIRES(ctx, y->NumElements() <= kint32max, errors::InvalidArgument("output too large"));
269
+
270
+ // Choose CUDA kernel to use.
271
+ void* cudaKernel = (void*)UpFirDn2DKernel_large<T>;
272
+ int tileOutW = -1;
273
+ int tileOutH = -1;
274
+ if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 7 && p.kernelH <= 7) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 7,7, 64,16>; tileOutW = 64; tileOutH = 16; }
275
+ if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 6 && p.kernelH <= 6) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 6,6, 64,16>; tileOutW = 64; tileOutH = 16; }
276
+ if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 5 && p.kernelH <= 5) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 5,5, 64,16>; tileOutW = 64; tileOutH = 16; }
277
+ if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 4 && p.kernelH <= 4) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 4,4, 64,16>; tileOutW = 64; tileOutH = 16; }
278
+ if (p.upx == 1 && p.upy == 1 && p.downx == 1 && p.downy == 1 && p.kernelW <= 3 && p.kernelH <= 3) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 1,1, 3,3, 64,16>; tileOutW = 64; tileOutH = 16; }
279
+ if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 8 && p.kernelH <= 8) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,2, 1,1, 8,8, 64,16>; tileOutW = 64; tileOutH = 16; }
280
+ if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 6 && p.kernelH <= 6) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,2, 1,1, 6,6, 64,16>; tileOutW = 64; tileOutH = 16; }
281
+ if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 4 && p.kernelH <= 4) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,2, 1,1, 4,4, 64,16>; tileOutW = 64; tileOutH = 16; }
282
+ if (p.upx == 2 && p.upy == 2 && p.downx == 1 && p.downy == 1 && p.kernelW <= 2 && p.kernelH <= 2) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 2,2, 1,1, 2,2, 64,16>; tileOutW = 64; tileOutH = 16; }
283
+ if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 8 && p.kernelH <= 8) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,2, 8,8, 32,8>; tileOutW = 32; tileOutH = 8; }
284
+ if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 6 && p.kernelH <= 6) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,2, 6,6, 32,8>; tileOutW = 32; tileOutH = 8; }
285
+ if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 4 && p.kernelH <= 4) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,2, 4,4, 32,8>; tileOutW = 32; tileOutH = 8; }
286
+ if (p.upx == 1 && p.upy == 1 && p.downx == 2 && p.downy == 2 && p.kernelW <= 2 && p.kernelH <= 2) { cudaKernel = (void*)UpFirDn2DKernel_small<T, 1,1, 2,2, 2,2, 32,8>; tileOutW = 32; tileOutH = 8; }
287
+
288
+ // Choose launch params.
289
+ dim3 blockSize;
290
+ dim3 gridSize;
291
+ if (tileOutW > 0 && tileOutH > 0) // small
292
+ {
293
+ p.loopMajor = (p.majorDim - 1) / 16384 + 1;
294
+ p.loopX = 1;
295
+ blockSize = dim3(32 * 8, 1, 1);
296
+ gridSize = dim3(((p.outH - 1) / tileOutH + 1) * p.minorDim, (p.outW - 1) / (p.loopX * tileOutW) + 1, (p.majorDim - 1) / p.loopMajor + 1);
297
+ }
298
+ else // large
299
+ {
300
+ p.loopMajor = (p.majorDim - 1) / 16384 + 1;
301
+ p.loopX = 4;
302
+ blockSize = dim3(4, 32, 1);
303
+ gridSize = dim3((p.outH * p.minorDim - 1) / blockSize.x + 1, (p.outW - 1) / (p.loopX * blockSize.y) + 1, (p.majorDim - 1) / p.loopMajor + 1);
304
+ }
305
+
306
+ // Launch CUDA kernel.
307
+ void* args[] = {&p};
308
+ OP_CHECK_CUDA_ERROR(ctx, cudaLaunchKernel(cudaKernel, gridSize, blockSize, args, 0, stream));
309
+ }
310
+ };
311
+
312
+ REGISTER_OP("UpFirDn2D")
313
+ .Input ("x: T")
314
+ .Input ("k: T")
315
+ .Output ("y: T")
316
+ .Attr ("T: {float, half}")
317
+ .Attr ("upx: int = 1")
318
+ .Attr ("upy: int = 1")
319
+ .Attr ("downx: int = 1")
320
+ .Attr ("downy: int = 1")
321
+ .Attr ("padx0: int = 0")
322
+ .Attr ("padx1: int = 0")
323
+ .Attr ("pady0: int = 0")
324
+ .Attr ("pady1: int = 0");
325
+ REGISTER_KERNEL_BUILDER(Name("UpFirDn2D").Device(DEVICE_GPU).TypeConstraint<float>("T"), UpFirDn2DOp<float>);
326
+ REGISTER_KERNEL_BUILDER(Name("UpFirDn2D").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"), UpFirDn2DOp<Eigen::half>);
327
+
328
+ //------------------------------------------------------------------------
dnnlib/tflib/ops/upfirdn_2d.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
4
+ #
5
+ # This work is made available under the Nvidia Source Code License-NC.
6
+ # To view a copy of this license, visit
7
+ # https://nvlabs.github.io/stylegan2/license.html
8
+
9
+ """Custom TensorFlow ops for efficient resampling of 2D images."""
10
+
11
+ import os
12
+ import numpy as np
13
+ import tensorflow as tf
14
+ from .. import custom_ops
15
+
16
+ def _get_plugin():
17
+ return custom_ops.get_plugin(os.path.splitext(__file__)[0] + '.cu')
18
+
19
+ #----------------------------------------------------------------------------
20
+
21
+ def upfirdn_2d(x, k, upx=1, upy=1, downx=1, downy=1, padx0=0, padx1=0, pady0=0, pady1=0, impl='cuda'):
22
+ r"""Pad, upsample, FIR filter, and downsample a batch of 2D images.
23
+
24
+ Accepts a batch of 2D images of the shape `[majorDim, inH, inW, minorDim]`
25
+ and performs the following operations for each image, batched across
26
+ `majorDim` and `minorDim`:
27
+
28
+ 1. Pad the image with zeros by the specified number of pixels on each side
29
+ (`padx0`, `padx1`, `pady0`, `pady1`). Specifying a negative value
30
+ corresponds to cropping the image.
31
+
32
+ 2. Upsample the image by inserting the zeros after each pixel (`upx`, `upy`).
33
+
34
+ 3. Convolve the image with the specified 2D FIR filter (`k`), shrinking the
35
+ image so that the footprint of all output pixels lies within the input image.
36
+
37
+ 4. Downsample the image by throwing away pixels (`downx`, `downy`).
38
+
39
+ This sequence of operations bears close resemblance to scipy.signal.upfirdn().
40
+ The fused op is considerably more efficient than performing the same calculation
41
+ using standard TensorFlow ops. It supports gradients of arbitrary order.
42
+
43
+ Args:
44
+ x: Input tensor of the shape `[majorDim, inH, inW, minorDim]`.
45
+ k: 2D FIR filter of the shape `[firH, firW]`.
46
+ upx: Integer upsampling factor along the X-axis (default: 1).
47
+ upy: Integer upsampling factor along the Y-axis (default: 1).
48
+ downx: Integer downsampling factor along the X-axis (default: 1).
49
+ downy: Integer downsampling factor along the Y-axis (default: 1).
50
+ padx0: Number of pixels to pad on the left side (default: 0).
51
+ padx1: Number of pixels to pad on the right side (default: 0).
52
+ pady0: Number of pixels to pad on the top side (default: 0).
53
+ pady1: Number of pixels to pad on the bottom side (default: 0).
54
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
55
+
56
+ Returns:
57
+ Tensor of the shape `[majorDim, outH, outW, minorDim]`, and same datatype as `x`.
58
+ """
59
+
60
+ impl_dict = {
61
+ 'ref': _upfirdn_2d_ref,
62
+ 'cuda': _upfirdn_2d_cuda,
63
+ }
64
+ return impl_dict[impl](x=x, k=k, upx=upx, upy=upy, downx=downx, downy=downy, padx0=padx0, padx1=padx1, pady0=pady0, pady1=pady1)
65
+
66
+ #----------------------------------------------------------------------------
67
+
68
+ def _upfirdn_2d_ref(x, k, upx, upy, downx, downy, padx0, padx1, pady0, pady1):
69
+ """Slow reference implementation of `upfirdn_2d()` using standard TensorFlow ops."""
70
+
71
+ x = tf.convert_to_tensor(x)
72
+ k = np.asarray(k, dtype=np.float32)
73
+ assert x.shape.rank == 4
74
+ inH = x.shape[1].value
75
+ inW = x.shape[2].value
76
+ minorDim = _shape(x, 3)
77
+ kernelH, kernelW = k.shape
78
+ assert inW >= 1 and inH >= 1
79
+ assert kernelW >= 1 and kernelH >= 1
80
+ assert isinstance(upx, int) and isinstance(upy, int)
81
+ assert isinstance(downx, int) and isinstance(downy, int)
82
+ assert isinstance(padx0, int) and isinstance(padx1, int)
83
+ assert isinstance(pady0, int) and isinstance(pady1, int)
84
+
85
+ # Upsample (insert zeros).
86
+ x = tf.reshape(x, [-1, inH, 1, inW, 1, minorDim])
87
+ x = tf.pad(x, [[0, 0], [0, 0], [0, upy - 1], [0, 0], [0, upx - 1], [0, 0]])
88
+ x = tf.reshape(x, [-1, inH * upy, inW * upx, minorDim])
89
+
90
+ # Pad (crop if negative).
91
+ x = tf.pad(x, [[0, 0], [max(pady0, 0), max(pady1, 0)], [max(padx0, 0), max(padx1, 0)], [0, 0]])
92
+ x = x[:, max(-pady0, 0) : x.shape[1].value - max(-pady1, 0), max(-padx0, 0) : x.shape[2].value - max(-padx1, 0), :]
93
+
94
+ # Convolve with filter.
95
+ x = tf.transpose(x, [0, 3, 1, 2])
96
+ x = tf.reshape(x, [-1, 1, inH * upy + pady0 + pady1, inW * upx + padx0 + padx1])
97
+ w = tf.constant(k[::-1, ::-1, np.newaxis, np.newaxis], dtype=x.dtype)
98
+ x = tf.nn.conv2d(x, w, strides=[1,1,1,1], padding='VALID', data_format='NCHW')
99
+ x = tf.reshape(x, [-1, minorDim, inH * upy + pady0 + pady1 - kernelH + 1, inW * upx + padx0 + padx1 - kernelW + 1])
100
+ x = tf.transpose(x, [0, 2, 3, 1])
101
+
102
+ # Downsample (throw away pixels).
103
+ return x[:, ::downy, ::downx, :]
104
+
105
+ #----------------------------------------------------------------------------
106
+
107
+ def _upfirdn_2d_cuda(x, k, upx, upy, downx, downy, padx0, padx1, pady0, pady1):
108
+ """Fast CUDA implementation of `upfirdn_2d()` using custom ops."""
109
+
110
+ x = tf.convert_to_tensor(x)
111
+ k = np.asarray(k, dtype=np.float32)
112
+ majorDim, inH, inW, minorDim = x.shape.as_list()
113
+ kernelH, kernelW = k.shape
114
+ assert inW >= 1 and inH >= 1
115
+ assert kernelW >= 1 and kernelH >= 1
116
+ assert isinstance(upx, int) and isinstance(upy, int)
117
+ assert isinstance(downx, int) and isinstance(downy, int)
118
+ assert isinstance(padx0, int) and isinstance(padx1, int)
119
+ assert isinstance(pady0, int) and isinstance(pady1, int)
120
+
121
+ outW = (inW * upx + padx0 + padx1 - kernelW) // downx + 1
122
+ outH = (inH * upy + pady0 + pady1 - kernelH) // downy + 1
123
+ assert outW >= 1 and outH >= 1
124
+
125
+ kc = tf.constant(k, dtype=x.dtype)
126
+ gkc = tf.constant(k[::-1, ::-1], dtype=x.dtype)
127
+ gpadx0 = kernelW - padx0 - 1
128
+ gpady0 = kernelH - pady0 - 1
129
+ gpadx1 = inW * upx - outW * downx + padx0 - upx + 1
130
+ gpady1 = inH * upy - outH * downy + pady0 - upy + 1
131
+
132
+ @tf.custom_gradient
133
+ def func(x):
134
+ y = _get_plugin().up_fir_dn2d(x=x, k=kc, upx=upx, upy=upy, downx=downx, downy=downy, padx0=padx0, padx1=padx1, pady0=pady0, pady1=pady1)
135
+ y.set_shape([majorDim, outH, outW, minorDim])
136
+ @tf.custom_gradient
137
+ def grad(dy):
138
+ dx = _get_plugin().up_fir_dn2d(x=dy, k=gkc, upx=downx, upy=downy, downx=upx, downy=upy, padx0=gpadx0, padx1=gpadx1, pady0=gpady0, pady1=gpady1)
139
+ dx.set_shape([majorDim, inH, inW, minorDim])
140
+ return dx, func
141
+ return y, grad
142
+ return func(x)
143
+
144
+ #----------------------------------------------------------------------------
145
+
146
+ def filter_2d(x, k, gain=1, data_format='NCHW', impl='cuda'):
147
+ r"""Filter a batch of 2D images with the given FIR filter.
148
+
149
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
150
+ and filters each image with the given filter. The filter is normalized so that
151
+ if the input pixels are constant, they will be scaled by the specified `gain`.
152
+ Pixels outside the image are assumed to be zero.
153
+
154
+ Args:
155
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
156
+ k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
157
+ gain: Scaling factor for signal magnitude (default: 1.0).
158
+ data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
159
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
160
+
161
+ Returns:
162
+ Tensor of the same shape and datatype as `x`.
163
+ """
164
+
165
+ k = _setup_kernel(k) * gain
166
+ p = k.shape[0] - 1
167
+ return _simple_upfirdn_2d(x, k, pad0=(p+1)//2, pad1=p//2, data_format=data_format, impl=impl)
168
+
169
+ #----------------------------------------------------------------------------
170
+
171
+ def upsample_2d(x, k=None, factor=2, gain=1, data_format='NCHW', impl='cuda'):
172
+ r"""Upsample a batch of 2D images with the given filter.
173
+
174
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
175
+ and upsamples each image with the given filter. The filter is normalized so that
176
+ if the input pixels are constant, they will be scaled by the specified `gain`.
177
+ Pixels outside the image are assumed to be zero, and the filter is padded with
178
+ zeros so that its shape is a multiple of the upsampling factor.
179
+
180
+ Args:
181
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
182
+ k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
183
+ The default is `[1] * factor`, which corresponds to nearest-neighbor
184
+ upsampling.
185
+ factor: Integer upsampling factor (default: 2).
186
+ gain: Scaling factor for signal magnitude (default: 1.0).
187
+ data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
188
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
189
+
190
+ Returns:
191
+ Tensor of the shape `[N, C, H * factor, W * factor]` or
192
+ `[N, H * factor, W * factor, C]`, and same datatype as `x`.
193
+ """
194
+
195
+ assert isinstance(factor, int) and factor >= 1
196
+ if k is None:
197
+ k = [1] * factor
198
+ k = _setup_kernel(k) * (gain * (factor ** 2))
199
+ p = k.shape[0] - factor
200
+ return _simple_upfirdn_2d(x, k, up=factor, pad0=(p+1)//2+factor-1, pad1=p//2, data_format=data_format, impl=impl)
201
+
202
+ #----------------------------------------------------------------------------
203
+
204
+ def downsample_2d(x, k=None, factor=2, gain=1, data_format='NCHW', impl='cuda'):
205
+ r"""Downsample a batch of 2D images with the given filter.
206
+
207
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]`
208
+ and downsamples each image with the given filter. The filter is normalized so that
209
+ if the input pixels are constant, they will be scaled by the specified `gain`.
210
+ Pixels outside the image are assumed to be zero, and the filter is padded with
211
+ zeros so that its shape is a multiple of the downsampling factor.
212
+
213
+ Args:
214
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
215
+ k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
216
+ The default is `[1] * factor`, which corresponds to average pooling.
217
+ factor: Integer downsampling factor (default: 2).
218
+ gain: Scaling factor for signal magnitude (default: 1.0).
219
+ data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
220
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
221
+
222
+ Returns:
223
+ Tensor of the shape `[N, C, H // factor, W // factor]` or
224
+ `[N, H // factor, W // factor, C]`, and same datatype as `x`.
225
+ """
226
+
227
+ assert isinstance(factor, int) and factor >= 1
228
+ if k is None:
229
+ k = [1] * factor
230
+ k = _setup_kernel(k) * gain
231
+ p = k.shape[0] - factor
232
+ return _simple_upfirdn_2d(x, k, down=factor, pad0=(p+1)//2, pad1=p//2, data_format=data_format, impl=impl)
233
+
234
+ #----------------------------------------------------------------------------
235
+
236
+ def upsample_conv_2d(x, w, k=None, factor=2, gain=1, data_format='NCHW', impl='cuda'):
237
+ r"""Fused `upsample_2d()` followed by `tf.nn.conv2d()`.
238
+
239
+ Padding is performed only once at the beginning, not between the operations.
240
+ The fused op is considerably more efficient than performing the same calculation
241
+ using standard TensorFlow ops. It supports gradients of arbitrary order.
242
+
243
+ Args:
244
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
245
+ w: Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`.
246
+ Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
247
+ k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
248
+ The default is `[1] * factor`, which corresponds to nearest-neighbor
249
+ upsampling.
250
+ factor: Integer upsampling factor (default: 2).
251
+ gain: Scaling factor for signal magnitude (default: 1.0).
252
+ data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
253
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
254
+
255
+ Returns:
256
+ Tensor of the shape `[N, C, H * factor, W * factor]` or
257
+ `[N, H * factor, W * factor, C]`, and same datatype as `x`.
258
+ """
259
+
260
+ assert isinstance(factor, int) and factor >= 1
261
+
262
+ # Check weight shape.
263
+ w = tf.convert_to_tensor(w)
264
+ assert w.shape.rank == 4
265
+ convH = w.shape[0].value
266
+ convW = w.shape[1].value
267
+ inC = _shape(w, 2)
268
+ outC = _shape(w, 3)
269
+ assert convW == convH
270
+
271
+ # Setup filter kernel.
272
+ if k is None:
273
+ k = [1] * factor
274
+ k = _setup_kernel(k) * (gain * (factor ** 2))
275
+ p = (k.shape[0] - factor) - (convW - 1)
276
+
277
+ # Determine data dimensions.
278
+ if data_format == 'NCHW':
279
+ stride = [1, 1, factor, factor]
280
+ output_shape = [_shape(x, 0), outC, (_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW]
281
+ num_groups = _shape(x, 1) // inC
282
+ else:
283
+ stride = [1, factor, factor, 1]
284
+ output_shape = [_shape(x, 0), (_shape(x, 1) - 1) * factor + convH, (_shape(x, 2) - 1) * factor + convW, outC]
285
+ num_groups = _shape(x, 3) // inC
286
+
287
+ # Transpose weights.
288
+ w = tf.reshape(w, [convH, convW, inC, num_groups, -1])
289
+ w = tf.transpose(w[::-1, ::-1], [0, 1, 4, 3, 2])
290
+ w = tf.reshape(w, [convH, convW, -1, num_groups * inC])
291
+
292
+ # Execute.
293
+ x = tf.nn.conv2d_transpose(x, w, output_shape=output_shape, strides=stride, padding='VALID', data_format=data_format)
294
+ return _simple_upfirdn_2d(x, k, pad0=(p+1)//2+factor-1, pad1=p//2+1, data_format=data_format, impl=impl)
295
+
296
+ #----------------------------------------------------------------------------
297
+
298
+ def conv_downsample_2d(x, w, k=None, factor=2, gain=1, data_format='NCHW', impl='cuda'):
299
+ r"""Fused `tf.nn.conv2d()` followed by `downsample_2d()`.
300
+
301
+ Padding is performed only once at the beginning, not between the operations.
302
+ The fused op is considerably more efficient than performing the same calculation
303
+ using standard TensorFlow ops. It supports gradients of arbitrary order.
304
+
305
+ Args:
306
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
307
+ w: Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`.
308
+ Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
309
+ k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable).
310
+ The default is `[1] * factor`, which corresponds to average pooling.
311
+ factor: Integer downsampling factor (default: 2).
312
+ gain: Scaling factor for signal magnitude (default: 1.0).
313
+ data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`).
314
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
315
+
316
+ Returns:
317
+ Tensor of the shape `[N, C, H // factor, W // factor]` or
318
+ `[N, H // factor, W // factor, C]`, and same datatype as `x`.
319
+ """
320
+
321
+ assert isinstance(factor, int) and factor >= 1
322
+ w = tf.convert_to_tensor(w)
323
+ convH, convW, _inC, _outC = w.shape.as_list()
324
+ assert convW == convH
325
+ if k is None:
326
+ k = [1] * factor
327
+ k = _setup_kernel(k) * gain
328
+ p = (k.shape[0] - factor) + (convW - 1)
329
+ if data_format == 'NCHW':
330
+ s = [1, 1, factor, factor]
331
+ else:
332
+ s = [1, factor, factor, 1]
333
+ x = _simple_upfirdn_2d(x, k, pad0=(p+1)//2, pad1=p//2, data_format=data_format, impl=impl)
334
+ return tf.nn.conv2d(x, w, strides=s, padding='VALID', data_format=data_format)
335
+
336
+ #----------------------------------------------------------------------------
337
+ # Internal helper funcs.
338
+
339
+ def _shape(tf_expr, dim_idx):
340
+ if tf_expr.shape.rank is not None:
341
+ dim = tf_expr.shape[dim_idx].value
342
+ if dim is not None:
343
+ return dim
344
+ return tf.shape(tf_expr)[dim_idx]
345
+
346
+ def _setup_kernel(k):
347
+ k = np.asarray(k, dtype=np.float32)
348
+ if k.ndim == 1:
349
+ k = np.outer(k, k)
350
+ k /= np.sum(k)
351
+ assert k.ndim == 2
352
+ assert k.shape[0] == k.shape[1]
353
+ return k
354
+
355
+ def _simple_upfirdn_2d(x, k, up=1, down=1, pad0=0, pad1=0, data_format='NCHW', impl='cuda'):
356
+ assert data_format in ['NCHW', 'NHWC']
357
+ assert x.shape.rank == 4
358
+ y = x
359
+ if data_format == 'NCHW':
360
+ y = tf.reshape(y, [-1, _shape(y, 2), _shape(y, 3), 1])
361
+ y = upfirdn_2d(y, k, upx=up, upy=up, downx=down, downy=down, padx0=pad0, padx1=pad1, pady0=pad0, pady1=pad1, impl=impl)
362
+ if data_format == 'NCHW':
363
+ y = tf.reshape(y, [-1, _shape(x, 1), _shape(y, 1), _shape(y, 2)])
364
+ return y
365
+
366
+ #----------------------------------------------------------------------------
dnnlib/tflib/optimizer.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
4
+ #
5
+ # This work is made available under the Nvidia Source Code License-NC.
6
+ # To view a copy of this license, visit
7
+ # https://nvlabs.github.io/stylegan2/license.html
8
+
9
+ """Helper wrapper for a Tensorflow optimizer."""
10
+
11
+ import numpy as np
12
+ import tensorflow as tf
13
+
14
+ from collections import OrderedDict
15
+ from typing import List, Union
16
+
17
+ from . import autosummary
18
+ from . import tfutil
19
+ from .. import util
20
+
21
+ from .tfutil import TfExpression, TfExpressionEx
22
+
23
+ try:
24
+ # TensorFlow 1.13
25
+ from tensorflow.python.ops import nccl_ops
26
+ except:
27
+ # Older TensorFlow versions
28
+ import tensorflow.contrib.nccl as nccl_ops
29
+
30
+ class Optimizer:
31
+ """A Wrapper for tf.train.Optimizer.
32
+
33
+ Automatically takes care of:
34
+ - Gradient averaging for multi-GPU training.
35
+ - Gradient accumulation for arbitrarily large minibatches.
36
+ - Dynamic loss scaling and typecasts for FP16 training.
37
+ - Ignoring corrupted gradients that contain NaNs/Infs.
38
+ - Reporting statistics.
39
+ - Well-chosen default settings.
40
+ """
41
+
42
+ def __init__(self,
43
+ name: str = "Train", # Name string that will appear in TensorFlow graph.
44
+ tf_optimizer: str = "tf.train.AdamOptimizer", # Underlying optimizer class.
45
+ learning_rate: TfExpressionEx = 0.001, # Learning rate. Can vary over time.
46
+ minibatch_multiplier: TfExpressionEx = None, # Treat N consecutive minibatches as one by accumulating gradients.
47
+ share: "Optimizer" = None, # Share internal state with a previously created optimizer?
48
+ use_loss_scaling: bool = False, # Enable dynamic loss scaling for robust mixed-precision training?
49
+ loss_scaling_init: float = 64.0, # Log2 of initial loss scaling factor.
50
+ loss_scaling_inc: float = 0.0005, # Log2 of per-minibatch loss scaling increment when there is no overflow.
51
+ loss_scaling_dec: float = 1.0, # Log2 of per-minibatch loss scaling decrement when there is an overflow.
52
+ report_mem_usage: bool = False, # Report fine-grained memory usage statistics in TensorBoard?
53
+ **kwargs):
54
+
55
+ # Public fields.
56
+ self.name = name
57
+ self.learning_rate = learning_rate
58
+ self.minibatch_multiplier = minibatch_multiplier
59
+ self.id = self.name.replace("/", ".")
60
+ self.scope = tf.get_default_graph().unique_name(self.id)
61
+ self.optimizer_class = util.get_obj_by_name(tf_optimizer)
62
+ self.optimizer_kwargs = dict(kwargs)
63
+ self.use_loss_scaling = use_loss_scaling
64
+ self.loss_scaling_init = loss_scaling_init
65
+ self.loss_scaling_inc = loss_scaling_inc
66
+ self.loss_scaling_dec = loss_scaling_dec
67
+
68
+ # Private fields.
69
+ self._updates_applied = False
70
+ self._devices = OrderedDict() # device_name => EasyDict()
71
+ self._shared_optimizers = OrderedDict() # device_name => optimizer_class
72
+ self._gradient_shapes = None # [shape, ...]
73
+ self._report_mem_usage = report_mem_usage
74
+
75
+ # Validate arguments.
76
+ assert callable(self.optimizer_class)
77
+
78
+ # Share internal state if requested.
79
+ if share is not None:
80
+ assert isinstance(share, Optimizer)
81
+ assert self.optimizer_class is share.optimizer_class
82
+ assert self.learning_rate is share.learning_rate
83
+ assert self.optimizer_kwargs == share.optimizer_kwargs
84
+ self._shared_optimizers = share._shared_optimizers # pylint: disable=protected-access
85
+
86
+ def _get_device(self, device_name: str):
87
+ """Get internal state for the given TensorFlow device."""
88
+ tfutil.assert_tf_initialized()
89
+ if device_name in self._devices:
90
+ return self._devices[device_name]
91
+
92
+ # Initialize fields.
93
+ device = util.EasyDict()
94
+ device.name = device_name
95
+ device.optimizer = None # Underlying optimizer: optimizer_class
96
+ device.loss_scaling_var = None # Log2 of loss scaling: tf.Variable
97
+ device.grad_raw = OrderedDict() # Raw gradients: var => [grad, ...]
98
+ device.grad_clean = OrderedDict() # Clean gradients: var => grad
99
+ device.grad_acc_vars = OrderedDict() # Accumulation sums: var => tf.Variable
100
+ device.grad_acc_count = None # Accumulation counter: tf.Variable
101
+ device.grad_acc = OrderedDict() # Accumulated gradients: var => grad
102
+
103
+ # Setup TensorFlow objects.
104
+ with tfutil.absolute_name_scope(self.scope + "/Devices"), tf.device(device_name), tf.control_dependencies(None):
105
+ if device_name not in self._shared_optimizers:
106
+ optimizer_name = self.scope.replace("/", "_") + "_opt%d" % len(self._shared_optimizers)
107
+ self._shared_optimizers[device_name] = self.optimizer_class(name=optimizer_name, learning_rate=self.learning_rate, **self.optimizer_kwargs)
108
+ device.optimizer = self._shared_optimizers[device_name]
109
+ if self.use_loss_scaling:
110
+ device.loss_scaling_var = tf.Variable(np.float32(self.loss_scaling_init), trainable=False, name="loss_scaling_var")
111
+
112
+ # Register device.
113
+ self._devices[device_name] = device
114
+ return device
115
+
116
+ def register_gradients(self, loss: TfExpression, trainable_vars: Union[List, dict]) -> None:
117
+ """Register the gradients of the given loss function with respect to the given variables.
118
+ Intended to be called once per GPU."""
119
+ tfutil.assert_tf_initialized()
120
+ assert not self._updates_applied
121
+ device = self._get_device(loss.device)
122
+
123
+ # Validate trainables.
124
+ if isinstance(trainable_vars, dict):
125
+ trainable_vars = list(trainable_vars.values()) # allow passing in Network.trainables as vars
126
+ assert isinstance(trainable_vars, list) and len(trainable_vars) >= 1
127
+ assert all(tfutil.is_tf_expression(expr) for expr in trainable_vars + [loss])
128
+ assert all(var.device == device.name for var in trainable_vars)
129
+
130
+ # Validate shapes.
131
+ if self._gradient_shapes is None:
132
+ self._gradient_shapes = [var.shape.as_list() for var in trainable_vars]
133
+ assert len(trainable_vars) == len(self._gradient_shapes)
134
+ assert all(var.shape.as_list() == var_shape for var, var_shape in zip(trainable_vars, self._gradient_shapes))
135
+
136
+ # Report memory usage if requested.
137
+ deps = []
138
+ if self._report_mem_usage:
139
+ self._report_mem_usage = False
140
+ try:
141
+ with tf.name_scope(self.id + '_mem'), tf.device(device.name), tf.control_dependencies([loss]):
142
+ deps.append(autosummary.autosummary(self.id + "/mem_usage_gb", tf.contrib.memory_stats.BytesInUse() / 2**30))
143
+ except tf.errors.NotFoundError:
144
+ pass
145
+
146
+ # Compute gradients.
147
+ with tf.name_scope(self.id + "_grad"), tf.device(device.name), tf.control_dependencies(deps):
148
+ loss = self.apply_loss_scaling(tf.cast(loss, tf.float32))
149
+ gate = tf.train.Optimizer.GATE_NONE # disable gating to reduce memory usage
150
+ grad_list = device.optimizer.compute_gradients(loss=loss, var_list=trainable_vars, gate_gradients=gate)
151
+
152
+ # Register gradients.
153
+ for grad, var in grad_list:
154
+ if var not in device.grad_raw:
155
+ device.grad_raw[var] = []
156
+ device.grad_raw[var].append(grad)
157
+
158
+ def apply_updates(self, allow_no_op: bool = False) -> tf.Operation:
159
+ """Construct training op to update the registered variables based on their gradients."""
160
+ tfutil.assert_tf_initialized()
161
+ assert not self._updates_applied
162
+ self._updates_applied = True
163
+ all_ops = []
164
+
165
+ # Check for no-op.
166
+ if allow_no_op and len(self._devices) == 0:
167
+ with tfutil.absolute_name_scope(self.scope):
168
+ return tf.no_op(name='TrainingOp')
169
+
170
+ # Clean up gradients.
171
+ for device_idx, device in enumerate(self._devices.values()):
172
+ with tfutil.absolute_name_scope(self.scope + "/Clean%d" % device_idx), tf.device(device.name):
173
+ for var, grad in device.grad_raw.items():
174
+
175
+ # Filter out disconnected gradients and convert to float32.
176
+ grad = [g for g in grad if g is not None]
177
+ grad = [tf.cast(g, tf.float32) for g in grad]
178
+
179
+ # Sum within the device.
180
+ if len(grad) == 0:
181
+ grad = tf.zeros(var.shape) # No gradients => zero.
182
+ elif len(grad) == 1:
183
+ grad = grad[0] # Single gradient => use as is.
184
+ else:
185
+ grad = tf.add_n(grad) # Multiple gradients => sum.
186
+
187
+ # Scale as needed.
188
+ scale = 1.0 / len(device.grad_raw[var]) / len(self._devices)
189
+ scale = tf.constant(scale, dtype=tf.float32, name="scale")
190
+ if self.minibatch_multiplier is not None:
191
+ scale /= tf.cast(self.minibatch_multiplier, tf.float32)
192
+ scale = self.undo_loss_scaling(scale)
193
+ device.grad_clean[var] = grad * scale
194
+
195
+ # Sum gradients across devices.
196
+ if len(self._devices) > 1:
197
+ with tfutil.absolute_name_scope(self.scope + "/Broadcast"), tf.device(None):
198
+ for all_vars in zip(*[device.grad_clean.keys() for device in self._devices.values()]):
199
+ if len(all_vars) > 0 and all(dim > 0 for dim in all_vars[0].shape.as_list()): # NCCL does not support zero-sized tensors.
200
+ all_grads = [device.grad_clean[var] for device, var in zip(self._devices.values(), all_vars)]
201
+ all_grads = nccl_ops.all_sum(all_grads)
202
+ for device, var, grad in zip(self._devices.values(), all_vars, all_grads):
203
+ device.grad_clean[var] = grad
204
+
205
+ # Apply updates separately on each device.
206
+ for device_idx, device in enumerate(self._devices.values()):
207
+ with tfutil.absolute_name_scope(self.scope + "/Apply%d" % device_idx), tf.device(device.name):
208
+ # pylint: disable=cell-var-from-loop
209
+
210
+ # Accumulate gradients over time.
211
+ if self.minibatch_multiplier is None:
212
+ acc_ok = tf.constant(True, name='acc_ok')
213
+ device.grad_acc = OrderedDict(device.grad_clean)
214
+ else:
215
+ # Create variables.
216
+ with tf.control_dependencies(None):
217
+ for var in device.grad_clean.keys():
218
+ device.grad_acc_vars[var] = tf.Variable(tf.zeros(var.shape), trainable=False, name="grad_acc_var")
219
+ device.grad_acc_count = tf.Variable(tf.zeros([]), trainable=False, name="grad_acc_count")
220
+
221
+ # Track counter.
222
+ count_cur = device.grad_acc_count + 1.0
223
+ count_inc_op = lambda: tf.assign(device.grad_acc_count, count_cur)
224
+ count_reset_op = lambda: tf.assign(device.grad_acc_count, tf.zeros([]))
225
+ acc_ok = (count_cur >= tf.cast(self.minibatch_multiplier, tf.float32))
226
+ all_ops.append(tf.cond(acc_ok, count_reset_op, count_inc_op))
227
+
228
+ # Track gradients.
229
+ for var, grad in device.grad_clean.items():
230
+ acc_var = device.grad_acc_vars[var]
231
+ acc_cur = acc_var + grad
232
+ device.grad_acc[var] = acc_cur
233
+ with tf.control_dependencies([acc_cur]):
234
+ acc_inc_op = lambda: tf.assign(acc_var, acc_cur)
235
+ acc_reset_op = lambda: tf.assign(acc_var, tf.zeros(var.shape))
236
+ all_ops.append(tf.cond(acc_ok, acc_reset_op, acc_inc_op))
237
+
238
+ # No overflow => apply gradients.
239
+ all_ok = tf.reduce_all(tf.stack([acc_ok] + [tf.reduce_all(tf.is_finite(g)) for g in device.grad_acc.values()]))
240
+ apply_op = lambda: device.optimizer.apply_gradients([(tf.cast(grad, var.dtype), var) for var, grad in device.grad_acc.items()])
241
+ all_ops.append(tf.cond(all_ok, apply_op, tf.no_op))
242
+
243
+ # Adjust loss scaling.
244
+ if self.use_loss_scaling:
245
+ ls_inc_op = lambda: tf.assign_add(device.loss_scaling_var, self.loss_scaling_inc)
246
+ ls_dec_op = lambda: tf.assign_sub(device.loss_scaling_var, self.loss_scaling_dec)
247
+ ls_update_op = lambda: tf.group(tf.cond(all_ok, ls_inc_op, ls_dec_op))
248
+ all_ops.append(tf.cond(acc_ok, ls_update_op, tf.no_op))
249
+
250
+ # Last device => report statistics.
251
+ if device_idx == len(self._devices) - 1:
252
+ all_ops.append(autosummary.autosummary(self.id + "/learning_rate", self.learning_rate))
253
+ all_ops.append(autosummary.autosummary(self.id + "/overflow_frequency", tf.where(all_ok, 0, 1), condition=acc_ok))
254
+ if self.use_loss_scaling:
255
+ all_ops.append(autosummary.autosummary(self.id + "/loss_scaling_log2", device.loss_scaling_var))
256
+
257
+ # Initialize variables.
258
+ self.reset_optimizer_state()
259
+ if self.use_loss_scaling:
260
+ tfutil.init_uninitialized_vars([device.loss_scaling_var for device in self._devices.values()])
261
+ if self.minibatch_multiplier is not None:
262
+ tfutil.run([var.initializer for device in self._devices.values() for var in list(device.grad_acc_vars.values()) + [device.grad_acc_count]])
263
+
264
+ # Group everything into a single op.
265
+ with tfutil.absolute_name_scope(self.scope):
266
+ return tf.group(*all_ops, name="TrainingOp")
267
+
268
+ def reset_optimizer_state(self) -> None:
269
+ """Reset internal state of the underlying optimizer."""
270
+ tfutil.assert_tf_initialized()
271
+ tfutil.run([var.initializer for device in self._devices.values() for var in device.optimizer.variables()])
272
+
273
+ def get_loss_scaling_var(self, device: str) -> Union[tf.Variable, None]:
274
+ """Get or create variable representing log2 of the current dynamic loss scaling factor."""
275
+ return self._get_device(device).loss_scaling_var
276
+
277
+ def apply_loss_scaling(self, value: TfExpression) -> TfExpression:
278
+ """Apply dynamic loss scaling for the given expression."""
279
+ assert tfutil.is_tf_expression(value)
280
+ if not self.use_loss_scaling:
281
+ return value
282
+ return value * tfutil.exp2(self.get_loss_scaling_var(value.device))
283
+
284
+ def undo_loss_scaling(self, value: TfExpression) -> TfExpression:
285
+ """Undo the effect of dynamic loss scaling for the given expression."""
286
+ assert tfutil.is_tf_expression(value)
287
+ if not self.use_loss_scaling:
288
+ return value
289
+ return value * tfutil.exp2(-self.get_loss_scaling_var(value.device)) # pylint: disable=invalid-unary-operand-type
290
+
291
+
292
+ class SimpleAdam:
293
+ """Simplified version of tf.train.AdamOptimizer that behaves identically when used with dnnlib.tflib.Optimizer."""
294
+
295
+ def __init__(self, name="Adam", learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8):
296
+ self.name = name
297
+ self.learning_rate = learning_rate
298
+ self.beta1 = beta1
299
+ self.beta2 = beta2
300
+ self.epsilon = epsilon
301
+ self.all_state_vars = []
302
+
303
+ def variables(self):
304
+ return self.all_state_vars
305
+
306
+ def compute_gradients(self, loss, var_list, gate_gradients=tf.train.Optimizer.GATE_NONE):
307
+ assert gate_gradients == tf.train.Optimizer.GATE_NONE
308
+ return list(zip(tf.gradients(loss, var_list), var_list))
309
+
310
+ def apply_gradients(self, grads_and_vars):
311
+ with tf.name_scope(self.name):
312
+ state_vars = []
313
+ update_ops = []
314
+
315
+ # Adjust learning rate to deal with startup bias.
316
+ with tf.control_dependencies(None):
317
+ b1pow_var = tf.Variable(dtype=tf.float32, initial_value=1, trainable=False)
318
+ b2pow_var = tf.Variable(dtype=tf.float32, initial_value=1, trainable=False)
319
+ state_vars += [b1pow_var, b2pow_var]
320
+ b1pow_new = b1pow_var * self.beta1
321
+ b2pow_new = b2pow_var * self.beta2
322
+ update_ops += [tf.assign(b1pow_var, b1pow_new), tf.assign(b2pow_var, b2pow_new)]
323
+ lr_new = self.learning_rate * tf.sqrt(1 - b2pow_new) / (1 - b1pow_new)
324
+
325
+ # Construct ops to update each variable.
326
+ for grad, var in grads_and_vars:
327
+ with tf.control_dependencies(None):
328
+ m_var = tf.Variable(dtype=tf.float32, initial_value=tf.zeros_like(var), trainable=False)
329
+ v_var = tf.Variable(dtype=tf.float32, initial_value=tf.zeros_like(var), trainable=False)
330
+ state_vars += [m_var, v_var]
331
+ m_new = self.beta1 * m_var + (1 - self.beta1) * grad
332
+ v_new = self.beta2 * v_var + (1 - self.beta2) * tf.square(grad)
333
+ var_delta = lr_new * m_new / (tf.sqrt(v_new) + self.epsilon)
334
+ update_ops += [tf.assign(m_var, m_new), tf.assign(v_var, v_new), tf.assign_sub(var, var_delta)]
335
+
336
+ # Group everything together.
337
+ self.all_state_vars += state_vars
338
+ return tf.group(*update_ops)
dnnlib/tflib/tfutil.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
4
+ #
5
+ # This work is made available under the Nvidia Source Code License-NC.
6
+ # To view a copy of this license, visit
7
+ # https://nvlabs.github.io/stylegan2/license.html
8
+
9
+ """Miscellaneous helper utils for Tensorflow."""
10
+
11
+ import os
12
+ import numpy as np
13
+ import tensorflow as tf
14
+
15
+ # Silence deprecation warnings from TensorFlow 1.13 onwards
16
+ import logging
17
+ logging.getLogger('tensorflow').setLevel(logging.ERROR)
18
+ import tensorflow.contrib # requires TensorFlow 1.x!
19
+ tf.contrib = tensorflow.contrib
20
+
21
+ from typing import Any, Iterable, List, Union
22
+
23
+ TfExpression = Union[tf.Tensor, tf.Variable, tf.Operation]
24
+ """A type that represents a valid Tensorflow expression."""
25
+
26
+ TfExpressionEx = Union[TfExpression, int, float, np.ndarray]
27
+ """A type that can be converted to a valid Tensorflow expression."""
28
+
29
+
30
+ def run(*args, **kwargs) -> Any:
31
+ """Run the specified ops in the default session."""
32
+ assert_tf_initialized()
33
+ return tf.get_default_session().run(*args, **kwargs)
34
+
35
+
36
+ def is_tf_expression(x: Any) -> bool:
37
+ """Check whether the input is a valid Tensorflow expression, i.e., Tensorflow Tensor, Variable, or Operation."""
38
+ return isinstance(x, (tf.Tensor, tf.Variable, tf.Operation))
39
+
40
+
41
+ def shape_to_list(shape: Iterable[tf.Dimension]) -> List[Union[int, None]]:
42
+ """Convert a Tensorflow shape to a list of ints. Retained for backwards compatibility -- use TensorShape.as_list() in new code."""
43
+ return [dim.value for dim in shape]
44
+
45
+
46
+ def flatten(x: TfExpressionEx) -> TfExpression:
47
+ """Shortcut function for flattening a tensor."""
48
+ with tf.name_scope("Flatten"):
49
+ return tf.reshape(x, [-1])
50
+
51
+
52
+ def log2(x: TfExpressionEx) -> TfExpression:
53
+ """Logarithm in base 2."""
54
+ with tf.name_scope("Log2"):
55
+ return tf.log(x) * np.float32(1.0 / np.log(2.0))
56
+
57
+
58
+ def exp2(x: TfExpressionEx) -> TfExpression:
59
+ """Exponent in base 2."""
60
+ with tf.name_scope("Exp2"):
61
+ return tf.exp(x * np.float32(np.log(2.0)))
62
+
63
+
64
+ def lerp(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpressionEx:
65
+ """Linear interpolation."""
66
+ with tf.name_scope("Lerp"):
67
+ return a + (b - a) * t
68
+
69
+
70
+ def lerp_clip(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpression:
71
+ """Linear interpolation with clip."""
72
+ with tf.name_scope("LerpClip"):
73
+ return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0)
74
+
75
+
76
+ def absolute_name_scope(scope: str) -> tf.name_scope:
77
+ """Forcefully enter the specified name scope, ignoring any surrounding scopes."""
78
+ return tf.name_scope(scope + "/")
79
+
80
+
81
+ def absolute_variable_scope(scope: str, **kwargs) -> tf.variable_scope:
82
+ """Forcefully enter the specified variable scope, ignoring any surrounding scopes."""
83
+ return tf.variable_scope(tf.VariableScope(name=scope, **kwargs), auxiliary_name_scope=False)
84
+
85
+
86
+ def _sanitize_tf_config(config_dict: dict = None) -> dict:
87
+ # Defaults.
88
+ cfg = dict()
89
+ cfg["rnd.np_random_seed"] = None # Random seed for NumPy. None = keep as is.
90
+ cfg["rnd.tf_random_seed"] = "auto" # Random seed for TensorFlow. 'auto' = derive from NumPy random state. None = keep as is.
91
+ cfg["env.TF_CPP_MIN_LOG_LEVEL"] = "1" # 0 = Print all available debug info from TensorFlow. 1 = Print warnings and errors, but disable debug info.
92
+ cfg["graph_options.place_pruned_graph"] = True # False = Check that all ops are available on the designated device. True = Skip the check for ops that are not used.
93
+ cfg["gpu_options.allow_growth"] = True # False = Allocate all GPU memory at the beginning. True = Allocate only as much GPU memory as needed.
94
+
95
+ # Remove defaults for environment variables that are already set.
96
+ for key in list(cfg):
97
+ fields = key.split(".")
98
+ if fields[0] == "env":
99
+ assert len(fields) == 2
100
+ if fields[1] in os.environ:
101
+ del cfg[key]
102
+
103
+ # User overrides.
104
+ if config_dict is not None:
105
+ cfg.update(config_dict)
106
+ return cfg
107
+
108
+
109
+ def init_tf(config_dict: dict = None) -> None:
110
+ """Initialize TensorFlow session using good default settings."""
111
+ # Skip if already initialized.
112
+ if tf.get_default_session() is not None:
113
+ return
114
+
115
+ # Setup config dict and random seeds.
116
+ cfg = _sanitize_tf_config(config_dict)
117
+ np_random_seed = cfg["rnd.np_random_seed"]
118
+ if np_random_seed is not None:
119
+ np.random.seed(np_random_seed)
120
+ tf_random_seed = cfg["rnd.tf_random_seed"]
121
+ if tf_random_seed == "auto":
122
+ tf_random_seed = np.random.randint(1 << 31)
123
+ if tf_random_seed is not None:
124
+ tf.set_random_seed(tf_random_seed)
125
+
126
+ # Setup environment variables.
127
+ for key, value in cfg.items():
128
+ fields = key.split(".")
129
+ if fields[0] == "env":
130
+ assert len(fields) == 2
131
+ os.environ[fields[1]] = str(value)
132
+
133
+ # Create default TensorFlow session.
134
+ create_session(cfg, force_as_default=True)
135
+
136
+
137
+ def assert_tf_initialized():
138
+ """Check that TensorFlow session has been initialized."""
139
+ if tf.get_default_session() is None:
140
+ raise RuntimeError("No default TensorFlow session found. Please call dnnlib.tflib.init_tf().")
141
+
142
+
143
+ def create_session(config_dict: dict = None, force_as_default: bool = False) -> tf.Session:
144
+ """Create tf.Session based on config dict."""
145
+ # Setup TensorFlow config proto.
146
+ cfg = _sanitize_tf_config(config_dict)
147
+ config_proto = tf.ConfigProto()
148
+ for key, value in cfg.items():
149
+ fields = key.split(".")
150
+ if fields[0] not in ["rnd", "env"]:
151
+ obj = config_proto
152
+ for field in fields[:-1]:
153
+ obj = getattr(obj, field)
154
+ setattr(obj, fields[-1], value)
155
+
156
+ # Create session.
157
+ session = tf.Session(config=config_proto)
158
+ if force_as_default:
159
+ # pylint: disable=protected-access
160
+ session._default_session = session.as_default()
161
+ session._default_session.enforce_nesting = False
162
+ session._default_session.__enter__()
163
+ return session
164
+
165
+
166
+ def init_uninitialized_vars(target_vars: List[tf.Variable] = None) -> None:
167
+ """Initialize all tf.Variables that have not already been initialized.
168
+
169
+ Equivalent to the following, but more efficient and does not bloat the tf graph:
170
+ tf.variables_initializer(tf.report_uninitialized_variables()).run()
171
+ """
172
+ assert_tf_initialized()
173
+ if target_vars is None:
174
+ target_vars = tf.global_variables()
175
+
176
+ test_vars = []
177
+ test_ops = []
178
+
179
+ with tf.control_dependencies(None): # ignore surrounding control_dependencies
180
+ for var in target_vars:
181
+ assert is_tf_expression(var)
182
+
183
+ try:
184
+ tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/IsVariableInitialized:0"))
185
+ except KeyError:
186
+ # Op does not exist => variable may be uninitialized.
187
+ test_vars.append(var)
188
+
189
+ with absolute_name_scope(var.name.split(":")[0]):
190
+ test_ops.append(tf.is_variable_initialized(var))
191
+
192
+ init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited]
193
+ run([var.initializer for var in init_vars])
194
+
195
+
196
+ def set_vars(var_to_value_dict: dict) -> None:
197
+ """Set the values of given tf.Variables.
198
+
199
+ Equivalent to the following, but more efficient and does not bloat the tf graph:
200
+ tflib.run([tf.assign(var, value) for var, value in var_to_value_dict.items()]
201
+ """
202
+ assert_tf_initialized()
203
+ ops = []
204
+ feed_dict = {}
205
+
206
+ for var, value in var_to_value_dict.items():
207
+ assert is_tf_expression(var)
208
+
209
+ try:
210
+ setter = tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/setter:0")) # look for existing op
211
+ except KeyError:
212
+ with absolute_name_scope(var.name.split(":")[0]):
213
+ with tf.control_dependencies(None): # ignore surrounding control_dependencies
214
+ setter = tf.assign(var, tf.placeholder(var.dtype, var.shape, "new_value"), name="setter") # create new setter
215
+
216
+ ops.append(setter)
217
+ feed_dict[setter.op.inputs[1]] = value
218
+
219
+ run(ops, feed_dict)
220
+
221
+
222
+ def create_var_with_large_initial_value(initial_value: np.ndarray, *args, **kwargs):
223
+ """Create tf.Variable with large initial value without bloating the tf graph."""
224
+ assert_tf_initialized()
225
+ assert isinstance(initial_value, np.ndarray)
226
+ zeros = tf.zeros(initial_value.shape, initial_value.dtype)
227
+ var = tf.Variable(zeros, *args, **kwargs)
228
+ set_vars({var: initial_value})
229
+ return var
230
+
231
+
232
+ def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False):
233
+ """Convert a minibatch of images from uint8 to float32 with configurable dynamic range.
234
+ Can be used as an input transformation for Network.run().
235
+ """
236
+ images = tf.cast(images, tf.float32)
237
+ if nhwc_to_nchw:
238
+ images = tf.transpose(images, [0, 3, 1, 2])
239
+ return images * ((drange[1] - drange[0]) / 255) + drange[0]
240
+
241
+
242
+ def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, shrink=1):
243
+ """Convert a minibatch of images from float32 to uint8 with configurable dynamic range.
244
+ Can be used as an output transformation for Network.run().
245
+ """
246
+ images = tf.cast(images, tf.float32)
247
+ if shrink > 1:
248
+ ksize = [1, 1, shrink, shrink]
249
+ images = tf.nn.avg_pool(images, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW")
250
+ if nchw_to_nhwc:
251
+ images = tf.transpose(images, [0, 2, 3, 1])
252
+ scale = 255 / (drange[1] - drange[0])
253
+ images = images * scale + (0.5 - drange[0] * scale)
254
+ return tf.saturate_cast(images, tf.uint8)
dnnlib/util.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) SenseTime Research. All rights reserved.
2
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
3
+ #
4
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
5
+ # and proprietary rights in and to this software, related documentation
6
+ # and any modifications thereto. Any use, reproduction, disclosure or
7
+ # distribution of this software and related documentation without an express
8
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
9
+
10
+ """Miscellaneous utility classes and functions."""
11
+
12
+ import ctypes
13
+ import fnmatch
14
+ import importlib
15
+ import inspect
16
+ import numpy as np
17
+ import os
18
+ import shutil
19
+ import sys
20
+ import types
21
+ import io
22
+ import pickle
23
+ import re
24
+ import requests
25
+ import html
26
+ import hashlib
27
+ import glob
28
+ import tempfile
29
+ import urllib
30
+ import urllib.request
31
+ import uuid
32
+
33
+ from distutils.util import strtobool
34
+ from typing import Any, List, Tuple, Union
35
+
36
+
37
+ # Util classes
38
+ # ------------------------------------------------------------------------------------------
39
+
40
+
41
+ class EasyDict(dict):
42
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
43
+
44
+ def __getattr__(self, name: str) -> Any:
45
+ try:
46
+ return self[name]
47
+ except KeyError:
48
+ raise AttributeError(name)
49
+
50
+ def __setattr__(self, name: str, value: Any) -> None:
51
+ self[name] = value
52
+
53
+ def __delattr__(self, name: str) -> None:
54
+ del self[name]
55
+
56
+
57
+ class Logger(object):
58
+ """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
59
+
60
+ def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
61
+ self.file = None
62
+
63
+ if file_name is not None:
64
+ self.file = open(file_name, file_mode)
65
+
66
+ self.should_flush = should_flush
67
+ self.stdout = sys.stdout
68
+ self.stderr = sys.stderr
69
+
70
+ sys.stdout = self
71
+ sys.stderr = self
72
+
73
+ def __enter__(self) -> "Logger":
74
+ return self
75
+
76
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
77
+ self.close()
78
+
79
+ def write(self, text: Union[str, bytes]) -> None:
80
+ """Write text to stdout (and a file) and optionally flush."""
81
+ if isinstance(text, bytes):
82
+ text = text.decode()
83
+ if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
84
+ return
85
+
86
+ if self.file is not None:
87
+ self.file.write(text)
88
+
89
+ self.stdout.write(text)
90
+
91
+ if self.should_flush:
92
+ self.flush()
93
+
94
+ def flush(self) -> None:
95
+ """Flush written text to both stdout and a file, if open."""
96
+ if self.file is not None:
97
+ self.file.flush()
98
+
99
+ self.stdout.flush()
100
+
101
+ def close(self) -> None:
102
+ """Flush, close possible files, and remove stdout/stderr mirroring."""
103
+ self.flush()
104
+
105
+ # if using multiple loggers, prevent closing in wrong order
106
+ if sys.stdout is self:
107
+ sys.stdout = self.stdout
108
+ if sys.stderr is self:
109
+ sys.stderr = self.stderr
110
+
111
+ if self.file is not None:
112
+ self.file.close()
113
+ self.file = None
114
+
115
+
116
+ # Cache directories
117
+ # ------------------------------------------------------------------------------------------
118
+
119
+ _dnnlib_cache_dir = None
120
+
121
+ def set_cache_dir(path: str) -> None:
122
+ global _dnnlib_cache_dir
123
+ _dnnlib_cache_dir = path
124
+
125
+ def make_cache_dir_path(*paths: str) -> str:
126
+ if _dnnlib_cache_dir is not None:
127
+ return os.path.join(_dnnlib_cache_dir, *paths)
128
+ if 'DNNLIB_CACHE_DIR' in os.environ:
129
+ return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
130
+ if 'HOME' in os.environ:
131
+ return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
132
+ if 'USERPROFILE' in os.environ:
133
+ return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
134
+ return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
135
+
136
+ # Small util functions
137
+ # ------------------------------------------------------------------------------------------
138
+
139
+
140
+ def format_time(seconds: Union[int, float]) -> str:
141
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
142
+ s = int(np.rint(seconds))
143
+
144
+ if s < 60:
145
+ return "{0}s".format(s)
146
+ elif s < 60 * 60:
147
+ return "{0}m {1:02}s".format(s // 60, s % 60)
148
+ elif s < 24 * 60 * 60:
149
+ return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
150
+ else:
151
+ return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
152
+
153
+
154
+ def ask_yes_no(question: str) -> bool:
155
+ """Ask the user the question until the user inputs a valid answer."""
156
+ while True:
157
+ try:
158
+ print("{0} [y/n]".format(question))
159
+ return strtobool(input().lower())
160
+ except ValueError:
161
+ pass
162
+
163
+
164
+ def tuple_product(t: Tuple) -> Any:
165
+ """Calculate the product of the tuple elements."""
166
+ result = 1
167
+
168
+ for v in t:
169
+ result *= v
170
+
171
+ return result
172
+
173
+
174
+ _str_to_ctype = {
175
+ "uint8": ctypes.c_ubyte,
176
+ "uint16": ctypes.c_uint16,
177
+ "uint32": ctypes.c_uint32,
178
+ "uint64": ctypes.c_uint64,
179
+ "int8": ctypes.c_byte,
180
+ "int16": ctypes.c_int16,
181
+ "int32": ctypes.c_int32,
182
+ "int64": ctypes.c_int64,
183
+ "float32": ctypes.c_float,
184
+ "float64": ctypes.c_double
185
+ }
186
+
187
+
188
+ def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
189
+ """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
190
+ type_str = None
191
+
192
+ if isinstance(type_obj, str):
193
+ type_str = type_obj
194
+ elif hasattr(type_obj, "__name__"):
195
+ type_str = type_obj.__name__
196
+ elif hasattr(type_obj, "name"):
197
+ type_str = type_obj.name
198
+ else:
199
+ raise RuntimeError("Cannot infer type name from input")
200
+
201
+ assert type_str in _str_to_ctype.keys()
202
+
203
+ my_dtype = np.dtype(type_str)
204
+ my_ctype = _str_to_ctype[type_str]
205
+
206
+ assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
207
+
208
+ return my_dtype, my_ctype
209
+
210
+
211
+ def is_pickleable(obj: Any) -> bool:
212
+ try:
213
+ with io.BytesIO() as stream:
214
+ pickle.dump(obj, stream)
215
+ return True
216
+ except:
217
+ return False
218
+
219
+
220
+ # Functionality to import modules/objects by name, and call functions by name
221
+ # ------------------------------------------------------------------------------------------
222
+
223
+ def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
224
+ """Searches for the underlying module behind the name to some python object.
225
+ Returns the module and the object name (original name with module part removed)."""
226
+
227
+ # allow convenience shorthands, substitute them by full names
228
+ obj_name = re.sub("^np.", "numpy.", obj_name)
229
+ obj_name = re.sub("^tf.", "tensorflow.", obj_name)
230
+
231
+ # list alternatives for (module_name, local_obj_name)
232
+ parts = obj_name.split(".")
233
+ name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
234
+
235
+ # try each alternative in turn
236
+ for module_name, local_obj_name in name_pairs:
237
+ try:
238
+ module = importlib.import_module(module_name) # may raise ImportError
239
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
240
+ return module, local_obj_name
241
+ except:
242
+ pass
243
+
244
+ # maybe some of the modules themselves contain errors?
245
+ for module_name, _local_obj_name in name_pairs:
246
+ try:
247
+ importlib.import_module(module_name) # may raise ImportError
248
+ except ImportError:
249
+ if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
250
+ raise
251
+
252
+ # maybe the requested attribute is missing?
253
+ for module_name, local_obj_name in name_pairs:
254
+ try:
255
+ module = importlib.import_module(module_name) # may raise ImportError
256
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
257
+ except ImportError:
258
+ pass
259
+
260
+ # we are out of luck, but we have no idea why
261
+ raise ImportError(obj_name)
262
+
263
+
264
+ def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
265
+ """Traverses the object name and returns the last (rightmost) python object."""
266
+ if obj_name == '':
267
+ return module
268
+ obj = module
269
+ for part in obj_name.split("."):
270
+ obj = getattr(obj, part)
271
+ return obj
272
+
273
+
274
+ def get_obj_by_name(name: str) -> Any:
275
+ """Finds the python object with the given name."""
276
+ module, obj_name = get_module_from_obj_name(name)
277
+ return get_obj_from_module(module, obj_name)
278
+
279
+
280
+ def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
281
+ """Finds the python object with the given name and calls it as a function."""
282
+ assert func_name is not None
283
+ # print('func_name: ', func_name) #'training.dataset.ImageFolderDataset'
284
+ func_obj = get_obj_by_name(func_name)
285
+ assert callable(func_obj)
286
+ return func_obj(*args, **kwargs)
287
+
288
+
289
+ def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
290
+ """Finds the python class with the given name and constructs it with the given arguments."""
291
+ return call_func_by_name(*args, func_name=class_name, **kwargs)
292
+
293
+
294
+ def get_module_dir_by_obj_name(obj_name: str) -> str:
295
+ """Get the directory path of the module containing the given object name."""
296
+ module, _ = get_module_from_obj_name(obj_name)
297
+ return os.path.dirname(inspect.getfile(module))
298
+
299
+
300
+ def is_top_level_function(obj: Any) -> bool:
301
+ """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
302
+ return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
303
+
304
+
305
+ def get_top_level_function_name(obj: Any) -> str:
306
+ """Return the fully-qualified name of a top-level function."""
307
+ assert is_top_level_function(obj)
308
+ module = obj.__module__
309
+ if module == '__main__':
310
+ module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
311
+ return module + "." + obj.__name__
312
+
313
+
314
+ # File system helpers
315
+ # ------------------------------------------------------------------------------------------
316
+
317
+ def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
318
+ """List all files recursively in a given directory while ignoring given file and directory names.
319
+ Returns list of tuples containing both absolute and relative paths."""
320
+ assert os.path.isdir(dir_path)
321
+ base_name = os.path.basename(os.path.normpath(dir_path))
322
+
323
+ if ignores is None:
324
+ ignores = []
325
+
326
+ result = []
327
+
328
+ for root, dirs, files in os.walk(dir_path, topdown=True):
329
+ for ignore_ in ignores:
330
+ dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
331
+
332
+ # dirs need to be edited in-place
333
+ for d in dirs_to_remove:
334
+ dirs.remove(d)
335
+
336
+ files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
337
+
338
+ absolute_paths = [os.path.join(root, f) for f in files]
339
+ relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
340
+
341
+ if add_base_to_relative:
342
+ relative_paths = [os.path.join(base_name, p) for p in relative_paths]
343
+
344
+ assert len(absolute_paths) == len(relative_paths)
345
+ result += zip(absolute_paths, relative_paths)
346
+
347
+ return result
348
+
349
+
350
+ def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
351
+ """Takes in a list of tuples of (src, dst) paths and copies files.
352
+ Will create all necessary directories."""
353
+ for file in files:
354
+ target_dir_name = os.path.dirname(file[1])
355
+
356
+ # will create all intermediate-level directories
357
+ if not os.path.exists(target_dir_name):
358
+ os.makedirs(target_dir_name)
359
+
360
+ shutil.copyfile(file[0], file[1])
361
+
362
+
363
+ # URL helpers
364
+ # ------------------------------------------------------------------------------------------
365
+
366
+ def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
367
+ """Determine whether the given object is a valid URL string."""
368
+ if not isinstance(obj, str) or not "://" in obj:
369
+ return False
370
+ if allow_file_urls and obj.startswith('file://'):
371
+ return True
372
+ try:
373
+ res = requests.compat.urlparse(obj)
374
+ if not res.scheme or not res.netloc or not "." in res.netloc:
375
+ return False
376
+ res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
377
+ if not res.scheme or not res.netloc or not "." in res.netloc:
378
+ return False
379
+ except:
380
+ return False
381
+ return True
382
+
383
+
384
+ def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
385
+ """Download the given URL and return a binary-mode file object to access the data."""
386
+ assert num_attempts >= 1
387
+ assert not (return_filename and (not cache))
388
+
389
+ # Doesn't look like an URL scheme so interpret it as a local filename.
390
+ if not re.match('^[a-z]+://', url):
391
+ return url if return_filename else open(url, "rb")
392
+
393
+ # Handle file URLs. This code handles unusual file:// patterns that
394
+ # arise on Windows:
395
+ #
396
+ # file:///c:/foo.txt
397
+ #
398
+ # which would translate to a local '/c:/foo.txt' filename that's
399
+ # invalid. Drop the forward slash for such pathnames.
400
+ #
401
+ # If you touch this code path, you should test it on both Linux and
402
+ # Windows.
403
+ #
404
+ # Some internet resources suggest using urllib.request.url2pathname() but
405
+ # but that converts forward slashes to backslashes and this causes
406
+ # its own set of problems.
407
+ if url.startswith('file://'):
408
+ filename = urllib.parse.urlparse(url).path
409
+ if re.match(r'^/[a-zA-Z]:', filename):
410
+ filename = filename[1:]
411
+ return filename if return_filename else open(filename, "rb")
412
+
413
+ assert is_url(url)
414
+
415
+ # Lookup from cache.
416
+ if cache_dir is None:
417
+ cache_dir = make_cache_dir_path('downloads')
418
+
419
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
420
+ if cache:
421
+ cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
422
+ if len(cache_files) == 1:
423
+ filename = cache_files[0]
424
+ return filename if return_filename else open(filename, "rb")
425
+
426
+ # Download.
427
+ url_name = None
428
+ url_data = None
429
+ with requests.Session() as session:
430
+ if verbose:
431
+ print("Downloading %s ..." % url, end="", flush=True)
432
+ for attempts_left in reversed(range(num_attempts)):
433
+ try:
434
+ with session.get(url) as res:
435
+ res.raise_for_status()
436
+ if len(res.content) == 0:
437
+ raise IOError("No data received")
438
+
439
+ if len(res.content) < 8192:
440
+ content_str = res.content.decode("utf-8")
441
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
442
+ links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
443
+ if len(links) == 1:
444
+ url = requests.compat.urljoin(url, links[0])
445
+ raise IOError("Google Drive virus checker nag")
446
+ if "Google Drive - Quota exceeded" in content_str:
447
+ raise IOError("Google Drive download quota exceeded -- please try again later")
448
+
449
+ match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
450
+ url_name = match[1] if match else url
451
+ url_data = res.content
452
+ if verbose:
453
+ print(" done")
454
+ break
455
+ except KeyboardInterrupt:
456
+ raise
457
+ except:
458
+ if not attempts_left:
459
+ if verbose:
460
+ print(" failed")
461
+ raise
462
+ if verbose:
463
+ print(".", end="", flush=True)
464
+
465
+ # Save to cache.
466
+ if cache:
467
+ safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
468
+ cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
469
+ temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
470
+ os.makedirs(cache_dir, exist_ok=True)
471
+ with open(temp_file, "wb") as f:
472
+ f.write(url_data)
473
+ os.replace(temp_file, cache_file) # atomic
474
+ if return_filename:
475
+ return cache_file
476
+
477
+ # Return data as file object.
478
+ assert not return_filename
479
+ return io.BytesIO(url_data)
torch_utils/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ # and proprietary rights in and to this software, related documentation
7
+ # and any modifications thereto. Any use, reproduction, disclosure or
8
+ # distribution of this software and related documentation without an express
9
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+
11
+ # empty
torch_utils/custom_ops.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ # and proprietary rights in and to this software, related documentation
7
+ # and any modifications thereto. Any use, reproduction, disclosure or
8
+ # distribution of this software and related documentation without an express
9
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+
11
+ import os
12
+ import glob
13
+ import torch
14
+ import torch.utils.cpp_extension
15
+ import importlib
16
+ import hashlib
17
+ import shutil
18
+ from pathlib import Path
19
+ import re
20
+ import uuid
21
+
22
+ from torch.utils.file_baton import FileBaton
23
+
24
+ #----------------------------------------------------------------------------
25
+ # Global options.
26
+
27
+ verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
28
+
29
+ #----------------------------------------------------------------------------
30
+ # Internal helper funcs.
31
+
32
+ def _find_compiler_bindir():
33
+ patterns = [
34
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
35
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
36
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
37
+ 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
38
+ ]
39
+ for pattern in patterns:
40
+ matches = sorted(glob.glob(pattern))
41
+ if len(matches):
42
+ return matches[-1]
43
+ return None
44
+
45
+ def _get_mangled_gpu_name():
46
+ name = torch.cuda.get_device_name().lower()
47
+ out = []
48
+ for c in name:
49
+ if re.match('[a-z0-9_-]+', c):
50
+ out.append(c)
51
+ else:
52
+ out.append('-')
53
+ return ''.join(out)
54
+
55
+
56
+ #----------------------------------------------------------------------------
57
+ # Main entry point for compiling and loading C++/CUDA plugins.
58
+
59
+ _cached_plugins = dict()
60
+
61
+ def get_plugin(module_name, sources, **build_kwargs):
62
+ assert verbosity in ['none', 'brief', 'full']
63
+
64
+ # Already cached?
65
+ if module_name in _cached_plugins:
66
+ return _cached_plugins[module_name]
67
+
68
+ # Print status.
69
+ if verbosity == 'full':
70
+ print(f'Setting up PyTorch plugin "{module_name}"...')
71
+ elif verbosity == 'brief':
72
+ print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
73
+
74
+ try: # pylint: disable=too-many-nested-blocks
75
+ # Make sure we can find the necessary compiler binaries.
76
+ if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
77
+ compiler_bindir = _find_compiler_bindir()
78
+ if compiler_bindir is None:
79
+ raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
80
+ os.environ['PATH'] += ';' + compiler_bindir
81
+
82
+ # Compile and load.
83
+ verbose_build = (verbosity == 'full')
84
+
85
+ # Incremental build md5sum trickery. Copies all the input source files
86
+ # into a cached build directory under a combined md5 digest of the input
87
+ # source files. Copying is done only if the combined digest has changed.
88
+ # This keeps input file timestamps and filenames the same as in previous
89
+ # extension builds, allowing for fast incremental rebuilds.
90
+ #
91
+ # This optimization is done only in case all the source files reside in
92
+ # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
93
+ # environment variable is set (we take this as a signal that the user
94
+ # actually cares about this.)
95
+ source_dirs_set = set(os.path.dirname(source) for source in sources)
96
+ if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ):
97
+ all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file()))
98
+
99
+ # Compute a combined hash digest for all source files in the same
100
+ # custom op directory (usually .cu, .cpp, .py and .h files).
101
+ hash_md5 = hashlib.md5()
102
+ for src in all_source_files:
103
+ with open(src, 'rb') as f:
104
+ hash_md5.update(f.read())
105
+ build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
106
+ digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest())
107
+
108
+ if not os.path.isdir(digest_build_dir):
109
+ os.makedirs(digest_build_dir, exist_ok=True)
110
+ baton = FileBaton(os.path.join(digest_build_dir, 'lock'))
111
+ if baton.try_acquire():
112
+ try:
113
+ for src in all_source_files:
114
+ shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src)))
115
+ finally:
116
+ baton.release()
117
+ else:
118
+ # Someone else is copying source files under the digest dir,
119
+ # wait until done and continue.
120
+ baton.wait()
121
+ digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources]
122
+ torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir,
123
+ verbose=verbose_build, sources=digest_sources, **build_kwargs)
124
+ else:
125
+ torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
126
+ module = importlib.import_module(module_name)
127
+
128
+ except:
129
+ if verbosity == 'brief':
130
+ print('Failed!')
131
+ raise
132
+
133
+ # Print status and add to cache.
134
+ if verbosity == 'full':
135
+ print(f'Done setting up PyTorch plugin "{module_name}".')
136
+ elif verbosity == 'brief':
137
+ print('Done.')
138
+ _cached_plugins[module_name] = module
139
+ return module
140
+
141
+ #----------------------------------------------------------------------------
142
+ def get_plugin_v3(module_name, sources, headers=None, source_dir=None, **build_kwargs):
143
+ assert verbosity in ['none', 'brief', 'full']
144
+ if headers is None:
145
+ headers = []
146
+ if source_dir is not None:
147
+ sources = [os.path.join(source_dir, fname) for fname in sources]
148
+ headers = [os.path.join(source_dir, fname) for fname in headers]
149
+
150
+ # Already cached?
151
+ if module_name in _cached_plugins:
152
+ return _cached_plugins[module_name]
153
+
154
+ # Print status.
155
+ if verbosity == 'full':
156
+ print(f'Setting up PyTorch plugin "{module_name}"...')
157
+ elif verbosity == 'brief':
158
+ print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
159
+ verbose_build = (verbosity == 'full')
160
+
161
+ # Compile and load.
162
+ try: # pylint: disable=too-many-nested-blocks
163
+ # Make sure we can find the necessary compiler binaries.
164
+ if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
165
+ compiler_bindir = _find_compiler_bindir()
166
+ if compiler_bindir is None:
167
+ raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
168
+ os.environ['PATH'] += ';' + compiler_bindir
169
+
170
+ # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either
171
+ # break the build or unnecessarily restrict what's available to nvcc.
172
+ # Unset it to let nvcc decide based on what's available on the
173
+ # machine.
174
+ os.environ['TORCH_CUDA_ARCH_LIST'] = ''
175
+
176
+ # Incremental build md5sum trickery. Copies all the input source files
177
+ # into a cached build directory under a combined md5 digest of the input
178
+ # source files. Copying is done only if the combined digest has changed.
179
+ # This keeps input file timestamps and filenames the same as in previous
180
+ # extension builds, allowing for fast incremental rebuilds.
181
+ #
182
+ # This optimization is done only in case all the source files reside in
183
+ # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
184
+ # environment variable is set (we take this as a signal that the user
185
+ # actually cares about this.)
186
+ #
187
+ # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work
188
+ # around the *.cu dependency bug in ninja config.
189
+ #
190
+ all_source_files = sorted(sources + headers)
191
+ all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files)
192
+ if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ):
193
+
194
+ # Compute combined hash digest for all source files.
195
+ hash_md5 = hashlib.md5()
196
+ for src in all_source_files:
197
+ with open(src, 'rb') as f:
198
+ hash_md5.update(f.read())
199
+
200
+ # Select cached build directory name.
201
+ source_digest = hash_md5.hexdigest()
202
+ build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
203
+ cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}')
204
+
205
+ if not os.path.isdir(cached_build_dir):
206
+ tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}'
207
+ os.makedirs(tmpdir)
208
+ for src in all_source_files:
209
+ shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src)))
210
+ try:
211
+ os.replace(tmpdir, cached_build_dir) # atomic
212
+ except OSError:
213
+ # source directory already exists, delete tmpdir and its contents.
214
+ shutil.rmtree(tmpdir)
215
+ if not os.path.isdir(cached_build_dir): raise
216
+
217
+ # Compile.
218
+ cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources]
219
+ torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir,
220
+ verbose=verbose_build, sources=cached_sources, **build_kwargs)
221
+ else:
222
+ torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
223
+
224
+ # Load.
225
+ module = importlib.import_module(module_name)
226
+
227
+ except:
228
+ if verbosity == 'brief':
229
+ print('Failed!')
230
+ raise
231
+
232
+ # Print status and add to cache dict.
233
+ if verbosity == 'full':
234
+ print(f'Done setting up PyTorch plugin "{module_name}".')
235
+ elif verbosity == 'brief':
236
+ print('Done.')
237
+ _cached_plugins[module_name] = module
238
+ return module
torch_utils/misc.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ # and proprietary rights in and to this software, related documentation
7
+ # and any modifications thereto. Any use, reproduction, disclosure or
8
+ # distribution of this software and related documentation without an express
9
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+
11
+ import re
12
+ import contextlib
13
+ import numpy as np
14
+ import torch
15
+ import warnings
16
+ import dnnlib
17
+
18
+ #----------------------------------------------------------------------------
19
+ # Cached construction of constant tensors. Avoids CPU=>GPU copy when the
20
+ # same constant is used multiple times.
21
+
22
+ _constant_cache = dict()
23
+
24
+ def constant(value, shape=None, dtype=None, device=None, memory_format=None):
25
+ value = np.asarray(value)
26
+ if shape is not None:
27
+ shape = tuple(shape)
28
+ if dtype is None:
29
+ dtype = torch.get_default_dtype()
30
+ if device is None:
31
+ device = torch.device('cpu')
32
+ if memory_format is None:
33
+ memory_format = torch.contiguous_format
34
+
35
+ key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
36
+ tensor = _constant_cache.get(key, None)
37
+ if tensor is None:
38
+ tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
39
+ if shape is not None:
40
+ tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
41
+ tensor = tensor.contiguous(memory_format=memory_format)
42
+ _constant_cache[key] = tensor
43
+ return tensor
44
+
45
+ #----------------------------------------------------------------------------
46
+ # Replace NaN/Inf with specified numerical values.
47
+
48
+ try:
49
+ nan_to_num = torch.nan_to_num # 1.8.0a0
50
+ except AttributeError:
51
+ def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
52
+ assert isinstance(input, torch.Tensor)
53
+ if posinf is None:
54
+ posinf = torch.finfo(input.dtype).max
55
+ if neginf is None:
56
+ neginf = torch.finfo(input.dtype).min
57
+ assert nan == 0
58
+ return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
59
+
60
+ #----------------------------------------------------------------------------
61
+ # Symbolic assert.
62
+
63
+ try:
64
+ symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
65
+ except AttributeError:
66
+ symbolic_assert = torch.Assert # 1.7.0
67
+
68
+ #----------------------------------------------------------------------------
69
+ # Context manager to suppress known warnings in torch.jit.trace().
70
+
71
+ class suppress_tracer_warnings(warnings.catch_warnings):
72
+ def __enter__(self):
73
+ super().__enter__()
74
+ warnings.simplefilter('ignore', category=torch.jit.TracerWarning)
75
+ return self
76
+
77
+ #----------------------------------------------------------------------------
78
+ # Assert that the shape of a tensor matches the given list of integers.
79
+ # None indicates that the size of a dimension is allowed to vary.
80
+ # Performs symbolic assertion when used in torch.jit.trace().
81
+
82
+ def assert_shape(tensor, ref_shape):
83
+ if tensor.ndim != len(ref_shape):
84
+ raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
85
+ for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
86
+ if ref_size is None:
87
+ pass
88
+ elif isinstance(ref_size, torch.Tensor):
89
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
90
+ symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
91
+ elif isinstance(size, torch.Tensor):
92
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
93
+ symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
94
+ elif size != ref_size:
95
+ raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
96
+
97
+ #----------------------------------------------------------------------------
98
+ # Function decorator that calls torch.autograd.profiler.record_function().
99
+
100
+ def profiled_function(fn):
101
+ def decorator(*args, **kwargs):
102
+ with torch.autograd.profiler.record_function(fn.__name__):
103
+ return fn(*args, **kwargs)
104
+ decorator.__name__ = fn.__name__
105
+ return decorator
106
+
107
+ #----------------------------------------------------------------------------
108
+ # Sampler for torch.utils.data.DataLoader that loops over the dataset
109
+ # indefinitely, shuffling items as it goes.
110
+
111
+ class InfiniteSampler(torch.utils.data.Sampler):
112
+ def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
113
+ assert len(dataset) > 0
114
+ assert num_replicas > 0
115
+ assert 0 <= rank < num_replicas
116
+ assert 0 <= window_size <= 1
117
+ super().__init__(dataset)
118
+ self.dataset = dataset
119
+ self.rank = rank
120
+ self.num_replicas = num_replicas
121
+ self.shuffle = shuffle
122
+ self.seed = seed
123
+ self.window_size = window_size
124
+
125
+ def __iter__(self):
126
+ order = np.arange(len(self.dataset))
127
+ rnd = None
128
+ window = 0
129
+ if self.shuffle:
130
+ rnd = np.random.RandomState(self.seed)
131
+ rnd.shuffle(order)
132
+ window = int(np.rint(order.size * self.window_size))
133
+
134
+ idx = 0
135
+ while True:
136
+ i = idx % order.size
137
+ if idx % self.num_replicas == self.rank:
138
+ yield order[i]
139
+ if window >= 2:
140
+ j = (i - rnd.randint(window)) % order.size
141
+ order[i], order[j] = order[j], order[i]
142
+ idx += 1
143
+
144
+ #----------------------------------------------------------------------------
145
+ # Utilities for operating with torch.nn.Module parameters and buffers.
146
+
147
+ def params_and_buffers(module):
148
+ assert isinstance(module, torch.nn.Module)
149
+ return list(module.parameters()) + list(module.buffers())
150
+
151
+ def named_params_and_buffers(module):
152
+ assert isinstance(module, torch.nn.Module)
153
+ return list(module.named_parameters()) + list(module.named_buffers())
154
+
155
+ def copy_params_and_buffers(src_module, dst_module, require_all=False):
156
+ assert isinstance(src_module, torch.nn.Module)
157
+ assert isinstance(dst_module, torch.nn.Module)
158
+ src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)}
159
+ for name, tensor in named_params_and_buffers(dst_module):
160
+ assert (name in src_tensors) or (not require_all)
161
+ if name in src_tensors:
162
+ tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
163
+
164
+ #----------------------------------------------------------------------------
165
+ # Context manager for easily enabling/disabling DistributedDataParallel
166
+ # synchronization.
167
+
168
+ @contextlib.contextmanager
169
+ def ddp_sync(module, sync):
170
+ assert isinstance(module, torch.nn.Module)
171
+ if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
172
+ yield
173
+ else:
174
+ with module.no_sync():
175
+ yield
176
+
177
+ #----------------------------------------------------------------------------
178
+ # Check DistributedDataParallel consistency across processes.
179
+
180
+ def check_ddp_consistency(module, ignore_regex=None):
181
+ assert isinstance(module, torch.nn.Module)
182
+ for name, tensor in named_params_and_buffers(module):
183
+ fullname = type(module).__name__ + '.' + name
184
+ if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
185
+ continue
186
+ tensor = tensor.detach()
187
+ other = tensor.clone()
188
+ torch.distributed.broadcast(tensor=other, src=0)
189
+ assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname
190
+
191
+ #----------------------------------------------------------------------------
192
+ # Print summary table of module hierarchy.
193
+
194
+ def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
195
+ assert isinstance(module, torch.nn.Module)
196
+ assert not isinstance(module, torch.jit.ScriptModule)
197
+ assert isinstance(inputs, (tuple, list))
198
+
199
+ # Register hooks.
200
+ entries = []
201
+ nesting = [0]
202
+ def pre_hook(_mod, _inputs):
203
+ nesting[0] += 1
204
+ def post_hook(mod, _inputs, outputs):
205
+ nesting[0] -= 1
206
+ if nesting[0] <= max_nesting:
207
+ outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
208
+ outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
209
+ entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
210
+ hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
211
+ hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
212
+
213
+ # Run module.
214
+ outputs = module(*inputs)
215
+ for hook in hooks:
216
+ hook.remove()
217
+
218
+ # Identify unique outputs, parameters, and buffers.
219
+ tensors_seen = set()
220
+ for e in entries:
221
+ e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
222
+ e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
223
+ e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
224
+ tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
225
+
226
+ # Filter out redundant entries.
227
+ if skip_redundant:
228
+ entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
229
+
230
+ # Construct table.
231
+ rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
232
+ rows += [['---'] * len(rows[0])]
233
+ param_total = 0
234
+ buffer_total = 0
235
+ submodule_names = {mod: name for name, mod in module.named_modules()}
236
+ for e in entries:
237
+ name = '<top-level>' if e.mod is module else submodule_names[e.mod]
238
+ param_size = sum(t.numel() for t in e.unique_params)
239
+ buffer_size = sum(t.numel() for t in e.unique_buffers)
240
+ output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs]
241
+ output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
242
+ rows += [[
243
+ name + (':0' if len(e.outputs) >= 2 else ''),
244
+ str(param_size) if param_size else '-',
245
+ str(buffer_size) if buffer_size else '-',
246
+ (output_shapes + ['-'])[0],
247
+ (output_dtypes + ['-'])[0],
248
+ ]]
249
+ for idx in range(1, len(e.outputs)):
250
+ rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
251
+ param_total += param_size
252
+ buffer_total += buffer_size
253
+ rows += [['---'] * len(rows[0])]
254
+ rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
255
+
256
+ # Print table.
257
+ widths = [max(len(cell) for cell in column) for column in zip(*rows)]
258
+ print()
259
+ for row in rows:
260
+ print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
261
+ print()
262
+ return outputs
263
+
264
+ #----------------------------------------------------------------------------
torch_utils/models.py ADDED
@@ -0,0 +1,756 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ # https://github.com/rosinality/stylegan2-pytorch/blob/master/model.py
4
+
5
+ import math
6
+ import random
7
+ import functools
8
+ import operator
9
+
10
+ import torch
11
+ from torch import nn
12
+ from torch.nn import functional as F
13
+ import torch.nn.init as init
14
+ from torch.autograd import Function
15
+
16
+ from .op_edit import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
17
+
18
+
19
+ class PixelNorm(nn.Module):
20
+ def __init__(self):
21
+ super().__init__()
22
+
23
+ def forward(self, input):
24
+ return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
25
+
26
+
27
+ def make_kernel(k):
28
+ k = torch.tensor(k, dtype=torch.float32)
29
+ if k.ndim == 1:
30
+ k = k[None, :] * k[:, None]
31
+ k /= k.sum()
32
+ return k
33
+
34
+
35
+ class Upsample(nn.Module):
36
+ def __init__(self, kernel, factor=2):
37
+ super().__init__()
38
+
39
+ self.factor = factor
40
+ kernel = make_kernel(kernel) * (factor ** 2)
41
+ self.register_buffer("kernel", kernel)
42
+
43
+ p = kernel.shape[0] - factor
44
+
45
+ pad0 = (p + 1) // 2 + factor - 1
46
+ pad1 = p // 2
47
+
48
+ self.pad = (pad0, pad1)
49
+
50
+ def forward(self, input):
51
+ out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
52
+ return out
53
+
54
+
55
+ class Downsample(nn.Module):
56
+ def __init__(self, kernel, factor=2):
57
+ super().__init__()
58
+
59
+ self.factor = factor
60
+ kernel = make_kernel(kernel)
61
+ self.register_buffer("kernel", kernel)
62
+
63
+ p = kernel.shape[0] - factor
64
+
65
+ pad0 = (p + 1) // 2
66
+ pad1 = p // 2
67
+
68
+ self.pad = (pad0, pad1)
69
+
70
+ def forward(self, input):
71
+ out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
72
+ return out
73
+
74
+
75
+ class Blur(nn.Module):
76
+ def __init__(self, kernel, pad, upsample_factor=1):
77
+ super().__init__()
78
+
79
+ kernel = make_kernel(kernel)
80
+
81
+ if upsample_factor > 1:
82
+ kernel = kernel * (upsample_factor ** 2)
83
+
84
+ self.register_buffer("kernel", kernel)
85
+
86
+ self.pad = pad
87
+
88
+ def forward(self, input):
89
+ out = upfirdn2d(input, self.kernel, pad=self.pad)
90
+ return out
91
+
92
+
93
+ class EqualConv2d(nn.Module):
94
+ def __init__(
95
+ self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
96
+ ):
97
+ super().__init__()
98
+
99
+ self.weight = nn.Parameter(
100
+ torch.randn(out_channel, in_channel, kernel_size, kernel_size)
101
+ )
102
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
103
+
104
+ self.stride = stride
105
+ self.padding = padding
106
+
107
+ if bias:
108
+ self.bias = nn.Parameter(torch.zeros(out_channel))
109
+
110
+ else:
111
+ self.bias = None
112
+
113
+ def forward(self, input):
114
+ out = F.conv2d(
115
+ input,
116
+ self.weight * self.scale,
117
+ bias=self.bias,
118
+ stride=self.stride,
119
+ padding=self.padding,
120
+ )
121
+ return out
122
+
123
+ def __repr__(self):
124
+ return (
125
+ f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
126
+ f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
127
+ )
128
+
129
+
130
+ class EqualLinear(nn.Module):
131
+ def __init__(
132
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
133
+ ):
134
+ super().__init__()
135
+
136
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
137
+
138
+ if bias:
139
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
140
+ else:
141
+ self.bias = None
142
+
143
+ self.activation = activation
144
+
145
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
146
+ self.lr_mul = lr_mul
147
+
148
+ def forward(self, input):
149
+ if self.activation:
150
+ out = F.linear(input, self.weight * self.scale)
151
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
152
+ else:
153
+ out = F.linear(
154
+ input, self.weight * self.scale, bias=self.bias * self.lr_mul
155
+ )
156
+ return out
157
+
158
+ def __repr__(self):
159
+ return (
160
+ f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
161
+ )
162
+
163
+
164
+ class ScaledLeakyReLU(nn.Module):
165
+ def __init__(self, negative_slope=0.2):
166
+ super().__init__()
167
+ self.negative_slope = negative_slope
168
+
169
+ def forward(self, input):
170
+ out = F.leaky_relu(input, negative_slope=self.negative_slope)
171
+ return out * math.sqrt(2)
172
+
173
+
174
+ class ModulatedConv2d(nn.Module):
175
+ def __init__(
176
+ self,
177
+ in_channel,
178
+ out_channel,
179
+ kernel_size,
180
+ style_dim,
181
+ demodulate=True,
182
+ upsample=False,
183
+ downsample=False,
184
+ blur_kernel=[1, 3, 3, 1],
185
+ ):
186
+ super().__init__()
187
+
188
+ self.eps = 1e-8
189
+ self.kernel_size = kernel_size
190
+ self.in_channel = in_channel
191
+ self.out_channel = out_channel
192
+ self.upsample = upsample
193
+ self.downsample = downsample
194
+
195
+ if upsample:
196
+ factor = 2
197
+ p = (len(blur_kernel) - factor) - (kernel_size - 1)
198
+ pad0 = (p + 1) // 2 + factor - 1
199
+ pad1 = p // 2 + 1
200
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
201
+
202
+ if downsample:
203
+ factor = 2
204
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
205
+ pad0 = (p + 1) // 2
206
+ pad1 = p // 2
207
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1))
208
+
209
+ fan_in = in_channel * kernel_size ** 2
210
+ self.scale = 1 / math.sqrt(fan_in)
211
+ self.padding = kernel_size // 2
212
+ self.weight = nn.Parameter(
213
+ torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
214
+ )
215
+ self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
216
+ self.demodulate = demodulate
217
+
218
+ def __repr__(self):
219
+ return (
220
+ f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
221
+ f"upsample={self.upsample}, downsample={self.downsample})"
222
+ )
223
+
224
+ def forward(self, input, style):
225
+ batch, in_channel, height, width = input.shape
226
+
227
+ style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
228
+ weight = self.scale * self.weight * style
229
+
230
+ if self.demodulate:
231
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
232
+ weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
233
+
234
+ weight = weight.view(
235
+ batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
236
+ )
237
+
238
+ if self.upsample:
239
+ input = input.view(1, batch * in_channel, height, width)
240
+ weight = weight.view(
241
+ batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
242
+ )
243
+ weight = weight.transpose(1, 2).reshape(
244
+ batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
245
+ )
246
+ out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
247
+ _, _, height, width = out.shape
248
+ out = out.view(batch, self.out_channel, height, width)
249
+ out = self.blur(out)
250
+
251
+ elif self.downsample:
252
+ input = self.blur(input)
253
+ _, _, height, width = input.shape
254
+ input = input.view(1, batch * in_channel, height, width)
255
+ out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
256
+ _, _, height, width = out.shape
257
+ out = out.view(batch, self.out_channel, height, width)
258
+
259
+ else:
260
+ input = input.view(1, batch * in_channel, height, width)
261
+ out = F.conv2d(input, weight, padding=self.padding, groups=batch)
262
+ _, _, height, width = out.shape
263
+ out = out.view(batch, self.out_channel, height, width)
264
+
265
+ return out
266
+
267
+
268
+ class NoiseInjection(nn.Module):
269
+ def __init__(self):
270
+ super().__init__()
271
+ self.weight = nn.Parameter(torch.zeros(1))
272
+
273
+ def forward(self, image, noise=None):
274
+ if noise is None:
275
+ batch, _, height, width = image.shape
276
+ noise = image.new_empty(batch, 1, height, width).normal_()
277
+ return image + self.weight * noise
278
+
279
+
280
+ class ConstantInput(nn.Module):
281
+ def __init__(self, channel, size=4):
282
+ super().__init__()
283
+ self.input = nn.Parameter(torch.randn(1, channel, size, size // 2))
284
+
285
+ def forward(self, input):
286
+ batch = input.shape[0]
287
+ out = self.input.repeat(batch, 1, 1, 1)
288
+ return out
289
+
290
+
291
+ class StyledConv(nn.Module):
292
+ def __init__(
293
+ self,
294
+ in_channel,
295
+ out_channel,
296
+ kernel_size,
297
+ style_dim,
298
+ upsample=False,
299
+ blur_kernel=[1, 3, 3, 1],
300
+ demodulate=True,
301
+ ):
302
+ super().__init__()
303
+ self.conv = ModulatedConv2d(
304
+ in_channel,
305
+ out_channel,
306
+ kernel_size,
307
+ style_dim,
308
+ upsample=upsample,
309
+ blur_kernel=blur_kernel,
310
+ demodulate=demodulate,
311
+ )
312
+ self.noise = NoiseInjection()
313
+ self.activate = FusedLeakyReLU(out_channel)
314
+
315
+ def forward(self, input, style, noise=None):
316
+ out = self.conv(input, style)
317
+ out = self.noise(out, noise=noise)
318
+ out = self.activate(out)
319
+ return out
320
+
321
+
322
+ class ToRGB(nn.Module):
323
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
324
+ super().__init__()
325
+ if upsample:
326
+ self.upsample = Upsample(blur_kernel)
327
+
328
+ self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
329
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
330
+
331
+ def forward(self, input, style, skip=None):
332
+ out = self.conv(input, style)
333
+ out = out + self.bias
334
+
335
+ if skip is not None:
336
+ skip = self.upsample(skip)
337
+ out = out + skip
338
+
339
+ return out
340
+
341
+
342
+ class Generator(nn.Module):
343
+ def __init__(
344
+ self,
345
+ size,
346
+ style_dim,
347
+ n_mlp,
348
+ channel_multiplier=1,
349
+ blur_kernel=[1, 3, 3, 1],
350
+ lr_mlp=0.01,
351
+ small=False,
352
+ small_isaac=False,
353
+ ):
354
+ super().__init__()
355
+
356
+ self.size = size
357
+
358
+ if small and size > 64:
359
+ raise ValueError("small only works for sizes <= 64")
360
+
361
+ self.style_dim = style_dim
362
+ layers = [PixelNorm()]
363
+
364
+ for i in range(n_mlp):
365
+ layers.append(
366
+ EqualLinear(
367
+ style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
368
+ )
369
+ )
370
+
371
+ self.style = nn.Sequential(*layers)
372
+
373
+ if small:
374
+ self.channels = {
375
+ 4: 64 * channel_multiplier,
376
+ 8: 64 * channel_multiplier,
377
+ 16: 64 * channel_multiplier,
378
+ 32: 64 * channel_multiplier,
379
+ 64: 64 * channel_multiplier,
380
+ }
381
+ elif small_isaac:
382
+ self.channels = {4: 256, 8: 256, 16: 256, 32: 256, 64: 128, 128: 128}
383
+ else:
384
+ self.channels = {
385
+ 4: 512,
386
+ 8: 512,
387
+ 16: 512,
388
+ 32: 512,
389
+ 64: 256 * channel_multiplier,
390
+ 128: 128 * channel_multiplier,
391
+ 256: 64 * channel_multiplier,
392
+ 512: 32 * channel_multiplier,
393
+ 1024: 16 * channel_multiplier,
394
+ }
395
+
396
+ self.input = ConstantInput(self.channels[4])
397
+ self.conv1 = StyledConv(
398
+ self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
399
+ )
400
+ self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
401
+
402
+ self.log_size = int(math.log(size, 2))
403
+ self.num_layers = (self.log_size - 2) * 2 + 1
404
+
405
+ self.convs = nn.ModuleList()
406
+ self.upsamples = nn.ModuleList()
407
+ self.to_rgbs = nn.ModuleList()
408
+ self.noises = nn.Module()
409
+
410
+ in_channel = self.channels[4]
411
+
412
+ for layer_idx in range(self.num_layers):
413
+ res = (layer_idx + 5) // 2
414
+ shape = [1, 1, 2 ** res, 2 ** res // 2]
415
+ self.noises.register_buffer(
416
+ "noise_{}".format(layer_idx), torch.randn(*shape)
417
+ )
418
+
419
+ for i in range(3, self.log_size + 1):
420
+ out_channel = self.channels[2 ** i]
421
+
422
+ self.convs.append(
423
+ StyledConv(
424
+ in_channel,
425
+ out_channel,
426
+ 3,
427
+ style_dim,
428
+ upsample=True,
429
+ blur_kernel=blur_kernel,
430
+ )
431
+ )
432
+
433
+ self.convs.append(
434
+ StyledConv(
435
+ out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
436
+ )
437
+ )
438
+
439
+ self.to_rgbs.append(ToRGB(out_channel, style_dim))
440
+ in_channel = out_channel
441
+
442
+ self.n_latent = self.log_size * 2 - 2
443
+
444
+ def make_noise(self):
445
+ device = self.input.input.device
446
+
447
+ noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2 // 2, device=device)]
448
+
449
+ for i in range(3, self.log_size + 1):
450
+ for _ in range(2):
451
+ noises.append(torch.randn(1, 1, 2 ** i, 2 ** i // 2, device=device))
452
+
453
+ return noises
454
+
455
+ def mean_latent(self, n_latent):
456
+ latent_in = torch.randn(
457
+ n_latent, self.style_dim, device=self.input.input.device
458
+ )
459
+ latent = self.style(latent_in).mean(0, keepdim=True)
460
+
461
+ return latent
462
+
463
+ def get_latent(self, input):
464
+ return self.style(input)
465
+
466
+ def forward(
467
+ self,
468
+ styles,
469
+ return_latents=False,
470
+ return_features=False,
471
+ inject_index=None,
472
+ truncation=1,
473
+ truncation_latent=None,
474
+ input_is_latent=False,
475
+ noise=None,
476
+ randomize_noise=True,
477
+ real=False,
478
+ ):
479
+ if not input_is_latent:
480
+ styles = [self.style(s) for s in styles]
481
+ if noise is None:
482
+ if randomize_noise:
483
+ noise = [None] * self.num_layers
484
+ else:
485
+ noise = [
486
+ getattr(self.noises, "noise_{}".format(i))
487
+ for i in range(self.num_layers)
488
+ ]
489
+
490
+ if truncation < 1:
491
+ # print('truncation_latent: ', truncation_latent.shape)
492
+ if not real: #if type(styles) == list:
493
+ style_t = []
494
+ for style in styles:
495
+ style_t.append(
496
+ truncation_latent + truncation * (style - truncation_latent)
497
+ ) # (-1.1162e-03-(-1.0914e-01))*0.8+(-1.0914e-01)
498
+ styles = style_t
499
+ else: # styles are latent (tensor: 1,18,512), for real PTI output
500
+ truncation_latent = truncation_latent.repeat(18,1).unsqueeze(0) # (1,512) --> (1,18,512)
501
+ styles = torch.add(truncation_latent,torch.mul(torch.sub(styles,truncation_latent),truncation))
502
+ # print('now styles after truncation : ', styles)
503
+ #if type(styles) == list and len(styles) < 2: # this if for input as list of [(1,512)]
504
+ if not real:
505
+ if len(styles) < 2:
506
+ inject_index = self.n_latent
507
+ if styles[0].ndim < 3:
508
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
509
+ else:
510
+ latent = styles[0]
511
+ elif type(styles) == list:
512
+ if inject_index is None:
513
+ inject_index = 4
514
+
515
+ latent = styles[0].unsqueeze(0)
516
+ if latent.shape[1] == 1:
517
+ latent = latent.repeat(1, inject_index, 1)
518
+ else:
519
+ latent = latent[:, :inject_index, :]
520
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
521
+ latent = torch.cat([latent, latent2], 1)
522
+ else: # input is tensor of size with torch.Size([1, 18, 512]), for real PTI output
523
+ latent = styles
524
+
525
+ # print(f'processed latent: {latent.shape}')
526
+
527
+ features = {}
528
+ out = self.input(latent)
529
+ features["out_0"] = out
530
+ out = self.conv1(out, latent[:, 0], noise=noise[0])
531
+ features["conv1_0"] = out
532
+
533
+ skip = self.to_rgb1(out, latent[:, 1])
534
+ features["skip_0"] = skip
535
+ i = 1
536
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
537
+ self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
538
+ ):
539
+ out = conv1(out, latent[:, i], noise=noise1)
540
+ features["conv1_{}".format(i)] = out
541
+ out = conv2(out, latent[:, i + 1], noise=noise2)
542
+ features["conv2_{}".format(i)] = out
543
+ skip = to_rgb(out, latent[:, i + 2], skip)
544
+ features["skip_{}".format(i)] = skip
545
+
546
+ i += 2
547
+
548
+ image = skip
549
+
550
+ if return_latents:
551
+ return image, latent
552
+ elif return_features:
553
+ return image, features
554
+ else:
555
+ return image, None
556
+
557
+
558
+ class ConvLayer(nn.Sequential):
559
+ def __init__(
560
+ self,
561
+ in_channel,
562
+ out_channel,
563
+ kernel_size,
564
+ downsample=False,
565
+ blur_kernel=[1, 3, 3, 1],
566
+ bias=True,
567
+ activate=True,
568
+ ):
569
+ layers = []
570
+
571
+ if downsample:
572
+ factor = 2
573
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
574
+ pad0 = (p + 1) // 2
575
+ pad1 = p // 2
576
+
577
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
578
+
579
+ stride = 2
580
+ self.padding = 0
581
+
582
+ else:
583
+ stride = 1
584
+ self.padding = kernel_size // 2
585
+
586
+ layers.append(
587
+ EqualConv2d(
588
+ in_channel,
589
+ out_channel,
590
+ kernel_size,
591
+ padding=self.padding,
592
+ stride=stride,
593
+ bias=bias and not activate,
594
+ )
595
+ )
596
+
597
+ if activate:
598
+ if bias:
599
+ layers.append(FusedLeakyReLU(out_channel))
600
+ else:
601
+ layers.append(ScaledLeakyReLU(0.2))
602
+
603
+ super().__init__(*layers)
604
+
605
+
606
+ class ResBlock(nn.Module):
607
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
608
+ super().__init__()
609
+
610
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
611
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
612
+
613
+ self.skip = ConvLayer(
614
+ in_channel, out_channel, 1, downsample=True, activate=False, bias=False
615
+ )
616
+
617
+ def forward(self, input):
618
+ out = self.conv1(input)
619
+ out = self.conv2(out)
620
+
621
+ skip = self.skip(input)
622
+ out = (out + skip) / math.sqrt(2)
623
+
624
+ return out
625
+
626
+
627
+ class StyleDiscriminator(nn.Module):
628
+ def __init__(
629
+ self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1], small=False
630
+ ):
631
+ super().__init__()
632
+
633
+ if small:
634
+ channels = {4: 64, 8: 64, 16: 64, 32: 64, 64: 64}
635
+
636
+ else:
637
+ channels = {
638
+ 4: 512,
639
+ 8: 512,
640
+ 16: 512,
641
+ 32: 512,
642
+ 64: 256 * channel_multiplier,
643
+ 128: 128 * channel_multiplier,
644
+ 256: 64 * channel_multiplier,
645
+ 512: 32 * channel_multiplier,
646
+ 1024: 16 * channel_multiplier,
647
+ }
648
+
649
+ convs = [ConvLayer(3, channels[size], 1)]
650
+
651
+ log_size = int(math.log(size, 2))
652
+ in_channel = channels[size]
653
+
654
+ for i in range(log_size, 2, -1):
655
+ out_channel = channels[2 ** (i - 1)]
656
+
657
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel))
658
+
659
+ in_channel = out_channel
660
+
661
+ self.convs = nn.Sequential(*convs)
662
+
663
+ self.stddev_group = 4
664
+ self.stddev_feat = 1
665
+
666
+ self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
667
+ self.final_linear = nn.Sequential(
668
+ EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
669
+ EqualLinear(channels[4], 1),
670
+ )
671
+
672
+
673
+ def forward(self, input):
674
+ h = input
675
+ h_list = []
676
+
677
+ for index, blocklist in enumerate(self.convs):
678
+ h = blocklist(h)
679
+ h_list.append(h)
680
+
681
+ out = h
682
+ batch, channel, height, width = out.shape
683
+ group = min(batch, self.stddev_group)
684
+ stddev = out.view(
685
+ group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
686
+ )
687
+ stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
688
+ stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
689
+ stddev = stddev.repeat(group, 1, height, width)
690
+ out = torch.cat([out, stddev], 1)
691
+
692
+ out = self.final_conv(out)
693
+ h_list.append(out)
694
+
695
+ out = out.view(batch, -1)
696
+ out = self.final_linear(out)
697
+
698
+ return out, h_list
699
+
700
+
701
+ class StyleEncoder(nn.Module):
702
+ def __init__(self, size, w_dim=512):
703
+ super().__init__()
704
+
705
+ channels = {
706
+ 4: 512,
707
+ 8: 512,
708
+ 16: 512,
709
+ 32: 512,
710
+ 64: 256,
711
+ 128: 128,
712
+ 256: 64,
713
+ 512: 32,
714
+ 1024: 16
715
+ }
716
+
717
+ self.w_dim = w_dim
718
+ log_size = int(math.log(size, 2))
719
+ convs = [ConvLayer(3, channels[size], 1)]
720
+
721
+ in_channel = channels[size]
722
+ for i in range(log_size, 2, -1):
723
+ out_channel = channels[2 ** (i - 1)]
724
+ convs.append(ResBlock(in_channel, out_channel))
725
+ in_channel = out_channel
726
+
727
+ convs.append(EqualConv2d(in_channel,2*self.w_dim, 4, padding=0, bias=False))
728
+
729
+ self.convs = nn.Sequential(*convs)
730
+
731
+ def forward(self, input):
732
+ out = self.convs(input)
733
+ # return out.view(len(input), self.n_latents, self.w_dim)
734
+ reshaped = out.view(len(input), 2*self.w_dim)
735
+ return reshaped[:,:self.w_dim], reshaped[:,self.w_dim:]
736
+
737
+ def kaiming_init(m):
738
+ if isinstance(m, (nn.Linear, nn.Conv2d)):
739
+ init.kaiming_normal_(m.weight)
740
+ if m.bias is not None:
741
+ m.bias.data.fill_(0)
742
+ elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
743
+ m.weight.data.fill_(1)
744
+ if m.bias is not None:
745
+ m.bias.data.fill_(0)
746
+
747
+
748
+ def normal_init(m):
749
+ if isinstance(m, (nn.Linear, nn.Conv2d)):
750
+ init.normal_(m.weight, 0, 0.02)
751
+ if m.bias is not None:
752
+ m.bias.data.fill_(0)
753
+ elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
754
+ m.weight.data.fill_(1)
755
+ if m.bias is not None:
756
+ m.bias.data.fill_(0)
torch_utils/models_face.py ADDED
@@ -0,0 +1,809 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ import math
4
+ import random
5
+ import functools
6
+ import operator
7
+
8
+ import torch
9
+ from torch import nn
10
+ from torch.nn import functional as F
11
+ import torch.nn.init as init
12
+ from torch.autograd import Function
13
+
14
+ from .op_edit import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
15
+
16
+
17
+ class PixelNorm(nn.Module):
18
+ def __init__(self):
19
+ super().__init__()
20
+
21
+ def forward(self, input):
22
+ return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
23
+
24
+
25
+ def make_kernel(k):
26
+ k = torch.tensor(k, dtype=torch.float32)
27
+
28
+ if k.ndim == 1:
29
+ k = k[None, :] * k[:, None]
30
+
31
+ k /= k.sum()
32
+
33
+ return k
34
+
35
+
36
+ class Upsample(nn.Module):
37
+ def __init__(self, kernel, factor=2):
38
+ super().__init__()
39
+
40
+ self.factor = factor
41
+ kernel = make_kernel(kernel) * (factor ** 2)
42
+ self.register_buffer("kernel", kernel)
43
+
44
+ p = kernel.shape[0] - factor
45
+
46
+ pad0 = (p + 1) // 2 + factor - 1
47
+ pad1 = p // 2
48
+
49
+ self.pad = (pad0, pad1)
50
+
51
+ def forward(self, input):
52
+ out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
53
+
54
+ return out
55
+
56
+
57
+ class Downsample(nn.Module):
58
+ def __init__(self, kernel, factor=2):
59
+ super().__init__()
60
+
61
+ self.factor = factor
62
+ kernel = make_kernel(kernel)
63
+ self.register_buffer("kernel", kernel)
64
+
65
+ p = kernel.shape[0] - factor
66
+
67
+ pad0 = (p + 1) // 2
68
+ pad1 = p // 2
69
+
70
+ self.pad = (pad0, pad1)
71
+
72
+ def forward(self, input):
73
+ out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
74
+
75
+ return out
76
+
77
+
78
+ class Blur(nn.Module):
79
+ def __init__(self, kernel, pad, upsample_factor=1):
80
+ super().__init__()
81
+
82
+ kernel = make_kernel(kernel)
83
+
84
+ if upsample_factor > 1:
85
+ kernel = kernel * (upsample_factor ** 2)
86
+
87
+ self.register_buffer("kernel", kernel)
88
+
89
+ self.pad = pad
90
+
91
+ def forward(self, input):
92
+ out = upfirdn2d(input, self.kernel, pad=self.pad)
93
+
94
+ return out
95
+
96
+
97
+ class EqualConv2d(nn.Module):
98
+ def __init__(
99
+ self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
100
+ ):
101
+ super().__init__()
102
+
103
+ self.weight = nn.Parameter(
104
+ torch.randn(out_channel, in_channel, kernel_size, kernel_size)
105
+ )
106
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
107
+
108
+ self.stride = stride
109
+ self.padding = padding
110
+
111
+ if bias:
112
+ self.bias = nn.Parameter(torch.zeros(out_channel))
113
+
114
+ else:
115
+ self.bias = None
116
+
117
+ def forward(self, input):
118
+ out = F.conv2d(
119
+ input,
120
+ self.weight * self.scale,
121
+ bias=self.bias,
122
+ stride=self.stride,
123
+ padding=self.padding,
124
+ )
125
+
126
+ return out
127
+
128
+ def __repr__(self):
129
+ return (
130
+ f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
131
+ f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
132
+ )
133
+
134
+
135
+ class EqualLinear(nn.Module):
136
+ def __init__(
137
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
138
+ ):
139
+ super().__init__()
140
+
141
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
142
+
143
+ if bias:
144
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
145
+
146
+ else:
147
+ self.bias = None
148
+
149
+ self.activation = activation
150
+
151
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
152
+ self.lr_mul = lr_mul
153
+
154
+ def forward(self, input):
155
+ if self.activation:
156
+ out = F.linear(input, self.weight * self.scale)
157
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
158
+
159
+ else:
160
+ out = F.linear(
161
+ input, self.weight * self.scale, bias=self.bias * self.lr_mul
162
+ )
163
+
164
+ return out
165
+
166
+ def __repr__(self):
167
+ return (
168
+ f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
169
+ )
170
+
171
+
172
+ class ScaledLeakyReLU(nn.Module):
173
+ def __init__(self, negative_slope=0.2):
174
+ super().__init__()
175
+
176
+ self.negative_slope = negative_slope
177
+
178
+ def forward(self, input):
179
+ out = F.leaky_relu(input, negative_slope=self.negative_slope)
180
+
181
+ return out * math.sqrt(2)
182
+
183
+
184
+ class ModulatedConv2d(nn.Module):
185
+ def __init__(
186
+ self,
187
+ in_channel,
188
+ out_channel,
189
+ kernel_size,
190
+ style_dim,
191
+ demodulate=True,
192
+ upsample=False,
193
+ downsample=False,
194
+ blur_kernel=[1, 3, 3, 1],
195
+ ):
196
+ super().__init__()
197
+
198
+ self.eps = 1e-8
199
+ self.kernel_size = kernel_size
200
+ self.in_channel = in_channel
201
+ self.out_channel = out_channel
202
+ self.upsample = upsample
203
+ self.downsample = downsample
204
+
205
+ if upsample:
206
+ factor = 2
207
+ p = (len(blur_kernel) - factor) - (kernel_size - 1)
208
+ pad0 = (p + 1) // 2 + factor - 1
209
+ pad1 = p // 2 + 1
210
+
211
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
212
+
213
+ if downsample:
214
+ factor = 2
215
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
216
+ pad0 = (p + 1) // 2
217
+ pad1 = p // 2
218
+
219
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1))
220
+
221
+ fan_in = in_channel * kernel_size ** 2
222
+ self.scale = 1 / math.sqrt(fan_in)
223
+ self.padding = kernel_size // 2
224
+
225
+ self.weight = nn.Parameter(
226
+ torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
227
+ )
228
+
229
+ self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
230
+
231
+ self.demodulate = demodulate
232
+
233
+ def __repr__(self):
234
+ return (
235
+ f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
236
+ f"upsample={self.upsample}, downsample={self.downsample})"
237
+ )
238
+
239
+ def forward(self, input, style):
240
+ batch, in_channel, height, width = input.shape
241
+
242
+ style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
243
+ weight = self.scale * self.weight * style
244
+
245
+ if self.demodulate:
246
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
247
+ weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
248
+
249
+ weight = weight.view(
250
+ batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
251
+ )
252
+
253
+ if self.upsample:
254
+ input = input.view(1, batch * in_channel, height, width)
255
+ weight = weight.view(
256
+ batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
257
+ )
258
+ weight = weight.transpose(1, 2).reshape(
259
+ batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
260
+ )
261
+ out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
262
+ _, _, height, width = out.shape
263
+ out = out.view(batch, self.out_channel, height, width)
264
+ out = self.blur(out)
265
+
266
+ elif self.downsample:
267
+ input = self.blur(input)
268
+ _, _, height, width = input.shape
269
+ input = input.view(1, batch * in_channel, height, width)
270
+ out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
271
+ _, _, height, width = out.shape
272
+ out = out.view(batch, self.out_channel, height, width)
273
+
274
+ else:
275
+ input = input.view(1, batch * in_channel, height, width)
276
+ out = F.conv2d(input, weight, padding=self.padding, groups=batch)
277
+ _, _, height, width = out.shape
278
+ out = out.view(batch, self.out_channel, height, width)
279
+
280
+ return out
281
+
282
+
283
+ class NoiseInjection(nn.Module):
284
+ def __init__(self):
285
+ super().__init__()
286
+
287
+ self.weight = nn.Parameter(torch.zeros(1))
288
+
289
+ def forward(self, image, noise=None):
290
+ if noise is None:
291
+ batch, _, height, width = image.shape
292
+ noise = image.new_empty(batch, 1, height, width).normal_()
293
+
294
+ return image + self.weight * noise
295
+
296
+
297
+ class ConstantInput(nn.Module):
298
+ def __init__(self, channel, size=4):
299
+ super().__init__()
300
+
301
+ self.input = nn.Parameter(torch.randn(1, channel, size, size))
302
+
303
+ def forward(self, input):
304
+ batch = input.shape[0]
305
+ out = self.input.repeat(batch, 1, 1, 1)
306
+
307
+ return out
308
+
309
+
310
+ class StyledConv(nn.Module):
311
+ def __init__(
312
+ self,
313
+ in_channel,
314
+ out_channel,
315
+ kernel_size,
316
+ style_dim,
317
+ upsample=False,
318
+ blur_kernel=[1, 3, 3, 1],
319
+ demodulate=True,
320
+ ):
321
+ super().__init__()
322
+
323
+ self.conv = ModulatedConv2d(
324
+ in_channel,
325
+ out_channel,
326
+ kernel_size,
327
+ style_dim,
328
+ upsample=upsample,
329
+ blur_kernel=blur_kernel,
330
+ demodulate=demodulate,
331
+ )
332
+
333
+ self.noise = NoiseInjection()
334
+ # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
335
+ # self.activate = ScaledLeakyReLU(0.2)
336
+ self.activate = FusedLeakyReLU(out_channel)
337
+
338
+ def forward(self, input, style, noise=None):
339
+ out = self.conv(input, style)
340
+ out = self.noise(out, noise=noise)
341
+ # out = out + self.bias
342
+ out = self.activate(out)
343
+
344
+ return out
345
+
346
+
347
+ class ToRGB(nn.Module):
348
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
349
+ super().__init__()
350
+
351
+ if upsample:
352
+ self.upsample = Upsample(blur_kernel)
353
+
354
+ self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
355
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
356
+
357
+ def forward(self, input, style, skip=None):
358
+ out = self.conv(input, style)
359
+ out = out + self.bias
360
+
361
+ if skip is not None:
362
+ skip = self.upsample(skip)
363
+
364
+ out = out + skip
365
+
366
+ return out
367
+
368
+
369
+ class Generator(nn.Module):
370
+ def __init__(
371
+ self,
372
+ size,
373
+ style_dim,
374
+ n_mlp,
375
+ channel_multiplier=1,
376
+ blur_kernel=[1, 3, 3, 1],
377
+ lr_mlp=0.01,
378
+ small=False,
379
+ small_isaac=False,
380
+ ):
381
+ super().__init__()
382
+
383
+ self.size = size
384
+
385
+ if small and size > 64:
386
+ raise ValueError("small only works for sizes <= 64")
387
+
388
+ self.style_dim = style_dim
389
+
390
+ layers = [PixelNorm()]
391
+
392
+ for i in range(n_mlp):
393
+ layers.append(
394
+ EqualLinear(
395
+ style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
396
+ )
397
+ )
398
+
399
+ self.style = nn.Sequential(*layers)
400
+
401
+ if small:
402
+ self.channels = {
403
+ 4: 64 * channel_multiplier,
404
+ 8: 64 * channel_multiplier,
405
+ 16: 64 * channel_multiplier,
406
+ 32: 64 * channel_multiplier,
407
+ 64: 64 * channel_multiplier,
408
+ }
409
+ elif small_isaac:
410
+ self.channels = {4: 256, 8: 256, 16: 256, 32: 256, 64: 128, 128: 128}
411
+ else:
412
+ self.channels = {
413
+ 4: 512,
414
+ 8: 512,
415
+ 16: 512,
416
+ 32: 512,
417
+ 64: 256 * channel_multiplier,
418
+ 128: 128 * channel_multiplier,
419
+ 256: 64 * channel_multiplier,
420
+ 512: 32 * channel_multiplier,
421
+ 1024: 16 * channel_multiplier,
422
+ }
423
+
424
+ self.input = ConstantInput(self.channels[4])
425
+ self.conv1 = StyledConv(
426
+ self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
427
+ )
428
+ self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
429
+
430
+ self.log_size = int(math.log(size, 2))
431
+ self.num_layers = (self.log_size - 2) * 2 + 1
432
+
433
+ self.convs = nn.ModuleList()
434
+ self.upsamples = nn.ModuleList()
435
+ self.to_rgbs = nn.ModuleList()
436
+ self.noises = nn.Module()
437
+
438
+ in_channel = self.channels[4]
439
+
440
+ for layer_idx in range(self.num_layers):
441
+ res = (layer_idx + 5) // 2
442
+ shape = [1, 1, 2 ** res, 2 ** res]
443
+ self.noises.register_buffer(
444
+ "noise_{}".format(layer_idx), torch.randn(*shape)
445
+ )
446
+
447
+ for i in range(3, self.log_size + 1):
448
+ out_channel = self.channels[2 ** i]
449
+
450
+ self.convs.append(
451
+ StyledConv(
452
+ in_channel,
453
+ out_channel,
454
+ 3,
455
+ style_dim,
456
+ upsample=True,
457
+ blur_kernel=blur_kernel,
458
+ )
459
+ )
460
+
461
+ self.convs.append(
462
+ StyledConv(
463
+ out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
464
+ )
465
+ )
466
+
467
+ self.to_rgbs.append(ToRGB(out_channel, style_dim))
468
+
469
+ in_channel = out_channel
470
+
471
+ self.n_latent = self.log_size * 2 - 2
472
+
473
+ def make_noise(self):
474
+ device = self.input.input.device
475
+
476
+ noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
477
+
478
+ for i in range(3, self.log_size + 1):
479
+ for _ in range(2):
480
+ noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
481
+
482
+ return noises
483
+
484
+ def mean_latent(self, n_latent):
485
+ latent_in = torch.randn(
486
+ n_latent, self.style_dim, device=self.input.input.device
487
+ )
488
+ latent = self.style(latent_in).mean(0, keepdim=True)
489
+
490
+ return latent
491
+
492
+ def get_latent(self, input):
493
+ return self.style(input)
494
+
495
+ def forward(
496
+ self,
497
+ styles,
498
+ return_latents=False,
499
+ return_features=False,
500
+ inject_index=None,
501
+ truncation=1,
502
+ truncation_latent=None,
503
+ input_is_latent=False,
504
+ noise=None,
505
+ randomize_noise=True,
506
+ ):
507
+ if not input_is_latent:
508
+ # print("haha")
509
+ styles = [self.style(s) for s in styles]
510
+ if noise is None:
511
+ if randomize_noise:
512
+ noise = [None] * self.num_layers
513
+ else:
514
+ noise = [
515
+ getattr(self.noises, "noise_{}".format(i))
516
+ for i in range(self.num_layers)
517
+ ]
518
+
519
+ if truncation < 1:
520
+ style_t = []
521
+
522
+ for style in styles:
523
+ style_t.append(
524
+ truncation_latent + truncation * (style - truncation_latent)
525
+ )
526
+
527
+ styles = style_t
528
+ # print(styles)
529
+ if len(styles) < 2:
530
+ inject_index = self.n_latent
531
+
532
+ if styles[0].ndim < 3:
533
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
534
+ # print("a")
535
+ else:
536
+ # print(len(styles))
537
+ latent = styles[0]
538
+ # print("b", latent.shape)
539
+
540
+ else:
541
+ # print("c")
542
+ if inject_index is None:
543
+ inject_index = 4
544
+
545
+ latent = styles[0].unsqueeze(0)
546
+ if latent.shape[1] == 1:
547
+ latent = latent.repeat(1, inject_index, 1)
548
+ else:
549
+ latent = latent[:, :inject_index, :]
550
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
551
+
552
+ latent = torch.cat([latent, latent2], 1)
553
+
554
+ features = {}
555
+ out = self.input(latent)
556
+ features["out_0"] = out
557
+ out = self.conv1(out, latent[:, 0], noise=noise[0])
558
+ features["conv1_0"] = out
559
+
560
+ skip = self.to_rgb1(out, latent[:, 1])
561
+ features["skip_0"] = skip
562
+ i = 1
563
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
564
+ self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
565
+ ):
566
+ out = conv1(out, latent[:, i], noise=noise1)
567
+ features["conv1_{}".format(i)] = out
568
+ out = conv2(out, latent[:, i + 1], noise=noise2)
569
+ features["conv2_{}".format(i)] = out
570
+ skip = to_rgb(out, latent[:, i + 2], skip)
571
+ features["skip_{}".format(i)] = skip
572
+
573
+ i += 2
574
+
575
+ image = skip
576
+
577
+ if return_latents:
578
+ return image, latent
579
+ elif return_features:
580
+ return image, features
581
+ else:
582
+ return image, None
583
+
584
+
585
+ class ConvLayer(nn.Sequential):
586
+ def __init__(
587
+ self,
588
+ in_channel,
589
+ out_channel,
590
+ kernel_size,
591
+ downsample=False,
592
+ blur_kernel=[1, 3, 3, 1],
593
+ bias=True,
594
+ activate=True,
595
+ ):
596
+ layers = []
597
+
598
+ if downsample:
599
+ factor = 2
600
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
601
+ pad0 = (p + 1) // 2
602
+ pad1 = p // 2
603
+
604
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
605
+
606
+ stride = 2
607
+ self.padding = 0
608
+
609
+ else:
610
+ stride = 1
611
+ self.padding = kernel_size // 2
612
+
613
+ layers.append(
614
+ EqualConv2d(
615
+ in_channel,
616
+ out_channel,
617
+ kernel_size,
618
+ padding=self.padding,
619
+ stride=stride,
620
+ bias=bias and not activate,
621
+ )
622
+ )
623
+
624
+ if activate:
625
+ if bias:
626
+ layers.append(FusedLeakyReLU(out_channel))
627
+
628
+ else:
629
+ layers.append(ScaledLeakyReLU(0.2))
630
+
631
+ super().__init__(*layers)
632
+
633
+
634
+ class ResBlock(nn.Module):
635
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
636
+ super().__init__()
637
+
638
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
639
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
640
+
641
+ self.skip = ConvLayer(
642
+ in_channel, out_channel, 1, downsample=True, activate=False, bias=False
643
+ )
644
+
645
+ def forward(self, input):
646
+ out = self.conv1(input)
647
+ out = self.conv2(out)
648
+
649
+ skip = self.skip(input)
650
+ out = (out + skip) / math.sqrt(2)
651
+
652
+ return out
653
+
654
+
655
+ class StyleDiscriminator(nn.Module):
656
+ def __init__(
657
+ self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1], small=False
658
+ ):
659
+ super().__init__()
660
+
661
+ if small:
662
+ channels = {4: 64, 8: 64, 16: 64, 32: 64, 64: 64}
663
+
664
+ else:
665
+ channels = {
666
+ 4: 512,
667
+ 8: 512,
668
+ 16: 512,
669
+ 32: 512,
670
+ 64: 256 * channel_multiplier,
671
+ 128: 128 * channel_multiplier,
672
+ 256: 64 * channel_multiplier,
673
+ 512: 32 * channel_multiplier,
674
+ 1024: 16 * channel_multiplier,
675
+ }
676
+
677
+ convs = [ConvLayer(3, channels[size], 1)]
678
+
679
+ log_size = int(math.log(size, 2))
680
+
681
+ in_channel = channels[size]
682
+
683
+ for i in range(log_size, 2, -1):
684
+ out_channel = channels[2 ** (i - 1)]
685
+
686
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel))
687
+
688
+ in_channel = out_channel
689
+
690
+ self.convs = nn.Sequential(*convs)
691
+
692
+ self.stddev_group = 4
693
+ self.stddev_feat = 1
694
+
695
+ self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
696
+ self.final_linear = nn.Sequential(
697
+ EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
698
+ EqualLinear(channels[4], 1),
699
+ )
700
+
701
+ # def forward(self, input):
702
+ # out = self.convs(input)
703
+
704
+ # batch, channel, height, width = out.shape
705
+ # group = min(batch, self.stddev_group)
706
+ # stddev = out.view(
707
+ # group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
708
+ # )
709
+ # stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
710
+ # stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
711
+ # stddev = stddev.repeat(group, 1, height, width)
712
+ # out = torch.cat([out, stddev], 1)
713
+
714
+ # out = self.final_conv(out)
715
+
716
+ # out = out.view(batch, -1)
717
+ # out = self.final_linear(out)
718
+
719
+ # return out
720
+
721
+ def forward(self, input):
722
+ h = input
723
+ h_list = []
724
+
725
+ for index, blocklist in enumerate(self.convs):
726
+ h = blocklist(h)
727
+ h_list.append(h)
728
+
729
+ out = h
730
+ batch, channel, height, width = out.shape
731
+ group = min(batch, self.stddev_group)
732
+ stddev = out.view(
733
+ group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
734
+ )
735
+ stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
736
+ stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
737
+ stddev = stddev.repeat(group, 1, height, width)
738
+ out = torch.cat([out, stddev], 1)
739
+
740
+ out = self.final_conv(out)
741
+ h_list.append(out)
742
+
743
+ out = out.view(batch, -1)
744
+ out = self.final_linear(out)
745
+
746
+ return out, h_list
747
+
748
+
749
+ class StyleEncoder(nn.Module):
750
+ def __init__(self, size, w_dim=512):
751
+ super().__init__()
752
+
753
+ channels = {
754
+ 4: 512,
755
+ 8: 512,
756
+ 16: 512,
757
+ 32: 512,
758
+ 64: 256,
759
+ 128: 128,
760
+ 256: 64,
761
+ 512: 32,
762
+ 1024: 16
763
+ }
764
+
765
+ self.w_dim = w_dim
766
+ log_size = int(math.log(size, 2))
767
+
768
+ # self.n_latents = log_size*2 - 2
769
+
770
+ convs = [ConvLayer(3, channels[size], 1)]
771
+
772
+ in_channel = channels[size]
773
+ for i in range(log_size, 2, -1):
774
+ out_channel = channels[2 ** (i - 1)]
775
+ convs.append(ResBlock(in_channel, out_channel))
776
+ in_channel = out_channel
777
+
778
+ # convs.append(EqualConv2d(in_channel, self.n_latents*self.w_dim, 4, padding=0, bias=False))
779
+ convs.append(EqualConv2d(in_channel,2*self.w_dim, 4, padding=0, bias=False))
780
+
781
+
782
+ self.convs = nn.Sequential(*convs)
783
+
784
+ def forward(self, input):
785
+ out = self.convs(input)
786
+ # return out.view(len(input), self.n_latents, self.w_dim)
787
+ reshaped = out.view(len(input), 2*self.w_dim)
788
+ return reshaped[:,:self.w_dim], reshaped[:,self.w_dim:]
789
+
790
+ def kaiming_init(m):
791
+ if isinstance(m, (nn.Linear, nn.Conv2d)):
792
+ init.kaiming_normal_(m.weight)
793
+ if m.bias is not None:
794
+ m.bias.data.fill_(0)
795
+ elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
796
+ m.weight.data.fill_(1)
797
+ if m.bias is not None:
798
+ m.bias.data.fill_(0)
799
+
800
+
801
+ def normal_init(m):
802
+ if isinstance(m, (nn.Linear, nn.Conv2d)):
803
+ init.normal_(m.weight, 0, 0.02)
804
+ if m.bias is not None:
805
+ m.bias.data.fill_(0)
806
+ elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
807
+ m.weight.data.fill_(1)
808
+ if m.bias is not None:
809
+ m.bias.data.fill_(0)
torch_utils/op_edit/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ from .fused_act import FusedLeakyReLU, fused_leaky_relu
4
+ from .upfirdn2d import upfirdn2d
torch_utils/op_edit/fused_act.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ import os
4
+
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+ from torch.autograd import Function
9
+ from torch.utils.cpp_extension import load
10
+
11
+
12
+ module_path = os.path.dirname(__file__)
13
+ fused = load(
14
+ "fused",
15
+ sources=[
16
+ os.path.join(module_path, "fused_bias_act.cpp"),
17
+ os.path.join(module_path, "fused_bias_act_kernel.cu"),
18
+ ],
19
+ )
20
+
21
+
22
+ class FusedLeakyReLUFunctionBackward(Function):
23
+ @staticmethod
24
+ def forward(ctx, grad_output, out, negative_slope, scale):
25
+ ctx.save_for_backward(out)
26
+ ctx.negative_slope = negative_slope
27
+ ctx.scale = scale
28
+
29
+ empty = grad_output.new_empty(0)
30
+
31
+ grad_input = fused.fused_bias_act(
32
+ grad_output, empty, out, 3, 1, negative_slope, scale
33
+ )
34
+
35
+ dim = [0]
36
+
37
+ if grad_input.ndim > 2:
38
+ dim += list(range(2, grad_input.ndim))
39
+
40
+ grad_bias = grad_input.sum(dim).detach()
41
+
42
+ return grad_input, grad_bias
43
+
44
+ @staticmethod
45
+ def backward(ctx, gradgrad_input, gradgrad_bias):
46
+ (out,) = ctx.saved_tensors
47
+ gradgrad_out = fused.fused_bias_act(
48
+ gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
49
+ )
50
+
51
+ return gradgrad_out, None, None, None
52
+
53
+
54
+ class FusedLeakyReLUFunction(Function):
55
+ @staticmethod
56
+ def forward(ctx, input, bias, negative_slope, scale):
57
+ empty = input.new_empty(0)
58
+ out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
59
+ ctx.save_for_backward(out)
60
+ ctx.negative_slope = negative_slope
61
+ ctx.scale = scale
62
+
63
+ return out
64
+
65
+ @staticmethod
66
+ def backward(ctx, grad_output):
67
+ (out,) = ctx.saved_tensors
68
+
69
+ grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
70
+ grad_output, out, ctx.negative_slope, ctx.scale
71
+ )
72
+
73
+ return grad_input, grad_bias, None, None
74
+
75
+
76
+ class FusedLeakyReLU(nn.Module):
77
+ def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
78
+ super().__init__()
79
+
80
+ self.bias = nn.Parameter(torch.zeros(channel))
81
+ self.negative_slope = negative_slope
82
+ self.scale = scale
83
+
84
+ def forward(self, input):
85
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
86
+
87
+
88
+ def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
89
+ if input.device.type == "cpu":
90
+ rest_dim = [1] * (input.ndim - bias.ndim - 1)
91
+ return (
92
+ F.leaky_relu(
93
+ input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
94
+ )
95
+ * scale
96
+ )
97
+
98
+ else:
99
+ return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
torch_utils/op_edit/fused_bias_act.cpp ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ #include <torch/extension.h>
4
+
5
+
6
+ torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
7
+ int act, int grad, float alpha, float scale);
8
+
9
+ #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
10
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
11
+ #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
12
+
13
+ torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
14
+ int act, int grad, float alpha, float scale) {
15
+ CHECK_CUDA(input);
16
+ CHECK_CUDA(bias);
17
+
18
+ return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
19
+ }
20
+
21
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
22
+ m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
23
+ }
torch_utils/op_edit/fused_bias_act_kernel.cu ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
4
+ //
5
+ // This work is made available under the Nvidia Source Code License-NC.
6
+ // To view a copy of this license, visit
7
+ // https://nvlabs.github.io/stylegan2/license.html
8
+
9
+ #include <torch/types.h>
10
+
11
+ #include <ATen/ATen.h>
12
+ #include <ATen/AccumulateType.h>
13
+ #include <ATen/cuda/CUDAContext.h>
14
+ #include <ATen/cuda/CUDAApplyUtils.cuh>
15
+
16
+ #include <cuda.h>
17
+ #include <cuda_runtime.h>
18
+
19
+
20
+ template <typename scalar_t>
21
+ static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
22
+ int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
23
+ int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
24
+
25
+ scalar_t zero = 0.0;
26
+
27
+ for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
28
+ scalar_t x = p_x[xi];
29
+
30
+ if (use_bias) {
31
+ x += p_b[(xi / step_b) % size_b];
32
+ }
33
+
34
+ scalar_t ref = use_ref ? p_ref[xi] : zero;
35
+
36
+ scalar_t y;
37
+
38
+ switch (act * 10 + grad) {
39
+ default:
40
+ case 10: y = x; break;
41
+ case 11: y = x; break;
42
+ case 12: y = 0.0; break;
43
+
44
+ case 30: y = (x > 0.0) ? x : x * alpha; break;
45
+ case 31: y = (ref > 0.0) ? x : x * alpha; break;
46
+ case 32: y = 0.0; break;
47
+ }
48
+
49
+ out[xi] = y * scale;
50
+ }
51
+ }
52
+
53
+
54
+ torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
55
+ int act, int grad, float alpha, float scale) {
56
+ int curDevice = -1;
57
+ cudaGetDevice(&curDevice);
58
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
59
+
60
+ auto x = input.contiguous();
61
+ auto b = bias.contiguous();
62
+ auto ref = refer.contiguous();
63
+
64
+ int use_bias = b.numel() ? 1 : 0;
65
+ int use_ref = ref.numel() ? 1 : 0;
66
+
67
+ int size_x = x.numel();
68
+ int size_b = b.numel();
69
+ int step_b = 1;
70
+
71
+ for (int i = 1 + 1; i < x.dim(); i++) {
72
+ step_b *= x.size(i);
73
+ }
74
+
75
+ int loop_x = 4;
76
+ int block_size = 4 * 32;
77
+ int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
78
+
79
+ auto y = torch::empty_like(x);
80
+
81
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
82
+ fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
83
+ y.data_ptr<scalar_t>(),
84
+ x.data_ptr<scalar_t>(),
85
+ b.data_ptr<scalar_t>(),
86
+ ref.data_ptr<scalar_t>(),
87
+ act,
88
+ grad,
89
+ alpha,
90
+ scale,
91
+ loop_x,
92
+ size_x,
93
+ step_b,
94
+ size_b,
95
+ use_bias,
96
+ use_ref
97
+ );
98
+ });
99
+
100
+ return y;
101
+ }
torch_utils/op_edit/upfirdn2d.cpp ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ #include <torch/extension.h>
4
+
5
+
6
+ torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
7
+ int up_x, int up_y, int down_x, int down_y,
8
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1);
9
+
10
+ #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
11
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
12
+ #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
13
+
14
+ torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
15
+ int up_x, int up_y, int down_x, int down_y,
16
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
17
+ CHECK_CUDA(input);
18
+ CHECK_CUDA(kernel);
19
+
20
+ return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
21
+ }
22
+
23
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
24
+ m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
25
+ }
torch_utils/op_edit/upfirdn2d.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ import os
4
+
5
+ import torch
6
+ from torch.nn import functional as F
7
+ from torch.autograd import Function
8
+ from torch.utils.cpp_extension import load
9
+
10
+
11
+ module_path = os.path.dirname(__file__)
12
+ upfirdn2d_op = load(
13
+ "upfirdn2d",
14
+ sources=[
15
+ os.path.join(module_path, "upfirdn2d.cpp"),
16
+ os.path.join(module_path, "upfirdn2d_kernel.cu"),
17
+ ],
18
+ )
19
+
20
+
21
+ class UpFirDn2dBackward(Function):
22
+ @staticmethod
23
+ def forward(
24
+ ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
25
+ ):
26
+
27
+ up_x, up_y = up
28
+ down_x, down_y = down
29
+ g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
30
+
31
+ grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
32
+
33
+ grad_input = upfirdn2d_op.upfirdn2d(
34
+ grad_output,
35
+ grad_kernel,
36
+ down_x,
37
+ down_y,
38
+ up_x,
39
+ up_y,
40
+ g_pad_x0,
41
+ g_pad_x1,
42
+ g_pad_y0,
43
+ g_pad_y1,
44
+ )
45
+ grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
46
+
47
+ ctx.save_for_backward(kernel)
48
+
49
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
50
+
51
+ ctx.up_x = up_x
52
+ ctx.up_y = up_y
53
+ ctx.down_x = down_x
54
+ ctx.down_y = down_y
55
+ ctx.pad_x0 = pad_x0
56
+ ctx.pad_x1 = pad_x1
57
+ ctx.pad_y0 = pad_y0
58
+ ctx.pad_y1 = pad_y1
59
+ ctx.in_size = in_size
60
+ ctx.out_size = out_size
61
+
62
+ return grad_input
63
+
64
+ @staticmethod
65
+ def backward(ctx, gradgrad_input):
66
+ (kernel,) = ctx.saved_tensors
67
+
68
+ gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
69
+
70
+ gradgrad_out = upfirdn2d_op.upfirdn2d(
71
+ gradgrad_input,
72
+ kernel,
73
+ ctx.up_x,
74
+ ctx.up_y,
75
+ ctx.down_x,
76
+ ctx.down_y,
77
+ ctx.pad_x0,
78
+ ctx.pad_x1,
79
+ ctx.pad_y0,
80
+ ctx.pad_y1,
81
+ )
82
+ # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
83
+ gradgrad_out = gradgrad_out.view(
84
+ ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
85
+ )
86
+
87
+ return gradgrad_out, None, None, None, None, None, None, None, None
88
+
89
+
90
+ class UpFirDn2d(Function):
91
+ @staticmethod
92
+ def forward(ctx, input, kernel, up, down, pad):
93
+ up_x, up_y = up
94
+ down_x, down_y = down
95
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
96
+
97
+ kernel_h, kernel_w = kernel.shape
98
+ batch, channel, in_h, in_w = input.shape
99
+ ctx.in_size = input.shape
100
+
101
+ input = input.reshape(-1, in_h, in_w, 1)
102
+
103
+ ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
104
+
105
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
106
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
107
+ ctx.out_size = (out_h, out_w)
108
+
109
+ ctx.up = (up_x, up_y)
110
+ ctx.down = (down_x, down_y)
111
+ ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
112
+
113
+ g_pad_x0 = kernel_w - pad_x0 - 1
114
+ g_pad_y0 = kernel_h - pad_y0 - 1
115
+ g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
116
+ g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
117
+
118
+ ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
119
+
120
+ out = upfirdn2d_op.upfirdn2d(
121
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
122
+ )
123
+ # out = out.view(major, out_h, out_w, minor)
124
+ out = out.view(-1, channel, out_h, out_w)
125
+
126
+ return out
127
+
128
+ @staticmethod
129
+ def backward(ctx, grad_output):
130
+ kernel, grad_kernel = ctx.saved_tensors
131
+
132
+ grad_input = UpFirDn2dBackward.apply(
133
+ grad_output,
134
+ kernel,
135
+ grad_kernel,
136
+ ctx.up,
137
+ ctx.down,
138
+ ctx.pad,
139
+ ctx.g_pad,
140
+ ctx.in_size,
141
+ ctx.out_size,
142
+ )
143
+
144
+ return grad_input, None, None, None, None
145
+
146
+
147
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
148
+ if input.device.type == "cpu":
149
+ out = upfirdn2d_native(
150
+ input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
151
+ )
152
+
153
+ else:
154
+ out = UpFirDn2d.apply(
155
+ input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
156
+ )
157
+
158
+ return out
159
+
160
+
161
+ def upfirdn2d_native(
162
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
163
+ ):
164
+ _, channel, in_h, in_w = input.shape
165
+ input = input.reshape(-1, in_h, in_w, 1)
166
+
167
+ _, in_h, in_w, minor = input.shape
168
+ kernel_h, kernel_w = kernel.shape
169
+
170
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
171
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
172
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
173
+
174
+ out = F.pad(
175
+ out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
176
+ )
177
+ out = out[
178
+ :,
179
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
180
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
181
+ :,
182
+ ]
183
+
184
+ out = out.permute(0, 3, 1, 2)
185
+ out = out.reshape(
186
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
187
+ )
188
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
189
+ out = F.conv2d(out, w)
190
+ out = out.reshape(
191
+ -1,
192
+ minor,
193
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
194
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
195
+ )
196
+ out = out.permute(0, 2, 3, 1)
197
+ out = out[:, ::down_y, ::down_x, :]
198
+
199
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
200
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
201
+
202
+ return out.view(-1, channel, out_h, out_w)
torch_utils/op_edit/upfirdn2d_kernel.cu ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
4
+ //
5
+ // This work is made available under the Nvidia Source Code License-NC.
6
+ // To view a copy of this license, visit
7
+ // https://nvlabs.github.io/stylegan2/license.html
8
+
9
+ #include <torch/types.h>
10
+
11
+ #include <ATen/ATen.h>
12
+ #include <ATen/AccumulateType.h>
13
+ #include <ATen/cuda/CUDAApplyUtils.cuh>
14
+ #include <ATen/cuda/CUDAContext.h>
15
+
16
+ #include <cuda.h>
17
+ #include <cuda_runtime.h>
18
+
19
+ static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
20
+ int c = a / b;
21
+
22
+ if (c * b > a) {
23
+ c--;
24
+ }
25
+
26
+ return c;
27
+ }
28
+
29
+ struct UpFirDn2DKernelParams {
30
+ int up_x;
31
+ int up_y;
32
+ int down_x;
33
+ int down_y;
34
+ int pad_x0;
35
+ int pad_x1;
36
+ int pad_y0;
37
+ int pad_y1;
38
+
39
+ int major_dim;
40
+ int in_h;
41
+ int in_w;
42
+ int minor_dim;
43
+ int kernel_h;
44
+ int kernel_w;
45
+ int out_h;
46
+ int out_w;
47
+ int loop_major;
48
+ int loop_x;
49
+ };
50
+
51
+ template <typename scalar_t>
52
+ __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
53
+ const scalar_t *kernel,
54
+ const UpFirDn2DKernelParams p) {
55
+ int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
56
+ int out_y = minor_idx / p.minor_dim;
57
+ minor_idx -= out_y * p.minor_dim;
58
+ int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
59
+ int major_idx_base = blockIdx.z * p.loop_major;
60
+
61
+ if (out_x_base >= p.out_w || out_y >= p.out_h ||
62
+ major_idx_base >= p.major_dim) {
63
+ return;
64
+ }
65
+
66
+ int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
67
+ int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
68
+ int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
69
+ int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
70
+
71
+ for (int loop_major = 0, major_idx = major_idx_base;
72
+ loop_major < p.loop_major && major_idx < p.major_dim;
73
+ loop_major++, major_idx++) {
74
+ for (int loop_x = 0, out_x = out_x_base;
75
+ loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
76
+ int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
77
+ int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
78
+ int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
79
+ int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
80
+
81
+ const scalar_t *x_p =
82
+ &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
83
+ minor_idx];
84
+ const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
85
+ int x_px = p.minor_dim;
86
+ int k_px = -p.up_x;
87
+ int x_py = p.in_w * p.minor_dim;
88
+ int k_py = -p.up_y * p.kernel_w;
89
+
90
+ scalar_t v = 0.0f;
91
+
92
+ for (int y = 0; y < h; y++) {
93
+ for (int x = 0; x < w; x++) {
94
+ v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);
95
+ x_p += x_px;
96
+ k_p += k_px;
97
+ }
98
+
99
+ x_p += x_py - w * x_px;
100
+ k_p += k_py - w * k_px;
101
+ }
102
+
103
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
104
+ minor_idx] = v;
105
+ }
106
+ }
107
+ }
108
+
109
+ template <typename scalar_t, int up_x, int up_y, int down_x, int down_y,
110
+ int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
111
+ __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
112
+ const scalar_t *kernel,
113
+ const UpFirDn2DKernelParams p) {
114
+ const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
115
+ const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
116
+
117
+ __shared__ volatile float sk[kernel_h][kernel_w];
118
+ __shared__ volatile float sx[tile_in_h][tile_in_w];
119
+
120
+ int minor_idx = blockIdx.x;
121
+ int tile_out_y = minor_idx / p.minor_dim;
122
+ minor_idx -= tile_out_y * p.minor_dim;
123
+ tile_out_y *= tile_out_h;
124
+ int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
125
+ int major_idx_base = blockIdx.z * p.loop_major;
126
+
127
+ if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
128
+ major_idx_base >= p.major_dim) {
129
+ return;
130
+ }
131
+
132
+ for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
133
+ tap_idx += blockDim.x) {
134
+ int ky = tap_idx / kernel_w;
135
+ int kx = tap_idx - ky * kernel_w;
136
+ scalar_t v = 0.0;
137
+
138
+ if (kx < p.kernel_w & ky < p.kernel_h) {
139
+ v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
140
+ }
141
+
142
+ sk[ky][kx] = v;
143
+ }
144
+
145
+ for (int loop_major = 0, major_idx = major_idx_base;
146
+ loop_major < p.loop_major & major_idx < p.major_dim;
147
+ loop_major++, major_idx++) {
148
+ for (int loop_x = 0, tile_out_x = tile_out_x_base;
149
+ loop_x < p.loop_x & tile_out_x < p.out_w;
150
+ loop_x++, tile_out_x += tile_out_w) {
151
+ int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
152
+ int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
153
+ int tile_in_x = floor_div(tile_mid_x, up_x);
154
+ int tile_in_y = floor_div(tile_mid_y, up_y);
155
+
156
+ __syncthreads();
157
+
158
+ for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
159
+ in_idx += blockDim.x) {
160
+ int rel_in_y = in_idx / tile_in_w;
161
+ int rel_in_x = in_idx - rel_in_y * tile_in_w;
162
+ int in_x = rel_in_x + tile_in_x;
163
+ int in_y = rel_in_y + tile_in_y;
164
+
165
+ scalar_t v = 0.0;
166
+
167
+ if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
168
+ v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
169
+ p.minor_dim +
170
+ minor_idx];
171
+ }
172
+
173
+ sx[rel_in_y][rel_in_x] = v;
174
+ }
175
+
176
+ __syncthreads();
177
+ for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
178
+ out_idx += blockDim.x) {
179
+ int rel_out_y = out_idx / tile_out_w;
180
+ int rel_out_x = out_idx - rel_out_y * tile_out_w;
181
+ int out_x = rel_out_x + tile_out_x;
182
+ int out_y = rel_out_y + tile_out_y;
183
+
184
+ int mid_x = tile_mid_x + rel_out_x * down_x;
185
+ int mid_y = tile_mid_y + rel_out_y * down_y;
186
+ int in_x = floor_div(mid_x, up_x);
187
+ int in_y = floor_div(mid_y, up_y);
188
+ int rel_in_x = in_x - tile_in_x;
189
+ int rel_in_y = in_y - tile_in_y;
190
+ int kernel_x = (in_x + 1) * up_x - mid_x - 1;
191
+ int kernel_y = (in_y + 1) * up_y - mid_y - 1;
192
+
193
+ scalar_t v = 0.0;
194
+
195
+ #pragma unroll
196
+ for (int y = 0; y < kernel_h / up_y; y++)
197
+ #pragma unroll
198
+ for (int x = 0; x < kernel_w / up_x; x++)
199
+ v += sx[rel_in_y + y][rel_in_x + x] *
200
+ sk[kernel_y + y * up_y][kernel_x + x * up_x];
201
+
202
+ if (out_x < p.out_w & out_y < p.out_h) {
203
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
204
+ minor_idx] = v;
205
+ }
206
+ }
207
+ }
208
+ }
209
+ }
210
+
211
+ torch::Tensor upfirdn2d_op(const torch::Tensor &input,
212
+ const torch::Tensor &kernel, int up_x, int up_y,
213
+ int down_x, int down_y, int pad_x0, int pad_x1,
214
+ int pad_y0, int pad_y1) {
215
+ int curDevice = -1;
216
+ cudaGetDevice(&curDevice);
217
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
218
+
219
+ UpFirDn2DKernelParams p;
220
+
221
+ auto x = input.contiguous();
222
+ auto k = kernel.contiguous();
223
+
224
+ p.major_dim = x.size(0);
225
+ p.in_h = x.size(1);
226
+ p.in_w = x.size(2);
227
+ p.minor_dim = x.size(3);
228
+ p.kernel_h = k.size(0);
229
+ p.kernel_w = k.size(1);
230
+ p.up_x = up_x;
231
+ p.up_y = up_y;
232
+ p.down_x = down_x;
233
+ p.down_y = down_y;
234
+ p.pad_x0 = pad_x0;
235
+ p.pad_x1 = pad_x1;
236
+ p.pad_y0 = pad_y0;
237
+ p.pad_y1 = pad_y1;
238
+
239
+ p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
240
+ p.down_y;
241
+ p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
242
+ p.down_x;
243
+
244
+ auto out =
245
+ at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
246
+
247
+ int mode = -1;
248
+
249
+ int tile_out_h = -1;
250
+ int tile_out_w = -1;
251
+
252
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
253
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
254
+ mode = 1;
255
+ tile_out_h = 16;
256
+ tile_out_w = 64;
257
+ }
258
+
259
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
260
+ p.kernel_h <= 3 && p.kernel_w <= 3) {
261
+ mode = 2;
262
+ tile_out_h = 16;
263
+ tile_out_w = 64;
264
+ }
265
+
266
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
267
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
268
+ mode = 3;
269
+ tile_out_h = 16;
270
+ tile_out_w = 64;
271
+ }
272
+
273
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
274
+ p.kernel_h <= 2 && p.kernel_w <= 2) {
275
+ mode = 4;
276
+ tile_out_h = 16;
277
+ tile_out_w = 64;
278
+ }
279
+
280
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
281
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
282
+ mode = 5;
283
+ tile_out_h = 8;
284
+ tile_out_w = 32;
285
+ }
286
+
287
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
288
+ p.kernel_h <= 2 && p.kernel_w <= 2) {
289
+ mode = 6;
290
+ tile_out_h = 8;
291
+ tile_out_w = 32;
292
+ }
293
+
294
+ dim3 block_size;
295
+ dim3 grid_size;
296
+
297
+ if (tile_out_h > 0 && tile_out_w > 0) {
298
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
299
+ p.loop_x = 1;
300
+ block_size = dim3(32 * 8, 1, 1);
301
+ grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
302
+ (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
303
+ (p.major_dim - 1) / p.loop_major + 1);
304
+ } else {
305
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
306
+ p.loop_x = 4;
307
+ block_size = dim3(4, 32, 1);
308
+ grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
309
+ (p.out_w - 1) / (p.loop_x * block_size.y) + 1,
310
+ (p.major_dim - 1) / p.loop_major + 1);
311
+ }
312
+
313
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
314
+ switch (mode) {
315
+ case 1:
316
+ upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>
317
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
318
+ x.data_ptr<scalar_t>(),
319
+ k.data_ptr<scalar_t>(), p);
320
+
321
+ break;
322
+
323
+ case 2:
324
+ upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>
325
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
326
+ x.data_ptr<scalar_t>(),
327
+ k.data_ptr<scalar_t>(), p);
328
+
329
+ break;
330
+
331
+ case 3:
332
+ upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>
333
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
334
+ x.data_ptr<scalar_t>(),
335
+ k.data_ptr<scalar_t>(), p);
336
+
337
+ break;
338
+
339
+ case 4:
340
+ upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>
341
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
342
+ x.data_ptr<scalar_t>(),
343
+ k.data_ptr<scalar_t>(), p);
344
+
345
+ break;
346
+
347
+ case 5:
348
+ upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
349
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
350
+ x.data_ptr<scalar_t>(),
351
+ k.data_ptr<scalar_t>(), p);
352
+
353
+ break;
354
+
355
+ case 6:
356
+ upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
357
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
358
+ x.data_ptr<scalar_t>(),
359
+ k.data_ptr<scalar_t>(), p);
360
+
361
+ break;
362
+
363
+ default:
364
+ upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(
365
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
366
+ k.data_ptr<scalar_t>(), p);
367
+ }
368
+ });
369
+
370
+ return out;
371
+ }
torch_utils/ops/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ #empty
torch_utils/ops/bias_act.cpp ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
+ //
5
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ // and proprietary rights in and to this software, related documentation
7
+ // and any modifications thereto. Any use, reproduction, disclosure or
8
+ // distribution of this software and related documentation without an express
9
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+
11
+ #include <torch/extension.h>
12
+ #include <ATen/cuda/CUDAContext.h>
13
+ #include <c10/cuda/CUDAGuard.h>
14
+ #include "bias_act.h"
15
+
16
+ //------------------------------------------------------------------------
17
+
18
+ static bool has_same_layout(torch::Tensor x, torch::Tensor y)
19
+ {
20
+ if (x.dim() != y.dim())
21
+ return false;
22
+ for (int64_t i = 0; i < x.dim(); i++)
23
+ {
24
+ if (x.size(i) != y.size(i))
25
+ return false;
26
+ if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
27
+ return false;
28
+ }
29
+ return true;
30
+ }
31
+
32
+ //------------------------------------------------------------------------
33
+
34
+ static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
35
+ {
36
+ // Validate arguments.
37
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
38
+ TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
39
+ TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
40
+ TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
41
+ TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
42
+ TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
43
+ TORCH_CHECK(b.dim() == 1, "b must have rank 1");
44
+ TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
45
+ TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
46
+ TORCH_CHECK(grad >= 0, "grad must be non-negative");
47
+
48
+ // Validate layout.
49
+ TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
50
+ TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
51
+ TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
52
+ TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
53
+ TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
54
+
55
+ // Create output tensor.
56
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
57
+ torch::Tensor y = torch::empty_like(x);
58
+ TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
59
+
60
+ // Initialize CUDA kernel parameters.
61
+ bias_act_kernel_params p;
62
+ p.x = x.data_ptr();
63
+ p.b = (b.numel()) ? b.data_ptr() : NULL;
64
+ p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
65
+ p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
66
+ p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
67
+ p.y = y.data_ptr();
68
+ p.grad = grad;
69
+ p.act = act;
70
+ p.alpha = alpha;
71
+ p.gain = gain;
72
+ p.clamp = clamp;
73
+ p.sizeX = (int)x.numel();
74
+ p.sizeB = (int)b.numel();
75
+ p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
76
+
77
+ // Choose CUDA kernel.
78
+ void* kernel;
79
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
80
+ {
81
+ kernel = choose_bias_act_kernel<scalar_t>(p);
82
+ });
83
+ TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
84
+
85
+ // Launch CUDA kernel.
86
+ p.loopX = 4;
87
+ int blockSize = 4 * 32;
88
+ int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
89
+ void* args[] = {&p};
90
+ AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
91
+ return y;
92
+ }
93
+
94
+ //------------------------------------------------------------------------
95
+
96
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
97
+ {
98
+ m.def("bias_act", &bias_act);
99
+ }
100
+
101
+ //------------------------------------------------------------------------
torch_utils/ops/bias_act.cu ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
+ //
5
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ // and proprietary rights in and to this software, related documentation
7
+ // and any modifications thereto. Any use, reproduction, disclosure or
8
+ // distribution of this software and related documentation without an express
9
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+
11
+ #include <c10/util/Half.h>
12
+ #include "bias_act.h"
13
+
14
+ //------------------------------------------------------------------------
15
+ // Helpers.
16
+
17
+ template <class T> struct InternalType;
18
+ template <> struct InternalType<double> { typedef double scalar_t; };
19
+ template <> struct InternalType<float> { typedef float scalar_t; };
20
+ template <> struct InternalType<c10::Half> { typedef float scalar_t; };
21
+
22
+ //------------------------------------------------------------------------
23
+ // CUDA kernel.
24
+
25
+ template <class T, int A>
26
+ __global__ void bias_act_kernel(bias_act_kernel_params p)
27
+ {
28
+ typedef typename InternalType<T>::scalar_t scalar_t;
29
+ int G = p.grad;
30
+ scalar_t alpha = (scalar_t)p.alpha;
31
+ scalar_t gain = (scalar_t)p.gain;
32
+ scalar_t clamp = (scalar_t)p.clamp;
33
+ scalar_t one = (scalar_t)1;
34
+ scalar_t two = (scalar_t)2;
35
+ scalar_t expRange = (scalar_t)80;
36
+ scalar_t halfExpRange = (scalar_t)40;
37
+ scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
38
+ scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
39
+
40
+ // Loop over elements.
41
+ int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
42
+ for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
43
+ {
44
+ // Load.
45
+ scalar_t x = (scalar_t)((const T*)p.x)[xi];
46
+ scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
47
+ scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
48
+ scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
49
+ scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
50
+ scalar_t yy = (gain != 0) ? yref / gain : 0;
51
+ scalar_t y = 0;
52
+
53
+ // Apply bias.
54
+ ((G == 0) ? x : xref) += b;
55
+
56
+ // linear
57
+ if (A == 1)
58
+ {
59
+ if (G == 0) y = x;
60
+ if (G == 1) y = x;
61
+ }
62
+
63
+ // relu
64
+ if (A == 2)
65
+ {
66
+ if (G == 0) y = (x > 0) ? x : 0;
67
+ if (G == 1) y = (yy > 0) ? x : 0;
68
+ }
69
+
70
+ // lrelu
71
+ if (A == 3)
72
+ {
73
+ if (G == 0) y = (x > 0) ? x : x * alpha;
74
+ if (G == 1) y = (yy > 0) ? x : x * alpha;
75
+ }
76
+
77
+ // tanh
78
+ if (A == 4)
79
+ {
80
+ if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
81
+ if (G == 1) y = x * (one - yy * yy);
82
+ if (G == 2) y = x * (one - yy * yy) * (-two * yy);
83
+ }
84
+
85
+ // sigmoid
86
+ if (A == 5)
87
+ {
88
+ if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
89
+ if (G == 1) y = x * yy * (one - yy);
90
+ if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
91
+ }
92
+
93
+ // elu
94
+ if (A == 6)
95
+ {
96
+ if (G == 0) y = (x >= 0) ? x : exp(x) - one;
97
+ if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
98
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
99
+ }
100
+
101
+ // selu
102
+ if (A == 7)
103
+ {
104
+ if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
105
+ if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
106
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
107
+ }
108
+
109
+ // softplus
110
+ if (A == 8)
111
+ {
112
+ if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
113
+ if (G == 1) y = x * (one - exp(-yy));
114
+ if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
115
+ }
116
+
117
+ // swish
118
+ if (A == 9)
119
+ {
120
+ if (G == 0)
121
+ y = (x < -expRange) ? 0 : x / (exp(-x) + one);
122
+ else
123
+ {
124
+ scalar_t c = exp(xref);
125
+ scalar_t d = c + one;
126
+ if (G == 1)
127
+ y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
128
+ else
129
+ y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
130
+ yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
131
+ }
132
+ }
133
+
134
+ // Apply gain.
135
+ y *= gain * dy;
136
+
137
+ // Clamp.
138
+ if (clamp >= 0)
139
+ {
140
+ if (G == 0)
141
+ y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
142
+ else
143
+ y = (yref > -clamp & yref < clamp) ? y : 0;
144
+ }
145
+
146
+ // Store.
147
+ ((T*)p.y)[xi] = (T)y;
148
+ }
149
+ }
150
+
151
+ //------------------------------------------------------------------------
152
+ // CUDA kernel selection.
153
+
154
+ template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p)
155
+ {
156
+ if (p.act == 1) return (void*)bias_act_kernel<T, 1>;
157
+ if (p.act == 2) return (void*)bias_act_kernel<T, 2>;
158
+ if (p.act == 3) return (void*)bias_act_kernel<T, 3>;
159
+ if (p.act == 4) return (void*)bias_act_kernel<T, 4>;
160
+ if (p.act == 5) return (void*)bias_act_kernel<T, 5>;
161
+ if (p.act == 6) return (void*)bias_act_kernel<T, 6>;
162
+ if (p.act == 7) return (void*)bias_act_kernel<T, 7>;
163
+ if (p.act == 8) return (void*)bias_act_kernel<T, 8>;
164
+ if (p.act == 9) return (void*)bias_act_kernel<T, 9>;
165
+ return NULL;
166
+ }
167
+
168
+ //------------------------------------------------------------------------
169
+ // Template specializations.
170
+
171
+ template void* choose_bias_act_kernel<double> (const bias_act_kernel_params& p);
172
+ template void* choose_bias_act_kernel<float> (const bias_act_kernel_params& p);
173
+ template void* choose_bias_act_kernel<c10::Half> (const bias_act_kernel_params& p);
174
+
175
+ //------------------------------------------------------------------------
torch_utils/ops/bias_act.h ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
+ //
5
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ // and proprietary rights in and to this software, related documentation
7
+ // and any modifications thereto. Any use, reproduction, disclosure or
8
+ // distribution of this software and related documentation without an express
9
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+
11
+ //------------------------------------------------------------------------
12
+ // CUDA kernel parameters.
13
+
14
+ struct bias_act_kernel_params
15
+ {
16
+ const void* x; // [sizeX]
17
+ const void* b; // [sizeB] or NULL
18
+ const void* xref; // [sizeX] or NULL
19
+ const void* yref; // [sizeX] or NULL
20
+ const void* dy; // [sizeX] or NULL
21
+ void* y; // [sizeX]
22
+
23
+ int grad;
24
+ int act;
25
+ float alpha;
26
+ float gain;
27
+ float clamp;
28
+
29
+ int sizeX;
30
+ int sizeB;
31
+ int stepB;
32
+ int loopX;
33
+ };
34
+
35
+ //------------------------------------------------------------------------
36
+ // CUDA kernel selection.
37
+
38
+ template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p);
39
+
40
+ //------------------------------------------------------------------------
torch_utils/ops/bias_act.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ # and proprietary rights in and to this software, related documentation
7
+ # and any modifications thereto. Any use, reproduction, disclosure or
8
+ # distribution of this software and related documentation without an express
9
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+
11
+ """Custom PyTorch ops for efficient bias and activation."""
12
+
13
+ import os
14
+ import warnings
15
+ import numpy as np
16
+ import torch
17
+ import dnnlib
18
+ import traceback
19
+
20
+ from .. import custom_ops
21
+ from .. import misc
22
+
23
+ #----------------------------------------------------------------------------
24
+
25
+ activation_funcs = {
26
+ 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
27
+ 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
28
+ 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
29
+ 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
30
+ 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
31
+ 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
32
+ 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
33
+ 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
34
+ 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
35
+ }
36
+
37
+ #----------------------------------------------------------------------------
38
+
39
+ _inited = False
40
+ _plugin = None
41
+ _null_tensor = torch.empty([0])
42
+
43
+ def _init():
44
+ global _inited, _plugin
45
+ if not _inited:
46
+ _inited = True
47
+ sources = ['bias_act.cpp', 'bias_act.cu']
48
+ sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
49
+ try:
50
+ _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
51
+ except:
52
+ warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
53
+ return _plugin is not None
54
+
55
+ #----------------------------------------------------------------------------
56
+
57
+ def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
58
+ r"""Fused bias and activation function.
59
+
60
+ Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
61
+ and scales the result by `gain`. Each of the steps is optional. In most cases,
62
+ the fused op is considerably more efficient than performing the same calculation
63
+ using standard PyTorch ops. It supports first and second order gradients,
64
+ but not third order gradients.
65
+
66
+ Args:
67
+ x: Input activation tensor. Can be of any shape.
68
+ b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
69
+ as `x`. The shape must be known, and it must match the dimension of `x`
70
+ corresponding to `dim`.
71
+ dim: The dimension in `x` corresponding to the elements of `b`.
72
+ The value of `dim` is ignored if `b` is not specified.
73
+ act: Name of the activation function to evaluate, or `"linear"` to disable.
74
+ Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
75
+ See `activation_funcs` for a full list. `None` is not allowed.
76
+ alpha: Shape parameter for the activation function, or `None` to use the default.
77
+ gain: Scaling factor for the output tensor, or `None` to use default.
78
+ See `activation_funcs` for the default scaling of each activation function.
79
+ If unsure, consider specifying 1.
80
+ clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
81
+ the clamping (default).
82
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
83
+
84
+ Returns:
85
+ Tensor of the same shape and datatype as `x`.
86
+ """
87
+ assert isinstance(x, torch.Tensor)
88
+ assert impl in ['ref', 'cuda']
89
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
90
+ return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
91
+ return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
92
+
93
+ #----------------------------------------------------------------------------
94
+
95
+ @misc.profiled_function
96
+ def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
97
+ """Slow reference implementation of `bias_act()` using standard TensorFlow ops.
98
+ """
99
+ assert isinstance(x, torch.Tensor)
100
+ assert clamp is None or clamp >= 0
101
+ spec = activation_funcs[act]
102
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
103
+ gain = float(gain if gain is not None else spec.def_gain)
104
+ clamp = float(clamp if clamp is not None else -1)
105
+
106
+ # Add bias.
107
+ if b is not None:
108
+ assert isinstance(b, torch.Tensor) and b.ndim == 1
109
+ assert 0 <= dim < x.ndim
110
+ assert b.shape[0] == x.shape[dim]
111
+ x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
112
+
113
+ # Evaluate activation function.
114
+ alpha = float(alpha)
115
+ x = spec.func(x, alpha=alpha)
116
+
117
+ # Scale by gain.
118
+ gain = float(gain)
119
+ if gain != 1:
120
+ x = x * gain
121
+
122
+ # Clamp.
123
+ if clamp >= 0:
124
+ x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
125
+ return x
126
+
127
+ #----------------------------------------------------------------------------
128
+
129
+ _bias_act_cuda_cache = dict()
130
+
131
+ def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
132
+ """Fast CUDA implementation of `bias_act()` using custom ops.
133
+ """
134
+ # Parse arguments.
135
+ assert clamp is None or clamp >= 0
136
+ spec = activation_funcs[act]
137
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
138
+ gain = float(gain if gain is not None else spec.def_gain)
139
+ clamp = float(clamp if clamp is not None else -1)
140
+
141
+ # Lookup from cache.
142
+ key = (dim, act, alpha, gain, clamp)
143
+ if key in _bias_act_cuda_cache:
144
+ return _bias_act_cuda_cache[key]
145
+
146
+ # Forward op.
147
+ class BiasActCuda(torch.autograd.Function):
148
+ @staticmethod
149
+ def forward(ctx, x, b): # pylint: disable=arguments-differ
150
+ ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format
151
+ x = x.contiguous(memory_format=ctx.memory_format)
152
+ b = b.contiguous() if b is not None else _null_tensor
153
+ y = x
154
+ if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
155
+ y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
156
+ ctx.save_for_backward(
157
+ x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
158
+ b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
159
+ y if 'y' in spec.ref else _null_tensor)
160
+ return y
161
+
162
+ @staticmethod
163
+ def backward(ctx, dy): # pylint: disable=arguments-differ
164
+ dy = dy.contiguous(memory_format=ctx.memory_format)
165
+ x, b, y = ctx.saved_tensors
166
+ dx = None
167
+ db = None
168
+
169
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
170
+ dx = dy
171
+ if act != 'linear' or gain != 1 or clamp >= 0:
172
+ dx = BiasActCudaGrad.apply(dy, x, b, y)
173
+
174
+ if ctx.needs_input_grad[1]:
175
+ db = dx.sum([i for i in range(dx.ndim) if i != dim])
176
+
177
+ return dx, db
178
+
179
+ # Backward op.
180
+ class BiasActCudaGrad(torch.autograd.Function):
181
+ @staticmethod
182
+ def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
183
+ ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format
184
+ dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
185
+ ctx.save_for_backward(
186
+ dy if spec.has_2nd_grad else _null_tensor,
187
+ x, b, y)
188
+ return dx
189
+
190
+ @staticmethod
191
+ def backward(ctx, d_dx): # pylint: disable=arguments-differ
192
+ d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
193
+ dy, x, b, y = ctx.saved_tensors
194
+ d_dy = None
195
+ d_x = None
196
+ d_b = None
197
+ d_y = None
198
+
199
+ if ctx.needs_input_grad[0]:
200
+ d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
201
+
202
+ if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
203
+ d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
204
+
205
+ if spec.has_2nd_grad and ctx.needs_input_grad[2]:
206
+ d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
207
+
208
+ return d_dy, d_x, d_b, d_y
209
+
210
+ # Add to cache.
211
+ _bias_act_cuda_cache[key] = BiasActCuda
212
+ return BiasActCuda
213
+
214
+ #----------------------------------------------------------------------------
torch_utils/ops/conv2d_gradfix.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ # and proprietary rights in and to this software, related documentation
7
+ # and any modifications thereto. Any use, reproduction, disclosure or
8
+ # distribution of this software and related documentation without an express
9
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+
11
+ """Custom replacement for `torch.nn.functional.conv2d` that supports
12
+ arbitrarily high order gradients with zero performance penalty."""
13
+
14
+ import warnings
15
+ import contextlib
16
+ import torch
17
+
18
+ # pylint: disable=redefined-builtin
19
+ # pylint: disable=arguments-differ
20
+ # pylint: disable=protected-access
21
+
22
+ #----------------------------------------------------------------------------
23
+
24
+ enabled = False # Enable the custom op by setting this to true.
25
+ weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
26
+
27
+ @contextlib.contextmanager
28
+ def no_weight_gradients():
29
+ global weight_gradients_disabled
30
+ old = weight_gradients_disabled
31
+ weight_gradients_disabled = True
32
+ yield
33
+ weight_gradients_disabled = old
34
+
35
+ #----------------------------------------------------------------------------
36
+
37
+ def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
38
+ if _should_use_custom_op(input):
39
+ return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
40
+ return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
41
+
42
+ def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
43
+ if _should_use_custom_op(input):
44
+ return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
45
+ return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
46
+
47
+ #----------------------------------------------------------------------------
48
+
49
+ def _should_use_custom_op(input):
50
+ assert isinstance(input, torch.Tensor)
51
+ if (not enabled) or (not torch.backends.cudnn.enabled):
52
+ return False
53
+ if input.device.type != 'cuda':
54
+ return False
55
+ if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
56
+ return True
57
+ warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().')
58
+ return False
59
+
60
+ def _tuple_of_ints(xs, ndim):
61
+ xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
62
+ assert len(xs) == ndim
63
+ assert all(isinstance(x, int) for x in xs)
64
+ return xs
65
+
66
+ #----------------------------------------------------------------------------
67
+
68
+ _conv2d_gradfix_cache = dict()
69
+
70
+ def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
71
+ # Parse arguments.
72
+ ndim = 2
73
+ weight_shape = tuple(weight_shape)
74
+ stride = _tuple_of_ints(stride, ndim)
75
+ padding = _tuple_of_ints(padding, ndim)
76
+ output_padding = _tuple_of_ints(output_padding, ndim)
77
+ dilation = _tuple_of_ints(dilation, ndim)
78
+
79
+ # Lookup from cache.
80
+ key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
81
+ if key in _conv2d_gradfix_cache:
82
+ return _conv2d_gradfix_cache[key]
83
+
84
+ # Validate arguments.
85
+ assert groups >= 1
86
+ assert len(weight_shape) == ndim + 2
87
+ assert all(stride[i] >= 1 for i in range(ndim))
88
+ assert all(padding[i] >= 0 for i in range(ndim))
89
+ assert all(dilation[i] >= 0 for i in range(ndim))
90
+ if not transpose:
91
+ assert all(output_padding[i] == 0 for i in range(ndim))
92
+ else: # transpose
93
+ assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
94
+
95
+ # Helpers.
96
+ common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
97
+ def calc_output_padding(input_shape, output_shape):
98
+ if transpose:
99
+ return [0, 0]
100
+ return [
101
+ input_shape[i + 2]
102
+ - (output_shape[i + 2] - 1) * stride[i]
103
+ - (1 - 2 * padding[i])
104
+ - dilation[i] * (weight_shape[i + 2] - 1)
105
+ for i in range(ndim)
106
+ ]
107
+
108
+ # Forward & backward.
109
+ class Conv2d(torch.autograd.Function):
110
+ @staticmethod
111
+ def forward(ctx, input, weight, bias):
112
+ assert weight.shape == weight_shape
113
+ if not transpose:
114
+ output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
115
+ else: # transpose
116
+ output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
117
+ ctx.save_for_backward(input, weight)
118
+ return output
119
+
120
+ @staticmethod
121
+ def backward(ctx, grad_output):
122
+ input, weight = ctx.saved_tensors
123
+ grad_input = None
124
+ grad_weight = None
125
+ grad_bias = None
126
+
127
+ if ctx.needs_input_grad[0]:
128
+ p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
129
+ grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None)
130
+ assert grad_input.shape == input.shape
131
+
132
+ if ctx.needs_input_grad[1] and not weight_gradients_disabled:
133
+ grad_weight = Conv2dGradWeight.apply(grad_output, input)
134
+ assert grad_weight.shape == weight_shape
135
+
136
+ if ctx.needs_input_grad[2]:
137
+ grad_bias = grad_output.sum([0, 2, 3])
138
+
139
+ return grad_input, grad_weight, grad_bias
140
+
141
+ # Gradient with respect to the weights.
142
+ class Conv2dGradWeight(torch.autograd.Function):
143
+ @staticmethod
144
+ def forward(ctx, grad_output, input):
145
+ op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight')
146
+ flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
147
+ grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
148
+ assert grad_weight.shape == weight_shape
149
+ ctx.save_for_backward(grad_output, input)
150
+ return grad_weight
151
+
152
+ @staticmethod
153
+ def backward(ctx, grad2_grad_weight):
154
+ grad_output, input = ctx.saved_tensors
155
+ grad2_grad_output = None
156
+ grad2_input = None
157
+
158
+ if ctx.needs_input_grad[0]:
159
+ grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
160
+ assert grad2_grad_output.shape == grad_output.shape
161
+
162
+ if ctx.needs_input_grad[1]:
163
+ p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
164
+ grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None)
165
+ assert grad2_input.shape == input.shape
166
+
167
+ return grad2_grad_output, grad2_input
168
+
169
+ _conv2d_gradfix_cache[key] = Conv2d
170
+ return Conv2d
171
+
172
+ #----------------------------------------------------------------------------
torch_utils/ops/conv2d_resample.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ # and proprietary rights in and to this software, related documentation
7
+ # and any modifications thereto. Any use, reproduction, disclosure or
8
+ # distribution of this software and related documentation without an express
9
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+
11
+ """2D convolution with optional up/downsampling."""
12
+
13
+ import torch
14
+
15
+ from .. import misc
16
+ from . import conv2d_gradfix
17
+ from . import upfirdn2d
18
+ from .upfirdn2d import _parse_padding
19
+ from .upfirdn2d import _get_filter_size
20
+
21
+ #----------------------------------------------------------------------------
22
+
23
+ def _get_weight_shape(w):
24
+ with misc.suppress_tracer_warnings(): # this value will be treated as a constant
25
+ shape = [int(sz) for sz in w.shape]
26
+ misc.assert_shape(w, shape)
27
+ return shape
28
+
29
+ #----------------------------------------------------------------------------
30
+
31
+ def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
32
+ """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
33
+ """
34
+ out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
35
+
36
+ # Flip weight if requested.
37
+ if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
38
+ w = w.flip([2, 3])
39
+
40
+ # Workaround performance pitfall in cuDNN 8.0.5, triggered when using
41
+ # 1x1 kernel + memory_format=channels_last + less than 64 channels.
42
+ if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose:
43
+ if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64:
44
+ if out_channels <= 4 and groups == 1:
45
+ in_shape = x.shape
46
+ x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1])
47
+ x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]])
48
+ else:
49
+ x = x.to(memory_format=torch.contiguous_format)
50
+ w = w.to(memory_format=torch.contiguous_format)
51
+ x = conv2d_gradfix.conv2d(x, w, groups=groups)
52
+ return x.to(memory_format=torch.channels_last)
53
+
54
+ # Otherwise => execute using conv2d_gradfix.
55
+ op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
56
+ return op(x, w, stride=stride, padding=padding, groups=groups)
57
+
58
+ #----------------------------------------------------------------------------
59
+
60
+ @misc.profiled_function
61
+ def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
62
+ r"""2D convolution with optional up/downsampling.
63
+
64
+ Padding is performed only once at the beginning, not between the operations.
65
+
66
+ Args:
67
+ x: Input tensor of shape
68
+ `[batch_size, in_channels, in_height, in_width]`.
69
+ w: Weight tensor of shape
70
+ `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
71
+ f: Low-pass filter for up/downsampling. Must be prepared beforehand by
72
+ calling upfirdn2d.setup_filter(). None = identity (default).
73
+ up: Integer upsampling factor (default: 1).
74
+ down: Integer downsampling factor (default: 1).
75
+ padding: Padding with respect to the upsampled image. Can be a single number
76
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
77
+ (default: 0).
78
+ groups: Split input channels into N groups (default: 1).
79
+ flip_weight: False = convolution, True = correlation (default: True).
80
+ flip_filter: False = convolution, True = correlation (default: False).
81
+
82
+ Returns:
83
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
84
+ """
85
+ # Validate arguments.
86
+ assert isinstance(x, torch.Tensor) and (x.ndim == 4)
87
+ assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
88
+ assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
89
+ assert isinstance(up, int) and (up >= 1)
90
+ assert isinstance(down, int) and (down >= 1)
91
+ assert isinstance(groups, int) and (groups >= 1)
92
+ out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
93
+ fw, fh = _get_filter_size(f)
94
+ px0, px1, py0, py1 = _parse_padding(padding)
95
+
96
+ # Adjust padding to account for up/downsampling.
97
+ if up > 1:
98
+ px0 += (fw + up - 1) // 2
99
+ px1 += (fw - up) // 2
100
+ py0 += (fh + up - 1) // 2
101
+ py1 += (fh - up) // 2
102
+ if down > 1:
103
+ px0 += (fw - down + 1) // 2
104
+ px1 += (fw - down) // 2
105
+ py0 += (fh - down + 1) // 2
106
+ py1 += (fh - down) // 2
107
+
108
+ # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
109
+ if kw == 1 and kh == 1 and (down > 1 and up == 1):
110
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
111
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
112
+ return x
113
+
114
+ # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
115
+ if kw == 1 and kh == 1 and (up > 1 and down == 1):
116
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
117
+ x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
118
+ return x
119
+
120
+ # Fast path: downsampling only => use strided convolution.
121
+ if down > 1 and up == 1:
122
+ x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
123
+ x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
124
+ return x
125
+
126
+ # Fast path: upsampling with optional downsampling => use transpose strided convolution.
127
+ if up > 1:
128
+ if groups == 1:
129
+ w = w.transpose(0, 1)
130
+ else:
131
+ w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
132
+ w = w.transpose(1, 2)
133
+ w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
134
+ px0 -= kw - 1
135
+ px1 -= kw - up
136
+ py0 -= kh - 1
137
+ py1 -= kh - up
138
+ pxt = max(min(-px0, -px1), 0)
139
+ pyt = max(min(-py0, -py1), 0)
140
+ x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
141
+ x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
142
+ if down > 1:
143
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
144
+ return x
145
+
146
+ # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
147
+ if up == 1 and down == 1:
148
+ if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
149
+ return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
150
+
151
+ # Fallback: Generic reference implementation.
152
+ x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
153
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
154
+ if down > 1:
155
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
156
+ return x
157
+
158
+ #----------------------------------------------------------------------------
torch_utils/ops/filtered_lrelu.cpp ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <torch/extension.h>
10
+ #include <ATen/cuda/CUDAContext.h>
11
+ #include <c10/cuda/CUDAGuard.h>
12
+ #include "filtered_lrelu.h"
13
+
14
+ //------------------------------------------------------------------------
15
+
16
+ static std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu(
17
+ torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b, torch::Tensor si,
18
+ int up, int down, int px0, int px1, int py0, int py1, int sx, int sy, float gain, float slope, float clamp, bool flip_filters, bool writeSigns)
19
+ {
20
+ // Set CUDA device.
21
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
22
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
23
+
24
+ // Validate arguments.
25
+ TORCH_CHECK(fu.device() == x.device() && fd.device() == x.device() && b.device() == x.device(), "all input tensors must reside on the same device");
26
+ TORCH_CHECK(fu.dtype() == torch::kFloat && fd.dtype() == torch::kFloat, "fu and fd must be float32");
27
+ TORCH_CHECK(b.dtype() == x.dtype(), "x and b must have the same dtype");
28
+ TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat, "x and b must be float16 or float32");
29
+ TORCH_CHECK(x.dim() == 4, "x must be rank 4");
30
+ TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large");
31
+ TORCH_CHECK(x.numel() > 0, "x is empty");
32
+ TORCH_CHECK((fu.dim() == 1 || fu.dim() == 2) && (fd.dim() == 1 || fd.dim() == 2), "fu and fd must be rank 1 or 2");
33
+ TORCH_CHECK(fu.size(0) <= INT_MAX && fu.size(-1) <= INT_MAX, "fu is too large");
34
+ TORCH_CHECK(fd.size(0) <= INT_MAX && fd.size(-1) <= INT_MAX, "fd is too large");
35
+ TORCH_CHECK(fu.numel() > 0, "fu is empty");
36
+ TORCH_CHECK(fd.numel() > 0, "fd is empty");
37
+ TORCH_CHECK(b.dim() == 1 && b.size(0) == x.size(1), "b must be a vector with the same number of channels as x");
38
+ TORCH_CHECK(up >= 1 && down >= 1, "up and down must be at least 1");
39
+
40
+ // Figure out how much shared memory is available on the device.
41
+ int maxSharedBytes = 0;
42
+ AT_CUDA_CHECK(cudaDeviceGetAttribute(&maxSharedBytes, cudaDevAttrMaxSharedMemoryPerBlockOptin, x.device().index()));
43
+ int sharedKB = maxSharedBytes >> 10;
44
+
45
+ // Populate enough launch parameters to check if a CUDA kernel exists.
46
+ filtered_lrelu_kernel_params p;
47
+ p.up = up;
48
+ p.down = down;
49
+ p.fuShape = make_int2((int)fu.size(-1), fu.dim() == 2 ? (int)fu.size(0) : 0); // shape [n, 0] indicates separable filter.
50
+ p.fdShape = make_int2((int)fd.size(-1), fd.dim() == 2 ? (int)fd.size(0) : 0);
51
+ filtered_lrelu_kernel_spec test_spec = choose_filtered_lrelu_kernel<float, int32_t, false, false>(p, sharedKB);
52
+ if (!test_spec.exec)
53
+ {
54
+ // No kernel found - return empty tensors and indicate missing kernel with return code of -1.
55
+ return std::make_tuple(torch::Tensor(), torch::Tensor(), -1);
56
+ }
57
+
58
+ // Input/output element size.
59
+ int64_t sz = (x.dtype() == torch::kHalf) ? 2 : 4;
60
+
61
+ // Input sizes.
62
+ int64_t xw = (int)x.size(3);
63
+ int64_t xh = (int)x.size(2);
64
+ int64_t fut_w = (int)fu.size(-1) - 1;
65
+ int64_t fut_h = (int)fu.size(0) - 1;
66
+ int64_t fdt_w = (int)fd.size(-1) - 1;
67
+ int64_t fdt_h = (int)fd.size(0) - 1;
68
+
69
+ // Logical size of upsampled buffer.
70
+ int64_t cw = xw * up + (px0 + px1) - fut_w;
71
+ int64_t ch = xh * up + (py0 + py1) - fut_h;
72
+ TORCH_CHECK(cw > fdt_w && ch > fdt_h, "upsampled buffer must be at least the size of downsampling filter");
73
+ TORCH_CHECK(cw <= INT_MAX && ch <= INT_MAX, "upsampled buffer is too large");
74
+
75
+ // Compute output size and allocate.
76
+ int64_t yw = (cw - fdt_w + (down - 1)) / down;
77
+ int64_t yh = (ch - fdt_h + (down - 1)) / down;
78
+ TORCH_CHECK(yw > 0 && yh > 0, "output must be at least 1x1");
79
+ TORCH_CHECK(yw <= INT_MAX && yh <= INT_MAX, "output is too large");
80
+ torch::Tensor y = torch::empty({x.size(0), x.size(1), yh, yw}, x.options(), x.suggest_memory_format());
81
+
82
+ // Allocate sign tensor.
83
+ torch::Tensor so;
84
+ torch::Tensor s = si;
85
+ bool readSigns = !!s.numel();
86
+ int64_t sw_active = 0; // Active width of sign tensor.
87
+ if (writeSigns)
88
+ {
89
+ sw_active = yw * down - (down - 1) + fdt_w; // Active width in elements.
90
+ int64_t sh = yh * down - (down - 1) + fdt_h; // Height = active height.
91
+ int64_t sw = (sw_active + 15) & ~15; // Width = active width in elements, rounded up to multiple of 16.
92
+ TORCH_CHECK(sh <= INT_MAX && (sw >> 2) <= INT_MAX, "signs is too large");
93
+ s = so = torch::empty({x.size(0), x.size(1), sh, sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous);
94
+ }
95
+ else if (readSigns)
96
+ sw_active = s.size(3) << 2;
97
+
98
+ // Validate sign tensor if in use.
99
+ if (readSigns || writeSigns)
100
+ {
101
+ TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
102
+ TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
103
+ TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x");
104
+ TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
105
+ TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x");
106
+ TORCH_CHECK(s.size(2) <= INT_MAX && s.size(3) <= INT_MAX, "signs is too large");
107
+ }
108
+
109
+ // Populate rest of CUDA kernel parameters.
110
+ p.x = x.data_ptr();
111
+ p.y = y.data_ptr();
112
+ p.b = b.data_ptr();
113
+ p.s = (readSigns || writeSigns) ? s.data_ptr<unsigned char>() : 0;
114
+ p.fu = fu.data_ptr<float>();
115
+ p.fd = fd.data_ptr<float>();
116
+ p.pad0 = make_int2(px0, py0);
117
+ p.gain = gain;
118
+ p.slope = slope;
119
+ p.clamp = clamp;
120
+ p.flip = (flip_filters) ? 1 : 0;
121
+ p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
122
+ p.yShape = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
123
+ p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3), (int)s.size(2)) : make_int2(0, 0); // Width is in bytes. Contiguous.
124
+ p.sOfs = make_int2(sx, sy);
125
+ p.swLimit = (sw_active + 3) >> 2; // Rounded up to bytes.
126
+
127
+ // x, y, b strides are in bytes.
128
+ p.xStride = make_longlong4(sz * x.stride(3), sz * x.stride(2), sz * x.stride(1), sz * x.stride(0));
129
+ p.yStride = make_longlong4(sz * y.stride(3), sz * y.stride(2), sz * y.stride(1), sz * y.stride(0));
130
+ p.bStride = sz * b.stride(0);
131
+
132
+ // fu, fd strides are in elements.
133
+ p.fuStride = make_longlong3(fu.stride(-1), fu.dim() == 2 ? fu.stride(0) : 0, 0);
134
+ p.fdStride = make_longlong3(fd.stride(-1), fd.dim() == 2 ? fd.stride(0) : 0, 0);
135
+
136
+ // Determine if indices don't fit in int32. Support negative strides although Torch currently never produces those.
137
+ bool index64b = false;
138
+ if (std::abs(p.bStride * x.size(1)) > INT_MAX) index64b = true;
139
+ if (std::min(x.size(0) * p.xStride.w, 0ll) + std::min(x.size(1) * p.xStride.z, 0ll) + std::min(x.size(2) * p.xStride.y, 0ll) + std::min(x.size(3) * p.xStride.x, 0ll) < -INT_MAX) index64b = true;
140
+ if (std::max(x.size(0) * p.xStride.w, 0ll) + std::max(x.size(1) * p.xStride.z, 0ll) + std::max(x.size(2) * p.xStride.y, 0ll) + std::max(x.size(3) * p.xStride.x, 0ll) > INT_MAX) index64b = true;
141
+ if (std::min(y.size(0) * p.yStride.w, 0ll) + std::min(y.size(1) * p.yStride.z, 0ll) + std::min(y.size(2) * p.yStride.y, 0ll) + std::min(y.size(3) * p.yStride.x, 0ll) < -INT_MAX) index64b = true;
142
+ if (std::max(y.size(0) * p.yStride.w, 0ll) + std::max(y.size(1) * p.yStride.z, 0ll) + std::max(y.size(2) * p.yStride.y, 0ll) + std::max(y.size(3) * p.yStride.x, 0ll) > INT_MAX) index64b = true;
143
+ if (s.numel() > INT_MAX) index64b = true;
144
+
145
+ // Choose CUDA kernel.
146
+ filtered_lrelu_kernel_spec spec = { 0 };
147
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_cuda", [&]
148
+ {
149
+ if constexpr (sizeof(scalar_t) <= 4) // Exclude doubles. constexpr prevents template instantiation.
150
+ {
151
+ // Choose kernel based on index type, datatype and sign read/write modes.
152
+ if (!index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, true, false>(p, sharedKB);
153
+ else if (!index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, false, true >(p, sharedKB);
154
+ else if (!index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, false, false>(p, sharedKB);
155
+ else if ( index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, true, false>(p, sharedKB);
156
+ else if ( index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, false, true >(p, sharedKB);
157
+ else if ( index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, false, false>(p, sharedKB);
158
+ }
159
+ });
160
+ TORCH_CHECK(spec.exec, "internal error - CUDA kernel not found") // This should not happen because we tested earlier that kernel exists.
161
+
162
+ // Launch CUDA kernel.
163
+ void* args[] = {&p};
164
+ int bx = spec.numWarps * 32;
165
+ int gx = (p.yShape.x - 1) / spec.tileOut.x + 1;
166
+ int gy = (p.yShape.y - 1) / spec.tileOut.y + 1;
167
+ int gz = p.yShape.z * p.yShape.w;
168
+
169
+ // Repeat multiple horizontal tiles in a CTA?
170
+ if (spec.xrep)
171
+ {
172
+ p.tilesXrep = spec.xrep;
173
+ p.tilesXdim = gx;
174
+
175
+ gx = (gx + p.tilesXrep - 1) / p.tilesXrep;
176
+ std::swap(gx, gy);
177
+ }
178
+ else
179
+ {
180
+ p.tilesXrep = 0;
181
+ p.tilesXdim = 0;
182
+ }
183
+
184
+ // Launch filter setup kernel.
185
+ AT_CUDA_CHECK(cudaLaunchKernel(spec.setup, 1, 1024, args, 0, at::cuda::getCurrentCUDAStream()));
186
+
187
+ // Copy kernels to constant memory.
188
+ if ( writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters<true, false>(at::cuda::getCurrentCUDAStream())));
189
+ else if (!writeSigns && readSigns) AT_CUDA_CHECK((copy_filters<false, true >(at::cuda::getCurrentCUDAStream())));
190
+ else if (!writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters<false, false>(at::cuda::getCurrentCUDAStream())));
191
+
192
+ // Set cache and shared memory configurations for main kernel.
193
+ AT_CUDA_CHECK(cudaFuncSetCacheConfig(spec.exec, cudaFuncCachePreferShared));
194
+ if (spec.dynamicSharedKB) // Need dynamically allocated shared memory?
195
+ AT_CUDA_CHECK(cudaFuncSetAttribute(spec.exec, cudaFuncAttributeMaxDynamicSharedMemorySize, spec.dynamicSharedKB << 10));
196
+ AT_CUDA_CHECK(cudaFuncSetSharedMemConfig(spec.exec, cudaSharedMemBankSizeFourByte));
197
+
198
+ // Launch main kernel.
199
+ const int maxSubGz = 65535; // CUDA maximum for block z dimension.
200
+ for (int zofs=0; zofs < gz; zofs += maxSubGz) // Do multiple launches if gz is too big.
201
+ {
202
+ p.blockZofs = zofs;
203
+ int subGz = std::min(maxSubGz, gz - zofs);
204
+ AT_CUDA_CHECK(cudaLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args, spec.dynamicSharedKB << 10, at::cuda::getCurrentCUDAStream()));
205
+ }
206
+
207
+ // Done.
208
+ return std::make_tuple(y, so, 0);
209
+ }
210
+
211
+ //------------------------------------------------------------------------
212
+
213
+ static torch::Tensor filtered_lrelu_act(torch::Tensor x, torch::Tensor si, int sx, int sy, float gain, float slope, float clamp, bool writeSigns)
214
+ {
215
+ // Set CUDA device.
216
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
217
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
218
+
219
+ // Validate arguments.
220
+ TORCH_CHECK(x.dim() == 4, "x must be rank 4");
221
+ TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large");
222
+ TORCH_CHECK(x.numel() > 0, "x is empty");
223
+ TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat || x.dtype() == torch::kDouble, "x must be float16, float32 or float64");
224
+
225
+ // Output signs if we don't have sign input.
226
+ torch::Tensor so;
227
+ torch::Tensor s = si;
228
+ bool readSigns = !!s.numel();
229
+ if (writeSigns)
230
+ {
231
+ int64_t sw = x.size(3);
232
+ sw = (sw + 15) & ~15; // Round to a multiple of 16 for coalescing.
233
+ s = so = torch::empty({x.size(0), x.size(1), x.size(2), sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous);
234
+ }
235
+
236
+ // Validate sign tensor if in use.
237
+ if (readSigns || writeSigns)
238
+ {
239
+ TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
240
+ TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
241
+ TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x");
242
+ TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
243
+ TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x");
244
+ TORCH_CHECK(s.size(2) <= INT_MAX && (s.size(3) << 2) <= INT_MAX, "signs tensor is too large");
245
+ }
246
+
247
+ // Initialize CUDA kernel parameters.
248
+ filtered_lrelu_act_kernel_params p;
249
+ p.x = x.data_ptr();
250
+ p.s = (readSigns || writeSigns) ? s.data_ptr<unsigned char>() : 0;
251
+ p.gain = gain;
252
+ p.slope = slope;
253
+ p.clamp = clamp;
254
+ p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
255
+ p.xStride = make_longlong4(x.stride(3), x.stride(2), x.stride(1), x.stride(0));
256
+ p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3) << 2, (int)s.size(2)) : make_int2(0, 0); // Width is in elements. Contiguous.
257
+ p.sOfs = make_int2(sx, sy);
258
+
259
+ // Choose CUDA kernel.
260
+ void* func = 0;
261
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_act_cuda", [&]
262
+ {
263
+ if (writeSigns)
264
+ func = choose_filtered_lrelu_act_kernel<scalar_t, true, false>();
265
+ else if (readSigns)
266
+ func = choose_filtered_lrelu_act_kernel<scalar_t, false, true>();
267
+ else
268
+ func = choose_filtered_lrelu_act_kernel<scalar_t, false, false>();
269
+ });
270
+ TORCH_CHECK(func, "internal error - CUDA kernel not found");
271
+
272
+ // Launch CUDA kernel.
273
+ void* args[] = {&p};
274
+ int bx = 128; // 4 warps per block.
275
+
276
+ // Logical size of launch = writeSigns ? p.s : p.x
277
+ uint32_t gx = writeSigns ? p.sShape.x : p.xShape.x;
278
+ uint32_t gy = writeSigns ? p.sShape.y : p.xShape.y;
279
+ uint32_t gz = p.xShape.z * p.xShape.w; // Same as in p.sShape if signs are in use.
280
+ gx = (gx - 1) / bx + 1;
281
+
282
+ // Make sure grid y and z dimensions are within CUDA launch limits. Kernel loops internally to do the rest.
283
+ const uint32_t gmax = 65535;
284
+ gy = std::min(gy, gmax);
285
+ gz = std::min(gz, gmax);
286
+
287
+ // Launch.
288
+ AT_CUDA_CHECK(cudaLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0, at::cuda::getCurrentCUDAStream()));
289
+ return so;
290
+ }
291
+
292
+ //------------------------------------------------------------------------
293
+
294
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
295
+ {
296
+ m.def("filtered_lrelu", &filtered_lrelu); // The whole thing.
297
+ m.def("filtered_lrelu_act_", &filtered_lrelu_act); // Activation and sign tensor handling only. Modifies data tensor in-place.
298
+ }
299
+
300
+ //------------------------------------------------------------------------
torch_utils/ops/filtered_lrelu.cu ADDED
@@ -0,0 +1,1284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <c10/util/Half.h>
10
+ #include "filtered_lrelu.h"
11
+ #include <cstdint>
12
+
13
+ //------------------------------------------------------------------------
14
+ // Helpers.
15
+
16
+ enum // Filter modes.
17
+ {
18
+ MODE_SUSD = 0, // Separable upsampling, separable downsampling.
19
+ MODE_FUSD = 1, // Full upsampling, separable downsampling.
20
+ MODE_SUFD = 2, // Separable upsampling, full downsampling.
21
+ MODE_FUFD = 3, // Full upsampling, full downsampling.
22
+ };
23
+
24
+ template <class T> struct InternalType;
25
+ template <> struct InternalType<double>
26
+ {
27
+ typedef double scalar_t; typedef double2 vec2_t; typedef double4 vec4_t;
28
+ __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_double2(0, 0); }
29
+ __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_double4(0, 0, 0, 0); }
30
+ __device__ __forceinline__ static double clamp(double x, double c) { return fmin(fmax(x, -c), c); }
31
+ };
32
+ template <> struct InternalType<float>
33
+ {
34
+ typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t;
35
+ __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); }
36
+ __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); }
37
+ __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); }
38
+ };
39
+ template <> struct InternalType<c10::Half>
40
+ {
41
+ typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t;
42
+ __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); }
43
+ __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); }
44
+ __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); }
45
+ };
46
+
47
+ #define MIN(A, B) ((A) < (B) ? (A) : (B))
48
+ #define MAX(A, B) ((A) > (B) ? (A) : (B))
49
+ #define CEIL_DIV(A, B) (((B)==1) ? (A) : \
50
+ ((B)==2) ? ((int)((A)+1) >> 1) : \
51
+ ((B)==4) ? ((int)((A)+3) >> 2) : \
52
+ (((A) + ((A) > 0 ? (B) - 1 : 0)) / (B)))
53
+
54
+ // This works only up to blocks of size 256 x 256 and for all N that are powers of two.
55
+ template <int N> __device__ __forceinline__ void fast_div_mod(int& x, int& y, unsigned int i)
56
+ {
57
+ if ((N & (N-1)) && N <= 256)
58
+ y = (i * ((1<<24)/N + 1)) >> 24; // Assumes N <= 256, i < N*256.
59
+ else
60
+ y = i/N;
61
+
62
+ x = i - y*N;
63
+ }
64
+
65
+ // Type cast stride before reading it.
66
+ template <class T> __device__ __forceinline__ T get_stride(const int64_t& x)
67
+ {
68
+ return *reinterpret_cast<const T*>(&x);
69
+ }
70
+
71
+ //------------------------------------------------------------------------
72
+ // Filters, setup kernel, copying function.
73
+
74
+ #define MAX_FILTER_SIZE 32
75
+
76
+ // Combined up/down filter buffers so that transfer can be done with one copy.
77
+ __device__ float g_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in global memory, written by setup kernel.
78
+ __device__ __constant__ float c_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in constant memory, read by main kernel.
79
+
80
+ // Accessors to combined buffers to index up/down filters individually.
81
+ #define c_fu (c_fbuf)
82
+ #define c_fd (c_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE)
83
+ #define g_fu (g_fbuf)
84
+ #define g_fd (g_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE)
85
+
86
+ // Set up filters into global memory buffer.
87
+ static __global__ void setup_filters_kernel(filtered_lrelu_kernel_params p)
88
+ {
89
+ for (int idx = threadIdx.x; idx < MAX_FILTER_SIZE * MAX_FILTER_SIZE; idx += blockDim.x)
90
+ {
91
+ int x, y;
92
+ fast_div_mod<MAX_FILTER_SIZE>(x, y, idx);
93
+
94
+ int fu_x = p.flip ? x : (p.fuShape.x - 1 - x);
95
+ int fu_y = p.flip ? y : (p.fuShape.y - 1 - y);
96
+ if (p.fuShape.y > 0)
97
+ g_fu[idx] = (x >= p.fuShape.x || y >= p.fuShape.y) ? 0.0f : p.fu[fu_x * p.fuStride.x + fu_y * p.fuStride.y];
98
+ else
99
+ g_fu[idx] = (x >= p.fuShape.x || y > 0) ? 0.0f : p.fu[fu_x * p.fuStride.x];
100
+
101
+ int fd_x = p.flip ? x : (p.fdShape.x - 1 - x);
102
+ int fd_y = p.flip ? y : (p.fdShape.y - 1 - y);
103
+ if (p.fdShape.y > 0)
104
+ g_fd[idx] = (x >= p.fdShape.x || y >= p.fdShape.y) ? 0.0f : p.fd[fd_x * p.fdStride.x + fd_y * p.fdStride.y];
105
+ else
106
+ g_fd[idx] = (x >= p.fdShape.x || y > 0) ? 0.0f : p.fd[fd_x * p.fdStride.x];
107
+ }
108
+ }
109
+
110
+ // Host function to copy filters written by setup kernel into constant buffer for main kernel.
111
+ template <bool, bool> static cudaError_t copy_filters(cudaStream_t stream)
112
+ {
113
+ void* src = 0;
114
+ cudaError_t err = cudaGetSymbolAddress(&src, g_fbuf);
115
+ if (err) return err;
116
+ return cudaMemcpyToSymbolAsync(c_fbuf, src, 2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE * sizeof(float), 0, cudaMemcpyDeviceToDevice, stream);
117
+ }
118
+
119
+ //------------------------------------------------------------------------
120
+ // Coordinate spaces:
121
+ // - Relative to input tensor: inX, inY, tileInX, tileInY
122
+ // - Relative to input tile: relInX, relInY, tileInW, tileInH
123
+ // - Relative to upsampled tile: relUpX, relUpY, tileUpW, tileUpH
124
+ // - Relative to output tile: relOutX, relOutY, tileOutW, tileOutH
125
+ // - Relative to output tensor: outX, outY, tileOutX, tileOutY
126
+ //
127
+ // Relationships between coordinate spaces:
128
+ // - inX = tileInX + relInX
129
+ // - inY = tileInY + relInY
130
+ // - relUpX = relInX * up + phaseInX
131
+ // - relUpY = relInY * up + phaseInY
132
+ // - relUpX = relOutX * down
133
+ // - relUpY = relOutY * down
134
+ // - outX = tileOutX + relOutX
135
+ // - outY = tileOutY + relOutY
136
+
137
+ extern __shared__ char s_buf_raw[]; // When sharedKB <= 48, allocate shared memory statically inside the kernel, otherwise use the externally allocated shared memory buffer.
138
+
139
+ template <class T, class index_t, int sharedKB, bool signWrite, bool signRead, int filterMode, int up, int fuSize, int down, int fdSize, int tileOutW, int tileOutH, int threadsPerBlock, bool enableXrep, bool enableWriteSkip>
140
+ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p)
141
+ {
142
+ // Check that we don't try to support non-existing filter modes.
143
+ static_assert(up == 1 || up == 2 || up == 4, "only up=1, up=2, up=4 scales supported");
144
+ static_assert(down == 1 || down == 2 || down == 4, "only down=1, down=2, down=4 scales supported");
145
+ static_assert(fuSize >= up, "upsampling filter size must be at least upsampling factor");
146
+ static_assert(fdSize >= down, "downsampling filter size must be at least downsampling factor");
147
+ static_assert(fuSize % up == 0, "upsampling filter size must be divisible with upsampling factor");
148
+ static_assert(fdSize % down == 0, "downsampling filter size must be divisible with downsampling factor");
149
+ static_assert(fuSize <= MAX_FILTER_SIZE && fdSize <= MAX_FILTER_SIZE, "filter size greater than MAX_FILTER_SIZE");
150
+ static_assert(up != 1 || (fuSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "up=1 supported only for 1x1 full filters");
151
+ static_assert(down != 1 || (fdSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "down=1 supported only for 1x1 full filters");
152
+ static_assert(!(up == 4 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "full filters not supported for up=4");
153
+ static_assert(!(down == 4 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "full filters not supported for down=4");
154
+
155
+ // Static definitions.
156
+ typedef typename InternalType<T>::scalar_t scalar_t;
157
+ typedef typename InternalType<T>::vec2_t vec2_t;
158
+ typedef typename InternalType<T>::vec4_t vec4_t;
159
+ const int tileUpW = (tileOutW * down + (fdSize - 1) - (down - 1) + 3) & ~3; // Upsampled tile width, rounded up to multiple of 4.
160
+ const int tileUpH = tileOutH * down + (fdSize - 1) - (down - 1); // Upsampled tile height.
161
+ const int tileInW = CEIL_DIV(tileUpW + (fuSize - 1), up); // Input tile width.
162
+ const int tileInH = CEIL_DIV(tileUpH + (fuSize - 1), up); // Input tile height.
163
+ const int tileUpH_up = CEIL_DIV(tileUpH, up) * up; // Upsampled tile height rounded up to a multiple of up.
164
+ const int tileInH_up = CEIL_DIV(tileUpH_up + (fuSize - 1), up); // For allocations only, to avoid shared memory read overruns with up=2 and up=4.
165
+
166
+ // Merge 1x1 downsampling into last upsampling step for upf1 and ups2.
167
+ const bool downInline = (down == 1) && ((up == 1 && filterMode == MODE_FUFD) || (up == 2 && filterMode == MODE_SUFD));
168
+
169
+ // Sizes of logical buffers.
170
+ const int szIn = tileInH_up * tileInW;
171
+ const int szUpX = tileInH_up * tileUpW;
172
+ const int szUpXY = downInline ? 0 : (tileUpH * tileUpW);
173
+ const int szDownX = tileUpH * tileOutW;
174
+
175
+ // Sizes for shared memory arrays.
176
+ const int s_buf0_size_base =
177
+ (filterMode == MODE_SUSD) ? MAX(szIn, szUpXY) :
178
+ (filterMode == MODE_FUSD) ? MAX(szIn, szDownX) :
179
+ (filterMode == MODE_SUFD) ? MAX(szIn, szUpXY) :
180
+ (filterMode == MODE_FUFD) ? szIn :
181
+ -1;
182
+ const int s_buf1_size_base =
183
+ (filterMode == MODE_SUSD) ? MAX(szUpX, szDownX) :
184
+ (filterMode == MODE_FUSD) ? szUpXY :
185
+ (filterMode == MODE_SUFD) ? szUpX :
186
+ (filterMode == MODE_FUFD) ? szUpXY :
187
+ -1;
188
+
189
+ // Ensure U128 alignment.
190
+ const int s_buf0_size = (s_buf0_size_base + 3) & ~3;
191
+ const int s_buf1_size = (s_buf1_size_base + 3) & ~3;
192
+
193
+ // Check at compile time that we don't use too much shared memory.
194
+ static_assert((s_buf0_size + s_buf1_size) * sizeof(scalar_t) <= (sharedKB << 10), "shared memory overflow");
195
+
196
+ // Declare shared memory arrays.
197
+ scalar_t* s_buf0;
198
+ scalar_t* s_buf1;
199
+ if (sharedKB <= 48)
200
+ {
201
+ // Allocate shared memory arrays here.
202
+ __shared__ scalar_t s_buf0_st[(sharedKB > 48) ? (1<<24) : (s_buf0_size + s_buf1_size)]; // Prevent launching if this isn't optimized away when unused.
203
+ s_buf0 = s_buf0_st;
204
+ s_buf1 = s_buf0 + s_buf0_size;
205
+ }
206
+ else
207
+ {
208
+ // Use the dynamically allocated shared memory array.
209
+ s_buf0 = (scalar_t*)s_buf_raw;
210
+ s_buf1 = s_buf0 + s_buf0_size;
211
+ }
212
+
213
+ // Pointers to the buffers.
214
+ scalar_t* s_tileIn; // Input tile: [relInX * tileInH + relInY]
215
+ scalar_t* s_tileUpX; // After horizontal upsampling: [relInY * tileUpW + relUpX]
216
+ scalar_t* s_tileUpXY; // After upsampling: [relUpY * tileUpW + relUpX]
217
+ scalar_t* s_tileDownX; // After horizontal downsampling: [relUpY * tileOutW + relOutX]
218
+ if (filterMode == MODE_SUSD)
219
+ {
220
+ s_tileIn = s_buf0;
221
+ s_tileUpX = s_buf1;
222
+ s_tileUpXY = s_buf0;
223
+ s_tileDownX = s_buf1;
224
+ }
225
+ else if (filterMode == MODE_FUSD)
226
+ {
227
+ s_tileIn = s_buf0;
228
+ s_tileUpXY = s_buf1;
229
+ s_tileDownX = s_buf0;
230
+ }
231
+ else if (filterMode == MODE_SUFD)
232
+ {
233
+ s_tileIn = s_buf0;
234
+ s_tileUpX = s_buf1;
235
+ s_tileUpXY = s_buf0;
236
+ }
237
+ else if (filterMode == MODE_FUFD)
238
+ {
239
+ s_tileIn = s_buf0;
240
+ s_tileUpXY = s_buf1;
241
+ }
242
+
243
+ // Allow large grids in z direction via per-launch offset.
244
+ int channelIdx = blockIdx.z + p.blockZofs;
245
+ int batchIdx = channelIdx / p.yShape.z;
246
+ channelIdx -= batchIdx * p.yShape.z;
247
+
248
+ // Offset to output feature map. In bytes.
249
+ index_t mapOfsOut = channelIdx * get_stride<index_t>(p.yStride.z) + batchIdx * get_stride<index_t>(p.yStride.w);
250
+
251
+ // Sign shift amount.
252
+ uint32_t signXo = ((threadIdx.x + p.sOfs.x) << 1) & 6;
253
+
254
+ // Inner tile loop.
255
+ #pragma unroll 1
256
+ for (int tileIdx = 0; !enableXrep || (tileIdx < MIN(p.tilesXrep, p.tilesXdim - p.tilesXrep * blockIdx.y)); tileIdx++)
257
+ {
258
+ // Locate output tile.
259
+ int tileX = enableXrep ? blockIdx.y * p.tilesXrep + tileIdx : blockIdx.x;
260
+ int tileOutX = tileX * tileOutW;
261
+ int tileOutY = (enableXrep ? blockIdx.x : blockIdx.y) * tileOutH;
262
+
263
+ // Locate input tile.
264
+ int tmpX = tileOutX * down - p.pad0.x;
265
+ int tmpY = tileOutY * down - p.pad0.y;
266
+ int tileInX = CEIL_DIV(tmpX, up);
267
+ int tileInY = CEIL_DIV(tmpY, up);
268
+ const int phaseInX = tileInX * up - tmpX;
269
+ const int phaseInY = tileInY * up - tmpY;
270
+
271
+ // Extra sync if input and output buffers are the same and we are not on first tile.
272
+ if (enableXrep && tileIdx > 0 && (filterMode == MODE_FUSD || (filterMode == MODE_SUFD && !downInline) || (filterMode == MODE_FUFD && downInline)))
273
+ __syncthreads();
274
+
275
+ // Load input tile & apply bias. Unrolled.
276
+ scalar_t b = (scalar_t)*(const T*)((const char*)p.b + (channelIdx * get_stride<index_t>(p.bStride)));
277
+ index_t mapOfsIn = channelIdx * get_stride<index_t>(p.xStride.z) + batchIdx * get_stride<index_t>(p.xStride.w);
278
+ int idx = threadIdx.x;
279
+ const int loopCountIN = CEIL_DIV(tileInW * tileInH, threadsPerBlock);
280
+ #pragma unroll
281
+ for (int loop = 0; loop < loopCountIN; loop++)
282
+ {
283
+ int relInX, relInY;
284
+ fast_div_mod<tileInW>(relInX, relInY, idx);
285
+ int inX = tileInX + relInX;
286
+ int inY = tileInY + relInY;
287
+ scalar_t v = 0;
288
+
289
+ if ((uint32_t)inX < p.xShape.x && (uint32_t)inY < p.xShape.y)
290
+ v = (scalar_t)*((const T*)((const char*)p.x + (inX * get_stride<index_t>(p.xStride.x) + inY * get_stride<index_t>(p.xStride.y) + mapOfsIn))) + b;
291
+
292
+ bool skip = (loop == loopCountIN-1) && (idx >= tileInW * tileInH);
293
+ if (!skip)
294
+ s_tileIn[idx] = v;
295
+
296
+ idx += threadsPerBlock;
297
+ }
298
+
299
+ if (filterMode == MODE_SUSD || filterMode == MODE_SUFD) // Separable upsampling filter.
300
+ {
301
+ // Horizontal upsampling.
302
+ __syncthreads();
303
+ if (up == 4)
304
+ {
305
+ for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up)
306
+ {
307
+ int relUpX0, relInY;
308
+ fast_div_mod<tileUpW>(relUpX0, relInY, idx);
309
+ int relInX0 = relUpX0 / up;
310
+ int src0 = relInX0 + tileInW * relInY;
311
+ int dst = relInY * tileUpW + relUpX0;
312
+ vec4_t v = InternalType<T>::zero_vec4();
313
+ scalar_t a = s_tileIn[src0];
314
+ if (phaseInX == 0)
315
+ {
316
+ #pragma unroll
317
+ for (int step = 0; step < fuSize / up; step++)
318
+ {
319
+ v.x += a * (scalar_t)c_fu[step * up + 0];
320
+ a = s_tileIn[src0 + step + 1];
321
+ v.y += a * (scalar_t)c_fu[step * up + 3];
322
+ v.z += a * (scalar_t)c_fu[step * up + 2];
323
+ v.w += a * (scalar_t)c_fu[step * up + 1];
324
+ }
325
+ }
326
+ else if (phaseInX == 1)
327
+ {
328
+ #pragma unroll
329
+ for (int step = 0; step < fuSize / up; step++)
330
+ {
331
+ v.x += a * (scalar_t)c_fu[step * up + 1];
332
+ v.y += a * (scalar_t)c_fu[step * up + 0];
333
+ a = s_tileIn[src0 + step + 1];
334
+ v.z += a * (scalar_t)c_fu[step * up + 3];
335
+ v.w += a * (scalar_t)c_fu[step * up + 2];
336
+ }
337
+ }
338
+ else if (phaseInX == 2)
339
+ {
340
+ #pragma unroll
341
+ for (int step = 0; step < fuSize / up; step++)
342
+ {
343
+ v.x += a * (scalar_t)c_fu[step * up + 2];
344
+ v.y += a * (scalar_t)c_fu[step * up + 1];
345
+ v.z += a * (scalar_t)c_fu[step * up + 0];
346
+ a = s_tileIn[src0 + step + 1];
347
+ v.w += a * (scalar_t)c_fu[step * up + 3];
348
+ }
349
+ }
350
+ else // (phaseInX == 3)
351
+ {
352
+ #pragma unroll
353
+ for (int step = 0; step < fuSize / up; step++)
354
+ {
355
+ v.x += a * (scalar_t)c_fu[step * up + 3];
356
+ v.y += a * (scalar_t)c_fu[step * up + 2];
357
+ v.z += a * (scalar_t)c_fu[step * up + 1];
358
+ v.w += a * (scalar_t)c_fu[step * up + 0];
359
+ a = s_tileIn[src0 + step + 1];
360
+ }
361
+ }
362
+ s_tileUpX[dst+0] = v.x;
363
+ s_tileUpX[dst+1] = v.y;
364
+ s_tileUpX[dst+2] = v.z;
365
+ s_tileUpX[dst+3] = v.w;
366
+ }
367
+ }
368
+ else if (up == 2)
369
+ {
370
+ bool p0 = (phaseInX == 0);
371
+ for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up)
372
+ {
373
+ int relUpX0, relInY;
374
+ fast_div_mod<tileUpW>(relUpX0, relInY, idx);
375
+ int relInX0 = relUpX0 / up;
376
+ int src0 = relInX0 + tileInW * relInY;
377
+ int dst = relInY * tileUpW + relUpX0;
378
+ vec2_t v = InternalType<T>::zero_vec2();
379
+ scalar_t a = s_tileIn[src0];
380
+ if (p0) // (phaseInX == 0)
381
+ {
382
+ #pragma unroll
383
+ for (int step = 0; step < fuSize / up; step++)
384
+ {
385
+ v.x += a * (scalar_t)c_fu[step * up + 0];
386
+ a = s_tileIn[src0 + step + 1];
387
+ v.y += a * (scalar_t)c_fu[step * up + 1];
388
+ }
389
+ }
390
+ else // (phaseInX == 1)
391
+ {
392
+ #pragma unroll
393
+ for (int step = 0; step < fuSize / up; step++)
394
+ {
395
+ v.x += a * (scalar_t)c_fu[step * up + 1];
396
+ v.y += a * (scalar_t)c_fu[step * up + 0];
397
+ a = s_tileIn[src0 + step + 1];
398
+ }
399
+ }
400
+ s_tileUpX[dst+0] = v.x;
401
+ s_tileUpX[dst+1] = v.y;
402
+ }
403
+ }
404
+
405
+ // Vertical upsampling & nonlinearity.
406
+
407
+ __syncthreads();
408
+ int groupMask = 15 << ((threadIdx.x & 31) & ~3);
409
+ int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs.
410
+ int sShapeMaxY = MIN(p.sShape.y, tileOutY * down + tileUpH); // Avoid out-of-tile sign writes.
411
+ if (up == 4)
412
+ {
413
+ minY -= 3; // Adjust according to block height.
414
+ for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x)
415
+ {
416
+ int relUpX, relInY0;
417
+ fast_div_mod<tileUpW>(relUpX, relInY0, idx);
418
+ int relUpY0 = relInY0 * up;
419
+ int src0 = relInY0 * tileUpW + relUpX;
420
+ int dst = relUpY0 * tileUpW + relUpX;
421
+ vec4_t v = InternalType<T>::zero_vec4();
422
+
423
+ scalar_t a = s_tileUpX[src0];
424
+ if (phaseInY == 0)
425
+ {
426
+ #pragma unroll
427
+ for (int step = 0; step < fuSize / up; step++)
428
+ {
429
+ v.x += a * (scalar_t)c_fu[step * up + 0];
430
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
431
+ v.y += a * (scalar_t)c_fu[step * up + 3];
432
+ v.z += a * (scalar_t)c_fu[step * up + 2];
433
+ v.w += a * (scalar_t)c_fu[step * up + 1];
434
+ }
435
+ }
436
+ else if (phaseInY == 1)
437
+ {
438
+ #pragma unroll
439
+ for (int step = 0; step < fuSize / up; step++)
440
+ {
441
+ v.x += a * (scalar_t)c_fu[step * up + 1];
442
+ v.y += a * (scalar_t)c_fu[step * up + 0];
443
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
444
+ v.z += a * (scalar_t)c_fu[step * up + 3];
445
+ v.w += a * (scalar_t)c_fu[step * up + 2];
446
+ }
447
+ }
448
+ else if (phaseInY == 2)
449
+ {
450
+ #pragma unroll
451
+ for (int step = 0; step < fuSize / up; step++)
452
+ {
453
+ v.x += a * (scalar_t)c_fu[step * up + 2];
454
+ v.y += a * (scalar_t)c_fu[step * up + 1];
455
+ v.z += a * (scalar_t)c_fu[step * up + 0];
456
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
457
+ v.w += a * (scalar_t)c_fu[step * up + 3];
458
+ }
459
+ }
460
+ else // (phaseInY == 3)
461
+ {
462
+ #pragma unroll
463
+ for (int step = 0; step < fuSize / up; step++)
464
+ {
465
+ v.x += a * (scalar_t)c_fu[step * up + 3];
466
+ v.y += a * (scalar_t)c_fu[step * up + 2];
467
+ v.z += a * (scalar_t)c_fu[step * up + 1];
468
+ v.w += a * (scalar_t)c_fu[step * up + 0];
469
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
470
+ }
471
+ }
472
+
473
+ int x = tileOutX * down + relUpX;
474
+ int y = tileOutY * down + relUpY0;
475
+ int signX = x + p.sOfs.x;
476
+ int signY = y + p.sOfs.y;
477
+ int signZ = blockIdx.z + p.blockZofs;
478
+ int signXb = signX >> 2;
479
+ index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
480
+ index_t si1 = si0 + p.sShape.x;
481
+ index_t si2 = si0 + p.sShape.x * 2;
482
+ index_t si3 = si0 + p.sShape.x * 3;
483
+
484
+ v.x *= (scalar_t)((float)up * (float)up * p.gain);
485
+ v.y *= (scalar_t)((float)up * (float)up * p.gain);
486
+ v.z *= (scalar_t)((float)up * (float)up * p.gain);
487
+ v.w *= (scalar_t)((float)up * (float)up * p.gain);
488
+
489
+ if (signWrite)
490
+ {
491
+ if (!enableWriteSkip)
492
+ {
493
+ // Determine and write signs.
494
+ int sx = __float_as_uint(v.x) >> 31 << 0;
495
+ int sy = __float_as_uint(v.y) >> 31 << 8;
496
+ int sz = __float_as_uint(v.z) >> 31 << 16;
497
+ int sw = __float_as_uint(v.w) >> 31 << 24;
498
+ if (sx) v.x *= p.slope;
499
+ if (sy) v.y *= p.slope;
500
+ if (sz) v.z *= p.slope;
501
+ if (sw) v.w *= p.slope;
502
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }
503
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }
504
+ if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType<T>::clamp(v.z, p.clamp); }
505
+ if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType<T>::clamp(v.w, p.clamp); }
506
+
507
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
508
+ {
509
+ // Combine signs.
510
+ uint32_t s = sx + sy + sw + sz;
511
+ s <<= (signX & 3) << 1;
512
+ s |= __shfl_xor_sync(groupMask, s, 1);
513
+ s |= __shfl_xor_sync(groupMask, s, 2);
514
+
515
+ // Write signs.
516
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
517
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
518
+ if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); }
519
+ if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); }
520
+ }
521
+ }
522
+ else
523
+ {
524
+ // Determine and write signs.
525
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
526
+ {
527
+ int sx = __float_as_uint(v.x) >> 31 << 0;
528
+ int sy = __float_as_uint(v.y) >> 31 << 8;
529
+ int sz = __float_as_uint(v.z) >> 31 << 16;
530
+ int sw = __float_as_uint(v.w) >> 31 << 24;
531
+ if (sx) v.x *= p.slope;
532
+ if (sy) v.y *= p.slope;
533
+ if (sz) v.z *= p.slope;
534
+ if (sw) v.w *= p.slope;
535
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }
536
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }
537
+ if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType<T>::clamp(v.z, p.clamp); }
538
+ if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType<T>::clamp(v.w, p.clamp); }
539
+
540
+ // Combine signs.
541
+ uint32_t s = sx + sy + sw + sz;
542
+ s <<= (signX & 3) << 1;
543
+ s |= __shfl_xor_sync(groupMask, s, 1);
544
+ s |= __shfl_xor_sync(groupMask, s, 2);
545
+
546
+ // Write signs.
547
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
548
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
549
+ if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); }
550
+ if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); }
551
+ }
552
+ else
553
+ {
554
+ // Just compute the values.
555
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
556
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
557
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);
558
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);
559
+ }
560
+ }
561
+ }
562
+ else if (signRead) // Read signs and apply.
563
+ {
564
+ if ((uint32_t)signXb < p.swLimit)
565
+ {
566
+ int ss = (signX & 3) << 1;
567
+ if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> ss; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; }
568
+ if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> ss; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; }
569
+ if ((uint32_t)(signY + 2) < p.sShape.y) { int s = p.s[si2] >> ss; if (s & 1) v.z *= p.slope; if (s & 2) v.z = 0.f; }
570
+ if ((uint32_t)(signY + 3) < p.sShape.y) { int s = p.s[si3] >> ss; if (s & 1) v.w *= p.slope; if (s & 2) v.w = 0.f; }
571
+ }
572
+ }
573
+ else // Forward pass with no sign write.
574
+ {
575
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
576
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
577
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);
578
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);
579
+ }
580
+
581
+ s_tileUpXY[dst + 0 * tileUpW] = v.x;
582
+ if (relUpY0 + 1 < tileUpH) s_tileUpXY[dst + 1 * tileUpW] = v.y;
583
+ if (relUpY0 + 2 < tileUpH) s_tileUpXY[dst + 2 * tileUpW] = v.z;
584
+ if (relUpY0 + 3 < tileUpH) s_tileUpXY[dst + 3 * tileUpW] = v.w;
585
+ }
586
+ }
587
+ else if (up == 2)
588
+ {
589
+ minY -= 1; // Adjust according to block height.
590
+ for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x)
591
+ {
592
+ int relUpX, relInY0;
593
+ fast_div_mod<tileUpW>(relUpX, relInY0, idx);
594
+ int relUpY0 = relInY0 * up;
595
+ int src0 = relInY0 * tileUpW + relUpX;
596
+ int dst = relUpY0 * tileUpW + relUpX;
597
+ vec2_t v = InternalType<T>::zero_vec2();
598
+
599
+ scalar_t a = s_tileUpX[src0];
600
+ if (phaseInY == 0)
601
+ {
602
+ #pragma unroll
603
+ for (int step = 0; step < fuSize / up; step++)
604
+ {
605
+ v.x += a * (scalar_t)c_fu[step * up + 0];
606
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
607
+ v.y += a * (scalar_t)c_fu[step * up + 1];
608
+ }
609
+ }
610
+ else // (phaseInY == 1)
611
+ {
612
+ #pragma unroll
613
+ for (int step = 0; step < fuSize / up; step++)
614
+ {
615
+ v.x += a * (scalar_t)c_fu[step * up + 1];
616
+ v.y += a * (scalar_t)c_fu[step * up + 0];
617
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
618
+ }
619
+ }
620
+
621
+ int x = tileOutX * down + relUpX;
622
+ int y = tileOutY * down + relUpY0;
623
+ int signX = x + p.sOfs.x;
624
+ int signY = y + p.sOfs.y;
625
+ int signZ = blockIdx.z + p.blockZofs;
626
+ int signXb = signX >> 2;
627
+ index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
628
+ index_t si1 = si0 + p.sShape.x;
629
+
630
+ v.x *= (scalar_t)((float)up * (float)up * p.gain);
631
+ v.y *= (scalar_t)((float)up * (float)up * p.gain);
632
+
633
+ if (signWrite)
634
+ {
635
+ if (!enableWriteSkip)
636
+ {
637
+ // Determine and write signs.
638
+ int sx = __float_as_uint(v.x) >> 31 << 0;
639
+ int sy = __float_as_uint(v.y) >> 31 << 8;
640
+ if (sx) v.x *= p.slope;
641
+ if (sy) v.y *= p.slope;
642
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }
643
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }
644
+
645
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
646
+ {
647
+ // Combine signs.
648
+ int s = sx + sy;
649
+ s <<= signXo;
650
+ s |= __shfl_xor_sync(groupMask, s, 1);
651
+ s |= __shfl_xor_sync(groupMask, s, 2);
652
+
653
+ // Write signs.
654
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
655
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
656
+ }
657
+ }
658
+ else
659
+ {
660
+ // Determine and write signs.
661
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
662
+ {
663
+ int sx = __float_as_uint(v.x) >> 31 << 0;
664
+ int sy = __float_as_uint(v.y) >> 31 << 8;
665
+ if (sx) v.x *= p.slope;
666
+ if (sy) v.y *= p.slope;
667
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }
668
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }
669
+
670
+ // Combine signs.
671
+ int s = sx + sy;
672
+ s <<= signXo;
673
+ s |= __shfl_xor_sync(groupMask, s, 1);
674
+ s |= __shfl_xor_sync(groupMask, s, 2);
675
+
676
+ // Write signs.
677
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
678
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
679
+ }
680
+ else
681
+ {
682
+ // Just compute the values.
683
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
684
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
685
+ }
686
+ }
687
+ }
688
+ else if (signRead) // Read signs and apply.
689
+ {
690
+ if ((uint32_t)signXb < p.swLimit)
691
+ {
692
+ if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> signXo; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; }
693
+ if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> signXo; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; }
694
+ }
695
+ }
696
+ else // Forward pass with no sign write.
697
+ {
698
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
699
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
700
+ }
701
+
702
+ if (!downInline)
703
+ {
704
+ // Write into temporary buffer.
705
+ s_tileUpXY[dst] = v.x;
706
+ if (relUpY0 < tileUpH - 1)
707
+ s_tileUpXY[dst + tileUpW] = v.y;
708
+ }
709
+ else
710
+ {
711
+ // Write directly into output buffer.
712
+ if ((uint32_t)x < p.yShape.x)
713
+ {
714
+ int ymax = MIN(p.yShape.y, tileUpH + tileOutY * down);
715
+ index_t ofs = x * get_stride<index_t>(p.yStride.x) + y * get_stride<index_t>(p.yStride.y) + mapOfsOut;
716
+ if ((uint32_t)y + 0 < p.yShape.y) *((T*)((char*)p.y + ofs)) = (T)(v.x * (scalar_t)c_fd[0]);
717
+ if ((uint32_t)y + 1 < ymax) *((T*)((char*)p.y + ofs + get_stride<index_t>(p.yStride.y))) = (T)(v.y * (scalar_t)c_fd[0]);
718
+ }
719
+ }
720
+ }
721
+ }
722
+ }
723
+ else if (filterMode == MODE_FUSD || filterMode == MODE_FUFD)
724
+ {
725
+ // Full upsampling filter.
726
+
727
+ if (up == 2)
728
+ {
729
+ // 2 x 2-wide.
730
+ __syncthreads();
731
+ int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH + p.sOfs.y : 0; // Skip already written signs.
732
+ for (int idx = threadIdx.x * 4; idx < tileUpW * tileUpH; idx += blockDim.x * 4)
733
+ {
734
+ int relUpX0, relUpY0;
735
+ fast_div_mod<tileUpW>(relUpX0, relUpY0, idx);
736
+ int relInX0 = CEIL_DIV(relUpX0 - phaseInX, up);
737
+ int relInY0 = CEIL_DIV(relUpY0 - phaseInY, up);
738
+ int src0 = relInX0 + tileInW * relInY0;
739
+ int tap0y = (relInY0 * up + phaseInY - relUpY0);
740
+
741
+ #define X_LOOP(TAPY, PX) \
742
+ for (int sx = 0; sx < fuSize / up; sx++) \
743
+ { \
744
+ v.x += a * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \
745
+ v.z += b * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 0) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \
746
+ v.y += a * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \
747
+ v.w += b * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 1) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \
748
+ }
749
+
750
+ vec4_t v = InternalType<T>::zero_vec4();
751
+ if (tap0y == 0 && phaseInX == 0)
752
+ #pragma unroll
753
+ for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
754
+ #pragma unroll
755
+ X_LOOP(0, 0) }
756
+ if (tap0y == 0 && phaseInX == 1)
757
+ #pragma unroll
758
+ for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
759
+ #pragma unroll
760
+ X_LOOP(0, 1) }
761
+ if (tap0y == 1 && phaseInX == 0)
762
+ #pragma unroll
763
+ for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
764
+ #pragma unroll
765
+ X_LOOP(1, 0) }
766
+ if (tap0y == 1 && phaseInX == 1)
767
+ #pragma unroll
768
+ for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
769
+ #pragma unroll
770
+ X_LOOP(1, 1) }
771
+
772
+ #undef X_LOOP
773
+
774
+ int x = tileOutX * down + relUpX0;
775
+ int y = tileOutY * down + relUpY0;
776
+ int signX = x + p.sOfs.x;
777
+ int signY = y + p.sOfs.y;
778
+ int signZ = blockIdx.z + p.blockZofs;
779
+ int signXb = signX >> 2;
780
+ index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
781
+
782
+ v.x *= (scalar_t)((float)up * (float)up * p.gain);
783
+ v.y *= (scalar_t)((float)up * (float)up * p.gain);
784
+ v.z *= (scalar_t)((float)up * (float)up * p.gain);
785
+ v.w *= (scalar_t)((float)up * (float)up * p.gain);
786
+
787
+ if (signWrite)
788
+ {
789
+ if (!enableWriteSkip)
790
+ {
791
+ // Determine and write signs.
792
+ int sx = __float_as_uint(v.x) >> 31;
793
+ int sy = __float_as_uint(v.y) >> 31;
794
+ int sz = __float_as_uint(v.z) >> 31;
795
+ int sw = __float_as_uint(v.w) >> 31;
796
+ if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType<T>::clamp(v.x, p.clamp); }
797
+ if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType<T>::clamp(v.y, p.clamp); }
798
+ if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType<T>::clamp(v.z, p.clamp); }
799
+ if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType<T>::clamp(v.w, p.clamp); }
800
+
801
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
802
+ {
803
+ p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6);
804
+ }
805
+ }
806
+ else
807
+ {
808
+ // Determine and write signs.
809
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
810
+ {
811
+ int sx = __float_as_uint(v.x) >> 31;
812
+ int sy = __float_as_uint(v.y) >> 31;
813
+ int sz = __float_as_uint(v.z) >> 31;
814
+ int sw = __float_as_uint(v.w) >> 31;
815
+ if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType<T>::clamp(v.x, p.clamp); }
816
+ if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType<T>::clamp(v.y, p.clamp); }
817
+ if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType<T>::clamp(v.z, p.clamp); }
818
+ if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType<T>::clamp(v.w, p.clamp); }
819
+
820
+ p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6);
821
+ }
822
+ else
823
+ {
824
+ // Just compute the values.
825
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
826
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
827
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);
828
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);
829
+ }
830
+ }
831
+ }
832
+ else if (signRead) // Read sign and apply.
833
+ {
834
+ if ((uint32_t)signY < p.sShape.y)
835
+ {
836
+ int s = 0;
837
+ if ((uint32_t)signXb < p.swLimit) s = p.s[si];
838
+ if ((uint32_t)signXb + 1 < p.swLimit) s |= p.s[si + 1] << 8;
839
+ s >>= (signX & 3) << 1;
840
+ if (s & 0x01) v.x *= p.slope; if (s & 0x02) v.x = 0.f;
841
+ if (s & 0x04) v.y *= p.slope; if (s & 0x08) v.y = 0.f;
842
+ if (s & 0x10) v.z *= p.slope; if (s & 0x20) v.z = 0.f;
843
+ if (s & 0x40) v.w *= p.slope; if (s & 0x80) v.w = 0.f;
844
+ }
845
+ }
846
+ else // Forward pass with no sign write.
847
+ {
848
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
849
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
850
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);
851
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);
852
+ }
853
+
854
+ s_tileUpXY[idx + 0] = v.x;
855
+ s_tileUpXY[idx + 1] = v.y;
856
+ s_tileUpXY[idx + 2] = v.z;
857
+ s_tileUpXY[idx + 3] = v.w;
858
+ }
859
+ }
860
+ else if (up == 1)
861
+ {
862
+ __syncthreads();
863
+ uint32_t groupMask = 15 << ((threadIdx.x & 31) & ~3);
864
+ int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs.
865
+ for (int idx = threadIdx.x; idx < tileUpW * tileUpH; idx += blockDim.x)
866
+ {
867
+ int relUpX0, relUpY0;
868
+ fast_div_mod<tileUpW>(relUpX0, relUpY0, idx);
869
+ scalar_t v = s_tileIn[idx] * (scalar_t)c_fu[0]; // 1x1 filter.
870
+
871
+ int x = tileOutX * down + relUpX0;
872
+ int y = tileOutY * down + relUpY0;
873
+ int signX = x + p.sOfs.x;
874
+ int signY = y + p.sOfs.y;
875
+ int signZ = blockIdx.z + p.blockZofs;
876
+ int signXb = signX >> 2;
877
+ index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
878
+ v *= (scalar_t)((float)up * (float)up * p.gain);
879
+
880
+ if (signWrite)
881
+ {
882
+ if (!enableWriteSkip)
883
+ {
884
+ // Determine and write sign.
885
+ uint32_t s = 0;
886
+ uint32_t signXbit = (1u << signXo);
887
+ if (v < 0.f)
888
+ {
889
+ s = signXbit;
890
+ v *= p.slope;
891
+ }
892
+ if (fabsf(v) > p.clamp)
893
+ {
894
+ s = signXbit * 2;
895
+ v = InternalType<T>::clamp(v, p.clamp);
896
+ }
897
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
898
+ {
899
+ s += __shfl_xor_sync(groupMask, s, 1); // Coalesce.
900
+ s += __shfl_xor_sync(groupMask, s, 2); // Coalesce.
901
+ p.s[si] = s; // Write.
902
+ }
903
+ }
904
+ else
905
+ {
906
+ // Determine and write sign.
907
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
908
+ {
909
+ uint32_t s = 0;
910
+ uint32_t signXbit = (1u << signXo);
911
+ if (v < 0.f)
912
+ {
913
+ s = signXbit;
914
+ v *= p.slope;
915
+ }
916
+ if (fabsf(v) > p.clamp)
917
+ {
918
+ s = signXbit * 2;
919
+ v = InternalType<T>::clamp(v, p.clamp);
920
+ }
921
+ s += __shfl_xor_sync(groupMask, s, 1); // Coalesce.
922
+ s += __shfl_xor_sync(groupMask, s, 2); // Coalesce.
923
+ p.s[si] = s; // Write.
924
+ }
925
+ else
926
+ {
927
+ // Just compute the value.
928
+ if (v < 0.f) v *= p.slope;
929
+ v = InternalType<T>::clamp(v, p.clamp);
930
+ }
931
+ }
932
+ }
933
+ else if (signRead)
934
+ {
935
+ // Read sign and apply if within sign tensor bounds.
936
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y)
937
+ {
938
+ int s = p.s[si];
939
+ s >>= signXo;
940
+ if (s & 1) v *= p.slope;
941
+ if (s & 2) v = 0.f;
942
+ }
943
+ }
944
+ else // Forward pass with no sign write.
945
+ {
946
+ if (v < 0.f) v *= p.slope;
947
+ v = InternalType<T>::clamp(v, p.clamp);
948
+ }
949
+
950
+ if (!downInline) // Write into temporary buffer.
951
+ s_tileUpXY[idx] = v;
952
+ else if ((uint32_t)x < p.yShape.x && (uint32_t)y < p.yShape.y) // Write directly into output buffer
953
+ *((T*)((char*)p.y + (x * get_stride<index_t>(p.yStride.x) + y * get_stride<index_t>(p.yStride.y) + mapOfsOut))) = (T)(v * (scalar_t)c_fd[0]);
954
+ }
955
+ }
956
+ }
957
+
958
+ // Downsampling.
959
+ if (filterMode == MODE_SUSD || filterMode == MODE_FUSD)
960
+ {
961
+ // Horizontal downsampling.
962
+ __syncthreads();
963
+ if (down == 4 && tileOutW % 4 == 0)
964
+ {
965
+ // Calculate 4 pixels at a time.
966
+ for (int idx = threadIdx.x * 4; idx < tileOutW * tileUpH; idx += blockDim.x * 4)
967
+ {
968
+ int relOutX0, relUpY;
969
+ fast_div_mod<tileOutW>(relOutX0, relUpY, idx);
970
+ int relUpX0 = relOutX0 * down;
971
+ int src0 = relUpY * tileUpW + relUpX0;
972
+ vec4_t v = InternalType<T>::zero_vec4();
973
+ #pragma unroll
974
+ for (int step = 0; step < fdSize; step++)
975
+ {
976
+ v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step];
977
+ v.y += s_tileUpXY[src0 + 4 + step] * (scalar_t)c_fd[step];
978
+ v.z += s_tileUpXY[src0 + 8 + step] * (scalar_t)c_fd[step];
979
+ v.w += s_tileUpXY[src0 + 12 + step] * (scalar_t)c_fd[step];
980
+ }
981
+ s_tileDownX[idx+0] = v.x;
982
+ s_tileDownX[idx+1] = v.y;
983
+ s_tileDownX[idx+2] = v.z;
984
+ s_tileDownX[idx+3] = v.w;
985
+ }
986
+ }
987
+ else if ((down == 2 || down == 4) && (tileOutW % 2 == 0))
988
+ {
989
+ // Calculate 2 pixels at a time.
990
+ for (int idx = threadIdx.x * 2; idx < tileOutW * tileUpH; idx += blockDim.x * 2)
991
+ {
992
+ int relOutX0, relUpY;
993
+ fast_div_mod<tileOutW>(relOutX0, relUpY, idx);
994
+ int relUpX0 = relOutX0 * down;
995
+ int src0 = relUpY * tileUpW + relUpX0;
996
+ vec2_t v = InternalType<T>::zero_vec2();
997
+ #pragma unroll
998
+ for (int step = 0; step < fdSize; step++)
999
+ {
1000
+ v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step];
1001
+ v.y += s_tileUpXY[src0 + down + step] * (scalar_t)c_fd[step];
1002
+ }
1003
+ s_tileDownX[idx+0] = v.x;
1004
+ s_tileDownX[idx+1] = v.y;
1005
+ }
1006
+ }
1007
+ else
1008
+ {
1009
+ // Calculate 1 pixel at a time.
1010
+ for (int idx = threadIdx.x; idx < tileOutW * tileUpH; idx += blockDim.x)
1011
+ {
1012
+ int relOutX0, relUpY;
1013
+ fast_div_mod<tileOutW>(relOutX0, relUpY, idx);
1014
+ int relUpX0 = relOutX0 * down;
1015
+ int src = relUpY * tileUpW + relUpX0;
1016
+ scalar_t v = 0.f;
1017
+ #pragma unroll
1018
+ for (int step = 0; step < fdSize; step++)
1019
+ v += s_tileUpXY[src + step] * (scalar_t)c_fd[step];
1020
+ s_tileDownX[idx] = v;
1021
+ }
1022
+ }
1023
+
1024
+ // Vertical downsampling & store output tile.
1025
+ __syncthreads();
1026
+ for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x)
1027
+ {
1028
+ int relOutX, relOutY0;
1029
+ fast_div_mod<tileOutW>(relOutX, relOutY0, idx);
1030
+ int relUpY0 = relOutY0 * down;
1031
+ int src0 = relUpY0 * tileOutW + relOutX;
1032
+ scalar_t v = 0;
1033
+ #pragma unroll
1034
+ for (int step = 0; step < fdSize; step++)
1035
+ v += s_tileDownX[src0 + step * tileOutW] * (scalar_t)c_fd[step];
1036
+
1037
+ int outX = tileOutX + relOutX;
1038
+ int outY = tileOutY + relOutY0;
1039
+
1040
+ if (outX < p.yShape.x & outY < p.yShape.y)
1041
+ *((T*)((char*)p.y + (outX * get_stride<index_t>(p.yStride.x) + outY * get_stride<index_t>(p.yStride.y) + mapOfsOut))) = (T)v;
1042
+ }
1043
+ }
1044
+ else if (filterMode == MODE_SUFD || filterMode == MODE_FUFD)
1045
+ {
1046
+ // Full downsampling filter.
1047
+ if (down == 2)
1048
+ {
1049
+ // 2-wide.
1050
+ __syncthreads();
1051
+ for (int idx = threadIdx.x * 2; idx < tileOutW * tileOutH; idx += blockDim.x * 2)
1052
+ {
1053
+ int relOutX0, relOutY0;
1054
+ fast_div_mod<tileOutW>(relOutX0, relOutY0, idx);
1055
+ int relUpX0 = relOutX0 * down;
1056
+ int relUpY0 = relOutY0 * down;
1057
+ int src0 = relUpY0 * tileUpW + relUpX0;
1058
+ vec2_t v = InternalType<T>::zero_vec2();
1059
+ #pragma unroll
1060
+ for (int sy = 0; sy < fdSize; sy++)
1061
+ #pragma unroll
1062
+ for (int sx = 0; sx < fdSize; sx++)
1063
+ {
1064
+ v.x += s_tileUpXY[src0 + 0 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE];
1065
+ v.y += s_tileUpXY[src0 + 2 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE];
1066
+ }
1067
+
1068
+ int outX = tileOutX + relOutX0;
1069
+ int outY = tileOutY + relOutY0;
1070
+ if ((uint32_t)outY < p.yShape.y)
1071
+ {
1072
+ index_t ofs = outX * get_stride<index_t>(p.yStride.x) + outY * get_stride<index_t>(p.yStride.y) + mapOfsOut;
1073
+ if (outX + 0 < p.yShape.x) *((T*)((char*)p.y + ofs)) = (T)v.x;
1074
+ if (outX + 1 < p.yShape.x) *((T*)((char*)p.y + ofs + get_stride<index_t>(p.yStride.x))) = (T)v.y;
1075
+ }
1076
+ }
1077
+ }
1078
+ else if (down == 1 && !downInline)
1079
+ {
1080
+ // Thread per pixel.
1081
+ __syncthreads();
1082
+ for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x)
1083
+ {
1084
+ int relOutX0, relOutY0;
1085
+ fast_div_mod<tileOutW>(relOutX0, relOutY0, idx);
1086
+ scalar_t v = s_tileUpXY[idx] * (scalar_t)c_fd[0]; // 1x1 filter.
1087
+
1088
+ int outX = tileOutX + relOutX0;
1089
+ int outY = tileOutY + relOutY0;
1090
+ if ((uint32_t)outX < p.yShape.x && (uint32_t)outY < p.yShape.y)
1091
+ *((T*)((char*)p.y + (outX * get_stride<index_t>(p.yStride.x) + outY * get_stride<index_t>(p.yStride.y) + mapOfsOut))) = (T)v;
1092
+ }
1093
+ }
1094
+ }
1095
+
1096
+ if (!enableXrep)
1097
+ break;
1098
+ }
1099
+ }
1100
+
1101
+ //------------------------------------------------------------------------
1102
+ // Compute activation function and signs for upsampled data tensor, modifying data tensor in-place. Used for accelerating the generic variant.
1103
+ // Sign tensor is known to be contiguous, and p.x and p.s have the same z, w dimensions. 64-bit indexing is always used.
1104
+
1105
+ template <class T, bool signWrite, bool signRead>
1106
+ static __global__ void filtered_lrelu_act_kernel(filtered_lrelu_act_kernel_params p)
1107
+ {
1108
+ typedef typename InternalType<T>::scalar_t scalar_t;
1109
+
1110
+ // Indexing.
1111
+ int32_t x = threadIdx.x + blockIdx.x * blockDim.x;
1112
+ int32_t ymax = signWrite ? p.sShape.y : p.xShape.y;
1113
+ int32_t qmax = p.xShape.z * p.xShape.w; // Combined minibatch*channel maximum index.
1114
+
1115
+ // Loop to accommodate oversized tensors.
1116
+ for (int32_t q = blockIdx.z; q < qmax; q += gridDim.z)
1117
+ for (int32_t y = blockIdx.y; y < ymax; y += gridDim.y)
1118
+ {
1119
+ // Extract z and w (channel, minibatch index).
1120
+ int32_t w = q / p.xShape.z;
1121
+ int32_t z = q - w * p.xShape.z;
1122
+
1123
+ // Choose behavior based on sign read/write mode.
1124
+ if (signWrite)
1125
+ {
1126
+ // Process value if in p.x.
1127
+ uint32_t s = 0;
1128
+ if (x < p.xShape.x && y < p.xShape.y)
1129
+ {
1130
+ int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
1131
+ T* pv = ((T*)p.x) + ix;
1132
+ scalar_t v = (scalar_t)(*pv);
1133
+
1134
+ // Gain, LReLU, clamp.
1135
+ v *= p.gain;
1136
+ if (v < 0.f)
1137
+ {
1138
+ v *= p.slope;
1139
+ s = 1; // Sign.
1140
+ }
1141
+ if (fabsf(v) > p.clamp)
1142
+ {
1143
+ v = InternalType<T>::clamp(v, p.clamp);
1144
+ s = 2; // Clamp.
1145
+ }
1146
+
1147
+ *pv = (T)v; // Write value.
1148
+ }
1149
+
1150
+ // Coalesce into threads 0 and 16 of warp.
1151
+ uint32_t m = (threadIdx.x & 16) ? 0xffff0000u : 0x0000ffffu;
1152
+ s <<= ((threadIdx.x & 15) << 1); // Shift into place.
1153
+ s |= __shfl_xor_sync(m, s, 1); // Distribute.
1154
+ s |= __shfl_xor_sync(m, s, 2);
1155
+ s |= __shfl_xor_sync(m, s, 4);
1156
+ s |= __shfl_xor_sync(m, s, 8);
1157
+
1158
+ // Write signs if leader and in p.s.
1159
+ if (!(threadIdx.x & 15) && x < p.sShape.x) // y is always in.
1160
+ {
1161
+ uint64_t is = x + p.sShape.x * (y + (int64_t)p.sShape.y * q); // Contiguous.
1162
+ ((uint32_t*)p.s)[is >> 4] = s;
1163
+ }
1164
+ }
1165
+ else if (signRead)
1166
+ {
1167
+ // Process value if in p.x.
1168
+ if (x < p.xShape.x) // y is always in.
1169
+ {
1170
+ int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
1171
+ T* pv = ((T*)p.x) + ix;
1172
+ scalar_t v = (scalar_t)(*pv);
1173
+ v *= p.gain;
1174
+
1175
+ // Apply sign buffer offset.
1176
+ uint32_t sx = x + p.sOfs.x;
1177
+ uint32_t sy = y + p.sOfs.y;
1178
+
1179
+ // Read and apply signs if we land inside valid region of sign buffer.
1180
+ if (sx < p.sShape.x && sy < p.sShape.y)
1181
+ {
1182
+ uint64_t is = (sx >> 2) + (p.sShape.x >> 2) * (sy + (uint64_t)p.sShape.y * q); // Contiguous.
1183
+ unsigned char s = p.s[is];
1184
+ s >>= (sx & 3) << 1; // Shift into place.
1185
+ if (s & 1) // Sign?
1186
+ v *= p.slope;
1187
+ if (s & 2) // Clamp?
1188
+ v = 0.f;
1189
+ }
1190
+
1191
+ *pv = (T)v; // Write value.
1192
+ }
1193
+ }
1194
+ else
1195
+ {
1196
+ // Forward pass with no sign write. Process value if in p.x.
1197
+ if (x < p.xShape.x) // y is always in.
1198
+ {
1199
+ int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
1200
+ T* pv = ((T*)p.x) + ix;
1201
+ scalar_t v = (scalar_t)(*pv);
1202
+ v *= p.gain;
1203
+ if (v < 0.f)
1204
+ v *= p.slope;
1205
+ if (fabsf(v) > p.clamp)
1206
+ v = InternalType<T>::clamp(v, p.clamp);
1207
+ *pv = (T)v; // Write value.
1208
+ }
1209
+ }
1210
+ }
1211
+ }
1212
+
1213
+ template <class T, bool signWrite, bool signRead> void* choose_filtered_lrelu_act_kernel(void)
1214
+ {
1215
+ return (void*)filtered_lrelu_act_kernel<T, signWrite, signRead>;
1216
+ }
1217
+
1218
+ //------------------------------------------------------------------------
1219
+ // CUDA kernel selection.
1220
+
1221
+ template <class T, class index_t, bool signWrite, bool signRead> filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB)
1222
+ {
1223
+ filtered_lrelu_kernel_spec s = { 0 };
1224
+
1225
+ // Return the first matching kernel.
1226
+ #define CASE(SH, U, FU, D, FD, MODE, TW, TH, W, XR, WS) \
1227
+ if (sharedKB >= SH) \
1228
+ if ((p.fuShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_SUFD)) || (p.fuShape.y > 0 && (MODE == MODE_FUSD || MODE == MODE_FUFD))) \
1229
+ if ((p.fdShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_FUSD)) || (p.fdShape.y > 0 && (MODE == MODE_SUFD || MODE == MODE_FUFD))) \
1230
+ if (p.up == U && p.fuShape.x <= FU && p.fuShape.y <= FU && p.down == D && p.fdShape.x <= FD && p.fdShape.y <= FD) \
1231
+ { \
1232
+ static_assert((D*TW % 4) == 0, "down * tileWidth must be divisible by 4"); \
1233
+ static_assert(FU % U == 0, "upscaling filter size must be multiple of upscaling factor"); \
1234
+ static_assert(FD % D == 0, "downscaling filter size must be multiple of downscaling factor"); \
1235
+ s.setup = (void*)setup_filters_kernel; \
1236
+ s.exec = (void*)filtered_lrelu_kernel<T, index_t, SH, signWrite, signRead, MODE, U, FU, D, FD, TW, TH, W*32, !!XR, !!WS>; \
1237
+ s.tileOut = make_int2(TW, TH); \
1238
+ s.numWarps = W; \
1239
+ s.xrep = XR; \
1240
+ s.dynamicSharedKB = (SH == 48) ? 0 : SH; \
1241
+ return s; \
1242
+ }
1243
+
1244
+ // Launch parameters for various kernel specializations.
1245
+ // Small filters must be listed before large filters, otherwise the kernel for larger filter will always match first.
1246
+ // Kernels that use more shared memory must be listed before those that use less, for the same reason.
1247
+
1248
+ CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/1,1, /*mode*/MODE_FUFD, /*tw,th,warps,xrep,wskip*/64, 178, 32, 0, 0) // 1t-upf1-downf1
1249
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/152, 95, 16, 0, 0) // 4t-ups2-downf1
1250
+ CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 22, 16, 0, 0) // 4t-upf1-downs2
1251
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 29, 16, 11, 0) // 4t-ups2-downs2
1252
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/60, 28, 16, 0, 0) // 4t-upf2-downs2
1253
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 28, 16, 0, 0) // 4t-ups2-downf2
1254
+ CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 31, 16, 11, 0) // 4t-ups4-downs2
1255
+ CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 36, 16, 0, 0) // 4t-ups4-downf2
1256
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 22, 16, 12, 0) // 4t-ups2-downs4
1257
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/29, 15, 16, 0, 0) // 4t-upf2-downs4
1258
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/96, 150, 28, 0, 0) // 6t-ups2-downf1
1259
+ CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 35, 24, 0, 0) // 6t-upf1-downs2
1260
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 16, 10, 0) // 6t-ups2-downs2
1261
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/58, 28, 24, 8, 0) // 6t-upf2-downs2
1262
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/52, 28, 16, 0, 0) // 6t-ups2-downf2
1263
+ CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 51, 16, 5, 0) // 6t-ups4-downs2
1264
+ CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 56, 16, 6, 0) // 6t-ups4-downf2
1265
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 18, 16, 12, 0) // 6t-ups2-downs4
1266
+ CASE(/*sharedKB*/96, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 31, 32, 6, 0) // 6t-upf2-downs4 96kB
1267
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 13, 24, 0, 0) // 6t-upf2-downs4
1268
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/148, 89, 24, 0, 0) // 8t-ups2-downf1
1269
+ CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 31, 16, 5, 0) // 8t-upf1-downs2
1270
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 41, 16, 9, 0) // 8t-ups2-downs2
1271
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 26, 24, 0, 0) // 8t-upf2-downs2
1272
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 40, 16, 0, 0) // 8t-ups2-downf2
1273
+ CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 24, 5, 0) // 8t-ups4-downs2
1274
+ CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 50, 16, 0, 0) // 8t-ups4-downf2
1275
+ CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/24, 24, 32, 12, 1) // 8t-ups2-downs4 96kB
1276
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 13, 16, 10, 1) // 8t-ups2-downs4
1277
+ CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 28, 28, 4, 0) // 8t-upf2-downs4 96kB
1278
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 10, 24, 0, 0) // 8t-upf2-downs4
1279
+
1280
+ #undef CASE
1281
+ return s; // No kernel found.
1282
+ }
1283
+
1284
+ //------------------------------------------------------------------------
torch_utils/ops/filtered_lrelu.h ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <cuda_runtime.h>
10
+
11
+ //------------------------------------------------------------------------
12
+ // CUDA kernel parameters.
13
+
14
+ struct filtered_lrelu_kernel_params
15
+ {
16
+ // These parameters decide which kernel to use.
17
+ int up; // upsampling ratio (1, 2, 4)
18
+ int down; // downsampling ratio (1, 2, 4)
19
+ int2 fuShape; // [size, 1] | [size, size]
20
+ int2 fdShape; // [size, 1] | [size, size]
21
+
22
+ int _dummy; // Alignment.
23
+
24
+ // Rest of the parameters.
25
+ const void* x; // Input tensor.
26
+ void* y; // Output tensor.
27
+ const void* b; // Bias tensor.
28
+ unsigned char* s; // Sign tensor in/out. NULL if unused.
29
+ const float* fu; // Upsampling filter.
30
+ const float* fd; // Downsampling filter.
31
+
32
+ int2 pad0; // Left/top padding.
33
+ float gain; // Additional gain factor.
34
+ float slope; // Leaky ReLU slope on negative side.
35
+ float clamp; // Clamp after nonlinearity.
36
+ int flip; // Filter kernel flip for gradient computation.
37
+
38
+ int tilesXdim; // Original number of horizontal output tiles.
39
+ int tilesXrep; // Number of horizontal tiles per CTA.
40
+ int blockZofs; // Block z offset to support large minibatch, channel dimensions.
41
+
42
+ int4 xShape; // [width, height, channel, batch]
43
+ int4 yShape; // [width, height, channel, batch]
44
+ int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused.
45
+ int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
46
+ int swLimit; // Active width of sign tensor in bytes.
47
+
48
+ longlong4 xStride; // Strides of all tensors except signs, same component order as shapes.
49
+ longlong4 yStride; //
50
+ int64_t bStride; //
51
+ longlong3 fuStride; //
52
+ longlong3 fdStride; //
53
+ };
54
+
55
+ struct filtered_lrelu_act_kernel_params
56
+ {
57
+ void* x; // Input/output, modified in-place.
58
+ unsigned char* s; // Sign tensor in/out. NULL if unused.
59
+
60
+ float gain; // Additional gain factor.
61
+ float slope; // Leaky ReLU slope on negative side.
62
+ float clamp; // Clamp after nonlinearity.
63
+
64
+ int4 xShape; // [width, height, channel, batch]
65
+ longlong4 xStride; // Input/output tensor strides, same order as in shape.
66
+ int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused.
67
+ int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
68
+ };
69
+
70
+ //------------------------------------------------------------------------
71
+ // CUDA kernel specialization.
72
+
73
+ struct filtered_lrelu_kernel_spec
74
+ {
75
+ void* setup; // Function for filter kernel setup.
76
+ void* exec; // Function for main operation.
77
+ int2 tileOut; // Width/height of launch tile.
78
+ int numWarps; // Number of warps per thread block, determines launch block size.
79
+ int xrep; // For processing multiple horizontal tiles per thread block.
80
+ int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants.
81
+ };
82
+
83
+ //------------------------------------------------------------------------
84
+ // CUDA kernel selection.
85
+
86
+ template <class T, class index_t, bool signWrite, bool signRead> filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
87
+ template <class T, bool signWrite, bool signRead> void* choose_filtered_lrelu_act_kernel(void);
88
+ template <bool signWrite, bool signRead> cudaError_t copy_filters(cudaStream_t stream);
89
+
90
+ //------------------------------------------------------------------------
torch_utils/ops/filtered_lrelu.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import os
10
+ import numpy as np
11
+ import torch
12
+ import warnings
13
+
14
+ from .. import custom_ops
15
+ from .. import misc
16
+ from . import upfirdn2d
17
+ from . import bias_act
18
+
19
+ #----------------------------------------------------------------------------
20
+
21
+ _plugin = None
22
+
23
+ def _init():
24
+ global _plugin
25
+ if _plugin is None:
26
+
27
+ # sources=['filtered_lrelu.h', 'filtered_lrelu.cu', 'filtered_lrelu.cpp', 'filtered_lrelu_wr.cu', 'filtered_lrelu_rd.cu', 'filtered_lrelu_ns.cu']
28
+ # sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
29
+ # try:
30
+ # _plugin = custom_ops.get_plugin('filtered_lrelu_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math', '--allow-unsupported-compiler'])
31
+ # except:
32
+ # warnings.warn('Failed to build CUDA kernels for filtered_lrelu_plugin. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
33
+
34
+ _plugin = custom_ops.get_plugin_v3(
35
+ module_name='filtered_lrelu_plugin',
36
+ sources=['filtered_lrelu.cpp', 'filtered_lrelu_wr.cu', 'filtered_lrelu_rd.cu', 'filtered_lrelu_ns.cu'],
37
+ headers=['filtered_lrelu.h', 'filtered_lrelu.cu'],
38
+ source_dir=os.path.dirname(__file__),
39
+ extra_cuda_cflags=['--use_fast_math', '--allow-unsupported-compiler'],
40
+ )
41
+ return True
42
+
43
+ def _get_filter_size(f):
44
+ if f is None:
45
+ return 1, 1
46
+ assert isinstance(f, torch.Tensor)
47
+ assert 1 <= f.ndim <= 2
48
+ return f.shape[-1], f.shape[0] # width, height
49
+
50
+ def _parse_padding(padding):
51
+ if isinstance(padding, int):
52
+ padding = [padding, padding]
53
+ assert isinstance(padding, (list, tuple))
54
+ assert all(isinstance(x, (int, np.integer)) for x in padding)
55
+ padding = [int(x) for x in padding]
56
+ if len(padding) == 2:
57
+ px, py = padding
58
+ padding = [px, px, py, py]
59
+ px0, px1, py0, py1 = padding
60
+ return px0, px1, py0, py1
61
+
62
+ #----------------------------------------------------------------------------
63
+
64
+ def filtered_lrelu(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False, impl='cuda'):
65
+ r"""Filtered leaky ReLU for a batch of 2D images.
66
+
67
+ Performs the following sequence of operations for each channel:
68
+
69
+ 1. Add channel-specific bias if provided (`b`).
70
+
71
+ 2. Upsample the image by inserting N-1 zeros after each pixel (`up`).
72
+
73
+ 3. Pad the image with the specified number of zeros on each side (`padding`).
74
+ Negative padding corresponds to cropping the image.
75
+
76
+ 4. Convolve the image with the specified upsampling FIR filter (`fu`), shrinking it
77
+ so that the footprint of all output pixels lies within the input image.
78
+
79
+ 5. Multiply each value by the provided gain factor (`gain`).
80
+
81
+ 6. Apply leaky ReLU activation function to each value.
82
+
83
+ 7. Clamp each value between -clamp and +clamp, if `clamp` parameter is provided.
84
+
85
+ 8. Convolve the image with the specified downsampling FIR filter (`fd`), shrinking
86
+ it so that the footprint of all output pixels lies within the input image.
87
+
88
+ 9. Downsample the image by keeping every Nth pixel (`down`).
89
+
90
+ The fused op is considerably more efficient than performing the same calculation
91
+ using standard PyTorch ops. It supports gradients of arbitrary order.
92
+
93
+ Args:
94
+ x: Float32/float16/float64 input tensor of the shape
95
+ `[batch_size, num_channels, in_height, in_width]`.
96
+ fu: Float32 upsampling FIR filter of the shape
97
+ `[filter_height, filter_width]` (non-separable),
98
+ `[filter_taps]` (separable), or
99
+ `None` (identity).
100
+ fd: Float32 downsampling FIR filter of the shape
101
+ `[filter_height, filter_width]` (non-separable),
102
+ `[filter_taps]` (separable), or
103
+ `None` (identity).
104
+ b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
105
+ as `x`. The length of vector must must match the channel dimension of `x`.
106
+ up: Integer upsampling factor (default: 1).
107
+ down: Integer downsampling factor. (default: 1).
108
+ padding: Padding with respect to the upsampled image. Can be a single number
109
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
110
+ (default: 0).
111
+ gain: Overall scaling factor for signal magnitude (default: sqrt(2)).
112
+ slope: Slope on the negative side of leaky ReLU (default: 0.2).
113
+ clamp: Maximum magnitude for leaky ReLU output (default: None).
114
+ flip_filter: False = convolution, True = correlation (default: False).
115
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
116
+
117
+ Returns:
118
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
119
+ """
120
+ assert isinstance(x, torch.Tensor)
121
+ assert impl in ['ref', 'cuda']
122
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
123
+ return _filtered_lrelu_cuda(up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter).apply(x, fu, fd, b, None, 0, 0)
124
+ return _filtered_lrelu_ref(x, fu=fu, fd=fd, b=b, up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter)
125
+
126
+ #----------------------------------------------------------------------------
127
+
128
+ @misc.profiled_function
129
+ def _filtered_lrelu_ref(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False):
130
+ """Slow and memory-inefficient reference implementation of `filtered_lrelu()` using
131
+ existing `upfirdn2n()` and `bias_act()` ops.
132
+ """
133
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
134
+ fu_w, fu_h = _get_filter_size(fu)
135
+ fd_w, fd_h = _get_filter_size(fd)
136
+ if b is not None:
137
+ assert isinstance(b, torch.Tensor) and b.dtype == x.dtype
138
+ misc.assert_shape(b, [x.shape[1]])
139
+ assert isinstance(up, int) and up >= 1
140
+ assert isinstance(down, int) and down >= 1
141
+ px0, px1, py0, py1 = _parse_padding(padding)
142
+ assert gain == float(gain) and gain > 0
143
+ assert slope == float(slope) and slope >= 0
144
+ assert clamp is None or (clamp == float(clamp) and clamp >= 0)
145
+
146
+ # Calculate output size.
147
+ batch_size, channels, in_h, in_w = x.shape
148
+ in_dtype = x.dtype
149
+ out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down
150
+ out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down
151
+
152
+ # Compute using existing ops.
153
+ x = bias_act.bias_act(x=x, b=b) # Apply bias.
154
+ x = upfirdn2d.upfirdn2d(x=x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample.
155
+ x = bias_act.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) # Bias, leaky ReLU, clamp.
156
+ x = upfirdn2d.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) # Downsample.
157
+
158
+ # Check output shape & dtype.
159
+ misc.assert_shape(x, [batch_size, channels, out_h, out_w])
160
+ assert x.dtype == in_dtype
161
+ return x
162
+
163
+ #----------------------------------------------------------------------------
164
+
165
+ _filtered_lrelu_cuda_cache = dict()
166
+
167
+ def _filtered_lrelu_cuda(up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False):
168
+ """Fast CUDA implementation of `filtered_lrelu()` using custom ops.
169
+ """
170
+ assert isinstance(up, int) and up >= 1
171
+ assert isinstance(down, int) and down >= 1
172
+ px0, px1, py0, py1 = _parse_padding(padding)
173
+ assert gain == float(gain) and gain > 0
174
+ gain = float(gain)
175
+ assert slope == float(slope) and slope >= 0
176
+ slope = float(slope)
177
+ assert clamp is None or (clamp == float(clamp) and clamp >= 0)
178
+ clamp = float(clamp if clamp is not None else 'inf')
179
+
180
+ # Lookup from cache.
181
+ key = (up, down, px0, px1, py0, py1, gain, slope, clamp, flip_filter)
182
+ if key in _filtered_lrelu_cuda_cache:
183
+ return _filtered_lrelu_cuda_cache[key]
184
+
185
+ # Forward op.
186
+ class FilteredLReluCuda(torch.autograd.Function):
187
+ @staticmethod
188
+ def forward(ctx, x, fu, fd, b, si, sx, sy): # pylint: disable=arguments-differ
189
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
190
+
191
+ # Replace empty up/downsample kernels with full 1x1 kernels (faster than separable).
192
+ if fu is None:
193
+ fu = torch.ones([1, 1], dtype=torch.float32, device=x.device)
194
+ if fd is None:
195
+ fd = torch.ones([1, 1], dtype=torch.float32, device=x.device)
196
+ assert 1 <= fu.ndim <= 2
197
+ assert 1 <= fd.ndim <= 2
198
+
199
+ # Replace separable 1x1 kernels with full 1x1 kernels when scale factor is 1.
200
+ if up == 1 and fu.ndim == 1 and fu.shape[0] == 1:
201
+ fu = fu.square()[None]
202
+ if down == 1 and fd.ndim == 1 and fd.shape[0] == 1:
203
+ fd = fd.square()[None]
204
+
205
+ # Missing sign input tensor.
206
+ if si is None:
207
+ si = torch.empty([0])
208
+
209
+ # Missing bias tensor.
210
+ if b is None:
211
+ b = torch.zeros([x.shape[1]], dtype=x.dtype, device=x.device)
212
+
213
+ # Construct internal sign tensor only if gradients are needed.
214
+ write_signs = (si.numel() == 0) and (x.requires_grad or b.requires_grad)
215
+
216
+ # Warn if input storage strides are not in decreasing order due to e.g. channels-last layout.
217
+ strides = [x.stride(i) for i in range(x.ndim) if x.size(i) > 1]
218
+ if any(a < b for a, b in zip(strides[:-1], strides[1:])):
219
+ warnings.warn("low-performance memory layout detected in filtered_lrelu input", RuntimeWarning)
220
+
221
+ # Call C++/Cuda plugin if datatype is supported.
222
+ if x.dtype in [torch.float16, torch.float32]:
223
+ if torch.cuda.current_stream(x.device) != torch.cuda.default_stream(x.device):
224
+ warnings.warn("filtered_lrelu called with non-default cuda stream but concurrent execution is not supported", RuntimeWarning)
225
+ y, so, return_code = _plugin.filtered_lrelu(x, fu, fd, b, si, up, down, px0, px1, py0, py1, sx, sy, gain, slope, clamp, flip_filter, write_signs)
226
+ else:
227
+ return_code = -1
228
+
229
+ # No Cuda kernel found? Fall back to generic implementation. Still more memory efficient than the reference implementation because
230
+ # only the bit-packed sign tensor is retained for gradient computation.
231
+ if return_code < 0:
232
+ warnings.warn("filtered_lrelu called with parameters that have no optimized CUDA kernel, using generic fallback", RuntimeWarning)
233
+
234
+ y = x.add(b.unsqueeze(-1).unsqueeze(-1)) # Add bias.
235
+ y = upfirdn2d.upfirdn2d(x=y, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample.
236
+ so = _plugin.filtered_lrelu_act_(y, si, sx, sy, gain, slope, clamp, write_signs) # Activation function and sign handling. Modifies y in-place.
237
+ y = upfirdn2d.upfirdn2d(x=y, f=fd, down=down, flip_filter=flip_filter) # Downsample.
238
+
239
+ # Prepare for gradient computation.
240
+ ctx.save_for_backward(fu, fd, (si if si.numel() else so))
241
+ ctx.x_shape = x.shape
242
+ ctx.y_shape = y.shape
243
+ ctx.s_ofs = sx, sy
244
+ return y
245
+
246
+ @staticmethod
247
+ def backward(ctx, dy): # pylint: disable=arguments-differ
248
+ fu, fd, si = ctx.saved_tensors
249
+ _, _, xh, xw = ctx.x_shape
250
+ _, _, yh, yw = ctx.y_shape
251
+ sx, sy = ctx.s_ofs
252
+ dx = None # 0
253
+ dfu = None; assert not ctx.needs_input_grad[1]
254
+ dfd = None; assert not ctx.needs_input_grad[2]
255
+ db = None # 3
256
+ dsi = None; assert not ctx.needs_input_grad[4]
257
+ dsx = None; assert not ctx.needs_input_grad[5]
258
+ dsy = None; assert not ctx.needs_input_grad[6]
259
+
260
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[3]:
261
+ pp = [
262
+ (fu.shape[-1] - 1) + (fd.shape[-1] - 1) - px0,
263
+ xw * up - yw * down + px0 - (up - 1),
264
+ (fu.shape[0] - 1) + (fd.shape[0] - 1) - py0,
265
+ xh * up - yh * down + py0 - (up - 1),
266
+ ]
267
+ gg = gain * (up ** 2) / (down ** 2)
268
+ ff = (not flip_filter)
269
+ sx = sx - (fu.shape[-1] - 1) + px0
270
+ sy = sy - (fu.shape[0] - 1) + py0
271
+ dx = _filtered_lrelu_cuda(up=down, down=up, padding=pp, gain=gg, slope=slope, clamp=None, flip_filter=ff).apply(dy, fd, fu, None, si, sx, sy)
272
+
273
+ if ctx.needs_input_grad[3]:
274
+ db = dx.sum([0, 2, 3])
275
+
276
+ return dx, dfu, dfd, db, dsi, dsx, dsy
277
+
278
+ # Add to cache.
279
+ _filtered_lrelu_cuda_cache[key] = FilteredLReluCuda
280
+ return FilteredLReluCuda
281
+
282
+ #----------------------------------------------------------------------------
torch_utils/ops/filtered_lrelu_ns.cu ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include "filtered_lrelu.cu"
10
+
11
+ // Template/kernel specializations for no signs mode (no gradients required).
12
+
13
+ // Full op, 32-bit indexing.
14
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
15
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int32_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
16
+
17
+ // Full op, 64-bit indexing.
18
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
19
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int64_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
20
+
21
+ // Activation/signs only for generic variant. 64-bit indexing.
22
+ template void* choose_filtered_lrelu_act_kernel<c10::Half, false, false>(void);
23
+ template void* choose_filtered_lrelu_act_kernel<float, false, false>(void);
24
+ template void* choose_filtered_lrelu_act_kernel<double, false, false>(void);
25
+
26
+ // Copy filters to constant memory.
27
+ template cudaError_t copy_filters<false, false>(cudaStream_t stream);
torch_utils/ops/filtered_lrelu_rd.cu ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include "filtered_lrelu.cu"
10
+
11
+ // Template/kernel specializations for sign read mode.
12
+
13
+ // Full op, 32-bit indexing.
14
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
15
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int32_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
16
+
17
+ // Full op, 64-bit indexing.
18
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
19
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int64_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
20
+
21
+ // Activation/signs only for generic variant. 64-bit indexing.
22
+ template void* choose_filtered_lrelu_act_kernel<c10::Half, false, true>(void);
23
+ template void* choose_filtered_lrelu_act_kernel<float, false, true>(void);
24
+ template void* choose_filtered_lrelu_act_kernel<double, false, true>(void);
25
+
26
+ // Copy filters to constant memory.
27
+ template cudaError_t copy_filters<false, true>(cudaStream_t stream);
torch_utils/ops/filtered_lrelu_wr.cu ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include "filtered_lrelu.cu"
10
+
11
+ // Template/kernel specializations for sign write mode.
12
+
13
+ // Full op, 32-bit indexing.
14
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
15
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int32_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
16
+
17
+ // Full op, 64-bit indexing.
18
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
19
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int64_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
20
+
21
+ // Activation/signs only for generic variant. 64-bit indexing.
22
+ template void* choose_filtered_lrelu_act_kernel<c10::Half, true, false>(void);
23
+ template void* choose_filtered_lrelu_act_kernel<float, true, false>(void);
24
+ template void* choose_filtered_lrelu_act_kernel<double, true, false>(void);
25
+
26
+ // Copy filters to constant memory.
27
+ template cudaError_t copy_filters<true, false>(cudaStream_t stream);
torch_utils/ops/fma.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ # and proprietary rights in and to this software, related documentation
7
+ # and any modifications thereto. Any use, reproduction, disclosure or
8
+ # distribution of this software and related documentation without an express
9
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+
11
+ """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
12
+
13
+ import torch
14
+
15
+ #----------------------------------------------------------------------------
16
+
17
+ def fma(a, b, c): # => a * b + c
18
+ return _FusedMultiplyAdd.apply(a, b, c)
19
+
20
+ #----------------------------------------------------------------------------
21
+
22
+ class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
23
+ @staticmethod
24
+ def forward(ctx, a, b, c): # pylint: disable=arguments-differ
25
+ out = torch.addcmul(c, a, b)
26
+ ctx.save_for_backward(a, b)
27
+ ctx.c_shape = c.shape
28
+ return out
29
+
30
+ @staticmethod
31
+ def backward(ctx, dout): # pylint: disable=arguments-differ
32
+ a, b = ctx.saved_tensors
33
+ c_shape = ctx.c_shape
34
+ da = None
35
+ db = None
36
+ dc = None
37
+
38
+ if ctx.needs_input_grad[0]:
39
+ da = _unbroadcast(dout * b, a.shape)
40
+
41
+ if ctx.needs_input_grad[1]:
42
+ db = _unbroadcast(dout * a, b.shape)
43
+
44
+ if ctx.needs_input_grad[2]:
45
+ dc = _unbroadcast(dout, c_shape)
46
+
47
+ return da, db, dc
48
+
49
+ #----------------------------------------------------------------------------
50
+
51
+ def _unbroadcast(x, shape):
52
+ extra_dims = x.ndim - len(shape)
53
+ assert extra_dims >= 0
54
+ dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
55
+ if len(dim):
56
+ x = x.sum(dim=dim, keepdim=True)
57
+ if extra_dims:
58
+ x = x.reshape(-1, *x.shape[extra_dims+1:])
59
+ assert x.shape == shape
60
+ return x
61
+
62
+ #----------------------------------------------------------------------------
torch_utils/ops/grid_sample_gradfix.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ # and proprietary rights in and to this software, related documentation
7
+ # and any modifications thereto. Any use, reproduction, disclosure or
8
+ # distribution of this software and related documentation without an express
9
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+
11
+ """Custom replacement for `torch.nn.functional.grid_sample` that
12
+ supports arbitrarily high order gradients between the input and output.
13
+ Only works on 2D images and assumes
14
+ `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
15
+
16
+ import warnings
17
+ import torch
18
+
19
+ # pylint: disable=redefined-builtin
20
+ # pylint: disable=arguments-differ
21
+ # pylint: disable=protected-access
22
+
23
+ #----------------------------------------------------------------------------
24
+
25
+ enabled = False # Enable the custom op by setting this to true.
26
+
27
+ #----------------------------------------------------------------------------
28
+
29
+ def grid_sample(input, grid):
30
+ if _should_use_custom_op():
31
+ return _GridSample2dForward.apply(input, grid)
32
+ return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
33
+
34
+ #----------------------------------------------------------------------------
35
+
36
+ def _should_use_custom_op():
37
+ if not enabled:
38
+ return False
39
+ if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
40
+ return True
41
+ warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().')
42
+ return False
43
+
44
+ #----------------------------------------------------------------------------
45
+
46
+ class _GridSample2dForward(torch.autograd.Function):
47
+ @staticmethod
48
+ def forward(ctx, input, grid):
49
+ assert input.ndim == 4
50
+ assert grid.ndim == 4
51
+ output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
52
+ ctx.save_for_backward(input, grid)
53
+ return output
54
+
55
+ @staticmethod
56
+ def backward(ctx, grad_output):
57
+ input, grid = ctx.saved_tensors
58
+ grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
59
+ return grad_input, grad_grid
60
+
61
+ #----------------------------------------------------------------------------
62
+
63
+ class _GridSample2dBackward(torch.autograd.Function):
64
+ @staticmethod
65
+ def forward(ctx, grad_output, input, grid):
66
+ op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
67
+ grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
68
+ ctx.save_for_backward(grid)
69
+ return grad_input, grad_grid
70
+
71
+ @staticmethod
72
+ def backward(ctx, grad2_grad_input, grad2_grad_grid):
73
+ _ = grad2_grad_grid # unused
74
+ grid, = ctx.saved_tensors
75
+ grad2_grad_output = None
76
+ grad2_input = None
77
+ grad2_grid = None
78
+
79
+ if ctx.needs_input_grad[0]:
80
+ grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
81
+
82
+ assert not ctx.needs_input_grad[2]
83
+ return grad2_grad_output, grad2_input, grad2_grid
84
+
85
+ #----------------------------------------------------------------------------
torch_utils/ops/upfirdn2d.cpp ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
+ //
5
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ // and proprietary rights in and to this software, related documentation
7
+ // and any modifications thereto. Any use, reproduction, disclosure or
8
+ // distribution of this software and related documentation without an express
9
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+
11
+ #include <torch/extension.h>
12
+ #include <ATen/cuda/CUDAContext.h>
13
+ #include <c10/cuda/CUDAGuard.h>
14
+ #include "upfirdn2d.h"
15
+
16
+ //------------------------------------------------------------------------
17
+
18
+ static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
19
+ {
20
+ // Validate arguments.
21
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
22
+ TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
23
+ TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
24
+ TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
25
+ TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
26
+ TORCH_CHECK(x.dim() == 4, "x must be rank 4");
27
+ TORCH_CHECK(f.dim() == 2, "f must be rank 2");
28
+ TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
29
+ TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
30
+ TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
31
+
32
+ // Create output tensor.
33
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
34
+ int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
35
+ int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
36
+ TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
37
+ torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
38
+ TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
39
+
40
+ // Initialize CUDA kernel parameters.
41
+ upfirdn2d_kernel_params p;
42
+ p.x = x.data_ptr();
43
+ p.f = f.data_ptr<float>();
44
+ p.y = y.data_ptr();
45
+ p.up = make_int2(upx, upy);
46
+ p.down = make_int2(downx, downy);
47
+ p.pad0 = make_int2(padx0, pady0);
48
+ p.flip = (flip) ? 1 : 0;
49
+ p.gain = gain;
50
+ p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
51
+ p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
52
+ p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
53
+ p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
54
+ p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
55
+ p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
56
+ p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
57
+ p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
58
+
59
+ // Choose CUDA kernel.
60
+ upfirdn2d_kernel_spec spec;
61
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
62
+ {
63
+ spec = choose_upfirdn2d_kernel<scalar_t>(p);
64
+ });
65
+
66
+ // Set looping options.
67
+ p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
68
+ p.loopMinor = spec.loopMinor;
69
+ p.loopX = spec.loopX;
70
+ p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
71
+ p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
72
+
73
+ // Compute grid size.
74
+ dim3 blockSize, gridSize;
75
+ if (spec.tileOutW < 0) // large
76
+ {
77
+ blockSize = dim3(4, 32, 1);
78
+ gridSize = dim3(
79
+ ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
80
+ (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
81
+ p.launchMajor);
82
+ }
83
+ else // small
84
+ {
85
+ blockSize = dim3(256, 1, 1);
86
+ gridSize = dim3(
87
+ ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
88
+ (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
89
+ p.launchMajor);
90
+ }
91
+
92
+ // Launch CUDA kernel.
93
+ void* args[] = {&p};
94
+ AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
95
+ return y;
96
+ }
97
+
98
+ //------------------------------------------------------------------------
99
+
100
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
101
+ {
102
+ m.def("upfirdn2d", &upfirdn2d);
103
+ }
104
+
105
+ //------------------------------------------------------------------------
torch_utils/ops/upfirdn2d.cu ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
+ //
5
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ // and proprietary rights in and to this software, related documentation
7
+ // and any modifications thereto. Any use, reproduction, disclosure or
8
+ // distribution of this software and related documentation without an express
9
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+
11
+ #include <c10/util/Half.h>
12
+ #include "upfirdn2d.h"
13
+
14
+ //------------------------------------------------------------------------
15
+ // Helpers.
16
+
17
+ template <class T> struct InternalType;
18
+ template <> struct InternalType<double> { typedef double scalar_t; };
19
+ template <> struct InternalType<float> { typedef float scalar_t; };
20
+ template <> struct InternalType<c10::Half> { typedef float scalar_t; };
21
+
22
+ static __device__ __forceinline__ int floor_div(int a, int b)
23
+ {
24
+ int t = 1 - a / b;
25
+ return (a + t * b) / b - t;
26
+ }
27
+
28
+ //------------------------------------------------------------------------
29
+ // Generic CUDA implementation for large filters.
30
+
31
+ template <class T> static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p)
32
+ {
33
+ typedef typename InternalType<T>::scalar_t scalar_t;
34
+
35
+ // Calculate thread index.
36
+ int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
37
+ int outY = minorBase / p.launchMinor;
38
+ minorBase -= outY * p.launchMinor;
39
+ int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
40
+ int majorBase = blockIdx.z * p.loopMajor;
41
+ if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
42
+ return;
43
+
44
+ // Setup Y receptive field.
45
+ int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
46
+ int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
47
+ int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
48
+ int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
49
+ if (p.flip)
50
+ filterY = p.filterSize.y - 1 - filterY;
51
+
52
+ // Loop over major, minor, and X.
53
+ for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
54
+ for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor)
55
+ {
56
+ int nc = major * p.sizeMinor + minor;
57
+ int n = nc / p.inSize.z;
58
+ int c = nc - n * p.inSize.z;
59
+ for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y)
60
+ {
61
+ // Setup X receptive field.
62
+ int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
63
+ int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
64
+ int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX;
65
+ int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
66
+ if (p.flip)
67
+ filterX = p.filterSize.x - 1 - filterX;
68
+
69
+ // Initialize pointers.
70
+ const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
71
+ const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
72
+ int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
73
+ int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
74
+
75
+ // Inner loop.
76
+ scalar_t v = 0;
77
+ for (int y = 0; y < h; y++)
78
+ {
79
+ for (int x = 0; x < w; x++)
80
+ {
81
+ v += (scalar_t)(*xp) * (scalar_t)(*fp);
82
+ xp += p.inStride.x;
83
+ fp += filterStepX;
84
+ }
85
+ xp += p.inStride.y - w * p.inStride.x;
86
+ fp += filterStepY - w * filterStepX;
87
+ }
88
+
89
+ // Store result.
90
+ v *= p.gain;
91
+ ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
92
+ }
93
+ }
94
+ }
95
+
96
+ //------------------------------------------------------------------------
97
+ // Specialized CUDA implementation for small filters.
98
+
99
+ template <class T, int upx, int upy, int downx, int downy, int filterW, int filterH, int tileOutW, int tileOutH, int loopMinor>
100
+ static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p)
101
+ {
102
+ typedef typename InternalType<T>::scalar_t scalar_t;
103
+ const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
104
+ const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
105
+ __shared__ volatile scalar_t sf[filterH][filterW];
106
+ __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
107
+
108
+ // Calculate tile index.
109
+ int minorBase = blockIdx.x;
110
+ int tileOutY = minorBase / p.launchMinor;
111
+ minorBase -= tileOutY * p.launchMinor;
112
+ minorBase *= loopMinor;
113
+ tileOutY *= tileOutH;
114
+ int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
115
+ int majorBase = blockIdx.z * p.loopMajor;
116
+ if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor)
117
+ return;
118
+
119
+ // Load filter (flipped).
120
+ for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x)
121
+ {
122
+ int fy = tapIdx / filterW;
123
+ int fx = tapIdx - fy * filterW;
124
+ scalar_t v = 0;
125
+ if (fx < p.filterSize.x & fy < p.filterSize.y)
126
+ {
127
+ int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
128
+ int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
129
+ v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
130
+ }
131
+ sf[fy][fx] = v;
132
+ }
133
+
134
+ // Loop over major and X.
135
+ for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
136
+ {
137
+ int baseNC = major * p.sizeMinor + minorBase;
138
+ int n = baseNC / p.inSize.z;
139
+ int baseC = baseNC - n * p.inSize.z;
140
+ for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW)
141
+ {
142
+ // Load input pixels.
143
+ int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
144
+ int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
145
+ int tileInX = floor_div(tileMidX, upx);
146
+ int tileInY = floor_div(tileMidY, upy);
147
+ __syncthreads();
148
+ for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x)
149
+ {
150
+ int relC = inIdx;
151
+ int relInX = relC / loopMinor;
152
+ int relInY = relInX / tileInW;
153
+ relC -= relInX * loopMinor;
154
+ relInX -= relInY * tileInW;
155
+ int c = baseC + relC;
156
+ int inX = tileInX + relInX;
157
+ int inY = tileInY + relInY;
158
+ scalar_t v = 0;
159
+ if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z)
160
+ v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
161
+ sx[relInY][relInX][relC] = v;
162
+ }
163
+
164
+ // Loop over output pixels.
165
+ __syncthreads();
166
+ for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x)
167
+ {
168
+ int relC = outIdx;
169
+ int relOutX = relC / loopMinor;
170
+ int relOutY = relOutX / tileOutW;
171
+ relC -= relOutX * loopMinor;
172
+ relOutX -= relOutY * tileOutW;
173
+ int c = baseC + relC;
174
+ int outX = tileOutX + relOutX;
175
+ int outY = tileOutY + relOutY;
176
+
177
+ // Setup receptive field.
178
+ int midX = tileMidX + relOutX * downx;
179
+ int midY = tileMidY + relOutY * downy;
180
+ int inX = floor_div(midX, upx);
181
+ int inY = floor_div(midY, upy);
182
+ int relInX = inX - tileInX;
183
+ int relInY = inY - tileInY;
184
+ int filterX = (inX + 1) * upx - midX - 1; // flipped
185
+ int filterY = (inY + 1) * upy - midY - 1; // flipped
186
+
187
+ // Inner loop.
188
+ if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z)
189
+ {
190
+ scalar_t v = 0;
191
+ #pragma unroll
192
+ for (int y = 0; y < filterH / upy; y++)
193
+ #pragma unroll
194
+ for (int x = 0; x < filterW / upx; x++)
195
+ v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx];
196
+ v *= p.gain;
197
+ ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
198
+ }
199
+ }
200
+ }
201
+ }
202
+ }
203
+
204
+ //------------------------------------------------------------------------
205
+ // CUDA kernel selection.
206
+
207
+ template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p)
208
+ {
209
+ int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
210
+
211
+ upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,1, 4}; // contiguous
212
+ if (s == 1) spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,4, 1}; // channels_last
213
+
214
+ if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
215
+ {
216
+ if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 64,16,1>, 64,16,1, 1};
217
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
218
+ if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5, 64,16,1>, 64,16,1, 1};
219
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
220
+ if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3, 64,16,1>, 64,16,1, 1};
221
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
222
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 20,1, 128,8,1>, 128,8,1, 1};
223
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
224
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 12,1, 128,8,1>, 128,8,1, 1};
225
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
226
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
227
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,20, 32,32,1>, 32,32,1, 1};
228
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
229
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,12, 32,32,1>, 32,32,1, 1};
230
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
231
+ }
232
+ if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
233
+ {
234
+ if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 16,16,8>, 16,16,8, 1};
235
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
236
+ if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
237
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
238
+ if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
239
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
240
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 20,1, 128,1,16>, 128,1,16, 1};
241
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
242
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 12,1, 128,1,16>, 128,1,16, 1};
243
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
244
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
245
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,20, 1,128,16>, 1,128,16, 1};
246
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
247
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,12, 1,128,16>, 1,128,16, 1};
248
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
249
+ }
250
+ if (s != 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
251
+ {
252
+ if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 64,16,1>, 64,16,1, 1};
253
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
254
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
255
+ if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 64,16,1>, 64,16,1, 1};
256
+ }
257
+ if (s == 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
258
+ {
259
+ if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 16,16,8>, 16,16,8, 1};
260
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 16,16,8>, 16,16,8, 1};
261
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
262
+ if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 16,16,8>, 16,16,8, 1};
263
+ }
264
+ if (s != 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
265
+ {
266
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
267
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 20,1, 128,8,1>, 128,8,1, 1};
268
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
269
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 12,1, 128,8,1>, 128,8,1, 1};
270
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
271
+ }
272
+ if (s == 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
273
+ {
274
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
275
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 20,1, 128,1,16>, 128,1,16, 1};
276
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
277
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 12,1, 128,1,16>, 128,1,16, 1};
278
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
279
+ }
280
+ if (s != 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
281
+ {
282
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
283
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,20, 32,32,1>, 32,32,1, 1};
284
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
285
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,12, 32,32,1>, 32,32,1, 1};
286
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
287
+ }
288
+ if (s == 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
289
+ {
290
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
291
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,20, 1,128,16>, 1,128,16, 1};
292
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
293
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,12, 1,128,16>, 1,128,16, 1};
294
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
295
+ }
296
+ if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // contiguous
297
+ {
298
+ if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 32,8,1>, 32,8,1, 1};
299
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 32,8,1>, 32,8,1, 1};
300
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 32,8,1>, 32,8,1, 1};
301
+ if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 32,8,1>, 32,8,1, 1};
302
+ }
303
+ if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // channels_last
304
+ {
305
+ if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 8,8,8>, 8,8,8, 1};
306
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 8,8,8>, 8,8,8, 1};
307
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 8,8,8>, 8,8,8, 1};
308
+ if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 8,8,8>, 8,8,8, 1};
309
+ }
310
+ if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // contiguous
311
+ {
312
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,8,1>, 64,8,1, 1};
313
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 20,1, 64,8,1>, 64,8,1, 1};
314
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,8,1>, 64,8,1, 1};
315
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 12,1, 64,8,1>, 64,8,1, 1};
316
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,8,1>, 64,8,1, 1};
317
+ }
318
+ if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // channels_last
319
+ {
320
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,1,8>, 64,1,8, 1};
321
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 20,1, 64,1,8>, 64,1,8, 1};
322
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,1,8>, 64,1,8, 1};
323
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 12,1, 64,1,8>, 64,1,8, 1};
324
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,1,8>, 64,1,8, 1};
325
+ }
326
+ if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // contiguous
327
+ {
328
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 32,16,1>, 32,16,1, 1};
329
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,20, 32,16,1>, 32,16,1, 1};
330
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 32,16,1>, 32,16,1, 1};
331
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,12, 32,16,1>, 32,16,1, 1};
332
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 32,16,1>, 32,16,1, 1};
333
+ }
334
+ if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // channels_last
335
+ {
336
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 1,64,8>, 1,64,8, 1};
337
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,20, 1,64,8>, 1,64,8, 1};
338
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 1,64,8>, 1,64,8, 1};
339
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,12, 1,64,8>, 1,64,8, 1};
340
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 1,64,8>, 1,64,8, 1};
341
+ }
342
+ return spec;
343
+ }
344
+
345
+ //------------------------------------------------------------------------
346
+ // Template specializations.
347
+
348
+ template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<double> (const upfirdn2d_kernel_params& p);
349
+ template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<float> (const upfirdn2d_kernel_params& p);
350
+ template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<c10::Half>(const upfirdn2d_kernel_params& p);
351
+
352
+ //------------------------------------------------------------------------
torch_utils/ops/upfirdn2d.h ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
+ //
5
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ // and proprietary rights in and to this software, related documentation
7
+ // and any modifications thereto. Any use, reproduction, disclosure or
8
+ // distribution of this software and related documentation without an express
9
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+
11
+ #include <cuda_runtime.h>
12
+
13
+ //------------------------------------------------------------------------
14
+ // CUDA kernel parameters.
15
+
16
+ struct upfirdn2d_kernel_params
17
+ {
18
+ const void* x;
19
+ const float* f;
20
+ void* y;
21
+
22
+ int2 up;
23
+ int2 down;
24
+ int2 pad0;
25
+ int flip;
26
+ float gain;
27
+
28
+ int4 inSize; // [width, height, channel, batch]
29
+ int4 inStride;
30
+ int2 filterSize; // [width, height]
31
+ int2 filterStride;
32
+ int4 outSize; // [width, height, channel, batch]
33
+ int4 outStride;
34
+ int sizeMinor;
35
+ int sizeMajor;
36
+
37
+ int loopMinor;
38
+ int loopMajor;
39
+ int loopX;
40
+ int launchMinor;
41
+ int launchMajor;
42
+ };
43
+
44
+ //------------------------------------------------------------------------
45
+ // CUDA kernel specialization.
46
+
47
+ struct upfirdn2d_kernel_spec
48
+ {
49
+ void* kernel;
50
+ int tileOutW;
51
+ int tileOutH;
52
+ int loopMinor;
53
+ int loopX;
54
+ };
55
+
56
+ //------------------------------------------------------------------------
57
+ // CUDA kernel selection.
58
+
59
+ template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
60
+
61
+ //------------------------------------------------------------------------
torch_utils/ops/upfirdn2d.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ # and proprietary rights in and to this software, related documentation
7
+ # and any modifications thereto. Any use, reproduction, disclosure or
8
+ # distribution of this software and related documentation without an express
9
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+
11
+ """Custom PyTorch ops for efficient resampling of 2D images."""
12
+
13
+ import os
14
+ import warnings
15
+ import numpy as np
16
+ import torch
17
+ import traceback
18
+
19
+ from .. import custom_ops
20
+ from .. import misc
21
+ from . import conv2d_gradfix
22
+
23
+ #----------------------------------------------------------------------------
24
+
25
+ _inited = False
26
+ _plugin = None
27
+
28
+ def _init():
29
+ global _inited, _plugin
30
+ if not _inited:
31
+ sources = ['upfirdn2d.cpp', 'upfirdn2d.cu']
32
+ sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
33
+ try:
34
+ _plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
35
+ except:
36
+ warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
37
+ return _plugin is not None
38
+
39
+ def _parse_scaling(scaling):
40
+ if isinstance(scaling, int):
41
+ scaling = [scaling, scaling]
42
+ assert isinstance(scaling, (list, tuple))
43
+ assert all(isinstance(x, int) for x in scaling)
44
+ sx, sy = scaling
45
+ assert sx >= 1 and sy >= 1
46
+ return sx, sy
47
+
48
+ def _parse_padding(padding):
49
+ if isinstance(padding, int):
50
+ padding = [padding, padding]
51
+ assert isinstance(padding, (list, tuple))
52
+ assert all(isinstance(x, int) for x in padding)
53
+ if len(padding) == 2:
54
+ padx, pady = padding
55
+ padding = [padx, padx, pady, pady]
56
+ padx0, padx1, pady0, pady1 = padding
57
+ return padx0, padx1, pady0, pady1
58
+
59
+ def _get_filter_size(f):
60
+ if f is None:
61
+ return 1, 1
62
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
63
+ fw = f.shape[-1]
64
+ fh = f.shape[0]
65
+ with misc.suppress_tracer_warnings():
66
+ fw = int(fw)
67
+ fh = int(fh)
68
+ misc.assert_shape(f, [fh, fw][:f.ndim])
69
+ assert fw >= 1 and fh >= 1
70
+ return fw, fh
71
+
72
+ #----------------------------------------------------------------------------
73
+
74
+ def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None):
75
+ r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
76
+
77
+ Args:
78
+ f: Torch tensor, numpy array, or python list of the shape
79
+ `[filter_height, filter_width]` (non-separable),
80
+ `[filter_taps]` (separable),
81
+ `[]` (impulse), or
82
+ `None` (identity).
83
+ device: Result device (default: cpu).
84
+ normalize: Normalize the filter so that it retains the magnitude
85
+ for constant input signal (DC)? (default: True).
86
+ flip_filter: Flip the filter? (default: False).
87
+ gain: Overall scaling factor for signal magnitude (default: 1).
88
+ separable: Return a separable filter? (default: select automatically).
89
+
90
+ Returns:
91
+ Float32 tensor of the shape
92
+ `[filter_height, filter_width]` (non-separable) or
93
+ `[filter_taps]` (separable).
94
+ """
95
+ # Validate.
96
+ if f is None:
97
+ f = 1
98
+ f = torch.as_tensor(f, dtype=torch.float32)
99
+ assert f.ndim in [0, 1, 2]
100
+ assert f.numel() > 0
101
+ if f.ndim == 0:
102
+ f = f[np.newaxis]
103
+
104
+ # Separable?
105
+ if separable is None:
106
+ separable = (f.ndim == 1 and f.numel() >= 8)
107
+ if f.ndim == 1 and not separable:
108
+ f = f.ger(f)
109
+ assert f.ndim == (1 if separable else 2)
110
+
111
+ # Apply normalize, flip, gain, and device.
112
+ if normalize:
113
+ f /= f.sum()
114
+ if flip_filter:
115
+ f = f.flip(list(range(f.ndim)))
116
+ f = f * (gain ** (f.ndim / 2))
117
+ f = f.to(device=device)
118
+ return f
119
+
120
+ #----------------------------------------------------------------------------
121
+
122
+ def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
123
+ r"""Pad, upsample, filter, and downsample a batch of 2D images.
124
+
125
+ Performs the following sequence of operations for each channel:
126
+
127
+ 1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
128
+
129
+ 2. Pad the image with the specified number of zeros on each side (`padding`).
130
+ Negative padding corresponds to cropping the image.
131
+
132
+ 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
133
+ so that the footprint of all output pixels lies within the input image.
134
+
135
+ 4. Downsample the image by keeping every Nth pixel (`down`).
136
+
137
+ This sequence of operations bears close resemblance to scipy.signal.upfirdn().
138
+ The fused op is considerably more efficient than performing the same calculation
139
+ using standard PyTorch ops. It supports gradients of arbitrary order.
140
+
141
+ Args:
142
+ x: Float32/float64/float16 input tensor of the shape
143
+ `[batch_size, num_channels, in_height, in_width]`.
144
+ f: Float32 FIR filter of the shape
145
+ `[filter_height, filter_width]` (non-separable),
146
+ `[filter_taps]` (separable), or
147
+ `None` (identity).
148
+ up: Integer upsampling factor. Can be a single int or a list/tuple
149
+ `[x, y]` (default: 1).
150
+ down: Integer downsampling factor. Can be a single int or a list/tuple
151
+ `[x, y]` (default: 1).
152
+ padding: Padding with respect to the upsampled image. Can be a single number
153
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
154
+ (default: 0).
155
+ flip_filter: False = convolution, True = correlation (default: False).
156
+ gain: Overall scaling factor for signal magnitude (default: 1).
157
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
158
+
159
+ Returns:
160
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
161
+ """
162
+ assert isinstance(x, torch.Tensor)
163
+ assert impl in ['ref', 'cuda']
164
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
165
+ return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f)
166
+ return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
167
+
168
+ #----------------------------------------------------------------------------
169
+
170
+ @misc.profiled_function
171
+ def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
172
+ """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
173
+ """
174
+ # Validate arguments.
175
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
176
+ if f is None:
177
+ f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
178
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
179
+ assert f.dtype == torch.float32 and not f.requires_grad
180
+ batch_size, num_channels, in_height, in_width = x.shape
181
+ upx, upy = _parse_scaling(up)
182
+ downx, downy = _parse_scaling(down)
183
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
184
+
185
+ # Upsample by inserting zeros.
186
+ x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
187
+ x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
188
+ x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
189
+
190
+ # Pad or crop.
191
+ x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
192
+ x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)]
193
+
194
+ # Setup filter.
195
+ f = f * (gain ** (f.ndim / 2))
196
+ f = f.to(x.dtype)
197
+ if not flip_filter:
198
+ f = f.flip(list(range(f.ndim)))
199
+
200
+ # Convolve with the filter.
201
+ f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
202
+ if f.ndim == 4:
203
+ x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels)
204
+ else:
205
+ x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
206
+ x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
207
+
208
+ # Downsample by throwing away pixels.
209
+ x = x[:, :, ::downy, ::downx]
210
+ return x
211
+
212
+ #----------------------------------------------------------------------------
213
+
214
+ _upfirdn2d_cuda_cache = dict()
215
+
216
+ def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1):
217
+ """Fast CUDA implementation of `upfirdn2d()` using custom ops.
218
+ """
219
+ # Parse arguments.
220
+ upx, upy = _parse_scaling(up)
221
+ downx, downy = _parse_scaling(down)
222
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
223
+
224
+ # Lookup from cache.
225
+ key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
226
+ if key in _upfirdn2d_cuda_cache:
227
+ return _upfirdn2d_cuda_cache[key]
228
+
229
+ # Forward op.
230
+ class Upfirdn2dCuda(torch.autograd.Function):
231
+ @staticmethod
232
+ def forward(ctx, x, f): # pylint: disable=arguments-differ
233
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
234
+ if f is None:
235
+ f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
236
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
237
+ y = x
238
+ if f.ndim == 2:
239
+ y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
240
+ else:
241
+ y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain))
242
+ y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain))
243
+ ctx.save_for_backward(f)
244
+ ctx.x_shape = x.shape
245
+ return y
246
+
247
+ @staticmethod
248
+ def backward(ctx, dy): # pylint: disable=arguments-differ
249
+ f, = ctx.saved_tensors
250
+ _, _, ih, iw = ctx.x_shape
251
+ _, _, oh, ow = dy.shape
252
+ fw, fh = _get_filter_size(f)
253
+ p = [
254
+ fw - padx0 - 1,
255
+ iw * upx - ow * downx + padx0 - upx + 1,
256
+ fh - pady0 - 1,
257
+ ih * upy - oh * downy + pady0 - upy + 1,
258
+ ]
259
+ dx = None
260
+ df = None
261
+
262
+ if ctx.needs_input_grad[0]:
263
+ dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f)
264
+
265
+ assert not ctx.needs_input_grad[1]
266
+ return dx, df
267
+
268
+ # Add to cache.
269
+ _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda
270
+ return Upfirdn2dCuda
271
+
272
+ #----------------------------------------------------------------------------
273
+
274
+ def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'):
275
+ r"""Filter a batch of 2D images using the given 2D FIR filter.
276
+
277
+ By default, the result is padded so that its shape matches the input.
278
+ User-specified padding is applied on top of that, with negative values
279
+ indicating cropping. Pixels outside the image are assumed to be zero.
280
+
281
+ Args:
282
+ x: Float32/float64/float16 input tensor of the shape
283
+ `[batch_size, num_channels, in_height, in_width]`.
284
+ f: Float32 FIR filter of the shape
285
+ `[filter_height, filter_width]` (non-separable),
286
+ `[filter_taps]` (separable), or
287
+ `None` (identity).
288
+ padding: Padding with respect to the output. Can be a single number or a
289
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
290
+ (default: 0).
291
+ flip_filter: False = convolution, True = correlation (default: False).
292
+ gain: Overall scaling factor for signal magnitude (default: 1).
293
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
294
+
295
+ Returns:
296
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
297
+ """
298
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
299
+ fw, fh = _get_filter_size(f)
300
+ p = [
301
+ padx0 + fw // 2,
302
+ padx1 + (fw - 1) // 2,
303
+ pady0 + fh // 2,
304
+ pady1 + (fh - 1) // 2,
305
+ ]
306
+ return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
307
+
308
+ #----------------------------------------------------------------------------
309
+
310
+ def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
311
+ r"""Upsample a batch of 2D images using the given 2D FIR filter.
312
+
313
+ By default, the result is padded so that its shape is a multiple of the input.
314
+ User-specified padding is applied on top of that, with negative values
315
+ indicating cropping. Pixels outside the image are assumed to be zero.
316
+
317
+ Args:
318
+ x: Float32/float64/float16 input tensor of the shape
319
+ `[batch_size, num_channels, in_height, in_width]`.
320
+ f: Float32 FIR filter of the shape
321
+ `[filter_height, filter_width]` (non-separable),
322
+ `[filter_taps]` (separable), or
323
+ `None` (identity).
324
+ up: Integer upsampling factor. Can be a single int or a list/tuple
325
+ `[x, y]` (default: 1).
326
+ padding: Padding with respect to the output. Can be a single number or a
327
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
328
+ (default: 0).
329
+ flip_filter: False = convolution, True = correlation (default: False).
330
+ gain: Overall scaling factor for signal magnitude (default: 1).
331
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
332
+
333
+ Returns:
334
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
335
+ """
336
+ upx, upy = _parse_scaling(up)
337
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
338
+ fw, fh = _get_filter_size(f)
339
+ p = [
340
+ padx0 + (fw + upx - 1) // 2,
341
+ padx1 + (fw - upx) // 2,
342
+ pady0 + (fh + upy - 1) // 2,
343
+ pady1 + (fh - upy) // 2,
344
+ ]
345
+ return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl)
346
+
347
+ #----------------------------------------------------------------------------
348
+
349
+ def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
350
+ r"""Downsample a batch of 2D images using the given 2D FIR filter.
351
+
352
+ By default, the result is padded so that its shape is a fraction of the input.
353
+ User-specified padding is applied on top of that, with negative values
354
+ indicating cropping. Pixels outside the image are assumed to be zero.
355
+
356
+ Args:
357
+ x: Float32/float64/float16 input tensor of the shape
358
+ `[batch_size, num_channels, in_height, in_width]`.
359
+ f: Float32 FIR filter of the shape
360
+ `[filter_height, filter_width]` (non-separable),
361
+ `[filter_taps]` (separable), or
362
+ `None` (identity).
363
+ down: Integer downsampling factor. Can be a single int or a list/tuple
364
+ `[x, y]` (default: 1).
365
+ padding: Padding with respect to the input. Can be a single number or a
366
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
367
+ (default: 0).
368
+ flip_filter: False = convolution, True = correlation (default: False).
369
+ gain: Overall scaling factor for signal magnitude (default: 1).
370
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
371
+
372
+ Returns:
373
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
374
+ """
375
+ downx, downy = _parse_scaling(down)
376
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
377
+ fw, fh = _get_filter_size(f)
378
+ p = [
379
+ padx0 + (fw - downx + 1) // 2,
380
+ padx1 + (fw - downx) // 2,
381
+ pady0 + (fh - downy + 1) // 2,
382
+ pady1 + (fh - downy) // 2,
383
+ ]
384
+ return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
385
+
386
+ #----------------------------------------------------------------------------
torch_utils/persistence.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ # and proprietary rights in and to this software, related documentation
7
+ # and any modifications thereto. Any use, reproduction, disclosure or
8
+ # distribution of this software and related documentation without an express
9
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+
11
+ """Facilities for pickling Python code alongside other data.
12
+
13
+ The pickled code is automatically imported into a separate Python module
14
+ during unpickling. This way, any previously exported pickles will remain
15
+ usable even if the original code is no longer available, or if the current
16
+ version of the code is not consistent with what was originally pickled."""
17
+
18
+ import sys
19
+ import pickle
20
+ import io
21
+ import inspect
22
+ import copy
23
+ import uuid
24
+ import types
25
+ import dnnlib
26
+
27
+ #----------------------------------------------------------------------------
28
+
29
+ _version = 6 # internal version number
30
+ _decorators = set() # {decorator_class, ...}
31
+ _import_hooks = [] # [hook_function, ...]
32
+ _module_to_src_dict = dict() # {module: src, ...}
33
+ _src_to_module_dict = dict() # {src: module, ...}
34
+
35
+ #----------------------------------------------------------------------------
36
+
37
+ def persistent_class(orig_class):
38
+ r"""Class decorator that extends a given class to save its source code
39
+ when pickled.
40
+
41
+ Example:
42
+
43
+ from torch_utils import persistence
44
+
45
+ @persistence.persistent_class
46
+ class MyNetwork(torch.nn.Module):
47
+ def __init__(self, num_inputs, num_outputs):
48
+ super().__init__()
49
+ self.fc = MyLayer(num_inputs, num_outputs)
50
+ ...
51
+
52
+ @persistence.persistent_class
53
+ class MyLayer(torch.nn.Module):
54
+ ...
55
+
56
+ When pickled, any instance of `MyNetwork` and `MyLayer` will save its
57
+ source code alongside other internal state (e.g., parameters, buffers,
58
+ and submodules). This way, any previously exported pickle will remain
59
+ usable even if the class definitions have been modified or are no
60
+ longer available.
61
+
62
+ The decorator saves the source code of the entire Python module
63
+ containing the decorated class. It does *not* save the source code of
64
+ any imported modules. Thus, the imported modules must be available
65
+ during unpickling, also including `torch_utils.persistence` itself.
66
+
67
+ It is ok to call functions defined in the same module from the
68
+ decorated class. However, if the decorated class depends on other
69
+ classes defined in the same module, they must be decorated as well.
70
+ This is illustrated in the above example in the case of `MyLayer`.
71
+
72
+ It is also possible to employ the decorator just-in-time before
73
+ calling the constructor. For example:
74
+
75
+ cls = MyLayer
76
+ if want_to_make_it_persistent:
77
+ cls = persistence.persistent_class(cls)
78
+ layer = cls(num_inputs, num_outputs)
79
+
80
+ As an additional feature, the decorator also keeps track of the
81
+ arguments that were used to construct each instance of the decorated
82
+ class. The arguments can be queried via `obj.init_args` and
83
+ `obj.init_kwargs`, and they are automatically pickled alongside other
84
+ object state. A typical use case is to first unpickle a previous
85
+ instance of a persistent class, and then upgrade it to use the latest
86
+ version of the source code:
87
+
88
+ with open('old_pickle.pkl', 'rb') as f:
89
+ old_net = pickle.load(f)
90
+ new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
91
+ misc.copy_params_and_buffers(old_net, new_net, require_all=True)
92
+ """
93
+ assert isinstance(orig_class, type)
94
+ if is_persistent(orig_class):
95
+ return orig_class
96
+
97
+ assert orig_class.__module__ in sys.modules
98
+ orig_module = sys.modules[orig_class.__module__]
99
+ orig_module_src = _module_to_src(orig_module)
100
+
101
+ class Decorator(orig_class):
102
+ _orig_module_src = orig_module_src
103
+ _orig_class_name = orig_class.__name__
104
+
105
+ def __init__(self, *args, **kwargs):
106
+ super().__init__(*args, **kwargs)
107
+ self._init_args = copy.deepcopy(args)
108
+ self._init_kwargs = copy.deepcopy(kwargs)
109
+ assert orig_class.__name__ in orig_module.__dict__
110
+ _check_pickleable(self.__reduce__())
111
+
112
+ @property
113
+ def init_args(self):
114
+ return copy.deepcopy(self._init_args)
115
+
116
+ @property
117
+ def init_kwargs(self):
118
+ return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs))
119
+
120
+ def __reduce__(self):
121
+ fields = list(super().__reduce__())
122
+ fields += [None] * max(3 - len(fields), 0)
123
+ if fields[0] is not _reconstruct_persistent_obj:
124
+ meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
125
+ fields[0] = _reconstruct_persistent_obj # reconstruct func
126
+ fields[1] = (meta,) # reconstruct args
127
+ fields[2] = None # state dict
128
+ return tuple(fields)
129
+
130
+ Decorator.__name__ = orig_class.__name__
131
+ _decorators.add(Decorator)
132
+ return Decorator
133
+
134
+ #----------------------------------------------------------------------------
135
+
136
+ def is_persistent(obj):
137
+ r"""Test whether the given object or class is persistent, i.e.,
138
+ whether it will save its source code when pickled.
139
+ """
140
+ try:
141
+ if obj in _decorators:
142
+ return True
143
+ except TypeError:
144
+ pass
145
+ return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
146
+
147
+ #----------------------------------------------------------------------------
148
+
149
+ def import_hook(hook):
150
+ r"""Register an import hook that is called whenever a persistent object
151
+ is being unpickled. A typical use case is to patch the pickled source
152
+ code to avoid errors and inconsistencies when the API of some imported
153
+ module has changed.
154
+
155
+ The hook should have the following signature:
156
+
157
+ hook(meta) -> modified meta
158
+
159
+ `meta` is an instance of `dnnlib.EasyDict` with the following fields:
160
+
161
+ type: Type of the persistent object, e.g. `'class'`.
162
+ version: Internal version number of `torch_utils.persistence`.
163
+ module_src Original source code of the Python module.
164
+ class_name: Class name in the original Python module.
165
+ state: Internal state of the object.
166
+
167
+ Example:
168
+
169
+ @persistence.import_hook
170
+ def wreck_my_network(meta):
171
+ if meta.class_name == 'MyNetwork':
172
+ print('MyNetwork is being imported. I will wreck it!')
173
+ meta.module_src = meta.module_src.replace("True", "False")
174
+ return meta
175
+ """
176
+ assert callable(hook)
177
+ _import_hooks.append(hook)
178
+
179
+ #----------------------------------------------------------------------------
180
+
181
+ def _reconstruct_persistent_obj(meta):
182
+ r"""Hook that is called internally by the `pickle` module to unpickle
183
+ a persistent object.
184
+ """
185
+ meta = dnnlib.EasyDict(meta)
186
+ meta.state = dnnlib.EasyDict(meta.state)
187
+ for hook in _import_hooks:
188
+ meta = hook(meta)
189
+ assert meta is not None
190
+
191
+ assert meta.version == _version
192
+ module = _src_to_module(meta.module_src)
193
+
194
+ assert meta.type == 'class'
195
+ orig_class = module.__dict__[meta.class_name]
196
+ decorator_class = persistent_class(orig_class)
197
+ obj = decorator_class.__new__(decorator_class)
198
+
199
+ setstate = getattr(obj, '__setstate__', None)
200
+ if callable(setstate):
201
+ setstate(meta.state) # pylint: disable=not-callable
202
+ else:
203
+ obj.__dict__.update(meta.state)
204
+ return obj
205
+
206
+ #----------------------------------------------------------------------------
207
+
208
+ def _module_to_src(module):
209
+ r"""Query the source code of a given Python module.
210
+ """
211
+ src = _module_to_src_dict.get(module, None)
212
+ if src is None:
213
+ src = inspect.getsource(module)
214
+ _module_to_src_dict[module] = src
215
+ _src_to_module_dict[src] = module
216
+ return src
217
+
218
+ def _src_to_module(src):
219
+ r"""Get or create a Python module for the given source code.
220
+ """
221
+ module = _src_to_module_dict.get(src, None)
222
+ if module is None:
223
+ module_name = "_imported_module_" + uuid.uuid4().hex
224
+ module = types.ModuleType(module_name)
225
+ sys.modules[module_name] = module
226
+ _module_to_src_dict[module] = src
227
+ _src_to_module_dict[src] = module
228
+ exec(src, module.__dict__) # pylint: disable=exec-used
229
+ return module
230
+
231
+ #----------------------------------------------------------------------------
232
+
233
+ def _check_pickleable(obj):
234
+ r"""Check that the given object is pickleable, raising an exception if
235
+ it is not. This function is expected to be considerably more efficient
236
+ than actually pickling the object.
237
+ """
238
+ def recurse(obj):
239
+ if isinstance(obj, (list, tuple, set)):
240
+ return [recurse(x) for x in obj]
241
+ if isinstance(obj, dict):
242
+ return [[recurse(x), recurse(y)] for x, y in obj.items()]
243
+ if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
244
+ return None # Python primitive types are pickleable.
245
+ if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']:
246
+ return None # NumPy arrays and PyTorch tensors are pickleable.
247
+ if is_persistent(obj):
248
+ return None # Persistent objects are pickleable, by virtue of the constructor check.
249
+ return obj
250
+ with io.BytesIO() as f:
251
+ pickle.dump(recurse(obj), f)
252
+
253
+ #----------------------------------------------------------------------------
torch_utils/training_stats.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) SenseTime Research. All rights reserved.
2
+
3
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
6
+ # and proprietary rights in and to this software, related documentation
7
+ # and any modifications thereto. Any use, reproduction, disclosure or
8
+ # distribution of this software and related documentation without an express
9
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
+
11
+ """Facilities for reporting and collecting training statistics across
12
+ multiple processes and devices. The interface is designed to minimize
13
+ synchronization overhead as well as the amount of boilerplate in user
14
+ code."""
15
+
16
+ import re
17
+ import numpy as np
18
+ import torch
19
+ import dnnlib
20
+
21
+ from . import misc
22
+
23
+ #----------------------------------------------------------------------------
24
+
25
+ _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
26
+ _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
27
+ _counter_dtype = torch.float64 # Data type to use for the internal counters.
28
+ _rank = 0 # Rank of the current process.
29
+ _sync_device = None # Device to use for multiprocess communication. None = single-process.
30
+ _sync_called = False # Has _sync() been called yet?
31
+ _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor
32
+ _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
33
+
34
+ #----------------------------------------------------------------------------
35
+
36
+ def init_multiprocessing(rank, sync_device):
37
+ r"""Initializes `torch_utils.training_stats` for collecting statistics
38
+ across multiple processes.
39
+
40
+ This function must be called after
41
+ `torch.distributed.init_process_group()` and before `Collector.update()`.
42
+ The call is not necessary if multi-process collection is not needed.
43
+
44
+ Args:
45
+ rank: Rank of the current process.
46
+ sync_device: PyTorch device to use for inter-process
47
+ communication, or None to disable multi-process
48
+ collection. Typically `torch.device('cuda', rank)`.
49
+ """
50
+ global _rank, _sync_device
51
+ assert not _sync_called
52
+ _rank = rank
53
+ _sync_device = sync_device
54
+
55
+ #----------------------------------------------------------------------------
56
+
57
+ @misc.profiled_function
58
+ def report(name, value):
59
+ r"""Broadcasts the given set of scalars to all interested instances of
60
+ `Collector`, across device and process boundaries.
61
+
62
+ This function is expected to be extremely cheap and can be safely
63
+ called from anywhere in the training loop, loss function, or inside a
64
+ `torch.nn.Module`.
65
+
66
+ Warning: The current implementation expects the set of unique names to
67
+ be consistent across processes. Please make sure that `report()` is
68
+ called at least once for each unique name by each process, and in the
69
+ same order. If a given process has no scalars to broadcast, it can do
70
+ `report(name, [])` (empty list).
71
+
72
+ Args:
73
+ name: Arbitrary string specifying the name of the statistic.
74
+ Averages are accumulated separately for each unique name.
75
+ value: Arbitrary set of scalars. Can be a list, tuple,
76
+ NumPy array, PyTorch tensor, or Python scalar.
77
+
78
+ Returns:
79
+ The same `value` that was passed in.
80
+ """
81
+ if name not in _counters:
82
+ _counters[name] = dict()
83
+
84
+ elems = torch.as_tensor(value)
85
+ if elems.numel() == 0:
86
+ return value
87
+
88
+ elems = elems.detach().flatten().to(_reduce_dtype)
89
+ moments = torch.stack([
90
+ torch.ones_like(elems).sum(),
91
+ elems.sum(),
92
+ elems.square().sum(),
93
+ ])
94
+ assert moments.ndim == 1 and moments.shape[0] == _num_moments
95
+ moments = moments.to(_counter_dtype)
96
+
97
+ device = moments.device
98
+ if device not in _counters[name]:
99
+ _counters[name][device] = torch.zeros_like(moments)
100
+ _counters[name][device].add_(moments)
101
+ return value
102
+
103
+ #----------------------------------------------------------------------------
104
+
105
+ def report0(name, value):
106
+ r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
107
+ but ignores any scalars provided by the other processes.
108
+ See `report()` for further details.
109
+ """
110
+ report(name, value if _rank == 0 else [])
111
+ return value
112
+
113
+ #----------------------------------------------------------------------------
114
+
115
+ class Collector:
116
+ r"""Collects the scalars broadcasted by `report()` and `report0()` and
117
+ computes their long-term averages (mean and standard deviation) over
118
+ user-defined periods of time.
119
+
120
+ The averages are first collected into internal counters that are not
121
+ directly visible to the user. They are then copied to the user-visible
122
+ state as a result of calling `update()` and can then be queried using
123
+ `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the
124
+ internal counters for the next round, so that the user-visible state
125
+ effectively reflects averages collected between the last two calls to
126
+ `update()`.
127
+
128
+ Args:
129
+ regex: Regular expression defining which statistics to
130
+ collect. The default is to collect everything.
131
+ keep_previous: Whether to retain the previous averages if no
132
+ scalars were collected on a given round
133
+ (default: True).
134
+ """
135
+ def __init__(self, regex='.*', keep_previous=True):
136
+ self._regex = re.compile(regex)
137
+ self._keep_previous = keep_previous
138
+ self._cumulative = dict()
139
+ self._moments = dict()
140
+ self.update()
141
+ self._moments.clear()
142
+
143
+ def names(self):
144
+ r"""Returns the names of all statistics broadcasted so far that
145
+ match the regular expression specified at construction time.
146
+ """
147
+ return [name for name in _counters if self._regex.fullmatch(name)]
148
+
149
+ def update(self):
150
+ r"""Copies current values of the internal counters to the
151
+ user-visible state and resets them for the next round.
152
+
153
+ If `keep_previous=True` was specified at construction time, the
154
+ operation is skipped for statistics that have received no scalars
155
+ since the last update, retaining their previous averages.
156
+
157
+ This method performs a number of GPU-to-CPU transfers and one
158
+ `torch.distributed.all_reduce()`. It is intended to be called
159
+ periodically in the main training loop, typically once every
160
+ N training steps.
161
+ """
162
+ if not self._keep_previous:
163
+ self._moments.clear()
164
+ for name, cumulative in _sync(self.names()):
165
+ if name not in self._cumulative:
166
+ self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
167
+ delta = cumulative - self._cumulative[name]
168
+ self._cumulative[name].copy_(cumulative)
169
+ if float(delta[0]) != 0:
170
+ self._moments[name] = delta
171
+
172
+ def _get_delta(self, name):
173
+ r"""Returns the raw moments that were accumulated for the given
174
+ statistic between the last two calls to `update()`, or zero if
175
+ no scalars were collected.
176
+ """
177
+ assert self._regex.fullmatch(name)
178
+ if name not in self._moments:
179
+ self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
180
+ return self._moments[name]
181
+
182
+ def num(self, name):
183
+ r"""Returns the number of scalars that were accumulated for the given
184
+ statistic between the last two calls to `update()`, or zero if
185
+ no scalars were collected.
186
+ """
187
+ delta = self._get_delta(name)
188
+ return int(delta[0])
189
+
190
+ def mean(self, name):
191
+ r"""Returns the mean of the scalars that were accumulated for the
192
+ given statistic between the last two calls to `update()`, or NaN if
193
+ no scalars were collected.
194
+ """
195
+ delta = self._get_delta(name)
196
+ if int(delta[0]) == 0:
197
+ return float('nan')
198
+ return float(delta[1] / delta[0])
199
+
200
+ def std(self, name):
201
+ r"""Returns the standard deviation of the scalars that were
202
+ accumulated for the given statistic between the last two calls to
203
+ `update()`, or NaN if no scalars were collected.
204
+ """
205
+ delta = self._get_delta(name)
206
+ if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
207
+ return float('nan')
208
+ if int(delta[0]) == 1:
209
+ return float(0)
210
+ mean = float(delta[1] / delta[0])
211
+ raw_var = float(delta[2] / delta[0])
212
+ return np.sqrt(max(raw_var - np.square(mean), 0))
213
+
214
+ def as_dict(self):
215
+ r"""Returns the averages accumulated between the last two calls to
216
+ `update()` as an `dnnlib.EasyDict`. The contents are as follows:
217
+
218
+ dnnlib.EasyDict(
219
+ NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT),
220
+ ...
221
+ )
222
+ """
223
+ stats = dnnlib.EasyDict()
224
+ for name in self.names():
225
+ stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
226
+ return stats
227
+
228
+ def __getitem__(self, name):
229
+ r"""Convenience getter.
230
+ `collector[name]` is a synonym for `collector.mean(name)`.
231
+ """
232
+ return self.mean(name)
233
+
234
+ #----------------------------------------------------------------------------
235
+
236
+ def _sync(names):
237
+ r"""Synchronize the global cumulative counters across devices and
238
+ processes. Called internally by `Collector.update()`.
239
+ """
240
+ if len(names) == 0:
241
+ return []
242
+ global _sync_called
243
+ _sync_called = True
244
+
245
+ # Collect deltas within current rank.
246
+ deltas = []
247
+ device = _sync_device if _sync_device is not None else torch.device('cpu')
248
+ for name in names:
249
+ delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
250
+ for counter in _counters[name].values():
251
+ delta.add_(counter.to(device))
252
+ counter.copy_(torch.zeros_like(counter))
253
+ deltas.append(delta)
254
+ deltas = torch.stack(deltas)
255
+
256
+ # Sum deltas across ranks.
257
+ if _sync_device is not None:
258
+ torch.distributed.all_reduce(deltas)
259
+
260
+ # Update cumulative values.
261
+ deltas = deltas.cpu()
262
+ for idx, name in enumerate(names):
263
+ if name not in _cumulative:
264
+ _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
265
+ _cumulative[name].add_(deltas[idx])
266
+
267
+ # Return name-value pairs.
268
+ return [(name, _cumulative[name]) for name in names]
269
+
270
+ #----------------------------------------------------------------------------