Spaces:
Runtime error
Runtime error
akhaliq3
commited on
Commit
•
e86b33b
1
Parent(s):
4787ac2
spaces demo
Browse files- LICENSE.txt +97 -0
- calc_metrics.py +190 -0
- dataset_tool.py +444 -0
- dnnlib/__init__.py +9 -0
- dnnlib/util.py +477 -0
- generate.py +129 -0
- legacy.py +320 -0
- metrics/__init__.py +9 -0
- metrics/frechet_inception_distance.py +41 -0
- metrics/inception_score.py +38 -0
- metrics/kernel_inception_distance.py +46 -0
- metrics/metric_main.py +152 -0
- metrics/metric_utils.py +275 -0
- metrics/perceptual_path_length.py +131 -0
- metrics/precision_recall.py +62 -0
- projector.py +212 -0
- style_mixing.py +118 -0
- torch_utils/__init__.py +9 -0
- torch_utils/custom_ops.py +126 -0
- torch_utils/misc.py +262 -0
- torch_utils/ops/__init__.py +9 -0
- torch_utils/ops/bias_act.cpp +99 -0
- torch_utils/ops/bias_act.cu +173 -0
- torch_utils/ops/bias_act.h +38 -0
- torch_utils/ops/bias_act.py +212 -0
- torch_utils/ops/conv2d_gradfix.py +170 -0
- torch_utils/ops/conv2d_resample.py +156 -0
- torch_utils/ops/fma.py +60 -0
- torch_utils/ops/grid_sample_gradfix.py +83 -0
- torch_utils/ops/upfirdn2d.cpp +103 -0
- torch_utils/ops/upfirdn2d.cu +350 -0
- torch_utils/ops/upfirdn2d.h +59 -0
- torch_utils/ops/upfirdn2d.py +384 -0
- torch_utils/persistence.py +251 -0
- torch_utils/training_stats.py +268 -0
- train.py +540 -0
- training/__init__.py +9 -0
- training/augment.py +431 -0
- training/dataset.py +236 -0
- training/loss.py +133 -0
- training/networks.py +729 -0
- training/training_loop.py +421 -0
LICENSE.txt
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
|
2 |
+
|
3 |
+
|
4 |
+
NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator Augmentation (ADA)
|
5 |
+
|
6 |
+
|
7 |
+
=======================================================================
|
8 |
+
|
9 |
+
1. Definitions
|
10 |
+
|
11 |
+
"Licensor" means any person or entity that distributes its Work.
|
12 |
+
|
13 |
+
"Software" means the original work of authorship made available under
|
14 |
+
this License.
|
15 |
+
|
16 |
+
"Work" means the Software and any additions to or derivative works of
|
17 |
+
the Software that are made available under this License.
|
18 |
+
|
19 |
+
The terms "reproduce," "reproduction," "derivative works," and
|
20 |
+
"distribution" have the meaning as provided under U.S. copyright law;
|
21 |
+
provided, however, that for the purposes of this License, derivative
|
22 |
+
works shall not include works that remain separable from, or merely
|
23 |
+
link (or bind by name) to the interfaces of, the Work.
|
24 |
+
|
25 |
+
Works, including the Software, are "made available" under this License
|
26 |
+
by including in or with the Work either (a) a copyright notice
|
27 |
+
referencing the applicability of this License to the Work, or (b) a
|
28 |
+
copy of this License.
|
29 |
+
|
30 |
+
2. License Grants
|
31 |
+
|
32 |
+
2.1 Copyright Grant. Subject to the terms and conditions of this
|
33 |
+
License, each Licensor grants to you a perpetual, worldwide,
|
34 |
+
non-exclusive, royalty-free, copyright license to reproduce,
|
35 |
+
prepare derivative works of, publicly display, publicly perform,
|
36 |
+
sublicense and distribute its Work and any resulting derivative
|
37 |
+
works in any form.
|
38 |
+
|
39 |
+
3. Limitations
|
40 |
+
|
41 |
+
3.1 Redistribution. You may reproduce or distribute the Work only
|
42 |
+
if (a) you do so under this License, (b) you include a complete
|
43 |
+
copy of this License with your distribution, and (c) you retain
|
44 |
+
without modification any copyright, patent, trademark, or
|
45 |
+
attribution notices that are present in the Work.
|
46 |
+
|
47 |
+
3.2 Derivative Works. You may specify that additional or different
|
48 |
+
terms apply to the use, reproduction, and distribution of your
|
49 |
+
derivative works of the Work ("Your Terms") only if (a) Your Terms
|
50 |
+
provide that the use limitation in Section 3.3 applies to your
|
51 |
+
derivative works, and (b) you identify the specific derivative
|
52 |
+
works that are subject to Your Terms. Notwithstanding Your Terms,
|
53 |
+
this License (including the redistribution requirements in Section
|
54 |
+
3.1) will continue to apply to the Work itself.
|
55 |
+
|
56 |
+
3.3 Use Limitation. The Work and any derivative works thereof only
|
57 |
+
may be used or intended for use non-commercially. Notwithstanding
|
58 |
+
the foregoing, NVIDIA and its affiliates may use the Work and any
|
59 |
+
derivative works commercially. As used herein, "non-commercially"
|
60 |
+
means for research or evaluation purposes only.
|
61 |
+
|
62 |
+
3.4 Patent Claims. If you bring or threaten to bring a patent claim
|
63 |
+
against any Licensor (including any claim, cross-claim or
|
64 |
+
counterclaim in a lawsuit) to enforce any patents that you allege
|
65 |
+
are infringed by any Work, then your rights under this License from
|
66 |
+
such Licensor (including the grant in Section 2.1) will terminate
|
67 |
+
immediately.
|
68 |
+
|
69 |
+
3.5 Trademarks. This License does not grant any rights to use any
|
70 |
+
Licensor’s or its affiliates’ names, logos, or trademarks, except
|
71 |
+
as necessary to reproduce the notices described in this License.
|
72 |
+
|
73 |
+
3.6 Termination. If you violate any term of this License, then your
|
74 |
+
rights under this License (including the grant in Section 2.1) will
|
75 |
+
terminate immediately.
|
76 |
+
|
77 |
+
4. Disclaimer of Warranty.
|
78 |
+
|
79 |
+
THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
80 |
+
KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
|
81 |
+
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
|
82 |
+
NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
|
83 |
+
THIS LICENSE.
|
84 |
+
|
85 |
+
5. Limitation of Liability.
|
86 |
+
|
87 |
+
EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
|
88 |
+
THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
|
89 |
+
SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
|
90 |
+
INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
|
91 |
+
OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
|
92 |
+
(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
|
93 |
+
LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
|
94 |
+
COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
|
95 |
+
THE POSSIBILITY OF SUCH DAMAGES.
|
96 |
+
|
97 |
+
=======================================================================
|
calc_metrics.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Calculate quality metrics for previous training run or pretrained network pickle."""
|
10 |
+
|
11 |
+
import os
|
12 |
+
import click
|
13 |
+
import json
|
14 |
+
import tempfile
|
15 |
+
import copy
|
16 |
+
import torch
|
17 |
+
import dnnlib
|
18 |
+
|
19 |
+
import legacy
|
20 |
+
from metrics import metric_main
|
21 |
+
from metrics import metric_utils
|
22 |
+
from torch_utils import training_stats
|
23 |
+
from torch_utils import custom_ops
|
24 |
+
from torch_utils import misc
|
25 |
+
|
26 |
+
#----------------------------------------------------------------------------
|
27 |
+
|
28 |
+
def subprocess_fn(rank, args, temp_dir):
|
29 |
+
dnnlib.util.Logger(should_flush=True)
|
30 |
+
|
31 |
+
# Init torch.distributed.
|
32 |
+
if args.num_gpus > 1:
|
33 |
+
init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init'))
|
34 |
+
if os.name == 'nt':
|
35 |
+
init_method = 'file:///' + init_file.replace('\\', '/')
|
36 |
+
torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus)
|
37 |
+
else:
|
38 |
+
init_method = f'file://{init_file}'
|
39 |
+
torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus)
|
40 |
+
|
41 |
+
# Init torch_utils.
|
42 |
+
sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None
|
43 |
+
training_stats.init_multiprocessing(rank=rank, sync_device=sync_device)
|
44 |
+
if rank != 0 or not args.verbose:
|
45 |
+
custom_ops.verbosity = 'none'
|
46 |
+
|
47 |
+
# Print network summary.
|
48 |
+
device = torch.device('cuda', rank)
|
49 |
+
torch.backends.cudnn.benchmark = True
|
50 |
+
torch.backends.cuda.matmul.allow_tf32 = False
|
51 |
+
torch.backends.cudnn.allow_tf32 = False
|
52 |
+
G = copy.deepcopy(args.G).eval().requires_grad_(False).to(device)
|
53 |
+
if rank == 0 and args.verbose:
|
54 |
+
z = torch.empty([1, G.z_dim], device=device)
|
55 |
+
c = torch.empty([1, G.c_dim], device=device)
|
56 |
+
misc.print_module_summary(G, [z, c])
|
57 |
+
|
58 |
+
# Calculate each metric.
|
59 |
+
for metric in args.metrics:
|
60 |
+
if rank == 0 and args.verbose:
|
61 |
+
print(f'Calculating {metric}...')
|
62 |
+
progress = metric_utils.ProgressMonitor(verbose=args.verbose)
|
63 |
+
result_dict = metric_main.calc_metric(metric=metric, G=G, dataset_kwargs=args.dataset_kwargs,
|
64 |
+
num_gpus=args.num_gpus, rank=rank, device=device, progress=progress)
|
65 |
+
if rank == 0:
|
66 |
+
metric_main.report_metric(result_dict, run_dir=args.run_dir, snapshot_pkl=args.network_pkl)
|
67 |
+
if rank == 0 and args.verbose:
|
68 |
+
print()
|
69 |
+
|
70 |
+
# Done.
|
71 |
+
if rank == 0 and args.verbose:
|
72 |
+
print('Exiting...')
|
73 |
+
|
74 |
+
#----------------------------------------------------------------------------
|
75 |
+
|
76 |
+
class CommaSeparatedList(click.ParamType):
|
77 |
+
name = 'list'
|
78 |
+
|
79 |
+
def convert(self, value, param, ctx):
|
80 |
+
_ = param, ctx
|
81 |
+
if value is None or value.lower() == 'none' or value == '':
|
82 |
+
return []
|
83 |
+
return value.split(',')
|
84 |
+
|
85 |
+
#----------------------------------------------------------------------------
|
86 |
+
|
87 |
+
@click.command()
|
88 |
+
@click.pass_context
|
89 |
+
@click.option('network_pkl', '--network', help='Network pickle filename or URL', metavar='PATH', required=True)
|
90 |
+
@click.option('--metrics', help='Comma-separated list or "none"', type=CommaSeparatedList(), default='fid50k_full', show_default=True)
|
91 |
+
@click.option('--data', help='Dataset to evaluate metrics against (directory or zip) [default: same as training data]', metavar='PATH')
|
92 |
+
@click.option('--mirror', help='Whether the dataset was augmented with x-flips during training [default: look up]', type=bool, metavar='BOOL')
|
93 |
+
@click.option('--gpus', help='Number of GPUs to use', type=int, default=1, metavar='INT', show_default=True)
|
94 |
+
@click.option('--verbose', help='Print optional information', type=bool, default=True, metavar='BOOL', show_default=True)
|
95 |
+
|
96 |
+
def calc_metrics(ctx, network_pkl, metrics, data, mirror, gpus, verbose):
|
97 |
+
"""Calculate quality metrics for previous training run or pretrained network pickle.
|
98 |
+
|
99 |
+
Examples:
|
100 |
+
|
101 |
+
\b
|
102 |
+
# Previous training run: look up options automatically, save result to JSONL file.
|
103 |
+
python calc_metrics.py --metrics=pr50k3_full \\
|
104 |
+
--network=~/training-runs/00000-ffhq10k-res64-auto1/network-snapshot-000000.pkl
|
105 |
+
|
106 |
+
\b
|
107 |
+
# Pre-trained network pickle: specify dataset explicitly, print result to stdout.
|
108 |
+
python calc_metrics.py --metrics=fid50k_full --data=~/datasets/ffhq.zip --mirror=1 \\
|
109 |
+
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl
|
110 |
+
|
111 |
+
Available metrics:
|
112 |
+
|
113 |
+
\b
|
114 |
+
ADA paper:
|
115 |
+
fid50k_full Frechet inception distance against the full dataset.
|
116 |
+
kid50k_full Kernel inception distance against the full dataset.
|
117 |
+
pr50k3_full Precision and recall againt the full dataset.
|
118 |
+
is50k Inception score for CIFAR-10.
|
119 |
+
|
120 |
+
\b
|
121 |
+
StyleGAN and StyleGAN2 papers:
|
122 |
+
fid50k Frechet inception distance against 50k real images.
|
123 |
+
kid50k Kernel inception distance against 50k real images.
|
124 |
+
pr50k3 Precision and recall against 50k real images.
|
125 |
+
ppl2_wend Perceptual path length in W at path endpoints against full image.
|
126 |
+
ppl_zfull Perceptual path length in Z for full paths against cropped image.
|
127 |
+
ppl_wfull Perceptual path length in W for full paths against cropped image.
|
128 |
+
ppl_zend Perceptual path length in Z at path endpoints against cropped image.
|
129 |
+
ppl_wend Perceptual path length in W at path endpoints against cropped image.
|
130 |
+
"""
|
131 |
+
dnnlib.util.Logger(should_flush=True)
|
132 |
+
|
133 |
+
# Validate arguments.
|
134 |
+
args = dnnlib.EasyDict(metrics=metrics, num_gpus=gpus, network_pkl=network_pkl, verbose=verbose)
|
135 |
+
if not all(metric_main.is_valid_metric(metric) for metric in args.metrics):
|
136 |
+
ctx.fail('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics()))
|
137 |
+
if not args.num_gpus >= 1:
|
138 |
+
ctx.fail('--gpus must be at least 1')
|
139 |
+
|
140 |
+
# Load network.
|
141 |
+
if not dnnlib.util.is_url(network_pkl, allow_file_urls=True) and not os.path.isfile(network_pkl):
|
142 |
+
ctx.fail('--network must point to a file or URL')
|
143 |
+
if args.verbose:
|
144 |
+
print(f'Loading network from "{network_pkl}"...')
|
145 |
+
with dnnlib.util.open_url(network_pkl, verbose=args.verbose) as f:
|
146 |
+
network_dict = legacy.load_network_pkl(f)
|
147 |
+
args.G = network_dict['G_ema'] # subclass of torch.nn.Module
|
148 |
+
|
149 |
+
# Initialize dataset options.
|
150 |
+
if data is not None:
|
151 |
+
args.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=data)
|
152 |
+
elif network_dict['training_set_kwargs'] is not None:
|
153 |
+
args.dataset_kwargs = dnnlib.EasyDict(network_dict['training_set_kwargs'])
|
154 |
+
else:
|
155 |
+
ctx.fail('Could not look up dataset options; please specify --data')
|
156 |
+
|
157 |
+
# Finalize dataset options.
|
158 |
+
args.dataset_kwargs.resolution = args.G.img_resolution
|
159 |
+
args.dataset_kwargs.use_labels = (args.G.c_dim != 0)
|
160 |
+
if mirror is not None:
|
161 |
+
args.dataset_kwargs.xflip = mirror
|
162 |
+
|
163 |
+
# Print dataset options.
|
164 |
+
if args.verbose:
|
165 |
+
print('Dataset options:')
|
166 |
+
print(json.dumps(args.dataset_kwargs, indent=2))
|
167 |
+
|
168 |
+
# Locate run dir.
|
169 |
+
args.run_dir = None
|
170 |
+
if os.path.isfile(network_pkl):
|
171 |
+
pkl_dir = os.path.dirname(network_pkl)
|
172 |
+
if os.path.isfile(os.path.join(pkl_dir, 'training_options.json')):
|
173 |
+
args.run_dir = pkl_dir
|
174 |
+
|
175 |
+
# Launch processes.
|
176 |
+
if args.verbose:
|
177 |
+
print('Launching processes...')
|
178 |
+
torch.multiprocessing.set_start_method('spawn')
|
179 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
180 |
+
if args.num_gpus == 1:
|
181 |
+
subprocess_fn(rank=0, args=args, temp_dir=temp_dir)
|
182 |
+
else:
|
183 |
+
torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus)
|
184 |
+
|
185 |
+
#----------------------------------------------------------------------------
|
186 |
+
|
187 |
+
if __name__ == "__main__":
|
188 |
+
calc_metrics() # pylint: disable=no-value-for-parameter
|
189 |
+
|
190 |
+
#----------------------------------------------------------------------------
|
dataset_tool.py
ADDED
@@ -0,0 +1,444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import functools
|
10 |
+
import io
|
11 |
+
import json
|
12 |
+
import os
|
13 |
+
import pickle
|
14 |
+
import sys
|
15 |
+
import tarfile
|
16 |
+
import gzip
|
17 |
+
import zipfile
|
18 |
+
from pathlib import Path
|
19 |
+
from typing import Callable, Optional, Tuple, Union
|
20 |
+
|
21 |
+
import click
|
22 |
+
import numpy as np
|
23 |
+
import PIL.Image
|
24 |
+
from tqdm import tqdm
|
25 |
+
|
26 |
+
#----------------------------------------------------------------------------
|
27 |
+
|
28 |
+
def error(msg):
|
29 |
+
print('Error: ' + msg)
|
30 |
+
sys.exit(1)
|
31 |
+
|
32 |
+
#----------------------------------------------------------------------------
|
33 |
+
|
34 |
+
def maybe_min(a: int, b: Optional[int]) -> int:
|
35 |
+
if b is not None:
|
36 |
+
return min(a, b)
|
37 |
+
return a
|
38 |
+
|
39 |
+
#----------------------------------------------------------------------------
|
40 |
+
|
41 |
+
def file_ext(name: Union[str, Path]) -> str:
|
42 |
+
return str(name).split('.')[-1]
|
43 |
+
|
44 |
+
#----------------------------------------------------------------------------
|
45 |
+
|
46 |
+
def is_image_ext(fname: Union[str, Path]) -> bool:
|
47 |
+
ext = file_ext(fname).lower()
|
48 |
+
return f'.{ext}' in PIL.Image.EXTENSION # type: ignore
|
49 |
+
|
50 |
+
#----------------------------------------------------------------------------
|
51 |
+
|
52 |
+
def open_image_folder(source_dir, *, max_images: Optional[int]):
|
53 |
+
input_images = [str(f) for f in sorted(Path(source_dir).rglob('*')) if is_image_ext(f) and os.path.isfile(f)]
|
54 |
+
|
55 |
+
# Load labels.
|
56 |
+
labels = {}
|
57 |
+
meta_fname = os.path.join(source_dir, 'dataset.json')
|
58 |
+
if os.path.isfile(meta_fname):
|
59 |
+
with open(meta_fname, 'r') as file:
|
60 |
+
labels = json.load(file)['labels']
|
61 |
+
if labels is not None:
|
62 |
+
labels = { x[0]: x[1] for x in labels }
|
63 |
+
else:
|
64 |
+
labels = {}
|
65 |
+
|
66 |
+
max_idx = maybe_min(len(input_images), max_images)
|
67 |
+
|
68 |
+
def iterate_images():
|
69 |
+
for idx, fname in enumerate(input_images):
|
70 |
+
arch_fname = os.path.relpath(fname, source_dir)
|
71 |
+
arch_fname = arch_fname.replace('\\', '/')
|
72 |
+
img = np.array(PIL.Image.open(fname))
|
73 |
+
yield dict(img=img, label=labels.get(arch_fname))
|
74 |
+
if idx >= max_idx-1:
|
75 |
+
break
|
76 |
+
return max_idx, iterate_images()
|
77 |
+
|
78 |
+
#----------------------------------------------------------------------------
|
79 |
+
|
80 |
+
def open_image_zip(source, *, max_images: Optional[int]):
|
81 |
+
with zipfile.ZipFile(source, mode='r') as z:
|
82 |
+
input_images = [str(f) for f in sorted(z.namelist()) if is_image_ext(f)]
|
83 |
+
|
84 |
+
# Load labels.
|
85 |
+
labels = {}
|
86 |
+
if 'dataset.json' in z.namelist():
|
87 |
+
with z.open('dataset.json', 'r') as file:
|
88 |
+
labels = json.load(file)['labels']
|
89 |
+
if labels is not None:
|
90 |
+
labels = { x[0]: x[1] for x in labels }
|
91 |
+
else:
|
92 |
+
labels = {}
|
93 |
+
|
94 |
+
max_idx = maybe_min(len(input_images), max_images)
|
95 |
+
|
96 |
+
def iterate_images():
|
97 |
+
with zipfile.ZipFile(source, mode='r') as z:
|
98 |
+
for idx, fname in enumerate(input_images):
|
99 |
+
with z.open(fname, 'r') as file:
|
100 |
+
img = PIL.Image.open(file) # type: ignore
|
101 |
+
img = np.array(img)
|
102 |
+
yield dict(img=img, label=labels.get(fname))
|
103 |
+
if idx >= max_idx-1:
|
104 |
+
break
|
105 |
+
return max_idx, iterate_images()
|
106 |
+
|
107 |
+
#----------------------------------------------------------------------------
|
108 |
+
|
109 |
+
def open_lmdb(lmdb_dir: str, *, max_images: Optional[int]):
|
110 |
+
import cv2 # pip install opencv-python
|
111 |
+
import lmdb # pip install lmdb # pylint: disable=import-error
|
112 |
+
|
113 |
+
with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn:
|
114 |
+
max_idx = maybe_min(txn.stat()['entries'], max_images)
|
115 |
+
|
116 |
+
def iterate_images():
|
117 |
+
with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn:
|
118 |
+
for idx, (_key, value) in enumerate(txn.cursor()):
|
119 |
+
try:
|
120 |
+
try:
|
121 |
+
img = cv2.imdecode(np.frombuffer(value, dtype=np.uint8), 1)
|
122 |
+
if img is None:
|
123 |
+
raise IOError('cv2.imdecode failed')
|
124 |
+
img = img[:, :, ::-1] # BGR => RGB
|
125 |
+
except IOError:
|
126 |
+
img = np.array(PIL.Image.open(io.BytesIO(value)))
|
127 |
+
yield dict(img=img, label=None)
|
128 |
+
if idx >= max_idx-1:
|
129 |
+
break
|
130 |
+
except:
|
131 |
+
print(sys.exc_info()[1])
|
132 |
+
|
133 |
+
return max_idx, iterate_images()
|
134 |
+
|
135 |
+
#----------------------------------------------------------------------------
|
136 |
+
|
137 |
+
def open_cifar10(tarball: str, *, max_images: Optional[int]):
|
138 |
+
images = []
|
139 |
+
labels = []
|
140 |
+
|
141 |
+
with tarfile.open(tarball, 'r:gz') as tar:
|
142 |
+
for batch in range(1, 6):
|
143 |
+
member = tar.getmember(f'cifar-10-batches-py/data_batch_{batch}')
|
144 |
+
with tar.extractfile(member) as file:
|
145 |
+
data = pickle.load(file, encoding='latin1')
|
146 |
+
images.append(data['data'].reshape(-1, 3, 32, 32))
|
147 |
+
labels.append(data['labels'])
|
148 |
+
|
149 |
+
images = np.concatenate(images)
|
150 |
+
labels = np.concatenate(labels)
|
151 |
+
images = images.transpose([0, 2, 3, 1]) # NCHW -> NHWC
|
152 |
+
assert images.shape == (50000, 32, 32, 3) and images.dtype == np.uint8
|
153 |
+
assert labels.shape == (50000,) and labels.dtype in [np.int32, np.int64]
|
154 |
+
assert np.min(images) == 0 and np.max(images) == 255
|
155 |
+
assert np.min(labels) == 0 and np.max(labels) == 9
|
156 |
+
|
157 |
+
max_idx = maybe_min(len(images), max_images)
|
158 |
+
|
159 |
+
def iterate_images():
|
160 |
+
for idx, img in enumerate(images):
|
161 |
+
yield dict(img=img, label=int(labels[idx]))
|
162 |
+
if idx >= max_idx-1:
|
163 |
+
break
|
164 |
+
|
165 |
+
return max_idx, iterate_images()
|
166 |
+
|
167 |
+
#----------------------------------------------------------------------------
|
168 |
+
|
169 |
+
def open_mnist(images_gz: str, *, max_images: Optional[int]):
|
170 |
+
labels_gz = images_gz.replace('-images-idx3-ubyte.gz', '-labels-idx1-ubyte.gz')
|
171 |
+
assert labels_gz != images_gz
|
172 |
+
images = []
|
173 |
+
labels = []
|
174 |
+
|
175 |
+
with gzip.open(images_gz, 'rb') as f:
|
176 |
+
images = np.frombuffer(f.read(), np.uint8, offset=16)
|
177 |
+
with gzip.open(labels_gz, 'rb') as f:
|
178 |
+
labels = np.frombuffer(f.read(), np.uint8, offset=8)
|
179 |
+
|
180 |
+
images = images.reshape(-1, 28, 28)
|
181 |
+
images = np.pad(images, [(0,0), (2,2), (2,2)], 'constant', constant_values=0)
|
182 |
+
assert images.shape == (60000, 32, 32) and images.dtype == np.uint8
|
183 |
+
assert labels.shape == (60000,) and labels.dtype == np.uint8
|
184 |
+
assert np.min(images) == 0 and np.max(images) == 255
|
185 |
+
assert np.min(labels) == 0 and np.max(labels) == 9
|
186 |
+
|
187 |
+
max_idx = maybe_min(len(images), max_images)
|
188 |
+
|
189 |
+
def iterate_images():
|
190 |
+
for idx, img in enumerate(images):
|
191 |
+
yield dict(img=img, label=int(labels[idx]))
|
192 |
+
if idx >= max_idx-1:
|
193 |
+
break
|
194 |
+
|
195 |
+
return max_idx, iterate_images()
|
196 |
+
|
197 |
+
#----------------------------------------------------------------------------
|
198 |
+
|
199 |
+
def make_transform(
|
200 |
+
transform: Optional[str],
|
201 |
+
output_width: Optional[int],
|
202 |
+
output_height: Optional[int],
|
203 |
+
resize_filter: str
|
204 |
+
) -> Callable[[np.ndarray], Optional[np.ndarray]]:
|
205 |
+
resample = { 'box': PIL.Image.BOX, 'lanczos': PIL.Image.LANCZOS }[resize_filter]
|
206 |
+
def scale(width, height, img):
|
207 |
+
w = img.shape[1]
|
208 |
+
h = img.shape[0]
|
209 |
+
if width == w and height == h:
|
210 |
+
return img
|
211 |
+
img = PIL.Image.fromarray(img)
|
212 |
+
ww = width if width is not None else w
|
213 |
+
hh = height if height is not None else h
|
214 |
+
img = img.resize((ww, hh), resample)
|
215 |
+
return np.array(img)
|
216 |
+
|
217 |
+
def center_crop(width, height, img):
|
218 |
+
crop = np.min(img.shape[:2])
|
219 |
+
img = img[(img.shape[0] - crop) // 2 : (img.shape[0] + crop) // 2, (img.shape[1] - crop) // 2 : (img.shape[1] + crop) // 2]
|
220 |
+
img = PIL.Image.fromarray(img, 'RGB')
|
221 |
+
img = img.resize((width, height), resample)
|
222 |
+
return np.array(img)
|
223 |
+
|
224 |
+
def center_crop_wide(width, height, img):
|
225 |
+
ch = int(np.round(width * img.shape[0] / img.shape[1]))
|
226 |
+
if img.shape[1] < width or ch < height:
|
227 |
+
return None
|
228 |
+
|
229 |
+
img = img[(img.shape[0] - ch) // 2 : (img.shape[0] + ch) // 2]
|
230 |
+
img = PIL.Image.fromarray(img, 'RGB')
|
231 |
+
img = img.resize((width, height), resample)
|
232 |
+
img = np.array(img)
|
233 |
+
|
234 |
+
canvas = np.zeros([width, width, 3], dtype=np.uint8)
|
235 |
+
canvas[(width - height) // 2 : (width + height) // 2, :] = img
|
236 |
+
return canvas
|
237 |
+
|
238 |
+
if transform is None:
|
239 |
+
return functools.partial(scale, output_width, output_height)
|
240 |
+
if transform == 'center-crop':
|
241 |
+
if (output_width is None) or (output_height is None):
|
242 |
+
error ('must specify --width and --height when using ' + transform + 'transform')
|
243 |
+
return functools.partial(center_crop, output_width, output_height)
|
244 |
+
if transform == 'center-crop-wide':
|
245 |
+
if (output_width is None) or (output_height is None):
|
246 |
+
error ('must specify --width and --height when using ' + transform + ' transform')
|
247 |
+
return functools.partial(center_crop_wide, output_width, output_height)
|
248 |
+
assert False, 'unknown transform'
|
249 |
+
|
250 |
+
#----------------------------------------------------------------------------
|
251 |
+
|
252 |
+
def open_dataset(source, *, max_images: Optional[int]):
|
253 |
+
if os.path.isdir(source):
|
254 |
+
if source.rstrip('/').endswith('_lmdb'):
|
255 |
+
return open_lmdb(source, max_images=max_images)
|
256 |
+
else:
|
257 |
+
return open_image_folder(source, max_images=max_images)
|
258 |
+
elif os.path.isfile(source):
|
259 |
+
if os.path.basename(source) == 'cifar-10-python.tar.gz':
|
260 |
+
return open_cifar10(source, max_images=max_images)
|
261 |
+
elif os.path.basename(source) == 'train-images-idx3-ubyte.gz':
|
262 |
+
return open_mnist(source, max_images=max_images)
|
263 |
+
elif file_ext(source) == 'zip':
|
264 |
+
return open_image_zip(source, max_images=max_images)
|
265 |
+
else:
|
266 |
+
assert False, 'unknown archive type'
|
267 |
+
else:
|
268 |
+
error(f'Missing input file or directory: {source}')
|
269 |
+
|
270 |
+
#----------------------------------------------------------------------------
|
271 |
+
|
272 |
+
def open_dest(dest: str) -> Tuple[str, Callable[[str, Union[bytes, str]], None], Callable[[], None]]:
|
273 |
+
dest_ext = file_ext(dest)
|
274 |
+
|
275 |
+
if dest_ext == 'zip':
|
276 |
+
if os.path.dirname(dest) != '':
|
277 |
+
os.makedirs(os.path.dirname(dest), exist_ok=True)
|
278 |
+
zf = zipfile.ZipFile(file=dest, mode='w', compression=zipfile.ZIP_STORED)
|
279 |
+
def zip_write_bytes(fname: str, data: Union[bytes, str]):
|
280 |
+
zf.writestr(fname, data)
|
281 |
+
return '', zip_write_bytes, zf.close
|
282 |
+
else:
|
283 |
+
# If the output folder already exists, check that is is
|
284 |
+
# empty.
|
285 |
+
#
|
286 |
+
# Note: creating the output directory is not strictly
|
287 |
+
# necessary as folder_write_bytes() also mkdirs, but it's better
|
288 |
+
# to give an error message earlier in case the dest folder
|
289 |
+
# somehow cannot be created.
|
290 |
+
if os.path.isdir(dest) and len(os.listdir(dest)) != 0:
|
291 |
+
error('--dest folder must be empty')
|
292 |
+
os.makedirs(dest, exist_ok=True)
|
293 |
+
|
294 |
+
def folder_write_bytes(fname: str, data: Union[bytes, str]):
|
295 |
+
os.makedirs(os.path.dirname(fname), exist_ok=True)
|
296 |
+
with open(fname, 'wb') as fout:
|
297 |
+
if isinstance(data, str):
|
298 |
+
data = data.encode('utf8')
|
299 |
+
fout.write(data)
|
300 |
+
return dest, folder_write_bytes, lambda: None
|
301 |
+
|
302 |
+
#----------------------------------------------------------------------------
|
303 |
+
|
304 |
+
@click.command()
|
305 |
+
@click.pass_context
|
306 |
+
@click.option('--source', help='Directory or archive name for input dataset', required=True, metavar='PATH')
|
307 |
+
@click.option('--dest', help='Output directory or archive name for output dataset', required=True, metavar='PATH')
|
308 |
+
@click.option('--max-images', help='Output only up to `max-images` images', type=int, default=None)
|
309 |
+
@click.option('--resize-filter', help='Filter to use when resizing images for output resolution', type=click.Choice(['box', 'lanczos']), default='lanczos', show_default=True)
|
310 |
+
@click.option('--transform', help='Input crop/resize mode', type=click.Choice(['center-crop', 'center-crop-wide']))
|
311 |
+
@click.option('--width', help='Output width', type=int)
|
312 |
+
@click.option('--height', help='Output height', type=int)
|
313 |
+
def convert_dataset(
|
314 |
+
ctx: click.Context,
|
315 |
+
source: str,
|
316 |
+
dest: str,
|
317 |
+
max_images: Optional[int],
|
318 |
+
transform: Optional[str],
|
319 |
+
resize_filter: str,
|
320 |
+
width: Optional[int],
|
321 |
+
height: Optional[int]
|
322 |
+
):
|
323 |
+
"""Convert an image dataset into a dataset archive usable with StyleGAN2 ADA PyTorch.
|
324 |
+
|
325 |
+
The input dataset format is guessed from the --source argument:
|
326 |
+
|
327 |
+
\b
|
328 |
+
--source *_lmdb/ Load LSUN dataset
|
329 |
+
--source cifar-10-python.tar.gz Load CIFAR-10 dataset
|
330 |
+
--source train-images-idx3-ubyte.gz Load MNIST dataset
|
331 |
+
--source path/ Recursively load all images from path/
|
332 |
+
--source dataset.zip Recursively load all images from dataset.zip
|
333 |
+
|
334 |
+
Specifying the output format and path:
|
335 |
+
|
336 |
+
\b
|
337 |
+
--dest /path/to/dir Save output files under /path/to/dir
|
338 |
+
--dest /path/to/dataset.zip Save output files into /path/to/dataset.zip
|
339 |
+
|
340 |
+
The output dataset format can be either an image folder or an uncompressed zip archive.
|
341 |
+
Zip archives makes it easier to move datasets around file servers and clusters, and may
|
342 |
+
offer better training performance on network file systems.
|
343 |
+
|
344 |
+
Images within the dataset archive will be stored as uncompressed PNG.
|
345 |
+
Uncompresed PNGs can be efficiently decoded in the training loop.
|
346 |
+
|
347 |
+
Class labels are stored in a file called 'dataset.json' that is stored at the
|
348 |
+
dataset root folder. This file has the following structure:
|
349 |
+
|
350 |
+
\b
|
351 |
+
{
|
352 |
+
"labels": [
|
353 |
+
["00000/img00000000.png",6],
|
354 |
+
["00000/img00000001.png",9],
|
355 |
+
... repeated for every image in the datase
|
356 |
+
["00049/img00049999.png",1]
|
357 |
+
]
|
358 |
+
}
|
359 |
+
|
360 |
+
If the 'dataset.json' file cannot be found, the dataset is interpreted as
|
361 |
+
not containing class labels.
|
362 |
+
|
363 |
+
Image scale/crop and resolution requirements:
|
364 |
+
|
365 |
+
Output images must be square-shaped and they must all have the same power-of-two
|
366 |
+
dimensions.
|
367 |
+
|
368 |
+
To scale arbitrary input image size to a specific width and height, use the
|
369 |
+
--width and --height options. Output resolution will be either the original
|
370 |
+
input resolution (if --width/--height was not specified) or the one specified with
|
371 |
+
--width/height.
|
372 |
+
|
373 |
+
Use the --transform=center-crop or --transform=center-crop-wide options to apply a
|
374 |
+
center crop transform on the input image. These options should be used with the
|
375 |
+
--width and --height options. For example:
|
376 |
+
|
377 |
+
\b
|
378 |
+
python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \\
|
379 |
+
--transform=center-crop-wide --width 512 --height=384
|
380 |
+
"""
|
381 |
+
|
382 |
+
PIL.Image.init() # type: ignore
|
383 |
+
|
384 |
+
if dest == '':
|
385 |
+
ctx.fail('--dest output filename or directory must not be an empty string')
|
386 |
+
|
387 |
+
num_files, input_iter = open_dataset(source, max_images=max_images)
|
388 |
+
archive_root_dir, save_bytes, close_dest = open_dest(dest)
|
389 |
+
|
390 |
+
transform_image = make_transform(transform, width, height, resize_filter)
|
391 |
+
|
392 |
+
dataset_attrs = None
|
393 |
+
|
394 |
+
labels = []
|
395 |
+
for idx, image in tqdm(enumerate(input_iter), total=num_files):
|
396 |
+
idx_str = f'{idx:08d}'
|
397 |
+
archive_fname = f'{idx_str[:5]}/img{idx_str}.png'
|
398 |
+
|
399 |
+
# Apply crop and resize.
|
400 |
+
img = transform_image(image['img'])
|
401 |
+
|
402 |
+
# Transform may drop images.
|
403 |
+
if img is None:
|
404 |
+
continue
|
405 |
+
|
406 |
+
# Error check to require uniform image attributes across
|
407 |
+
# the whole dataset.
|
408 |
+
channels = img.shape[2] if img.ndim == 3 else 1
|
409 |
+
cur_image_attrs = {
|
410 |
+
'width': img.shape[1],
|
411 |
+
'height': img.shape[0],
|
412 |
+
'channels': channels
|
413 |
+
}
|
414 |
+
if dataset_attrs is None:
|
415 |
+
dataset_attrs = cur_image_attrs
|
416 |
+
width = dataset_attrs['width']
|
417 |
+
height = dataset_attrs['height']
|
418 |
+
if width != height:
|
419 |
+
error(f'Image dimensions after scale and crop are required to be square. Got {width}x{height}')
|
420 |
+
if dataset_attrs['channels'] not in [1, 3]:
|
421 |
+
error('Input images must be stored as RGB or grayscale')
|
422 |
+
if width != 2 ** int(np.floor(np.log2(width))):
|
423 |
+
error('Image width/height after scale and crop are required to be power-of-two')
|
424 |
+
elif dataset_attrs != cur_image_attrs:
|
425 |
+
err = [f' dataset {k}/cur image {k}: {dataset_attrs[k]}/{cur_image_attrs[k]}' for k in dataset_attrs.keys()]
|
426 |
+
error(f'Image {archive_fname} attributes must be equal across all images of the dataset. Got:\n' + '\n'.join(err))
|
427 |
+
|
428 |
+
# Save the image as an uncompressed PNG.
|
429 |
+
img = PIL.Image.fromarray(img, { 1: 'L', 3: 'RGB' }[channels])
|
430 |
+
image_bits = io.BytesIO()
|
431 |
+
img.save(image_bits, format='png', compress_level=0, optimize=False)
|
432 |
+
save_bytes(os.path.join(archive_root_dir, archive_fname), image_bits.getbuffer())
|
433 |
+
labels.append([archive_fname, image['label']] if image['label'] is not None else None)
|
434 |
+
|
435 |
+
metadata = {
|
436 |
+
'labels': labels if all(x is not None for x in labels) else None
|
437 |
+
}
|
438 |
+
save_bytes(os.path.join(archive_root_dir, 'dataset.json'), json.dumps(metadata))
|
439 |
+
close_dest()
|
440 |
+
|
441 |
+
#----------------------------------------------------------------------------
|
442 |
+
|
443 |
+
if __name__ == "__main__":
|
444 |
+
convert_dataset() # pylint: disable=no-value-for-parameter
|
dnnlib/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
from .util import EasyDict, make_cache_dir_path
|
dnnlib/util.py
ADDED
@@ -0,0 +1,477 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Miscellaneous utility classes and functions."""
|
10 |
+
|
11 |
+
import ctypes
|
12 |
+
import fnmatch
|
13 |
+
import importlib
|
14 |
+
import inspect
|
15 |
+
import numpy as np
|
16 |
+
import os
|
17 |
+
import shutil
|
18 |
+
import sys
|
19 |
+
import types
|
20 |
+
import io
|
21 |
+
import pickle
|
22 |
+
import re
|
23 |
+
import requests
|
24 |
+
import html
|
25 |
+
import hashlib
|
26 |
+
import glob
|
27 |
+
import tempfile
|
28 |
+
import urllib
|
29 |
+
import urllib.request
|
30 |
+
import uuid
|
31 |
+
|
32 |
+
from distutils.util import strtobool
|
33 |
+
from typing import Any, List, Tuple, Union
|
34 |
+
|
35 |
+
|
36 |
+
# Util classes
|
37 |
+
# ------------------------------------------------------------------------------------------
|
38 |
+
|
39 |
+
|
40 |
+
class EasyDict(dict):
|
41 |
+
"""Convenience class that behaves like a dict but allows access with the attribute syntax."""
|
42 |
+
|
43 |
+
def __getattr__(self, name: str) -> Any:
|
44 |
+
try:
|
45 |
+
return self[name]
|
46 |
+
except KeyError:
|
47 |
+
raise AttributeError(name)
|
48 |
+
|
49 |
+
def __setattr__(self, name: str, value: Any) -> None:
|
50 |
+
self[name] = value
|
51 |
+
|
52 |
+
def __delattr__(self, name: str) -> None:
|
53 |
+
del self[name]
|
54 |
+
|
55 |
+
|
56 |
+
class Logger(object):
|
57 |
+
"""Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
|
58 |
+
|
59 |
+
def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
|
60 |
+
self.file = None
|
61 |
+
|
62 |
+
if file_name is not None:
|
63 |
+
self.file = open(file_name, file_mode)
|
64 |
+
|
65 |
+
self.should_flush = should_flush
|
66 |
+
self.stdout = sys.stdout
|
67 |
+
self.stderr = sys.stderr
|
68 |
+
|
69 |
+
sys.stdout = self
|
70 |
+
sys.stderr = self
|
71 |
+
|
72 |
+
def __enter__(self) -> "Logger":
|
73 |
+
return self
|
74 |
+
|
75 |
+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
76 |
+
self.close()
|
77 |
+
|
78 |
+
def write(self, text: Union[str, bytes]) -> None:
|
79 |
+
"""Write text to stdout (and a file) and optionally flush."""
|
80 |
+
if isinstance(text, bytes):
|
81 |
+
text = text.decode()
|
82 |
+
if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
|
83 |
+
return
|
84 |
+
|
85 |
+
if self.file is not None:
|
86 |
+
self.file.write(text)
|
87 |
+
|
88 |
+
self.stdout.write(text)
|
89 |
+
|
90 |
+
if self.should_flush:
|
91 |
+
self.flush()
|
92 |
+
|
93 |
+
def flush(self) -> None:
|
94 |
+
"""Flush written text to both stdout and a file, if open."""
|
95 |
+
if self.file is not None:
|
96 |
+
self.file.flush()
|
97 |
+
|
98 |
+
self.stdout.flush()
|
99 |
+
|
100 |
+
def close(self) -> None:
|
101 |
+
"""Flush, close possible files, and remove stdout/stderr mirroring."""
|
102 |
+
self.flush()
|
103 |
+
|
104 |
+
# if using multiple loggers, prevent closing in wrong order
|
105 |
+
if sys.stdout is self:
|
106 |
+
sys.stdout = self.stdout
|
107 |
+
if sys.stderr is self:
|
108 |
+
sys.stderr = self.stderr
|
109 |
+
|
110 |
+
if self.file is not None:
|
111 |
+
self.file.close()
|
112 |
+
self.file = None
|
113 |
+
|
114 |
+
|
115 |
+
# Cache directories
|
116 |
+
# ------------------------------------------------------------------------------------------
|
117 |
+
|
118 |
+
_dnnlib_cache_dir = None
|
119 |
+
|
120 |
+
def set_cache_dir(path: str) -> None:
|
121 |
+
global _dnnlib_cache_dir
|
122 |
+
_dnnlib_cache_dir = path
|
123 |
+
|
124 |
+
def make_cache_dir_path(*paths: str) -> str:
|
125 |
+
if _dnnlib_cache_dir is not None:
|
126 |
+
return os.path.join(_dnnlib_cache_dir, *paths)
|
127 |
+
if 'DNNLIB_CACHE_DIR' in os.environ:
|
128 |
+
return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
|
129 |
+
if 'HOME' in os.environ:
|
130 |
+
return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
|
131 |
+
if 'USERPROFILE' in os.environ:
|
132 |
+
return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
|
133 |
+
return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
|
134 |
+
|
135 |
+
# Small util functions
|
136 |
+
# ------------------------------------------------------------------------------------------
|
137 |
+
|
138 |
+
|
139 |
+
def format_time(seconds: Union[int, float]) -> str:
|
140 |
+
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
|
141 |
+
s = int(np.rint(seconds))
|
142 |
+
|
143 |
+
if s < 60:
|
144 |
+
return "{0}s".format(s)
|
145 |
+
elif s < 60 * 60:
|
146 |
+
return "{0}m {1:02}s".format(s // 60, s % 60)
|
147 |
+
elif s < 24 * 60 * 60:
|
148 |
+
return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
|
149 |
+
else:
|
150 |
+
return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
|
151 |
+
|
152 |
+
|
153 |
+
def ask_yes_no(question: str) -> bool:
|
154 |
+
"""Ask the user the question until the user inputs a valid answer."""
|
155 |
+
while True:
|
156 |
+
try:
|
157 |
+
print("{0} [y/n]".format(question))
|
158 |
+
return strtobool(input().lower())
|
159 |
+
except ValueError:
|
160 |
+
pass
|
161 |
+
|
162 |
+
|
163 |
+
def tuple_product(t: Tuple) -> Any:
|
164 |
+
"""Calculate the product of the tuple elements."""
|
165 |
+
result = 1
|
166 |
+
|
167 |
+
for v in t:
|
168 |
+
result *= v
|
169 |
+
|
170 |
+
return result
|
171 |
+
|
172 |
+
|
173 |
+
_str_to_ctype = {
|
174 |
+
"uint8": ctypes.c_ubyte,
|
175 |
+
"uint16": ctypes.c_uint16,
|
176 |
+
"uint32": ctypes.c_uint32,
|
177 |
+
"uint64": ctypes.c_uint64,
|
178 |
+
"int8": ctypes.c_byte,
|
179 |
+
"int16": ctypes.c_int16,
|
180 |
+
"int32": ctypes.c_int32,
|
181 |
+
"int64": ctypes.c_int64,
|
182 |
+
"float32": ctypes.c_float,
|
183 |
+
"float64": ctypes.c_double
|
184 |
+
}
|
185 |
+
|
186 |
+
|
187 |
+
def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
|
188 |
+
"""Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
|
189 |
+
type_str = None
|
190 |
+
|
191 |
+
if isinstance(type_obj, str):
|
192 |
+
type_str = type_obj
|
193 |
+
elif hasattr(type_obj, "__name__"):
|
194 |
+
type_str = type_obj.__name__
|
195 |
+
elif hasattr(type_obj, "name"):
|
196 |
+
type_str = type_obj.name
|
197 |
+
else:
|
198 |
+
raise RuntimeError("Cannot infer type name from input")
|
199 |
+
|
200 |
+
assert type_str in _str_to_ctype.keys()
|
201 |
+
|
202 |
+
my_dtype = np.dtype(type_str)
|
203 |
+
my_ctype = _str_to_ctype[type_str]
|
204 |
+
|
205 |
+
assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
|
206 |
+
|
207 |
+
return my_dtype, my_ctype
|
208 |
+
|
209 |
+
|
210 |
+
def is_pickleable(obj: Any) -> bool:
|
211 |
+
try:
|
212 |
+
with io.BytesIO() as stream:
|
213 |
+
pickle.dump(obj, stream)
|
214 |
+
return True
|
215 |
+
except:
|
216 |
+
return False
|
217 |
+
|
218 |
+
|
219 |
+
# Functionality to import modules/objects by name, and call functions by name
|
220 |
+
# ------------------------------------------------------------------------------------------
|
221 |
+
|
222 |
+
def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
|
223 |
+
"""Searches for the underlying module behind the name to some python object.
|
224 |
+
Returns the module and the object name (original name with module part removed)."""
|
225 |
+
|
226 |
+
# allow convenience shorthands, substitute them by full names
|
227 |
+
obj_name = re.sub("^np.", "numpy.", obj_name)
|
228 |
+
obj_name = re.sub("^tf.", "tensorflow.", obj_name)
|
229 |
+
|
230 |
+
# list alternatives for (module_name, local_obj_name)
|
231 |
+
parts = obj_name.split(".")
|
232 |
+
name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
|
233 |
+
|
234 |
+
# try each alternative in turn
|
235 |
+
for module_name, local_obj_name in name_pairs:
|
236 |
+
try:
|
237 |
+
module = importlib.import_module(module_name) # may raise ImportError
|
238 |
+
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
239 |
+
return module, local_obj_name
|
240 |
+
except:
|
241 |
+
pass
|
242 |
+
|
243 |
+
# maybe some of the modules themselves contain errors?
|
244 |
+
for module_name, _local_obj_name in name_pairs:
|
245 |
+
try:
|
246 |
+
importlib.import_module(module_name) # may raise ImportError
|
247 |
+
except ImportError:
|
248 |
+
if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
|
249 |
+
raise
|
250 |
+
|
251 |
+
# maybe the requested attribute is missing?
|
252 |
+
for module_name, local_obj_name in name_pairs:
|
253 |
+
try:
|
254 |
+
module = importlib.import_module(module_name) # may raise ImportError
|
255 |
+
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
256 |
+
except ImportError:
|
257 |
+
pass
|
258 |
+
|
259 |
+
# we are out of luck, but we have no idea why
|
260 |
+
raise ImportError(obj_name)
|
261 |
+
|
262 |
+
|
263 |
+
def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
|
264 |
+
"""Traverses the object name and returns the last (rightmost) python object."""
|
265 |
+
if obj_name == '':
|
266 |
+
return module
|
267 |
+
obj = module
|
268 |
+
for part in obj_name.split("."):
|
269 |
+
obj = getattr(obj, part)
|
270 |
+
return obj
|
271 |
+
|
272 |
+
|
273 |
+
def get_obj_by_name(name: str) -> Any:
|
274 |
+
"""Finds the python object with the given name."""
|
275 |
+
module, obj_name = get_module_from_obj_name(name)
|
276 |
+
return get_obj_from_module(module, obj_name)
|
277 |
+
|
278 |
+
|
279 |
+
def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
|
280 |
+
"""Finds the python object with the given name and calls it as a function."""
|
281 |
+
assert func_name is not None
|
282 |
+
func_obj = get_obj_by_name(func_name)
|
283 |
+
assert callable(func_obj)
|
284 |
+
return func_obj(*args, **kwargs)
|
285 |
+
|
286 |
+
|
287 |
+
def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
|
288 |
+
"""Finds the python class with the given name and constructs it with the given arguments."""
|
289 |
+
return call_func_by_name(*args, func_name=class_name, **kwargs)
|
290 |
+
|
291 |
+
|
292 |
+
def get_module_dir_by_obj_name(obj_name: str) -> str:
|
293 |
+
"""Get the directory path of the module containing the given object name."""
|
294 |
+
module, _ = get_module_from_obj_name(obj_name)
|
295 |
+
return os.path.dirname(inspect.getfile(module))
|
296 |
+
|
297 |
+
|
298 |
+
def is_top_level_function(obj: Any) -> bool:
|
299 |
+
"""Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
|
300 |
+
return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
|
301 |
+
|
302 |
+
|
303 |
+
def get_top_level_function_name(obj: Any) -> str:
|
304 |
+
"""Return the fully-qualified name of a top-level function."""
|
305 |
+
assert is_top_level_function(obj)
|
306 |
+
module = obj.__module__
|
307 |
+
if module == '__main__':
|
308 |
+
module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
|
309 |
+
return module + "." + obj.__name__
|
310 |
+
|
311 |
+
|
312 |
+
# File system helpers
|
313 |
+
# ------------------------------------------------------------------------------------------
|
314 |
+
|
315 |
+
def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
|
316 |
+
"""List all files recursively in a given directory while ignoring given file and directory names.
|
317 |
+
Returns list of tuples containing both absolute and relative paths."""
|
318 |
+
assert os.path.isdir(dir_path)
|
319 |
+
base_name = os.path.basename(os.path.normpath(dir_path))
|
320 |
+
|
321 |
+
if ignores is None:
|
322 |
+
ignores = []
|
323 |
+
|
324 |
+
result = []
|
325 |
+
|
326 |
+
for root, dirs, files in os.walk(dir_path, topdown=True):
|
327 |
+
for ignore_ in ignores:
|
328 |
+
dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
|
329 |
+
|
330 |
+
# dirs need to be edited in-place
|
331 |
+
for d in dirs_to_remove:
|
332 |
+
dirs.remove(d)
|
333 |
+
|
334 |
+
files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
|
335 |
+
|
336 |
+
absolute_paths = [os.path.join(root, f) for f in files]
|
337 |
+
relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
|
338 |
+
|
339 |
+
if add_base_to_relative:
|
340 |
+
relative_paths = [os.path.join(base_name, p) for p in relative_paths]
|
341 |
+
|
342 |
+
assert len(absolute_paths) == len(relative_paths)
|
343 |
+
result += zip(absolute_paths, relative_paths)
|
344 |
+
|
345 |
+
return result
|
346 |
+
|
347 |
+
|
348 |
+
def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
|
349 |
+
"""Takes in a list of tuples of (src, dst) paths and copies files.
|
350 |
+
Will create all necessary directories."""
|
351 |
+
for file in files:
|
352 |
+
target_dir_name = os.path.dirname(file[1])
|
353 |
+
|
354 |
+
# will create all intermediate-level directories
|
355 |
+
if not os.path.exists(target_dir_name):
|
356 |
+
os.makedirs(target_dir_name)
|
357 |
+
|
358 |
+
shutil.copyfile(file[0], file[1])
|
359 |
+
|
360 |
+
|
361 |
+
# URL helpers
|
362 |
+
# ------------------------------------------------------------------------------------------
|
363 |
+
|
364 |
+
def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
|
365 |
+
"""Determine whether the given object is a valid URL string."""
|
366 |
+
if not isinstance(obj, str) or not "://" in obj:
|
367 |
+
return False
|
368 |
+
if allow_file_urls and obj.startswith('file://'):
|
369 |
+
return True
|
370 |
+
try:
|
371 |
+
res = requests.compat.urlparse(obj)
|
372 |
+
if not res.scheme or not res.netloc or not "." in res.netloc:
|
373 |
+
return False
|
374 |
+
res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
|
375 |
+
if not res.scheme or not res.netloc or not "." in res.netloc:
|
376 |
+
return False
|
377 |
+
except:
|
378 |
+
return False
|
379 |
+
return True
|
380 |
+
|
381 |
+
|
382 |
+
def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
|
383 |
+
"""Download the given URL and return a binary-mode file object to access the data."""
|
384 |
+
assert num_attempts >= 1
|
385 |
+
assert not (return_filename and (not cache))
|
386 |
+
|
387 |
+
# Doesn't look like an URL scheme so interpret it as a local filename.
|
388 |
+
if not re.match('^[a-z]+://', url):
|
389 |
+
return url if return_filename else open(url, "rb")
|
390 |
+
|
391 |
+
# Handle file URLs. This code handles unusual file:// patterns that
|
392 |
+
# arise on Windows:
|
393 |
+
#
|
394 |
+
# file:///c:/foo.txt
|
395 |
+
#
|
396 |
+
# which would translate to a local '/c:/foo.txt' filename that's
|
397 |
+
# invalid. Drop the forward slash for such pathnames.
|
398 |
+
#
|
399 |
+
# If you touch this code path, you should test it on both Linux and
|
400 |
+
# Windows.
|
401 |
+
#
|
402 |
+
# Some internet resources suggest using urllib.request.url2pathname() but
|
403 |
+
# but that converts forward slashes to backslashes and this causes
|
404 |
+
# its own set of problems.
|
405 |
+
if url.startswith('file://'):
|
406 |
+
filename = urllib.parse.urlparse(url).path
|
407 |
+
if re.match(r'^/[a-zA-Z]:', filename):
|
408 |
+
filename = filename[1:]
|
409 |
+
return filename if return_filename else open(filename, "rb")
|
410 |
+
|
411 |
+
assert is_url(url)
|
412 |
+
|
413 |
+
# Lookup from cache.
|
414 |
+
if cache_dir is None:
|
415 |
+
cache_dir = make_cache_dir_path('downloads')
|
416 |
+
|
417 |
+
url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
|
418 |
+
if cache:
|
419 |
+
cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
|
420 |
+
if len(cache_files) == 1:
|
421 |
+
filename = cache_files[0]
|
422 |
+
return filename if return_filename else open(filename, "rb")
|
423 |
+
|
424 |
+
# Download.
|
425 |
+
url_name = None
|
426 |
+
url_data = None
|
427 |
+
with requests.Session() as session:
|
428 |
+
if verbose:
|
429 |
+
print("Downloading %s ..." % url, end="", flush=True)
|
430 |
+
for attempts_left in reversed(range(num_attempts)):
|
431 |
+
try:
|
432 |
+
with session.get(url) as res:
|
433 |
+
res.raise_for_status()
|
434 |
+
if len(res.content) == 0:
|
435 |
+
raise IOError("No data received")
|
436 |
+
|
437 |
+
if len(res.content) < 8192:
|
438 |
+
content_str = res.content.decode("utf-8")
|
439 |
+
if "download_warning" in res.headers.get("Set-Cookie", ""):
|
440 |
+
links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
|
441 |
+
if len(links) == 1:
|
442 |
+
url = requests.compat.urljoin(url, links[0])
|
443 |
+
raise IOError("Google Drive virus checker nag")
|
444 |
+
if "Google Drive - Quota exceeded" in content_str:
|
445 |
+
raise IOError("Google Drive download quota exceeded -- please try again later")
|
446 |
+
|
447 |
+
match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
|
448 |
+
url_name = match[1] if match else url
|
449 |
+
url_data = res.content
|
450 |
+
if verbose:
|
451 |
+
print(" done")
|
452 |
+
break
|
453 |
+
except KeyboardInterrupt:
|
454 |
+
raise
|
455 |
+
except:
|
456 |
+
if not attempts_left:
|
457 |
+
if verbose:
|
458 |
+
print(" failed")
|
459 |
+
raise
|
460 |
+
if verbose:
|
461 |
+
print(".", end="", flush=True)
|
462 |
+
|
463 |
+
# Save to cache.
|
464 |
+
if cache:
|
465 |
+
safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
|
466 |
+
cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
|
467 |
+
temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
|
468 |
+
os.makedirs(cache_dir, exist_ok=True)
|
469 |
+
with open(temp_file, "wb") as f:
|
470 |
+
f.write(url_data)
|
471 |
+
os.replace(temp_file, cache_file) # atomic
|
472 |
+
if return_filename:
|
473 |
+
return cache_file
|
474 |
+
|
475 |
+
# Return data as file object.
|
476 |
+
assert not return_filename
|
477 |
+
return io.BytesIO(url_data)
|
generate.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Generate images using pretrained network pickle."""
|
10 |
+
|
11 |
+
import os
|
12 |
+
import re
|
13 |
+
from typing import List, Optional
|
14 |
+
|
15 |
+
import click
|
16 |
+
import dnnlib
|
17 |
+
import numpy as np
|
18 |
+
import PIL.Image
|
19 |
+
import torch
|
20 |
+
|
21 |
+
import legacy
|
22 |
+
|
23 |
+
#----------------------------------------------------------------------------
|
24 |
+
|
25 |
+
def num_range(s: str) -> List[int]:
|
26 |
+
'''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.'''
|
27 |
+
|
28 |
+
range_re = re.compile(r'^(\d+)-(\d+)$')
|
29 |
+
m = range_re.match(s)
|
30 |
+
if m:
|
31 |
+
return list(range(int(m.group(1)), int(m.group(2))+1))
|
32 |
+
vals = s.split(',')
|
33 |
+
return [int(x) for x in vals]
|
34 |
+
|
35 |
+
#----------------------------------------------------------------------------
|
36 |
+
|
37 |
+
@click.command()
|
38 |
+
@click.pass_context
|
39 |
+
@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
|
40 |
+
@click.option('--seeds', type=num_range, help='List of random seeds')
|
41 |
+
@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
|
42 |
+
@click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)')
|
43 |
+
@click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
|
44 |
+
@click.option('--projected-w', help='Projection result file', type=str, metavar='FILE')
|
45 |
+
@click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR')
|
46 |
+
def generate_images(
|
47 |
+
ctx: click.Context,
|
48 |
+
network_pkl: str,
|
49 |
+
seeds: Optional[List[int]],
|
50 |
+
truncation_psi: float,
|
51 |
+
noise_mode: str,
|
52 |
+
outdir: str,
|
53 |
+
class_idx: Optional[int],
|
54 |
+
projected_w: Optional[str]
|
55 |
+
):
|
56 |
+
"""Generate images using pretrained network pickle.
|
57 |
+
|
58 |
+
Examples:
|
59 |
+
|
60 |
+
\b
|
61 |
+
# Generate curated MetFaces images without truncation (Fig.10 left)
|
62 |
+
python generate.py --outdir=out --trunc=1 --seeds=85,265,297,849 \\
|
63 |
+
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
|
64 |
+
|
65 |
+
\b
|
66 |
+
# Generate uncurated MetFaces images with truncation (Fig.12 upper left)
|
67 |
+
python generate.py --outdir=out --trunc=0.7 --seeds=600-605 \\
|
68 |
+
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
|
69 |
+
|
70 |
+
\b
|
71 |
+
# Generate class conditional CIFAR-10 images (Fig.17 left, Car)
|
72 |
+
python generate.py --outdir=out --seeds=0-35 --class=1 \\
|
73 |
+
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/cifar10.pkl
|
74 |
+
|
75 |
+
\b
|
76 |
+
# Render an image from projected W
|
77 |
+
python generate.py --outdir=out --projected_w=projected_w.npz \\
|
78 |
+
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
|
79 |
+
"""
|
80 |
+
|
81 |
+
print('Loading networks from "%s"...' % network_pkl)
|
82 |
+
device = torch.device('cuda')
|
83 |
+
with dnnlib.util.open_url(network_pkl) as f:
|
84 |
+
G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
|
85 |
+
|
86 |
+
os.makedirs(outdir, exist_ok=True)
|
87 |
+
|
88 |
+
# Synthesize the result of a W projection.
|
89 |
+
if projected_w is not None:
|
90 |
+
if seeds is not None:
|
91 |
+
print ('warn: --seeds is ignored when using --projected-w')
|
92 |
+
print(f'Generating images from projected W "{projected_w}"')
|
93 |
+
ws = np.load(projected_w)['w']
|
94 |
+
ws = torch.tensor(ws, device=device) # pylint: disable=not-callable
|
95 |
+
assert ws.shape[1:] == (G.num_ws, G.w_dim)
|
96 |
+
for idx, w in enumerate(ws):
|
97 |
+
img = G.synthesis(w.unsqueeze(0), noise_mode=noise_mode)
|
98 |
+
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
|
99 |
+
img = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/proj{idx:02d}.png')
|
100 |
+
return
|
101 |
+
|
102 |
+
if seeds is None:
|
103 |
+
ctx.fail('--seeds option is required when not using --projected-w')
|
104 |
+
|
105 |
+
# Labels.
|
106 |
+
label = torch.zeros([1, G.c_dim], device=device)
|
107 |
+
if G.c_dim != 0:
|
108 |
+
if class_idx is None:
|
109 |
+
ctx.fail('Must specify class label with --class when using a conditional network')
|
110 |
+
label[:, class_idx] = 1
|
111 |
+
else:
|
112 |
+
if class_idx is not None:
|
113 |
+
print ('warn: --class=lbl ignored when running on an unconditional network')
|
114 |
+
|
115 |
+
# Generate images.
|
116 |
+
for seed_idx, seed in enumerate(seeds):
|
117 |
+
print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
|
118 |
+
z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
|
119 |
+
img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
|
120 |
+
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
|
121 |
+
PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.png')
|
122 |
+
|
123 |
+
|
124 |
+
#----------------------------------------------------------------------------
|
125 |
+
|
126 |
+
if __name__ == "__main__":
|
127 |
+
generate_images() # pylint: disable=no-value-for-parameter
|
128 |
+
|
129 |
+
#----------------------------------------------------------------------------
|
legacy.py
ADDED
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import click
|
10 |
+
import pickle
|
11 |
+
import re
|
12 |
+
import copy
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
import dnnlib
|
16 |
+
from torch_utils import misc
|
17 |
+
|
18 |
+
#----------------------------------------------------------------------------
|
19 |
+
|
20 |
+
def load_network_pkl(f, force_fp16=False):
|
21 |
+
data = _LegacyUnpickler(f).load()
|
22 |
+
|
23 |
+
# Legacy TensorFlow pickle => convert.
|
24 |
+
if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data):
|
25 |
+
tf_G, tf_D, tf_Gs = data
|
26 |
+
G = convert_tf_generator(tf_G)
|
27 |
+
D = convert_tf_discriminator(tf_D)
|
28 |
+
G_ema = convert_tf_generator(tf_Gs)
|
29 |
+
data = dict(G=G, D=D, G_ema=G_ema)
|
30 |
+
|
31 |
+
# Add missing fields.
|
32 |
+
if 'training_set_kwargs' not in data:
|
33 |
+
data['training_set_kwargs'] = None
|
34 |
+
if 'augment_pipe' not in data:
|
35 |
+
data['augment_pipe'] = None
|
36 |
+
|
37 |
+
# Validate contents.
|
38 |
+
assert isinstance(data['G'], torch.nn.Module)
|
39 |
+
assert isinstance(data['D'], torch.nn.Module)
|
40 |
+
assert isinstance(data['G_ema'], torch.nn.Module)
|
41 |
+
assert isinstance(data['training_set_kwargs'], (dict, type(None)))
|
42 |
+
assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None)))
|
43 |
+
|
44 |
+
# Force FP16.
|
45 |
+
if force_fp16:
|
46 |
+
for key in ['G', 'D', 'G_ema']:
|
47 |
+
old = data[key]
|
48 |
+
kwargs = copy.deepcopy(old.init_kwargs)
|
49 |
+
if key.startswith('G'):
|
50 |
+
kwargs.synthesis_kwargs = dnnlib.EasyDict(kwargs.get('synthesis_kwargs', {}))
|
51 |
+
kwargs.synthesis_kwargs.num_fp16_res = 4
|
52 |
+
kwargs.synthesis_kwargs.conv_clamp = 256
|
53 |
+
if key.startswith('D'):
|
54 |
+
kwargs.num_fp16_res = 4
|
55 |
+
kwargs.conv_clamp = 256
|
56 |
+
if kwargs != old.init_kwargs:
|
57 |
+
new = type(old)(**kwargs).eval().requires_grad_(False)
|
58 |
+
misc.copy_params_and_buffers(old, new, require_all=True)
|
59 |
+
data[key] = new
|
60 |
+
return data
|
61 |
+
|
62 |
+
#----------------------------------------------------------------------------
|
63 |
+
|
64 |
+
class _TFNetworkStub(dnnlib.EasyDict):
|
65 |
+
pass
|
66 |
+
|
67 |
+
class _LegacyUnpickler(pickle.Unpickler):
|
68 |
+
def find_class(self, module, name):
|
69 |
+
if module == 'dnnlib.tflib.network' and name == 'Network':
|
70 |
+
return _TFNetworkStub
|
71 |
+
return super().find_class(module, name)
|
72 |
+
|
73 |
+
#----------------------------------------------------------------------------
|
74 |
+
|
75 |
+
def _collect_tf_params(tf_net):
|
76 |
+
# pylint: disable=protected-access
|
77 |
+
tf_params = dict()
|
78 |
+
def recurse(prefix, tf_net):
|
79 |
+
for name, value in tf_net.variables:
|
80 |
+
tf_params[prefix + name] = value
|
81 |
+
for name, comp in tf_net.components.items():
|
82 |
+
recurse(prefix + name + '/', comp)
|
83 |
+
recurse('', tf_net)
|
84 |
+
return tf_params
|
85 |
+
|
86 |
+
#----------------------------------------------------------------------------
|
87 |
+
|
88 |
+
def _populate_module_params(module, *patterns):
|
89 |
+
for name, tensor in misc.named_params_and_buffers(module):
|
90 |
+
found = False
|
91 |
+
value = None
|
92 |
+
for pattern, value_fn in zip(patterns[0::2], patterns[1::2]):
|
93 |
+
match = re.fullmatch(pattern, name)
|
94 |
+
if match:
|
95 |
+
found = True
|
96 |
+
if value_fn is not None:
|
97 |
+
value = value_fn(*match.groups())
|
98 |
+
break
|
99 |
+
try:
|
100 |
+
assert found
|
101 |
+
if value is not None:
|
102 |
+
tensor.copy_(torch.from_numpy(np.array(value)))
|
103 |
+
except:
|
104 |
+
print(name, list(tensor.shape))
|
105 |
+
raise
|
106 |
+
|
107 |
+
#----------------------------------------------------------------------------
|
108 |
+
|
109 |
+
def convert_tf_generator(tf_G):
|
110 |
+
if tf_G.version < 4:
|
111 |
+
raise ValueError('TensorFlow pickle version too low')
|
112 |
+
|
113 |
+
# Collect kwargs.
|
114 |
+
tf_kwargs = tf_G.static_kwargs
|
115 |
+
known_kwargs = set()
|
116 |
+
def kwarg(tf_name, default=None, none=None):
|
117 |
+
known_kwargs.add(tf_name)
|
118 |
+
val = tf_kwargs.get(tf_name, default)
|
119 |
+
return val if val is not None else none
|
120 |
+
|
121 |
+
# Convert kwargs.
|
122 |
+
kwargs = dnnlib.EasyDict(
|
123 |
+
z_dim = kwarg('latent_size', 512),
|
124 |
+
c_dim = kwarg('label_size', 0),
|
125 |
+
w_dim = kwarg('dlatent_size', 512),
|
126 |
+
img_resolution = kwarg('resolution', 1024),
|
127 |
+
img_channels = kwarg('num_channels', 3),
|
128 |
+
mapping_kwargs = dnnlib.EasyDict(
|
129 |
+
num_layers = kwarg('mapping_layers', 8),
|
130 |
+
embed_features = kwarg('label_fmaps', None),
|
131 |
+
layer_features = kwarg('mapping_fmaps', None),
|
132 |
+
activation = kwarg('mapping_nonlinearity', 'lrelu'),
|
133 |
+
lr_multiplier = kwarg('mapping_lrmul', 0.01),
|
134 |
+
w_avg_beta = kwarg('w_avg_beta', 0.995, none=1),
|
135 |
+
),
|
136 |
+
synthesis_kwargs = dnnlib.EasyDict(
|
137 |
+
channel_base = kwarg('fmap_base', 16384) * 2,
|
138 |
+
channel_max = kwarg('fmap_max', 512),
|
139 |
+
num_fp16_res = kwarg('num_fp16_res', 0),
|
140 |
+
conv_clamp = kwarg('conv_clamp', None),
|
141 |
+
architecture = kwarg('architecture', 'skip'),
|
142 |
+
resample_filter = kwarg('resample_kernel', [1,3,3,1]),
|
143 |
+
use_noise = kwarg('use_noise', True),
|
144 |
+
activation = kwarg('nonlinearity', 'lrelu'),
|
145 |
+
),
|
146 |
+
)
|
147 |
+
|
148 |
+
# Check for unknown kwargs.
|
149 |
+
kwarg('truncation_psi')
|
150 |
+
kwarg('truncation_cutoff')
|
151 |
+
kwarg('style_mixing_prob')
|
152 |
+
kwarg('structure')
|
153 |
+
unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
|
154 |
+
if len(unknown_kwargs) > 0:
|
155 |
+
raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
|
156 |
+
|
157 |
+
# Collect params.
|
158 |
+
tf_params = _collect_tf_params(tf_G)
|
159 |
+
for name, value in list(tf_params.items()):
|
160 |
+
match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name)
|
161 |
+
if match:
|
162 |
+
r = kwargs.img_resolution // (2 ** int(match.group(1)))
|
163 |
+
tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value
|
164 |
+
kwargs.synthesis.kwargs.architecture = 'orig'
|
165 |
+
#for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
|
166 |
+
|
167 |
+
# Convert params.
|
168 |
+
from training import networks
|
169 |
+
G = networks.Generator(**kwargs).eval().requires_grad_(False)
|
170 |
+
# pylint: disable=unnecessary-lambda
|
171 |
+
_populate_module_params(G,
|
172 |
+
r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'],
|
173 |
+
r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(),
|
174 |
+
r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'],
|
175 |
+
r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(),
|
176 |
+
r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'],
|
177 |
+
r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0],
|
178 |
+
r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1),
|
179 |
+
r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'],
|
180 |
+
r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0],
|
181 |
+
r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'],
|
182 |
+
r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(),
|
183 |
+
r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1,
|
184 |
+
r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
|
185 |
+
r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'],
|
186 |
+
r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0],
|
187 |
+
r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'],
|
188 |
+
r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(),
|
189 |
+
r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1,
|
190 |
+
r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1),
|
191 |
+
r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'],
|
192 |
+
r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0],
|
193 |
+
r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'],
|
194 |
+
r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(),
|
195 |
+
r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1,
|
196 |
+
r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1),
|
197 |
+
r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'],
|
198 |
+
r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(),
|
199 |
+
r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1,
|
200 |
+
r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
|
201 |
+
r'.*\.resample_filter', None,
|
202 |
+
)
|
203 |
+
return G
|
204 |
+
|
205 |
+
#----------------------------------------------------------------------------
|
206 |
+
|
207 |
+
def convert_tf_discriminator(tf_D):
|
208 |
+
if tf_D.version < 4:
|
209 |
+
raise ValueError('TensorFlow pickle version too low')
|
210 |
+
|
211 |
+
# Collect kwargs.
|
212 |
+
tf_kwargs = tf_D.static_kwargs
|
213 |
+
known_kwargs = set()
|
214 |
+
def kwarg(tf_name, default=None):
|
215 |
+
known_kwargs.add(tf_name)
|
216 |
+
return tf_kwargs.get(tf_name, default)
|
217 |
+
|
218 |
+
# Convert kwargs.
|
219 |
+
kwargs = dnnlib.EasyDict(
|
220 |
+
c_dim = kwarg('label_size', 0),
|
221 |
+
img_resolution = kwarg('resolution', 1024),
|
222 |
+
img_channels = kwarg('num_channels', 3),
|
223 |
+
architecture = kwarg('architecture', 'resnet'),
|
224 |
+
channel_base = kwarg('fmap_base', 16384) * 2,
|
225 |
+
channel_max = kwarg('fmap_max', 512),
|
226 |
+
num_fp16_res = kwarg('num_fp16_res', 0),
|
227 |
+
conv_clamp = kwarg('conv_clamp', None),
|
228 |
+
cmap_dim = kwarg('mapping_fmaps', None),
|
229 |
+
block_kwargs = dnnlib.EasyDict(
|
230 |
+
activation = kwarg('nonlinearity', 'lrelu'),
|
231 |
+
resample_filter = kwarg('resample_kernel', [1,3,3,1]),
|
232 |
+
freeze_layers = kwarg('freeze_layers', 0),
|
233 |
+
),
|
234 |
+
mapping_kwargs = dnnlib.EasyDict(
|
235 |
+
num_layers = kwarg('mapping_layers', 0),
|
236 |
+
embed_features = kwarg('mapping_fmaps', None),
|
237 |
+
layer_features = kwarg('mapping_fmaps', None),
|
238 |
+
activation = kwarg('nonlinearity', 'lrelu'),
|
239 |
+
lr_multiplier = kwarg('mapping_lrmul', 0.1),
|
240 |
+
),
|
241 |
+
epilogue_kwargs = dnnlib.EasyDict(
|
242 |
+
mbstd_group_size = kwarg('mbstd_group_size', None),
|
243 |
+
mbstd_num_channels = kwarg('mbstd_num_features', 1),
|
244 |
+
activation = kwarg('nonlinearity', 'lrelu'),
|
245 |
+
),
|
246 |
+
)
|
247 |
+
|
248 |
+
# Check for unknown kwargs.
|
249 |
+
kwarg('structure')
|
250 |
+
unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
|
251 |
+
if len(unknown_kwargs) > 0:
|
252 |
+
raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
|
253 |
+
|
254 |
+
# Collect params.
|
255 |
+
tf_params = _collect_tf_params(tf_D)
|
256 |
+
for name, value in list(tf_params.items()):
|
257 |
+
match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name)
|
258 |
+
if match:
|
259 |
+
r = kwargs.img_resolution // (2 ** int(match.group(1)))
|
260 |
+
tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value
|
261 |
+
kwargs.architecture = 'orig'
|
262 |
+
#for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
|
263 |
+
|
264 |
+
# Convert params.
|
265 |
+
from training import networks
|
266 |
+
D = networks.Discriminator(**kwargs).eval().requires_grad_(False)
|
267 |
+
# pylint: disable=unnecessary-lambda
|
268 |
+
_populate_module_params(D,
|
269 |
+
r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1),
|
270 |
+
r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'],
|
271 |
+
r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1),
|
272 |
+
r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'],
|
273 |
+
r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1),
|
274 |
+
r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(),
|
275 |
+
r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'],
|
276 |
+
r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(),
|
277 |
+
r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'],
|
278 |
+
r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1),
|
279 |
+
r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'],
|
280 |
+
r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(),
|
281 |
+
r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'],
|
282 |
+
r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(),
|
283 |
+
r'b4\.out\.bias', lambda: tf_params[f'Output/bias'],
|
284 |
+
r'.*\.resample_filter', None,
|
285 |
+
)
|
286 |
+
return D
|
287 |
+
|
288 |
+
#----------------------------------------------------------------------------
|
289 |
+
|
290 |
+
@click.command()
|
291 |
+
@click.option('--source', help='Input pickle', required=True, metavar='PATH')
|
292 |
+
@click.option('--dest', help='Output pickle', required=True, metavar='PATH')
|
293 |
+
@click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True)
|
294 |
+
def convert_network_pickle(source, dest, force_fp16):
|
295 |
+
"""Convert legacy network pickle into the native PyTorch format.
|
296 |
+
|
297 |
+
The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA.
|
298 |
+
It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks.
|
299 |
+
|
300 |
+
Example:
|
301 |
+
|
302 |
+
\b
|
303 |
+
python legacy.py \\
|
304 |
+
--source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\
|
305 |
+
--dest=stylegan2-cat-config-f.pkl
|
306 |
+
"""
|
307 |
+
print(f'Loading "{source}"...')
|
308 |
+
with dnnlib.util.open_url(source) as f:
|
309 |
+
data = load_network_pkl(f, force_fp16=force_fp16)
|
310 |
+
print(f'Saving "{dest}"...')
|
311 |
+
with open(dest, 'wb') as f:
|
312 |
+
pickle.dump(data, f)
|
313 |
+
print('Done.')
|
314 |
+
|
315 |
+
#----------------------------------------------------------------------------
|
316 |
+
|
317 |
+
if __name__ == "__main__":
|
318 |
+
convert_network_pickle() # pylint: disable=no-value-for-parameter
|
319 |
+
|
320 |
+
#----------------------------------------------------------------------------
|
metrics/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
# empty
|
metrics/frechet_inception_distance.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Frechet Inception Distance (FID) from the paper
|
10 |
+
"GANs trained by a two time-scale update rule converge to a local Nash
|
11 |
+
equilibrium". Matches the original implementation by Heusel et al. at
|
12 |
+
https://github.com/bioinf-jku/TTUR/blob/master/fid.py"""
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
import scipy.linalg
|
16 |
+
from . import metric_utils
|
17 |
+
|
18 |
+
#----------------------------------------------------------------------------
|
19 |
+
|
20 |
+
def compute_fid(opts, max_real, num_gen):
|
21 |
+
# Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
22 |
+
detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
|
23 |
+
detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
|
24 |
+
|
25 |
+
mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset(
|
26 |
+
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
27 |
+
rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real).get_mean_cov()
|
28 |
+
|
29 |
+
mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator(
|
30 |
+
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
31 |
+
rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen).get_mean_cov()
|
32 |
+
|
33 |
+
if opts.rank != 0:
|
34 |
+
return float('nan')
|
35 |
+
|
36 |
+
m = np.square(mu_gen - mu_real).sum()
|
37 |
+
s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
|
38 |
+
fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
|
39 |
+
return float(fid)
|
40 |
+
|
41 |
+
#----------------------------------------------------------------------------
|
metrics/inception_score.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Inception Score (IS) from the paper "Improved techniques for training
|
10 |
+
GANs". Matches the original implementation by Salimans et al. at
|
11 |
+
https://github.com/openai/improved-gan/blob/master/inception_score/model.py"""
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
from . import metric_utils
|
15 |
+
|
16 |
+
#----------------------------------------------------------------------------
|
17 |
+
|
18 |
+
def compute_is(opts, num_gen, num_splits):
|
19 |
+
# Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
20 |
+
detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
|
21 |
+
detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer.
|
22 |
+
|
23 |
+
gen_probs = metric_utils.compute_feature_stats_for_generator(
|
24 |
+
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
25 |
+
capture_all=True, max_items=num_gen).get_all()
|
26 |
+
|
27 |
+
if opts.rank != 0:
|
28 |
+
return float('nan'), float('nan')
|
29 |
+
|
30 |
+
scores = []
|
31 |
+
for i in range(num_splits):
|
32 |
+
part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits]
|
33 |
+
kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True)))
|
34 |
+
kl = np.mean(np.sum(kl, axis=1))
|
35 |
+
scores.append(np.exp(kl))
|
36 |
+
return float(np.mean(scores)), float(np.std(scores))
|
37 |
+
|
38 |
+
#----------------------------------------------------------------------------
|
metrics/kernel_inception_distance.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Kernel Inception Distance (KID) from the paper "Demystifying MMD
|
10 |
+
GANs". Matches the original implementation by Binkowski et al. at
|
11 |
+
https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py"""
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
from . import metric_utils
|
15 |
+
|
16 |
+
#----------------------------------------------------------------------------
|
17 |
+
|
18 |
+
def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size):
|
19 |
+
# Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
20 |
+
detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
|
21 |
+
detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
|
22 |
+
|
23 |
+
real_features = metric_utils.compute_feature_stats_for_dataset(
|
24 |
+
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
25 |
+
rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all()
|
26 |
+
|
27 |
+
gen_features = metric_utils.compute_feature_stats_for_generator(
|
28 |
+
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
29 |
+
rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all()
|
30 |
+
|
31 |
+
if opts.rank != 0:
|
32 |
+
return float('nan')
|
33 |
+
|
34 |
+
n = real_features.shape[1]
|
35 |
+
m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size)
|
36 |
+
t = 0
|
37 |
+
for _subset_idx in range(num_subsets):
|
38 |
+
x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)]
|
39 |
+
y = real_features[np.random.choice(real_features.shape[0], m, replace=False)]
|
40 |
+
a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3
|
41 |
+
b = (x @ y.T / n + 1) ** 3
|
42 |
+
t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m
|
43 |
+
kid = t / num_subsets / m
|
44 |
+
return float(kid)
|
45 |
+
|
46 |
+
#----------------------------------------------------------------------------
|
metrics/metric_main.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import os
|
10 |
+
import time
|
11 |
+
import json
|
12 |
+
import torch
|
13 |
+
import dnnlib
|
14 |
+
|
15 |
+
from . import metric_utils
|
16 |
+
from . import frechet_inception_distance
|
17 |
+
from . import kernel_inception_distance
|
18 |
+
from . import precision_recall
|
19 |
+
from . import perceptual_path_length
|
20 |
+
from . import inception_score
|
21 |
+
|
22 |
+
#----------------------------------------------------------------------------
|
23 |
+
|
24 |
+
_metric_dict = dict() # name => fn
|
25 |
+
|
26 |
+
def register_metric(fn):
|
27 |
+
assert callable(fn)
|
28 |
+
_metric_dict[fn.__name__] = fn
|
29 |
+
return fn
|
30 |
+
|
31 |
+
def is_valid_metric(metric):
|
32 |
+
return metric in _metric_dict
|
33 |
+
|
34 |
+
def list_valid_metrics():
|
35 |
+
return list(_metric_dict.keys())
|
36 |
+
|
37 |
+
#----------------------------------------------------------------------------
|
38 |
+
|
39 |
+
def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments.
|
40 |
+
assert is_valid_metric(metric)
|
41 |
+
opts = metric_utils.MetricOptions(**kwargs)
|
42 |
+
|
43 |
+
# Calculate.
|
44 |
+
start_time = time.time()
|
45 |
+
results = _metric_dict[metric](opts)
|
46 |
+
total_time = time.time() - start_time
|
47 |
+
|
48 |
+
# Broadcast results.
|
49 |
+
for key, value in list(results.items()):
|
50 |
+
if opts.num_gpus > 1:
|
51 |
+
value = torch.as_tensor(value, dtype=torch.float64, device=opts.device)
|
52 |
+
torch.distributed.broadcast(tensor=value, src=0)
|
53 |
+
value = float(value.cpu())
|
54 |
+
results[key] = value
|
55 |
+
|
56 |
+
# Decorate with metadata.
|
57 |
+
return dnnlib.EasyDict(
|
58 |
+
results = dnnlib.EasyDict(results),
|
59 |
+
metric = metric,
|
60 |
+
total_time = total_time,
|
61 |
+
total_time_str = dnnlib.util.format_time(total_time),
|
62 |
+
num_gpus = opts.num_gpus,
|
63 |
+
)
|
64 |
+
|
65 |
+
#----------------------------------------------------------------------------
|
66 |
+
|
67 |
+
def report_metric(result_dict, run_dir=None, snapshot_pkl=None):
|
68 |
+
metric = result_dict['metric']
|
69 |
+
assert is_valid_metric(metric)
|
70 |
+
if run_dir is not None and snapshot_pkl is not None:
|
71 |
+
snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir)
|
72 |
+
|
73 |
+
jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time()))
|
74 |
+
print(jsonl_line)
|
75 |
+
if run_dir is not None and os.path.isdir(run_dir):
|
76 |
+
with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f:
|
77 |
+
f.write(jsonl_line + '\n')
|
78 |
+
|
79 |
+
#----------------------------------------------------------------------------
|
80 |
+
# Primary metrics.
|
81 |
+
|
82 |
+
@register_metric
|
83 |
+
def fid50k_full(opts):
|
84 |
+
opts.dataset_kwargs.update(max_size=None, xflip=False)
|
85 |
+
fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000)
|
86 |
+
return dict(fid50k_full=fid)
|
87 |
+
|
88 |
+
@register_metric
|
89 |
+
def kid50k_full(opts):
|
90 |
+
opts.dataset_kwargs.update(max_size=None, xflip=False)
|
91 |
+
kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000)
|
92 |
+
return dict(kid50k_full=kid)
|
93 |
+
|
94 |
+
@register_metric
|
95 |
+
def pr50k3_full(opts):
|
96 |
+
opts.dataset_kwargs.update(max_size=None, xflip=False)
|
97 |
+
precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
|
98 |
+
return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall)
|
99 |
+
|
100 |
+
@register_metric
|
101 |
+
def ppl2_wend(opts):
|
102 |
+
ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2)
|
103 |
+
return dict(ppl2_wend=ppl)
|
104 |
+
|
105 |
+
@register_metric
|
106 |
+
def is50k(opts):
|
107 |
+
opts.dataset_kwargs.update(max_size=None, xflip=False)
|
108 |
+
mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10)
|
109 |
+
return dict(is50k_mean=mean, is50k_std=std)
|
110 |
+
|
111 |
+
#----------------------------------------------------------------------------
|
112 |
+
# Legacy metrics.
|
113 |
+
|
114 |
+
@register_metric
|
115 |
+
def fid50k(opts):
|
116 |
+
opts.dataset_kwargs.update(max_size=None)
|
117 |
+
fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000)
|
118 |
+
return dict(fid50k=fid)
|
119 |
+
|
120 |
+
@register_metric
|
121 |
+
def kid50k(opts):
|
122 |
+
opts.dataset_kwargs.update(max_size=None)
|
123 |
+
kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000)
|
124 |
+
return dict(kid50k=kid)
|
125 |
+
|
126 |
+
@register_metric
|
127 |
+
def pr50k3(opts):
|
128 |
+
opts.dataset_kwargs.update(max_size=None)
|
129 |
+
precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
|
130 |
+
return dict(pr50k3_precision=precision, pr50k3_recall=recall)
|
131 |
+
|
132 |
+
@register_metric
|
133 |
+
def ppl_zfull(opts):
|
134 |
+
ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='z', sampling='full', crop=True, batch_size=2)
|
135 |
+
return dict(ppl_zfull=ppl)
|
136 |
+
|
137 |
+
@register_metric
|
138 |
+
def ppl_wfull(opts):
|
139 |
+
ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='full', crop=True, batch_size=2)
|
140 |
+
return dict(ppl_wfull=ppl)
|
141 |
+
|
142 |
+
@register_metric
|
143 |
+
def ppl_zend(opts):
|
144 |
+
ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='z', sampling='end', crop=True, batch_size=2)
|
145 |
+
return dict(ppl_zend=ppl)
|
146 |
+
|
147 |
+
@register_metric
|
148 |
+
def ppl_wend(opts):
|
149 |
+
ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=True, batch_size=2)
|
150 |
+
return dict(ppl_wend=ppl)
|
151 |
+
|
152 |
+
#----------------------------------------------------------------------------
|
metrics/metric_utils.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import os
|
10 |
+
import time
|
11 |
+
import hashlib
|
12 |
+
import pickle
|
13 |
+
import copy
|
14 |
+
import uuid
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
import dnnlib
|
18 |
+
|
19 |
+
#----------------------------------------------------------------------------
|
20 |
+
|
21 |
+
class MetricOptions:
|
22 |
+
def __init__(self, G=None, G_kwargs={}, dataset_kwargs={}, num_gpus=1, rank=0, device=None, progress=None, cache=True):
|
23 |
+
assert 0 <= rank < num_gpus
|
24 |
+
self.G = G
|
25 |
+
self.G_kwargs = dnnlib.EasyDict(G_kwargs)
|
26 |
+
self.dataset_kwargs = dnnlib.EasyDict(dataset_kwargs)
|
27 |
+
self.num_gpus = num_gpus
|
28 |
+
self.rank = rank
|
29 |
+
self.device = device if device is not None else torch.device('cuda', rank)
|
30 |
+
self.progress = progress.sub() if progress is not None and rank == 0 else ProgressMonitor()
|
31 |
+
self.cache = cache
|
32 |
+
|
33 |
+
#----------------------------------------------------------------------------
|
34 |
+
|
35 |
+
_feature_detector_cache = dict()
|
36 |
+
|
37 |
+
def get_feature_detector_name(url):
|
38 |
+
return os.path.splitext(url.split('/')[-1])[0]
|
39 |
+
|
40 |
+
def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False):
|
41 |
+
assert 0 <= rank < num_gpus
|
42 |
+
key = (url, device)
|
43 |
+
if key not in _feature_detector_cache:
|
44 |
+
is_leader = (rank == 0)
|
45 |
+
if not is_leader and num_gpus > 1:
|
46 |
+
torch.distributed.barrier() # leader goes first
|
47 |
+
with dnnlib.util.open_url(url, verbose=(verbose and is_leader)) as f:
|
48 |
+
_feature_detector_cache[key] = torch.jit.load(f).eval().to(device)
|
49 |
+
if is_leader and num_gpus > 1:
|
50 |
+
torch.distributed.barrier() # others follow
|
51 |
+
return _feature_detector_cache[key]
|
52 |
+
|
53 |
+
#----------------------------------------------------------------------------
|
54 |
+
|
55 |
+
class FeatureStats:
|
56 |
+
def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None):
|
57 |
+
self.capture_all = capture_all
|
58 |
+
self.capture_mean_cov = capture_mean_cov
|
59 |
+
self.max_items = max_items
|
60 |
+
self.num_items = 0
|
61 |
+
self.num_features = None
|
62 |
+
self.all_features = None
|
63 |
+
self.raw_mean = None
|
64 |
+
self.raw_cov = None
|
65 |
+
|
66 |
+
def set_num_features(self, num_features):
|
67 |
+
if self.num_features is not None:
|
68 |
+
assert num_features == self.num_features
|
69 |
+
else:
|
70 |
+
self.num_features = num_features
|
71 |
+
self.all_features = []
|
72 |
+
self.raw_mean = np.zeros([num_features], dtype=np.float64)
|
73 |
+
self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64)
|
74 |
+
|
75 |
+
def is_full(self):
|
76 |
+
return (self.max_items is not None) and (self.num_items >= self.max_items)
|
77 |
+
|
78 |
+
def append(self, x):
|
79 |
+
x = np.asarray(x, dtype=np.float32)
|
80 |
+
assert x.ndim == 2
|
81 |
+
if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items):
|
82 |
+
if self.num_items >= self.max_items:
|
83 |
+
return
|
84 |
+
x = x[:self.max_items - self.num_items]
|
85 |
+
|
86 |
+
self.set_num_features(x.shape[1])
|
87 |
+
self.num_items += x.shape[0]
|
88 |
+
if self.capture_all:
|
89 |
+
self.all_features.append(x)
|
90 |
+
if self.capture_mean_cov:
|
91 |
+
x64 = x.astype(np.float64)
|
92 |
+
self.raw_mean += x64.sum(axis=0)
|
93 |
+
self.raw_cov += x64.T @ x64
|
94 |
+
|
95 |
+
def append_torch(self, x, num_gpus=1, rank=0):
|
96 |
+
assert isinstance(x, torch.Tensor) and x.ndim == 2
|
97 |
+
assert 0 <= rank < num_gpus
|
98 |
+
if num_gpus > 1:
|
99 |
+
ys = []
|
100 |
+
for src in range(num_gpus):
|
101 |
+
y = x.clone()
|
102 |
+
torch.distributed.broadcast(y, src=src)
|
103 |
+
ys.append(y)
|
104 |
+
x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples
|
105 |
+
self.append(x.cpu().numpy())
|
106 |
+
|
107 |
+
def get_all(self):
|
108 |
+
assert self.capture_all
|
109 |
+
return np.concatenate(self.all_features, axis=0)
|
110 |
+
|
111 |
+
def get_all_torch(self):
|
112 |
+
return torch.from_numpy(self.get_all())
|
113 |
+
|
114 |
+
def get_mean_cov(self):
|
115 |
+
assert self.capture_mean_cov
|
116 |
+
mean = self.raw_mean / self.num_items
|
117 |
+
cov = self.raw_cov / self.num_items
|
118 |
+
cov = cov - np.outer(mean, mean)
|
119 |
+
return mean, cov
|
120 |
+
|
121 |
+
def save(self, pkl_file):
|
122 |
+
with open(pkl_file, 'wb') as f:
|
123 |
+
pickle.dump(self.__dict__, f)
|
124 |
+
|
125 |
+
@staticmethod
|
126 |
+
def load(pkl_file):
|
127 |
+
with open(pkl_file, 'rb') as f:
|
128 |
+
s = dnnlib.EasyDict(pickle.load(f))
|
129 |
+
obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items)
|
130 |
+
obj.__dict__.update(s)
|
131 |
+
return obj
|
132 |
+
|
133 |
+
#----------------------------------------------------------------------------
|
134 |
+
|
135 |
+
class ProgressMonitor:
|
136 |
+
def __init__(self, tag=None, num_items=None, flush_interval=1000, verbose=False, progress_fn=None, pfn_lo=0, pfn_hi=1000, pfn_total=1000):
|
137 |
+
self.tag = tag
|
138 |
+
self.num_items = num_items
|
139 |
+
self.verbose = verbose
|
140 |
+
self.flush_interval = flush_interval
|
141 |
+
self.progress_fn = progress_fn
|
142 |
+
self.pfn_lo = pfn_lo
|
143 |
+
self.pfn_hi = pfn_hi
|
144 |
+
self.pfn_total = pfn_total
|
145 |
+
self.start_time = time.time()
|
146 |
+
self.batch_time = self.start_time
|
147 |
+
self.batch_items = 0
|
148 |
+
if self.progress_fn is not None:
|
149 |
+
self.progress_fn(self.pfn_lo, self.pfn_total)
|
150 |
+
|
151 |
+
def update(self, cur_items):
|
152 |
+
assert (self.num_items is None) or (cur_items <= self.num_items)
|
153 |
+
if (cur_items < self.batch_items + self.flush_interval) and (self.num_items is None or cur_items < self.num_items):
|
154 |
+
return
|
155 |
+
cur_time = time.time()
|
156 |
+
total_time = cur_time - self.start_time
|
157 |
+
time_per_item = (cur_time - self.batch_time) / max(cur_items - self.batch_items, 1)
|
158 |
+
if (self.verbose) and (self.tag is not None):
|
159 |
+
print(f'{self.tag:<19s} items {cur_items:<7d} time {dnnlib.util.format_time(total_time):<12s} ms/item {time_per_item*1e3:.2f}')
|
160 |
+
self.batch_time = cur_time
|
161 |
+
self.batch_items = cur_items
|
162 |
+
|
163 |
+
if (self.progress_fn is not None) and (self.num_items is not None):
|
164 |
+
self.progress_fn(self.pfn_lo + (self.pfn_hi - self.pfn_lo) * (cur_items / self.num_items), self.pfn_total)
|
165 |
+
|
166 |
+
def sub(self, tag=None, num_items=None, flush_interval=1000, rel_lo=0, rel_hi=1):
|
167 |
+
return ProgressMonitor(
|
168 |
+
tag = tag,
|
169 |
+
num_items = num_items,
|
170 |
+
flush_interval = flush_interval,
|
171 |
+
verbose = self.verbose,
|
172 |
+
progress_fn = self.progress_fn,
|
173 |
+
pfn_lo = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_lo,
|
174 |
+
pfn_hi = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_hi,
|
175 |
+
pfn_total = self.pfn_total,
|
176 |
+
)
|
177 |
+
|
178 |
+
#----------------------------------------------------------------------------
|
179 |
+
|
180 |
+
def compute_feature_stats_for_dataset(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, data_loader_kwargs=None, max_items=None, **stats_kwargs):
|
181 |
+
dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
|
182 |
+
if data_loader_kwargs is None:
|
183 |
+
data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2)
|
184 |
+
|
185 |
+
# Try to lookup from cache.
|
186 |
+
cache_file = None
|
187 |
+
if opts.cache:
|
188 |
+
# Choose cache file name.
|
189 |
+
args = dict(dataset_kwargs=opts.dataset_kwargs, detector_url=detector_url, detector_kwargs=detector_kwargs, stats_kwargs=stats_kwargs)
|
190 |
+
md5 = hashlib.md5(repr(sorted(args.items())).encode('utf-8'))
|
191 |
+
cache_tag = f'{dataset.name}-{get_feature_detector_name(detector_url)}-{md5.hexdigest()}'
|
192 |
+
cache_file = dnnlib.make_cache_dir_path('gan-metrics', cache_tag + '.pkl')
|
193 |
+
|
194 |
+
# Check if the file exists (all processes must agree).
|
195 |
+
flag = os.path.isfile(cache_file) if opts.rank == 0 else False
|
196 |
+
if opts.num_gpus > 1:
|
197 |
+
flag = torch.as_tensor(flag, dtype=torch.float32, device=opts.device)
|
198 |
+
torch.distributed.broadcast(tensor=flag, src=0)
|
199 |
+
flag = (float(flag.cpu()) != 0)
|
200 |
+
|
201 |
+
# Load.
|
202 |
+
if flag:
|
203 |
+
return FeatureStats.load(cache_file)
|
204 |
+
|
205 |
+
# Initialize.
|
206 |
+
num_items = len(dataset)
|
207 |
+
if max_items is not None:
|
208 |
+
num_items = min(num_items, max_items)
|
209 |
+
stats = FeatureStats(max_items=num_items, **stats_kwargs)
|
210 |
+
progress = opts.progress.sub(tag='dataset features', num_items=num_items, rel_lo=rel_lo, rel_hi=rel_hi)
|
211 |
+
detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
|
212 |
+
|
213 |
+
# Main loop.
|
214 |
+
item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)]
|
215 |
+
for images, _labels in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs):
|
216 |
+
if images.shape[1] == 1:
|
217 |
+
images = images.repeat([1, 3, 1, 1])
|
218 |
+
features = detector(images.to(opts.device), **detector_kwargs)
|
219 |
+
stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
|
220 |
+
progress.update(stats.num_items)
|
221 |
+
|
222 |
+
# Save to cache.
|
223 |
+
if cache_file is not None and opts.rank == 0:
|
224 |
+
os.makedirs(os.path.dirname(cache_file), exist_ok=True)
|
225 |
+
temp_file = cache_file + '.' + uuid.uuid4().hex
|
226 |
+
stats.save(temp_file)
|
227 |
+
os.replace(temp_file, cache_file) # atomic
|
228 |
+
return stats
|
229 |
+
|
230 |
+
#----------------------------------------------------------------------------
|
231 |
+
|
232 |
+
def compute_feature_stats_for_generator(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, batch_gen=None, jit=False, **stats_kwargs):
|
233 |
+
if batch_gen is None:
|
234 |
+
batch_gen = min(batch_size, 4)
|
235 |
+
assert batch_size % batch_gen == 0
|
236 |
+
|
237 |
+
# Setup generator and load labels.
|
238 |
+
G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
|
239 |
+
dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
|
240 |
+
|
241 |
+
# Image generation func.
|
242 |
+
def run_generator(z, c):
|
243 |
+
img = G(z=z, c=c, **opts.G_kwargs)
|
244 |
+
img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
|
245 |
+
return img
|
246 |
+
|
247 |
+
# JIT.
|
248 |
+
if jit:
|
249 |
+
z = torch.zeros([batch_gen, G.z_dim], device=opts.device)
|
250 |
+
c = torch.zeros([batch_gen, G.c_dim], device=opts.device)
|
251 |
+
run_generator = torch.jit.trace(run_generator, [z, c], check_trace=False)
|
252 |
+
|
253 |
+
# Initialize.
|
254 |
+
stats = FeatureStats(**stats_kwargs)
|
255 |
+
assert stats.max_items is not None
|
256 |
+
progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi)
|
257 |
+
detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
|
258 |
+
|
259 |
+
# Main loop.
|
260 |
+
while not stats.is_full():
|
261 |
+
images = []
|
262 |
+
for _i in range(batch_size // batch_gen):
|
263 |
+
z = torch.randn([batch_gen, G.z_dim], device=opts.device)
|
264 |
+
c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_gen)]
|
265 |
+
c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
|
266 |
+
images.append(run_generator(z, c))
|
267 |
+
images = torch.cat(images)
|
268 |
+
if images.shape[1] == 1:
|
269 |
+
images = images.repeat([1, 3, 1, 1])
|
270 |
+
features = detector(images, **detector_kwargs)
|
271 |
+
stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
|
272 |
+
progress.update(stats.num_items)
|
273 |
+
return stats
|
274 |
+
|
275 |
+
#----------------------------------------------------------------------------
|
metrics/perceptual_path_length.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Perceptual Path Length (PPL) from the paper "A Style-Based Generator
|
10 |
+
Architecture for Generative Adversarial Networks". Matches the original
|
11 |
+
implementation by Karras et al. at
|
12 |
+
https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py"""
|
13 |
+
|
14 |
+
import copy
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
import dnnlib
|
18 |
+
from . import metric_utils
|
19 |
+
|
20 |
+
#----------------------------------------------------------------------------
|
21 |
+
|
22 |
+
# Spherical interpolation of a batch of vectors.
|
23 |
+
def slerp(a, b, t):
|
24 |
+
a = a / a.norm(dim=-1, keepdim=True)
|
25 |
+
b = b / b.norm(dim=-1, keepdim=True)
|
26 |
+
d = (a * b).sum(dim=-1, keepdim=True)
|
27 |
+
p = t * torch.acos(d)
|
28 |
+
c = b - d * a
|
29 |
+
c = c / c.norm(dim=-1, keepdim=True)
|
30 |
+
d = a * torch.cos(p) + c * torch.sin(p)
|
31 |
+
d = d / d.norm(dim=-1, keepdim=True)
|
32 |
+
return d
|
33 |
+
|
34 |
+
#----------------------------------------------------------------------------
|
35 |
+
|
36 |
+
class PPLSampler(torch.nn.Module):
|
37 |
+
def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16):
|
38 |
+
assert space in ['z', 'w']
|
39 |
+
assert sampling in ['full', 'end']
|
40 |
+
super().__init__()
|
41 |
+
self.G = copy.deepcopy(G)
|
42 |
+
self.G_kwargs = G_kwargs
|
43 |
+
self.epsilon = epsilon
|
44 |
+
self.space = space
|
45 |
+
self.sampling = sampling
|
46 |
+
self.crop = crop
|
47 |
+
self.vgg16 = copy.deepcopy(vgg16)
|
48 |
+
|
49 |
+
def forward(self, c):
|
50 |
+
# Generate random latents and interpolation t-values.
|
51 |
+
t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0)
|
52 |
+
z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2)
|
53 |
+
|
54 |
+
# Interpolate in W or Z.
|
55 |
+
if self.space == 'w':
|
56 |
+
w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2)
|
57 |
+
wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2))
|
58 |
+
wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon)
|
59 |
+
else: # space == 'z'
|
60 |
+
zt0 = slerp(z0, z1, t.unsqueeze(1))
|
61 |
+
zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon)
|
62 |
+
wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2)
|
63 |
+
|
64 |
+
# Randomize noise buffers.
|
65 |
+
for name, buf in self.G.named_buffers():
|
66 |
+
if name.endswith('.noise_const'):
|
67 |
+
buf.copy_(torch.randn_like(buf))
|
68 |
+
|
69 |
+
# Generate images.
|
70 |
+
img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs)
|
71 |
+
|
72 |
+
# Center crop.
|
73 |
+
if self.crop:
|
74 |
+
assert img.shape[2] == img.shape[3]
|
75 |
+
c = img.shape[2] // 8
|
76 |
+
img = img[:, :, c*3 : c*7, c*2 : c*6]
|
77 |
+
|
78 |
+
# Downsample to 256x256.
|
79 |
+
factor = self.G.img_resolution // 256
|
80 |
+
if factor > 1:
|
81 |
+
img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5])
|
82 |
+
|
83 |
+
# Scale dynamic range from [-1,1] to [0,255].
|
84 |
+
img = (img + 1) * (255 / 2)
|
85 |
+
if self.G.img_channels == 1:
|
86 |
+
img = img.repeat([1, 3, 1, 1])
|
87 |
+
|
88 |
+
# Evaluate differential LPIPS.
|
89 |
+
lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2)
|
90 |
+
dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2
|
91 |
+
return dist
|
92 |
+
|
93 |
+
#----------------------------------------------------------------------------
|
94 |
+
|
95 |
+
def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size, jit=False):
|
96 |
+
dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
|
97 |
+
vgg16_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
|
98 |
+
vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose)
|
99 |
+
|
100 |
+
# Setup sampler.
|
101 |
+
sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16)
|
102 |
+
sampler.eval().requires_grad_(False).to(opts.device)
|
103 |
+
if jit:
|
104 |
+
c = torch.zeros([batch_size, opts.G.c_dim], device=opts.device)
|
105 |
+
sampler = torch.jit.trace(sampler, [c], check_trace=False)
|
106 |
+
|
107 |
+
# Sampling loop.
|
108 |
+
dist = []
|
109 |
+
progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples)
|
110 |
+
for batch_start in range(0, num_samples, batch_size * opts.num_gpus):
|
111 |
+
progress.update(batch_start)
|
112 |
+
c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_size)]
|
113 |
+
c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
|
114 |
+
x = sampler(c)
|
115 |
+
for src in range(opts.num_gpus):
|
116 |
+
y = x.clone()
|
117 |
+
if opts.num_gpus > 1:
|
118 |
+
torch.distributed.broadcast(y, src=src)
|
119 |
+
dist.append(y)
|
120 |
+
progress.update(num_samples)
|
121 |
+
|
122 |
+
# Compute PPL.
|
123 |
+
if opts.rank != 0:
|
124 |
+
return float('nan')
|
125 |
+
dist = torch.cat(dist)[:num_samples].cpu().numpy()
|
126 |
+
lo = np.percentile(dist, 1, interpolation='lower')
|
127 |
+
hi = np.percentile(dist, 99, interpolation='higher')
|
128 |
+
ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean()
|
129 |
+
return float(ppl)
|
130 |
+
|
131 |
+
#----------------------------------------------------------------------------
|
metrics/precision_recall.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Precision/Recall (PR) from the paper "Improved Precision and Recall
|
10 |
+
Metric for Assessing Generative Models". Matches the original implementation
|
11 |
+
by Kynkaanniemi et al. at
|
12 |
+
https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py"""
|
13 |
+
|
14 |
+
import torch
|
15 |
+
from . import metric_utils
|
16 |
+
|
17 |
+
#----------------------------------------------------------------------------
|
18 |
+
|
19 |
+
def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size):
|
20 |
+
assert 0 <= rank < num_gpus
|
21 |
+
num_cols = col_features.shape[0]
|
22 |
+
num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus
|
23 |
+
col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches)
|
24 |
+
dist_batches = []
|
25 |
+
for col_batch in col_batches[rank :: num_gpus]:
|
26 |
+
dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0]
|
27 |
+
for src in range(num_gpus):
|
28 |
+
dist_broadcast = dist_batch.clone()
|
29 |
+
if num_gpus > 1:
|
30 |
+
torch.distributed.broadcast(dist_broadcast, src=src)
|
31 |
+
dist_batches.append(dist_broadcast.cpu() if rank == 0 else None)
|
32 |
+
return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None
|
33 |
+
|
34 |
+
#----------------------------------------------------------------------------
|
35 |
+
|
36 |
+
def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size):
|
37 |
+
detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
|
38 |
+
detector_kwargs = dict(return_features=True)
|
39 |
+
|
40 |
+
real_features = metric_utils.compute_feature_stats_for_dataset(
|
41 |
+
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
42 |
+
rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device)
|
43 |
+
|
44 |
+
gen_features = metric_utils.compute_feature_stats_for_generator(
|
45 |
+
opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
|
46 |
+
rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device)
|
47 |
+
|
48 |
+
results = dict()
|
49 |
+
for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]:
|
50 |
+
kth = []
|
51 |
+
for manifold_batch in manifold.split(row_batch_size):
|
52 |
+
dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
|
53 |
+
kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None)
|
54 |
+
kth = torch.cat(kth) if opts.rank == 0 else None
|
55 |
+
pred = []
|
56 |
+
for probes_batch in probes.split(row_batch_size):
|
57 |
+
dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
|
58 |
+
pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None)
|
59 |
+
results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan')
|
60 |
+
return results['precision'], results['recall']
|
61 |
+
|
62 |
+
#----------------------------------------------------------------------------
|
projector.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Project given image to the latent space of pretrained network pickle."""
|
10 |
+
|
11 |
+
import copy
|
12 |
+
import os
|
13 |
+
from time import perf_counter
|
14 |
+
|
15 |
+
import click
|
16 |
+
import imageio
|
17 |
+
import numpy as np
|
18 |
+
import PIL.Image
|
19 |
+
import torch
|
20 |
+
import torch.nn.functional as F
|
21 |
+
|
22 |
+
import dnnlib
|
23 |
+
import legacy
|
24 |
+
|
25 |
+
def project(
|
26 |
+
G,
|
27 |
+
target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution
|
28 |
+
*,
|
29 |
+
num_steps = 1000,
|
30 |
+
w_avg_samples = 10000,
|
31 |
+
initial_learning_rate = 0.1,
|
32 |
+
initial_noise_factor = 0.05,
|
33 |
+
lr_rampdown_length = 0.25,
|
34 |
+
lr_rampup_length = 0.05,
|
35 |
+
noise_ramp_length = 0.75,
|
36 |
+
regularize_noise_weight = 1e5,
|
37 |
+
verbose = False,
|
38 |
+
device: torch.device
|
39 |
+
):
|
40 |
+
assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution)
|
41 |
+
|
42 |
+
def logprint(*args):
|
43 |
+
if verbose:
|
44 |
+
print(*args)
|
45 |
+
|
46 |
+
G = copy.deepcopy(G).eval().requires_grad_(False).to(device) # type: ignore
|
47 |
+
|
48 |
+
# Compute w stats.
|
49 |
+
logprint(f'Computing W midpoint and stddev using {w_avg_samples} samples...')
|
50 |
+
z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim)
|
51 |
+
w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None) # [N, L, C]
|
52 |
+
w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C]
|
53 |
+
w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C]
|
54 |
+
w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5
|
55 |
+
|
56 |
+
# Setup noise inputs.
|
57 |
+
noise_bufs = { name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name }
|
58 |
+
|
59 |
+
# Load VGG16 feature detector.
|
60 |
+
url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
|
61 |
+
with dnnlib.util.open_url(url) as f:
|
62 |
+
vgg16 = torch.jit.load(f).eval().to(device)
|
63 |
+
|
64 |
+
# Features for target image.
|
65 |
+
target_images = target.unsqueeze(0).to(device).to(torch.float32)
|
66 |
+
if target_images.shape[2] > 256:
|
67 |
+
target_images = F.interpolate(target_images, size=(256, 256), mode='area')
|
68 |
+
target_features = vgg16(target_images, resize_images=False, return_lpips=True)
|
69 |
+
|
70 |
+
w_opt = torch.tensor(w_avg, dtype=torch.float32, device=device, requires_grad=True) # pylint: disable=not-callable
|
71 |
+
w_out = torch.zeros([num_steps] + list(w_opt.shape[1:]), dtype=torch.float32, device=device)
|
72 |
+
optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999), lr=initial_learning_rate)
|
73 |
+
|
74 |
+
# Init noise.
|
75 |
+
for buf in noise_bufs.values():
|
76 |
+
buf[:] = torch.randn_like(buf)
|
77 |
+
buf.requires_grad = True
|
78 |
+
|
79 |
+
for step in range(num_steps):
|
80 |
+
# Learning rate schedule.
|
81 |
+
t = step / num_steps
|
82 |
+
w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2
|
83 |
+
lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length)
|
84 |
+
lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
|
85 |
+
lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length)
|
86 |
+
lr = initial_learning_rate * lr_ramp
|
87 |
+
for param_group in optimizer.param_groups:
|
88 |
+
param_group['lr'] = lr
|
89 |
+
|
90 |
+
# Synth images from opt_w.
|
91 |
+
w_noise = torch.randn_like(w_opt) * w_noise_scale
|
92 |
+
ws = (w_opt + w_noise).repeat([1, G.mapping.num_ws, 1])
|
93 |
+
synth_images = G.synthesis(ws, noise_mode='const')
|
94 |
+
|
95 |
+
# Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
|
96 |
+
synth_images = (synth_images + 1) * (255/2)
|
97 |
+
if synth_images.shape[2] > 256:
|
98 |
+
synth_images = F.interpolate(synth_images, size=(256, 256), mode='area')
|
99 |
+
|
100 |
+
# Features for synth images.
|
101 |
+
synth_features = vgg16(synth_images, resize_images=False, return_lpips=True)
|
102 |
+
dist = (target_features - synth_features).square().sum()
|
103 |
+
|
104 |
+
# Noise regularization.
|
105 |
+
reg_loss = 0.0
|
106 |
+
for v in noise_bufs.values():
|
107 |
+
noise = v[None,None,:,:] # must be [1,1,H,W] for F.avg_pool2d()
|
108 |
+
while True:
|
109 |
+
reg_loss += (noise*torch.roll(noise, shifts=1, dims=3)).mean()**2
|
110 |
+
reg_loss += (noise*torch.roll(noise, shifts=1, dims=2)).mean()**2
|
111 |
+
if noise.shape[2] <= 8:
|
112 |
+
break
|
113 |
+
noise = F.avg_pool2d(noise, kernel_size=2)
|
114 |
+
loss = dist + reg_loss * regularize_noise_weight
|
115 |
+
|
116 |
+
# Step
|
117 |
+
optimizer.zero_grad(set_to_none=True)
|
118 |
+
loss.backward()
|
119 |
+
optimizer.step()
|
120 |
+
logprint(f'step {step+1:>4d}/{num_steps}: dist {dist:<4.2f} loss {float(loss):<5.2f}')
|
121 |
+
|
122 |
+
# Save projected W for each optimization step.
|
123 |
+
w_out[step] = w_opt.detach()[0]
|
124 |
+
|
125 |
+
# Normalize noise.
|
126 |
+
with torch.no_grad():
|
127 |
+
for buf in noise_bufs.values():
|
128 |
+
buf -= buf.mean()
|
129 |
+
buf *= buf.square().mean().rsqrt()
|
130 |
+
|
131 |
+
return w_out.repeat([1, G.mapping.num_ws, 1])
|
132 |
+
|
133 |
+
#----------------------------------------------------------------------------
|
134 |
+
|
135 |
+
@click.command()
|
136 |
+
@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
|
137 |
+
@click.option('--target', 'target_fname', help='Target image file to project to', required=True, metavar='FILE')
|
138 |
+
@click.option('--num-steps', help='Number of optimization steps', type=int, default=1000, show_default=True)
|
139 |
+
@click.option('--seed', help='Random seed', type=int, default=303, show_default=True)
|
140 |
+
@click.option('--save-video', help='Save an mp4 video of optimization progress', type=bool, default=True, show_default=True)
|
141 |
+
@click.option('--outdir', help='Where to save the output images', required=True, metavar='DIR')
|
142 |
+
def run_projection(
|
143 |
+
network_pkl: str,
|
144 |
+
target_fname: str,
|
145 |
+
outdir: str,
|
146 |
+
save_video: bool,
|
147 |
+
seed: int,
|
148 |
+
num_steps: int
|
149 |
+
):
|
150 |
+
"""Project given image to the latent space of pretrained network pickle.
|
151 |
+
|
152 |
+
Examples:
|
153 |
+
|
154 |
+
\b
|
155 |
+
python projector.py --outdir=out --target=~/mytargetimg.png \\
|
156 |
+
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl
|
157 |
+
"""
|
158 |
+
np.random.seed(seed)
|
159 |
+
torch.manual_seed(seed)
|
160 |
+
|
161 |
+
# Load networks.
|
162 |
+
print('Loading networks from "%s"...' % network_pkl)
|
163 |
+
device = torch.device('cuda')
|
164 |
+
with dnnlib.util.open_url(network_pkl) as fp:
|
165 |
+
G = legacy.load_network_pkl(fp)['G_ema'].requires_grad_(False).to(device) # type: ignore
|
166 |
+
|
167 |
+
# Load target image.
|
168 |
+
target_pil = PIL.Image.open(target_fname).convert('RGB')
|
169 |
+
w, h = target_pil.size
|
170 |
+
s = min(w, h)
|
171 |
+
target_pil = target_pil.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2))
|
172 |
+
target_pil = target_pil.resize((G.img_resolution, G.img_resolution), PIL.Image.LANCZOS)
|
173 |
+
target_uint8 = np.array(target_pil, dtype=np.uint8)
|
174 |
+
|
175 |
+
# Optimize projection.
|
176 |
+
start_time = perf_counter()
|
177 |
+
projected_w_steps = project(
|
178 |
+
G,
|
179 |
+
target=torch.tensor(target_uint8.transpose([2, 0, 1]), device=device), # pylint: disable=not-callable
|
180 |
+
num_steps=num_steps,
|
181 |
+
device=device,
|
182 |
+
verbose=True
|
183 |
+
)
|
184 |
+
print (f'Elapsed: {(perf_counter()-start_time):.1f} s')
|
185 |
+
|
186 |
+
# Render debug output: optional video and projected image and W vector.
|
187 |
+
os.makedirs(outdir, exist_ok=True)
|
188 |
+
if save_video:
|
189 |
+
video = imageio.get_writer(f'{outdir}/proj.mp4', mode='I', fps=10, codec='libx264', bitrate='16M')
|
190 |
+
print (f'Saving optimization progress video "{outdir}/proj.mp4"')
|
191 |
+
for projected_w in projected_w_steps:
|
192 |
+
synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const')
|
193 |
+
synth_image = (synth_image + 1) * (255/2)
|
194 |
+
synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
|
195 |
+
video.append_data(np.concatenate([target_uint8, synth_image], axis=1))
|
196 |
+
video.close()
|
197 |
+
|
198 |
+
# Save final projected frame and W vector.
|
199 |
+
target_pil.save(f'{outdir}/target.png')
|
200 |
+
projected_w = projected_w_steps[-1]
|
201 |
+
synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const')
|
202 |
+
synth_image = (synth_image + 1) * (255/2)
|
203 |
+
synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
|
204 |
+
PIL.Image.fromarray(synth_image, 'RGB').save(f'{outdir}/proj.png')
|
205 |
+
np.savez(f'{outdir}/projected_w.npz', w=projected_w.unsqueeze(0).cpu().numpy())
|
206 |
+
|
207 |
+
#----------------------------------------------------------------------------
|
208 |
+
|
209 |
+
if __name__ == "__main__":
|
210 |
+
run_projection() # pylint: disable=no-value-for-parameter
|
211 |
+
|
212 |
+
#----------------------------------------------------------------------------
|
style_mixing.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Generate style mixing image matrix using pretrained network pickle."""
|
10 |
+
|
11 |
+
import os
|
12 |
+
import re
|
13 |
+
from typing import List
|
14 |
+
|
15 |
+
import click
|
16 |
+
import dnnlib
|
17 |
+
import numpy as np
|
18 |
+
import PIL.Image
|
19 |
+
import torch
|
20 |
+
|
21 |
+
import legacy
|
22 |
+
|
23 |
+
#----------------------------------------------------------------------------
|
24 |
+
|
25 |
+
def num_range(s: str) -> List[int]:
|
26 |
+
'''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.'''
|
27 |
+
|
28 |
+
range_re = re.compile(r'^(\d+)-(\d+)$')
|
29 |
+
m = range_re.match(s)
|
30 |
+
if m:
|
31 |
+
return list(range(int(m.group(1)), int(m.group(2))+1))
|
32 |
+
vals = s.split(',')
|
33 |
+
return [int(x) for x in vals]
|
34 |
+
|
35 |
+
#----------------------------------------------------------------------------
|
36 |
+
|
37 |
+
@click.command()
|
38 |
+
@click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
|
39 |
+
@click.option('--rows', 'row_seeds', type=num_range, help='Random seeds to use for image rows', required=True)
|
40 |
+
@click.option('--cols', 'col_seeds', type=num_range, help='Random seeds to use for image columns', required=True)
|
41 |
+
@click.option('--styles', 'col_styles', type=num_range, help='Style layer range', default='0-6', show_default=True)
|
42 |
+
@click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
|
43 |
+
@click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
|
44 |
+
@click.option('--outdir', type=str, required=True)
|
45 |
+
def generate_style_mix(
|
46 |
+
network_pkl: str,
|
47 |
+
row_seeds: List[int],
|
48 |
+
col_seeds: List[int],
|
49 |
+
col_styles: List[int],
|
50 |
+
truncation_psi: float,
|
51 |
+
noise_mode: str,
|
52 |
+
outdir: str
|
53 |
+
):
|
54 |
+
"""Generate images using pretrained network pickle.
|
55 |
+
|
56 |
+
Examples:
|
57 |
+
|
58 |
+
\b
|
59 |
+
python style_mixing.py --outdir=out --rows=85,100,75,458,1500 --cols=55,821,1789,293 \\
|
60 |
+
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
|
61 |
+
"""
|
62 |
+
print('Loading networks from "%s"...' % network_pkl)
|
63 |
+
device = torch.device('cuda')
|
64 |
+
with dnnlib.util.open_url(network_pkl) as f:
|
65 |
+
G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
|
66 |
+
|
67 |
+
os.makedirs(outdir, exist_ok=True)
|
68 |
+
|
69 |
+
print('Generating W vectors...')
|
70 |
+
all_seeds = list(set(row_seeds + col_seeds))
|
71 |
+
all_z = np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in all_seeds])
|
72 |
+
all_w = G.mapping(torch.from_numpy(all_z).to(device), None)
|
73 |
+
w_avg = G.mapping.w_avg
|
74 |
+
all_w = w_avg + (all_w - w_avg) * truncation_psi
|
75 |
+
w_dict = {seed: w for seed, w in zip(all_seeds, list(all_w))}
|
76 |
+
|
77 |
+
print('Generating images...')
|
78 |
+
all_images = G.synthesis(all_w, noise_mode=noise_mode)
|
79 |
+
all_images = (all_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy()
|
80 |
+
image_dict = {(seed, seed): image for seed, image in zip(all_seeds, list(all_images))}
|
81 |
+
|
82 |
+
print('Generating style-mixed images...')
|
83 |
+
for row_seed in row_seeds:
|
84 |
+
for col_seed in col_seeds:
|
85 |
+
w = w_dict[row_seed].clone()
|
86 |
+
w[col_styles] = w_dict[col_seed][col_styles]
|
87 |
+
image = G.synthesis(w[np.newaxis], noise_mode=noise_mode)
|
88 |
+
image = (image.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
|
89 |
+
image_dict[(row_seed, col_seed)] = image[0].cpu().numpy()
|
90 |
+
|
91 |
+
print('Saving images...')
|
92 |
+
os.makedirs(outdir, exist_ok=True)
|
93 |
+
for (row_seed, col_seed), image in image_dict.items():
|
94 |
+
PIL.Image.fromarray(image, 'RGB').save(f'{outdir}/{row_seed}-{col_seed}.png')
|
95 |
+
|
96 |
+
print('Saving image grid...')
|
97 |
+
W = G.img_resolution
|
98 |
+
H = G.img_resolution
|
99 |
+
canvas = PIL.Image.new('RGB', (W * (len(col_seeds) + 1), H * (len(row_seeds) + 1)), 'black')
|
100 |
+
for row_idx, row_seed in enumerate([0] + row_seeds):
|
101 |
+
for col_idx, col_seed in enumerate([0] + col_seeds):
|
102 |
+
if row_idx == 0 and col_idx == 0:
|
103 |
+
continue
|
104 |
+
key = (row_seed, col_seed)
|
105 |
+
if row_idx == 0:
|
106 |
+
key = (col_seed, col_seed)
|
107 |
+
if col_idx == 0:
|
108 |
+
key = (row_seed, row_seed)
|
109 |
+
canvas.paste(PIL.Image.fromarray(image_dict[key], 'RGB'), (W * col_idx, H * row_idx))
|
110 |
+
canvas.save(f'{outdir}/grid.png')
|
111 |
+
|
112 |
+
|
113 |
+
#----------------------------------------------------------------------------
|
114 |
+
|
115 |
+
if __name__ == "__main__":
|
116 |
+
generate_style_mix() # pylint: disable=no-value-for-parameter
|
117 |
+
|
118 |
+
#----------------------------------------------------------------------------
|
torch_utils/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
# empty
|
torch_utils/custom_ops.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import os
|
10 |
+
import glob
|
11 |
+
import torch
|
12 |
+
import torch.utils.cpp_extension
|
13 |
+
import importlib
|
14 |
+
import hashlib
|
15 |
+
import shutil
|
16 |
+
from pathlib import Path
|
17 |
+
|
18 |
+
from torch.utils.file_baton import FileBaton
|
19 |
+
|
20 |
+
#----------------------------------------------------------------------------
|
21 |
+
# Global options.
|
22 |
+
|
23 |
+
verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
|
24 |
+
|
25 |
+
#----------------------------------------------------------------------------
|
26 |
+
# Internal helper funcs.
|
27 |
+
|
28 |
+
def _find_compiler_bindir():
|
29 |
+
patterns = [
|
30 |
+
'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
31 |
+
'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
32 |
+
'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
33 |
+
'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
|
34 |
+
]
|
35 |
+
for pattern in patterns:
|
36 |
+
matches = sorted(glob.glob(pattern))
|
37 |
+
if len(matches):
|
38 |
+
return matches[-1]
|
39 |
+
return None
|
40 |
+
|
41 |
+
#----------------------------------------------------------------------------
|
42 |
+
# Main entry point for compiling and loading C++/CUDA plugins.
|
43 |
+
|
44 |
+
_cached_plugins = dict()
|
45 |
+
|
46 |
+
def get_plugin(module_name, sources, **build_kwargs):
|
47 |
+
assert verbosity in ['none', 'brief', 'full']
|
48 |
+
|
49 |
+
# Already cached?
|
50 |
+
if module_name in _cached_plugins:
|
51 |
+
return _cached_plugins[module_name]
|
52 |
+
|
53 |
+
# Print status.
|
54 |
+
if verbosity == 'full':
|
55 |
+
print(f'Setting up PyTorch plugin "{module_name}"...')
|
56 |
+
elif verbosity == 'brief':
|
57 |
+
print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
|
58 |
+
|
59 |
+
try: # pylint: disable=too-many-nested-blocks
|
60 |
+
# Make sure we can find the necessary compiler binaries.
|
61 |
+
if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
|
62 |
+
compiler_bindir = _find_compiler_bindir()
|
63 |
+
if compiler_bindir is None:
|
64 |
+
raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
|
65 |
+
os.environ['PATH'] += ';' + compiler_bindir
|
66 |
+
|
67 |
+
# Compile and load.
|
68 |
+
verbose_build = (verbosity == 'full')
|
69 |
+
|
70 |
+
# Incremental build md5sum trickery. Copies all the input source files
|
71 |
+
# into a cached build directory under a combined md5 digest of the input
|
72 |
+
# source files. Copying is done only if the combined digest has changed.
|
73 |
+
# This keeps input file timestamps and filenames the same as in previous
|
74 |
+
# extension builds, allowing for fast incremental rebuilds.
|
75 |
+
#
|
76 |
+
# This optimization is done only in case all the source files reside in
|
77 |
+
# a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
|
78 |
+
# environment variable is set (we take this as a signal that the user
|
79 |
+
# actually cares about this.)
|
80 |
+
source_dirs_set = set(os.path.dirname(source) for source in sources)
|
81 |
+
if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ):
|
82 |
+
all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file()))
|
83 |
+
|
84 |
+
# Compute a combined hash digest for all source files in the same
|
85 |
+
# custom op directory (usually .cu, .cpp, .py and .h files).
|
86 |
+
hash_md5 = hashlib.md5()
|
87 |
+
for src in all_source_files:
|
88 |
+
with open(src, 'rb') as f:
|
89 |
+
hash_md5.update(f.read())
|
90 |
+
build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
|
91 |
+
digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest())
|
92 |
+
|
93 |
+
if not os.path.isdir(digest_build_dir):
|
94 |
+
os.makedirs(digest_build_dir, exist_ok=True)
|
95 |
+
baton = FileBaton(os.path.join(digest_build_dir, 'lock'))
|
96 |
+
if baton.try_acquire():
|
97 |
+
try:
|
98 |
+
for src in all_source_files:
|
99 |
+
shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src)))
|
100 |
+
finally:
|
101 |
+
baton.release()
|
102 |
+
else:
|
103 |
+
# Someone else is copying source files under the digest dir,
|
104 |
+
# wait until done and continue.
|
105 |
+
baton.wait()
|
106 |
+
digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources]
|
107 |
+
torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir,
|
108 |
+
verbose=verbose_build, sources=digest_sources, **build_kwargs)
|
109 |
+
else:
|
110 |
+
torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
|
111 |
+
module = importlib.import_module(module_name)
|
112 |
+
|
113 |
+
except:
|
114 |
+
if verbosity == 'brief':
|
115 |
+
print('Failed!')
|
116 |
+
raise
|
117 |
+
|
118 |
+
# Print status and add to cache.
|
119 |
+
if verbosity == 'full':
|
120 |
+
print(f'Done setting up PyTorch plugin "{module_name}".')
|
121 |
+
elif verbosity == 'brief':
|
122 |
+
print('Done.')
|
123 |
+
_cached_plugins[module_name] = module
|
124 |
+
return module
|
125 |
+
|
126 |
+
#----------------------------------------------------------------------------
|
torch_utils/misc.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import re
|
10 |
+
import contextlib
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import warnings
|
14 |
+
import dnnlib
|
15 |
+
|
16 |
+
#----------------------------------------------------------------------------
|
17 |
+
# Cached construction of constant tensors. Avoids CPU=>GPU copy when the
|
18 |
+
# same constant is used multiple times.
|
19 |
+
|
20 |
+
_constant_cache = dict()
|
21 |
+
|
22 |
+
def constant(value, shape=None, dtype=None, device=None, memory_format=None):
|
23 |
+
value = np.asarray(value)
|
24 |
+
if shape is not None:
|
25 |
+
shape = tuple(shape)
|
26 |
+
if dtype is None:
|
27 |
+
dtype = torch.get_default_dtype()
|
28 |
+
if device is None:
|
29 |
+
device = torch.device('cpu')
|
30 |
+
if memory_format is None:
|
31 |
+
memory_format = torch.contiguous_format
|
32 |
+
|
33 |
+
key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
|
34 |
+
tensor = _constant_cache.get(key, None)
|
35 |
+
if tensor is None:
|
36 |
+
tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
|
37 |
+
if shape is not None:
|
38 |
+
tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
|
39 |
+
tensor = tensor.contiguous(memory_format=memory_format)
|
40 |
+
_constant_cache[key] = tensor
|
41 |
+
return tensor
|
42 |
+
|
43 |
+
#----------------------------------------------------------------------------
|
44 |
+
# Replace NaN/Inf with specified numerical values.
|
45 |
+
|
46 |
+
try:
|
47 |
+
nan_to_num = torch.nan_to_num # 1.8.0a0
|
48 |
+
except AttributeError:
|
49 |
+
def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
|
50 |
+
assert isinstance(input, torch.Tensor)
|
51 |
+
if posinf is None:
|
52 |
+
posinf = torch.finfo(input.dtype).max
|
53 |
+
if neginf is None:
|
54 |
+
neginf = torch.finfo(input.dtype).min
|
55 |
+
assert nan == 0
|
56 |
+
return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
|
57 |
+
|
58 |
+
#----------------------------------------------------------------------------
|
59 |
+
# Symbolic assert.
|
60 |
+
|
61 |
+
try:
|
62 |
+
symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
|
63 |
+
except AttributeError:
|
64 |
+
symbolic_assert = torch.Assert # 1.7.0
|
65 |
+
|
66 |
+
#----------------------------------------------------------------------------
|
67 |
+
# Context manager to suppress known warnings in torch.jit.trace().
|
68 |
+
|
69 |
+
class suppress_tracer_warnings(warnings.catch_warnings):
|
70 |
+
def __enter__(self):
|
71 |
+
super().__enter__()
|
72 |
+
warnings.simplefilter('ignore', category=torch.jit.TracerWarning)
|
73 |
+
return self
|
74 |
+
|
75 |
+
#----------------------------------------------------------------------------
|
76 |
+
# Assert that the shape of a tensor matches the given list of integers.
|
77 |
+
# None indicates that the size of a dimension is allowed to vary.
|
78 |
+
# Performs symbolic assertion when used in torch.jit.trace().
|
79 |
+
|
80 |
+
def assert_shape(tensor, ref_shape):
|
81 |
+
if tensor.ndim != len(ref_shape):
|
82 |
+
raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
|
83 |
+
for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
|
84 |
+
if ref_size is None:
|
85 |
+
pass
|
86 |
+
elif isinstance(ref_size, torch.Tensor):
|
87 |
+
with suppress_tracer_warnings(): # as_tensor results are registered as constants
|
88 |
+
symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
|
89 |
+
elif isinstance(size, torch.Tensor):
|
90 |
+
with suppress_tracer_warnings(): # as_tensor results are registered as constants
|
91 |
+
symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
|
92 |
+
elif size != ref_size:
|
93 |
+
raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
|
94 |
+
|
95 |
+
#----------------------------------------------------------------------------
|
96 |
+
# Function decorator that calls torch.autograd.profiler.record_function().
|
97 |
+
|
98 |
+
def profiled_function(fn):
|
99 |
+
def decorator(*args, **kwargs):
|
100 |
+
with torch.autograd.profiler.record_function(fn.__name__):
|
101 |
+
return fn(*args, **kwargs)
|
102 |
+
decorator.__name__ = fn.__name__
|
103 |
+
return decorator
|
104 |
+
|
105 |
+
#----------------------------------------------------------------------------
|
106 |
+
# Sampler for torch.utils.data.DataLoader that loops over the dataset
|
107 |
+
# indefinitely, shuffling items as it goes.
|
108 |
+
|
109 |
+
class InfiniteSampler(torch.utils.data.Sampler):
|
110 |
+
def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
|
111 |
+
assert len(dataset) > 0
|
112 |
+
assert num_replicas > 0
|
113 |
+
assert 0 <= rank < num_replicas
|
114 |
+
assert 0 <= window_size <= 1
|
115 |
+
super().__init__(dataset)
|
116 |
+
self.dataset = dataset
|
117 |
+
self.rank = rank
|
118 |
+
self.num_replicas = num_replicas
|
119 |
+
self.shuffle = shuffle
|
120 |
+
self.seed = seed
|
121 |
+
self.window_size = window_size
|
122 |
+
|
123 |
+
def __iter__(self):
|
124 |
+
order = np.arange(len(self.dataset))
|
125 |
+
rnd = None
|
126 |
+
window = 0
|
127 |
+
if self.shuffle:
|
128 |
+
rnd = np.random.RandomState(self.seed)
|
129 |
+
rnd.shuffle(order)
|
130 |
+
window = int(np.rint(order.size * self.window_size))
|
131 |
+
|
132 |
+
idx = 0
|
133 |
+
while True:
|
134 |
+
i = idx % order.size
|
135 |
+
if idx % self.num_replicas == self.rank:
|
136 |
+
yield order[i]
|
137 |
+
if window >= 2:
|
138 |
+
j = (i - rnd.randint(window)) % order.size
|
139 |
+
order[i], order[j] = order[j], order[i]
|
140 |
+
idx += 1
|
141 |
+
|
142 |
+
#----------------------------------------------------------------------------
|
143 |
+
# Utilities for operating with torch.nn.Module parameters and buffers.
|
144 |
+
|
145 |
+
def params_and_buffers(module):
|
146 |
+
assert isinstance(module, torch.nn.Module)
|
147 |
+
return list(module.parameters()) + list(module.buffers())
|
148 |
+
|
149 |
+
def named_params_and_buffers(module):
|
150 |
+
assert isinstance(module, torch.nn.Module)
|
151 |
+
return list(module.named_parameters()) + list(module.named_buffers())
|
152 |
+
|
153 |
+
def copy_params_and_buffers(src_module, dst_module, require_all=False):
|
154 |
+
assert isinstance(src_module, torch.nn.Module)
|
155 |
+
assert isinstance(dst_module, torch.nn.Module)
|
156 |
+
src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)}
|
157 |
+
for name, tensor in named_params_and_buffers(dst_module):
|
158 |
+
assert (name in src_tensors) or (not require_all)
|
159 |
+
if name in src_tensors:
|
160 |
+
tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
|
161 |
+
|
162 |
+
#----------------------------------------------------------------------------
|
163 |
+
# Context manager for easily enabling/disabling DistributedDataParallel
|
164 |
+
# synchronization.
|
165 |
+
|
166 |
+
@contextlib.contextmanager
|
167 |
+
def ddp_sync(module, sync):
|
168 |
+
assert isinstance(module, torch.nn.Module)
|
169 |
+
if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
|
170 |
+
yield
|
171 |
+
else:
|
172 |
+
with module.no_sync():
|
173 |
+
yield
|
174 |
+
|
175 |
+
#----------------------------------------------------------------------------
|
176 |
+
# Check DistributedDataParallel consistency across processes.
|
177 |
+
|
178 |
+
def check_ddp_consistency(module, ignore_regex=None):
|
179 |
+
assert isinstance(module, torch.nn.Module)
|
180 |
+
for name, tensor in named_params_and_buffers(module):
|
181 |
+
fullname = type(module).__name__ + '.' + name
|
182 |
+
if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
|
183 |
+
continue
|
184 |
+
tensor = tensor.detach()
|
185 |
+
other = tensor.clone()
|
186 |
+
torch.distributed.broadcast(tensor=other, src=0)
|
187 |
+
assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname
|
188 |
+
|
189 |
+
#----------------------------------------------------------------------------
|
190 |
+
# Print summary table of module hierarchy.
|
191 |
+
|
192 |
+
def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
|
193 |
+
assert isinstance(module, torch.nn.Module)
|
194 |
+
assert not isinstance(module, torch.jit.ScriptModule)
|
195 |
+
assert isinstance(inputs, (tuple, list))
|
196 |
+
|
197 |
+
# Register hooks.
|
198 |
+
entries = []
|
199 |
+
nesting = [0]
|
200 |
+
def pre_hook(_mod, _inputs):
|
201 |
+
nesting[0] += 1
|
202 |
+
def post_hook(mod, _inputs, outputs):
|
203 |
+
nesting[0] -= 1
|
204 |
+
if nesting[0] <= max_nesting:
|
205 |
+
outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
|
206 |
+
outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
|
207 |
+
entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
|
208 |
+
hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
|
209 |
+
hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
|
210 |
+
|
211 |
+
# Run module.
|
212 |
+
outputs = module(*inputs)
|
213 |
+
for hook in hooks:
|
214 |
+
hook.remove()
|
215 |
+
|
216 |
+
# Identify unique outputs, parameters, and buffers.
|
217 |
+
tensors_seen = set()
|
218 |
+
for e in entries:
|
219 |
+
e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
|
220 |
+
e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
|
221 |
+
e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
|
222 |
+
tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
|
223 |
+
|
224 |
+
# Filter out redundant entries.
|
225 |
+
if skip_redundant:
|
226 |
+
entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
|
227 |
+
|
228 |
+
# Construct table.
|
229 |
+
rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
|
230 |
+
rows += [['---'] * len(rows[0])]
|
231 |
+
param_total = 0
|
232 |
+
buffer_total = 0
|
233 |
+
submodule_names = {mod: name for name, mod in module.named_modules()}
|
234 |
+
for e in entries:
|
235 |
+
name = '<top-level>' if e.mod is module else submodule_names[e.mod]
|
236 |
+
param_size = sum(t.numel() for t in e.unique_params)
|
237 |
+
buffer_size = sum(t.numel() for t in e.unique_buffers)
|
238 |
+
output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs]
|
239 |
+
output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
|
240 |
+
rows += [[
|
241 |
+
name + (':0' if len(e.outputs) >= 2 else ''),
|
242 |
+
str(param_size) if param_size else '-',
|
243 |
+
str(buffer_size) if buffer_size else '-',
|
244 |
+
(output_shapes + ['-'])[0],
|
245 |
+
(output_dtypes + ['-'])[0],
|
246 |
+
]]
|
247 |
+
for idx in range(1, len(e.outputs)):
|
248 |
+
rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
|
249 |
+
param_total += param_size
|
250 |
+
buffer_total += buffer_size
|
251 |
+
rows += [['---'] * len(rows[0])]
|
252 |
+
rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
|
253 |
+
|
254 |
+
# Print table.
|
255 |
+
widths = [max(len(cell) for cell in column) for column in zip(*rows)]
|
256 |
+
print()
|
257 |
+
for row in rows:
|
258 |
+
print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
|
259 |
+
print()
|
260 |
+
return outputs
|
261 |
+
|
262 |
+
#----------------------------------------------------------------------------
|
torch_utils/ops/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
# empty
|
torch_utils/ops/bias_act.cpp
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#include <torch/extension.h>
|
10 |
+
#include <ATen/cuda/CUDAContext.h>
|
11 |
+
#include <c10/cuda/CUDAGuard.h>
|
12 |
+
#include "bias_act.h"
|
13 |
+
|
14 |
+
//------------------------------------------------------------------------
|
15 |
+
|
16 |
+
static bool has_same_layout(torch::Tensor x, torch::Tensor y)
|
17 |
+
{
|
18 |
+
if (x.dim() != y.dim())
|
19 |
+
return false;
|
20 |
+
for (int64_t i = 0; i < x.dim(); i++)
|
21 |
+
{
|
22 |
+
if (x.size(i) != y.size(i))
|
23 |
+
return false;
|
24 |
+
if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
|
25 |
+
return false;
|
26 |
+
}
|
27 |
+
return true;
|
28 |
+
}
|
29 |
+
|
30 |
+
//------------------------------------------------------------------------
|
31 |
+
|
32 |
+
static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
|
33 |
+
{
|
34 |
+
// Validate arguments.
|
35 |
+
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
|
36 |
+
TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
|
37 |
+
TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
|
38 |
+
TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
|
39 |
+
TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
|
40 |
+
TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
|
41 |
+
TORCH_CHECK(b.dim() == 1, "b must have rank 1");
|
42 |
+
TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
|
43 |
+
TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
|
44 |
+
TORCH_CHECK(grad >= 0, "grad must be non-negative");
|
45 |
+
|
46 |
+
// Validate layout.
|
47 |
+
TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
|
48 |
+
TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
|
49 |
+
TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
|
50 |
+
TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
|
51 |
+
TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
|
52 |
+
|
53 |
+
// Create output tensor.
|
54 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
55 |
+
torch::Tensor y = torch::empty_like(x);
|
56 |
+
TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
|
57 |
+
|
58 |
+
// Initialize CUDA kernel parameters.
|
59 |
+
bias_act_kernel_params p;
|
60 |
+
p.x = x.data_ptr();
|
61 |
+
p.b = (b.numel()) ? b.data_ptr() : NULL;
|
62 |
+
p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
|
63 |
+
p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
|
64 |
+
p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
|
65 |
+
p.y = y.data_ptr();
|
66 |
+
p.grad = grad;
|
67 |
+
p.act = act;
|
68 |
+
p.alpha = alpha;
|
69 |
+
p.gain = gain;
|
70 |
+
p.clamp = clamp;
|
71 |
+
p.sizeX = (int)x.numel();
|
72 |
+
p.sizeB = (int)b.numel();
|
73 |
+
p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
|
74 |
+
|
75 |
+
// Choose CUDA kernel.
|
76 |
+
void* kernel;
|
77 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
|
78 |
+
{
|
79 |
+
kernel = choose_bias_act_kernel<scalar_t>(p);
|
80 |
+
});
|
81 |
+
TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
|
82 |
+
|
83 |
+
// Launch CUDA kernel.
|
84 |
+
p.loopX = 4;
|
85 |
+
int blockSize = 4 * 32;
|
86 |
+
int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
|
87 |
+
void* args[] = {&p};
|
88 |
+
AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
|
89 |
+
return y;
|
90 |
+
}
|
91 |
+
|
92 |
+
//------------------------------------------------------------------------
|
93 |
+
|
94 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
95 |
+
{
|
96 |
+
m.def("bias_act", &bias_act);
|
97 |
+
}
|
98 |
+
|
99 |
+
//------------------------------------------------------------------------
|
torch_utils/ops/bias_act.cu
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#include <c10/util/Half.h>
|
10 |
+
#include "bias_act.h"
|
11 |
+
|
12 |
+
//------------------------------------------------------------------------
|
13 |
+
// Helpers.
|
14 |
+
|
15 |
+
template <class T> struct InternalType;
|
16 |
+
template <> struct InternalType<double> { typedef double scalar_t; };
|
17 |
+
template <> struct InternalType<float> { typedef float scalar_t; };
|
18 |
+
template <> struct InternalType<c10::Half> { typedef float scalar_t; };
|
19 |
+
|
20 |
+
//------------------------------------------------------------------------
|
21 |
+
// CUDA kernel.
|
22 |
+
|
23 |
+
template <class T, int A>
|
24 |
+
__global__ void bias_act_kernel(bias_act_kernel_params p)
|
25 |
+
{
|
26 |
+
typedef typename InternalType<T>::scalar_t scalar_t;
|
27 |
+
int G = p.grad;
|
28 |
+
scalar_t alpha = (scalar_t)p.alpha;
|
29 |
+
scalar_t gain = (scalar_t)p.gain;
|
30 |
+
scalar_t clamp = (scalar_t)p.clamp;
|
31 |
+
scalar_t one = (scalar_t)1;
|
32 |
+
scalar_t two = (scalar_t)2;
|
33 |
+
scalar_t expRange = (scalar_t)80;
|
34 |
+
scalar_t halfExpRange = (scalar_t)40;
|
35 |
+
scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
|
36 |
+
scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
|
37 |
+
|
38 |
+
// Loop over elements.
|
39 |
+
int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
|
40 |
+
for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
|
41 |
+
{
|
42 |
+
// Load.
|
43 |
+
scalar_t x = (scalar_t)((const T*)p.x)[xi];
|
44 |
+
scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
|
45 |
+
scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
|
46 |
+
scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
|
47 |
+
scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
|
48 |
+
scalar_t yy = (gain != 0) ? yref / gain : 0;
|
49 |
+
scalar_t y = 0;
|
50 |
+
|
51 |
+
// Apply bias.
|
52 |
+
((G == 0) ? x : xref) += b;
|
53 |
+
|
54 |
+
// linear
|
55 |
+
if (A == 1)
|
56 |
+
{
|
57 |
+
if (G == 0) y = x;
|
58 |
+
if (G == 1) y = x;
|
59 |
+
}
|
60 |
+
|
61 |
+
// relu
|
62 |
+
if (A == 2)
|
63 |
+
{
|
64 |
+
if (G == 0) y = (x > 0) ? x : 0;
|
65 |
+
if (G == 1) y = (yy > 0) ? x : 0;
|
66 |
+
}
|
67 |
+
|
68 |
+
// lrelu
|
69 |
+
if (A == 3)
|
70 |
+
{
|
71 |
+
if (G == 0) y = (x > 0) ? x : x * alpha;
|
72 |
+
if (G == 1) y = (yy > 0) ? x : x * alpha;
|
73 |
+
}
|
74 |
+
|
75 |
+
// tanh
|
76 |
+
if (A == 4)
|
77 |
+
{
|
78 |
+
if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
|
79 |
+
if (G == 1) y = x * (one - yy * yy);
|
80 |
+
if (G == 2) y = x * (one - yy * yy) * (-two * yy);
|
81 |
+
}
|
82 |
+
|
83 |
+
// sigmoid
|
84 |
+
if (A == 5)
|
85 |
+
{
|
86 |
+
if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
|
87 |
+
if (G == 1) y = x * yy * (one - yy);
|
88 |
+
if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
|
89 |
+
}
|
90 |
+
|
91 |
+
// elu
|
92 |
+
if (A == 6)
|
93 |
+
{
|
94 |
+
if (G == 0) y = (x >= 0) ? x : exp(x) - one;
|
95 |
+
if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
|
96 |
+
if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
|
97 |
+
}
|
98 |
+
|
99 |
+
// selu
|
100 |
+
if (A == 7)
|
101 |
+
{
|
102 |
+
if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
|
103 |
+
if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
|
104 |
+
if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
|
105 |
+
}
|
106 |
+
|
107 |
+
// softplus
|
108 |
+
if (A == 8)
|
109 |
+
{
|
110 |
+
if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
|
111 |
+
if (G == 1) y = x * (one - exp(-yy));
|
112 |
+
if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
|
113 |
+
}
|
114 |
+
|
115 |
+
// swish
|
116 |
+
if (A == 9)
|
117 |
+
{
|
118 |
+
if (G == 0)
|
119 |
+
y = (x < -expRange) ? 0 : x / (exp(-x) + one);
|
120 |
+
else
|
121 |
+
{
|
122 |
+
scalar_t c = exp(xref);
|
123 |
+
scalar_t d = c + one;
|
124 |
+
if (G == 1)
|
125 |
+
y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
|
126 |
+
else
|
127 |
+
y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
|
128 |
+
yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
|
129 |
+
}
|
130 |
+
}
|
131 |
+
|
132 |
+
// Apply gain.
|
133 |
+
y *= gain * dy;
|
134 |
+
|
135 |
+
// Clamp.
|
136 |
+
if (clamp >= 0)
|
137 |
+
{
|
138 |
+
if (G == 0)
|
139 |
+
y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
|
140 |
+
else
|
141 |
+
y = (yref > -clamp & yref < clamp) ? y : 0;
|
142 |
+
}
|
143 |
+
|
144 |
+
// Store.
|
145 |
+
((T*)p.y)[xi] = (T)y;
|
146 |
+
}
|
147 |
+
}
|
148 |
+
|
149 |
+
//------------------------------------------------------------------------
|
150 |
+
// CUDA kernel selection.
|
151 |
+
|
152 |
+
template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p)
|
153 |
+
{
|
154 |
+
if (p.act == 1) return (void*)bias_act_kernel<T, 1>;
|
155 |
+
if (p.act == 2) return (void*)bias_act_kernel<T, 2>;
|
156 |
+
if (p.act == 3) return (void*)bias_act_kernel<T, 3>;
|
157 |
+
if (p.act == 4) return (void*)bias_act_kernel<T, 4>;
|
158 |
+
if (p.act == 5) return (void*)bias_act_kernel<T, 5>;
|
159 |
+
if (p.act == 6) return (void*)bias_act_kernel<T, 6>;
|
160 |
+
if (p.act == 7) return (void*)bias_act_kernel<T, 7>;
|
161 |
+
if (p.act == 8) return (void*)bias_act_kernel<T, 8>;
|
162 |
+
if (p.act == 9) return (void*)bias_act_kernel<T, 9>;
|
163 |
+
return NULL;
|
164 |
+
}
|
165 |
+
|
166 |
+
//------------------------------------------------------------------------
|
167 |
+
// Template specializations.
|
168 |
+
|
169 |
+
template void* choose_bias_act_kernel<double> (const bias_act_kernel_params& p);
|
170 |
+
template void* choose_bias_act_kernel<float> (const bias_act_kernel_params& p);
|
171 |
+
template void* choose_bias_act_kernel<c10::Half> (const bias_act_kernel_params& p);
|
172 |
+
|
173 |
+
//------------------------------------------------------------------------
|
torch_utils/ops/bias_act.h
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
//------------------------------------------------------------------------
|
10 |
+
// CUDA kernel parameters.
|
11 |
+
|
12 |
+
struct bias_act_kernel_params
|
13 |
+
{
|
14 |
+
const void* x; // [sizeX]
|
15 |
+
const void* b; // [sizeB] or NULL
|
16 |
+
const void* xref; // [sizeX] or NULL
|
17 |
+
const void* yref; // [sizeX] or NULL
|
18 |
+
const void* dy; // [sizeX] or NULL
|
19 |
+
void* y; // [sizeX]
|
20 |
+
|
21 |
+
int grad;
|
22 |
+
int act;
|
23 |
+
float alpha;
|
24 |
+
float gain;
|
25 |
+
float clamp;
|
26 |
+
|
27 |
+
int sizeX;
|
28 |
+
int sizeB;
|
29 |
+
int stepB;
|
30 |
+
int loopX;
|
31 |
+
};
|
32 |
+
|
33 |
+
//------------------------------------------------------------------------
|
34 |
+
// CUDA kernel selection.
|
35 |
+
|
36 |
+
template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p);
|
37 |
+
|
38 |
+
//------------------------------------------------------------------------
|
torch_utils/ops/bias_act.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Custom PyTorch ops for efficient bias and activation."""
|
10 |
+
|
11 |
+
import os
|
12 |
+
import warnings
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
import dnnlib
|
16 |
+
import traceback
|
17 |
+
|
18 |
+
from .. import custom_ops
|
19 |
+
from .. import misc
|
20 |
+
|
21 |
+
#----------------------------------------------------------------------------
|
22 |
+
|
23 |
+
activation_funcs = {
|
24 |
+
'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
|
25 |
+
'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
|
26 |
+
'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
|
27 |
+
'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
|
28 |
+
'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
|
29 |
+
'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
|
30 |
+
'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
|
31 |
+
'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
|
32 |
+
'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
|
33 |
+
}
|
34 |
+
|
35 |
+
#----------------------------------------------------------------------------
|
36 |
+
|
37 |
+
_inited = False
|
38 |
+
_plugin = None
|
39 |
+
_null_tensor = torch.empty([0])
|
40 |
+
|
41 |
+
def _init():
|
42 |
+
global _inited, _plugin
|
43 |
+
if not _inited:
|
44 |
+
_inited = True
|
45 |
+
sources = ['bias_act.cpp', 'bias_act.cu']
|
46 |
+
sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
|
47 |
+
try:
|
48 |
+
_plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
|
49 |
+
except:
|
50 |
+
warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
|
51 |
+
return _plugin is not None
|
52 |
+
|
53 |
+
#----------------------------------------------------------------------------
|
54 |
+
|
55 |
+
def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
|
56 |
+
r"""Fused bias and activation function.
|
57 |
+
|
58 |
+
Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
|
59 |
+
and scales the result by `gain`. Each of the steps is optional. In most cases,
|
60 |
+
the fused op is considerably more efficient than performing the same calculation
|
61 |
+
using standard PyTorch ops. It supports first and second order gradients,
|
62 |
+
but not third order gradients.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
x: Input activation tensor. Can be of any shape.
|
66 |
+
b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
|
67 |
+
as `x`. The shape must be known, and it must match the dimension of `x`
|
68 |
+
corresponding to `dim`.
|
69 |
+
dim: The dimension in `x` corresponding to the elements of `b`.
|
70 |
+
The value of `dim` is ignored if `b` is not specified.
|
71 |
+
act: Name of the activation function to evaluate, or `"linear"` to disable.
|
72 |
+
Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
|
73 |
+
See `activation_funcs` for a full list. `None` is not allowed.
|
74 |
+
alpha: Shape parameter for the activation function, or `None` to use the default.
|
75 |
+
gain: Scaling factor for the output tensor, or `None` to use default.
|
76 |
+
See `activation_funcs` for the default scaling of each activation function.
|
77 |
+
If unsure, consider specifying 1.
|
78 |
+
clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
|
79 |
+
the clamping (default).
|
80 |
+
impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
Tensor of the same shape and datatype as `x`.
|
84 |
+
"""
|
85 |
+
assert isinstance(x, torch.Tensor)
|
86 |
+
assert impl in ['ref', 'cuda']
|
87 |
+
if impl == 'cuda' and x.device.type == 'cuda' and _init():
|
88 |
+
return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
|
89 |
+
return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
|
90 |
+
|
91 |
+
#----------------------------------------------------------------------------
|
92 |
+
|
93 |
+
@misc.profiled_function
|
94 |
+
def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
|
95 |
+
"""Slow reference implementation of `bias_act()` using standard TensorFlow ops.
|
96 |
+
"""
|
97 |
+
assert isinstance(x, torch.Tensor)
|
98 |
+
assert clamp is None or clamp >= 0
|
99 |
+
spec = activation_funcs[act]
|
100 |
+
alpha = float(alpha if alpha is not None else spec.def_alpha)
|
101 |
+
gain = float(gain if gain is not None else spec.def_gain)
|
102 |
+
clamp = float(clamp if clamp is not None else -1)
|
103 |
+
|
104 |
+
# Add bias.
|
105 |
+
if b is not None:
|
106 |
+
assert isinstance(b, torch.Tensor) and b.ndim == 1
|
107 |
+
assert 0 <= dim < x.ndim
|
108 |
+
assert b.shape[0] == x.shape[dim]
|
109 |
+
x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
|
110 |
+
|
111 |
+
# Evaluate activation function.
|
112 |
+
alpha = float(alpha)
|
113 |
+
x = spec.func(x, alpha=alpha)
|
114 |
+
|
115 |
+
# Scale by gain.
|
116 |
+
gain = float(gain)
|
117 |
+
if gain != 1:
|
118 |
+
x = x * gain
|
119 |
+
|
120 |
+
# Clamp.
|
121 |
+
if clamp >= 0:
|
122 |
+
x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
|
123 |
+
return x
|
124 |
+
|
125 |
+
#----------------------------------------------------------------------------
|
126 |
+
|
127 |
+
_bias_act_cuda_cache = dict()
|
128 |
+
|
129 |
+
def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
|
130 |
+
"""Fast CUDA implementation of `bias_act()` using custom ops.
|
131 |
+
"""
|
132 |
+
# Parse arguments.
|
133 |
+
assert clamp is None or clamp >= 0
|
134 |
+
spec = activation_funcs[act]
|
135 |
+
alpha = float(alpha if alpha is not None else spec.def_alpha)
|
136 |
+
gain = float(gain if gain is not None else spec.def_gain)
|
137 |
+
clamp = float(clamp if clamp is not None else -1)
|
138 |
+
|
139 |
+
# Lookup from cache.
|
140 |
+
key = (dim, act, alpha, gain, clamp)
|
141 |
+
if key in _bias_act_cuda_cache:
|
142 |
+
return _bias_act_cuda_cache[key]
|
143 |
+
|
144 |
+
# Forward op.
|
145 |
+
class BiasActCuda(torch.autograd.Function):
|
146 |
+
@staticmethod
|
147 |
+
def forward(ctx, x, b): # pylint: disable=arguments-differ
|
148 |
+
ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format
|
149 |
+
x = x.contiguous(memory_format=ctx.memory_format)
|
150 |
+
b = b.contiguous() if b is not None else _null_tensor
|
151 |
+
y = x
|
152 |
+
if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
|
153 |
+
y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
|
154 |
+
ctx.save_for_backward(
|
155 |
+
x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
|
156 |
+
b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
|
157 |
+
y if 'y' in spec.ref else _null_tensor)
|
158 |
+
return y
|
159 |
+
|
160 |
+
@staticmethod
|
161 |
+
def backward(ctx, dy): # pylint: disable=arguments-differ
|
162 |
+
dy = dy.contiguous(memory_format=ctx.memory_format)
|
163 |
+
x, b, y = ctx.saved_tensors
|
164 |
+
dx = None
|
165 |
+
db = None
|
166 |
+
|
167 |
+
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
|
168 |
+
dx = dy
|
169 |
+
if act != 'linear' or gain != 1 or clamp >= 0:
|
170 |
+
dx = BiasActCudaGrad.apply(dy, x, b, y)
|
171 |
+
|
172 |
+
if ctx.needs_input_grad[1]:
|
173 |
+
db = dx.sum([i for i in range(dx.ndim) if i != dim])
|
174 |
+
|
175 |
+
return dx, db
|
176 |
+
|
177 |
+
# Backward op.
|
178 |
+
class BiasActCudaGrad(torch.autograd.Function):
|
179 |
+
@staticmethod
|
180 |
+
def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
|
181 |
+
ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format
|
182 |
+
dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
|
183 |
+
ctx.save_for_backward(
|
184 |
+
dy if spec.has_2nd_grad else _null_tensor,
|
185 |
+
x, b, y)
|
186 |
+
return dx
|
187 |
+
|
188 |
+
@staticmethod
|
189 |
+
def backward(ctx, d_dx): # pylint: disable=arguments-differ
|
190 |
+
d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
|
191 |
+
dy, x, b, y = ctx.saved_tensors
|
192 |
+
d_dy = None
|
193 |
+
d_x = None
|
194 |
+
d_b = None
|
195 |
+
d_y = None
|
196 |
+
|
197 |
+
if ctx.needs_input_grad[0]:
|
198 |
+
d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
|
199 |
+
|
200 |
+
if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
|
201 |
+
d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
|
202 |
+
|
203 |
+
if spec.has_2nd_grad and ctx.needs_input_grad[2]:
|
204 |
+
d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
|
205 |
+
|
206 |
+
return d_dy, d_x, d_b, d_y
|
207 |
+
|
208 |
+
# Add to cache.
|
209 |
+
_bias_act_cuda_cache[key] = BiasActCuda
|
210 |
+
return BiasActCuda
|
211 |
+
|
212 |
+
#----------------------------------------------------------------------------
|
torch_utils/ops/conv2d_gradfix.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Custom replacement for `torch.nn.functional.conv2d` that supports
|
10 |
+
arbitrarily high order gradients with zero performance penalty."""
|
11 |
+
|
12 |
+
import warnings
|
13 |
+
import contextlib
|
14 |
+
import torch
|
15 |
+
|
16 |
+
# pylint: disable=redefined-builtin
|
17 |
+
# pylint: disable=arguments-differ
|
18 |
+
# pylint: disable=protected-access
|
19 |
+
|
20 |
+
#----------------------------------------------------------------------------
|
21 |
+
|
22 |
+
enabled = False # Enable the custom op by setting this to true.
|
23 |
+
weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
|
24 |
+
|
25 |
+
@contextlib.contextmanager
|
26 |
+
def no_weight_gradients():
|
27 |
+
global weight_gradients_disabled
|
28 |
+
old = weight_gradients_disabled
|
29 |
+
weight_gradients_disabled = True
|
30 |
+
yield
|
31 |
+
weight_gradients_disabled = old
|
32 |
+
|
33 |
+
#----------------------------------------------------------------------------
|
34 |
+
|
35 |
+
def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
36 |
+
if _should_use_custom_op(input):
|
37 |
+
return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
|
38 |
+
return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
39 |
+
|
40 |
+
def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
|
41 |
+
if _should_use_custom_op(input):
|
42 |
+
return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
|
43 |
+
return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
|
44 |
+
|
45 |
+
#----------------------------------------------------------------------------
|
46 |
+
|
47 |
+
def _should_use_custom_op(input):
|
48 |
+
assert isinstance(input, torch.Tensor)
|
49 |
+
if (not enabled) or (not torch.backends.cudnn.enabled):
|
50 |
+
return False
|
51 |
+
if input.device.type != 'cuda':
|
52 |
+
return False
|
53 |
+
if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
|
54 |
+
return True
|
55 |
+
warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().')
|
56 |
+
return False
|
57 |
+
|
58 |
+
def _tuple_of_ints(xs, ndim):
|
59 |
+
xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
|
60 |
+
assert len(xs) == ndim
|
61 |
+
assert all(isinstance(x, int) for x in xs)
|
62 |
+
return xs
|
63 |
+
|
64 |
+
#----------------------------------------------------------------------------
|
65 |
+
|
66 |
+
_conv2d_gradfix_cache = dict()
|
67 |
+
|
68 |
+
def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
|
69 |
+
# Parse arguments.
|
70 |
+
ndim = 2
|
71 |
+
weight_shape = tuple(weight_shape)
|
72 |
+
stride = _tuple_of_ints(stride, ndim)
|
73 |
+
padding = _tuple_of_ints(padding, ndim)
|
74 |
+
output_padding = _tuple_of_ints(output_padding, ndim)
|
75 |
+
dilation = _tuple_of_ints(dilation, ndim)
|
76 |
+
|
77 |
+
# Lookup from cache.
|
78 |
+
key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
|
79 |
+
if key in _conv2d_gradfix_cache:
|
80 |
+
return _conv2d_gradfix_cache[key]
|
81 |
+
|
82 |
+
# Validate arguments.
|
83 |
+
assert groups >= 1
|
84 |
+
assert len(weight_shape) == ndim + 2
|
85 |
+
assert all(stride[i] >= 1 for i in range(ndim))
|
86 |
+
assert all(padding[i] >= 0 for i in range(ndim))
|
87 |
+
assert all(dilation[i] >= 0 for i in range(ndim))
|
88 |
+
if not transpose:
|
89 |
+
assert all(output_padding[i] == 0 for i in range(ndim))
|
90 |
+
else: # transpose
|
91 |
+
assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
|
92 |
+
|
93 |
+
# Helpers.
|
94 |
+
common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
|
95 |
+
def calc_output_padding(input_shape, output_shape):
|
96 |
+
if transpose:
|
97 |
+
return [0, 0]
|
98 |
+
return [
|
99 |
+
input_shape[i + 2]
|
100 |
+
- (output_shape[i + 2] - 1) * stride[i]
|
101 |
+
- (1 - 2 * padding[i])
|
102 |
+
- dilation[i] * (weight_shape[i + 2] - 1)
|
103 |
+
for i in range(ndim)
|
104 |
+
]
|
105 |
+
|
106 |
+
# Forward & backward.
|
107 |
+
class Conv2d(torch.autograd.Function):
|
108 |
+
@staticmethod
|
109 |
+
def forward(ctx, input, weight, bias):
|
110 |
+
assert weight.shape == weight_shape
|
111 |
+
if not transpose:
|
112 |
+
output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
|
113 |
+
else: # transpose
|
114 |
+
output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
|
115 |
+
ctx.save_for_backward(input, weight)
|
116 |
+
return output
|
117 |
+
|
118 |
+
@staticmethod
|
119 |
+
def backward(ctx, grad_output):
|
120 |
+
input, weight = ctx.saved_tensors
|
121 |
+
grad_input = None
|
122 |
+
grad_weight = None
|
123 |
+
grad_bias = None
|
124 |
+
|
125 |
+
if ctx.needs_input_grad[0]:
|
126 |
+
p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
|
127 |
+
grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None)
|
128 |
+
assert grad_input.shape == input.shape
|
129 |
+
|
130 |
+
if ctx.needs_input_grad[1] and not weight_gradients_disabled:
|
131 |
+
grad_weight = Conv2dGradWeight.apply(grad_output, input)
|
132 |
+
assert grad_weight.shape == weight_shape
|
133 |
+
|
134 |
+
if ctx.needs_input_grad[2]:
|
135 |
+
grad_bias = grad_output.sum([0, 2, 3])
|
136 |
+
|
137 |
+
return grad_input, grad_weight, grad_bias
|
138 |
+
|
139 |
+
# Gradient with respect to the weights.
|
140 |
+
class Conv2dGradWeight(torch.autograd.Function):
|
141 |
+
@staticmethod
|
142 |
+
def forward(ctx, grad_output, input):
|
143 |
+
op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight')
|
144 |
+
flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
|
145 |
+
grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
|
146 |
+
assert grad_weight.shape == weight_shape
|
147 |
+
ctx.save_for_backward(grad_output, input)
|
148 |
+
return grad_weight
|
149 |
+
|
150 |
+
@staticmethod
|
151 |
+
def backward(ctx, grad2_grad_weight):
|
152 |
+
grad_output, input = ctx.saved_tensors
|
153 |
+
grad2_grad_output = None
|
154 |
+
grad2_input = None
|
155 |
+
|
156 |
+
if ctx.needs_input_grad[0]:
|
157 |
+
grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
|
158 |
+
assert grad2_grad_output.shape == grad_output.shape
|
159 |
+
|
160 |
+
if ctx.needs_input_grad[1]:
|
161 |
+
p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
|
162 |
+
grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None)
|
163 |
+
assert grad2_input.shape == input.shape
|
164 |
+
|
165 |
+
return grad2_grad_output, grad2_input
|
166 |
+
|
167 |
+
_conv2d_gradfix_cache[key] = Conv2d
|
168 |
+
return Conv2d
|
169 |
+
|
170 |
+
#----------------------------------------------------------------------------
|
torch_utils/ops/conv2d_resample.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""2D convolution with optional up/downsampling."""
|
10 |
+
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from .. import misc
|
14 |
+
from . import conv2d_gradfix
|
15 |
+
from . import upfirdn2d
|
16 |
+
from .upfirdn2d import _parse_padding
|
17 |
+
from .upfirdn2d import _get_filter_size
|
18 |
+
|
19 |
+
#----------------------------------------------------------------------------
|
20 |
+
|
21 |
+
def _get_weight_shape(w):
|
22 |
+
with misc.suppress_tracer_warnings(): # this value will be treated as a constant
|
23 |
+
shape = [int(sz) for sz in w.shape]
|
24 |
+
misc.assert_shape(w, shape)
|
25 |
+
return shape
|
26 |
+
|
27 |
+
#----------------------------------------------------------------------------
|
28 |
+
|
29 |
+
def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
|
30 |
+
"""Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
|
31 |
+
"""
|
32 |
+
out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
|
33 |
+
|
34 |
+
# Flip weight if requested.
|
35 |
+
if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
|
36 |
+
w = w.flip([2, 3])
|
37 |
+
|
38 |
+
# Workaround performance pitfall in cuDNN 8.0.5, triggered when using
|
39 |
+
# 1x1 kernel + memory_format=channels_last + less than 64 channels.
|
40 |
+
if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose:
|
41 |
+
if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64:
|
42 |
+
if out_channels <= 4 and groups == 1:
|
43 |
+
in_shape = x.shape
|
44 |
+
x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1])
|
45 |
+
x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]])
|
46 |
+
else:
|
47 |
+
x = x.to(memory_format=torch.contiguous_format)
|
48 |
+
w = w.to(memory_format=torch.contiguous_format)
|
49 |
+
x = conv2d_gradfix.conv2d(x, w, groups=groups)
|
50 |
+
return x.to(memory_format=torch.channels_last)
|
51 |
+
|
52 |
+
# Otherwise => execute using conv2d_gradfix.
|
53 |
+
op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
|
54 |
+
return op(x, w, stride=stride, padding=padding, groups=groups)
|
55 |
+
|
56 |
+
#----------------------------------------------------------------------------
|
57 |
+
|
58 |
+
@misc.profiled_function
|
59 |
+
def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
|
60 |
+
r"""2D convolution with optional up/downsampling.
|
61 |
+
|
62 |
+
Padding is performed only once at the beginning, not between the operations.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
x: Input tensor of shape
|
66 |
+
`[batch_size, in_channels, in_height, in_width]`.
|
67 |
+
w: Weight tensor of shape
|
68 |
+
`[out_channels, in_channels//groups, kernel_height, kernel_width]`.
|
69 |
+
f: Low-pass filter for up/downsampling. Must be prepared beforehand by
|
70 |
+
calling upfirdn2d.setup_filter(). None = identity (default).
|
71 |
+
up: Integer upsampling factor (default: 1).
|
72 |
+
down: Integer downsampling factor (default: 1).
|
73 |
+
padding: Padding with respect to the upsampled image. Can be a single number
|
74 |
+
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
75 |
+
(default: 0).
|
76 |
+
groups: Split input channels into N groups (default: 1).
|
77 |
+
flip_weight: False = convolution, True = correlation (default: True).
|
78 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
82 |
+
"""
|
83 |
+
# Validate arguments.
|
84 |
+
assert isinstance(x, torch.Tensor) and (x.ndim == 4)
|
85 |
+
assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
|
86 |
+
assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
|
87 |
+
assert isinstance(up, int) and (up >= 1)
|
88 |
+
assert isinstance(down, int) and (down >= 1)
|
89 |
+
assert isinstance(groups, int) and (groups >= 1)
|
90 |
+
out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
|
91 |
+
fw, fh = _get_filter_size(f)
|
92 |
+
px0, px1, py0, py1 = _parse_padding(padding)
|
93 |
+
|
94 |
+
# Adjust padding to account for up/downsampling.
|
95 |
+
if up > 1:
|
96 |
+
px0 += (fw + up - 1) // 2
|
97 |
+
px1 += (fw - up) // 2
|
98 |
+
py0 += (fh + up - 1) // 2
|
99 |
+
py1 += (fh - up) // 2
|
100 |
+
if down > 1:
|
101 |
+
px0 += (fw - down + 1) // 2
|
102 |
+
px1 += (fw - down) // 2
|
103 |
+
py0 += (fh - down + 1) // 2
|
104 |
+
py1 += (fh - down) // 2
|
105 |
+
|
106 |
+
# Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
|
107 |
+
if kw == 1 and kh == 1 and (down > 1 and up == 1):
|
108 |
+
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
|
109 |
+
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
110 |
+
return x
|
111 |
+
|
112 |
+
# Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
|
113 |
+
if kw == 1 and kh == 1 and (up > 1 and down == 1):
|
114 |
+
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
115 |
+
x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
|
116 |
+
return x
|
117 |
+
|
118 |
+
# Fast path: downsampling only => use strided convolution.
|
119 |
+
if down > 1 and up == 1:
|
120 |
+
x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
|
121 |
+
x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
|
122 |
+
return x
|
123 |
+
|
124 |
+
# Fast path: upsampling with optional downsampling => use transpose strided convolution.
|
125 |
+
if up > 1:
|
126 |
+
if groups == 1:
|
127 |
+
w = w.transpose(0, 1)
|
128 |
+
else:
|
129 |
+
w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
|
130 |
+
w = w.transpose(1, 2)
|
131 |
+
w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
|
132 |
+
px0 -= kw - 1
|
133 |
+
px1 -= kw - up
|
134 |
+
py0 -= kh - 1
|
135 |
+
py1 -= kh - up
|
136 |
+
pxt = max(min(-px0, -px1), 0)
|
137 |
+
pyt = max(min(-py0, -py1), 0)
|
138 |
+
x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
|
139 |
+
x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
|
140 |
+
if down > 1:
|
141 |
+
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
|
142 |
+
return x
|
143 |
+
|
144 |
+
# Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
|
145 |
+
if up == 1 and down == 1:
|
146 |
+
if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
|
147 |
+
return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
|
148 |
+
|
149 |
+
# Fallback: Generic reference implementation.
|
150 |
+
x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
|
151 |
+
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
152 |
+
if down > 1:
|
153 |
+
x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
|
154 |
+
return x
|
155 |
+
|
156 |
+
#----------------------------------------------------------------------------
|
torch_utils/ops/fma.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
|
10 |
+
|
11 |
+
import torch
|
12 |
+
|
13 |
+
#----------------------------------------------------------------------------
|
14 |
+
|
15 |
+
def fma(a, b, c): # => a * b + c
|
16 |
+
return _FusedMultiplyAdd.apply(a, b, c)
|
17 |
+
|
18 |
+
#----------------------------------------------------------------------------
|
19 |
+
|
20 |
+
class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
|
21 |
+
@staticmethod
|
22 |
+
def forward(ctx, a, b, c): # pylint: disable=arguments-differ
|
23 |
+
out = torch.addcmul(c, a, b)
|
24 |
+
ctx.save_for_backward(a, b)
|
25 |
+
ctx.c_shape = c.shape
|
26 |
+
return out
|
27 |
+
|
28 |
+
@staticmethod
|
29 |
+
def backward(ctx, dout): # pylint: disable=arguments-differ
|
30 |
+
a, b = ctx.saved_tensors
|
31 |
+
c_shape = ctx.c_shape
|
32 |
+
da = None
|
33 |
+
db = None
|
34 |
+
dc = None
|
35 |
+
|
36 |
+
if ctx.needs_input_grad[0]:
|
37 |
+
da = _unbroadcast(dout * b, a.shape)
|
38 |
+
|
39 |
+
if ctx.needs_input_grad[1]:
|
40 |
+
db = _unbroadcast(dout * a, b.shape)
|
41 |
+
|
42 |
+
if ctx.needs_input_grad[2]:
|
43 |
+
dc = _unbroadcast(dout, c_shape)
|
44 |
+
|
45 |
+
return da, db, dc
|
46 |
+
|
47 |
+
#----------------------------------------------------------------------------
|
48 |
+
|
49 |
+
def _unbroadcast(x, shape):
|
50 |
+
extra_dims = x.ndim - len(shape)
|
51 |
+
assert extra_dims >= 0
|
52 |
+
dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
|
53 |
+
if len(dim):
|
54 |
+
x = x.sum(dim=dim, keepdim=True)
|
55 |
+
if extra_dims:
|
56 |
+
x = x.reshape(-1, *x.shape[extra_dims+1:])
|
57 |
+
assert x.shape == shape
|
58 |
+
return x
|
59 |
+
|
60 |
+
#----------------------------------------------------------------------------
|
torch_utils/ops/grid_sample_gradfix.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Custom replacement for `torch.nn.functional.grid_sample` that
|
10 |
+
supports arbitrarily high order gradients between the input and output.
|
11 |
+
Only works on 2D images and assumes
|
12 |
+
`mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
|
13 |
+
|
14 |
+
import warnings
|
15 |
+
import torch
|
16 |
+
|
17 |
+
# pylint: disable=redefined-builtin
|
18 |
+
# pylint: disable=arguments-differ
|
19 |
+
# pylint: disable=protected-access
|
20 |
+
|
21 |
+
#----------------------------------------------------------------------------
|
22 |
+
|
23 |
+
enabled = False # Enable the custom op by setting this to true.
|
24 |
+
|
25 |
+
#----------------------------------------------------------------------------
|
26 |
+
|
27 |
+
def grid_sample(input, grid):
|
28 |
+
if _should_use_custom_op():
|
29 |
+
return _GridSample2dForward.apply(input, grid)
|
30 |
+
return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
|
31 |
+
|
32 |
+
#----------------------------------------------------------------------------
|
33 |
+
|
34 |
+
def _should_use_custom_op():
|
35 |
+
if not enabled:
|
36 |
+
return False
|
37 |
+
if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
|
38 |
+
return True
|
39 |
+
warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().')
|
40 |
+
return False
|
41 |
+
|
42 |
+
#----------------------------------------------------------------------------
|
43 |
+
|
44 |
+
class _GridSample2dForward(torch.autograd.Function):
|
45 |
+
@staticmethod
|
46 |
+
def forward(ctx, input, grid):
|
47 |
+
assert input.ndim == 4
|
48 |
+
assert grid.ndim == 4
|
49 |
+
output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
|
50 |
+
ctx.save_for_backward(input, grid)
|
51 |
+
return output
|
52 |
+
|
53 |
+
@staticmethod
|
54 |
+
def backward(ctx, grad_output):
|
55 |
+
input, grid = ctx.saved_tensors
|
56 |
+
grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
|
57 |
+
return grad_input, grad_grid
|
58 |
+
|
59 |
+
#----------------------------------------------------------------------------
|
60 |
+
|
61 |
+
class _GridSample2dBackward(torch.autograd.Function):
|
62 |
+
@staticmethod
|
63 |
+
def forward(ctx, grad_output, input, grid):
|
64 |
+
op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
|
65 |
+
grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
|
66 |
+
ctx.save_for_backward(grid)
|
67 |
+
return grad_input, grad_grid
|
68 |
+
|
69 |
+
@staticmethod
|
70 |
+
def backward(ctx, grad2_grad_input, grad2_grad_grid):
|
71 |
+
_ = grad2_grad_grid # unused
|
72 |
+
grid, = ctx.saved_tensors
|
73 |
+
grad2_grad_output = None
|
74 |
+
grad2_input = None
|
75 |
+
grad2_grid = None
|
76 |
+
|
77 |
+
if ctx.needs_input_grad[0]:
|
78 |
+
grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
|
79 |
+
|
80 |
+
assert not ctx.needs_input_grad[2]
|
81 |
+
return grad2_grad_output, grad2_input, grad2_grid
|
82 |
+
|
83 |
+
#----------------------------------------------------------------------------
|
torch_utils/ops/upfirdn2d.cpp
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#include <torch/extension.h>
|
10 |
+
#include <ATen/cuda/CUDAContext.h>
|
11 |
+
#include <c10/cuda/CUDAGuard.h>
|
12 |
+
#include "upfirdn2d.h"
|
13 |
+
|
14 |
+
//------------------------------------------------------------------------
|
15 |
+
|
16 |
+
static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
|
17 |
+
{
|
18 |
+
// Validate arguments.
|
19 |
+
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
|
20 |
+
TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
|
21 |
+
TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
|
22 |
+
TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
|
23 |
+
TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
|
24 |
+
TORCH_CHECK(x.dim() == 4, "x must be rank 4");
|
25 |
+
TORCH_CHECK(f.dim() == 2, "f must be rank 2");
|
26 |
+
TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
|
27 |
+
TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
|
28 |
+
TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
|
29 |
+
|
30 |
+
// Create output tensor.
|
31 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
32 |
+
int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
|
33 |
+
int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
|
34 |
+
TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
|
35 |
+
torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
|
36 |
+
TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
|
37 |
+
|
38 |
+
// Initialize CUDA kernel parameters.
|
39 |
+
upfirdn2d_kernel_params p;
|
40 |
+
p.x = x.data_ptr();
|
41 |
+
p.f = f.data_ptr<float>();
|
42 |
+
p.y = y.data_ptr();
|
43 |
+
p.up = make_int2(upx, upy);
|
44 |
+
p.down = make_int2(downx, downy);
|
45 |
+
p.pad0 = make_int2(padx0, pady0);
|
46 |
+
p.flip = (flip) ? 1 : 0;
|
47 |
+
p.gain = gain;
|
48 |
+
p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
|
49 |
+
p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
|
50 |
+
p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
|
51 |
+
p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
|
52 |
+
p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
|
53 |
+
p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
|
54 |
+
p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
|
55 |
+
p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
|
56 |
+
|
57 |
+
// Choose CUDA kernel.
|
58 |
+
upfirdn2d_kernel_spec spec;
|
59 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
|
60 |
+
{
|
61 |
+
spec = choose_upfirdn2d_kernel<scalar_t>(p);
|
62 |
+
});
|
63 |
+
|
64 |
+
// Set looping options.
|
65 |
+
p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
|
66 |
+
p.loopMinor = spec.loopMinor;
|
67 |
+
p.loopX = spec.loopX;
|
68 |
+
p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
|
69 |
+
p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
|
70 |
+
|
71 |
+
// Compute grid size.
|
72 |
+
dim3 blockSize, gridSize;
|
73 |
+
if (spec.tileOutW < 0) // large
|
74 |
+
{
|
75 |
+
blockSize = dim3(4, 32, 1);
|
76 |
+
gridSize = dim3(
|
77 |
+
((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
|
78 |
+
(p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
|
79 |
+
p.launchMajor);
|
80 |
+
}
|
81 |
+
else // small
|
82 |
+
{
|
83 |
+
blockSize = dim3(256, 1, 1);
|
84 |
+
gridSize = dim3(
|
85 |
+
((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
|
86 |
+
(p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
|
87 |
+
p.launchMajor);
|
88 |
+
}
|
89 |
+
|
90 |
+
// Launch CUDA kernel.
|
91 |
+
void* args[] = {&p};
|
92 |
+
AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
|
93 |
+
return y;
|
94 |
+
}
|
95 |
+
|
96 |
+
//------------------------------------------------------------------------
|
97 |
+
|
98 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
99 |
+
{
|
100 |
+
m.def("upfirdn2d", &upfirdn2d);
|
101 |
+
}
|
102 |
+
|
103 |
+
//------------------------------------------------------------------------
|
torch_utils/ops/upfirdn2d.cu
ADDED
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#include <c10/util/Half.h>
|
10 |
+
#include "upfirdn2d.h"
|
11 |
+
|
12 |
+
//------------------------------------------------------------------------
|
13 |
+
// Helpers.
|
14 |
+
|
15 |
+
template <class T> struct InternalType;
|
16 |
+
template <> struct InternalType<double> { typedef double scalar_t; };
|
17 |
+
template <> struct InternalType<float> { typedef float scalar_t; };
|
18 |
+
template <> struct InternalType<c10::Half> { typedef float scalar_t; };
|
19 |
+
|
20 |
+
static __device__ __forceinline__ int floor_div(int a, int b)
|
21 |
+
{
|
22 |
+
int t = 1 - a / b;
|
23 |
+
return (a + t * b) / b - t;
|
24 |
+
}
|
25 |
+
|
26 |
+
//------------------------------------------------------------------------
|
27 |
+
// Generic CUDA implementation for large filters.
|
28 |
+
|
29 |
+
template <class T> static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p)
|
30 |
+
{
|
31 |
+
typedef typename InternalType<T>::scalar_t scalar_t;
|
32 |
+
|
33 |
+
// Calculate thread index.
|
34 |
+
int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
|
35 |
+
int outY = minorBase / p.launchMinor;
|
36 |
+
minorBase -= outY * p.launchMinor;
|
37 |
+
int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
|
38 |
+
int majorBase = blockIdx.z * p.loopMajor;
|
39 |
+
if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
|
40 |
+
return;
|
41 |
+
|
42 |
+
// Setup Y receptive field.
|
43 |
+
int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
|
44 |
+
int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
|
45 |
+
int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
|
46 |
+
int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
|
47 |
+
if (p.flip)
|
48 |
+
filterY = p.filterSize.y - 1 - filterY;
|
49 |
+
|
50 |
+
// Loop over major, minor, and X.
|
51 |
+
for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
|
52 |
+
for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor)
|
53 |
+
{
|
54 |
+
int nc = major * p.sizeMinor + minor;
|
55 |
+
int n = nc / p.inSize.z;
|
56 |
+
int c = nc - n * p.inSize.z;
|
57 |
+
for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y)
|
58 |
+
{
|
59 |
+
// Setup X receptive field.
|
60 |
+
int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
|
61 |
+
int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
|
62 |
+
int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX;
|
63 |
+
int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
|
64 |
+
if (p.flip)
|
65 |
+
filterX = p.filterSize.x - 1 - filterX;
|
66 |
+
|
67 |
+
// Initialize pointers.
|
68 |
+
const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
|
69 |
+
const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
|
70 |
+
int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
|
71 |
+
int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
|
72 |
+
|
73 |
+
// Inner loop.
|
74 |
+
scalar_t v = 0;
|
75 |
+
for (int y = 0; y < h; y++)
|
76 |
+
{
|
77 |
+
for (int x = 0; x < w; x++)
|
78 |
+
{
|
79 |
+
v += (scalar_t)(*xp) * (scalar_t)(*fp);
|
80 |
+
xp += p.inStride.x;
|
81 |
+
fp += filterStepX;
|
82 |
+
}
|
83 |
+
xp += p.inStride.y - w * p.inStride.x;
|
84 |
+
fp += filterStepY - w * filterStepX;
|
85 |
+
}
|
86 |
+
|
87 |
+
// Store result.
|
88 |
+
v *= p.gain;
|
89 |
+
((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
|
90 |
+
}
|
91 |
+
}
|
92 |
+
}
|
93 |
+
|
94 |
+
//------------------------------------------------------------------------
|
95 |
+
// Specialized CUDA implementation for small filters.
|
96 |
+
|
97 |
+
template <class T, int upx, int upy, int downx, int downy, int filterW, int filterH, int tileOutW, int tileOutH, int loopMinor>
|
98 |
+
static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p)
|
99 |
+
{
|
100 |
+
typedef typename InternalType<T>::scalar_t scalar_t;
|
101 |
+
const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
|
102 |
+
const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
|
103 |
+
__shared__ volatile scalar_t sf[filterH][filterW];
|
104 |
+
__shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
|
105 |
+
|
106 |
+
// Calculate tile index.
|
107 |
+
int minorBase = blockIdx.x;
|
108 |
+
int tileOutY = minorBase / p.launchMinor;
|
109 |
+
minorBase -= tileOutY * p.launchMinor;
|
110 |
+
minorBase *= loopMinor;
|
111 |
+
tileOutY *= tileOutH;
|
112 |
+
int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
|
113 |
+
int majorBase = blockIdx.z * p.loopMajor;
|
114 |
+
if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor)
|
115 |
+
return;
|
116 |
+
|
117 |
+
// Load filter (flipped).
|
118 |
+
for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x)
|
119 |
+
{
|
120 |
+
int fy = tapIdx / filterW;
|
121 |
+
int fx = tapIdx - fy * filterW;
|
122 |
+
scalar_t v = 0;
|
123 |
+
if (fx < p.filterSize.x & fy < p.filterSize.y)
|
124 |
+
{
|
125 |
+
int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
|
126 |
+
int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
|
127 |
+
v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
|
128 |
+
}
|
129 |
+
sf[fy][fx] = v;
|
130 |
+
}
|
131 |
+
|
132 |
+
// Loop over major and X.
|
133 |
+
for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
|
134 |
+
{
|
135 |
+
int baseNC = major * p.sizeMinor + minorBase;
|
136 |
+
int n = baseNC / p.inSize.z;
|
137 |
+
int baseC = baseNC - n * p.inSize.z;
|
138 |
+
for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW)
|
139 |
+
{
|
140 |
+
// Load input pixels.
|
141 |
+
int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
|
142 |
+
int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
|
143 |
+
int tileInX = floor_div(tileMidX, upx);
|
144 |
+
int tileInY = floor_div(tileMidY, upy);
|
145 |
+
__syncthreads();
|
146 |
+
for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x)
|
147 |
+
{
|
148 |
+
int relC = inIdx;
|
149 |
+
int relInX = relC / loopMinor;
|
150 |
+
int relInY = relInX / tileInW;
|
151 |
+
relC -= relInX * loopMinor;
|
152 |
+
relInX -= relInY * tileInW;
|
153 |
+
int c = baseC + relC;
|
154 |
+
int inX = tileInX + relInX;
|
155 |
+
int inY = tileInY + relInY;
|
156 |
+
scalar_t v = 0;
|
157 |
+
if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z)
|
158 |
+
v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
|
159 |
+
sx[relInY][relInX][relC] = v;
|
160 |
+
}
|
161 |
+
|
162 |
+
// Loop over output pixels.
|
163 |
+
__syncthreads();
|
164 |
+
for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x)
|
165 |
+
{
|
166 |
+
int relC = outIdx;
|
167 |
+
int relOutX = relC / loopMinor;
|
168 |
+
int relOutY = relOutX / tileOutW;
|
169 |
+
relC -= relOutX * loopMinor;
|
170 |
+
relOutX -= relOutY * tileOutW;
|
171 |
+
int c = baseC + relC;
|
172 |
+
int outX = tileOutX + relOutX;
|
173 |
+
int outY = tileOutY + relOutY;
|
174 |
+
|
175 |
+
// Setup receptive field.
|
176 |
+
int midX = tileMidX + relOutX * downx;
|
177 |
+
int midY = tileMidY + relOutY * downy;
|
178 |
+
int inX = floor_div(midX, upx);
|
179 |
+
int inY = floor_div(midY, upy);
|
180 |
+
int relInX = inX - tileInX;
|
181 |
+
int relInY = inY - tileInY;
|
182 |
+
int filterX = (inX + 1) * upx - midX - 1; // flipped
|
183 |
+
int filterY = (inY + 1) * upy - midY - 1; // flipped
|
184 |
+
|
185 |
+
// Inner loop.
|
186 |
+
if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z)
|
187 |
+
{
|
188 |
+
scalar_t v = 0;
|
189 |
+
#pragma unroll
|
190 |
+
for (int y = 0; y < filterH / upy; y++)
|
191 |
+
#pragma unroll
|
192 |
+
for (int x = 0; x < filterW / upx; x++)
|
193 |
+
v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx];
|
194 |
+
v *= p.gain;
|
195 |
+
((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
|
196 |
+
}
|
197 |
+
}
|
198 |
+
}
|
199 |
+
}
|
200 |
+
}
|
201 |
+
|
202 |
+
//------------------------------------------------------------------------
|
203 |
+
// CUDA kernel selection.
|
204 |
+
|
205 |
+
template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p)
|
206 |
+
{
|
207 |
+
int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
|
208 |
+
|
209 |
+
upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,1, 4}; // contiguous
|
210 |
+
if (s == 1) spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,4, 1}; // channels_last
|
211 |
+
|
212 |
+
if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
|
213 |
+
{
|
214 |
+
if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 64,16,1>, 64,16,1, 1};
|
215 |
+
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
|
216 |
+
if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5, 64,16,1>, 64,16,1, 1};
|
217 |
+
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
|
218 |
+
if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3, 64,16,1>, 64,16,1, 1};
|
219 |
+
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
|
220 |
+
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 20,1, 128,8,1>, 128,8,1, 1};
|
221 |
+
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
|
222 |
+
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 12,1, 128,8,1>, 128,8,1, 1};
|
223 |
+
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
|
224 |
+
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
|
225 |
+
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,20, 32,32,1>, 32,32,1, 1};
|
226 |
+
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
|
227 |
+
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,12, 32,32,1>, 32,32,1, 1};
|
228 |
+
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
|
229 |
+
}
|
230 |
+
if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
|
231 |
+
{
|
232 |
+
if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 16,16,8>, 16,16,8, 1};
|
233 |
+
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
234 |
+
if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
235 |
+
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
236 |
+
if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
237 |
+
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
|
238 |
+
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 20,1, 128,1,16>, 128,1,16, 1};
|
239 |
+
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
|
240 |
+
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 12,1, 128,1,16>, 128,1,16, 1};
|
241 |
+
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
|
242 |
+
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
|
243 |
+
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,20, 1,128,16>, 1,128,16, 1};
|
244 |
+
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
|
245 |
+
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,12, 1,128,16>, 1,128,16, 1};
|
246 |
+
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
|
247 |
+
}
|
248 |
+
if (s != 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
|
249 |
+
{
|
250 |
+
if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 64,16,1>, 64,16,1, 1};
|
251 |
+
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
|
252 |
+
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
|
253 |
+
if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 64,16,1>, 64,16,1, 1};
|
254 |
+
}
|
255 |
+
if (s == 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
|
256 |
+
{
|
257 |
+
if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 16,16,8>, 16,16,8, 1};
|
258 |
+
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 16,16,8>, 16,16,8, 1};
|
259 |
+
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
|
260 |
+
if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 16,16,8>, 16,16,8, 1};
|
261 |
+
}
|
262 |
+
if (s != 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
|
263 |
+
{
|
264 |
+
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
|
265 |
+
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 20,1, 128,8,1>, 128,8,1, 1};
|
266 |
+
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
|
267 |
+
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 12,1, 128,8,1>, 128,8,1, 1};
|
268 |
+
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
|
269 |
+
}
|
270 |
+
if (s == 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
|
271 |
+
{
|
272 |
+
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
|
273 |
+
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 20,1, 128,1,16>, 128,1,16, 1};
|
274 |
+
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
|
275 |
+
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 12,1, 128,1,16>, 128,1,16, 1};
|
276 |
+
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
|
277 |
+
}
|
278 |
+
if (s != 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
|
279 |
+
{
|
280 |
+
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
|
281 |
+
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,20, 32,32,1>, 32,32,1, 1};
|
282 |
+
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
|
283 |
+
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,12, 32,32,1>, 32,32,1, 1};
|
284 |
+
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
|
285 |
+
}
|
286 |
+
if (s == 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
|
287 |
+
{
|
288 |
+
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
|
289 |
+
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,20, 1,128,16>, 1,128,16, 1};
|
290 |
+
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
|
291 |
+
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,12, 1,128,16>, 1,128,16, 1};
|
292 |
+
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
|
293 |
+
}
|
294 |
+
if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // contiguous
|
295 |
+
{
|
296 |
+
if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 32,8,1>, 32,8,1, 1};
|
297 |
+
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 32,8,1>, 32,8,1, 1};
|
298 |
+
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 32,8,1>, 32,8,1, 1};
|
299 |
+
if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 32,8,1>, 32,8,1, 1};
|
300 |
+
}
|
301 |
+
if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // channels_last
|
302 |
+
{
|
303 |
+
if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 8,8,8>, 8,8,8, 1};
|
304 |
+
if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 8,8,8>, 8,8,8, 1};
|
305 |
+
if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 8,8,8>, 8,8,8, 1};
|
306 |
+
if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 8,8,8>, 8,8,8, 1};
|
307 |
+
}
|
308 |
+
if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // contiguous
|
309 |
+
{
|
310 |
+
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,8,1>, 64,8,1, 1};
|
311 |
+
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 20,1, 64,8,1>, 64,8,1, 1};
|
312 |
+
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,8,1>, 64,8,1, 1};
|
313 |
+
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 12,1, 64,8,1>, 64,8,1, 1};
|
314 |
+
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,8,1>, 64,8,1, 1};
|
315 |
+
}
|
316 |
+
if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // channels_last
|
317 |
+
{
|
318 |
+
if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,1,8>, 64,1,8, 1};
|
319 |
+
if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 20,1, 64,1,8>, 64,1,8, 1};
|
320 |
+
if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,1,8>, 64,1,8, 1};
|
321 |
+
if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 12,1, 64,1,8>, 64,1,8, 1};
|
322 |
+
if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,1,8>, 64,1,8, 1};
|
323 |
+
}
|
324 |
+
if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // contiguous
|
325 |
+
{
|
326 |
+
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 32,16,1>, 32,16,1, 1};
|
327 |
+
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,20, 32,16,1>, 32,16,1, 1};
|
328 |
+
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 32,16,1>, 32,16,1, 1};
|
329 |
+
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,12, 32,16,1>, 32,16,1, 1};
|
330 |
+
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 32,16,1>, 32,16,1, 1};
|
331 |
+
}
|
332 |
+
if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // channels_last
|
333 |
+
{
|
334 |
+
if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 1,64,8>, 1,64,8, 1};
|
335 |
+
if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,20, 1,64,8>, 1,64,8, 1};
|
336 |
+
if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 1,64,8>, 1,64,8, 1};
|
337 |
+
if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,12, 1,64,8>, 1,64,8, 1};
|
338 |
+
if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 1,64,8>, 1,64,8, 1};
|
339 |
+
}
|
340 |
+
return spec;
|
341 |
+
}
|
342 |
+
|
343 |
+
//------------------------------------------------------------------------
|
344 |
+
// Template specializations.
|
345 |
+
|
346 |
+
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<double> (const upfirdn2d_kernel_params& p);
|
347 |
+
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<float> (const upfirdn2d_kernel_params& p);
|
348 |
+
template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<c10::Half>(const upfirdn2d_kernel_params& p);
|
349 |
+
|
350 |
+
//------------------------------------------------------------------------
|
torch_utils/ops/upfirdn2d.h
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
//
|
3 |
+
// NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
// and proprietary rights in and to this software, related documentation
|
5 |
+
// and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
// distribution of this software and related documentation without an express
|
7 |
+
// license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
#include <cuda_runtime.h>
|
10 |
+
|
11 |
+
//------------------------------------------------------------------------
|
12 |
+
// CUDA kernel parameters.
|
13 |
+
|
14 |
+
struct upfirdn2d_kernel_params
|
15 |
+
{
|
16 |
+
const void* x;
|
17 |
+
const float* f;
|
18 |
+
void* y;
|
19 |
+
|
20 |
+
int2 up;
|
21 |
+
int2 down;
|
22 |
+
int2 pad0;
|
23 |
+
int flip;
|
24 |
+
float gain;
|
25 |
+
|
26 |
+
int4 inSize; // [width, height, channel, batch]
|
27 |
+
int4 inStride;
|
28 |
+
int2 filterSize; // [width, height]
|
29 |
+
int2 filterStride;
|
30 |
+
int4 outSize; // [width, height, channel, batch]
|
31 |
+
int4 outStride;
|
32 |
+
int sizeMinor;
|
33 |
+
int sizeMajor;
|
34 |
+
|
35 |
+
int loopMinor;
|
36 |
+
int loopMajor;
|
37 |
+
int loopX;
|
38 |
+
int launchMinor;
|
39 |
+
int launchMajor;
|
40 |
+
};
|
41 |
+
|
42 |
+
//------------------------------------------------------------------------
|
43 |
+
// CUDA kernel specialization.
|
44 |
+
|
45 |
+
struct upfirdn2d_kernel_spec
|
46 |
+
{
|
47 |
+
void* kernel;
|
48 |
+
int tileOutW;
|
49 |
+
int tileOutH;
|
50 |
+
int loopMinor;
|
51 |
+
int loopX;
|
52 |
+
};
|
53 |
+
|
54 |
+
//------------------------------------------------------------------------
|
55 |
+
// CUDA kernel selection.
|
56 |
+
|
57 |
+
template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
|
58 |
+
|
59 |
+
//------------------------------------------------------------------------
|
torch_utils/ops/upfirdn2d.py
ADDED
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Custom PyTorch ops for efficient resampling of 2D images."""
|
10 |
+
|
11 |
+
import os
|
12 |
+
import warnings
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
import traceback
|
16 |
+
|
17 |
+
from .. import custom_ops
|
18 |
+
from .. import misc
|
19 |
+
from . import conv2d_gradfix
|
20 |
+
|
21 |
+
#----------------------------------------------------------------------------
|
22 |
+
|
23 |
+
_inited = False
|
24 |
+
_plugin = None
|
25 |
+
|
26 |
+
def _init():
|
27 |
+
global _inited, _plugin
|
28 |
+
if not _inited:
|
29 |
+
sources = ['upfirdn2d.cpp', 'upfirdn2d.cu']
|
30 |
+
sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
|
31 |
+
try:
|
32 |
+
_plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
|
33 |
+
except:
|
34 |
+
warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
|
35 |
+
return _plugin is not None
|
36 |
+
|
37 |
+
def _parse_scaling(scaling):
|
38 |
+
if isinstance(scaling, int):
|
39 |
+
scaling = [scaling, scaling]
|
40 |
+
assert isinstance(scaling, (list, tuple))
|
41 |
+
assert all(isinstance(x, int) for x in scaling)
|
42 |
+
sx, sy = scaling
|
43 |
+
assert sx >= 1 and sy >= 1
|
44 |
+
return sx, sy
|
45 |
+
|
46 |
+
def _parse_padding(padding):
|
47 |
+
if isinstance(padding, int):
|
48 |
+
padding = [padding, padding]
|
49 |
+
assert isinstance(padding, (list, tuple))
|
50 |
+
assert all(isinstance(x, int) for x in padding)
|
51 |
+
if len(padding) == 2:
|
52 |
+
padx, pady = padding
|
53 |
+
padding = [padx, padx, pady, pady]
|
54 |
+
padx0, padx1, pady0, pady1 = padding
|
55 |
+
return padx0, padx1, pady0, pady1
|
56 |
+
|
57 |
+
def _get_filter_size(f):
|
58 |
+
if f is None:
|
59 |
+
return 1, 1
|
60 |
+
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
61 |
+
fw = f.shape[-1]
|
62 |
+
fh = f.shape[0]
|
63 |
+
with misc.suppress_tracer_warnings():
|
64 |
+
fw = int(fw)
|
65 |
+
fh = int(fh)
|
66 |
+
misc.assert_shape(f, [fh, fw][:f.ndim])
|
67 |
+
assert fw >= 1 and fh >= 1
|
68 |
+
return fw, fh
|
69 |
+
|
70 |
+
#----------------------------------------------------------------------------
|
71 |
+
|
72 |
+
def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None):
|
73 |
+
r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
f: Torch tensor, numpy array, or python list of the shape
|
77 |
+
`[filter_height, filter_width]` (non-separable),
|
78 |
+
`[filter_taps]` (separable),
|
79 |
+
`[]` (impulse), or
|
80 |
+
`None` (identity).
|
81 |
+
device: Result device (default: cpu).
|
82 |
+
normalize: Normalize the filter so that it retains the magnitude
|
83 |
+
for constant input signal (DC)? (default: True).
|
84 |
+
flip_filter: Flip the filter? (default: False).
|
85 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
86 |
+
separable: Return a separable filter? (default: select automatically).
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
Float32 tensor of the shape
|
90 |
+
`[filter_height, filter_width]` (non-separable) or
|
91 |
+
`[filter_taps]` (separable).
|
92 |
+
"""
|
93 |
+
# Validate.
|
94 |
+
if f is None:
|
95 |
+
f = 1
|
96 |
+
f = torch.as_tensor(f, dtype=torch.float32)
|
97 |
+
assert f.ndim in [0, 1, 2]
|
98 |
+
assert f.numel() > 0
|
99 |
+
if f.ndim == 0:
|
100 |
+
f = f[np.newaxis]
|
101 |
+
|
102 |
+
# Separable?
|
103 |
+
if separable is None:
|
104 |
+
separable = (f.ndim == 1 and f.numel() >= 8)
|
105 |
+
if f.ndim == 1 and not separable:
|
106 |
+
f = f.ger(f)
|
107 |
+
assert f.ndim == (1 if separable else 2)
|
108 |
+
|
109 |
+
# Apply normalize, flip, gain, and device.
|
110 |
+
if normalize:
|
111 |
+
f /= f.sum()
|
112 |
+
if flip_filter:
|
113 |
+
f = f.flip(list(range(f.ndim)))
|
114 |
+
f = f * (gain ** (f.ndim / 2))
|
115 |
+
f = f.to(device=device)
|
116 |
+
return f
|
117 |
+
|
118 |
+
#----------------------------------------------------------------------------
|
119 |
+
|
120 |
+
def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
121 |
+
r"""Pad, upsample, filter, and downsample a batch of 2D images.
|
122 |
+
|
123 |
+
Performs the following sequence of operations for each channel:
|
124 |
+
|
125 |
+
1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
|
126 |
+
|
127 |
+
2. Pad the image with the specified number of zeros on each side (`padding`).
|
128 |
+
Negative padding corresponds to cropping the image.
|
129 |
+
|
130 |
+
3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
|
131 |
+
so that the footprint of all output pixels lies within the input image.
|
132 |
+
|
133 |
+
4. Downsample the image by keeping every Nth pixel (`down`).
|
134 |
+
|
135 |
+
This sequence of operations bears close resemblance to scipy.signal.upfirdn().
|
136 |
+
The fused op is considerably more efficient than performing the same calculation
|
137 |
+
using standard PyTorch ops. It supports gradients of arbitrary order.
|
138 |
+
|
139 |
+
Args:
|
140 |
+
x: Float32/float64/float16 input tensor of the shape
|
141 |
+
`[batch_size, num_channels, in_height, in_width]`.
|
142 |
+
f: Float32 FIR filter of the shape
|
143 |
+
`[filter_height, filter_width]` (non-separable),
|
144 |
+
`[filter_taps]` (separable), or
|
145 |
+
`None` (identity).
|
146 |
+
up: Integer upsampling factor. Can be a single int or a list/tuple
|
147 |
+
`[x, y]` (default: 1).
|
148 |
+
down: Integer downsampling factor. Can be a single int or a list/tuple
|
149 |
+
`[x, y]` (default: 1).
|
150 |
+
padding: Padding with respect to the upsampled image. Can be a single number
|
151 |
+
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
152 |
+
(default: 0).
|
153 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
154 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
155 |
+
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
156 |
+
|
157 |
+
Returns:
|
158 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
159 |
+
"""
|
160 |
+
assert isinstance(x, torch.Tensor)
|
161 |
+
assert impl in ['ref', 'cuda']
|
162 |
+
if impl == 'cuda' and x.device.type == 'cuda' and _init():
|
163 |
+
return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f)
|
164 |
+
return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
|
165 |
+
|
166 |
+
#----------------------------------------------------------------------------
|
167 |
+
|
168 |
+
@misc.profiled_function
|
169 |
+
def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
|
170 |
+
"""Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
|
171 |
+
"""
|
172 |
+
# Validate arguments.
|
173 |
+
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
174 |
+
if f is None:
|
175 |
+
f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
|
176 |
+
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
177 |
+
assert f.dtype == torch.float32 and not f.requires_grad
|
178 |
+
batch_size, num_channels, in_height, in_width = x.shape
|
179 |
+
upx, upy = _parse_scaling(up)
|
180 |
+
downx, downy = _parse_scaling(down)
|
181 |
+
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
182 |
+
|
183 |
+
# Upsample by inserting zeros.
|
184 |
+
x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
|
185 |
+
x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
|
186 |
+
x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
|
187 |
+
|
188 |
+
# Pad or crop.
|
189 |
+
x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
|
190 |
+
x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)]
|
191 |
+
|
192 |
+
# Setup filter.
|
193 |
+
f = f * (gain ** (f.ndim / 2))
|
194 |
+
f = f.to(x.dtype)
|
195 |
+
if not flip_filter:
|
196 |
+
f = f.flip(list(range(f.ndim)))
|
197 |
+
|
198 |
+
# Convolve with the filter.
|
199 |
+
f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
|
200 |
+
if f.ndim == 4:
|
201 |
+
x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels)
|
202 |
+
else:
|
203 |
+
x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
|
204 |
+
x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
|
205 |
+
|
206 |
+
# Downsample by throwing away pixels.
|
207 |
+
x = x[:, :, ::downy, ::downx]
|
208 |
+
return x
|
209 |
+
|
210 |
+
#----------------------------------------------------------------------------
|
211 |
+
|
212 |
+
_upfirdn2d_cuda_cache = dict()
|
213 |
+
|
214 |
+
def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1):
|
215 |
+
"""Fast CUDA implementation of `upfirdn2d()` using custom ops.
|
216 |
+
"""
|
217 |
+
# Parse arguments.
|
218 |
+
upx, upy = _parse_scaling(up)
|
219 |
+
downx, downy = _parse_scaling(down)
|
220 |
+
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
221 |
+
|
222 |
+
# Lookup from cache.
|
223 |
+
key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
|
224 |
+
if key in _upfirdn2d_cuda_cache:
|
225 |
+
return _upfirdn2d_cuda_cache[key]
|
226 |
+
|
227 |
+
# Forward op.
|
228 |
+
class Upfirdn2dCuda(torch.autograd.Function):
|
229 |
+
@staticmethod
|
230 |
+
def forward(ctx, x, f): # pylint: disable=arguments-differ
|
231 |
+
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
232 |
+
if f is None:
|
233 |
+
f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
|
234 |
+
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
235 |
+
y = x
|
236 |
+
if f.ndim == 2:
|
237 |
+
y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
|
238 |
+
else:
|
239 |
+
y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain))
|
240 |
+
y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain))
|
241 |
+
ctx.save_for_backward(f)
|
242 |
+
ctx.x_shape = x.shape
|
243 |
+
return y
|
244 |
+
|
245 |
+
@staticmethod
|
246 |
+
def backward(ctx, dy): # pylint: disable=arguments-differ
|
247 |
+
f, = ctx.saved_tensors
|
248 |
+
_, _, ih, iw = ctx.x_shape
|
249 |
+
_, _, oh, ow = dy.shape
|
250 |
+
fw, fh = _get_filter_size(f)
|
251 |
+
p = [
|
252 |
+
fw - padx0 - 1,
|
253 |
+
iw * upx - ow * downx + padx0 - upx + 1,
|
254 |
+
fh - pady0 - 1,
|
255 |
+
ih * upy - oh * downy + pady0 - upy + 1,
|
256 |
+
]
|
257 |
+
dx = None
|
258 |
+
df = None
|
259 |
+
|
260 |
+
if ctx.needs_input_grad[0]:
|
261 |
+
dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f)
|
262 |
+
|
263 |
+
assert not ctx.needs_input_grad[1]
|
264 |
+
return dx, df
|
265 |
+
|
266 |
+
# Add to cache.
|
267 |
+
_upfirdn2d_cuda_cache[key] = Upfirdn2dCuda
|
268 |
+
return Upfirdn2dCuda
|
269 |
+
|
270 |
+
#----------------------------------------------------------------------------
|
271 |
+
|
272 |
+
def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
273 |
+
r"""Filter a batch of 2D images using the given 2D FIR filter.
|
274 |
+
|
275 |
+
By default, the result is padded so that its shape matches the input.
|
276 |
+
User-specified padding is applied on top of that, with negative values
|
277 |
+
indicating cropping. Pixels outside the image are assumed to be zero.
|
278 |
+
|
279 |
+
Args:
|
280 |
+
x: Float32/float64/float16 input tensor of the shape
|
281 |
+
`[batch_size, num_channels, in_height, in_width]`.
|
282 |
+
f: Float32 FIR filter of the shape
|
283 |
+
`[filter_height, filter_width]` (non-separable),
|
284 |
+
`[filter_taps]` (separable), or
|
285 |
+
`None` (identity).
|
286 |
+
padding: Padding with respect to the output. Can be a single number or a
|
287 |
+
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
288 |
+
(default: 0).
|
289 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
290 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
291 |
+
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
292 |
+
|
293 |
+
Returns:
|
294 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
295 |
+
"""
|
296 |
+
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
297 |
+
fw, fh = _get_filter_size(f)
|
298 |
+
p = [
|
299 |
+
padx0 + fw // 2,
|
300 |
+
padx1 + (fw - 1) // 2,
|
301 |
+
pady0 + fh // 2,
|
302 |
+
pady1 + (fh - 1) // 2,
|
303 |
+
]
|
304 |
+
return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
|
305 |
+
|
306 |
+
#----------------------------------------------------------------------------
|
307 |
+
|
308 |
+
def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
309 |
+
r"""Upsample a batch of 2D images using the given 2D FIR filter.
|
310 |
+
|
311 |
+
By default, the result is padded so that its shape is a multiple of the input.
|
312 |
+
User-specified padding is applied on top of that, with negative values
|
313 |
+
indicating cropping. Pixels outside the image are assumed to be zero.
|
314 |
+
|
315 |
+
Args:
|
316 |
+
x: Float32/float64/float16 input tensor of the shape
|
317 |
+
`[batch_size, num_channels, in_height, in_width]`.
|
318 |
+
f: Float32 FIR filter of the shape
|
319 |
+
`[filter_height, filter_width]` (non-separable),
|
320 |
+
`[filter_taps]` (separable), or
|
321 |
+
`None` (identity).
|
322 |
+
up: Integer upsampling factor. Can be a single int or a list/tuple
|
323 |
+
`[x, y]` (default: 1).
|
324 |
+
padding: Padding with respect to the output. Can be a single number or a
|
325 |
+
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
326 |
+
(default: 0).
|
327 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
328 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
329 |
+
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
330 |
+
|
331 |
+
Returns:
|
332 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
333 |
+
"""
|
334 |
+
upx, upy = _parse_scaling(up)
|
335 |
+
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
336 |
+
fw, fh = _get_filter_size(f)
|
337 |
+
p = [
|
338 |
+
padx0 + (fw + upx - 1) // 2,
|
339 |
+
padx1 + (fw - upx) // 2,
|
340 |
+
pady0 + (fh + upy - 1) // 2,
|
341 |
+
pady1 + (fh - upy) // 2,
|
342 |
+
]
|
343 |
+
return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl)
|
344 |
+
|
345 |
+
#----------------------------------------------------------------------------
|
346 |
+
|
347 |
+
def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
348 |
+
r"""Downsample a batch of 2D images using the given 2D FIR filter.
|
349 |
+
|
350 |
+
By default, the result is padded so that its shape is a fraction of the input.
|
351 |
+
User-specified padding is applied on top of that, with negative values
|
352 |
+
indicating cropping. Pixels outside the image are assumed to be zero.
|
353 |
+
|
354 |
+
Args:
|
355 |
+
x: Float32/float64/float16 input tensor of the shape
|
356 |
+
`[batch_size, num_channels, in_height, in_width]`.
|
357 |
+
f: Float32 FIR filter of the shape
|
358 |
+
`[filter_height, filter_width]` (non-separable),
|
359 |
+
`[filter_taps]` (separable), or
|
360 |
+
`None` (identity).
|
361 |
+
down: Integer downsampling factor. Can be a single int or a list/tuple
|
362 |
+
`[x, y]` (default: 1).
|
363 |
+
padding: Padding with respect to the input. Can be a single number or a
|
364 |
+
list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
365 |
+
(default: 0).
|
366 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
367 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
368 |
+
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
369 |
+
|
370 |
+
Returns:
|
371 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
372 |
+
"""
|
373 |
+
downx, downy = _parse_scaling(down)
|
374 |
+
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
375 |
+
fw, fh = _get_filter_size(f)
|
376 |
+
p = [
|
377 |
+
padx0 + (fw - downx + 1) // 2,
|
378 |
+
padx1 + (fw - downx) // 2,
|
379 |
+
pady0 + (fh - downy + 1) // 2,
|
380 |
+
pady1 + (fh - downy) // 2,
|
381 |
+
]
|
382 |
+
return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
|
383 |
+
|
384 |
+
#----------------------------------------------------------------------------
|
torch_utils/persistence.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Facilities for pickling Python code alongside other data.
|
10 |
+
|
11 |
+
The pickled code is automatically imported into a separate Python module
|
12 |
+
during unpickling. This way, any previously exported pickles will remain
|
13 |
+
usable even if the original code is no longer available, or if the current
|
14 |
+
version of the code is not consistent with what was originally pickled."""
|
15 |
+
|
16 |
+
import sys
|
17 |
+
import pickle
|
18 |
+
import io
|
19 |
+
import inspect
|
20 |
+
import copy
|
21 |
+
import uuid
|
22 |
+
import types
|
23 |
+
import dnnlib
|
24 |
+
|
25 |
+
#----------------------------------------------------------------------------
|
26 |
+
|
27 |
+
_version = 6 # internal version number
|
28 |
+
_decorators = set() # {decorator_class, ...}
|
29 |
+
_import_hooks = [] # [hook_function, ...]
|
30 |
+
_module_to_src_dict = dict() # {module: src, ...}
|
31 |
+
_src_to_module_dict = dict() # {src: module, ...}
|
32 |
+
|
33 |
+
#----------------------------------------------------------------------------
|
34 |
+
|
35 |
+
def persistent_class(orig_class):
|
36 |
+
r"""Class decorator that extends a given class to save its source code
|
37 |
+
when pickled.
|
38 |
+
|
39 |
+
Example:
|
40 |
+
|
41 |
+
from torch_utils import persistence
|
42 |
+
|
43 |
+
@persistence.persistent_class
|
44 |
+
class MyNetwork(torch.nn.Module):
|
45 |
+
def __init__(self, num_inputs, num_outputs):
|
46 |
+
super().__init__()
|
47 |
+
self.fc = MyLayer(num_inputs, num_outputs)
|
48 |
+
...
|
49 |
+
|
50 |
+
@persistence.persistent_class
|
51 |
+
class MyLayer(torch.nn.Module):
|
52 |
+
...
|
53 |
+
|
54 |
+
When pickled, any instance of `MyNetwork` and `MyLayer` will save its
|
55 |
+
source code alongside other internal state (e.g., parameters, buffers,
|
56 |
+
and submodules). This way, any previously exported pickle will remain
|
57 |
+
usable even if the class definitions have been modified or are no
|
58 |
+
longer available.
|
59 |
+
|
60 |
+
The decorator saves the source code of the entire Python module
|
61 |
+
containing the decorated class. It does *not* save the source code of
|
62 |
+
any imported modules. Thus, the imported modules must be available
|
63 |
+
during unpickling, also including `torch_utils.persistence` itself.
|
64 |
+
|
65 |
+
It is ok to call functions defined in the same module from the
|
66 |
+
decorated class. However, if the decorated class depends on other
|
67 |
+
classes defined in the same module, they must be decorated as well.
|
68 |
+
This is illustrated in the above example in the case of `MyLayer`.
|
69 |
+
|
70 |
+
It is also possible to employ the decorator just-in-time before
|
71 |
+
calling the constructor. For example:
|
72 |
+
|
73 |
+
cls = MyLayer
|
74 |
+
if want_to_make_it_persistent:
|
75 |
+
cls = persistence.persistent_class(cls)
|
76 |
+
layer = cls(num_inputs, num_outputs)
|
77 |
+
|
78 |
+
As an additional feature, the decorator also keeps track of the
|
79 |
+
arguments that were used to construct each instance of the decorated
|
80 |
+
class. The arguments can be queried via `obj.init_args` and
|
81 |
+
`obj.init_kwargs`, and they are automatically pickled alongside other
|
82 |
+
object state. A typical use case is to first unpickle a previous
|
83 |
+
instance of a persistent class, and then upgrade it to use the latest
|
84 |
+
version of the source code:
|
85 |
+
|
86 |
+
with open('old_pickle.pkl', 'rb') as f:
|
87 |
+
old_net = pickle.load(f)
|
88 |
+
new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
|
89 |
+
misc.copy_params_and_buffers(old_net, new_net, require_all=True)
|
90 |
+
"""
|
91 |
+
assert isinstance(orig_class, type)
|
92 |
+
if is_persistent(orig_class):
|
93 |
+
return orig_class
|
94 |
+
|
95 |
+
assert orig_class.__module__ in sys.modules
|
96 |
+
orig_module = sys.modules[orig_class.__module__]
|
97 |
+
orig_module_src = _module_to_src(orig_module)
|
98 |
+
|
99 |
+
class Decorator(orig_class):
|
100 |
+
_orig_module_src = orig_module_src
|
101 |
+
_orig_class_name = orig_class.__name__
|
102 |
+
|
103 |
+
def __init__(self, *args, **kwargs):
|
104 |
+
super().__init__(*args, **kwargs)
|
105 |
+
self._init_args = copy.deepcopy(args)
|
106 |
+
self._init_kwargs = copy.deepcopy(kwargs)
|
107 |
+
assert orig_class.__name__ in orig_module.__dict__
|
108 |
+
_check_pickleable(self.__reduce__())
|
109 |
+
|
110 |
+
@property
|
111 |
+
def init_args(self):
|
112 |
+
return copy.deepcopy(self._init_args)
|
113 |
+
|
114 |
+
@property
|
115 |
+
def init_kwargs(self):
|
116 |
+
return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs))
|
117 |
+
|
118 |
+
def __reduce__(self):
|
119 |
+
fields = list(super().__reduce__())
|
120 |
+
fields += [None] * max(3 - len(fields), 0)
|
121 |
+
if fields[0] is not _reconstruct_persistent_obj:
|
122 |
+
meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
|
123 |
+
fields[0] = _reconstruct_persistent_obj # reconstruct func
|
124 |
+
fields[1] = (meta,) # reconstruct args
|
125 |
+
fields[2] = None # state dict
|
126 |
+
return tuple(fields)
|
127 |
+
|
128 |
+
Decorator.__name__ = orig_class.__name__
|
129 |
+
_decorators.add(Decorator)
|
130 |
+
return Decorator
|
131 |
+
|
132 |
+
#----------------------------------------------------------------------------
|
133 |
+
|
134 |
+
def is_persistent(obj):
|
135 |
+
r"""Test whether the given object or class is persistent, i.e.,
|
136 |
+
whether it will save its source code when pickled.
|
137 |
+
"""
|
138 |
+
try:
|
139 |
+
if obj in _decorators:
|
140 |
+
return True
|
141 |
+
except TypeError:
|
142 |
+
pass
|
143 |
+
return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
|
144 |
+
|
145 |
+
#----------------------------------------------------------------------------
|
146 |
+
|
147 |
+
def import_hook(hook):
|
148 |
+
r"""Register an import hook that is called whenever a persistent object
|
149 |
+
is being unpickled. A typical use case is to patch the pickled source
|
150 |
+
code to avoid errors and inconsistencies when the API of some imported
|
151 |
+
module has changed.
|
152 |
+
|
153 |
+
The hook should have the following signature:
|
154 |
+
|
155 |
+
hook(meta) -> modified meta
|
156 |
+
|
157 |
+
`meta` is an instance of `dnnlib.EasyDict` with the following fields:
|
158 |
+
|
159 |
+
type: Type of the persistent object, e.g. `'class'`.
|
160 |
+
version: Internal version number of `torch_utils.persistence`.
|
161 |
+
module_src Original source code of the Python module.
|
162 |
+
class_name: Class name in the original Python module.
|
163 |
+
state: Internal state of the object.
|
164 |
+
|
165 |
+
Example:
|
166 |
+
|
167 |
+
@persistence.import_hook
|
168 |
+
def wreck_my_network(meta):
|
169 |
+
if meta.class_name == 'MyNetwork':
|
170 |
+
print('MyNetwork is being imported. I will wreck it!')
|
171 |
+
meta.module_src = meta.module_src.replace("True", "False")
|
172 |
+
return meta
|
173 |
+
"""
|
174 |
+
assert callable(hook)
|
175 |
+
_import_hooks.append(hook)
|
176 |
+
|
177 |
+
#----------------------------------------------------------------------------
|
178 |
+
|
179 |
+
def _reconstruct_persistent_obj(meta):
|
180 |
+
r"""Hook that is called internally by the `pickle` module to unpickle
|
181 |
+
a persistent object.
|
182 |
+
"""
|
183 |
+
meta = dnnlib.EasyDict(meta)
|
184 |
+
meta.state = dnnlib.EasyDict(meta.state)
|
185 |
+
for hook in _import_hooks:
|
186 |
+
meta = hook(meta)
|
187 |
+
assert meta is not None
|
188 |
+
|
189 |
+
assert meta.version == _version
|
190 |
+
module = _src_to_module(meta.module_src)
|
191 |
+
|
192 |
+
assert meta.type == 'class'
|
193 |
+
orig_class = module.__dict__[meta.class_name]
|
194 |
+
decorator_class = persistent_class(orig_class)
|
195 |
+
obj = decorator_class.__new__(decorator_class)
|
196 |
+
|
197 |
+
setstate = getattr(obj, '__setstate__', None)
|
198 |
+
if callable(setstate):
|
199 |
+
setstate(meta.state) # pylint: disable=not-callable
|
200 |
+
else:
|
201 |
+
obj.__dict__.update(meta.state)
|
202 |
+
return obj
|
203 |
+
|
204 |
+
#----------------------------------------------------------------------------
|
205 |
+
|
206 |
+
def _module_to_src(module):
|
207 |
+
r"""Query the source code of a given Python module.
|
208 |
+
"""
|
209 |
+
src = _module_to_src_dict.get(module, None)
|
210 |
+
if src is None:
|
211 |
+
src = inspect.getsource(module)
|
212 |
+
_module_to_src_dict[module] = src
|
213 |
+
_src_to_module_dict[src] = module
|
214 |
+
return src
|
215 |
+
|
216 |
+
def _src_to_module(src):
|
217 |
+
r"""Get or create a Python module for the given source code.
|
218 |
+
"""
|
219 |
+
module = _src_to_module_dict.get(src, None)
|
220 |
+
if module is None:
|
221 |
+
module_name = "_imported_module_" + uuid.uuid4().hex
|
222 |
+
module = types.ModuleType(module_name)
|
223 |
+
sys.modules[module_name] = module
|
224 |
+
_module_to_src_dict[module] = src
|
225 |
+
_src_to_module_dict[src] = module
|
226 |
+
exec(src, module.__dict__) # pylint: disable=exec-used
|
227 |
+
return module
|
228 |
+
|
229 |
+
#----------------------------------------------------------------------------
|
230 |
+
|
231 |
+
def _check_pickleable(obj):
|
232 |
+
r"""Check that the given object is pickleable, raising an exception if
|
233 |
+
it is not. This function is expected to be considerably more efficient
|
234 |
+
than actually pickling the object.
|
235 |
+
"""
|
236 |
+
def recurse(obj):
|
237 |
+
if isinstance(obj, (list, tuple, set)):
|
238 |
+
return [recurse(x) for x in obj]
|
239 |
+
if isinstance(obj, dict):
|
240 |
+
return [[recurse(x), recurse(y)] for x, y in obj.items()]
|
241 |
+
if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
|
242 |
+
return None # Python primitive types are pickleable.
|
243 |
+
if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']:
|
244 |
+
return None # NumPy arrays and PyTorch tensors are pickleable.
|
245 |
+
if is_persistent(obj):
|
246 |
+
return None # Persistent objects are pickleable, by virtue of the constructor check.
|
247 |
+
return obj
|
248 |
+
with io.BytesIO() as f:
|
249 |
+
pickle.dump(recurse(obj), f)
|
250 |
+
|
251 |
+
#----------------------------------------------------------------------------
|
torch_utils/training_stats.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Facilities for reporting and collecting training statistics across
|
10 |
+
multiple processes and devices. The interface is designed to minimize
|
11 |
+
synchronization overhead as well as the amount of boilerplate in user
|
12 |
+
code."""
|
13 |
+
|
14 |
+
import re
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
import dnnlib
|
18 |
+
|
19 |
+
from . import misc
|
20 |
+
|
21 |
+
#----------------------------------------------------------------------------
|
22 |
+
|
23 |
+
_num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
|
24 |
+
_reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
|
25 |
+
_counter_dtype = torch.float64 # Data type to use for the internal counters.
|
26 |
+
_rank = 0 # Rank of the current process.
|
27 |
+
_sync_device = None # Device to use for multiprocess communication. None = single-process.
|
28 |
+
_sync_called = False # Has _sync() been called yet?
|
29 |
+
_counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor
|
30 |
+
_cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
|
31 |
+
|
32 |
+
#----------------------------------------------------------------------------
|
33 |
+
|
34 |
+
def init_multiprocessing(rank, sync_device):
|
35 |
+
r"""Initializes `torch_utils.training_stats` for collecting statistics
|
36 |
+
across multiple processes.
|
37 |
+
|
38 |
+
This function must be called after
|
39 |
+
`torch.distributed.init_process_group()` and before `Collector.update()`.
|
40 |
+
The call is not necessary if multi-process collection is not needed.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
rank: Rank of the current process.
|
44 |
+
sync_device: PyTorch device to use for inter-process
|
45 |
+
communication, or None to disable multi-process
|
46 |
+
collection. Typically `torch.device('cuda', rank)`.
|
47 |
+
"""
|
48 |
+
global _rank, _sync_device
|
49 |
+
assert not _sync_called
|
50 |
+
_rank = rank
|
51 |
+
_sync_device = sync_device
|
52 |
+
|
53 |
+
#----------------------------------------------------------------------------
|
54 |
+
|
55 |
+
@misc.profiled_function
|
56 |
+
def report(name, value):
|
57 |
+
r"""Broadcasts the given set of scalars to all interested instances of
|
58 |
+
`Collector`, across device and process boundaries.
|
59 |
+
|
60 |
+
This function is expected to be extremely cheap and can be safely
|
61 |
+
called from anywhere in the training loop, loss function, or inside a
|
62 |
+
`torch.nn.Module`.
|
63 |
+
|
64 |
+
Warning: The current implementation expects the set of unique names to
|
65 |
+
be consistent across processes. Please make sure that `report()` is
|
66 |
+
called at least once for each unique name by each process, and in the
|
67 |
+
same order. If a given process has no scalars to broadcast, it can do
|
68 |
+
`report(name, [])` (empty list).
|
69 |
+
|
70 |
+
Args:
|
71 |
+
name: Arbitrary string specifying the name of the statistic.
|
72 |
+
Averages are accumulated separately for each unique name.
|
73 |
+
value: Arbitrary set of scalars. Can be a list, tuple,
|
74 |
+
NumPy array, PyTorch tensor, or Python scalar.
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
The same `value` that was passed in.
|
78 |
+
"""
|
79 |
+
if name not in _counters:
|
80 |
+
_counters[name] = dict()
|
81 |
+
|
82 |
+
elems = torch.as_tensor(value)
|
83 |
+
if elems.numel() == 0:
|
84 |
+
return value
|
85 |
+
|
86 |
+
elems = elems.detach().flatten().to(_reduce_dtype)
|
87 |
+
moments = torch.stack([
|
88 |
+
torch.ones_like(elems).sum(),
|
89 |
+
elems.sum(),
|
90 |
+
elems.square().sum(),
|
91 |
+
])
|
92 |
+
assert moments.ndim == 1 and moments.shape[0] == _num_moments
|
93 |
+
moments = moments.to(_counter_dtype)
|
94 |
+
|
95 |
+
device = moments.device
|
96 |
+
if device not in _counters[name]:
|
97 |
+
_counters[name][device] = torch.zeros_like(moments)
|
98 |
+
_counters[name][device].add_(moments)
|
99 |
+
return value
|
100 |
+
|
101 |
+
#----------------------------------------------------------------------------
|
102 |
+
|
103 |
+
def report0(name, value):
|
104 |
+
r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
|
105 |
+
but ignores any scalars provided by the other processes.
|
106 |
+
See `report()` for further details.
|
107 |
+
"""
|
108 |
+
report(name, value if _rank == 0 else [])
|
109 |
+
return value
|
110 |
+
|
111 |
+
#----------------------------------------------------------------------------
|
112 |
+
|
113 |
+
class Collector:
|
114 |
+
r"""Collects the scalars broadcasted by `report()` and `report0()` and
|
115 |
+
computes their long-term averages (mean and standard deviation) over
|
116 |
+
user-defined periods of time.
|
117 |
+
|
118 |
+
The averages are first collected into internal counters that are not
|
119 |
+
directly visible to the user. They are then copied to the user-visible
|
120 |
+
state as a result of calling `update()` and can then be queried using
|
121 |
+
`mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the
|
122 |
+
internal counters for the next round, so that the user-visible state
|
123 |
+
effectively reflects averages collected between the last two calls to
|
124 |
+
`update()`.
|
125 |
+
|
126 |
+
Args:
|
127 |
+
regex: Regular expression defining which statistics to
|
128 |
+
collect. The default is to collect everything.
|
129 |
+
keep_previous: Whether to retain the previous averages if no
|
130 |
+
scalars were collected on a given round
|
131 |
+
(default: True).
|
132 |
+
"""
|
133 |
+
def __init__(self, regex='.*', keep_previous=True):
|
134 |
+
self._regex = re.compile(regex)
|
135 |
+
self._keep_previous = keep_previous
|
136 |
+
self._cumulative = dict()
|
137 |
+
self._moments = dict()
|
138 |
+
self.update()
|
139 |
+
self._moments.clear()
|
140 |
+
|
141 |
+
def names(self):
|
142 |
+
r"""Returns the names of all statistics broadcasted so far that
|
143 |
+
match the regular expression specified at construction time.
|
144 |
+
"""
|
145 |
+
return [name for name in _counters if self._regex.fullmatch(name)]
|
146 |
+
|
147 |
+
def update(self):
|
148 |
+
r"""Copies current values of the internal counters to the
|
149 |
+
user-visible state and resets them for the next round.
|
150 |
+
|
151 |
+
If `keep_previous=True` was specified at construction time, the
|
152 |
+
operation is skipped for statistics that have received no scalars
|
153 |
+
since the last update, retaining their previous averages.
|
154 |
+
|
155 |
+
This method performs a number of GPU-to-CPU transfers and one
|
156 |
+
`torch.distributed.all_reduce()`. It is intended to be called
|
157 |
+
periodically in the main training loop, typically once every
|
158 |
+
N training steps.
|
159 |
+
"""
|
160 |
+
if not self._keep_previous:
|
161 |
+
self._moments.clear()
|
162 |
+
for name, cumulative in _sync(self.names()):
|
163 |
+
if name not in self._cumulative:
|
164 |
+
self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
|
165 |
+
delta = cumulative - self._cumulative[name]
|
166 |
+
self._cumulative[name].copy_(cumulative)
|
167 |
+
if float(delta[0]) != 0:
|
168 |
+
self._moments[name] = delta
|
169 |
+
|
170 |
+
def _get_delta(self, name):
|
171 |
+
r"""Returns the raw moments that were accumulated for the given
|
172 |
+
statistic between the last two calls to `update()`, or zero if
|
173 |
+
no scalars were collected.
|
174 |
+
"""
|
175 |
+
assert self._regex.fullmatch(name)
|
176 |
+
if name not in self._moments:
|
177 |
+
self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
|
178 |
+
return self._moments[name]
|
179 |
+
|
180 |
+
def num(self, name):
|
181 |
+
r"""Returns the number of scalars that were accumulated for the given
|
182 |
+
statistic between the last two calls to `update()`, or zero if
|
183 |
+
no scalars were collected.
|
184 |
+
"""
|
185 |
+
delta = self._get_delta(name)
|
186 |
+
return int(delta[0])
|
187 |
+
|
188 |
+
def mean(self, name):
|
189 |
+
r"""Returns the mean of the scalars that were accumulated for the
|
190 |
+
given statistic between the last two calls to `update()`, or NaN if
|
191 |
+
no scalars were collected.
|
192 |
+
"""
|
193 |
+
delta = self._get_delta(name)
|
194 |
+
if int(delta[0]) == 0:
|
195 |
+
return float('nan')
|
196 |
+
return float(delta[1] / delta[0])
|
197 |
+
|
198 |
+
def std(self, name):
|
199 |
+
r"""Returns the standard deviation of the scalars that were
|
200 |
+
accumulated for the given statistic between the last two calls to
|
201 |
+
`update()`, or NaN if no scalars were collected.
|
202 |
+
"""
|
203 |
+
delta = self._get_delta(name)
|
204 |
+
if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
|
205 |
+
return float('nan')
|
206 |
+
if int(delta[0]) == 1:
|
207 |
+
return float(0)
|
208 |
+
mean = float(delta[1] / delta[0])
|
209 |
+
raw_var = float(delta[2] / delta[0])
|
210 |
+
return np.sqrt(max(raw_var - np.square(mean), 0))
|
211 |
+
|
212 |
+
def as_dict(self):
|
213 |
+
r"""Returns the averages accumulated between the last two calls to
|
214 |
+
`update()` as an `dnnlib.EasyDict`. The contents are as follows:
|
215 |
+
|
216 |
+
dnnlib.EasyDict(
|
217 |
+
NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT),
|
218 |
+
...
|
219 |
+
)
|
220 |
+
"""
|
221 |
+
stats = dnnlib.EasyDict()
|
222 |
+
for name in self.names():
|
223 |
+
stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
|
224 |
+
return stats
|
225 |
+
|
226 |
+
def __getitem__(self, name):
|
227 |
+
r"""Convenience getter.
|
228 |
+
`collector[name]` is a synonym for `collector.mean(name)`.
|
229 |
+
"""
|
230 |
+
return self.mean(name)
|
231 |
+
|
232 |
+
#----------------------------------------------------------------------------
|
233 |
+
|
234 |
+
def _sync(names):
|
235 |
+
r"""Synchronize the global cumulative counters across devices and
|
236 |
+
processes. Called internally by `Collector.update()`.
|
237 |
+
"""
|
238 |
+
if len(names) == 0:
|
239 |
+
return []
|
240 |
+
global _sync_called
|
241 |
+
_sync_called = True
|
242 |
+
|
243 |
+
# Collect deltas within current rank.
|
244 |
+
deltas = []
|
245 |
+
device = _sync_device if _sync_device is not None else torch.device('cpu')
|
246 |
+
for name in names:
|
247 |
+
delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
|
248 |
+
for counter in _counters[name].values():
|
249 |
+
delta.add_(counter.to(device))
|
250 |
+
counter.copy_(torch.zeros_like(counter))
|
251 |
+
deltas.append(delta)
|
252 |
+
deltas = torch.stack(deltas)
|
253 |
+
|
254 |
+
# Sum deltas across ranks.
|
255 |
+
if _sync_device is not None:
|
256 |
+
torch.distributed.all_reduce(deltas)
|
257 |
+
|
258 |
+
# Update cumulative values.
|
259 |
+
deltas = deltas.cpu()
|
260 |
+
for idx, name in enumerate(names):
|
261 |
+
if name not in _cumulative:
|
262 |
+
_cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
|
263 |
+
_cumulative[name].add_(deltas[idx])
|
264 |
+
|
265 |
+
# Return name-value pairs.
|
266 |
+
return [(name, _cumulative[name]) for name in names]
|
267 |
+
|
268 |
+
#----------------------------------------------------------------------------
|
train.py
ADDED
@@ -0,0 +1,540 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Train a GAN using the techniques described in the paper
|
10 |
+
"Training Generative Adversarial Networks with Limited Data"."""
|
11 |
+
|
12 |
+
import os
|
13 |
+
import click
|
14 |
+
import re
|
15 |
+
import json
|
16 |
+
import tempfile
|
17 |
+
import torch
|
18 |
+
import dnnlib
|
19 |
+
|
20 |
+
from training import training_loop
|
21 |
+
from metrics import metric_main
|
22 |
+
from torch_utils import training_stats
|
23 |
+
from torch_utils import custom_ops
|
24 |
+
|
25 |
+
#----------------------------------------------------------------------------
|
26 |
+
|
27 |
+
class UserError(Exception):
|
28 |
+
pass
|
29 |
+
|
30 |
+
#----------------------------------------------------------------------------
|
31 |
+
|
32 |
+
def setup_training_loop_kwargs(
|
33 |
+
# General options (not included in desc).
|
34 |
+
gpus = None, # Number of GPUs: <int>, default = 1 gpu
|
35 |
+
snap = None, # Snapshot interval: <int>, default = 50 ticks
|
36 |
+
metrics = None, # List of metric names: [], ['fid50k_full'] (default), ...
|
37 |
+
seed = None, # Random seed: <int>, default = 0
|
38 |
+
|
39 |
+
# Dataset.
|
40 |
+
data = None, # Training dataset (required): <path>
|
41 |
+
cond = None, # Train conditional model based on dataset labels: <bool>, default = False
|
42 |
+
subset = None, # Train with only N images: <int>, default = all
|
43 |
+
mirror = None, # Augment dataset with x-flips: <bool>, default = False
|
44 |
+
|
45 |
+
# Base config.
|
46 |
+
cfg = None, # Base config: 'auto' (default), 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar'
|
47 |
+
gamma = None, # Override R1 gamma: <float>
|
48 |
+
kimg = None, # Override training duration: <int>
|
49 |
+
batch = None, # Override batch size: <int>
|
50 |
+
|
51 |
+
# Discriminator augmentation.
|
52 |
+
aug = None, # Augmentation mode: 'ada' (default), 'noaug', 'fixed'
|
53 |
+
p = None, # Specify p for 'fixed' (required): <float>
|
54 |
+
target = None, # Override ADA target for 'ada': <float>, default = depends on aug
|
55 |
+
augpipe = None, # Augmentation pipeline: 'blit', 'geom', 'color', 'filter', 'noise', 'cutout', 'bg', 'bgc' (default), ..., 'bgcfnc'
|
56 |
+
|
57 |
+
# Transfer learning.
|
58 |
+
resume = None, # Load previous network: 'noresume' (default), 'ffhq256', 'ffhq512', 'ffhq1024', 'celebahq256', 'lsundog256', <file>, <url>
|
59 |
+
freezed = None, # Freeze-D: <int>, default = 0 discriminator layers
|
60 |
+
|
61 |
+
# Performance options (not included in desc).
|
62 |
+
fp32 = None, # Disable mixed-precision training: <bool>, default = False
|
63 |
+
nhwc = None, # Use NHWC memory format with FP16: <bool>, default = False
|
64 |
+
allow_tf32 = None, # Allow PyTorch to use TF32 for matmul and convolutions: <bool>, default = False
|
65 |
+
nobench = None, # Disable cuDNN benchmarking: <bool>, default = False
|
66 |
+
workers = None, # Override number of DataLoader workers: <int>, default = 3
|
67 |
+
):
|
68 |
+
args = dnnlib.EasyDict()
|
69 |
+
|
70 |
+
# ------------------------------------------
|
71 |
+
# General options: gpus, snap, metrics, seed
|
72 |
+
# ------------------------------------------
|
73 |
+
|
74 |
+
if gpus is None:
|
75 |
+
gpus = 1
|
76 |
+
assert isinstance(gpus, int)
|
77 |
+
if not (gpus >= 1 and gpus & (gpus - 1) == 0):
|
78 |
+
raise UserError('--gpus must be a power of two')
|
79 |
+
args.num_gpus = gpus
|
80 |
+
|
81 |
+
if snap is None:
|
82 |
+
snap = 50
|
83 |
+
assert isinstance(snap, int)
|
84 |
+
if snap < 1:
|
85 |
+
raise UserError('--snap must be at least 1')
|
86 |
+
args.image_snapshot_ticks = snap
|
87 |
+
args.network_snapshot_ticks = snap
|
88 |
+
|
89 |
+
if metrics is None:
|
90 |
+
metrics = ['fid50k_full']
|
91 |
+
assert isinstance(metrics, list)
|
92 |
+
if not all(metric_main.is_valid_metric(metric) for metric in metrics):
|
93 |
+
raise UserError('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics()))
|
94 |
+
args.metrics = metrics
|
95 |
+
|
96 |
+
if seed is None:
|
97 |
+
seed = 0
|
98 |
+
assert isinstance(seed, int)
|
99 |
+
args.random_seed = seed
|
100 |
+
|
101 |
+
# -----------------------------------
|
102 |
+
# Dataset: data, cond, subset, mirror
|
103 |
+
# -----------------------------------
|
104 |
+
|
105 |
+
assert data is not None
|
106 |
+
assert isinstance(data, str)
|
107 |
+
args.training_set_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=data, use_labels=True, max_size=None, xflip=False)
|
108 |
+
args.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, num_workers=3, prefetch_factor=2)
|
109 |
+
try:
|
110 |
+
training_set = dnnlib.util.construct_class_by_name(**args.training_set_kwargs) # subclass of training.dataset.Dataset
|
111 |
+
args.training_set_kwargs.resolution = training_set.resolution # be explicit about resolution
|
112 |
+
args.training_set_kwargs.use_labels = training_set.has_labels # be explicit about labels
|
113 |
+
args.training_set_kwargs.max_size = len(training_set) # be explicit about dataset size
|
114 |
+
desc = training_set.name
|
115 |
+
del training_set # conserve memory
|
116 |
+
except IOError as err:
|
117 |
+
raise UserError(f'--data: {err}')
|
118 |
+
|
119 |
+
if cond is None:
|
120 |
+
cond = False
|
121 |
+
assert isinstance(cond, bool)
|
122 |
+
if cond:
|
123 |
+
if not args.training_set_kwargs.use_labels:
|
124 |
+
raise UserError('--cond=True requires labels specified in dataset.json')
|
125 |
+
desc += '-cond'
|
126 |
+
else:
|
127 |
+
args.training_set_kwargs.use_labels = False
|
128 |
+
|
129 |
+
if subset is not None:
|
130 |
+
assert isinstance(subset, int)
|
131 |
+
if not 1 <= subset <= args.training_set_kwargs.max_size:
|
132 |
+
raise UserError(f'--subset must be between 1 and {args.training_set_kwargs.max_size}')
|
133 |
+
desc += f'-subset{subset}'
|
134 |
+
if subset < args.training_set_kwargs.max_size:
|
135 |
+
args.training_set_kwargs.max_size = subset
|
136 |
+
args.training_set_kwargs.random_seed = args.random_seed
|
137 |
+
|
138 |
+
if mirror is None:
|
139 |
+
mirror = False
|
140 |
+
assert isinstance(mirror, bool)
|
141 |
+
if mirror:
|
142 |
+
desc += '-mirror'
|
143 |
+
args.training_set_kwargs.xflip = True
|
144 |
+
|
145 |
+
# ------------------------------------
|
146 |
+
# Base config: cfg, gamma, kimg, batch
|
147 |
+
# ------------------------------------
|
148 |
+
|
149 |
+
if cfg is None:
|
150 |
+
cfg = 'auto'
|
151 |
+
assert isinstance(cfg, str)
|
152 |
+
desc += f'-{cfg}'
|
153 |
+
|
154 |
+
cfg_specs = {
|
155 |
+
'auto': dict(ref_gpus=-1, kimg=25000, mb=-1, mbstd=-1, fmaps=-1, lrate=-1, gamma=-1, ema=-1, ramp=0.05, map=2), # Populated dynamically based on resolution and GPU count.
|
156 |
+
'stylegan2': dict(ref_gpus=8, kimg=25000, mb=32, mbstd=4, fmaps=1, lrate=0.002, gamma=10, ema=10, ramp=None, map=8), # Uses mixed-precision, unlike the original StyleGAN2.
|
157 |
+
'paper256': dict(ref_gpus=8, kimg=25000, mb=64, mbstd=8, fmaps=0.5, lrate=0.0025, gamma=1, ema=20, ramp=None, map=8),
|
158 |
+
'paper512': dict(ref_gpus=8, kimg=25000, mb=64, mbstd=8, fmaps=1, lrate=0.0025, gamma=0.5, ema=20, ramp=None, map=8),
|
159 |
+
'paper1024': dict(ref_gpus=8, kimg=25000, mb=32, mbstd=4, fmaps=1, lrate=0.002, gamma=2, ema=10, ramp=None, map=8),
|
160 |
+
'cifar': dict(ref_gpus=2, kimg=100000, mb=64, mbstd=32, fmaps=1, lrate=0.0025, gamma=0.01, ema=500, ramp=0.05, map=2),
|
161 |
+
}
|
162 |
+
|
163 |
+
assert cfg in cfg_specs
|
164 |
+
spec = dnnlib.EasyDict(cfg_specs[cfg])
|
165 |
+
if cfg == 'auto':
|
166 |
+
desc += f'{gpus:d}'
|
167 |
+
spec.ref_gpus = gpus
|
168 |
+
res = args.training_set_kwargs.resolution
|
169 |
+
spec.mb = max(min(gpus * min(4096 // res, 32), 64), gpus) # keep gpu memory consumption at bay
|
170 |
+
spec.mbstd = min(spec.mb // gpus, 4) # other hyperparams behave more predictably if mbstd group size remains fixed
|
171 |
+
spec.fmaps = 1 if res >= 512 else 0.5
|
172 |
+
spec.lrate = 0.002 if res >= 1024 else 0.0025
|
173 |
+
spec.gamma = 0.0002 * (res ** 2) / spec.mb # heuristic formula
|
174 |
+
spec.ema = spec.mb * 10 / 32
|
175 |
+
|
176 |
+
args.G_kwargs = dnnlib.EasyDict(class_name='training.networks.Generator', z_dim=512, w_dim=512, mapping_kwargs=dnnlib.EasyDict(), synthesis_kwargs=dnnlib.EasyDict())
|
177 |
+
args.D_kwargs = dnnlib.EasyDict(class_name='training.networks.Discriminator', block_kwargs=dnnlib.EasyDict(), mapping_kwargs=dnnlib.EasyDict(), epilogue_kwargs=dnnlib.EasyDict())
|
178 |
+
args.G_kwargs.synthesis_kwargs.channel_base = args.D_kwargs.channel_base = int(spec.fmaps * 32768)
|
179 |
+
args.G_kwargs.synthesis_kwargs.channel_max = args.D_kwargs.channel_max = 512
|
180 |
+
args.G_kwargs.mapping_kwargs.num_layers = spec.map
|
181 |
+
args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 4 # enable mixed-precision training
|
182 |
+
args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = 256 # clamp activations to avoid float16 overflow
|
183 |
+
args.D_kwargs.epilogue_kwargs.mbstd_group_size = spec.mbstd
|
184 |
+
|
185 |
+
args.G_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=spec.lrate, betas=[0,0.99], eps=1e-8)
|
186 |
+
args.D_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=spec.lrate, betas=[0,0.99], eps=1e-8)
|
187 |
+
args.loss_kwargs = dnnlib.EasyDict(class_name='training.loss.StyleGAN2Loss', r1_gamma=spec.gamma)
|
188 |
+
|
189 |
+
args.total_kimg = spec.kimg
|
190 |
+
args.batch_size = spec.mb
|
191 |
+
args.batch_gpu = spec.mb // spec.ref_gpus
|
192 |
+
args.ema_kimg = spec.ema
|
193 |
+
args.ema_rampup = spec.ramp
|
194 |
+
|
195 |
+
if cfg == 'cifar':
|
196 |
+
args.loss_kwargs.pl_weight = 0 # disable path length regularization
|
197 |
+
args.loss_kwargs.style_mixing_prob = 0 # disable style mixing
|
198 |
+
args.D_kwargs.architecture = 'orig' # disable residual skip connections
|
199 |
+
|
200 |
+
if gamma is not None:
|
201 |
+
assert isinstance(gamma, float)
|
202 |
+
if not gamma >= 0:
|
203 |
+
raise UserError('--gamma must be non-negative')
|
204 |
+
desc += f'-gamma{gamma:g}'
|
205 |
+
args.loss_kwargs.r1_gamma = gamma
|
206 |
+
|
207 |
+
if kimg is not None:
|
208 |
+
assert isinstance(kimg, int)
|
209 |
+
if not kimg >= 1:
|
210 |
+
raise UserError('--kimg must be at least 1')
|
211 |
+
desc += f'-kimg{kimg:d}'
|
212 |
+
args.total_kimg = kimg
|
213 |
+
|
214 |
+
if batch is not None:
|
215 |
+
assert isinstance(batch, int)
|
216 |
+
if not (batch >= 1 and batch % gpus == 0):
|
217 |
+
raise UserError('--batch must be at least 1 and divisible by --gpus')
|
218 |
+
desc += f'-batch{batch}'
|
219 |
+
args.batch_size = batch
|
220 |
+
args.batch_gpu = batch // gpus
|
221 |
+
|
222 |
+
# ---------------------------------------------------
|
223 |
+
# Discriminator augmentation: aug, p, target, augpipe
|
224 |
+
# ---------------------------------------------------
|
225 |
+
|
226 |
+
if aug is None:
|
227 |
+
aug = 'ada'
|
228 |
+
else:
|
229 |
+
assert isinstance(aug, str)
|
230 |
+
desc += f'-{aug}'
|
231 |
+
|
232 |
+
if aug == 'ada':
|
233 |
+
args.ada_target = 0.6
|
234 |
+
|
235 |
+
elif aug == 'noaug':
|
236 |
+
pass
|
237 |
+
|
238 |
+
elif aug == 'fixed':
|
239 |
+
if p is None:
|
240 |
+
raise UserError(f'--aug={aug} requires specifying --p')
|
241 |
+
|
242 |
+
else:
|
243 |
+
raise UserError(f'--aug={aug} not supported')
|
244 |
+
|
245 |
+
if p is not None:
|
246 |
+
assert isinstance(p, float)
|
247 |
+
if aug != 'fixed':
|
248 |
+
raise UserError('--p can only be specified with --aug=fixed')
|
249 |
+
if not 0 <= p <= 1:
|
250 |
+
raise UserError('--p must be between 0 and 1')
|
251 |
+
desc += f'-p{p:g}'
|
252 |
+
args.augment_p = p
|
253 |
+
|
254 |
+
if target is not None:
|
255 |
+
assert isinstance(target, float)
|
256 |
+
if aug != 'ada':
|
257 |
+
raise UserError('--target can only be specified with --aug=ada')
|
258 |
+
if not 0 <= target <= 1:
|
259 |
+
raise UserError('--target must be between 0 and 1')
|
260 |
+
desc += f'-target{target:g}'
|
261 |
+
args.ada_target = target
|
262 |
+
|
263 |
+
assert augpipe is None or isinstance(augpipe, str)
|
264 |
+
if augpipe is None:
|
265 |
+
augpipe = 'bgc'
|
266 |
+
else:
|
267 |
+
if aug == 'noaug':
|
268 |
+
raise UserError('--augpipe cannot be specified with --aug=noaug')
|
269 |
+
desc += f'-{augpipe}'
|
270 |
+
|
271 |
+
augpipe_specs = {
|
272 |
+
'blit': dict(xflip=1, rotate90=1, xint=1),
|
273 |
+
'geom': dict(scale=1, rotate=1, aniso=1, xfrac=1),
|
274 |
+
'color': dict(brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1),
|
275 |
+
'filter': dict(imgfilter=1),
|
276 |
+
'noise': dict(noise=1),
|
277 |
+
'cutout': dict(cutout=1),
|
278 |
+
'bg': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1),
|
279 |
+
'bgc': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1),
|
280 |
+
'bgcf': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1),
|
281 |
+
'bgcfn': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1, noise=1),
|
282 |
+
'bgcfnc': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1, noise=1, cutout=1),
|
283 |
+
}
|
284 |
+
|
285 |
+
assert augpipe in augpipe_specs
|
286 |
+
if aug != 'noaug':
|
287 |
+
args.augment_kwargs = dnnlib.EasyDict(class_name='training.augment.AugmentPipe', **augpipe_specs[augpipe])
|
288 |
+
|
289 |
+
# ----------------------------------
|
290 |
+
# Transfer learning: resume, freezed
|
291 |
+
# ----------------------------------
|
292 |
+
|
293 |
+
resume_specs = {
|
294 |
+
'ffhq256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res256-mirror-paper256-noaug.pkl',
|
295 |
+
'ffhq512': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res512-mirror-stylegan2-noaug.pkl',
|
296 |
+
'ffhq1024': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res1024-mirror-stylegan2-noaug.pkl',
|
297 |
+
'celebahq256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/celebahq-res256-mirror-paper256-kimg100000-ada-target0.5.pkl',
|
298 |
+
'lsundog256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/lsundog-res256-paper256-kimg100000-noaug.pkl',
|
299 |
+
}
|
300 |
+
|
301 |
+
assert resume is None or isinstance(resume, str)
|
302 |
+
if resume is None:
|
303 |
+
resume = 'noresume'
|
304 |
+
elif resume == 'noresume':
|
305 |
+
desc += '-noresume'
|
306 |
+
elif resume in resume_specs:
|
307 |
+
desc += f'-resume{resume}'
|
308 |
+
args.resume_pkl = resume_specs[resume] # predefined url
|
309 |
+
else:
|
310 |
+
desc += '-resumecustom'
|
311 |
+
args.resume_pkl = resume # custom path or url
|
312 |
+
|
313 |
+
if resume != 'noresume':
|
314 |
+
args.ada_kimg = 100 # make ADA react faster at the beginning
|
315 |
+
args.ema_rampup = None # disable EMA rampup
|
316 |
+
|
317 |
+
if freezed is not None:
|
318 |
+
assert isinstance(freezed, int)
|
319 |
+
if not freezed >= 0:
|
320 |
+
raise UserError('--freezed must be non-negative')
|
321 |
+
desc += f'-freezed{freezed:d}'
|
322 |
+
args.D_kwargs.block_kwargs.freeze_layers = freezed
|
323 |
+
|
324 |
+
# -------------------------------------------------
|
325 |
+
# Performance options: fp32, nhwc, nobench, workers
|
326 |
+
# -------------------------------------------------
|
327 |
+
|
328 |
+
if fp32 is None:
|
329 |
+
fp32 = False
|
330 |
+
assert isinstance(fp32, bool)
|
331 |
+
if fp32:
|
332 |
+
args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 0
|
333 |
+
args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = None
|
334 |
+
|
335 |
+
if nhwc is None:
|
336 |
+
nhwc = False
|
337 |
+
assert isinstance(nhwc, bool)
|
338 |
+
if nhwc:
|
339 |
+
args.G_kwargs.synthesis_kwargs.fp16_channels_last = args.D_kwargs.block_kwargs.fp16_channels_last = True
|
340 |
+
|
341 |
+
if nobench is None:
|
342 |
+
nobench = False
|
343 |
+
assert isinstance(nobench, bool)
|
344 |
+
if nobench:
|
345 |
+
args.cudnn_benchmark = False
|
346 |
+
|
347 |
+
if allow_tf32 is None:
|
348 |
+
allow_tf32 = False
|
349 |
+
assert isinstance(allow_tf32, bool)
|
350 |
+
if allow_tf32:
|
351 |
+
args.allow_tf32 = True
|
352 |
+
|
353 |
+
if workers is not None:
|
354 |
+
assert isinstance(workers, int)
|
355 |
+
if not workers >= 1:
|
356 |
+
raise UserError('--workers must be at least 1')
|
357 |
+
args.data_loader_kwargs.num_workers = workers
|
358 |
+
|
359 |
+
return desc, args
|
360 |
+
|
361 |
+
#----------------------------------------------------------------------------
|
362 |
+
|
363 |
+
def subprocess_fn(rank, args, temp_dir):
|
364 |
+
dnnlib.util.Logger(file_name=os.path.join(args.run_dir, 'log.txt'), file_mode='a', should_flush=True)
|
365 |
+
|
366 |
+
# Init torch.distributed.
|
367 |
+
if args.num_gpus > 1:
|
368 |
+
init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init'))
|
369 |
+
if os.name == 'nt':
|
370 |
+
init_method = 'file:///' + init_file.replace('\\', '/')
|
371 |
+
torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus)
|
372 |
+
else:
|
373 |
+
init_method = f'file://{init_file}'
|
374 |
+
torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus)
|
375 |
+
|
376 |
+
# Init torch_utils.
|
377 |
+
sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None
|
378 |
+
training_stats.init_multiprocessing(rank=rank, sync_device=sync_device)
|
379 |
+
if rank != 0:
|
380 |
+
custom_ops.verbosity = 'none'
|
381 |
+
|
382 |
+
# Execute training loop.
|
383 |
+
training_loop.training_loop(rank=rank, **args)
|
384 |
+
|
385 |
+
#----------------------------------------------------------------------------
|
386 |
+
|
387 |
+
class CommaSeparatedList(click.ParamType):
|
388 |
+
name = 'list'
|
389 |
+
|
390 |
+
def convert(self, value, param, ctx):
|
391 |
+
_ = param, ctx
|
392 |
+
if value is None or value.lower() == 'none' or value == '':
|
393 |
+
return []
|
394 |
+
return value.split(',')
|
395 |
+
|
396 |
+
#----------------------------------------------------------------------------
|
397 |
+
|
398 |
+
@click.command()
|
399 |
+
@click.pass_context
|
400 |
+
|
401 |
+
# General options.
|
402 |
+
@click.option('--outdir', help='Where to save the results', required=True, metavar='DIR')
|
403 |
+
@click.option('--gpus', help='Number of GPUs to use [default: 1]', type=int, metavar='INT')
|
404 |
+
@click.option('--snap', help='Snapshot interval [default: 50 ticks]', type=int, metavar='INT')
|
405 |
+
@click.option('--metrics', help='Comma-separated list or "none" [default: fid50k_full]', type=CommaSeparatedList())
|
406 |
+
@click.option('--seed', help='Random seed [default: 0]', type=int, metavar='INT')
|
407 |
+
@click.option('-n', '--dry-run', help='Print training options and exit', is_flag=True)
|
408 |
+
|
409 |
+
# Dataset.
|
410 |
+
@click.option('--data', help='Training data (directory or zip)', metavar='PATH', required=True)
|
411 |
+
@click.option('--cond', help='Train conditional model based on dataset labels [default: false]', type=bool, metavar='BOOL')
|
412 |
+
@click.option('--subset', help='Train with only N images [default: all]', type=int, metavar='INT')
|
413 |
+
@click.option('--mirror', help='Enable dataset x-flips [default: false]', type=bool, metavar='BOOL')
|
414 |
+
|
415 |
+
# Base config.
|
416 |
+
@click.option('--cfg', help='Base config [default: auto]', type=click.Choice(['auto', 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar']))
|
417 |
+
@click.option('--gamma', help='Override R1 gamma', type=float)
|
418 |
+
@click.option('--kimg', help='Override training duration', type=int, metavar='INT')
|
419 |
+
@click.option('--batch', help='Override batch size', type=int, metavar='INT')
|
420 |
+
|
421 |
+
# Discriminator augmentation.
|
422 |
+
@click.option('--aug', help='Augmentation mode [default: ada]', type=click.Choice(['noaug', 'ada', 'fixed']))
|
423 |
+
@click.option('--p', help='Augmentation probability for --aug=fixed', type=float)
|
424 |
+
@click.option('--target', help='ADA target value for --aug=ada', type=float)
|
425 |
+
@click.option('--augpipe', help='Augmentation pipeline [default: bgc]', type=click.Choice(['blit', 'geom', 'color', 'filter', 'noise', 'cutout', 'bg', 'bgc', 'bgcf', 'bgcfn', 'bgcfnc']))
|
426 |
+
|
427 |
+
# Transfer learning.
|
428 |
+
@click.option('--resume', help='Resume training [default: noresume]', metavar='PKL')
|
429 |
+
@click.option('--freezed', help='Freeze-D [default: 0 layers]', type=int, metavar='INT')
|
430 |
+
|
431 |
+
# Performance options.
|
432 |
+
@click.option('--fp32', help='Disable mixed-precision training', type=bool, metavar='BOOL')
|
433 |
+
@click.option('--nhwc', help='Use NHWC memory format with FP16', type=bool, metavar='BOOL')
|
434 |
+
@click.option('--nobench', help='Disable cuDNN benchmarking', type=bool, metavar='BOOL')
|
435 |
+
@click.option('--allow-tf32', help='Allow PyTorch to use TF32 internally', type=bool, metavar='BOOL')
|
436 |
+
@click.option('--workers', help='Override number of DataLoader workers', type=int, metavar='INT')
|
437 |
+
|
438 |
+
def main(ctx, outdir, dry_run, **config_kwargs):
|
439 |
+
"""Train a GAN using the techniques described in the paper
|
440 |
+
"Training Generative Adversarial Networks with Limited Data".
|
441 |
+
|
442 |
+
Examples:
|
443 |
+
|
444 |
+
\b
|
445 |
+
# Train with custom dataset using 1 GPU.
|
446 |
+
python train.py --outdir=~/training-runs --data=~/mydataset.zip --gpus=1
|
447 |
+
|
448 |
+
\b
|
449 |
+
# Train class-conditional CIFAR-10 using 2 GPUs.
|
450 |
+
python train.py --outdir=~/training-runs --data=~/datasets/cifar10.zip \\
|
451 |
+
--gpus=2 --cfg=cifar --cond=1
|
452 |
+
|
453 |
+
\b
|
454 |
+
# Transfer learn MetFaces from FFHQ using 4 GPUs.
|
455 |
+
python train.py --outdir=~/training-runs --data=~/datasets/metfaces.zip \\
|
456 |
+
--gpus=4 --cfg=paper1024 --mirror=1 --resume=ffhq1024 --snap=10
|
457 |
+
|
458 |
+
\b
|
459 |
+
# Reproduce original StyleGAN2 config F.
|
460 |
+
python train.py --outdir=~/training-runs --data=~/datasets/ffhq.zip \\
|
461 |
+
--gpus=8 --cfg=stylegan2 --mirror=1 --aug=noaug
|
462 |
+
|
463 |
+
\b
|
464 |
+
Base configs (--cfg):
|
465 |
+
auto Automatically select reasonable defaults based on resolution
|
466 |
+
and GPU count. Good starting point for new datasets.
|
467 |
+
stylegan2 Reproduce results for StyleGAN2 config F at 1024x1024.
|
468 |
+
paper256 Reproduce results for FFHQ and LSUN Cat at 256x256.
|
469 |
+
paper512 Reproduce results for BreCaHAD and AFHQ at 512x512.
|
470 |
+
paper1024 Reproduce results for MetFaces at 1024x1024.
|
471 |
+
cifar Reproduce results for CIFAR-10 at 32x32.
|
472 |
+
|
473 |
+
\b
|
474 |
+
Transfer learning source networks (--resume):
|
475 |
+
ffhq256 FFHQ trained at 256x256 resolution.
|
476 |
+
ffhq512 FFHQ trained at 512x512 resolution.
|
477 |
+
ffhq1024 FFHQ trained at 1024x1024 resolution.
|
478 |
+
celebahq256 CelebA-HQ trained at 256x256 resolution.
|
479 |
+
lsundog256 LSUN Dog trained at 256x256 resolution.
|
480 |
+
<PATH or URL> Custom network pickle.
|
481 |
+
"""
|
482 |
+
dnnlib.util.Logger(should_flush=True)
|
483 |
+
|
484 |
+
# Setup training options.
|
485 |
+
try:
|
486 |
+
run_desc, args = setup_training_loop_kwargs(**config_kwargs)
|
487 |
+
except UserError as err:
|
488 |
+
ctx.fail(err)
|
489 |
+
|
490 |
+
# Pick output directory.
|
491 |
+
prev_run_dirs = []
|
492 |
+
if os.path.isdir(outdir):
|
493 |
+
prev_run_dirs = [x for x in os.listdir(outdir) if os.path.isdir(os.path.join(outdir, x))]
|
494 |
+
prev_run_ids = [re.match(r'^\d+', x) for x in prev_run_dirs]
|
495 |
+
prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None]
|
496 |
+
cur_run_id = max(prev_run_ids, default=-1) + 1
|
497 |
+
args.run_dir = os.path.join(outdir, f'{cur_run_id:05d}-{run_desc}')
|
498 |
+
assert not os.path.exists(args.run_dir)
|
499 |
+
|
500 |
+
# Print options.
|
501 |
+
print()
|
502 |
+
print('Training options:')
|
503 |
+
print(json.dumps(args, indent=2))
|
504 |
+
print()
|
505 |
+
print(f'Output directory: {args.run_dir}')
|
506 |
+
print(f'Training data: {args.training_set_kwargs.path}')
|
507 |
+
print(f'Training duration: {args.total_kimg} kimg')
|
508 |
+
print(f'Number of GPUs: {args.num_gpus}')
|
509 |
+
print(f'Number of images: {args.training_set_kwargs.max_size}')
|
510 |
+
print(f'Image resolution: {args.training_set_kwargs.resolution}')
|
511 |
+
print(f'Conditional model: {args.training_set_kwargs.use_labels}')
|
512 |
+
print(f'Dataset x-flips: {args.training_set_kwargs.xflip}')
|
513 |
+
print()
|
514 |
+
|
515 |
+
# Dry run?
|
516 |
+
if dry_run:
|
517 |
+
print('Dry run; exiting.')
|
518 |
+
return
|
519 |
+
|
520 |
+
# Create output directory.
|
521 |
+
print('Creating output directory...')
|
522 |
+
os.makedirs(args.run_dir)
|
523 |
+
with open(os.path.join(args.run_dir, 'training_options.json'), 'wt') as f:
|
524 |
+
json.dump(args, f, indent=2)
|
525 |
+
|
526 |
+
# Launch processes.
|
527 |
+
print('Launching processes...')
|
528 |
+
torch.multiprocessing.set_start_method('spawn')
|
529 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
530 |
+
if args.num_gpus == 1:
|
531 |
+
subprocess_fn(rank=0, args=args, temp_dir=temp_dir)
|
532 |
+
else:
|
533 |
+
torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus)
|
534 |
+
|
535 |
+
#----------------------------------------------------------------------------
|
536 |
+
|
537 |
+
if __name__ == "__main__":
|
538 |
+
main() # pylint: disable=no-value-for-parameter
|
539 |
+
|
540 |
+
#----------------------------------------------------------------------------
|
training/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
# empty
|
training/augment.py
ADDED
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import scipy.signal
|
11 |
+
import torch
|
12 |
+
from torch_utils import persistence
|
13 |
+
from torch_utils import misc
|
14 |
+
from torch_utils.ops import upfirdn2d
|
15 |
+
from torch_utils.ops import grid_sample_gradfix
|
16 |
+
from torch_utils.ops import conv2d_gradfix
|
17 |
+
|
18 |
+
#----------------------------------------------------------------------------
|
19 |
+
# Coefficients of various wavelet decomposition low-pass filters.
|
20 |
+
|
21 |
+
wavelets = {
|
22 |
+
'haar': [0.7071067811865476, 0.7071067811865476],
|
23 |
+
'db1': [0.7071067811865476, 0.7071067811865476],
|
24 |
+
'db2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
|
25 |
+
'db3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
|
26 |
+
'db4': [-0.010597401784997278, 0.032883011666982945, 0.030841381835986965, -0.18703481171888114, -0.02798376941698385, 0.6308807679295904, 0.7148465705525415, 0.23037781330885523],
|
27 |
+
'db5': [0.003335725285001549, -0.012580751999015526, -0.006241490213011705, 0.07757149384006515, -0.03224486958502952, -0.24229488706619015, 0.13842814590110342, 0.7243085284385744, 0.6038292697974729, 0.160102397974125],
|
28 |
+
'db6': [-0.00107730108499558, 0.004777257511010651, 0.0005538422009938016, -0.031582039318031156, 0.02752286553001629, 0.09750160558707936, -0.12976686756709563, -0.22626469396516913, 0.3152503517092432, 0.7511339080215775, 0.4946238903983854, 0.11154074335008017],
|
29 |
+
'db7': [0.0003537138000010399, -0.0018016407039998328, 0.00042957797300470274, 0.012550998556013784, -0.01657454163101562, -0.03802993693503463, 0.0806126091510659, 0.07130921926705004, -0.22403618499416572, -0.14390600392910627, 0.4697822874053586, 0.7291320908465551, 0.39653931948230575, 0.07785205408506236],
|
30 |
+
'db8': [-0.00011747678400228192, 0.0006754494059985568, -0.0003917403729959771, -0.00487035299301066, 0.008746094047015655, 0.013981027917015516, -0.04408825393106472, -0.01736930100202211, 0.128747426620186, 0.00047248457399797254, -0.2840155429624281, -0.015829105256023893, 0.5853546836548691, 0.6756307362980128, 0.3128715909144659, 0.05441584224308161],
|
31 |
+
'sym2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
|
32 |
+
'sym3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
|
33 |
+
'sym4': [-0.07576571478927333, -0.02963552764599851, 0.49761866763201545, 0.8037387518059161, 0.29785779560527736, -0.09921954357684722, -0.012603967262037833, 0.0322231006040427],
|
34 |
+
'sym5': [0.027333068345077982, 0.029519490925774643, -0.039134249302383094, 0.1993975339773936, 0.7234076904024206, 0.6339789634582119, 0.01660210576452232, -0.17532808990845047, -0.021101834024758855, 0.019538882735286728],
|
35 |
+
'sym6': [0.015404109327027373, 0.0034907120842174702, -0.11799011114819057, -0.048311742585633, 0.4910559419267466, 0.787641141030194, 0.3379294217276218, -0.07263752278646252, -0.021060292512300564, 0.04472490177066578, 0.0017677118642428036, -0.007800708325034148],
|
36 |
+
'sym7': [0.002681814568257878, -0.0010473848886829163, -0.01263630340325193, 0.03051551316596357, 0.0678926935013727, -0.049552834937127255, 0.017441255086855827, 0.5361019170917628, 0.767764317003164, 0.2886296317515146, -0.14004724044296152, -0.10780823770381774, 0.004010244871533663, 0.010268176708511255],
|
37 |
+
'sym8': [-0.0033824159510061256, -0.0005421323317911481, 0.03169508781149298, 0.007607487324917605, -0.1432942383508097, -0.061273359067658524, 0.4813596512583722, 0.7771857517005235, 0.3644418948353314, -0.05194583810770904, -0.027219029917056003, 0.049137179673607506, 0.003808752013890615, -0.01495225833704823, -0.0003029205147213668, 0.0018899503327594609],
|
38 |
+
}
|
39 |
+
|
40 |
+
#----------------------------------------------------------------------------
|
41 |
+
# Helpers for constructing transformation matrices.
|
42 |
+
|
43 |
+
def matrix(*rows, device=None):
|
44 |
+
assert all(len(row) == len(rows[0]) for row in rows)
|
45 |
+
elems = [x for row in rows for x in row]
|
46 |
+
ref = [x for x in elems if isinstance(x, torch.Tensor)]
|
47 |
+
if len(ref) == 0:
|
48 |
+
return misc.constant(np.asarray(rows), device=device)
|
49 |
+
assert device is None or device == ref[0].device
|
50 |
+
elems = [x if isinstance(x, torch.Tensor) else misc.constant(x, shape=ref[0].shape, device=ref[0].device) for x in elems]
|
51 |
+
return torch.stack(elems, dim=-1).reshape(ref[0].shape + (len(rows), -1))
|
52 |
+
|
53 |
+
def translate2d(tx, ty, **kwargs):
|
54 |
+
return matrix(
|
55 |
+
[1, 0, tx],
|
56 |
+
[0, 1, ty],
|
57 |
+
[0, 0, 1],
|
58 |
+
**kwargs)
|
59 |
+
|
60 |
+
def translate3d(tx, ty, tz, **kwargs):
|
61 |
+
return matrix(
|
62 |
+
[1, 0, 0, tx],
|
63 |
+
[0, 1, 0, ty],
|
64 |
+
[0, 0, 1, tz],
|
65 |
+
[0, 0, 0, 1],
|
66 |
+
**kwargs)
|
67 |
+
|
68 |
+
def scale2d(sx, sy, **kwargs):
|
69 |
+
return matrix(
|
70 |
+
[sx, 0, 0],
|
71 |
+
[0, sy, 0],
|
72 |
+
[0, 0, 1],
|
73 |
+
**kwargs)
|
74 |
+
|
75 |
+
def scale3d(sx, sy, sz, **kwargs):
|
76 |
+
return matrix(
|
77 |
+
[sx, 0, 0, 0],
|
78 |
+
[0, sy, 0, 0],
|
79 |
+
[0, 0, sz, 0],
|
80 |
+
[0, 0, 0, 1],
|
81 |
+
**kwargs)
|
82 |
+
|
83 |
+
def rotate2d(theta, **kwargs):
|
84 |
+
return matrix(
|
85 |
+
[torch.cos(theta), torch.sin(-theta), 0],
|
86 |
+
[torch.sin(theta), torch.cos(theta), 0],
|
87 |
+
[0, 0, 1],
|
88 |
+
**kwargs)
|
89 |
+
|
90 |
+
def rotate3d(v, theta, **kwargs):
|
91 |
+
vx = v[..., 0]; vy = v[..., 1]; vz = v[..., 2]
|
92 |
+
s = torch.sin(theta); c = torch.cos(theta); cc = 1 - c
|
93 |
+
return matrix(
|
94 |
+
[vx*vx*cc+c, vx*vy*cc-vz*s, vx*vz*cc+vy*s, 0],
|
95 |
+
[vy*vx*cc+vz*s, vy*vy*cc+c, vy*vz*cc-vx*s, 0],
|
96 |
+
[vz*vx*cc-vy*s, vz*vy*cc+vx*s, vz*vz*cc+c, 0],
|
97 |
+
[0, 0, 0, 1],
|
98 |
+
**kwargs)
|
99 |
+
|
100 |
+
def translate2d_inv(tx, ty, **kwargs):
|
101 |
+
return translate2d(-tx, -ty, **kwargs)
|
102 |
+
|
103 |
+
def scale2d_inv(sx, sy, **kwargs):
|
104 |
+
return scale2d(1 / sx, 1 / sy, **kwargs)
|
105 |
+
|
106 |
+
def rotate2d_inv(theta, **kwargs):
|
107 |
+
return rotate2d(-theta, **kwargs)
|
108 |
+
|
109 |
+
#----------------------------------------------------------------------------
|
110 |
+
# Versatile image augmentation pipeline from the paper
|
111 |
+
# "Training Generative Adversarial Networks with Limited Data".
|
112 |
+
#
|
113 |
+
# All augmentations are disabled by default; individual augmentations can
|
114 |
+
# be enabled by setting their probability multipliers to 1.
|
115 |
+
|
116 |
+
@persistence.persistent_class
|
117 |
+
class AugmentPipe(torch.nn.Module):
|
118 |
+
def __init__(self,
|
119 |
+
xflip=0, rotate90=0, xint=0, xint_max=0.125,
|
120 |
+
scale=0, rotate=0, aniso=0, xfrac=0, scale_std=0.2, rotate_max=1, aniso_std=0.2, xfrac_std=0.125,
|
121 |
+
brightness=0, contrast=0, lumaflip=0, hue=0, saturation=0, brightness_std=0.2, contrast_std=0.5, hue_max=1, saturation_std=1,
|
122 |
+
imgfilter=0, imgfilter_bands=[1,1,1,1], imgfilter_std=1,
|
123 |
+
noise=0, cutout=0, noise_std=0.1, cutout_size=0.5,
|
124 |
+
):
|
125 |
+
super().__init__()
|
126 |
+
self.register_buffer('p', torch.ones([])) # Overall multiplier for augmentation probability.
|
127 |
+
|
128 |
+
# Pixel blitting.
|
129 |
+
self.xflip = float(xflip) # Probability multiplier for x-flip.
|
130 |
+
self.rotate90 = float(rotate90) # Probability multiplier for 90 degree rotations.
|
131 |
+
self.xint = float(xint) # Probability multiplier for integer translation.
|
132 |
+
self.xint_max = float(xint_max) # Range of integer translation, relative to image dimensions.
|
133 |
+
|
134 |
+
# General geometric transformations.
|
135 |
+
self.scale = float(scale) # Probability multiplier for isotropic scaling.
|
136 |
+
self.rotate = float(rotate) # Probability multiplier for arbitrary rotation.
|
137 |
+
self.aniso = float(aniso) # Probability multiplier for anisotropic scaling.
|
138 |
+
self.xfrac = float(xfrac) # Probability multiplier for fractional translation.
|
139 |
+
self.scale_std = float(scale_std) # Log2 standard deviation of isotropic scaling.
|
140 |
+
self.rotate_max = float(rotate_max) # Range of arbitrary rotation, 1 = full circle.
|
141 |
+
self.aniso_std = float(aniso_std) # Log2 standard deviation of anisotropic scaling.
|
142 |
+
self.xfrac_std = float(xfrac_std) # Standard deviation of frational translation, relative to image dimensions.
|
143 |
+
|
144 |
+
# Color transformations.
|
145 |
+
self.brightness = float(brightness) # Probability multiplier for brightness.
|
146 |
+
self.contrast = float(contrast) # Probability multiplier for contrast.
|
147 |
+
self.lumaflip = float(lumaflip) # Probability multiplier for luma flip.
|
148 |
+
self.hue = float(hue) # Probability multiplier for hue rotation.
|
149 |
+
self.saturation = float(saturation) # Probability multiplier for saturation.
|
150 |
+
self.brightness_std = float(brightness_std) # Standard deviation of brightness.
|
151 |
+
self.contrast_std = float(contrast_std) # Log2 standard deviation of contrast.
|
152 |
+
self.hue_max = float(hue_max) # Range of hue rotation, 1 = full circle.
|
153 |
+
self.saturation_std = float(saturation_std) # Log2 standard deviation of saturation.
|
154 |
+
|
155 |
+
# Image-space filtering.
|
156 |
+
self.imgfilter = float(imgfilter) # Probability multiplier for image-space filtering.
|
157 |
+
self.imgfilter_bands = list(imgfilter_bands) # Probability multipliers for individual frequency bands.
|
158 |
+
self.imgfilter_std = float(imgfilter_std) # Log2 standard deviation of image-space filter amplification.
|
159 |
+
|
160 |
+
# Image-space corruptions.
|
161 |
+
self.noise = float(noise) # Probability multiplier for additive RGB noise.
|
162 |
+
self.cutout = float(cutout) # Probability multiplier for cutout.
|
163 |
+
self.noise_std = float(noise_std) # Standard deviation of additive RGB noise.
|
164 |
+
self.cutout_size = float(cutout_size) # Size of the cutout rectangle, relative to image dimensions.
|
165 |
+
|
166 |
+
# Setup orthogonal lowpass filter for geometric augmentations.
|
167 |
+
self.register_buffer('Hz_geom', upfirdn2d.setup_filter(wavelets['sym6']))
|
168 |
+
|
169 |
+
# Construct filter bank for image-space filtering.
|
170 |
+
Hz_lo = np.asarray(wavelets['sym2']) # H(z)
|
171 |
+
Hz_hi = Hz_lo * ((-1) ** np.arange(Hz_lo.size)) # H(-z)
|
172 |
+
Hz_lo2 = np.convolve(Hz_lo, Hz_lo[::-1]) / 2 # H(z) * H(z^-1) / 2
|
173 |
+
Hz_hi2 = np.convolve(Hz_hi, Hz_hi[::-1]) / 2 # H(-z) * H(-z^-1) / 2
|
174 |
+
Hz_fbank = np.eye(4, 1) # Bandpass(H(z), b_i)
|
175 |
+
for i in range(1, Hz_fbank.shape[0]):
|
176 |
+
Hz_fbank = np.dstack([Hz_fbank, np.zeros_like(Hz_fbank)]).reshape(Hz_fbank.shape[0], -1)[:, :-1]
|
177 |
+
Hz_fbank = scipy.signal.convolve(Hz_fbank, [Hz_lo2])
|
178 |
+
Hz_fbank[i, (Hz_fbank.shape[1] - Hz_hi2.size) // 2 : (Hz_fbank.shape[1] + Hz_hi2.size) // 2] += Hz_hi2
|
179 |
+
self.register_buffer('Hz_fbank', torch.as_tensor(Hz_fbank, dtype=torch.float32))
|
180 |
+
|
181 |
+
def forward(self, images, debug_percentile=None):
|
182 |
+
assert isinstance(images, torch.Tensor) and images.ndim == 4
|
183 |
+
batch_size, num_channels, height, width = images.shape
|
184 |
+
device = images.device
|
185 |
+
if debug_percentile is not None:
|
186 |
+
debug_percentile = torch.as_tensor(debug_percentile, dtype=torch.float32, device=device)
|
187 |
+
|
188 |
+
# -------------------------------------
|
189 |
+
# Select parameters for pixel blitting.
|
190 |
+
# -------------------------------------
|
191 |
+
|
192 |
+
# Initialize inverse homogeneous 2D transform: G_inv @ pixel_out ==> pixel_in
|
193 |
+
I_3 = torch.eye(3, device=device)
|
194 |
+
G_inv = I_3
|
195 |
+
|
196 |
+
# Apply x-flip with probability (xflip * strength).
|
197 |
+
if self.xflip > 0:
|
198 |
+
i = torch.floor(torch.rand([batch_size], device=device) * 2)
|
199 |
+
i = torch.where(torch.rand([batch_size], device=device) < self.xflip * self.p, i, torch.zeros_like(i))
|
200 |
+
if debug_percentile is not None:
|
201 |
+
i = torch.full_like(i, torch.floor(debug_percentile * 2))
|
202 |
+
G_inv = G_inv @ scale2d_inv(1 - 2 * i, 1)
|
203 |
+
|
204 |
+
# Apply 90 degree rotations with probability (rotate90 * strength).
|
205 |
+
if self.rotate90 > 0:
|
206 |
+
i = torch.floor(torch.rand([batch_size], device=device) * 4)
|
207 |
+
i = torch.where(torch.rand([batch_size], device=device) < self.rotate90 * self.p, i, torch.zeros_like(i))
|
208 |
+
if debug_percentile is not None:
|
209 |
+
i = torch.full_like(i, torch.floor(debug_percentile * 4))
|
210 |
+
G_inv = G_inv @ rotate2d_inv(-np.pi / 2 * i)
|
211 |
+
|
212 |
+
# Apply integer translation with probability (xint * strength).
|
213 |
+
if self.xint > 0:
|
214 |
+
t = (torch.rand([batch_size, 2], device=device) * 2 - 1) * self.xint_max
|
215 |
+
t = torch.where(torch.rand([batch_size, 1], device=device) < self.xint * self.p, t, torch.zeros_like(t))
|
216 |
+
if debug_percentile is not None:
|
217 |
+
t = torch.full_like(t, (debug_percentile * 2 - 1) * self.xint_max)
|
218 |
+
G_inv = G_inv @ translate2d_inv(torch.round(t[:,0] * width), torch.round(t[:,1] * height))
|
219 |
+
|
220 |
+
# --------------------------------------------------------
|
221 |
+
# Select parameters for general geometric transformations.
|
222 |
+
# --------------------------------------------------------
|
223 |
+
|
224 |
+
# Apply isotropic scaling with probability (scale * strength).
|
225 |
+
if self.scale > 0:
|
226 |
+
s = torch.exp2(torch.randn([batch_size], device=device) * self.scale_std)
|
227 |
+
s = torch.where(torch.rand([batch_size], device=device) < self.scale * self.p, s, torch.ones_like(s))
|
228 |
+
if debug_percentile is not None:
|
229 |
+
s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.scale_std))
|
230 |
+
G_inv = G_inv @ scale2d_inv(s, s)
|
231 |
+
|
232 |
+
# Apply pre-rotation with probability p_rot.
|
233 |
+
p_rot = 1 - torch.sqrt((1 - self.rotate * self.p).clamp(0, 1)) # P(pre OR post) = p
|
234 |
+
if self.rotate > 0:
|
235 |
+
theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max
|
236 |
+
theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta))
|
237 |
+
if debug_percentile is not None:
|
238 |
+
theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.rotate_max)
|
239 |
+
G_inv = G_inv @ rotate2d_inv(-theta) # Before anisotropic scaling.
|
240 |
+
|
241 |
+
# Apply anisotropic scaling with probability (aniso * strength).
|
242 |
+
if self.aniso > 0:
|
243 |
+
s = torch.exp2(torch.randn([batch_size], device=device) * self.aniso_std)
|
244 |
+
s = torch.where(torch.rand([batch_size], device=device) < self.aniso * self.p, s, torch.ones_like(s))
|
245 |
+
if debug_percentile is not None:
|
246 |
+
s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.aniso_std))
|
247 |
+
G_inv = G_inv @ scale2d_inv(s, 1 / s)
|
248 |
+
|
249 |
+
# Apply post-rotation with probability p_rot.
|
250 |
+
if self.rotate > 0:
|
251 |
+
theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max
|
252 |
+
theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta))
|
253 |
+
if debug_percentile is not None:
|
254 |
+
theta = torch.zeros_like(theta)
|
255 |
+
G_inv = G_inv @ rotate2d_inv(-theta) # After anisotropic scaling.
|
256 |
+
|
257 |
+
# Apply fractional translation with probability (xfrac * strength).
|
258 |
+
if self.xfrac > 0:
|
259 |
+
t = torch.randn([batch_size, 2], device=device) * self.xfrac_std
|
260 |
+
t = torch.where(torch.rand([batch_size, 1], device=device) < self.xfrac * self.p, t, torch.zeros_like(t))
|
261 |
+
if debug_percentile is not None:
|
262 |
+
t = torch.full_like(t, torch.erfinv(debug_percentile * 2 - 1) * self.xfrac_std)
|
263 |
+
G_inv = G_inv @ translate2d_inv(t[:,0] * width, t[:,1] * height)
|
264 |
+
|
265 |
+
# ----------------------------------
|
266 |
+
# Execute geometric transformations.
|
267 |
+
# ----------------------------------
|
268 |
+
|
269 |
+
# Execute if the transform is not identity.
|
270 |
+
if G_inv is not I_3:
|
271 |
+
|
272 |
+
# Calculate padding.
|
273 |
+
cx = (width - 1) / 2
|
274 |
+
cy = (height - 1) / 2
|
275 |
+
cp = matrix([-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1], device=device) # [idx, xyz]
|
276 |
+
cp = G_inv @ cp.t() # [batch, xyz, idx]
|
277 |
+
Hz_pad = self.Hz_geom.shape[0] // 4
|
278 |
+
margin = cp[:, :2, :].permute(1, 0, 2).flatten(1) # [xy, batch * idx]
|
279 |
+
margin = torch.cat([-margin, margin]).max(dim=1).values # [x0, y0, x1, y1]
|
280 |
+
margin = margin + misc.constant([Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device)
|
281 |
+
margin = margin.max(misc.constant([0, 0] * 2, device=device))
|
282 |
+
margin = margin.min(misc.constant([width-1, height-1] * 2, device=device))
|
283 |
+
mx0, my0, mx1, my1 = margin.ceil().to(torch.int32)
|
284 |
+
|
285 |
+
# Pad image and adjust origin.
|
286 |
+
images = torch.nn.functional.pad(input=images, pad=[mx0,mx1,my0,my1], mode='reflect')
|
287 |
+
G_inv = translate2d((mx0 - mx1) / 2, (my0 - my1) / 2) @ G_inv
|
288 |
+
|
289 |
+
# Upsample.
|
290 |
+
images = upfirdn2d.upsample2d(x=images, f=self.Hz_geom, up=2)
|
291 |
+
G_inv = scale2d(2, 2, device=device) @ G_inv @ scale2d_inv(2, 2, device=device)
|
292 |
+
G_inv = translate2d(-0.5, -0.5, device=device) @ G_inv @ translate2d_inv(-0.5, -0.5, device=device)
|
293 |
+
|
294 |
+
# Execute transformation.
|
295 |
+
shape = [batch_size, num_channels, (height + Hz_pad * 2) * 2, (width + Hz_pad * 2) * 2]
|
296 |
+
G_inv = scale2d(2 / images.shape[3], 2 / images.shape[2], device=device) @ G_inv @ scale2d_inv(2 / shape[3], 2 / shape[2], device=device)
|
297 |
+
grid = torch.nn.functional.affine_grid(theta=G_inv[:,:2,:], size=shape, align_corners=False)
|
298 |
+
images = grid_sample_gradfix.grid_sample(images, grid)
|
299 |
+
|
300 |
+
# Downsample and crop.
|
301 |
+
images = upfirdn2d.downsample2d(x=images, f=self.Hz_geom, down=2, padding=-Hz_pad*2, flip_filter=True)
|
302 |
+
|
303 |
+
# --------------------------------------------
|
304 |
+
# Select parameters for color transformations.
|
305 |
+
# --------------------------------------------
|
306 |
+
|
307 |
+
# Initialize homogeneous 3D transformation matrix: C @ color_in ==> color_out
|
308 |
+
I_4 = torch.eye(4, device=device)
|
309 |
+
C = I_4
|
310 |
+
|
311 |
+
# Apply brightness with probability (brightness * strength).
|
312 |
+
if self.brightness > 0:
|
313 |
+
b = torch.randn([batch_size], device=device) * self.brightness_std
|
314 |
+
b = torch.where(torch.rand([batch_size], device=device) < self.brightness * self.p, b, torch.zeros_like(b))
|
315 |
+
if debug_percentile is not None:
|
316 |
+
b = torch.full_like(b, torch.erfinv(debug_percentile * 2 - 1) * self.brightness_std)
|
317 |
+
C = translate3d(b, b, b) @ C
|
318 |
+
|
319 |
+
# Apply contrast with probability (contrast * strength).
|
320 |
+
if self.contrast > 0:
|
321 |
+
c = torch.exp2(torch.randn([batch_size], device=device) * self.contrast_std)
|
322 |
+
c = torch.where(torch.rand([batch_size], device=device) < self.contrast * self.p, c, torch.ones_like(c))
|
323 |
+
if debug_percentile is not None:
|
324 |
+
c = torch.full_like(c, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.contrast_std))
|
325 |
+
C = scale3d(c, c, c) @ C
|
326 |
+
|
327 |
+
# Apply luma flip with probability (lumaflip * strength).
|
328 |
+
v = misc.constant(np.asarray([1, 1, 1, 0]) / np.sqrt(3), device=device) # Luma axis.
|
329 |
+
if self.lumaflip > 0:
|
330 |
+
i = torch.floor(torch.rand([batch_size, 1, 1], device=device) * 2)
|
331 |
+
i = torch.where(torch.rand([batch_size, 1, 1], device=device) < self.lumaflip * self.p, i, torch.zeros_like(i))
|
332 |
+
if debug_percentile is not None:
|
333 |
+
i = torch.full_like(i, torch.floor(debug_percentile * 2))
|
334 |
+
C = (I_4 - 2 * v.ger(v) * i) @ C # Householder reflection.
|
335 |
+
|
336 |
+
# Apply hue rotation with probability (hue * strength).
|
337 |
+
if self.hue > 0 and num_channels > 1:
|
338 |
+
theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.hue_max
|
339 |
+
theta = torch.where(torch.rand([batch_size], device=device) < self.hue * self.p, theta, torch.zeros_like(theta))
|
340 |
+
if debug_percentile is not None:
|
341 |
+
theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.hue_max)
|
342 |
+
C = rotate3d(v, theta) @ C # Rotate around v.
|
343 |
+
|
344 |
+
# Apply saturation with probability (saturation * strength).
|
345 |
+
if self.saturation > 0 and num_channels > 1:
|
346 |
+
s = torch.exp2(torch.randn([batch_size, 1, 1], device=device) * self.saturation_std)
|
347 |
+
s = torch.where(torch.rand([batch_size, 1, 1], device=device) < self.saturation * self.p, s, torch.ones_like(s))
|
348 |
+
if debug_percentile is not None:
|
349 |
+
s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.saturation_std))
|
350 |
+
C = (v.ger(v) + (I_4 - v.ger(v)) * s) @ C
|
351 |
+
|
352 |
+
# ------------------------------
|
353 |
+
# Execute color transformations.
|
354 |
+
# ------------------------------
|
355 |
+
|
356 |
+
# Execute if the transform is not identity.
|
357 |
+
if C is not I_4:
|
358 |
+
images = images.reshape([batch_size, num_channels, height * width])
|
359 |
+
if num_channels == 3:
|
360 |
+
images = C[:, :3, :3] @ images + C[:, :3, 3:]
|
361 |
+
elif num_channels == 1:
|
362 |
+
C = C[:, :3, :].mean(dim=1, keepdims=True)
|
363 |
+
images = images * C[:, :, :3].sum(dim=2, keepdims=True) + C[:, :, 3:]
|
364 |
+
else:
|
365 |
+
raise ValueError('Image must be RGB (3 channels) or L (1 channel)')
|
366 |
+
images = images.reshape([batch_size, num_channels, height, width])
|
367 |
+
|
368 |
+
# ----------------------
|
369 |
+
# Image-space filtering.
|
370 |
+
# ----------------------
|
371 |
+
|
372 |
+
if self.imgfilter > 0:
|
373 |
+
num_bands = self.Hz_fbank.shape[0]
|
374 |
+
assert len(self.imgfilter_bands) == num_bands
|
375 |
+
expected_power = misc.constant(np.array([10, 1, 1, 1]) / 13, device=device) # Expected power spectrum (1/f).
|
376 |
+
|
377 |
+
# Apply amplification for each band with probability (imgfilter * strength * band_strength).
|
378 |
+
g = torch.ones([batch_size, num_bands], device=device) # Global gain vector (identity).
|
379 |
+
for i, band_strength in enumerate(self.imgfilter_bands):
|
380 |
+
t_i = torch.exp2(torch.randn([batch_size], device=device) * self.imgfilter_std)
|
381 |
+
t_i = torch.where(torch.rand([batch_size], device=device) < self.imgfilter * self.p * band_strength, t_i, torch.ones_like(t_i))
|
382 |
+
if debug_percentile is not None:
|
383 |
+
t_i = torch.full_like(t_i, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.imgfilter_std)) if band_strength > 0 else torch.ones_like(t_i)
|
384 |
+
t = torch.ones([batch_size, num_bands], device=device) # Temporary gain vector.
|
385 |
+
t[:, i] = t_i # Replace i'th element.
|
386 |
+
t = t / (expected_power * t.square()).sum(dim=-1, keepdims=True).sqrt() # Normalize power.
|
387 |
+
g = g * t # Accumulate into global gain.
|
388 |
+
|
389 |
+
# Construct combined amplification filter.
|
390 |
+
Hz_prime = g @ self.Hz_fbank # [batch, tap]
|
391 |
+
Hz_prime = Hz_prime.unsqueeze(1).repeat([1, num_channels, 1]) # [batch, channels, tap]
|
392 |
+
Hz_prime = Hz_prime.reshape([batch_size * num_channels, 1, -1]) # [batch * channels, 1, tap]
|
393 |
+
|
394 |
+
# Apply filter.
|
395 |
+
p = self.Hz_fbank.shape[1] // 2
|
396 |
+
images = images.reshape([1, batch_size * num_channels, height, width])
|
397 |
+
images = torch.nn.functional.pad(input=images, pad=[p,p,p,p], mode='reflect')
|
398 |
+
images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(2), groups=batch_size*num_channels)
|
399 |
+
images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(3), groups=batch_size*num_channels)
|
400 |
+
images = images.reshape([batch_size, num_channels, height, width])
|
401 |
+
|
402 |
+
# ------------------------
|
403 |
+
# Image-space corruptions.
|
404 |
+
# ------------------------
|
405 |
+
|
406 |
+
# Apply additive RGB noise with probability (noise * strength).
|
407 |
+
if self.noise > 0:
|
408 |
+
sigma = torch.randn([batch_size, 1, 1, 1], device=device).abs() * self.noise_std
|
409 |
+
sigma = torch.where(torch.rand([batch_size, 1, 1, 1], device=device) < self.noise * self.p, sigma, torch.zeros_like(sigma))
|
410 |
+
if debug_percentile is not None:
|
411 |
+
sigma = torch.full_like(sigma, torch.erfinv(debug_percentile) * self.noise_std)
|
412 |
+
images = images + torch.randn([batch_size, num_channels, height, width], device=device) * sigma
|
413 |
+
|
414 |
+
# Apply cutout with probability (cutout * strength).
|
415 |
+
if self.cutout > 0:
|
416 |
+
size = torch.full([batch_size, 2, 1, 1, 1], self.cutout_size, device=device)
|
417 |
+
size = torch.where(torch.rand([batch_size, 1, 1, 1, 1], device=device) < self.cutout * self.p, size, torch.zeros_like(size))
|
418 |
+
center = torch.rand([batch_size, 2, 1, 1, 1], device=device)
|
419 |
+
if debug_percentile is not None:
|
420 |
+
size = torch.full_like(size, self.cutout_size)
|
421 |
+
center = torch.full_like(center, debug_percentile)
|
422 |
+
coord_x = torch.arange(width, device=device).reshape([1, 1, 1, -1])
|
423 |
+
coord_y = torch.arange(height, device=device).reshape([1, 1, -1, 1])
|
424 |
+
mask_x = (((coord_x + 0.5) / width - center[:, 0]).abs() >= size[:, 0] / 2)
|
425 |
+
mask_y = (((coord_y + 0.5) / height - center[:, 1]).abs() >= size[:, 1] / 2)
|
426 |
+
mask = torch.logical_or(mask_x, mask_y).to(torch.float32)
|
427 |
+
images = images * mask
|
428 |
+
|
429 |
+
return images
|
430 |
+
|
431 |
+
#----------------------------------------------------------------------------
|
training/dataset.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import os
|
10 |
+
import numpy as np
|
11 |
+
import zipfile
|
12 |
+
import PIL.Image
|
13 |
+
import json
|
14 |
+
import torch
|
15 |
+
import dnnlib
|
16 |
+
|
17 |
+
try:
|
18 |
+
import pyspng
|
19 |
+
except ImportError:
|
20 |
+
pyspng = None
|
21 |
+
|
22 |
+
#----------------------------------------------------------------------------
|
23 |
+
|
24 |
+
class Dataset(torch.utils.data.Dataset):
|
25 |
+
def __init__(self,
|
26 |
+
name, # Name of the dataset.
|
27 |
+
raw_shape, # Shape of the raw image data (NCHW).
|
28 |
+
max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip.
|
29 |
+
use_labels = False, # Enable conditioning labels? False = label dimension is zero.
|
30 |
+
xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size.
|
31 |
+
random_seed = 0, # Random seed to use when applying max_size.
|
32 |
+
):
|
33 |
+
self._name = name
|
34 |
+
self._raw_shape = list(raw_shape)
|
35 |
+
self._use_labels = use_labels
|
36 |
+
self._raw_labels = None
|
37 |
+
self._label_shape = None
|
38 |
+
|
39 |
+
# Apply max_size.
|
40 |
+
self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64)
|
41 |
+
if (max_size is not None) and (self._raw_idx.size > max_size):
|
42 |
+
np.random.RandomState(random_seed).shuffle(self._raw_idx)
|
43 |
+
self._raw_idx = np.sort(self._raw_idx[:max_size])
|
44 |
+
|
45 |
+
# Apply xflip.
|
46 |
+
self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8)
|
47 |
+
if xflip:
|
48 |
+
self._raw_idx = np.tile(self._raw_idx, 2)
|
49 |
+
self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)])
|
50 |
+
|
51 |
+
def _get_raw_labels(self):
|
52 |
+
if self._raw_labels is None:
|
53 |
+
self._raw_labels = self._load_raw_labels() if self._use_labels else None
|
54 |
+
if self._raw_labels is None:
|
55 |
+
self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32)
|
56 |
+
assert isinstance(self._raw_labels, np.ndarray)
|
57 |
+
assert self._raw_labels.shape[0] == self._raw_shape[0]
|
58 |
+
assert self._raw_labels.dtype in [np.float32, np.int64]
|
59 |
+
if self._raw_labels.dtype == np.int64:
|
60 |
+
assert self._raw_labels.ndim == 1
|
61 |
+
assert np.all(self._raw_labels >= 0)
|
62 |
+
return self._raw_labels
|
63 |
+
|
64 |
+
def close(self): # to be overridden by subclass
|
65 |
+
pass
|
66 |
+
|
67 |
+
def _load_raw_image(self, raw_idx): # to be overridden by subclass
|
68 |
+
raise NotImplementedError
|
69 |
+
|
70 |
+
def _load_raw_labels(self): # to be overridden by subclass
|
71 |
+
raise NotImplementedError
|
72 |
+
|
73 |
+
def __getstate__(self):
|
74 |
+
return dict(self.__dict__, _raw_labels=None)
|
75 |
+
|
76 |
+
def __del__(self):
|
77 |
+
try:
|
78 |
+
self.close()
|
79 |
+
except:
|
80 |
+
pass
|
81 |
+
|
82 |
+
def __len__(self):
|
83 |
+
return self._raw_idx.size
|
84 |
+
|
85 |
+
def __getitem__(self, idx):
|
86 |
+
image = self._load_raw_image(self._raw_idx[idx])
|
87 |
+
assert isinstance(image, np.ndarray)
|
88 |
+
assert list(image.shape) == self.image_shape
|
89 |
+
assert image.dtype == np.uint8
|
90 |
+
if self._xflip[idx]:
|
91 |
+
assert image.ndim == 3 # CHW
|
92 |
+
image = image[:, :, ::-1]
|
93 |
+
return image.copy(), self.get_label(idx)
|
94 |
+
|
95 |
+
def get_label(self, idx):
|
96 |
+
label = self._get_raw_labels()[self._raw_idx[idx]]
|
97 |
+
if label.dtype == np.int64:
|
98 |
+
onehot = np.zeros(self.label_shape, dtype=np.float32)
|
99 |
+
onehot[label] = 1
|
100 |
+
label = onehot
|
101 |
+
return label.copy()
|
102 |
+
|
103 |
+
def get_details(self, idx):
|
104 |
+
d = dnnlib.EasyDict()
|
105 |
+
d.raw_idx = int(self._raw_idx[idx])
|
106 |
+
d.xflip = (int(self._xflip[idx]) != 0)
|
107 |
+
d.raw_label = self._get_raw_labels()[d.raw_idx].copy()
|
108 |
+
return d
|
109 |
+
|
110 |
+
@property
|
111 |
+
def name(self):
|
112 |
+
return self._name
|
113 |
+
|
114 |
+
@property
|
115 |
+
def image_shape(self):
|
116 |
+
return list(self._raw_shape[1:])
|
117 |
+
|
118 |
+
@property
|
119 |
+
def num_channels(self):
|
120 |
+
assert len(self.image_shape) == 3 # CHW
|
121 |
+
return self.image_shape[0]
|
122 |
+
|
123 |
+
@property
|
124 |
+
def resolution(self):
|
125 |
+
assert len(self.image_shape) == 3 # CHW
|
126 |
+
assert self.image_shape[1] == self.image_shape[2]
|
127 |
+
return self.image_shape[1]
|
128 |
+
|
129 |
+
@property
|
130 |
+
def label_shape(self):
|
131 |
+
if self._label_shape is None:
|
132 |
+
raw_labels = self._get_raw_labels()
|
133 |
+
if raw_labels.dtype == np.int64:
|
134 |
+
self._label_shape = [int(np.max(raw_labels)) + 1]
|
135 |
+
else:
|
136 |
+
self._label_shape = raw_labels.shape[1:]
|
137 |
+
return list(self._label_shape)
|
138 |
+
|
139 |
+
@property
|
140 |
+
def label_dim(self):
|
141 |
+
assert len(self.label_shape) == 1
|
142 |
+
return self.label_shape[0]
|
143 |
+
|
144 |
+
@property
|
145 |
+
def has_labels(self):
|
146 |
+
return any(x != 0 for x in self.label_shape)
|
147 |
+
|
148 |
+
@property
|
149 |
+
def has_onehot_labels(self):
|
150 |
+
return self._get_raw_labels().dtype == np.int64
|
151 |
+
|
152 |
+
#----------------------------------------------------------------------------
|
153 |
+
|
154 |
+
class ImageFolderDataset(Dataset):
|
155 |
+
def __init__(self,
|
156 |
+
path, # Path to directory or zip.
|
157 |
+
resolution = None, # Ensure specific resolution, None = highest available.
|
158 |
+
**super_kwargs, # Additional arguments for the Dataset base class.
|
159 |
+
):
|
160 |
+
self._path = path
|
161 |
+
self._zipfile = None
|
162 |
+
|
163 |
+
if os.path.isdir(self._path):
|
164 |
+
self._type = 'dir'
|
165 |
+
self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files}
|
166 |
+
elif self._file_ext(self._path) == '.zip':
|
167 |
+
self._type = 'zip'
|
168 |
+
self._all_fnames = set(self._get_zipfile().namelist())
|
169 |
+
else:
|
170 |
+
raise IOError('Path must point to a directory or zip')
|
171 |
+
|
172 |
+
PIL.Image.init()
|
173 |
+
self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION)
|
174 |
+
if len(self._image_fnames) == 0:
|
175 |
+
raise IOError('No image files found in the specified path')
|
176 |
+
|
177 |
+
name = os.path.splitext(os.path.basename(self._path))[0]
|
178 |
+
raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape)
|
179 |
+
if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution):
|
180 |
+
raise IOError('Image files do not match the specified resolution')
|
181 |
+
super().__init__(name=name, raw_shape=raw_shape, **super_kwargs)
|
182 |
+
|
183 |
+
@staticmethod
|
184 |
+
def _file_ext(fname):
|
185 |
+
return os.path.splitext(fname)[1].lower()
|
186 |
+
|
187 |
+
def _get_zipfile(self):
|
188 |
+
assert self._type == 'zip'
|
189 |
+
if self._zipfile is None:
|
190 |
+
self._zipfile = zipfile.ZipFile(self._path)
|
191 |
+
return self._zipfile
|
192 |
+
|
193 |
+
def _open_file(self, fname):
|
194 |
+
if self._type == 'dir':
|
195 |
+
return open(os.path.join(self._path, fname), 'rb')
|
196 |
+
if self._type == 'zip':
|
197 |
+
return self._get_zipfile().open(fname, 'r')
|
198 |
+
return None
|
199 |
+
|
200 |
+
def close(self):
|
201 |
+
try:
|
202 |
+
if self._zipfile is not None:
|
203 |
+
self._zipfile.close()
|
204 |
+
finally:
|
205 |
+
self._zipfile = None
|
206 |
+
|
207 |
+
def __getstate__(self):
|
208 |
+
return dict(super().__getstate__(), _zipfile=None)
|
209 |
+
|
210 |
+
def _load_raw_image(self, raw_idx):
|
211 |
+
fname = self._image_fnames[raw_idx]
|
212 |
+
with self._open_file(fname) as f:
|
213 |
+
if pyspng is not None and self._file_ext(fname) == '.png':
|
214 |
+
image = pyspng.load(f.read())
|
215 |
+
else:
|
216 |
+
image = np.array(PIL.Image.open(f))
|
217 |
+
if image.ndim == 2:
|
218 |
+
image = image[:, :, np.newaxis] # HW => HWC
|
219 |
+
image = image.transpose(2, 0, 1) # HWC => CHW
|
220 |
+
return image
|
221 |
+
|
222 |
+
def _load_raw_labels(self):
|
223 |
+
fname = 'dataset.json'
|
224 |
+
if fname not in self._all_fnames:
|
225 |
+
return None
|
226 |
+
with self._open_file(fname) as f:
|
227 |
+
labels = json.load(f)['labels']
|
228 |
+
if labels is None:
|
229 |
+
return None
|
230 |
+
labels = dict(labels)
|
231 |
+
labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames]
|
232 |
+
labels = np.array(labels)
|
233 |
+
labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])
|
234 |
+
return labels
|
235 |
+
|
236 |
+
#----------------------------------------------------------------------------
|
training/loss.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
from torch_utils import training_stats
|
12 |
+
from torch_utils import misc
|
13 |
+
from torch_utils.ops import conv2d_gradfix
|
14 |
+
|
15 |
+
#----------------------------------------------------------------------------
|
16 |
+
|
17 |
+
class Loss:
|
18 |
+
def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync, gain): # to be overridden by subclass
|
19 |
+
raise NotImplementedError()
|
20 |
+
|
21 |
+
#----------------------------------------------------------------------------
|
22 |
+
|
23 |
+
class StyleGAN2Loss(Loss):
|
24 |
+
def __init__(self, device, G_mapping, G_synthesis, D, augment_pipe=None, style_mixing_prob=0.9, r1_gamma=10, pl_batch_shrink=2, pl_decay=0.01, pl_weight=2):
|
25 |
+
super().__init__()
|
26 |
+
self.device = device
|
27 |
+
self.G_mapping = G_mapping
|
28 |
+
self.G_synthesis = G_synthesis
|
29 |
+
self.D = D
|
30 |
+
self.augment_pipe = augment_pipe
|
31 |
+
self.style_mixing_prob = style_mixing_prob
|
32 |
+
self.r1_gamma = r1_gamma
|
33 |
+
self.pl_batch_shrink = pl_batch_shrink
|
34 |
+
self.pl_decay = pl_decay
|
35 |
+
self.pl_weight = pl_weight
|
36 |
+
self.pl_mean = torch.zeros([], device=device)
|
37 |
+
|
38 |
+
def run_G(self, z, c, sync):
|
39 |
+
with misc.ddp_sync(self.G_mapping, sync):
|
40 |
+
ws = self.G_mapping(z, c)
|
41 |
+
if self.style_mixing_prob > 0:
|
42 |
+
with torch.autograd.profiler.record_function('style_mixing'):
|
43 |
+
cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1])
|
44 |
+
cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1]))
|
45 |
+
ws[:, cutoff:] = self.G_mapping(torch.randn_like(z), c, skip_w_avg_update=True)[:, cutoff:]
|
46 |
+
with misc.ddp_sync(self.G_synthesis, sync):
|
47 |
+
img = self.G_synthesis(ws)
|
48 |
+
return img, ws
|
49 |
+
|
50 |
+
def run_D(self, img, c, sync):
|
51 |
+
if self.augment_pipe is not None:
|
52 |
+
img = self.augment_pipe(img)
|
53 |
+
with misc.ddp_sync(self.D, sync):
|
54 |
+
logits = self.D(img, c)
|
55 |
+
return logits
|
56 |
+
|
57 |
+
def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync, gain):
|
58 |
+
assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth']
|
59 |
+
do_Gmain = (phase in ['Gmain', 'Gboth'])
|
60 |
+
do_Dmain = (phase in ['Dmain', 'Dboth'])
|
61 |
+
do_Gpl = (phase in ['Greg', 'Gboth']) and (self.pl_weight != 0)
|
62 |
+
do_Dr1 = (phase in ['Dreg', 'Dboth']) and (self.r1_gamma != 0)
|
63 |
+
|
64 |
+
# Gmain: Maximize logits for generated images.
|
65 |
+
if do_Gmain:
|
66 |
+
with torch.autograd.profiler.record_function('Gmain_forward'):
|
67 |
+
gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync=(sync and not do_Gpl)) # May get synced by Gpl.
|
68 |
+
gen_logits = self.run_D(gen_img, gen_c, sync=False)
|
69 |
+
training_stats.report('Loss/scores/fake', gen_logits)
|
70 |
+
training_stats.report('Loss/signs/fake', gen_logits.sign())
|
71 |
+
loss_Gmain = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits))
|
72 |
+
training_stats.report('Loss/G/loss', loss_Gmain)
|
73 |
+
with torch.autograd.profiler.record_function('Gmain_backward'):
|
74 |
+
loss_Gmain.mean().mul(gain).backward()
|
75 |
+
|
76 |
+
# Gpl: Apply path length regularization.
|
77 |
+
if do_Gpl:
|
78 |
+
with torch.autograd.profiler.record_function('Gpl_forward'):
|
79 |
+
batch_size = gen_z.shape[0] // self.pl_batch_shrink
|
80 |
+
gen_img, gen_ws = self.run_G(gen_z[:batch_size], gen_c[:batch_size], sync=sync)
|
81 |
+
pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3])
|
82 |
+
with torch.autograd.profiler.record_function('pl_grads'), conv2d_gradfix.no_weight_gradients():
|
83 |
+
pl_grads = torch.autograd.grad(outputs=[(gen_img * pl_noise).sum()], inputs=[gen_ws], create_graph=True, only_inputs=True)[0]
|
84 |
+
pl_lengths = pl_grads.square().sum(2).mean(1).sqrt()
|
85 |
+
pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay)
|
86 |
+
self.pl_mean.copy_(pl_mean.detach())
|
87 |
+
pl_penalty = (pl_lengths - pl_mean).square()
|
88 |
+
training_stats.report('Loss/pl_penalty', pl_penalty)
|
89 |
+
loss_Gpl = pl_penalty * self.pl_weight
|
90 |
+
training_stats.report('Loss/G/reg', loss_Gpl)
|
91 |
+
with torch.autograd.profiler.record_function('Gpl_backward'):
|
92 |
+
(gen_img[:, 0, 0, 0] * 0 + loss_Gpl).mean().mul(gain).backward()
|
93 |
+
|
94 |
+
# Dmain: Minimize logits for generated images.
|
95 |
+
loss_Dgen = 0
|
96 |
+
if do_Dmain:
|
97 |
+
with torch.autograd.profiler.record_function('Dgen_forward'):
|
98 |
+
gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync=False)
|
99 |
+
gen_logits = self.run_D(gen_img, gen_c, sync=False) # Gets synced by loss_Dreal.
|
100 |
+
training_stats.report('Loss/scores/fake', gen_logits)
|
101 |
+
training_stats.report('Loss/signs/fake', gen_logits.sign())
|
102 |
+
loss_Dgen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits))
|
103 |
+
with torch.autograd.profiler.record_function('Dgen_backward'):
|
104 |
+
loss_Dgen.mean().mul(gain).backward()
|
105 |
+
|
106 |
+
# Dmain: Maximize logits for real images.
|
107 |
+
# Dr1: Apply R1 regularization.
|
108 |
+
if do_Dmain or do_Dr1:
|
109 |
+
name = 'Dreal_Dr1' if do_Dmain and do_Dr1 else 'Dreal' if do_Dmain else 'Dr1'
|
110 |
+
with torch.autograd.profiler.record_function(name + '_forward'):
|
111 |
+
real_img_tmp = real_img.detach().requires_grad_(do_Dr1)
|
112 |
+
real_logits = self.run_D(real_img_tmp, real_c, sync=sync)
|
113 |
+
training_stats.report('Loss/scores/real', real_logits)
|
114 |
+
training_stats.report('Loss/signs/real', real_logits.sign())
|
115 |
+
|
116 |
+
loss_Dreal = 0
|
117 |
+
if do_Dmain:
|
118 |
+
loss_Dreal = torch.nn.functional.softplus(-real_logits) # -log(sigmoid(real_logits))
|
119 |
+
training_stats.report('Loss/D/loss', loss_Dgen + loss_Dreal)
|
120 |
+
|
121 |
+
loss_Dr1 = 0
|
122 |
+
if do_Dr1:
|
123 |
+
with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients():
|
124 |
+
r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp], create_graph=True, only_inputs=True)[0]
|
125 |
+
r1_penalty = r1_grads.square().sum([1,2,3])
|
126 |
+
loss_Dr1 = r1_penalty * (self.r1_gamma / 2)
|
127 |
+
training_stats.report('Loss/r1_penalty', r1_penalty)
|
128 |
+
training_stats.report('Loss/D/reg', loss_Dr1)
|
129 |
+
|
130 |
+
with torch.autograd.profiler.record_function(name + '_backward'):
|
131 |
+
(real_logits * 0 + loss_Dreal + loss_Dr1).mean().mul(gain).backward()
|
132 |
+
|
133 |
+
#----------------------------------------------------------------------------
|
training/networks.py
ADDED
@@ -0,0 +1,729 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
from torch_utils import misc
|
12 |
+
from torch_utils import persistence
|
13 |
+
from torch_utils.ops import conv2d_resample
|
14 |
+
from torch_utils.ops import upfirdn2d
|
15 |
+
from torch_utils.ops import bias_act
|
16 |
+
from torch_utils.ops import fma
|
17 |
+
|
18 |
+
#----------------------------------------------------------------------------
|
19 |
+
|
20 |
+
@misc.profiled_function
|
21 |
+
def normalize_2nd_moment(x, dim=1, eps=1e-8):
|
22 |
+
return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt()
|
23 |
+
|
24 |
+
#----------------------------------------------------------------------------
|
25 |
+
|
26 |
+
@misc.profiled_function
|
27 |
+
def modulated_conv2d(
|
28 |
+
x, # Input tensor of shape [batch_size, in_channels, in_height, in_width].
|
29 |
+
weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width].
|
30 |
+
styles, # Modulation coefficients of shape [batch_size, in_channels].
|
31 |
+
noise = None, # Optional noise tensor to add to the output activations.
|
32 |
+
up = 1, # Integer upsampling factor.
|
33 |
+
down = 1, # Integer downsampling factor.
|
34 |
+
padding = 0, # Padding with respect to the upsampled image.
|
35 |
+
resample_filter = None, # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter().
|
36 |
+
demodulate = True, # Apply weight demodulation?
|
37 |
+
flip_weight = True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d).
|
38 |
+
fused_modconv = True, # Perform modulation, convolution, and demodulation as a single fused operation?
|
39 |
+
):
|
40 |
+
batch_size = x.shape[0]
|
41 |
+
out_channels, in_channels, kh, kw = weight.shape
|
42 |
+
misc.assert_shape(weight, [out_channels, in_channels, kh, kw]) # [OIkk]
|
43 |
+
misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW]
|
44 |
+
misc.assert_shape(styles, [batch_size, in_channels]) # [NI]
|
45 |
+
|
46 |
+
# Pre-normalize inputs to avoid FP16 overflow.
|
47 |
+
if x.dtype == torch.float16 and demodulate:
|
48 |
+
weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1,2,3], keepdim=True)) # max_Ikk
|
49 |
+
styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I
|
50 |
+
|
51 |
+
# Calculate per-sample weights and demodulation coefficients.
|
52 |
+
w = None
|
53 |
+
dcoefs = None
|
54 |
+
if demodulate or fused_modconv:
|
55 |
+
w = weight.unsqueeze(0) # [NOIkk]
|
56 |
+
w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk]
|
57 |
+
if demodulate:
|
58 |
+
dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO]
|
59 |
+
if demodulate and fused_modconv:
|
60 |
+
w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk]
|
61 |
+
|
62 |
+
# Execute by scaling the activations before and after the convolution.
|
63 |
+
if not fused_modconv:
|
64 |
+
x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1)
|
65 |
+
x = conv2d_resample.conv2d_resample(x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, flip_weight=flip_weight)
|
66 |
+
if demodulate and noise is not None:
|
67 |
+
x = fma.fma(x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype))
|
68 |
+
elif demodulate:
|
69 |
+
x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1)
|
70 |
+
elif noise is not None:
|
71 |
+
x = x.add_(noise.to(x.dtype))
|
72 |
+
return x
|
73 |
+
|
74 |
+
# Execute as one fused op using grouped convolution.
|
75 |
+
with misc.suppress_tracer_warnings(): # this value will be treated as a constant
|
76 |
+
batch_size = int(batch_size)
|
77 |
+
misc.assert_shape(x, [batch_size, in_channels, None, None])
|
78 |
+
x = x.reshape(1, -1, *x.shape[2:])
|
79 |
+
w = w.reshape(-1, in_channels, kh, kw)
|
80 |
+
x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, groups=batch_size, flip_weight=flip_weight)
|
81 |
+
x = x.reshape(batch_size, -1, *x.shape[2:])
|
82 |
+
if noise is not None:
|
83 |
+
x = x.add_(noise)
|
84 |
+
return x
|
85 |
+
|
86 |
+
#----------------------------------------------------------------------------
|
87 |
+
|
88 |
+
@persistence.persistent_class
|
89 |
+
class FullyConnectedLayer(torch.nn.Module):
|
90 |
+
def __init__(self,
|
91 |
+
in_features, # Number of input features.
|
92 |
+
out_features, # Number of output features.
|
93 |
+
bias = True, # Apply additive bias before the activation function?
|
94 |
+
activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
|
95 |
+
lr_multiplier = 1, # Learning rate multiplier.
|
96 |
+
bias_init = 0, # Initial value for the additive bias.
|
97 |
+
):
|
98 |
+
super().__init__()
|
99 |
+
self.activation = activation
|
100 |
+
self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier)
|
101 |
+
self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None
|
102 |
+
self.weight_gain = lr_multiplier / np.sqrt(in_features)
|
103 |
+
self.bias_gain = lr_multiplier
|
104 |
+
|
105 |
+
def forward(self, x):
|
106 |
+
w = self.weight.to(x.dtype) * self.weight_gain
|
107 |
+
b = self.bias
|
108 |
+
if b is not None:
|
109 |
+
b = b.to(x.dtype)
|
110 |
+
if self.bias_gain != 1:
|
111 |
+
b = b * self.bias_gain
|
112 |
+
|
113 |
+
if self.activation == 'linear' and b is not None:
|
114 |
+
x = torch.addmm(b.unsqueeze(0), x, w.t())
|
115 |
+
else:
|
116 |
+
x = x.matmul(w.t())
|
117 |
+
x = bias_act.bias_act(x, b, act=self.activation)
|
118 |
+
return x
|
119 |
+
|
120 |
+
#----------------------------------------------------------------------------
|
121 |
+
|
122 |
+
@persistence.persistent_class
|
123 |
+
class Conv2dLayer(torch.nn.Module):
|
124 |
+
def __init__(self,
|
125 |
+
in_channels, # Number of input channels.
|
126 |
+
out_channels, # Number of output channels.
|
127 |
+
kernel_size, # Width and height of the convolution kernel.
|
128 |
+
bias = True, # Apply additive bias before the activation function?
|
129 |
+
activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
|
130 |
+
up = 1, # Integer upsampling factor.
|
131 |
+
down = 1, # Integer downsampling factor.
|
132 |
+
resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
|
133 |
+
conv_clamp = None, # Clamp the output to +-X, None = disable clamping.
|
134 |
+
channels_last = False, # Expect the input to have memory_format=channels_last?
|
135 |
+
trainable = True, # Update the weights of this layer during training?
|
136 |
+
):
|
137 |
+
super().__init__()
|
138 |
+
self.activation = activation
|
139 |
+
self.up = up
|
140 |
+
self.down = down
|
141 |
+
self.conv_clamp = conv_clamp
|
142 |
+
self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
|
143 |
+
self.padding = kernel_size // 2
|
144 |
+
self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
|
145 |
+
self.act_gain = bias_act.activation_funcs[activation].def_gain
|
146 |
+
|
147 |
+
memory_format = torch.channels_last if channels_last else torch.contiguous_format
|
148 |
+
weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)
|
149 |
+
bias = torch.zeros([out_channels]) if bias else None
|
150 |
+
if trainable:
|
151 |
+
self.weight = torch.nn.Parameter(weight)
|
152 |
+
self.bias = torch.nn.Parameter(bias) if bias is not None else None
|
153 |
+
else:
|
154 |
+
self.register_buffer('weight', weight)
|
155 |
+
if bias is not None:
|
156 |
+
self.register_buffer('bias', bias)
|
157 |
+
else:
|
158 |
+
self.bias = None
|
159 |
+
|
160 |
+
def forward(self, x, gain=1):
|
161 |
+
w = self.weight * self.weight_gain
|
162 |
+
b = self.bias.to(x.dtype) if self.bias is not None else None
|
163 |
+
flip_weight = (self.up == 1) # slightly faster
|
164 |
+
x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=self.resample_filter, up=self.up, down=self.down, padding=self.padding, flip_weight=flip_weight)
|
165 |
+
|
166 |
+
act_gain = self.act_gain * gain
|
167 |
+
act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
|
168 |
+
x = bias_act.bias_act(x, b, act=self.activation, gain=act_gain, clamp=act_clamp)
|
169 |
+
return x
|
170 |
+
|
171 |
+
#----------------------------------------------------------------------------
|
172 |
+
|
173 |
+
@persistence.persistent_class
|
174 |
+
class MappingNetwork(torch.nn.Module):
|
175 |
+
def __init__(self,
|
176 |
+
z_dim, # Input latent (Z) dimensionality, 0 = no latent.
|
177 |
+
c_dim, # Conditioning label (C) dimensionality, 0 = no label.
|
178 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
179 |
+
num_ws, # Number of intermediate latents to output, None = do not broadcast.
|
180 |
+
num_layers = 8, # Number of mapping layers.
|
181 |
+
embed_features = None, # Label embedding dimensionality, None = same as w_dim.
|
182 |
+
layer_features = None, # Number of intermediate features in the mapping layers, None = same as w_dim.
|
183 |
+
activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
|
184 |
+
lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers.
|
185 |
+
w_avg_beta = 0.995, # Decay for tracking the moving average of W during training, None = do not track.
|
186 |
+
):
|
187 |
+
super().__init__()
|
188 |
+
self.z_dim = z_dim
|
189 |
+
self.c_dim = c_dim
|
190 |
+
self.w_dim = w_dim
|
191 |
+
self.num_ws = num_ws
|
192 |
+
self.num_layers = num_layers
|
193 |
+
self.w_avg_beta = w_avg_beta
|
194 |
+
|
195 |
+
if embed_features is None:
|
196 |
+
embed_features = w_dim
|
197 |
+
if c_dim == 0:
|
198 |
+
embed_features = 0
|
199 |
+
if layer_features is None:
|
200 |
+
layer_features = w_dim
|
201 |
+
features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]
|
202 |
+
|
203 |
+
if c_dim > 0:
|
204 |
+
self.embed = FullyConnectedLayer(c_dim, embed_features)
|
205 |
+
for idx in range(num_layers):
|
206 |
+
in_features = features_list[idx]
|
207 |
+
out_features = features_list[idx + 1]
|
208 |
+
layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier)
|
209 |
+
setattr(self, f'fc{idx}', layer)
|
210 |
+
|
211 |
+
if num_ws is not None and w_avg_beta is not None:
|
212 |
+
self.register_buffer('w_avg', torch.zeros([w_dim]))
|
213 |
+
|
214 |
+
def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False):
|
215 |
+
# Embed, normalize, and concat inputs.
|
216 |
+
x = None
|
217 |
+
with torch.autograd.profiler.record_function('input'):
|
218 |
+
if self.z_dim > 0:
|
219 |
+
misc.assert_shape(z, [None, self.z_dim])
|
220 |
+
x = normalize_2nd_moment(z.to(torch.float32))
|
221 |
+
if self.c_dim > 0:
|
222 |
+
misc.assert_shape(c, [None, self.c_dim])
|
223 |
+
y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
|
224 |
+
x = torch.cat([x, y], dim=1) if x is not None else y
|
225 |
+
|
226 |
+
# Main layers.
|
227 |
+
for idx in range(self.num_layers):
|
228 |
+
layer = getattr(self, f'fc{idx}')
|
229 |
+
x = layer(x)
|
230 |
+
|
231 |
+
# Update moving average of W.
|
232 |
+
if self.w_avg_beta is not None and self.training and not skip_w_avg_update:
|
233 |
+
with torch.autograd.profiler.record_function('update_w_avg'):
|
234 |
+
self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
|
235 |
+
|
236 |
+
# Broadcast.
|
237 |
+
if self.num_ws is not None:
|
238 |
+
with torch.autograd.profiler.record_function('broadcast'):
|
239 |
+
x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
|
240 |
+
|
241 |
+
# Apply truncation.
|
242 |
+
if truncation_psi != 1:
|
243 |
+
with torch.autograd.profiler.record_function('truncate'):
|
244 |
+
assert self.w_avg_beta is not None
|
245 |
+
if self.num_ws is None or truncation_cutoff is None:
|
246 |
+
x = self.w_avg.lerp(x, truncation_psi)
|
247 |
+
else:
|
248 |
+
x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)
|
249 |
+
return x
|
250 |
+
|
251 |
+
#----------------------------------------------------------------------------
|
252 |
+
|
253 |
+
@persistence.persistent_class
|
254 |
+
class SynthesisLayer(torch.nn.Module):
|
255 |
+
def __init__(self,
|
256 |
+
in_channels, # Number of input channels.
|
257 |
+
out_channels, # Number of output channels.
|
258 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
259 |
+
resolution, # Resolution of this layer.
|
260 |
+
kernel_size = 3, # Convolution kernel size.
|
261 |
+
up = 1, # Integer upsampling factor.
|
262 |
+
use_noise = True, # Enable noise input?
|
263 |
+
activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
|
264 |
+
resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
|
265 |
+
conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
266 |
+
channels_last = False, # Use channels_last format for the weights?
|
267 |
+
):
|
268 |
+
super().__init__()
|
269 |
+
self.resolution = resolution
|
270 |
+
self.up = up
|
271 |
+
self.use_noise = use_noise
|
272 |
+
self.activation = activation
|
273 |
+
self.conv_clamp = conv_clamp
|
274 |
+
self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
|
275 |
+
self.padding = kernel_size // 2
|
276 |
+
self.act_gain = bias_act.activation_funcs[activation].def_gain
|
277 |
+
|
278 |
+
self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
|
279 |
+
memory_format = torch.channels_last if channels_last else torch.contiguous_format
|
280 |
+
self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))
|
281 |
+
if use_noise:
|
282 |
+
self.register_buffer('noise_const', torch.randn([resolution, resolution]))
|
283 |
+
self.noise_strength = torch.nn.Parameter(torch.zeros([]))
|
284 |
+
self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
|
285 |
+
|
286 |
+
def forward(self, x, w, noise_mode='random', fused_modconv=True, gain=1):
|
287 |
+
assert noise_mode in ['random', 'const', 'none']
|
288 |
+
in_resolution = self.resolution // self.up
|
289 |
+
misc.assert_shape(x, [None, self.weight.shape[1], in_resolution, in_resolution])
|
290 |
+
styles = self.affine(w)
|
291 |
+
|
292 |
+
noise = None
|
293 |
+
if self.use_noise and noise_mode == 'random':
|
294 |
+
noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution], device=x.device) * self.noise_strength
|
295 |
+
if self.use_noise and noise_mode == 'const':
|
296 |
+
noise = self.noise_const * self.noise_strength
|
297 |
+
|
298 |
+
flip_weight = (self.up == 1) # slightly faster
|
299 |
+
x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=noise, up=self.up,
|
300 |
+
padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight, fused_modconv=fused_modconv)
|
301 |
+
|
302 |
+
act_gain = self.act_gain * gain
|
303 |
+
act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
|
304 |
+
x = bias_act.bias_act(x, self.bias.to(x.dtype), act=self.activation, gain=act_gain, clamp=act_clamp)
|
305 |
+
return x
|
306 |
+
|
307 |
+
#----------------------------------------------------------------------------
|
308 |
+
|
309 |
+
@persistence.persistent_class
|
310 |
+
class ToRGBLayer(torch.nn.Module):
|
311 |
+
def __init__(self, in_channels, out_channels, w_dim, kernel_size=1, conv_clamp=None, channels_last=False):
|
312 |
+
super().__init__()
|
313 |
+
self.conv_clamp = conv_clamp
|
314 |
+
self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
|
315 |
+
memory_format = torch.channels_last if channels_last else torch.contiguous_format
|
316 |
+
self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))
|
317 |
+
self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
|
318 |
+
self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
|
319 |
+
|
320 |
+
def forward(self, x, w, fused_modconv=True):
|
321 |
+
styles = self.affine(w) * self.weight_gain
|
322 |
+
x = modulated_conv2d(x=x, weight=self.weight, styles=styles, demodulate=False, fused_modconv=fused_modconv)
|
323 |
+
x = bias_act.bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp)
|
324 |
+
return x
|
325 |
+
|
326 |
+
#----------------------------------------------------------------------------
|
327 |
+
|
328 |
+
@persistence.persistent_class
|
329 |
+
class SynthesisBlock(torch.nn.Module):
|
330 |
+
def __init__(self,
|
331 |
+
in_channels, # Number of input channels, 0 = first block.
|
332 |
+
out_channels, # Number of output channels.
|
333 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
334 |
+
resolution, # Resolution of this block.
|
335 |
+
img_channels, # Number of output color channels.
|
336 |
+
is_last, # Is this the last block?
|
337 |
+
architecture = 'skip', # Architecture: 'orig', 'skip', 'resnet'.
|
338 |
+
resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
|
339 |
+
conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
340 |
+
use_fp16 = False, # Use FP16 for this block?
|
341 |
+
fp16_channels_last = False, # Use channels-last memory format with FP16?
|
342 |
+
**layer_kwargs, # Arguments for SynthesisLayer.
|
343 |
+
):
|
344 |
+
assert architecture in ['orig', 'skip', 'resnet']
|
345 |
+
super().__init__()
|
346 |
+
self.in_channels = in_channels
|
347 |
+
self.w_dim = w_dim
|
348 |
+
self.resolution = resolution
|
349 |
+
self.img_channels = img_channels
|
350 |
+
self.is_last = is_last
|
351 |
+
self.architecture = architecture
|
352 |
+
self.use_fp16 = use_fp16
|
353 |
+
self.channels_last = (use_fp16 and fp16_channels_last)
|
354 |
+
self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
|
355 |
+
self.num_conv = 0
|
356 |
+
self.num_torgb = 0
|
357 |
+
|
358 |
+
if in_channels == 0:
|
359 |
+
self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution]))
|
360 |
+
|
361 |
+
if in_channels != 0:
|
362 |
+
self.conv0 = SynthesisLayer(in_channels, out_channels, w_dim=w_dim, resolution=resolution, up=2,
|
363 |
+
resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs)
|
364 |
+
self.num_conv += 1
|
365 |
+
|
366 |
+
self.conv1 = SynthesisLayer(out_channels, out_channels, w_dim=w_dim, resolution=resolution,
|
367 |
+
conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs)
|
368 |
+
self.num_conv += 1
|
369 |
+
|
370 |
+
if is_last or architecture == 'skip':
|
371 |
+
self.torgb = ToRGBLayer(out_channels, img_channels, w_dim=w_dim,
|
372 |
+
conv_clamp=conv_clamp, channels_last=self.channels_last)
|
373 |
+
self.num_torgb += 1
|
374 |
+
|
375 |
+
if in_channels != 0 and architecture == 'resnet':
|
376 |
+
self.skip = Conv2dLayer(in_channels, out_channels, kernel_size=1, bias=False, up=2,
|
377 |
+
resample_filter=resample_filter, channels_last=self.channels_last)
|
378 |
+
|
379 |
+
def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, **layer_kwargs):
|
380 |
+
misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim])
|
381 |
+
w_iter = iter(ws.unbind(dim=1))
|
382 |
+
dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
|
383 |
+
memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
|
384 |
+
if fused_modconv is None:
|
385 |
+
with misc.suppress_tracer_warnings(): # this value will be treated as a constant
|
386 |
+
fused_modconv = (not self.training) and (dtype == torch.float32 or int(x.shape[0]) == 1)
|
387 |
+
|
388 |
+
# Input.
|
389 |
+
if self.in_channels == 0:
|
390 |
+
x = self.const.to(dtype=dtype, memory_format=memory_format)
|
391 |
+
x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1])
|
392 |
+
else:
|
393 |
+
misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2])
|
394 |
+
x = x.to(dtype=dtype, memory_format=memory_format)
|
395 |
+
|
396 |
+
# Main layers.
|
397 |
+
if self.in_channels == 0:
|
398 |
+
x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
|
399 |
+
elif self.architecture == 'resnet':
|
400 |
+
y = self.skip(x, gain=np.sqrt(0.5))
|
401 |
+
x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
|
402 |
+
x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs)
|
403 |
+
x = y.add_(x)
|
404 |
+
else:
|
405 |
+
x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
|
406 |
+
x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
|
407 |
+
|
408 |
+
# ToRGB.
|
409 |
+
if img is not None:
|
410 |
+
misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2])
|
411 |
+
img = upfirdn2d.upsample2d(img, self.resample_filter)
|
412 |
+
if self.is_last or self.architecture == 'skip':
|
413 |
+
y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv)
|
414 |
+
y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format)
|
415 |
+
img = img.add_(y) if img is not None else y
|
416 |
+
|
417 |
+
assert x.dtype == dtype
|
418 |
+
assert img is None or img.dtype == torch.float32
|
419 |
+
return x, img
|
420 |
+
|
421 |
+
#----------------------------------------------------------------------------
|
422 |
+
|
423 |
+
@persistence.persistent_class
|
424 |
+
class SynthesisNetwork(torch.nn.Module):
|
425 |
+
def __init__(self,
|
426 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
427 |
+
img_resolution, # Output image resolution.
|
428 |
+
img_channels, # Number of color channels.
|
429 |
+
channel_base = 32768, # Overall multiplier for the number of channels.
|
430 |
+
channel_max = 512, # Maximum number of channels in any layer.
|
431 |
+
num_fp16_res = 0, # Use FP16 for the N highest resolutions.
|
432 |
+
**block_kwargs, # Arguments for SynthesisBlock.
|
433 |
+
):
|
434 |
+
assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0
|
435 |
+
super().__init__()
|
436 |
+
self.w_dim = w_dim
|
437 |
+
self.img_resolution = img_resolution
|
438 |
+
self.img_resolution_log2 = int(np.log2(img_resolution))
|
439 |
+
self.img_channels = img_channels
|
440 |
+
self.block_resolutions = [2 ** i for i in range(2, self.img_resolution_log2 + 1)]
|
441 |
+
channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions}
|
442 |
+
fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
|
443 |
+
|
444 |
+
self.num_ws = 0
|
445 |
+
for res in self.block_resolutions:
|
446 |
+
in_channels = channels_dict[res // 2] if res > 4 else 0
|
447 |
+
out_channels = channels_dict[res]
|
448 |
+
use_fp16 = (res >= fp16_resolution)
|
449 |
+
is_last = (res == self.img_resolution)
|
450 |
+
block = SynthesisBlock(in_channels, out_channels, w_dim=w_dim, resolution=res,
|
451 |
+
img_channels=img_channels, is_last=is_last, use_fp16=use_fp16, **block_kwargs)
|
452 |
+
self.num_ws += block.num_conv
|
453 |
+
if is_last:
|
454 |
+
self.num_ws += block.num_torgb
|
455 |
+
setattr(self, f'b{res}', block)
|
456 |
+
|
457 |
+
def forward(self, ws, **block_kwargs):
|
458 |
+
block_ws = []
|
459 |
+
with torch.autograd.profiler.record_function('split_ws'):
|
460 |
+
misc.assert_shape(ws, [None, self.num_ws, self.w_dim])
|
461 |
+
ws = ws.to(torch.float32)
|
462 |
+
w_idx = 0
|
463 |
+
for res in self.block_resolutions:
|
464 |
+
block = getattr(self, f'b{res}')
|
465 |
+
block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb))
|
466 |
+
w_idx += block.num_conv
|
467 |
+
|
468 |
+
x = img = None
|
469 |
+
for res, cur_ws in zip(self.block_resolutions, block_ws):
|
470 |
+
block = getattr(self, f'b{res}')
|
471 |
+
x, img = block(x, img, cur_ws, **block_kwargs)
|
472 |
+
return img
|
473 |
+
|
474 |
+
#----------------------------------------------------------------------------
|
475 |
+
|
476 |
+
@persistence.persistent_class
|
477 |
+
class Generator(torch.nn.Module):
|
478 |
+
def __init__(self,
|
479 |
+
z_dim, # Input latent (Z) dimensionality.
|
480 |
+
c_dim, # Conditioning label (C) dimensionality.
|
481 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
482 |
+
img_resolution, # Output resolution.
|
483 |
+
img_channels, # Number of output color channels.
|
484 |
+
mapping_kwargs = {}, # Arguments for MappingNetwork.
|
485 |
+
synthesis_kwargs = {}, # Arguments for SynthesisNetwork.
|
486 |
+
):
|
487 |
+
super().__init__()
|
488 |
+
self.z_dim = z_dim
|
489 |
+
self.c_dim = c_dim
|
490 |
+
self.w_dim = w_dim
|
491 |
+
self.img_resolution = img_resolution
|
492 |
+
self.img_channels = img_channels
|
493 |
+
self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs)
|
494 |
+
self.num_ws = self.synthesis.num_ws
|
495 |
+
self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs)
|
496 |
+
|
497 |
+
def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, **synthesis_kwargs):
|
498 |
+
ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff)
|
499 |
+
img = self.synthesis(ws, **synthesis_kwargs)
|
500 |
+
return img
|
501 |
+
|
502 |
+
#----------------------------------------------------------------------------
|
503 |
+
|
504 |
+
@persistence.persistent_class
|
505 |
+
class DiscriminatorBlock(torch.nn.Module):
|
506 |
+
def __init__(self,
|
507 |
+
in_channels, # Number of input channels, 0 = first block.
|
508 |
+
tmp_channels, # Number of intermediate channels.
|
509 |
+
out_channels, # Number of output channels.
|
510 |
+
resolution, # Resolution of this block.
|
511 |
+
img_channels, # Number of input color channels.
|
512 |
+
first_layer_idx, # Index of the first layer.
|
513 |
+
architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
|
514 |
+
activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
|
515 |
+
resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
|
516 |
+
conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
517 |
+
use_fp16 = False, # Use FP16 for this block?
|
518 |
+
fp16_channels_last = False, # Use channels-last memory format with FP16?
|
519 |
+
freeze_layers = 0, # Freeze-D: Number of layers to freeze.
|
520 |
+
):
|
521 |
+
assert in_channels in [0, tmp_channels]
|
522 |
+
assert architecture in ['orig', 'skip', 'resnet']
|
523 |
+
super().__init__()
|
524 |
+
self.in_channels = in_channels
|
525 |
+
self.resolution = resolution
|
526 |
+
self.img_channels = img_channels
|
527 |
+
self.first_layer_idx = first_layer_idx
|
528 |
+
self.architecture = architecture
|
529 |
+
self.use_fp16 = use_fp16
|
530 |
+
self.channels_last = (use_fp16 and fp16_channels_last)
|
531 |
+
self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
|
532 |
+
|
533 |
+
self.num_layers = 0
|
534 |
+
def trainable_gen():
|
535 |
+
while True:
|
536 |
+
layer_idx = self.first_layer_idx + self.num_layers
|
537 |
+
trainable = (layer_idx >= freeze_layers)
|
538 |
+
self.num_layers += 1
|
539 |
+
yield trainable
|
540 |
+
trainable_iter = trainable_gen()
|
541 |
+
|
542 |
+
if in_channels == 0 or architecture == 'skip':
|
543 |
+
self.fromrgb = Conv2dLayer(img_channels, tmp_channels, kernel_size=1, activation=activation,
|
544 |
+
trainable=next(trainable_iter), conv_clamp=conv_clamp, channels_last=self.channels_last)
|
545 |
+
|
546 |
+
self.conv0 = Conv2dLayer(tmp_channels, tmp_channels, kernel_size=3, activation=activation,
|
547 |
+
trainable=next(trainable_iter), conv_clamp=conv_clamp, channels_last=self.channels_last)
|
548 |
+
|
549 |
+
self.conv1 = Conv2dLayer(tmp_channels, out_channels, kernel_size=3, activation=activation, down=2,
|
550 |
+
trainable=next(trainable_iter), resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last)
|
551 |
+
|
552 |
+
if architecture == 'resnet':
|
553 |
+
self.skip = Conv2dLayer(tmp_channels, out_channels, kernel_size=1, bias=False, down=2,
|
554 |
+
trainable=next(trainable_iter), resample_filter=resample_filter, channels_last=self.channels_last)
|
555 |
+
|
556 |
+
def forward(self, x, img, force_fp32=False):
|
557 |
+
dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
|
558 |
+
memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
|
559 |
+
|
560 |
+
# Input.
|
561 |
+
if x is not None:
|
562 |
+
misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution])
|
563 |
+
x = x.to(dtype=dtype, memory_format=memory_format)
|
564 |
+
|
565 |
+
# FromRGB.
|
566 |
+
if self.in_channels == 0 or self.architecture == 'skip':
|
567 |
+
misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution])
|
568 |
+
img = img.to(dtype=dtype, memory_format=memory_format)
|
569 |
+
y = self.fromrgb(img)
|
570 |
+
x = x + y if x is not None else y
|
571 |
+
img = upfirdn2d.downsample2d(img, self.resample_filter) if self.architecture == 'skip' else None
|
572 |
+
|
573 |
+
# Main layers.
|
574 |
+
if self.architecture == 'resnet':
|
575 |
+
y = self.skip(x, gain=np.sqrt(0.5))
|
576 |
+
x = self.conv0(x)
|
577 |
+
x = self.conv1(x, gain=np.sqrt(0.5))
|
578 |
+
x = y.add_(x)
|
579 |
+
else:
|
580 |
+
x = self.conv0(x)
|
581 |
+
x = self.conv1(x)
|
582 |
+
|
583 |
+
assert x.dtype == dtype
|
584 |
+
return x, img
|
585 |
+
|
586 |
+
#----------------------------------------------------------------------------
|
587 |
+
|
588 |
+
@persistence.persistent_class
|
589 |
+
class MinibatchStdLayer(torch.nn.Module):
|
590 |
+
def __init__(self, group_size, num_channels=1):
|
591 |
+
super().__init__()
|
592 |
+
self.group_size = group_size
|
593 |
+
self.num_channels = num_channels
|
594 |
+
|
595 |
+
def forward(self, x):
|
596 |
+
N, C, H, W = x.shape
|
597 |
+
with misc.suppress_tracer_warnings(): # as_tensor results are registered as constants
|
598 |
+
G = torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N)) if self.group_size is not None else N
|
599 |
+
F = self.num_channels
|
600 |
+
c = C // F
|
601 |
+
|
602 |
+
y = x.reshape(G, -1, F, c, H, W) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c.
|
603 |
+
y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group.
|
604 |
+
y = y.square().mean(dim=0) # [nFcHW] Calc variance over group.
|
605 |
+
y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group.
|
606 |
+
y = y.mean(dim=[2,3,4]) # [nF] Take average over channels and pixels.
|
607 |
+
y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions.
|
608 |
+
y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels.
|
609 |
+
x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels.
|
610 |
+
return x
|
611 |
+
|
612 |
+
#----------------------------------------------------------------------------
|
613 |
+
|
614 |
+
@persistence.persistent_class
|
615 |
+
class DiscriminatorEpilogue(torch.nn.Module):
|
616 |
+
def __init__(self,
|
617 |
+
in_channels, # Number of input channels.
|
618 |
+
cmap_dim, # Dimensionality of mapped conditioning label, 0 = no label.
|
619 |
+
resolution, # Resolution of this block.
|
620 |
+
img_channels, # Number of input color channels.
|
621 |
+
architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
|
622 |
+
mbstd_group_size = 4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
|
623 |
+
mbstd_num_channels = 1, # Number of features for the minibatch standard deviation layer, 0 = disable.
|
624 |
+
activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
|
625 |
+
conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
626 |
+
):
|
627 |
+
assert architecture in ['orig', 'skip', 'resnet']
|
628 |
+
super().__init__()
|
629 |
+
self.in_channels = in_channels
|
630 |
+
self.cmap_dim = cmap_dim
|
631 |
+
self.resolution = resolution
|
632 |
+
self.img_channels = img_channels
|
633 |
+
self.architecture = architecture
|
634 |
+
|
635 |
+
if architecture == 'skip':
|
636 |
+
self.fromrgb = Conv2dLayer(img_channels, in_channels, kernel_size=1, activation=activation)
|
637 |
+
self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels) if mbstd_num_channels > 0 else None
|
638 |
+
self.conv = Conv2dLayer(in_channels + mbstd_num_channels, in_channels, kernel_size=3, activation=activation, conv_clamp=conv_clamp)
|
639 |
+
self.fc = FullyConnectedLayer(in_channels * (resolution ** 2), in_channels, activation=activation)
|
640 |
+
self.out = FullyConnectedLayer(in_channels, 1 if cmap_dim == 0 else cmap_dim)
|
641 |
+
|
642 |
+
def forward(self, x, img, cmap, force_fp32=False):
|
643 |
+
misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution]) # [NCHW]
|
644 |
+
_ = force_fp32 # unused
|
645 |
+
dtype = torch.float32
|
646 |
+
memory_format = torch.contiguous_format
|
647 |
+
|
648 |
+
# FromRGB.
|
649 |
+
x = x.to(dtype=dtype, memory_format=memory_format)
|
650 |
+
if self.architecture == 'skip':
|
651 |
+
misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution])
|
652 |
+
img = img.to(dtype=dtype, memory_format=memory_format)
|
653 |
+
x = x + self.fromrgb(img)
|
654 |
+
|
655 |
+
# Main layers.
|
656 |
+
if self.mbstd is not None:
|
657 |
+
x = self.mbstd(x)
|
658 |
+
x = self.conv(x)
|
659 |
+
x = self.fc(x.flatten(1))
|
660 |
+
x = self.out(x)
|
661 |
+
|
662 |
+
# Conditioning.
|
663 |
+
if self.cmap_dim > 0:
|
664 |
+
misc.assert_shape(cmap, [None, self.cmap_dim])
|
665 |
+
x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
|
666 |
+
|
667 |
+
assert x.dtype == dtype
|
668 |
+
return x
|
669 |
+
|
670 |
+
#----------------------------------------------------------------------------
|
671 |
+
|
672 |
+
@persistence.persistent_class
|
673 |
+
class Discriminator(torch.nn.Module):
|
674 |
+
def __init__(self,
|
675 |
+
c_dim, # Conditioning label (C) dimensionality.
|
676 |
+
img_resolution, # Input resolution.
|
677 |
+
img_channels, # Number of input color channels.
|
678 |
+
architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
|
679 |
+
channel_base = 32768, # Overall multiplier for the number of channels.
|
680 |
+
channel_max = 512, # Maximum number of channels in any layer.
|
681 |
+
num_fp16_res = 0, # Use FP16 for the N highest resolutions.
|
682 |
+
conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
683 |
+
cmap_dim = None, # Dimensionality of mapped conditioning label, None = default.
|
684 |
+
block_kwargs = {}, # Arguments for DiscriminatorBlock.
|
685 |
+
mapping_kwargs = {}, # Arguments for MappingNetwork.
|
686 |
+
epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue.
|
687 |
+
):
|
688 |
+
super().__init__()
|
689 |
+
self.c_dim = c_dim
|
690 |
+
self.img_resolution = img_resolution
|
691 |
+
self.img_resolution_log2 = int(np.log2(img_resolution))
|
692 |
+
self.img_channels = img_channels
|
693 |
+
self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
|
694 |
+
channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
|
695 |
+
fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
|
696 |
+
|
697 |
+
if cmap_dim is None:
|
698 |
+
cmap_dim = channels_dict[4]
|
699 |
+
if c_dim == 0:
|
700 |
+
cmap_dim = 0
|
701 |
+
|
702 |
+
common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
|
703 |
+
cur_layer_idx = 0
|
704 |
+
for res in self.block_resolutions:
|
705 |
+
in_channels = channels_dict[res] if res < img_resolution else 0
|
706 |
+
tmp_channels = channels_dict[res]
|
707 |
+
out_channels = channels_dict[res // 2]
|
708 |
+
use_fp16 = (res >= fp16_resolution)
|
709 |
+
block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
|
710 |
+
first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
|
711 |
+
setattr(self, f'b{res}', block)
|
712 |
+
cur_layer_idx += block.num_layers
|
713 |
+
if c_dim > 0:
|
714 |
+
self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
|
715 |
+
self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs)
|
716 |
+
|
717 |
+
def forward(self, img, c, **block_kwargs):
|
718 |
+
x = None
|
719 |
+
for res in self.block_resolutions:
|
720 |
+
block = getattr(self, f'b{res}')
|
721 |
+
x, img = block(x, img, **block_kwargs)
|
722 |
+
|
723 |
+
cmap = None
|
724 |
+
if self.c_dim > 0:
|
725 |
+
cmap = self.mapping(None, c)
|
726 |
+
x = self.b4(x, img, cmap)
|
727 |
+
return x
|
728 |
+
|
729 |
+
#----------------------------------------------------------------------------
|
training/training_loop.py
ADDED
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
import os
|
10 |
+
import time
|
11 |
+
import copy
|
12 |
+
import json
|
13 |
+
import pickle
|
14 |
+
import psutil
|
15 |
+
import PIL.Image
|
16 |
+
import numpy as np
|
17 |
+
import torch
|
18 |
+
import dnnlib
|
19 |
+
from torch_utils import misc
|
20 |
+
from torch_utils import training_stats
|
21 |
+
from torch_utils.ops import conv2d_gradfix
|
22 |
+
from torch_utils.ops import grid_sample_gradfix
|
23 |
+
|
24 |
+
import legacy
|
25 |
+
from metrics import metric_main
|
26 |
+
|
27 |
+
#----------------------------------------------------------------------------
|
28 |
+
|
29 |
+
def setup_snapshot_image_grid(training_set, random_seed=0):
|
30 |
+
rnd = np.random.RandomState(random_seed)
|
31 |
+
gw = np.clip(7680 // training_set.image_shape[2], 7, 32)
|
32 |
+
gh = np.clip(4320 // training_set.image_shape[1], 4, 32)
|
33 |
+
|
34 |
+
# No labels => show random subset of training samples.
|
35 |
+
if not training_set.has_labels:
|
36 |
+
all_indices = list(range(len(training_set)))
|
37 |
+
rnd.shuffle(all_indices)
|
38 |
+
grid_indices = [all_indices[i % len(all_indices)] for i in range(gw * gh)]
|
39 |
+
|
40 |
+
else:
|
41 |
+
# Group training samples by label.
|
42 |
+
label_groups = dict() # label => [idx, ...]
|
43 |
+
for idx in range(len(training_set)):
|
44 |
+
label = tuple(training_set.get_details(idx).raw_label.flat[::-1])
|
45 |
+
if label not in label_groups:
|
46 |
+
label_groups[label] = []
|
47 |
+
label_groups[label].append(idx)
|
48 |
+
|
49 |
+
# Reorder.
|
50 |
+
label_order = sorted(label_groups.keys())
|
51 |
+
for label in label_order:
|
52 |
+
rnd.shuffle(label_groups[label])
|
53 |
+
|
54 |
+
# Organize into grid.
|
55 |
+
grid_indices = []
|
56 |
+
for y in range(gh):
|
57 |
+
label = label_order[y % len(label_order)]
|
58 |
+
indices = label_groups[label]
|
59 |
+
grid_indices += [indices[x % len(indices)] for x in range(gw)]
|
60 |
+
label_groups[label] = [indices[(i + gw) % len(indices)] for i in range(len(indices))]
|
61 |
+
|
62 |
+
# Load data.
|
63 |
+
images, labels = zip(*[training_set[i] for i in grid_indices])
|
64 |
+
return (gw, gh), np.stack(images), np.stack(labels)
|
65 |
+
|
66 |
+
#----------------------------------------------------------------------------
|
67 |
+
|
68 |
+
def save_image_grid(img, fname, drange, grid_size):
|
69 |
+
lo, hi = drange
|
70 |
+
img = np.asarray(img, dtype=np.float32)
|
71 |
+
img = (img - lo) * (255 / (hi - lo))
|
72 |
+
img = np.rint(img).clip(0, 255).astype(np.uint8)
|
73 |
+
|
74 |
+
gw, gh = grid_size
|
75 |
+
_N, C, H, W = img.shape
|
76 |
+
img = img.reshape(gh, gw, C, H, W)
|
77 |
+
img = img.transpose(0, 3, 1, 4, 2)
|
78 |
+
img = img.reshape(gh * H, gw * W, C)
|
79 |
+
|
80 |
+
assert C in [1, 3]
|
81 |
+
if C == 1:
|
82 |
+
PIL.Image.fromarray(img[:, :, 0], 'L').save(fname)
|
83 |
+
if C == 3:
|
84 |
+
PIL.Image.fromarray(img, 'RGB').save(fname)
|
85 |
+
|
86 |
+
#----------------------------------------------------------------------------
|
87 |
+
|
88 |
+
def training_loop(
|
89 |
+
run_dir = '.', # Output directory.
|
90 |
+
training_set_kwargs = {}, # Options for training set.
|
91 |
+
data_loader_kwargs = {}, # Options for torch.utils.data.DataLoader.
|
92 |
+
G_kwargs = {}, # Options for generator network.
|
93 |
+
D_kwargs = {}, # Options for discriminator network.
|
94 |
+
G_opt_kwargs = {}, # Options for generator optimizer.
|
95 |
+
D_opt_kwargs = {}, # Options for discriminator optimizer.
|
96 |
+
augment_kwargs = None, # Options for augmentation pipeline. None = disable.
|
97 |
+
loss_kwargs = {}, # Options for loss function.
|
98 |
+
metrics = [], # Metrics to evaluate during training.
|
99 |
+
random_seed = 0, # Global random seed.
|
100 |
+
num_gpus = 1, # Number of GPUs participating in the training.
|
101 |
+
rank = 0, # Rank of the current process in [0, num_gpus[.
|
102 |
+
batch_size = 4, # Total batch size for one training iteration. Can be larger than batch_gpu * num_gpus.
|
103 |
+
batch_gpu = 4, # Number of samples processed at a time by one GPU.
|
104 |
+
ema_kimg = 10, # Half-life of the exponential moving average (EMA) of generator weights.
|
105 |
+
ema_rampup = None, # EMA ramp-up coefficient.
|
106 |
+
G_reg_interval = 4, # How often to perform regularization for G? None = disable lazy regularization.
|
107 |
+
D_reg_interval = 16, # How often to perform regularization for D? None = disable lazy regularization.
|
108 |
+
augment_p = 0, # Initial value of augmentation probability.
|
109 |
+
ada_target = None, # ADA target value. None = fixed p.
|
110 |
+
ada_interval = 4, # How often to perform ADA adjustment?
|
111 |
+
ada_kimg = 500, # ADA adjustment speed, measured in how many kimg it takes for p to increase/decrease by one unit.
|
112 |
+
total_kimg = 25000, # Total length of the training, measured in thousands of real images.
|
113 |
+
kimg_per_tick = 4, # Progress snapshot interval.
|
114 |
+
image_snapshot_ticks = 50, # How often to save image snapshots? None = disable.
|
115 |
+
network_snapshot_ticks = 50, # How often to save network snapshots? None = disable.
|
116 |
+
resume_pkl = None, # Network pickle to resume training from.
|
117 |
+
cudnn_benchmark = True, # Enable torch.backends.cudnn.benchmark?
|
118 |
+
allow_tf32 = False, # Enable torch.backends.cuda.matmul.allow_tf32 and torch.backends.cudnn.allow_tf32?
|
119 |
+
abort_fn = None, # Callback function for determining whether to abort training. Must return consistent results across ranks.
|
120 |
+
progress_fn = None, # Callback function for updating training progress. Called for all ranks.
|
121 |
+
):
|
122 |
+
# Initialize.
|
123 |
+
start_time = time.time()
|
124 |
+
device = torch.device('cuda', rank)
|
125 |
+
np.random.seed(random_seed * num_gpus + rank)
|
126 |
+
torch.manual_seed(random_seed * num_gpus + rank)
|
127 |
+
torch.backends.cudnn.benchmark = cudnn_benchmark # Improves training speed.
|
128 |
+
torch.backends.cuda.matmul.allow_tf32 = allow_tf32 # Allow PyTorch to internally use tf32 for matmul
|
129 |
+
torch.backends.cudnn.allow_tf32 = allow_tf32 # Allow PyTorch to internally use tf32 for convolutions
|
130 |
+
conv2d_gradfix.enabled = True # Improves training speed.
|
131 |
+
grid_sample_gradfix.enabled = True # Avoids errors with the augmentation pipe.
|
132 |
+
|
133 |
+
# Load training set.
|
134 |
+
if rank == 0:
|
135 |
+
print('Loading training set...')
|
136 |
+
training_set = dnnlib.util.construct_class_by_name(**training_set_kwargs) # subclass of training.dataset.Dataset
|
137 |
+
training_set_sampler = misc.InfiniteSampler(dataset=training_set, rank=rank, num_replicas=num_gpus, seed=random_seed)
|
138 |
+
training_set_iterator = iter(torch.utils.data.DataLoader(dataset=training_set, sampler=training_set_sampler, batch_size=batch_size//num_gpus, **data_loader_kwargs))
|
139 |
+
if rank == 0:
|
140 |
+
print()
|
141 |
+
print('Num images: ', len(training_set))
|
142 |
+
print('Image shape:', training_set.image_shape)
|
143 |
+
print('Label shape:', training_set.label_shape)
|
144 |
+
print()
|
145 |
+
|
146 |
+
# Construct networks.
|
147 |
+
if rank == 0:
|
148 |
+
print('Constructing networks...')
|
149 |
+
common_kwargs = dict(c_dim=training_set.label_dim, img_resolution=training_set.resolution, img_channels=training_set.num_channels)
|
150 |
+
G = dnnlib.util.construct_class_by_name(**G_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
|
151 |
+
D = dnnlib.util.construct_class_by_name(**D_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
|
152 |
+
G_ema = copy.deepcopy(G).eval()
|
153 |
+
|
154 |
+
# Resume from existing pickle.
|
155 |
+
if (resume_pkl is not None) and (rank == 0):
|
156 |
+
print(f'Resuming from "{resume_pkl}"')
|
157 |
+
with dnnlib.util.open_url(resume_pkl) as f:
|
158 |
+
resume_data = legacy.load_network_pkl(f)
|
159 |
+
for name, module in [('G', G), ('D', D), ('G_ema', G_ema)]:
|
160 |
+
misc.copy_params_and_buffers(resume_data[name], module, require_all=False)
|
161 |
+
|
162 |
+
# Print network summary tables.
|
163 |
+
if rank == 0:
|
164 |
+
z = torch.empty([batch_gpu, G.z_dim], device=device)
|
165 |
+
c = torch.empty([batch_gpu, G.c_dim], device=device)
|
166 |
+
img = misc.print_module_summary(G, [z, c])
|
167 |
+
misc.print_module_summary(D, [img, c])
|
168 |
+
|
169 |
+
# Setup augmentation.
|
170 |
+
if rank == 0:
|
171 |
+
print('Setting up augmentation...')
|
172 |
+
augment_pipe = None
|
173 |
+
ada_stats = None
|
174 |
+
if (augment_kwargs is not None) and (augment_p > 0 or ada_target is not None):
|
175 |
+
augment_pipe = dnnlib.util.construct_class_by_name(**augment_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
|
176 |
+
augment_pipe.p.copy_(torch.as_tensor(augment_p))
|
177 |
+
if ada_target is not None:
|
178 |
+
ada_stats = training_stats.Collector(regex='Loss/signs/real')
|
179 |
+
|
180 |
+
# Distribute across GPUs.
|
181 |
+
if rank == 0:
|
182 |
+
print(f'Distributing across {num_gpus} GPUs...')
|
183 |
+
ddp_modules = dict()
|
184 |
+
for name, module in [('G_mapping', G.mapping), ('G_synthesis', G.synthesis), ('D', D), (None, G_ema), ('augment_pipe', augment_pipe)]:
|
185 |
+
if (num_gpus > 1) and (module is not None) and len(list(module.parameters())) != 0:
|
186 |
+
module.requires_grad_(True)
|
187 |
+
module = torch.nn.parallel.DistributedDataParallel(module, device_ids=[device], broadcast_buffers=False)
|
188 |
+
module.requires_grad_(False)
|
189 |
+
if name is not None:
|
190 |
+
ddp_modules[name] = module
|
191 |
+
|
192 |
+
# Setup training phases.
|
193 |
+
if rank == 0:
|
194 |
+
print('Setting up training phases...')
|
195 |
+
loss = dnnlib.util.construct_class_by_name(device=device, **ddp_modules, **loss_kwargs) # subclass of training.loss.Loss
|
196 |
+
phases = []
|
197 |
+
for name, module, opt_kwargs, reg_interval in [('G', G, G_opt_kwargs, G_reg_interval), ('D', D, D_opt_kwargs, D_reg_interval)]:
|
198 |
+
if reg_interval is None:
|
199 |
+
opt = dnnlib.util.construct_class_by_name(params=module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer
|
200 |
+
phases += [dnnlib.EasyDict(name=name+'both', module=module, opt=opt, interval=1)]
|
201 |
+
else: # Lazy regularization.
|
202 |
+
mb_ratio = reg_interval / (reg_interval + 1)
|
203 |
+
opt_kwargs = dnnlib.EasyDict(opt_kwargs)
|
204 |
+
opt_kwargs.lr = opt_kwargs.lr * mb_ratio
|
205 |
+
opt_kwargs.betas = [beta ** mb_ratio for beta in opt_kwargs.betas]
|
206 |
+
opt = dnnlib.util.construct_class_by_name(module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer
|
207 |
+
phases += [dnnlib.EasyDict(name=name+'main', module=module, opt=opt, interval=1)]
|
208 |
+
phases += [dnnlib.EasyDict(name=name+'reg', module=module, opt=opt, interval=reg_interval)]
|
209 |
+
for phase in phases:
|
210 |
+
phase.start_event = None
|
211 |
+
phase.end_event = None
|
212 |
+
if rank == 0:
|
213 |
+
phase.start_event = torch.cuda.Event(enable_timing=True)
|
214 |
+
phase.end_event = torch.cuda.Event(enable_timing=True)
|
215 |
+
|
216 |
+
# Export sample images.
|
217 |
+
grid_size = None
|
218 |
+
grid_z = None
|
219 |
+
grid_c = None
|
220 |
+
if rank == 0:
|
221 |
+
print('Exporting sample images...')
|
222 |
+
grid_size, images, labels = setup_snapshot_image_grid(training_set=training_set)
|
223 |
+
save_image_grid(images, os.path.join(run_dir, 'reals.png'), drange=[0,255], grid_size=grid_size)
|
224 |
+
grid_z = torch.randn([labels.shape[0], G.z_dim], device=device).split(batch_gpu)
|
225 |
+
grid_c = torch.from_numpy(labels).to(device).split(batch_gpu)
|
226 |
+
images = torch.cat([G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c)]).numpy()
|
227 |
+
save_image_grid(images, os.path.join(run_dir, 'fakes_init.png'), drange=[-1,1], grid_size=grid_size)
|
228 |
+
|
229 |
+
# Initialize logs.
|
230 |
+
if rank == 0:
|
231 |
+
print('Initializing logs...')
|
232 |
+
stats_collector = training_stats.Collector(regex='.*')
|
233 |
+
stats_metrics = dict()
|
234 |
+
stats_jsonl = None
|
235 |
+
stats_tfevents = None
|
236 |
+
if rank == 0:
|
237 |
+
stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'wt')
|
238 |
+
try:
|
239 |
+
import torch.utils.tensorboard as tensorboard
|
240 |
+
stats_tfevents = tensorboard.SummaryWriter(run_dir)
|
241 |
+
except ImportError as err:
|
242 |
+
print('Skipping tfevents export:', err)
|
243 |
+
|
244 |
+
# Train.
|
245 |
+
if rank == 0:
|
246 |
+
print(f'Training for {total_kimg} kimg...')
|
247 |
+
print()
|
248 |
+
cur_nimg = 0
|
249 |
+
cur_tick = 0
|
250 |
+
tick_start_nimg = cur_nimg
|
251 |
+
tick_start_time = time.time()
|
252 |
+
maintenance_time = tick_start_time - start_time
|
253 |
+
batch_idx = 0
|
254 |
+
if progress_fn is not None:
|
255 |
+
progress_fn(0, total_kimg)
|
256 |
+
while True:
|
257 |
+
|
258 |
+
# Fetch training data.
|
259 |
+
with torch.autograd.profiler.record_function('data_fetch'):
|
260 |
+
phase_real_img, phase_real_c = next(training_set_iterator)
|
261 |
+
phase_real_img = (phase_real_img.to(device).to(torch.float32) / 127.5 - 1).split(batch_gpu)
|
262 |
+
phase_real_c = phase_real_c.to(device).split(batch_gpu)
|
263 |
+
all_gen_z = torch.randn([len(phases) * batch_size, G.z_dim], device=device)
|
264 |
+
all_gen_z = [phase_gen_z.split(batch_gpu) for phase_gen_z in all_gen_z.split(batch_size)]
|
265 |
+
all_gen_c = [training_set.get_label(np.random.randint(len(training_set))) for _ in range(len(phases) * batch_size)]
|
266 |
+
all_gen_c = torch.from_numpy(np.stack(all_gen_c)).pin_memory().to(device)
|
267 |
+
all_gen_c = [phase_gen_c.split(batch_gpu) for phase_gen_c in all_gen_c.split(batch_size)]
|
268 |
+
|
269 |
+
# Execute training phases.
|
270 |
+
for phase, phase_gen_z, phase_gen_c in zip(phases, all_gen_z, all_gen_c):
|
271 |
+
if batch_idx % phase.interval != 0:
|
272 |
+
continue
|
273 |
+
|
274 |
+
# Initialize gradient accumulation.
|
275 |
+
if phase.start_event is not None:
|
276 |
+
phase.start_event.record(torch.cuda.current_stream(device))
|
277 |
+
phase.opt.zero_grad(set_to_none=True)
|
278 |
+
phase.module.requires_grad_(True)
|
279 |
+
|
280 |
+
# Accumulate gradients over multiple rounds.
|
281 |
+
for round_idx, (real_img, real_c, gen_z, gen_c) in enumerate(zip(phase_real_img, phase_real_c, phase_gen_z, phase_gen_c)):
|
282 |
+
sync = (round_idx == batch_size // (batch_gpu * num_gpus) - 1)
|
283 |
+
gain = phase.interval
|
284 |
+
loss.accumulate_gradients(phase=phase.name, real_img=real_img, real_c=real_c, gen_z=gen_z, gen_c=gen_c, sync=sync, gain=gain)
|
285 |
+
|
286 |
+
# Update weights.
|
287 |
+
phase.module.requires_grad_(False)
|
288 |
+
with torch.autograd.profiler.record_function(phase.name + '_opt'):
|
289 |
+
for param in phase.module.parameters():
|
290 |
+
if param.grad is not None:
|
291 |
+
misc.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad)
|
292 |
+
phase.opt.step()
|
293 |
+
if phase.end_event is not None:
|
294 |
+
phase.end_event.record(torch.cuda.current_stream(device))
|
295 |
+
|
296 |
+
# Update G_ema.
|
297 |
+
with torch.autograd.profiler.record_function('Gema'):
|
298 |
+
ema_nimg = ema_kimg * 1000
|
299 |
+
if ema_rampup is not None:
|
300 |
+
ema_nimg = min(ema_nimg, cur_nimg * ema_rampup)
|
301 |
+
ema_beta = 0.5 ** (batch_size / max(ema_nimg, 1e-8))
|
302 |
+
for p_ema, p in zip(G_ema.parameters(), G.parameters()):
|
303 |
+
p_ema.copy_(p.lerp(p_ema, ema_beta))
|
304 |
+
for b_ema, b in zip(G_ema.buffers(), G.buffers()):
|
305 |
+
b_ema.copy_(b)
|
306 |
+
|
307 |
+
# Update state.
|
308 |
+
cur_nimg += batch_size
|
309 |
+
batch_idx += 1
|
310 |
+
|
311 |
+
# Execute ADA heuristic.
|
312 |
+
if (ada_stats is not None) and (batch_idx % ada_interval == 0):
|
313 |
+
ada_stats.update()
|
314 |
+
adjust = np.sign(ada_stats['Loss/signs/real'] - ada_target) * (batch_size * ada_interval) / (ada_kimg * 1000)
|
315 |
+
augment_pipe.p.copy_((augment_pipe.p + adjust).max(misc.constant(0, device=device)))
|
316 |
+
|
317 |
+
# Perform maintenance tasks once per tick.
|
318 |
+
done = (cur_nimg >= total_kimg * 1000)
|
319 |
+
if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000):
|
320 |
+
continue
|
321 |
+
|
322 |
+
# Print status line, accumulating the same information in stats_collector.
|
323 |
+
tick_end_time = time.time()
|
324 |
+
fields = []
|
325 |
+
fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"]
|
326 |
+
fields += [f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<8.1f}"]
|
327 |
+
fields += [f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"]
|
328 |
+
fields += [f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}"]
|
329 |
+
fields += [f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}"]
|
330 |
+
fields += [f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}"]
|
331 |
+
fields += [f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"]
|
332 |
+
fields += [f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}"]
|
333 |
+
torch.cuda.reset_peak_memory_stats()
|
334 |
+
fields += [f"augment {training_stats.report0('Progress/augment', float(augment_pipe.p.cpu()) if augment_pipe is not None else 0):.3f}"]
|
335 |
+
training_stats.report0('Timing/total_hours', (tick_end_time - start_time) / (60 * 60))
|
336 |
+
training_stats.report0('Timing/total_days', (tick_end_time - start_time) / (24 * 60 * 60))
|
337 |
+
if rank == 0:
|
338 |
+
print(' '.join(fields))
|
339 |
+
|
340 |
+
# Check for abort.
|
341 |
+
if (not done) and (abort_fn is not None) and abort_fn():
|
342 |
+
done = True
|
343 |
+
if rank == 0:
|
344 |
+
print()
|
345 |
+
print('Aborting...')
|
346 |
+
|
347 |
+
# Save image snapshot.
|
348 |
+
if (rank == 0) and (image_snapshot_ticks is not None) and (done or cur_tick % image_snapshot_ticks == 0):
|
349 |
+
images = torch.cat([G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c)]).numpy()
|
350 |
+
save_image_grid(images, os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}.png'), drange=[-1,1], grid_size=grid_size)
|
351 |
+
|
352 |
+
# Save network snapshot.
|
353 |
+
snapshot_pkl = None
|
354 |
+
snapshot_data = None
|
355 |
+
if (network_snapshot_ticks is not None) and (done or cur_tick % network_snapshot_ticks == 0):
|
356 |
+
snapshot_data = dict(training_set_kwargs=dict(training_set_kwargs))
|
357 |
+
for name, module in [('G', G), ('D', D), ('G_ema', G_ema), ('augment_pipe', augment_pipe)]:
|
358 |
+
if module is not None:
|
359 |
+
if num_gpus > 1:
|
360 |
+
misc.check_ddp_consistency(module, ignore_regex=r'.*\.w_avg')
|
361 |
+
module = copy.deepcopy(module).eval().requires_grad_(False).cpu()
|
362 |
+
snapshot_data[name] = module
|
363 |
+
del module # conserve memory
|
364 |
+
snapshot_pkl = os.path.join(run_dir, f'network-snapshot-{cur_nimg//1000:06d}.pkl')
|
365 |
+
if rank == 0:
|
366 |
+
with open(snapshot_pkl, 'wb') as f:
|
367 |
+
pickle.dump(snapshot_data, f)
|
368 |
+
|
369 |
+
# Evaluate metrics.
|
370 |
+
if (snapshot_data is not None) and (len(metrics) > 0):
|
371 |
+
if rank == 0:
|
372 |
+
print('Evaluating metrics...')
|
373 |
+
for metric in metrics:
|
374 |
+
result_dict = metric_main.calc_metric(metric=metric, G=snapshot_data['G_ema'],
|
375 |
+
dataset_kwargs=training_set_kwargs, num_gpus=num_gpus, rank=rank, device=device)
|
376 |
+
if rank == 0:
|
377 |
+
metric_main.report_metric(result_dict, run_dir=run_dir, snapshot_pkl=snapshot_pkl)
|
378 |
+
stats_metrics.update(result_dict.results)
|
379 |
+
del snapshot_data # conserve memory
|
380 |
+
|
381 |
+
# Collect statistics.
|
382 |
+
for phase in phases:
|
383 |
+
value = []
|
384 |
+
if (phase.start_event is not None) and (phase.end_event is not None):
|
385 |
+
phase.end_event.synchronize()
|
386 |
+
value = phase.start_event.elapsed_time(phase.end_event)
|
387 |
+
training_stats.report0('Timing/' + phase.name, value)
|
388 |
+
stats_collector.update()
|
389 |
+
stats_dict = stats_collector.as_dict()
|
390 |
+
|
391 |
+
# Update logs.
|
392 |
+
timestamp = time.time()
|
393 |
+
if stats_jsonl is not None:
|
394 |
+
fields = dict(stats_dict, timestamp=timestamp)
|
395 |
+
stats_jsonl.write(json.dumps(fields) + '\n')
|
396 |
+
stats_jsonl.flush()
|
397 |
+
if stats_tfevents is not None:
|
398 |
+
global_step = int(cur_nimg / 1e3)
|
399 |
+
walltime = timestamp - start_time
|
400 |
+
for name, value in stats_dict.items():
|
401 |
+
stats_tfevents.add_scalar(name, value.mean, global_step=global_step, walltime=walltime)
|
402 |
+
for name, value in stats_metrics.items():
|
403 |
+
stats_tfevents.add_scalar(f'Metrics/{name}', value, global_step=global_step, walltime=walltime)
|
404 |
+
stats_tfevents.flush()
|
405 |
+
if progress_fn is not None:
|
406 |
+
progress_fn(cur_nimg // 1000, total_kimg)
|
407 |
+
|
408 |
+
# Update state.
|
409 |
+
cur_tick += 1
|
410 |
+
tick_start_nimg = cur_nimg
|
411 |
+
tick_start_time = time.time()
|
412 |
+
maintenance_time = tick_start_time - tick_end_time
|
413 |
+
if done:
|
414 |
+
break
|
415 |
+
|
416 |
+
# Done.
|
417 |
+
if rank == 0:
|
418 |
+
print()
|
419 |
+
print('Exiting...')
|
420 |
+
|
421 |
+
#----------------------------------------------------------------------------
|