akhaliq3 commited on
Commit
6d14893
1 Parent(s): 6747486

remove old files

Browse files
.gitattributes DELETED
@@ -1,27 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bin.* filter=lfs diff=lfs merge=lfs -text
5
- *.bz2 filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.model filter=lfs diff=lfs merge=lfs -text
12
- *.msgpack filter=lfs diff=lfs merge=lfs -text
13
- *.onnx filter=lfs diff=lfs merge=lfs -text
14
- *.ot filter=lfs diff=lfs merge=lfs -text
15
- *.parquet filter=lfs diff=lfs merge=lfs -text
16
- *.pb filter=lfs diff=lfs merge=lfs -text
17
- *.pt filter=lfs diff=lfs merge=lfs -text
18
- *.pth filter=lfs diff=lfs merge=lfs -text
19
- *.rar filter=lfs diff=lfs merge=lfs -text
20
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
- *.tar.* filter=lfs diff=lfs merge=lfs -text
22
- *.tflite filter=lfs diff=lfs merge=lfs -text
23
- *.tgz filter=lfs diff=lfs merge=lfs -text
24
- *.xz filter=lfs diff=lfs merge=lfs -text
25
- *.zip filter=lfs diff=lfs merge=lfs -text
26
- *.zstandard filter=lfs diff=lfs merge=lfs -text
27
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LICENSE.txt DELETED
@@ -1,97 +0,0 @@
1
- Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
2
-
3
-
4
- NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator Augmentation (ADA)
5
-
6
-
7
- =======================================================================
8
-
9
- 1. Definitions
10
-
11
- "Licensor" means any person or entity that distributes its Work.
12
-
13
- "Software" means the original work of authorship made available under
14
- this License.
15
-
16
- "Work" means the Software and any additions to or derivative works of
17
- the Software that are made available under this License.
18
-
19
- The terms "reproduce," "reproduction," "derivative works," and
20
- "distribution" have the meaning as provided under U.S. copyright law;
21
- provided, however, that for the purposes of this License, derivative
22
- works shall not include works that remain separable from, or merely
23
- link (or bind by name) to the interfaces of, the Work.
24
-
25
- Works, including the Software, are "made available" under this License
26
- by including in or with the Work either (a) a copyright notice
27
- referencing the applicability of this License to the Work, or (b) a
28
- copy of this License.
29
-
30
- 2. License Grants
31
-
32
- 2.1 Copyright Grant. Subject to the terms and conditions of this
33
- License, each Licensor grants to you a perpetual, worldwide,
34
- non-exclusive, royalty-free, copyright license to reproduce,
35
- prepare derivative works of, publicly display, publicly perform,
36
- sublicense and distribute its Work and any resulting derivative
37
- works in any form.
38
-
39
- 3. Limitations
40
-
41
- 3.1 Redistribution. You may reproduce or distribute the Work only
42
- if (a) you do so under this License, (b) you include a complete
43
- copy of this License with your distribution, and (c) you retain
44
- without modification any copyright, patent, trademark, or
45
- attribution notices that are present in the Work.
46
-
47
- 3.2 Derivative Works. You may specify that additional or different
48
- terms apply to the use, reproduction, and distribution of your
49
- derivative works of the Work ("Your Terms") only if (a) Your Terms
50
- provide that the use limitation in Section 3.3 applies to your
51
- derivative works, and (b) you identify the specific derivative
52
- works that are subject to Your Terms. Notwithstanding Your Terms,
53
- this License (including the redistribution requirements in Section
54
- 3.1) will continue to apply to the Work itself.
55
-
56
- 3.3 Use Limitation. The Work and any derivative works thereof only
57
- may be used or intended for use non-commercially. Notwithstanding
58
- the foregoing, NVIDIA and its affiliates may use the Work and any
59
- derivative works commercially. As used herein, "non-commercially"
60
- means for research or evaluation purposes only.
61
-
62
- 3.4 Patent Claims. If you bring or threaten to bring a patent claim
63
- against any Licensor (including any claim, cross-claim or
64
- counterclaim in a lawsuit) to enforce any patents that you allege
65
- are infringed by any Work, then your rights under this License from
66
- such Licensor (including the grant in Section 2.1) will terminate
67
- immediately.
68
-
69
- 3.5 Trademarks. This License does not grant any rights to use any
70
- Licensor’s or its affiliates’ names, logos, or trademarks, except
71
- as necessary to reproduce the notices described in this License.
72
-
73
- 3.6 Termination. If you violate any term of this License, then your
74
- rights under this License (including the grant in Section 2.1) will
75
- terminate immediately.
76
-
77
- 4. Disclaimer of Warranty.
78
-
79
- THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
80
- KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
81
- MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
82
- NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
83
- THIS LICENSE.
84
-
85
- 5. Limitation of Liability.
86
-
87
- EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
88
- THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
89
- SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
90
- INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
91
- OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
92
- (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
93
- LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
94
- COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
95
- THE POSSIBILITY OF SUCH DAMAGES.
96
-
97
- =======================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md DELETED
@@ -1,37 +0,0 @@
1
- ---
2
- title: StyleGAN_CLIP
3
- emoji: 🌍
4
- colorFrom: yellow
5
- colorTo: yellow
6
- sdk: gradio
7
- app_file: app.py
8
- pinned: false
9
- ---
10
-
11
- # Configuration
12
-
13
- `title`: _string_
14
- Display title for the Space
15
-
16
- `emoji`: _string_
17
- Space emoji (emoji-only character allowed)
18
-
19
- `colorFrom`: _string_
20
- Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
21
-
22
- `colorTo`: _string_
23
- Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
24
-
25
- `sdk`: _string_
26
- Can be either `gradio` or `streamlit`
27
-
28
- `sdk_version` : _string_
29
- Only applicable for `streamlit` SDK.
30
- See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
31
-
32
- `app_file`: _string_
33
- Path to your main application file (which contains either `gradio` or `streamlit` Python code).
34
- Path is relative to the root of the repository.
35
-
36
- `pinned`: _boolean_
37
- Whether the Space stays on top of your list.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py DELETED
@@ -1,203 +0,0 @@
1
- import os
2
- os.system("pip install --upgrade torch==1.9.1+cu111 torchvision==0.10.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html")
3
- os.system("git clone https://github.com/NVlabs/stylegan3")
4
- os.system("git clone https://github.com/openai/CLIP")
5
- os.system("pip install -e ./CLIP")
6
- os.system("pip install einops ninja scipy numpy Pillow tqdm")
7
-
8
- import sys
9
- sys.path.append('./CLIP')
10
- sys.path.append('./stylegan3')
11
-
12
- import io
13
- import os, time
14
- import pickle
15
- import shutil
16
- import numpy as np
17
- from PIL import Image
18
- import torch
19
- import torch.nn.functional as F
20
- import requests
21
- import torchvision.transforms as transforms
22
- import torchvision.transforms.functional as TF
23
- import clip
24
- from tqdm.notebook import tqdm
25
- from torchvision.transforms import Compose, Resize, ToTensor, Normalize
26
- from einops import rearrange
27
-
28
- device = torch.device('cuda:0')
29
-
30
-
31
- def fetch(url_or_path):
32
- if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
33
- r = requests.get(url_or_path)
34
- r.raise_for_status()
35
- fd = io.BytesIO()
36
- fd.write(r.content)
37
- fd.seek(0)
38
- return fd
39
- return open(url_or_path, 'rb')
40
-
41
- def fetch_model(url_or_path):
42
- basename = os.path.basename(url_or_path)
43
- if os.path.exists(basename):
44
- return basename
45
- else:
46
- os.system("wget -c '{url_or_path}'")
47
- return basename
48
-
49
- def norm1(prompt):
50
- "Normalize to the unit sphere."
51
- return prompt / prompt.square().sum(dim=-1,keepdim=True).sqrt()
52
-
53
- def spherical_dist_loss(x, y):
54
- x = F.normalize(x, dim=-1)
55
- y = F.normalize(y, dim=-1)
56
- return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
57
-
58
- class MakeCutouts(torch.nn.Module):
59
- def __init__(self, cut_size, cutn, cut_pow=1.):
60
- super().__init__()
61
- self.cut_size = cut_size
62
- self.cutn = cutn
63
- self.cut_pow = cut_pow
64
-
65
- def forward(self, input):
66
- sideY, sideX = input.shape[2:4]
67
- max_size = min(sideX, sideY)
68
- min_size = min(sideX, sideY, self.cut_size)
69
- cutouts = []
70
- for _ in range(self.cutn):
71
- size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
72
- offsetx = torch.randint(0, sideX - size + 1, ())
73
- offsety = torch.randint(0, sideY - size + 1, ())
74
- cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
75
- cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
76
- return torch.cat(cutouts)
77
-
78
- make_cutouts = MakeCutouts(224, 32, 0.5)
79
-
80
- def embed_image(image):
81
- n = image.shape[0]
82
- cutouts = make_cutouts(image)
83
- embeds = clip_model.embed_cutout(cutouts)
84
- embeds = rearrange(embeds, '(cc n) c -> cc n c', n=n)
85
- return embeds
86
-
87
- def embed_url(url):
88
- image = Image.open(fetch(url)).convert('RGB')
89
- return embed_image(TF.to_tensor(image).to(device).unsqueeze(0)).mean(0).squeeze(0)
90
-
91
- class CLIP(object):
92
- def __init__(self):
93
- clip_model = "ViT-B/32"
94
- self.model, _ = clip.load(clip_model)
95
- self.model = self.model.requires_grad_(False)
96
- self.normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
97
- std=[0.26862954, 0.26130258, 0.27577711])
98
-
99
- @torch.no_grad()
100
- def embed_text(self, prompt):
101
- "Normalized clip text embedding."
102
- return norm1(self.model.encode_text(clip.tokenize(prompt).to(device)).float())
103
-
104
- def embed_cutout(self, image):
105
- "Normalized clip image embedding."
106
- return norm1(self.model.encode_image(self.normalize(image)))
107
-
108
- clip_model = CLIP()
109
-
110
- # Load stylegan model
111
-
112
- base_url = "https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/"
113
- model_name = "stylegan3-t-ffhqu-1024x1024.pkl"
114
- #model_name = "stylegan3-r-metfacesu-1024x1024.pkl"
115
- #model_name = "stylegan3-t-afhqv2-512x512.pkl"
116
- network_url = base_url + model_name
117
-
118
- os.system("wget -c https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhqu-1024x1024.pkl")
119
-
120
- with open('stylegan3-t-ffhqu-1024x1024.pkl', 'rb') as fp:
121
- G = pickle.load(fp)['G_ema'].to(device)
122
-
123
- zs = torch.randn([10000, G.mapping.z_dim], device=device)
124
- w_stds = G.mapping(zs, None).std(0)
125
-
126
-
127
-
128
-
129
- def inference(text):
130
- target = clip_model.embed_text(text)
131
-
132
- steps = 600
133
- seed = 2
134
-
135
- tf = Compose([
136
- Resize(224),
137
- lambda x: torch.clamp((x+1)/2,min=0,max=1),
138
- ])
139
-
140
- torch.manual_seed(seed)
141
- timestring = time.strftime('%Y%m%d%H%M%S')
142
-
143
- with torch.no_grad():
144
- qs = []
145
- losses = []
146
- for _ in range(8):
147
- q = (G.mapping(torch.randn([4,G.mapping.z_dim], device=device), None, truncation_psi=0.7) - G.mapping.w_avg) / w_stds
148
- images = G.synthesis(q * w_stds + G.mapping.w_avg)
149
- embeds = embed_image(images.add(1).div(2))
150
- loss = spherical_dist_loss(embeds, target).mean(0)
151
- i = torch.argmin(loss)
152
- qs.append(q[i])
153
- losses.append(loss[i])
154
- qs = torch.stack(qs)
155
- losses = torch.stack(losses)
156
- print(losses)
157
- print(losses.shape, qs.shape)
158
- i = torch.argmin(losses)
159
- q = qs[i].unsqueeze(0)
160
-
161
- q.requires_grad_()
162
-
163
- q_ema = q
164
- opt = torch.optim.AdamW([q], lr=0.03, betas=(0.0,0.999))
165
- loop = tqdm(range(steps))
166
- for i in loop:
167
- opt.zero_grad()
168
- w = q * w_stds
169
- image = G.synthesis(w + G.mapping.w_avg, noise_mode='const')
170
- embed = embed_image(image.add(1).div(2))
171
- loss = spherical_dist_loss(embed, target).mean()
172
- loss.backward()
173
- opt.step()
174
- loop.set_postfix(loss=loss.item(), q_magnitude=q.std().item())
175
-
176
- q_ema = q_ema * 0.9 + q * 0.1
177
- image = G.synthesis(q_ema * w_stds + G.mapping.w_avg, noise_mode='const')
178
-
179
- if i % 10 == 0:
180
- display(TF.to_pil_image(tf(image)[0]))
181
- pil_image = TF.to_pil_image(image[0].add(1).div(2).clamp(0,1))
182
- #os.makedirs(f'samples/{timestring}', exist_ok=True)
183
- #pil_image.save(f'samples/{timestring}/{i:04}.jpg')
184
-
185
- return pil_image
186
-
187
-
188
-
189
- title = "StyleGAN+CLIP_with_Latent_Bootstraping"
190
- description = "Gradio demo for StyleGAN+CLIP_with_Latent_Bootstraping. To use it, simply add your text, or click one of the examples to load them. Read more at the links below."
191
- article = "<p style='text-align: center'>colab by https://twitter.com/EricHallahan <a href='https://colab.research.google.com/drive/1br7GP_D6XCgulxPTAFhwGaV-ijFe084X' target='_blank'>Colab</a></p>"
192
-
193
- examples = [['elon musk']]
194
- gr.Interface(
195
- inference,
196
- "text",
197
- gr.outputs.Image(type="pil", label="Output"),
198
- title=title,
199
- description=description,
200
- article=article,
201
- enable_queue=True,
202
- examples=examples
203
- ).launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
calc_metrics.py DELETED
@@ -1,190 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- """Calculate quality metrics for previous training run or pretrained network pickle."""
10
-
11
- import os
12
- import click
13
- import json
14
- import tempfile
15
- import copy
16
- import torch
17
- import dnnlib
18
-
19
- import legacy
20
- from metrics import metric_main
21
- from metrics import metric_utils
22
- from torch_utils import training_stats
23
- from torch_utils import custom_ops
24
- from torch_utils import misc
25
-
26
- #----------------------------------------------------------------------------
27
-
28
- def subprocess_fn(rank, args, temp_dir):
29
- dnnlib.util.Logger(should_flush=True)
30
-
31
- # Init torch.distributed.
32
- if args.num_gpus > 1:
33
- init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init'))
34
- if os.name == 'nt':
35
- init_method = 'file:///' + init_file.replace('\\', '/')
36
- torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus)
37
- else:
38
- init_method = f'file://{init_file}'
39
- torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus)
40
-
41
- # Init torch_utils.
42
- sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None
43
- training_stats.init_multiprocessing(rank=rank, sync_device=sync_device)
44
- if rank != 0 or not args.verbose:
45
- custom_ops.verbosity = 'none'
46
-
47
- # Print network summary.
48
- device = torch.device('cuda', rank)
49
- torch.backends.cudnn.benchmark = True
50
- torch.backends.cuda.matmul.allow_tf32 = False
51
- torch.backends.cudnn.allow_tf32 = False
52
- G = copy.deepcopy(args.G).eval().requires_grad_(False).to(device)
53
- if rank == 0 and args.verbose:
54
- z = torch.empty([1, G.z_dim], device=device)
55
- c = torch.empty([1, G.c_dim], device=device)
56
- misc.print_module_summary(G, [z, c])
57
-
58
- # Calculate each metric.
59
- for metric in args.metrics:
60
- if rank == 0 and args.verbose:
61
- print(f'Calculating {metric}...')
62
- progress = metric_utils.ProgressMonitor(verbose=args.verbose)
63
- result_dict = metric_main.calc_metric(metric=metric, G=G, dataset_kwargs=args.dataset_kwargs,
64
- num_gpus=args.num_gpus, rank=rank, device=device, progress=progress)
65
- if rank == 0:
66
- metric_main.report_metric(result_dict, run_dir=args.run_dir, snapshot_pkl=args.network_pkl)
67
- if rank == 0 and args.verbose:
68
- print()
69
-
70
- # Done.
71
- if rank == 0 and args.verbose:
72
- print('Exiting...')
73
-
74
- #----------------------------------------------------------------------------
75
-
76
- class CommaSeparatedList(click.ParamType):
77
- name = 'list'
78
-
79
- def convert(self, value, param, ctx):
80
- _ = param, ctx
81
- if value is None or value.lower() == 'none' or value == '':
82
- return []
83
- return value.split(',')
84
-
85
- #----------------------------------------------------------------------------
86
-
87
- @click.command()
88
- @click.pass_context
89
- @click.option('network_pkl', '--network', help='Network pickle filename or URL', metavar='PATH', required=True)
90
- @click.option('--metrics', help='Comma-separated list or "none"', type=CommaSeparatedList(), default='fid50k_full', show_default=True)
91
- @click.option('--data', help='Dataset to evaluate metrics against (directory or zip) [default: same as training data]', metavar='PATH')
92
- @click.option('--mirror', help='Whether the dataset was augmented with x-flips during training [default: look up]', type=bool, metavar='BOOL')
93
- @click.option('--gpus', help='Number of GPUs to use', type=int, default=1, metavar='INT', show_default=True)
94
- @click.option('--verbose', help='Print optional information', type=bool, default=True, metavar='BOOL', show_default=True)
95
-
96
- def calc_metrics(ctx, network_pkl, metrics, data, mirror, gpus, verbose):
97
- """Calculate quality metrics for previous training run or pretrained network pickle.
98
-
99
- Examples:
100
-
101
- \b
102
- # Previous training run: look up options automatically, save result to JSONL file.
103
- python calc_metrics.py --metrics=pr50k3_full \\
104
- --network=~/training-runs/00000-ffhq10k-res64-auto1/network-snapshot-000000.pkl
105
-
106
- \b
107
- # Pre-trained network pickle: specify dataset explicitly, print result to stdout.
108
- python calc_metrics.py --metrics=fid50k_full --data=~/datasets/ffhq.zip --mirror=1 \\
109
- --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl
110
-
111
- Available metrics:
112
-
113
- \b
114
- ADA paper:
115
- fid50k_full Frechet inception distance against the full dataset.
116
- kid50k_full Kernel inception distance against the full dataset.
117
- pr50k3_full Precision and recall againt the full dataset.
118
- is50k Inception score for CIFAR-10.
119
-
120
- \b
121
- StyleGAN and StyleGAN2 papers:
122
- fid50k Frechet inception distance against 50k real images.
123
- kid50k Kernel inception distance against 50k real images.
124
- pr50k3 Precision and recall against 50k real images.
125
- ppl2_wend Perceptual path length in W at path endpoints against full image.
126
- ppl_zfull Perceptual path length in Z for full paths against cropped image.
127
- ppl_wfull Perceptual path length in W for full paths against cropped image.
128
- ppl_zend Perceptual path length in Z at path endpoints against cropped image.
129
- ppl_wend Perceptual path length in W at path endpoints against cropped image.
130
- """
131
- dnnlib.util.Logger(should_flush=True)
132
-
133
- # Validate arguments.
134
- args = dnnlib.EasyDict(metrics=metrics, num_gpus=gpus, network_pkl=network_pkl, verbose=verbose)
135
- if not all(metric_main.is_valid_metric(metric) for metric in args.metrics):
136
- ctx.fail('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics()))
137
- if not args.num_gpus >= 1:
138
- ctx.fail('--gpus must be at least 1')
139
-
140
- # Load network.
141
- if not dnnlib.util.is_url(network_pkl, allow_file_urls=True) and not os.path.isfile(network_pkl):
142
- ctx.fail('--network must point to a file or URL')
143
- if args.verbose:
144
- print(f'Loading network from "{network_pkl}"...')
145
- with dnnlib.util.open_url(network_pkl, verbose=args.verbose) as f:
146
- network_dict = legacy.load_network_pkl(f)
147
- args.G = network_dict['G_ema'] # subclass of torch.nn.Module
148
-
149
- # Initialize dataset options.
150
- if data is not None:
151
- args.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=data)
152
- elif network_dict['training_set_kwargs'] is not None:
153
- args.dataset_kwargs = dnnlib.EasyDict(network_dict['training_set_kwargs'])
154
- else:
155
- ctx.fail('Could not look up dataset options; please specify --data')
156
-
157
- # Finalize dataset options.
158
- args.dataset_kwargs.resolution = args.G.img_resolution
159
- args.dataset_kwargs.use_labels = (args.G.c_dim != 0)
160
- if mirror is not None:
161
- args.dataset_kwargs.xflip = mirror
162
-
163
- # Print dataset options.
164
- if args.verbose:
165
- print('Dataset options:')
166
- print(json.dumps(args.dataset_kwargs, indent=2))
167
-
168
- # Locate run dir.
169
- args.run_dir = None
170
- if os.path.isfile(network_pkl):
171
- pkl_dir = os.path.dirname(network_pkl)
172
- if os.path.isfile(os.path.join(pkl_dir, 'training_options.json')):
173
- args.run_dir = pkl_dir
174
-
175
- # Launch processes.
176
- if args.verbose:
177
- print('Launching processes...')
178
- torch.multiprocessing.set_start_method('spawn')
179
- with tempfile.TemporaryDirectory() as temp_dir:
180
- if args.num_gpus == 1:
181
- subprocess_fn(rank=0, args=args, temp_dir=temp_dir)
182
- else:
183
- torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus)
184
-
185
- #----------------------------------------------------------------------------
186
-
187
- if __name__ == "__main__":
188
- calc_metrics() # pylint: disable=no-value-for-parameter
189
-
190
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dataset_tool.py DELETED
@@ -1,444 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- import functools
10
- import io
11
- import json
12
- import os
13
- import pickle
14
- import sys
15
- import tarfile
16
- import gzip
17
- import zipfile
18
- from pathlib import Path
19
- from typing import Callable, Optional, Tuple, Union
20
-
21
- import click
22
- import numpy as np
23
- import PIL.Image
24
- from tqdm import tqdm
25
-
26
- #----------------------------------------------------------------------------
27
-
28
- def error(msg):
29
- print('Error: ' + msg)
30
- sys.exit(1)
31
-
32
- #----------------------------------------------------------------------------
33
-
34
- def maybe_min(a: int, b: Optional[int]) -> int:
35
- if b is not None:
36
- return min(a, b)
37
- return a
38
-
39
- #----------------------------------------------------------------------------
40
-
41
- def file_ext(name: Union[str, Path]) -> str:
42
- return str(name).split('.')[-1]
43
-
44
- #----------------------------------------------------------------------------
45
-
46
- def is_image_ext(fname: Union[str, Path]) -> bool:
47
- ext = file_ext(fname).lower()
48
- return f'.{ext}' in PIL.Image.EXTENSION # type: ignore
49
-
50
- #----------------------------------------------------------------------------
51
-
52
- def open_image_folder(source_dir, *, max_images: Optional[int]):
53
- input_images = [str(f) for f in sorted(Path(source_dir).rglob('*')) if is_image_ext(f) and os.path.isfile(f)]
54
-
55
- # Load labels.
56
- labels = {}
57
- meta_fname = os.path.join(source_dir, 'dataset.json')
58
- if os.path.isfile(meta_fname):
59
- with open(meta_fname, 'r') as file:
60
- labels = json.load(file)['labels']
61
- if labels is not None:
62
- labels = { x[0]: x[1] for x in labels }
63
- else:
64
- labels = {}
65
-
66
- max_idx = maybe_min(len(input_images), max_images)
67
-
68
- def iterate_images():
69
- for idx, fname in enumerate(input_images):
70
- arch_fname = os.path.relpath(fname, source_dir)
71
- arch_fname = arch_fname.replace('\\', '/')
72
- img = np.array(PIL.Image.open(fname))
73
- yield dict(img=img, label=labels.get(arch_fname))
74
- if idx >= max_idx-1:
75
- break
76
- return max_idx, iterate_images()
77
-
78
- #----------------------------------------------------------------------------
79
-
80
- def open_image_zip(source, *, max_images: Optional[int]):
81
- with zipfile.ZipFile(source, mode='r') as z:
82
- input_images = [str(f) for f in sorted(z.namelist()) if is_image_ext(f)]
83
-
84
- # Load labels.
85
- labels = {}
86
- if 'dataset.json' in z.namelist():
87
- with z.open('dataset.json', 'r') as file:
88
- labels = json.load(file)['labels']
89
- if labels is not None:
90
- labels = { x[0]: x[1] for x in labels }
91
- else:
92
- labels = {}
93
-
94
- max_idx = maybe_min(len(input_images), max_images)
95
-
96
- def iterate_images():
97
- with zipfile.ZipFile(source, mode='r') as z:
98
- for idx, fname in enumerate(input_images):
99
- with z.open(fname, 'r') as file:
100
- img = PIL.Image.open(file) # type: ignore
101
- img = np.array(img)
102
- yield dict(img=img, label=labels.get(fname))
103
- if idx >= max_idx-1:
104
- break
105
- return max_idx, iterate_images()
106
-
107
- #----------------------------------------------------------------------------
108
-
109
- def open_lmdb(lmdb_dir: str, *, max_images: Optional[int]):
110
- import cv2 # pip install opencv-python
111
- import lmdb # pip install lmdb # pylint: disable=import-error
112
-
113
- with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn:
114
- max_idx = maybe_min(txn.stat()['entries'], max_images)
115
-
116
- def iterate_images():
117
- with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn:
118
- for idx, (_key, value) in enumerate(txn.cursor()):
119
- try:
120
- try:
121
- img = cv2.imdecode(np.frombuffer(value, dtype=np.uint8), 1)
122
- if img is None:
123
- raise IOError('cv2.imdecode failed')
124
- img = img[:, :, ::-1] # BGR => RGB
125
- except IOError:
126
- img = np.array(PIL.Image.open(io.BytesIO(value)))
127
- yield dict(img=img, label=None)
128
- if idx >= max_idx-1:
129
- break
130
- except:
131
- print(sys.exc_info()[1])
132
-
133
- return max_idx, iterate_images()
134
-
135
- #----------------------------------------------------------------------------
136
-
137
- def open_cifar10(tarball: str, *, max_images: Optional[int]):
138
- images = []
139
- labels = []
140
-
141
- with tarfile.open(tarball, 'r:gz') as tar:
142
- for batch in range(1, 6):
143
- member = tar.getmember(f'cifar-10-batches-py/data_batch_{batch}')
144
- with tar.extractfile(member) as file:
145
- data = pickle.load(file, encoding='latin1')
146
- images.append(data['data'].reshape(-1, 3, 32, 32))
147
- labels.append(data['labels'])
148
-
149
- images = np.concatenate(images)
150
- labels = np.concatenate(labels)
151
- images = images.transpose([0, 2, 3, 1]) # NCHW -> NHWC
152
- assert images.shape == (50000, 32, 32, 3) and images.dtype == np.uint8
153
- assert labels.shape == (50000,) and labels.dtype in [np.int32, np.int64]
154
- assert np.min(images) == 0 and np.max(images) == 255
155
- assert np.min(labels) == 0 and np.max(labels) == 9
156
-
157
- max_idx = maybe_min(len(images), max_images)
158
-
159
- def iterate_images():
160
- for idx, img in enumerate(images):
161
- yield dict(img=img, label=int(labels[idx]))
162
- if idx >= max_idx-1:
163
- break
164
-
165
- return max_idx, iterate_images()
166
-
167
- #----------------------------------------------------------------------------
168
-
169
- def open_mnist(images_gz: str, *, max_images: Optional[int]):
170
- labels_gz = images_gz.replace('-images-idx3-ubyte.gz', '-labels-idx1-ubyte.gz')
171
- assert labels_gz != images_gz
172
- images = []
173
- labels = []
174
-
175
- with gzip.open(images_gz, 'rb') as f:
176
- images = np.frombuffer(f.read(), np.uint8, offset=16)
177
- with gzip.open(labels_gz, 'rb') as f:
178
- labels = np.frombuffer(f.read(), np.uint8, offset=8)
179
-
180
- images = images.reshape(-1, 28, 28)
181
- images = np.pad(images, [(0,0), (2,2), (2,2)], 'constant', constant_values=0)
182
- assert images.shape == (60000, 32, 32) and images.dtype == np.uint8
183
- assert labels.shape == (60000,) and labels.dtype == np.uint8
184
- assert np.min(images) == 0 and np.max(images) == 255
185
- assert np.min(labels) == 0 and np.max(labels) == 9
186
-
187
- max_idx = maybe_min(len(images), max_images)
188
-
189
- def iterate_images():
190
- for idx, img in enumerate(images):
191
- yield dict(img=img, label=int(labels[idx]))
192
- if idx >= max_idx-1:
193
- break
194
-
195
- return max_idx, iterate_images()
196
-
197
- #----------------------------------------------------------------------------
198
-
199
- def make_transform(
200
- transform: Optional[str],
201
- output_width: Optional[int],
202
- output_height: Optional[int],
203
- resize_filter: str
204
- ) -> Callable[[np.ndarray], Optional[np.ndarray]]:
205
- resample = { 'box': PIL.Image.BOX, 'lanczos': PIL.Image.LANCZOS }[resize_filter]
206
- def scale(width, height, img):
207
- w = img.shape[1]
208
- h = img.shape[0]
209
- if width == w and height == h:
210
- return img
211
- img = PIL.Image.fromarray(img)
212
- ww = width if width is not None else w
213
- hh = height if height is not None else h
214
- img = img.resize((ww, hh), resample)
215
- return np.array(img)
216
-
217
- def center_crop(width, height, img):
218
- crop = np.min(img.shape[:2])
219
- img = img[(img.shape[0] - crop) // 2 : (img.shape[0] + crop) // 2, (img.shape[1] - crop) // 2 : (img.shape[1] + crop) // 2]
220
- img = PIL.Image.fromarray(img, 'RGB')
221
- img = img.resize((width, height), resample)
222
- return np.array(img)
223
-
224
- def center_crop_wide(width, height, img):
225
- ch = int(np.round(width * img.shape[0] / img.shape[1]))
226
- if img.shape[1] < width or ch < height:
227
- return None
228
-
229
- img = img[(img.shape[0] - ch) // 2 : (img.shape[0] + ch) // 2]
230
- img = PIL.Image.fromarray(img, 'RGB')
231
- img = img.resize((width, height), resample)
232
- img = np.array(img)
233
-
234
- canvas = np.zeros([width, width, 3], dtype=np.uint8)
235
- canvas[(width - height) // 2 : (width + height) // 2, :] = img
236
- return canvas
237
-
238
- if transform is None:
239
- return functools.partial(scale, output_width, output_height)
240
- if transform == 'center-crop':
241
- if (output_width is None) or (output_height is None):
242
- error ('must specify --width and --height when using ' + transform + 'transform')
243
- return functools.partial(center_crop, output_width, output_height)
244
- if transform == 'center-crop-wide':
245
- if (output_width is None) or (output_height is None):
246
- error ('must specify --width and --height when using ' + transform + ' transform')
247
- return functools.partial(center_crop_wide, output_width, output_height)
248
- assert False, 'unknown transform'
249
-
250
- #----------------------------------------------------------------------------
251
-
252
- def open_dataset(source, *, max_images: Optional[int]):
253
- if os.path.isdir(source):
254
- if source.rstrip('/').endswith('_lmdb'):
255
- return open_lmdb(source, max_images=max_images)
256
- else:
257
- return open_image_folder(source, max_images=max_images)
258
- elif os.path.isfile(source):
259
- if os.path.basename(source) == 'cifar-10-python.tar.gz':
260
- return open_cifar10(source, max_images=max_images)
261
- elif os.path.basename(source) == 'train-images-idx3-ubyte.gz':
262
- return open_mnist(source, max_images=max_images)
263
- elif file_ext(source) == 'zip':
264
- return open_image_zip(source, max_images=max_images)
265
- else:
266
- assert False, 'unknown archive type'
267
- else:
268
- error(f'Missing input file or directory: {source}')
269
-
270
- #----------------------------------------------------------------------------
271
-
272
- def open_dest(dest: str) -> Tuple[str, Callable[[str, Union[bytes, str]], None], Callable[[], None]]:
273
- dest_ext = file_ext(dest)
274
-
275
- if dest_ext == 'zip':
276
- if os.path.dirname(dest) != '':
277
- os.makedirs(os.path.dirname(dest), exist_ok=True)
278
- zf = zipfile.ZipFile(file=dest, mode='w', compression=zipfile.ZIP_STORED)
279
- def zip_write_bytes(fname: str, data: Union[bytes, str]):
280
- zf.writestr(fname, data)
281
- return '', zip_write_bytes, zf.close
282
- else:
283
- # If the output folder already exists, check that is is
284
- # empty.
285
- #
286
- # Note: creating the output directory is not strictly
287
- # necessary as folder_write_bytes() also mkdirs, but it's better
288
- # to give an error message earlier in case the dest folder
289
- # somehow cannot be created.
290
- if os.path.isdir(dest) and len(os.listdir(dest)) != 0:
291
- error('--dest folder must be empty')
292
- os.makedirs(dest, exist_ok=True)
293
-
294
- def folder_write_bytes(fname: str, data: Union[bytes, str]):
295
- os.makedirs(os.path.dirname(fname), exist_ok=True)
296
- with open(fname, 'wb') as fout:
297
- if isinstance(data, str):
298
- data = data.encode('utf8')
299
- fout.write(data)
300
- return dest, folder_write_bytes, lambda: None
301
-
302
- #----------------------------------------------------------------------------
303
-
304
- @click.command()
305
- @click.pass_context
306
- @click.option('--source', help='Directory or archive name for input dataset', required=True, metavar='PATH')
307
- @click.option('--dest', help='Output directory or archive name for output dataset', required=True, metavar='PATH')
308
- @click.option('--max-images', help='Output only up to `max-images` images', type=int, default=None)
309
- @click.option('--resize-filter', help='Filter to use when resizing images for output resolution', type=click.Choice(['box', 'lanczos']), default='lanczos', show_default=True)
310
- @click.option('--transform', help='Input crop/resize mode', type=click.Choice(['center-crop', 'center-crop-wide']))
311
- @click.option('--width', help='Output width', type=int)
312
- @click.option('--height', help='Output height', type=int)
313
- def convert_dataset(
314
- ctx: click.Context,
315
- source: str,
316
- dest: str,
317
- max_images: Optional[int],
318
- transform: Optional[str],
319
- resize_filter: str,
320
- width: Optional[int],
321
- height: Optional[int]
322
- ):
323
- """Convert an image dataset into a dataset archive usable with StyleGAN2 ADA PyTorch.
324
-
325
- The input dataset format is guessed from the --source argument:
326
-
327
- \b
328
- --source *_lmdb/ Load LSUN dataset
329
- --source cifar-10-python.tar.gz Load CIFAR-10 dataset
330
- --source train-images-idx3-ubyte.gz Load MNIST dataset
331
- --source path/ Recursively load all images from path/
332
- --source dataset.zip Recursively load all images from dataset.zip
333
-
334
- Specifying the output format and path:
335
-
336
- \b
337
- --dest /path/to/dir Save output files under /path/to/dir
338
- --dest /path/to/dataset.zip Save output files into /path/to/dataset.zip
339
-
340
- The output dataset format can be either an image folder or an uncompressed zip archive.
341
- Zip archives makes it easier to move datasets around file servers and clusters, and may
342
- offer better training performance on network file systems.
343
-
344
- Images within the dataset archive will be stored as uncompressed PNG.
345
- Uncompresed PNGs can be efficiently decoded in the training loop.
346
-
347
- Class labels are stored in a file called 'dataset.json' that is stored at the
348
- dataset root folder. This file has the following structure:
349
-
350
- \b
351
- {
352
- "labels": [
353
- ["00000/img00000000.png",6],
354
- ["00000/img00000001.png",9],
355
- ... repeated for every image in the datase
356
- ["00049/img00049999.png",1]
357
- ]
358
- }
359
-
360
- If the 'dataset.json' file cannot be found, the dataset is interpreted as
361
- not containing class labels.
362
-
363
- Image scale/crop and resolution requirements:
364
-
365
- Output images must be square-shaped and they must all have the same power-of-two
366
- dimensions.
367
-
368
- To scale arbitrary input image size to a specific width and height, use the
369
- --width and --height options. Output resolution will be either the original
370
- input resolution (if --width/--height was not specified) or the one specified with
371
- --width/height.
372
-
373
- Use the --transform=center-crop or --transform=center-crop-wide options to apply a
374
- center crop transform on the input image. These options should be used with the
375
- --width and --height options. For example:
376
-
377
- \b
378
- python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \\
379
- --transform=center-crop-wide --width 512 --height=384
380
- """
381
-
382
- PIL.Image.init() # type: ignore
383
-
384
- if dest == '':
385
- ctx.fail('--dest output filename or directory must not be an empty string')
386
-
387
- num_files, input_iter = open_dataset(source, max_images=max_images)
388
- archive_root_dir, save_bytes, close_dest = open_dest(dest)
389
-
390
- transform_image = make_transform(transform, width, height, resize_filter)
391
-
392
- dataset_attrs = None
393
-
394
- labels = []
395
- for idx, image in tqdm(enumerate(input_iter), total=num_files):
396
- idx_str = f'{idx:08d}'
397
- archive_fname = f'{idx_str[:5]}/img{idx_str}.png'
398
-
399
- # Apply crop and resize.
400
- img = transform_image(image['img'])
401
-
402
- # Transform may drop images.
403
- if img is None:
404
- continue
405
-
406
- # Error check to require uniform image attributes across
407
- # the whole dataset.
408
- channels = img.shape[2] if img.ndim == 3 else 1
409
- cur_image_attrs = {
410
- 'width': img.shape[1],
411
- 'height': img.shape[0],
412
- 'channels': channels
413
- }
414
- if dataset_attrs is None:
415
- dataset_attrs = cur_image_attrs
416
- width = dataset_attrs['width']
417
- height = dataset_attrs['height']
418
- if width != height:
419
- error(f'Image dimensions after scale and crop are required to be square. Got {width}x{height}')
420
- if dataset_attrs['channels'] not in [1, 3]:
421
- error('Input images must be stored as RGB or grayscale')
422
- if width != 2 ** int(np.floor(np.log2(width))):
423
- error('Image width/height after scale and crop are required to be power-of-two')
424
- elif dataset_attrs != cur_image_attrs:
425
- err = [f' dataset {k}/cur image {k}: {dataset_attrs[k]}/{cur_image_attrs[k]}' for k in dataset_attrs.keys()]
426
- error(f'Image {archive_fname} attributes must be equal across all images of the dataset. Got:\n' + '\n'.join(err))
427
-
428
- # Save the image as an uncompressed PNG.
429
- img = PIL.Image.fromarray(img, { 1: 'L', 3: 'RGB' }[channels])
430
- image_bits = io.BytesIO()
431
- img.save(image_bits, format='png', compress_level=0, optimize=False)
432
- save_bytes(os.path.join(archive_root_dir, archive_fname), image_bits.getbuffer())
433
- labels.append([archive_fname, image['label']] if image['label'] is not None else None)
434
-
435
- metadata = {
436
- 'labels': labels if all(x is not None for x in labels) else None
437
- }
438
- save_bytes(os.path.join(archive_root_dir, 'dataset.json'), json.dumps(metadata))
439
- close_dest()
440
-
441
- #----------------------------------------------------------------------------
442
-
443
- if __name__ == "__main__":
444
- convert_dataset() # pylint: disable=no-value-for-parameter
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dnnlib/__init__.py DELETED
@@ -1,9 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- from .util import EasyDict, make_cache_dir_path
 
 
 
 
 
 
 
 
 
 
dnnlib/util.py DELETED
@@ -1,477 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- """Miscellaneous utility classes and functions."""
10
-
11
- import ctypes
12
- import fnmatch
13
- import importlib
14
- import inspect
15
- import numpy as np
16
- import os
17
- import shutil
18
- import sys
19
- import types
20
- import io
21
- import pickle
22
- import re
23
- import requests
24
- import html
25
- import hashlib
26
- import glob
27
- import tempfile
28
- import urllib
29
- import urllib.request
30
- import uuid
31
-
32
- from distutils.util import strtobool
33
- from typing import Any, List, Tuple, Union
34
-
35
-
36
- # Util classes
37
- # ------------------------------------------------------------------------------------------
38
-
39
-
40
- class EasyDict(dict):
41
- """Convenience class that behaves like a dict but allows access with the attribute syntax."""
42
-
43
- def __getattr__(self, name: str) -> Any:
44
- try:
45
- return self[name]
46
- except KeyError:
47
- raise AttributeError(name)
48
-
49
- def __setattr__(self, name: str, value: Any) -> None:
50
- self[name] = value
51
-
52
- def __delattr__(self, name: str) -> None:
53
- del self[name]
54
-
55
-
56
- class Logger(object):
57
- """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
58
-
59
- def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
60
- self.file = None
61
-
62
- if file_name is not None:
63
- self.file = open(file_name, file_mode)
64
-
65
- self.should_flush = should_flush
66
- self.stdout = sys.stdout
67
- self.stderr = sys.stderr
68
-
69
- sys.stdout = self
70
- sys.stderr = self
71
-
72
- def __enter__(self) -> "Logger":
73
- return self
74
-
75
- def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
76
- self.close()
77
-
78
- def write(self, text: Union[str, bytes]) -> None:
79
- """Write text to stdout (and a file) and optionally flush."""
80
- if isinstance(text, bytes):
81
- text = text.decode()
82
- if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
83
- return
84
-
85
- if self.file is not None:
86
- self.file.write(text)
87
-
88
- self.stdout.write(text)
89
-
90
- if self.should_flush:
91
- self.flush()
92
-
93
- def flush(self) -> None:
94
- """Flush written text to both stdout and a file, if open."""
95
- if self.file is not None:
96
- self.file.flush()
97
-
98
- self.stdout.flush()
99
-
100
- def close(self) -> None:
101
- """Flush, close possible files, and remove stdout/stderr mirroring."""
102
- self.flush()
103
-
104
- # if using multiple loggers, prevent closing in wrong order
105
- if sys.stdout is self:
106
- sys.stdout = self.stdout
107
- if sys.stderr is self:
108
- sys.stderr = self.stderr
109
-
110
- if self.file is not None:
111
- self.file.close()
112
- self.file = None
113
-
114
-
115
- # Cache directories
116
- # ------------------------------------------------------------------------------------------
117
-
118
- _dnnlib_cache_dir = None
119
-
120
- def set_cache_dir(path: str) -> None:
121
- global _dnnlib_cache_dir
122
- _dnnlib_cache_dir = path
123
-
124
- def make_cache_dir_path(*paths: str) -> str:
125
- if _dnnlib_cache_dir is not None:
126
- return os.path.join(_dnnlib_cache_dir, *paths)
127
- if 'DNNLIB_CACHE_DIR' in os.environ:
128
- return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
129
- if 'HOME' in os.environ:
130
- return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
131
- if 'USERPROFILE' in os.environ:
132
- return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
133
- return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
134
-
135
- # Small util functions
136
- # ------------------------------------------------------------------------------------------
137
-
138
-
139
- def format_time(seconds: Union[int, float]) -> str:
140
- """Convert the seconds to human readable string with days, hours, minutes and seconds."""
141
- s = int(np.rint(seconds))
142
-
143
- if s < 60:
144
- return "{0}s".format(s)
145
- elif s < 60 * 60:
146
- return "{0}m {1:02}s".format(s // 60, s % 60)
147
- elif s < 24 * 60 * 60:
148
- return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
149
- else:
150
- return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
151
-
152
-
153
- def ask_yes_no(question: str) -> bool:
154
- """Ask the user the question until the user inputs a valid answer."""
155
- while True:
156
- try:
157
- print("{0} [y/n]".format(question))
158
- return strtobool(input().lower())
159
- except ValueError:
160
- pass
161
-
162
-
163
- def tuple_product(t: Tuple) -> Any:
164
- """Calculate the product of the tuple elements."""
165
- result = 1
166
-
167
- for v in t:
168
- result *= v
169
-
170
- return result
171
-
172
-
173
- _str_to_ctype = {
174
- "uint8": ctypes.c_ubyte,
175
- "uint16": ctypes.c_uint16,
176
- "uint32": ctypes.c_uint32,
177
- "uint64": ctypes.c_uint64,
178
- "int8": ctypes.c_byte,
179
- "int16": ctypes.c_int16,
180
- "int32": ctypes.c_int32,
181
- "int64": ctypes.c_int64,
182
- "float32": ctypes.c_float,
183
- "float64": ctypes.c_double
184
- }
185
-
186
-
187
- def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
188
- """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
189
- type_str = None
190
-
191
- if isinstance(type_obj, str):
192
- type_str = type_obj
193
- elif hasattr(type_obj, "__name__"):
194
- type_str = type_obj.__name__
195
- elif hasattr(type_obj, "name"):
196
- type_str = type_obj.name
197
- else:
198
- raise RuntimeError("Cannot infer type name from input")
199
-
200
- assert type_str in _str_to_ctype.keys()
201
-
202
- my_dtype = np.dtype(type_str)
203
- my_ctype = _str_to_ctype[type_str]
204
-
205
- assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
206
-
207
- return my_dtype, my_ctype
208
-
209
-
210
- def is_pickleable(obj: Any) -> bool:
211
- try:
212
- with io.BytesIO() as stream:
213
- pickle.dump(obj, stream)
214
- return True
215
- except:
216
- return False
217
-
218
-
219
- # Functionality to import modules/objects by name, and call functions by name
220
- # ------------------------------------------------------------------------------------------
221
-
222
- def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
223
- """Searches for the underlying module behind the name to some python object.
224
- Returns the module and the object name (original name with module part removed)."""
225
-
226
- # allow convenience shorthands, substitute them by full names
227
- obj_name = re.sub("^np.", "numpy.", obj_name)
228
- obj_name = re.sub("^tf.", "tensorflow.", obj_name)
229
-
230
- # list alternatives for (module_name, local_obj_name)
231
- parts = obj_name.split(".")
232
- name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
233
-
234
- # try each alternative in turn
235
- for module_name, local_obj_name in name_pairs:
236
- try:
237
- module = importlib.import_module(module_name) # may raise ImportError
238
- get_obj_from_module(module, local_obj_name) # may raise AttributeError
239
- return module, local_obj_name
240
- except:
241
- pass
242
-
243
- # maybe some of the modules themselves contain errors?
244
- for module_name, _local_obj_name in name_pairs:
245
- try:
246
- importlib.import_module(module_name) # may raise ImportError
247
- except ImportError:
248
- if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
249
- raise
250
-
251
- # maybe the requested attribute is missing?
252
- for module_name, local_obj_name in name_pairs:
253
- try:
254
- module = importlib.import_module(module_name) # may raise ImportError
255
- get_obj_from_module(module, local_obj_name) # may raise AttributeError
256
- except ImportError:
257
- pass
258
-
259
- # we are out of luck, but we have no idea why
260
- raise ImportError(obj_name)
261
-
262
-
263
- def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
264
- """Traverses the object name and returns the last (rightmost) python object."""
265
- if obj_name == '':
266
- return module
267
- obj = module
268
- for part in obj_name.split("."):
269
- obj = getattr(obj, part)
270
- return obj
271
-
272
-
273
- def get_obj_by_name(name: str) -> Any:
274
- """Finds the python object with the given name."""
275
- module, obj_name = get_module_from_obj_name(name)
276
- return get_obj_from_module(module, obj_name)
277
-
278
-
279
- def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
280
- """Finds the python object with the given name and calls it as a function."""
281
- assert func_name is not None
282
- func_obj = get_obj_by_name(func_name)
283
- assert callable(func_obj)
284
- return func_obj(*args, **kwargs)
285
-
286
-
287
- def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
288
- """Finds the python class with the given name and constructs it with the given arguments."""
289
- return call_func_by_name(*args, func_name=class_name, **kwargs)
290
-
291
-
292
- def get_module_dir_by_obj_name(obj_name: str) -> str:
293
- """Get the directory path of the module containing the given object name."""
294
- module, _ = get_module_from_obj_name(obj_name)
295
- return os.path.dirname(inspect.getfile(module))
296
-
297
-
298
- def is_top_level_function(obj: Any) -> bool:
299
- """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
300
- return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
301
-
302
-
303
- def get_top_level_function_name(obj: Any) -> str:
304
- """Return the fully-qualified name of a top-level function."""
305
- assert is_top_level_function(obj)
306
- module = obj.__module__
307
- if module == '__main__':
308
- module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
309
- return module + "." + obj.__name__
310
-
311
-
312
- # File system helpers
313
- # ------------------------------------------------------------------------------------------
314
-
315
- def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
316
- """List all files recursively in a given directory while ignoring given file and directory names.
317
- Returns list of tuples containing both absolute and relative paths."""
318
- assert os.path.isdir(dir_path)
319
- base_name = os.path.basename(os.path.normpath(dir_path))
320
-
321
- if ignores is None:
322
- ignores = []
323
-
324
- result = []
325
-
326
- for root, dirs, files in os.walk(dir_path, topdown=True):
327
- for ignore_ in ignores:
328
- dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
329
-
330
- # dirs need to be edited in-place
331
- for d in dirs_to_remove:
332
- dirs.remove(d)
333
-
334
- files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
335
-
336
- absolute_paths = [os.path.join(root, f) for f in files]
337
- relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
338
-
339
- if add_base_to_relative:
340
- relative_paths = [os.path.join(base_name, p) for p in relative_paths]
341
-
342
- assert len(absolute_paths) == len(relative_paths)
343
- result += zip(absolute_paths, relative_paths)
344
-
345
- return result
346
-
347
-
348
- def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
349
- """Takes in a list of tuples of (src, dst) paths and copies files.
350
- Will create all necessary directories."""
351
- for file in files:
352
- target_dir_name = os.path.dirname(file[1])
353
-
354
- # will create all intermediate-level directories
355
- if not os.path.exists(target_dir_name):
356
- os.makedirs(target_dir_name)
357
-
358
- shutil.copyfile(file[0], file[1])
359
-
360
-
361
- # URL helpers
362
- # ------------------------------------------------------------------------------------------
363
-
364
- def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
365
- """Determine whether the given object is a valid URL string."""
366
- if not isinstance(obj, str) or not "://" in obj:
367
- return False
368
- if allow_file_urls and obj.startswith('file://'):
369
- return True
370
- try:
371
- res = requests.compat.urlparse(obj)
372
- if not res.scheme or not res.netloc or not "." in res.netloc:
373
- return False
374
- res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
375
- if not res.scheme or not res.netloc or not "." in res.netloc:
376
- return False
377
- except:
378
- return False
379
- return True
380
-
381
-
382
- def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
383
- """Download the given URL and return a binary-mode file object to access the data."""
384
- assert num_attempts >= 1
385
- assert not (return_filename and (not cache))
386
-
387
- # Doesn't look like an URL scheme so interpret it as a local filename.
388
- if not re.match('^[a-z]+://', url):
389
- return url if return_filename else open(url, "rb")
390
-
391
- # Handle file URLs. This code handles unusual file:// patterns that
392
- # arise on Windows:
393
- #
394
- # file:///c:/foo.txt
395
- #
396
- # which would translate to a local '/c:/foo.txt' filename that's
397
- # invalid. Drop the forward slash for such pathnames.
398
- #
399
- # If you touch this code path, you should test it on both Linux and
400
- # Windows.
401
- #
402
- # Some internet resources suggest using urllib.request.url2pathname() but
403
- # but that converts forward slashes to backslashes and this causes
404
- # its own set of problems.
405
- if url.startswith('file://'):
406
- filename = urllib.parse.urlparse(url).path
407
- if re.match(r'^/[a-zA-Z]:', filename):
408
- filename = filename[1:]
409
- return filename if return_filename else open(filename, "rb")
410
-
411
- assert is_url(url)
412
-
413
- # Lookup from cache.
414
- if cache_dir is None:
415
- cache_dir = make_cache_dir_path('downloads')
416
-
417
- url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
418
- if cache:
419
- cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
420
- if len(cache_files) == 1:
421
- filename = cache_files[0]
422
- return filename if return_filename else open(filename, "rb")
423
-
424
- # Download.
425
- url_name = None
426
- url_data = None
427
- with requests.Session() as session:
428
- if verbose:
429
- print("Downloading %s ..." % url, end="", flush=True)
430
- for attempts_left in reversed(range(num_attempts)):
431
- try:
432
- with session.get(url) as res:
433
- res.raise_for_status()
434
- if len(res.content) == 0:
435
- raise IOError("No data received")
436
-
437
- if len(res.content) < 8192:
438
- content_str = res.content.decode("utf-8")
439
- if "download_warning" in res.headers.get("Set-Cookie", ""):
440
- links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
441
- if len(links) == 1:
442
- url = requests.compat.urljoin(url, links[0])
443
- raise IOError("Google Drive virus checker nag")
444
- if "Google Drive - Quota exceeded" in content_str:
445
- raise IOError("Google Drive download quota exceeded -- please try again later")
446
-
447
- match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
448
- url_name = match[1] if match else url
449
- url_data = res.content
450
- if verbose:
451
- print(" done")
452
- break
453
- except KeyboardInterrupt:
454
- raise
455
- except:
456
- if not attempts_left:
457
- if verbose:
458
- print(" failed")
459
- raise
460
- if verbose:
461
- print(".", end="", flush=True)
462
-
463
- # Save to cache.
464
- if cache:
465
- safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
466
- cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
467
- temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
468
- os.makedirs(cache_dir, exist_ok=True)
469
- with open(temp_file, "wb") as f:
470
- f.write(url_data)
471
- os.replace(temp_file, cache_file) # atomic
472
- if return_filename:
473
- return cache_file
474
-
475
- # Return data as file object.
476
- assert not return_filename
477
- return io.BytesIO(url_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
generate.py DELETED
@@ -1,129 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- """Generate images using pretrained network pickle."""
10
-
11
- import os
12
- import re
13
- from typing import List, Optional
14
-
15
- import click
16
- import dnnlib
17
- import numpy as np
18
- import PIL.Image
19
- import torch
20
-
21
- import legacy
22
-
23
- #----------------------------------------------------------------------------
24
-
25
- def num_range(s: str) -> List[int]:
26
- '''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.'''
27
-
28
- range_re = re.compile(r'^(\d+)-(\d+)$')
29
- m = range_re.match(s)
30
- if m:
31
- return list(range(int(m.group(1)), int(m.group(2))+1))
32
- vals = s.split(',')
33
- return [int(x) for x in vals]
34
-
35
- #----------------------------------------------------------------------------
36
-
37
- @click.command()
38
- @click.pass_context
39
- @click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
40
- @click.option('--seeds', type=num_range, help='List of random seeds')
41
- @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
42
- @click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)')
43
- @click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
44
- @click.option('--projected-w', help='Projection result file', type=str, metavar='FILE')
45
- @click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR')
46
- def generate_images(
47
- ctx: click.Context,
48
- network_pkl: str,
49
- seeds: Optional[List[int]],
50
- truncation_psi: float,
51
- noise_mode: str,
52
- outdir: str,
53
- class_idx: Optional[int],
54
- projected_w: Optional[str]
55
- ):
56
- """Generate images using pretrained network pickle.
57
-
58
- Examples:
59
-
60
- \b
61
- # Generate curated MetFaces images without truncation (Fig.10 left)
62
- python generate.py --outdir=out --trunc=1 --seeds=85,265,297,849 \\
63
- --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
64
-
65
- \b
66
- # Generate uncurated MetFaces images with truncation (Fig.12 upper left)
67
- python generate.py --outdir=out --trunc=0.7 --seeds=600-605 \\
68
- --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
69
-
70
- \b
71
- # Generate class conditional CIFAR-10 images (Fig.17 left, Car)
72
- python generate.py --outdir=out --seeds=0-35 --class=1 \\
73
- --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/cifar10.pkl
74
-
75
- \b
76
- # Render an image from projected W
77
- python generate.py --outdir=out --projected_w=projected_w.npz \\
78
- --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
79
- """
80
-
81
- print('Loading networks from "%s"...' % network_pkl)
82
- device = torch.device('cuda')
83
- with dnnlib.util.open_url(network_pkl) as f:
84
- G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
85
-
86
- os.makedirs(outdir, exist_ok=True)
87
-
88
- # Synthesize the result of a W projection.
89
- if projected_w is not None:
90
- if seeds is not None:
91
- print ('warn: --seeds is ignored when using --projected-w')
92
- print(f'Generating images from projected W "{projected_w}"')
93
- ws = np.load(projected_w)['w']
94
- ws = torch.tensor(ws, device=device) # pylint: disable=not-callable
95
- assert ws.shape[1:] == (G.num_ws, G.w_dim)
96
- for idx, w in enumerate(ws):
97
- img = G.synthesis(w.unsqueeze(0), noise_mode=noise_mode)
98
- img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
99
- img = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/proj{idx:02d}.png')
100
- return
101
-
102
- if seeds is None:
103
- ctx.fail('--seeds option is required when not using --projected-w')
104
-
105
- # Labels.
106
- label = torch.zeros([1, G.c_dim], device=device)
107
- if G.c_dim != 0:
108
- if class_idx is None:
109
- ctx.fail('Must specify class label with --class when using a conditional network')
110
- label[:, class_idx] = 1
111
- else:
112
- if class_idx is not None:
113
- print ('warn: --class=lbl ignored when running on an unconditional network')
114
-
115
- # Generate images.
116
- for seed_idx, seed in enumerate(seeds):
117
- print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
118
- z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
119
- img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
120
- img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
121
- PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.png')
122
-
123
-
124
- #----------------------------------------------------------------------------
125
-
126
- if __name__ == "__main__":
127
- generate_images() # pylint: disable=no-value-for-parameter
128
-
129
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
legacy.py DELETED
@@ -1,320 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- import click
10
- import pickle
11
- import re
12
- import copy
13
- import numpy as np
14
- import torch
15
- import dnnlib
16
- from torch_utils import misc
17
-
18
- #----------------------------------------------------------------------------
19
-
20
- def load_network_pkl(f, force_fp16=False):
21
- data = _LegacyUnpickler(f).load()
22
-
23
- # Legacy TensorFlow pickle => convert.
24
- if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data):
25
- tf_G, tf_D, tf_Gs = data
26
- G = convert_tf_generator(tf_G)
27
- D = convert_tf_discriminator(tf_D)
28
- G_ema = convert_tf_generator(tf_Gs)
29
- data = dict(G=G, D=D, G_ema=G_ema)
30
-
31
- # Add missing fields.
32
- if 'training_set_kwargs' not in data:
33
- data['training_set_kwargs'] = None
34
- if 'augment_pipe' not in data:
35
- data['augment_pipe'] = None
36
-
37
- # Validate contents.
38
- assert isinstance(data['G'], torch.nn.Module)
39
- assert isinstance(data['D'], torch.nn.Module)
40
- assert isinstance(data['G_ema'], torch.nn.Module)
41
- assert isinstance(data['training_set_kwargs'], (dict, type(None)))
42
- assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None)))
43
-
44
- # Force FP16.
45
- if force_fp16:
46
- for key in ['G', 'D', 'G_ema']:
47
- old = data[key]
48
- kwargs = copy.deepcopy(old.init_kwargs)
49
- if key.startswith('G'):
50
- kwargs.synthesis_kwargs = dnnlib.EasyDict(kwargs.get('synthesis_kwargs', {}))
51
- kwargs.synthesis_kwargs.num_fp16_res = 4
52
- kwargs.synthesis_kwargs.conv_clamp = 256
53
- if key.startswith('D'):
54
- kwargs.num_fp16_res = 4
55
- kwargs.conv_clamp = 256
56
- if kwargs != old.init_kwargs:
57
- new = type(old)(**kwargs).eval().requires_grad_(False)
58
- misc.copy_params_and_buffers(old, new, require_all=True)
59
- data[key] = new
60
- return data
61
-
62
- #----------------------------------------------------------------------------
63
-
64
- class _TFNetworkStub(dnnlib.EasyDict):
65
- pass
66
-
67
- class _LegacyUnpickler(pickle.Unpickler):
68
- def find_class(self, module, name):
69
- if module == 'dnnlib.tflib.network' and name == 'Network':
70
- return _TFNetworkStub
71
- return super().find_class(module, name)
72
-
73
- #----------------------------------------------------------------------------
74
-
75
- def _collect_tf_params(tf_net):
76
- # pylint: disable=protected-access
77
- tf_params = dict()
78
- def recurse(prefix, tf_net):
79
- for name, value in tf_net.variables:
80
- tf_params[prefix + name] = value
81
- for name, comp in tf_net.components.items():
82
- recurse(prefix + name + '/', comp)
83
- recurse('', tf_net)
84
- return tf_params
85
-
86
- #----------------------------------------------------------------------------
87
-
88
- def _populate_module_params(module, *patterns):
89
- for name, tensor in misc.named_params_and_buffers(module):
90
- found = False
91
- value = None
92
- for pattern, value_fn in zip(patterns[0::2], patterns[1::2]):
93
- match = re.fullmatch(pattern, name)
94
- if match:
95
- found = True
96
- if value_fn is not None:
97
- value = value_fn(*match.groups())
98
- break
99
- try:
100
- assert found
101
- if value is not None:
102
- tensor.copy_(torch.from_numpy(np.array(value)))
103
- except:
104
- print(name, list(tensor.shape))
105
- raise
106
-
107
- #----------------------------------------------------------------------------
108
-
109
- def convert_tf_generator(tf_G):
110
- if tf_G.version < 4:
111
- raise ValueError('TensorFlow pickle version too low')
112
-
113
- # Collect kwargs.
114
- tf_kwargs = tf_G.static_kwargs
115
- known_kwargs = set()
116
- def kwarg(tf_name, default=None, none=None):
117
- known_kwargs.add(tf_name)
118
- val = tf_kwargs.get(tf_name, default)
119
- return val if val is not None else none
120
-
121
- # Convert kwargs.
122
- kwargs = dnnlib.EasyDict(
123
- z_dim = kwarg('latent_size', 512),
124
- c_dim = kwarg('label_size', 0),
125
- w_dim = kwarg('dlatent_size', 512),
126
- img_resolution = kwarg('resolution', 1024),
127
- img_channels = kwarg('num_channels', 3),
128
- mapping_kwargs = dnnlib.EasyDict(
129
- num_layers = kwarg('mapping_layers', 8),
130
- embed_features = kwarg('label_fmaps', None),
131
- layer_features = kwarg('mapping_fmaps', None),
132
- activation = kwarg('mapping_nonlinearity', 'lrelu'),
133
- lr_multiplier = kwarg('mapping_lrmul', 0.01),
134
- w_avg_beta = kwarg('w_avg_beta', 0.995, none=1),
135
- ),
136
- synthesis_kwargs = dnnlib.EasyDict(
137
- channel_base = kwarg('fmap_base', 16384) * 2,
138
- channel_max = kwarg('fmap_max', 512),
139
- num_fp16_res = kwarg('num_fp16_res', 0),
140
- conv_clamp = kwarg('conv_clamp', None),
141
- architecture = kwarg('architecture', 'skip'),
142
- resample_filter = kwarg('resample_kernel', [1,3,3,1]),
143
- use_noise = kwarg('use_noise', True),
144
- activation = kwarg('nonlinearity', 'lrelu'),
145
- ),
146
- )
147
-
148
- # Check for unknown kwargs.
149
- kwarg('truncation_psi')
150
- kwarg('truncation_cutoff')
151
- kwarg('style_mixing_prob')
152
- kwarg('structure')
153
- unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
154
- if len(unknown_kwargs) > 0:
155
- raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
156
-
157
- # Collect params.
158
- tf_params = _collect_tf_params(tf_G)
159
- for name, value in list(tf_params.items()):
160
- match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name)
161
- if match:
162
- r = kwargs.img_resolution // (2 ** int(match.group(1)))
163
- tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value
164
- kwargs.synthesis.kwargs.architecture = 'orig'
165
- #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
166
-
167
- # Convert params.
168
- from training import networks
169
- G = networks.Generator(**kwargs).eval().requires_grad_(False)
170
- # pylint: disable=unnecessary-lambda
171
- _populate_module_params(G,
172
- r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'],
173
- r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(),
174
- r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'],
175
- r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(),
176
- r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'],
177
- r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0],
178
- r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1),
179
- r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'],
180
- r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0],
181
- r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'],
182
- r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(),
183
- r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1,
184
- r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
185
- r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'],
186
- r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0],
187
- r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'],
188
- r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(),
189
- r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1,
190
- r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1),
191
- r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'],
192
- r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0],
193
- r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'],
194
- r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(),
195
- r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1,
196
- r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1),
197
- r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'],
198
- r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(),
199
- r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1,
200
- r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
201
- r'.*\.resample_filter', None,
202
- )
203
- return G
204
-
205
- #----------------------------------------------------------------------------
206
-
207
- def convert_tf_discriminator(tf_D):
208
- if tf_D.version < 4:
209
- raise ValueError('TensorFlow pickle version too low')
210
-
211
- # Collect kwargs.
212
- tf_kwargs = tf_D.static_kwargs
213
- known_kwargs = set()
214
- def kwarg(tf_name, default=None):
215
- known_kwargs.add(tf_name)
216
- return tf_kwargs.get(tf_name, default)
217
-
218
- # Convert kwargs.
219
- kwargs = dnnlib.EasyDict(
220
- c_dim = kwarg('label_size', 0),
221
- img_resolution = kwarg('resolution', 1024),
222
- img_channels = kwarg('num_channels', 3),
223
- architecture = kwarg('architecture', 'resnet'),
224
- channel_base = kwarg('fmap_base', 16384) * 2,
225
- channel_max = kwarg('fmap_max', 512),
226
- num_fp16_res = kwarg('num_fp16_res', 0),
227
- conv_clamp = kwarg('conv_clamp', None),
228
- cmap_dim = kwarg('mapping_fmaps', None),
229
- block_kwargs = dnnlib.EasyDict(
230
- activation = kwarg('nonlinearity', 'lrelu'),
231
- resample_filter = kwarg('resample_kernel', [1,3,3,1]),
232
- freeze_layers = kwarg('freeze_layers', 0),
233
- ),
234
- mapping_kwargs = dnnlib.EasyDict(
235
- num_layers = kwarg('mapping_layers', 0),
236
- embed_features = kwarg('mapping_fmaps', None),
237
- layer_features = kwarg('mapping_fmaps', None),
238
- activation = kwarg('nonlinearity', 'lrelu'),
239
- lr_multiplier = kwarg('mapping_lrmul', 0.1),
240
- ),
241
- epilogue_kwargs = dnnlib.EasyDict(
242
- mbstd_group_size = kwarg('mbstd_group_size', None),
243
- mbstd_num_channels = kwarg('mbstd_num_features', 1),
244
- activation = kwarg('nonlinearity', 'lrelu'),
245
- ),
246
- )
247
-
248
- # Check for unknown kwargs.
249
- kwarg('structure')
250
- unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
251
- if len(unknown_kwargs) > 0:
252
- raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
253
-
254
- # Collect params.
255
- tf_params = _collect_tf_params(tf_D)
256
- for name, value in list(tf_params.items()):
257
- match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name)
258
- if match:
259
- r = kwargs.img_resolution // (2 ** int(match.group(1)))
260
- tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value
261
- kwargs.architecture = 'orig'
262
- #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
263
-
264
- # Convert params.
265
- from training import networks
266
- D = networks.Discriminator(**kwargs).eval().requires_grad_(False)
267
- # pylint: disable=unnecessary-lambda
268
- _populate_module_params(D,
269
- r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1),
270
- r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'],
271
- r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1),
272
- r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'],
273
- r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1),
274
- r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(),
275
- r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'],
276
- r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(),
277
- r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'],
278
- r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1),
279
- r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'],
280
- r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(),
281
- r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'],
282
- r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(),
283
- r'b4\.out\.bias', lambda: tf_params[f'Output/bias'],
284
- r'.*\.resample_filter', None,
285
- )
286
- return D
287
-
288
- #----------------------------------------------------------------------------
289
-
290
- @click.command()
291
- @click.option('--source', help='Input pickle', required=True, metavar='PATH')
292
- @click.option('--dest', help='Output pickle', required=True, metavar='PATH')
293
- @click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True)
294
- def convert_network_pickle(source, dest, force_fp16):
295
- """Convert legacy network pickle into the native PyTorch format.
296
-
297
- The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA.
298
- It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks.
299
-
300
- Example:
301
-
302
- \b
303
- python legacy.py \\
304
- --source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\
305
- --dest=stylegan2-cat-config-f.pkl
306
- """
307
- print(f'Loading "{source}"...')
308
- with dnnlib.util.open_url(source) as f:
309
- data = load_network_pkl(f, force_fp16=force_fp16)
310
- print(f'Saving "{dest}"...')
311
- with open(dest, 'wb') as f:
312
- pickle.dump(data, f)
313
- print('Done.')
314
-
315
- #----------------------------------------------------------------------------
316
-
317
- if __name__ == "__main__":
318
- convert_network_pickle() # pylint: disable=no-value-for-parameter
319
-
320
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
metrics/__init__.py DELETED
@@ -1,9 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- # empty
 
 
 
 
 
 
 
 
 
 
metrics/frechet_inception_distance.py DELETED
@@ -1,41 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- """Frechet Inception Distance (FID) from the paper
10
- "GANs trained by a two time-scale update rule converge to a local Nash
11
- equilibrium". Matches the original implementation by Heusel et al. at
12
- https://github.com/bioinf-jku/TTUR/blob/master/fid.py"""
13
-
14
- import numpy as np
15
- import scipy.linalg
16
- from . import metric_utils
17
-
18
- #----------------------------------------------------------------------------
19
-
20
- def compute_fid(opts, max_real, num_gen):
21
- # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
22
- detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
23
- detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
24
-
25
- mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset(
26
- opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
27
- rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real).get_mean_cov()
28
-
29
- mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator(
30
- opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
31
- rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen).get_mean_cov()
32
-
33
- if opts.rank != 0:
34
- return float('nan')
35
-
36
- m = np.square(mu_gen - mu_real).sum()
37
- s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
38
- fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
39
- return float(fid)
40
-
41
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
metrics/inception_score.py DELETED
@@ -1,38 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- """Inception Score (IS) from the paper "Improved techniques for training
10
- GANs". Matches the original implementation by Salimans et al. at
11
- https://github.com/openai/improved-gan/blob/master/inception_score/model.py"""
12
-
13
- import numpy as np
14
- from . import metric_utils
15
-
16
- #----------------------------------------------------------------------------
17
-
18
- def compute_is(opts, num_gen, num_splits):
19
- # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
20
- detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
21
- detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer.
22
-
23
- gen_probs = metric_utils.compute_feature_stats_for_generator(
24
- opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
25
- capture_all=True, max_items=num_gen).get_all()
26
-
27
- if opts.rank != 0:
28
- return float('nan'), float('nan')
29
-
30
- scores = []
31
- for i in range(num_splits):
32
- part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits]
33
- kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True)))
34
- kl = np.mean(np.sum(kl, axis=1))
35
- scores.append(np.exp(kl))
36
- return float(np.mean(scores)), float(np.std(scores))
37
-
38
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
metrics/kernel_inception_distance.py DELETED
@@ -1,46 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- """Kernel Inception Distance (KID) from the paper "Demystifying MMD
10
- GANs". Matches the original implementation by Binkowski et al. at
11
- https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py"""
12
-
13
- import numpy as np
14
- from . import metric_utils
15
-
16
- #----------------------------------------------------------------------------
17
-
18
- def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size):
19
- # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
20
- detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
21
- detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
22
-
23
- real_features = metric_utils.compute_feature_stats_for_dataset(
24
- opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
25
- rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all()
26
-
27
- gen_features = metric_utils.compute_feature_stats_for_generator(
28
- opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
29
- rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all()
30
-
31
- if opts.rank != 0:
32
- return float('nan')
33
-
34
- n = real_features.shape[1]
35
- m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size)
36
- t = 0
37
- for _subset_idx in range(num_subsets):
38
- x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)]
39
- y = real_features[np.random.choice(real_features.shape[0], m, replace=False)]
40
- a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3
41
- b = (x @ y.T / n + 1) ** 3
42
- t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m
43
- kid = t / num_subsets / m
44
- return float(kid)
45
-
46
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
metrics/metric_main.py DELETED
@@ -1,152 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- import os
10
- import time
11
- import json
12
- import torch
13
- import dnnlib
14
-
15
- from . import metric_utils
16
- from . import frechet_inception_distance
17
- from . import kernel_inception_distance
18
- from . import precision_recall
19
- from . import perceptual_path_length
20
- from . import inception_score
21
-
22
- #----------------------------------------------------------------------------
23
-
24
- _metric_dict = dict() # name => fn
25
-
26
- def register_metric(fn):
27
- assert callable(fn)
28
- _metric_dict[fn.__name__] = fn
29
- return fn
30
-
31
- def is_valid_metric(metric):
32
- return metric in _metric_dict
33
-
34
- def list_valid_metrics():
35
- return list(_metric_dict.keys())
36
-
37
- #----------------------------------------------------------------------------
38
-
39
- def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments.
40
- assert is_valid_metric(metric)
41
- opts = metric_utils.MetricOptions(**kwargs)
42
-
43
- # Calculate.
44
- start_time = time.time()
45
- results = _metric_dict[metric](opts)
46
- total_time = time.time() - start_time
47
-
48
- # Broadcast results.
49
- for key, value in list(results.items()):
50
- if opts.num_gpus > 1:
51
- value = torch.as_tensor(value, dtype=torch.float64, device=opts.device)
52
- torch.distributed.broadcast(tensor=value, src=0)
53
- value = float(value.cpu())
54
- results[key] = value
55
-
56
- # Decorate with metadata.
57
- return dnnlib.EasyDict(
58
- results = dnnlib.EasyDict(results),
59
- metric = metric,
60
- total_time = total_time,
61
- total_time_str = dnnlib.util.format_time(total_time),
62
- num_gpus = opts.num_gpus,
63
- )
64
-
65
- #----------------------------------------------------------------------------
66
-
67
- def report_metric(result_dict, run_dir=None, snapshot_pkl=None):
68
- metric = result_dict['metric']
69
- assert is_valid_metric(metric)
70
- if run_dir is not None and snapshot_pkl is not None:
71
- snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir)
72
-
73
- jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time()))
74
- print(jsonl_line)
75
- if run_dir is not None and os.path.isdir(run_dir):
76
- with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f:
77
- f.write(jsonl_line + '\n')
78
-
79
- #----------------------------------------------------------------------------
80
- # Primary metrics.
81
-
82
- @register_metric
83
- def fid50k_full(opts):
84
- opts.dataset_kwargs.update(max_size=None, xflip=False)
85
- fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000)
86
- return dict(fid50k_full=fid)
87
-
88
- @register_metric
89
- def kid50k_full(opts):
90
- opts.dataset_kwargs.update(max_size=None, xflip=False)
91
- kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000)
92
- return dict(kid50k_full=kid)
93
-
94
- @register_metric
95
- def pr50k3_full(opts):
96
- opts.dataset_kwargs.update(max_size=None, xflip=False)
97
- precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
98
- return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall)
99
-
100
- @register_metric
101
- def ppl2_wend(opts):
102
- ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2)
103
- return dict(ppl2_wend=ppl)
104
-
105
- @register_metric
106
- def is50k(opts):
107
- opts.dataset_kwargs.update(max_size=None, xflip=False)
108
- mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10)
109
- return dict(is50k_mean=mean, is50k_std=std)
110
-
111
- #----------------------------------------------------------------------------
112
- # Legacy metrics.
113
-
114
- @register_metric
115
- def fid50k(opts):
116
- opts.dataset_kwargs.update(max_size=None)
117
- fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000)
118
- return dict(fid50k=fid)
119
-
120
- @register_metric
121
- def kid50k(opts):
122
- opts.dataset_kwargs.update(max_size=None)
123
- kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000)
124
- return dict(kid50k=kid)
125
-
126
- @register_metric
127
- def pr50k3(opts):
128
- opts.dataset_kwargs.update(max_size=None)
129
- precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
130
- return dict(pr50k3_precision=precision, pr50k3_recall=recall)
131
-
132
- @register_metric
133
- def ppl_zfull(opts):
134
- ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='z', sampling='full', crop=True, batch_size=2)
135
- return dict(ppl_zfull=ppl)
136
-
137
- @register_metric
138
- def ppl_wfull(opts):
139
- ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='full', crop=True, batch_size=2)
140
- return dict(ppl_wfull=ppl)
141
-
142
- @register_metric
143
- def ppl_zend(opts):
144
- ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='z', sampling='end', crop=True, batch_size=2)
145
- return dict(ppl_zend=ppl)
146
-
147
- @register_metric
148
- def ppl_wend(opts):
149
- ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=True, batch_size=2)
150
- return dict(ppl_wend=ppl)
151
-
152
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
metrics/metric_utils.py DELETED
@@ -1,275 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- import os
10
- import time
11
- import hashlib
12
- import pickle
13
- import copy
14
- import uuid
15
- import numpy as np
16
- import torch
17
- import dnnlib
18
-
19
- #----------------------------------------------------------------------------
20
-
21
- class MetricOptions:
22
- def __init__(self, G=None, G_kwargs={}, dataset_kwargs={}, num_gpus=1, rank=0, device=None, progress=None, cache=True):
23
- assert 0 <= rank < num_gpus
24
- self.G = G
25
- self.G_kwargs = dnnlib.EasyDict(G_kwargs)
26
- self.dataset_kwargs = dnnlib.EasyDict(dataset_kwargs)
27
- self.num_gpus = num_gpus
28
- self.rank = rank
29
- self.device = device if device is not None else torch.device('cuda', rank)
30
- self.progress = progress.sub() if progress is not None and rank == 0 else ProgressMonitor()
31
- self.cache = cache
32
-
33
- #----------------------------------------------------------------------------
34
-
35
- _feature_detector_cache = dict()
36
-
37
- def get_feature_detector_name(url):
38
- return os.path.splitext(url.split('/')[-1])[0]
39
-
40
- def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False):
41
- assert 0 <= rank < num_gpus
42
- key = (url, device)
43
- if key not in _feature_detector_cache:
44
- is_leader = (rank == 0)
45
- if not is_leader and num_gpus > 1:
46
- torch.distributed.barrier() # leader goes first
47
- with dnnlib.util.open_url(url, verbose=(verbose and is_leader)) as f:
48
- _feature_detector_cache[key] = torch.jit.load(f).eval().to(device)
49
- if is_leader and num_gpus > 1:
50
- torch.distributed.barrier() # others follow
51
- return _feature_detector_cache[key]
52
-
53
- #----------------------------------------------------------------------------
54
-
55
- class FeatureStats:
56
- def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None):
57
- self.capture_all = capture_all
58
- self.capture_mean_cov = capture_mean_cov
59
- self.max_items = max_items
60
- self.num_items = 0
61
- self.num_features = None
62
- self.all_features = None
63
- self.raw_mean = None
64
- self.raw_cov = None
65
-
66
- def set_num_features(self, num_features):
67
- if self.num_features is not None:
68
- assert num_features == self.num_features
69
- else:
70
- self.num_features = num_features
71
- self.all_features = []
72
- self.raw_mean = np.zeros([num_features], dtype=np.float64)
73
- self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64)
74
-
75
- def is_full(self):
76
- return (self.max_items is not None) and (self.num_items >= self.max_items)
77
-
78
- def append(self, x):
79
- x = np.asarray(x, dtype=np.float32)
80
- assert x.ndim == 2
81
- if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items):
82
- if self.num_items >= self.max_items:
83
- return
84
- x = x[:self.max_items - self.num_items]
85
-
86
- self.set_num_features(x.shape[1])
87
- self.num_items += x.shape[0]
88
- if self.capture_all:
89
- self.all_features.append(x)
90
- if self.capture_mean_cov:
91
- x64 = x.astype(np.float64)
92
- self.raw_mean += x64.sum(axis=0)
93
- self.raw_cov += x64.T @ x64
94
-
95
- def append_torch(self, x, num_gpus=1, rank=0):
96
- assert isinstance(x, torch.Tensor) and x.ndim == 2
97
- assert 0 <= rank < num_gpus
98
- if num_gpus > 1:
99
- ys = []
100
- for src in range(num_gpus):
101
- y = x.clone()
102
- torch.distributed.broadcast(y, src=src)
103
- ys.append(y)
104
- x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples
105
- self.append(x.cpu().numpy())
106
-
107
- def get_all(self):
108
- assert self.capture_all
109
- return np.concatenate(self.all_features, axis=0)
110
-
111
- def get_all_torch(self):
112
- return torch.from_numpy(self.get_all())
113
-
114
- def get_mean_cov(self):
115
- assert self.capture_mean_cov
116
- mean = self.raw_mean / self.num_items
117
- cov = self.raw_cov / self.num_items
118
- cov = cov - np.outer(mean, mean)
119
- return mean, cov
120
-
121
- def save(self, pkl_file):
122
- with open(pkl_file, 'wb') as f:
123
- pickle.dump(self.__dict__, f)
124
-
125
- @staticmethod
126
- def load(pkl_file):
127
- with open(pkl_file, 'rb') as f:
128
- s = dnnlib.EasyDict(pickle.load(f))
129
- obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items)
130
- obj.__dict__.update(s)
131
- return obj
132
-
133
- #----------------------------------------------------------------------------
134
-
135
- class ProgressMonitor:
136
- def __init__(self, tag=None, num_items=None, flush_interval=1000, verbose=False, progress_fn=None, pfn_lo=0, pfn_hi=1000, pfn_total=1000):
137
- self.tag = tag
138
- self.num_items = num_items
139
- self.verbose = verbose
140
- self.flush_interval = flush_interval
141
- self.progress_fn = progress_fn
142
- self.pfn_lo = pfn_lo
143
- self.pfn_hi = pfn_hi
144
- self.pfn_total = pfn_total
145
- self.start_time = time.time()
146
- self.batch_time = self.start_time
147
- self.batch_items = 0
148
- if self.progress_fn is not None:
149
- self.progress_fn(self.pfn_lo, self.pfn_total)
150
-
151
- def update(self, cur_items):
152
- assert (self.num_items is None) or (cur_items <= self.num_items)
153
- if (cur_items < self.batch_items + self.flush_interval) and (self.num_items is None or cur_items < self.num_items):
154
- return
155
- cur_time = time.time()
156
- total_time = cur_time - self.start_time
157
- time_per_item = (cur_time - self.batch_time) / max(cur_items - self.batch_items, 1)
158
- if (self.verbose) and (self.tag is not None):
159
- print(f'{self.tag:<19s} items {cur_items:<7d} time {dnnlib.util.format_time(total_time):<12s} ms/item {time_per_item*1e3:.2f}')
160
- self.batch_time = cur_time
161
- self.batch_items = cur_items
162
-
163
- if (self.progress_fn is not None) and (self.num_items is not None):
164
- self.progress_fn(self.pfn_lo + (self.pfn_hi - self.pfn_lo) * (cur_items / self.num_items), self.pfn_total)
165
-
166
- def sub(self, tag=None, num_items=None, flush_interval=1000, rel_lo=0, rel_hi=1):
167
- return ProgressMonitor(
168
- tag = tag,
169
- num_items = num_items,
170
- flush_interval = flush_interval,
171
- verbose = self.verbose,
172
- progress_fn = self.progress_fn,
173
- pfn_lo = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_lo,
174
- pfn_hi = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_hi,
175
- pfn_total = self.pfn_total,
176
- )
177
-
178
- #----------------------------------------------------------------------------
179
-
180
- def compute_feature_stats_for_dataset(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, data_loader_kwargs=None, max_items=None, **stats_kwargs):
181
- dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
182
- if data_loader_kwargs is None:
183
- data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2)
184
-
185
- # Try to lookup from cache.
186
- cache_file = None
187
- if opts.cache:
188
- # Choose cache file name.
189
- args = dict(dataset_kwargs=opts.dataset_kwargs, detector_url=detector_url, detector_kwargs=detector_kwargs, stats_kwargs=stats_kwargs)
190
- md5 = hashlib.md5(repr(sorted(args.items())).encode('utf-8'))
191
- cache_tag = f'{dataset.name}-{get_feature_detector_name(detector_url)}-{md5.hexdigest()}'
192
- cache_file = dnnlib.make_cache_dir_path('gan-metrics', cache_tag + '.pkl')
193
-
194
- # Check if the file exists (all processes must agree).
195
- flag = os.path.isfile(cache_file) if opts.rank == 0 else False
196
- if opts.num_gpus > 1:
197
- flag = torch.as_tensor(flag, dtype=torch.float32, device=opts.device)
198
- torch.distributed.broadcast(tensor=flag, src=0)
199
- flag = (float(flag.cpu()) != 0)
200
-
201
- # Load.
202
- if flag:
203
- return FeatureStats.load(cache_file)
204
-
205
- # Initialize.
206
- num_items = len(dataset)
207
- if max_items is not None:
208
- num_items = min(num_items, max_items)
209
- stats = FeatureStats(max_items=num_items, **stats_kwargs)
210
- progress = opts.progress.sub(tag='dataset features', num_items=num_items, rel_lo=rel_lo, rel_hi=rel_hi)
211
- detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
212
-
213
- # Main loop.
214
- item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)]
215
- for images, _labels in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs):
216
- if images.shape[1] == 1:
217
- images = images.repeat([1, 3, 1, 1])
218
- features = detector(images.to(opts.device), **detector_kwargs)
219
- stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
220
- progress.update(stats.num_items)
221
-
222
- # Save to cache.
223
- if cache_file is not None and opts.rank == 0:
224
- os.makedirs(os.path.dirname(cache_file), exist_ok=True)
225
- temp_file = cache_file + '.' + uuid.uuid4().hex
226
- stats.save(temp_file)
227
- os.replace(temp_file, cache_file) # atomic
228
- return stats
229
-
230
- #----------------------------------------------------------------------------
231
-
232
- def compute_feature_stats_for_generator(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, batch_gen=None, jit=False, **stats_kwargs):
233
- if batch_gen is None:
234
- batch_gen = min(batch_size, 4)
235
- assert batch_size % batch_gen == 0
236
-
237
- # Setup generator and load labels.
238
- G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
239
- dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
240
-
241
- # Image generation func.
242
- def run_generator(z, c):
243
- img = G(z=z, c=c, **opts.G_kwargs)
244
- img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
245
- return img
246
-
247
- # JIT.
248
- if jit:
249
- z = torch.zeros([batch_gen, G.z_dim], device=opts.device)
250
- c = torch.zeros([batch_gen, G.c_dim], device=opts.device)
251
- run_generator = torch.jit.trace(run_generator, [z, c], check_trace=False)
252
-
253
- # Initialize.
254
- stats = FeatureStats(**stats_kwargs)
255
- assert stats.max_items is not None
256
- progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi)
257
- detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
258
-
259
- # Main loop.
260
- while not stats.is_full():
261
- images = []
262
- for _i in range(batch_size // batch_gen):
263
- z = torch.randn([batch_gen, G.z_dim], device=opts.device)
264
- c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_gen)]
265
- c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
266
- images.append(run_generator(z, c))
267
- images = torch.cat(images)
268
- if images.shape[1] == 1:
269
- images = images.repeat([1, 3, 1, 1])
270
- features = detector(images, **detector_kwargs)
271
- stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
272
- progress.update(stats.num_items)
273
- return stats
274
-
275
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
metrics/perceptual_path_length.py DELETED
@@ -1,131 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- """Perceptual Path Length (PPL) from the paper "A Style-Based Generator
10
- Architecture for Generative Adversarial Networks". Matches the original
11
- implementation by Karras et al. at
12
- https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py"""
13
-
14
- import copy
15
- import numpy as np
16
- import torch
17
- import dnnlib
18
- from . import metric_utils
19
-
20
- #----------------------------------------------------------------------------
21
-
22
- # Spherical interpolation of a batch of vectors.
23
- def slerp(a, b, t):
24
- a = a / a.norm(dim=-1, keepdim=True)
25
- b = b / b.norm(dim=-1, keepdim=True)
26
- d = (a * b).sum(dim=-1, keepdim=True)
27
- p = t * torch.acos(d)
28
- c = b - d * a
29
- c = c / c.norm(dim=-1, keepdim=True)
30
- d = a * torch.cos(p) + c * torch.sin(p)
31
- d = d / d.norm(dim=-1, keepdim=True)
32
- return d
33
-
34
- #----------------------------------------------------------------------------
35
-
36
- class PPLSampler(torch.nn.Module):
37
- def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16):
38
- assert space in ['z', 'w']
39
- assert sampling in ['full', 'end']
40
- super().__init__()
41
- self.G = copy.deepcopy(G)
42
- self.G_kwargs = G_kwargs
43
- self.epsilon = epsilon
44
- self.space = space
45
- self.sampling = sampling
46
- self.crop = crop
47
- self.vgg16 = copy.deepcopy(vgg16)
48
-
49
- def forward(self, c):
50
- # Generate random latents and interpolation t-values.
51
- t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0)
52
- z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2)
53
-
54
- # Interpolate in W or Z.
55
- if self.space == 'w':
56
- w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2)
57
- wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2))
58
- wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon)
59
- else: # space == 'z'
60
- zt0 = slerp(z0, z1, t.unsqueeze(1))
61
- zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon)
62
- wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2)
63
-
64
- # Randomize noise buffers.
65
- for name, buf in self.G.named_buffers():
66
- if name.endswith('.noise_const'):
67
- buf.copy_(torch.randn_like(buf))
68
-
69
- # Generate images.
70
- img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs)
71
-
72
- # Center crop.
73
- if self.crop:
74
- assert img.shape[2] == img.shape[3]
75
- c = img.shape[2] // 8
76
- img = img[:, :, c*3 : c*7, c*2 : c*6]
77
-
78
- # Downsample to 256x256.
79
- factor = self.G.img_resolution // 256
80
- if factor > 1:
81
- img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5])
82
-
83
- # Scale dynamic range from [-1,1] to [0,255].
84
- img = (img + 1) * (255 / 2)
85
- if self.G.img_channels == 1:
86
- img = img.repeat([1, 3, 1, 1])
87
-
88
- # Evaluate differential LPIPS.
89
- lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2)
90
- dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2
91
- return dist
92
-
93
- #----------------------------------------------------------------------------
94
-
95
- def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size, jit=False):
96
- dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
97
- vgg16_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
98
- vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose)
99
-
100
- # Setup sampler.
101
- sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16)
102
- sampler.eval().requires_grad_(False).to(opts.device)
103
- if jit:
104
- c = torch.zeros([batch_size, opts.G.c_dim], device=opts.device)
105
- sampler = torch.jit.trace(sampler, [c], check_trace=False)
106
-
107
- # Sampling loop.
108
- dist = []
109
- progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples)
110
- for batch_start in range(0, num_samples, batch_size * opts.num_gpus):
111
- progress.update(batch_start)
112
- c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_size)]
113
- c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
114
- x = sampler(c)
115
- for src in range(opts.num_gpus):
116
- y = x.clone()
117
- if opts.num_gpus > 1:
118
- torch.distributed.broadcast(y, src=src)
119
- dist.append(y)
120
- progress.update(num_samples)
121
-
122
- # Compute PPL.
123
- if opts.rank != 0:
124
- return float('nan')
125
- dist = torch.cat(dist)[:num_samples].cpu().numpy()
126
- lo = np.percentile(dist, 1, interpolation='lower')
127
- hi = np.percentile(dist, 99, interpolation='higher')
128
- ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean()
129
- return float(ppl)
130
-
131
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
metrics/precision_recall.py DELETED
@@ -1,62 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- """Precision/Recall (PR) from the paper "Improved Precision and Recall
10
- Metric for Assessing Generative Models". Matches the original implementation
11
- by Kynkaanniemi et al. at
12
- https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py"""
13
-
14
- import torch
15
- from . import metric_utils
16
-
17
- #----------------------------------------------------------------------------
18
-
19
- def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size):
20
- assert 0 <= rank < num_gpus
21
- num_cols = col_features.shape[0]
22
- num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus
23
- col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches)
24
- dist_batches = []
25
- for col_batch in col_batches[rank :: num_gpus]:
26
- dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0]
27
- for src in range(num_gpus):
28
- dist_broadcast = dist_batch.clone()
29
- if num_gpus > 1:
30
- torch.distributed.broadcast(dist_broadcast, src=src)
31
- dist_batches.append(dist_broadcast.cpu() if rank == 0 else None)
32
- return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None
33
-
34
- #----------------------------------------------------------------------------
35
-
36
- def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size):
37
- detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
38
- detector_kwargs = dict(return_features=True)
39
-
40
- real_features = metric_utils.compute_feature_stats_for_dataset(
41
- opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
42
- rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device)
43
-
44
- gen_features = metric_utils.compute_feature_stats_for_generator(
45
- opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
46
- rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device)
47
-
48
- results = dict()
49
- for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]:
50
- kth = []
51
- for manifold_batch in manifold.split(row_batch_size):
52
- dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
53
- kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None)
54
- kth = torch.cat(kth) if opts.rank == 0 else None
55
- pred = []
56
- for probes_batch in probes.split(row_batch_size):
57
- dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
58
- pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None)
59
- results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan')
60
- return results['precision'], results['recall']
61
-
62
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
packages.txt DELETED
@@ -1 +0,0 @@
1
- libtinfo5
 
 
projector.py DELETED
@@ -1,212 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- """Project given image to the latent space of pretrained network pickle."""
10
-
11
- import copy
12
- import os
13
- from time import perf_counter
14
-
15
- import click
16
- import imageio
17
- import numpy as np
18
- import PIL.Image
19
- import torch
20
- import torch.nn.functional as F
21
-
22
- import dnnlib
23
- import legacy
24
-
25
- def project(
26
- G,
27
- target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution
28
- *,
29
- num_steps = 1000,
30
- w_avg_samples = 10000,
31
- initial_learning_rate = 0.1,
32
- initial_noise_factor = 0.05,
33
- lr_rampdown_length = 0.25,
34
- lr_rampup_length = 0.05,
35
- noise_ramp_length = 0.75,
36
- regularize_noise_weight = 1e5,
37
- verbose = False,
38
- device: torch.device
39
- ):
40
- assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution)
41
-
42
- def logprint(*args):
43
- if verbose:
44
- print(*args)
45
-
46
- G = copy.deepcopy(G).eval().requires_grad_(False).to(device) # type: ignore
47
-
48
- # Compute w stats.
49
- logprint(f'Computing W midpoint and stddev using {w_avg_samples} samples...')
50
- z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim)
51
- w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None) # [N, L, C]
52
- w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C]
53
- w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C]
54
- w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5
55
-
56
- # Setup noise inputs.
57
- noise_bufs = { name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name }
58
-
59
- # Load VGG16 feature detector.
60
- url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
61
- with dnnlib.util.open_url(url) as f:
62
- vgg16 = torch.jit.load(f).eval().to(device)
63
-
64
- # Features for target image.
65
- target_images = target.unsqueeze(0).to(device).to(torch.float32)
66
- if target_images.shape[2] > 256:
67
- target_images = F.interpolate(target_images, size=(256, 256), mode='area')
68
- target_features = vgg16(target_images, resize_images=False, return_lpips=True)
69
-
70
- w_opt = torch.tensor(w_avg, dtype=torch.float32, device=device, requires_grad=True) # pylint: disable=not-callable
71
- w_out = torch.zeros([num_steps] + list(w_opt.shape[1:]), dtype=torch.float32, device=device)
72
- optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999), lr=initial_learning_rate)
73
-
74
- # Init noise.
75
- for buf in noise_bufs.values():
76
- buf[:] = torch.randn_like(buf)
77
- buf.requires_grad = True
78
-
79
- for step in range(num_steps):
80
- # Learning rate schedule.
81
- t = step / num_steps
82
- w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2
83
- lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length)
84
- lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
85
- lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length)
86
- lr = initial_learning_rate * lr_ramp
87
- for param_group in optimizer.param_groups:
88
- param_group['lr'] = lr
89
-
90
- # Synth images from opt_w.
91
- w_noise = torch.randn_like(w_opt) * w_noise_scale
92
- ws = (w_opt + w_noise).repeat([1, G.mapping.num_ws, 1])
93
- synth_images = G.synthesis(ws, noise_mode='const')
94
-
95
- # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
96
- synth_images = (synth_images + 1) * (255/2)
97
- if synth_images.shape[2] > 256:
98
- synth_images = F.interpolate(synth_images, size=(256, 256), mode='area')
99
-
100
- # Features for synth images.
101
- synth_features = vgg16(synth_images, resize_images=False, return_lpips=True)
102
- dist = (target_features - synth_features).square().sum()
103
-
104
- # Noise regularization.
105
- reg_loss = 0.0
106
- for v in noise_bufs.values():
107
- noise = v[None,None,:,:] # must be [1,1,H,W] for F.avg_pool2d()
108
- while True:
109
- reg_loss += (noise*torch.roll(noise, shifts=1, dims=3)).mean()**2
110
- reg_loss += (noise*torch.roll(noise, shifts=1, dims=2)).mean()**2
111
- if noise.shape[2] <= 8:
112
- break
113
- noise = F.avg_pool2d(noise, kernel_size=2)
114
- loss = dist + reg_loss * regularize_noise_weight
115
-
116
- # Step
117
- optimizer.zero_grad(set_to_none=True)
118
- loss.backward()
119
- optimizer.step()
120
- logprint(f'step {step+1:>4d}/{num_steps}: dist {dist:<4.2f} loss {float(loss):<5.2f}')
121
-
122
- # Save projected W for each optimization step.
123
- w_out[step] = w_opt.detach()[0]
124
-
125
- # Normalize noise.
126
- with torch.no_grad():
127
- for buf in noise_bufs.values():
128
- buf -= buf.mean()
129
- buf *= buf.square().mean().rsqrt()
130
-
131
- return w_out.repeat([1, G.mapping.num_ws, 1])
132
-
133
- #----------------------------------------------------------------------------
134
-
135
- @click.command()
136
- @click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
137
- @click.option('--target', 'target_fname', help='Target image file to project to', required=True, metavar='FILE')
138
- @click.option('--num-steps', help='Number of optimization steps', type=int, default=1000, show_default=True)
139
- @click.option('--seed', help='Random seed', type=int, default=303, show_default=True)
140
- @click.option('--save-video', help='Save an mp4 video of optimization progress', type=bool, default=True, show_default=True)
141
- @click.option('--outdir', help='Where to save the output images', required=True, metavar='DIR')
142
- def run_projection(
143
- network_pkl: str,
144
- target_fname: str,
145
- outdir: str,
146
- save_video: bool,
147
- seed: int,
148
- num_steps: int
149
- ):
150
- """Project given image to the latent space of pretrained network pickle.
151
-
152
- Examples:
153
-
154
- \b
155
- python projector.py --outdir=out --target=~/mytargetimg.png \\
156
- --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl
157
- """
158
- np.random.seed(seed)
159
- torch.manual_seed(seed)
160
-
161
- # Load networks.
162
- print('Loading networks from "%s"...' % network_pkl)
163
- device = torch.device('cuda')
164
- with dnnlib.util.open_url(network_pkl) as fp:
165
- G = legacy.load_network_pkl(fp)['G_ema'].requires_grad_(False).to(device) # type: ignore
166
-
167
- # Load target image.
168
- target_pil = PIL.Image.open(target_fname).convert('RGB')
169
- w, h = target_pil.size
170
- s = min(w, h)
171
- target_pil = target_pil.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2))
172
- target_pil = target_pil.resize((G.img_resolution, G.img_resolution), PIL.Image.LANCZOS)
173
- target_uint8 = np.array(target_pil, dtype=np.uint8)
174
-
175
- # Optimize projection.
176
- start_time = perf_counter()
177
- projected_w_steps = project(
178
- G,
179
- target=torch.tensor(target_uint8.transpose([2, 0, 1]), device=device), # pylint: disable=not-callable
180
- num_steps=num_steps,
181
- device=device,
182
- verbose=True
183
- )
184
- print (f'Elapsed: {(perf_counter()-start_time):.1f} s')
185
-
186
- # Render debug output: optional video and projected image and W vector.
187
- os.makedirs(outdir, exist_ok=True)
188
- if save_video:
189
- video = imageio.get_writer(f'{outdir}/proj.mp4', mode='I', fps=10, codec='libx264', bitrate='16M')
190
- print (f'Saving optimization progress video "{outdir}/proj.mp4"')
191
- for projected_w in projected_w_steps:
192
- synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const')
193
- synth_image = (synth_image + 1) * (255/2)
194
- synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
195
- video.append_data(np.concatenate([target_uint8, synth_image], axis=1))
196
- video.close()
197
-
198
- # Save final projected frame and W vector.
199
- target_pil.save(f'{outdir}/target.png')
200
- projected_w = projected_w_steps[-1]
201
- synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const')
202
- synth_image = (synth_image + 1) * (255/2)
203
- synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
204
- PIL.Image.fromarray(synth_image, 'RGB').save(f'{outdir}/proj.png')
205
- np.savez(f'{outdir}/projected_w.npz', w=projected_w.unsqueeze(0).cpu().numpy())
206
-
207
- #----------------------------------------------------------------------------
208
-
209
- if __name__ == "__main__":
210
- run_projection() # pylint: disable=no-value-for-parameter
211
-
212
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
style_mixing.py DELETED
@@ -1,118 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- """Generate style mixing image matrix using pretrained network pickle."""
10
-
11
- import os
12
- import re
13
- from typing import List
14
-
15
- import click
16
- import dnnlib
17
- import numpy as np
18
- import PIL.Image
19
- import torch
20
-
21
- import legacy
22
-
23
- #----------------------------------------------------------------------------
24
-
25
- def num_range(s: str) -> List[int]:
26
- '''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.'''
27
-
28
- range_re = re.compile(r'^(\d+)-(\d+)$')
29
- m = range_re.match(s)
30
- if m:
31
- return list(range(int(m.group(1)), int(m.group(2))+1))
32
- vals = s.split(',')
33
- return [int(x) for x in vals]
34
-
35
- #----------------------------------------------------------------------------
36
-
37
- @click.command()
38
- @click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
39
- @click.option('--rows', 'row_seeds', type=num_range, help='Random seeds to use for image rows', required=True)
40
- @click.option('--cols', 'col_seeds', type=num_range, help='Random seeds to use for image columns', required=True)
41
- @click.option('--styles', 'col_styles', type=num_range, help='Style layer range', default='0-6', show_default=True)
42
- @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
43
- @click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
44
- @click.option('--outdir', type=str, required=True)
45
- def generate_style_mix(
46
- network_pkl: str,
47
- row_seeds: List[int],
48
- col_seeds: List[int],
49
- col_styles: List[int],
50
- truncation_psi: float,
51
- noise_mode: str,
52
- outdir: str
53
- ):
54
- """Generate images using pretrained network pickle.
55
-
56
- Examples:
57
-
58
- \b
59
- python style_mixing.py --outdir=out --rows=85,100,75,458,1500 --cols=55,821,1789,293 \\
60
- --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
61
- """
62
- print('Loading networks from "%s"...' % network_pkl)
63
- device = torch.device('cuda')
64
- with dnnlib.util.open_url(network_pkl) as f:
65
- G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
66
-
67
- os.makedirs(outdir, exist_ok=True)
68
-
69
- print('Generating W vectors...')
70
- all_seeds = list(set(row_seeds + col_seeds))
71
- all_z = np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in all_seeds])
72
- all_w = G.mapping(torch.from_numpy(all_z).to(device), None)
73
- w_avg = G.mapping.w_avg
74
- all_w = w_avg + (all_w - w_avg) * truncation_psi
75
- w_dict = {seed: w for seed, w in zip(all_seeds, list(all_w))}
76
-
77
- print('Generating images...')
78
- all_images = G.synthesis(all_w, noise_mode=noise_mode)
79
- all_images = (all_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy()
80
- image_dict = {(seed, seed): image for seed, image in zip(all_seeds, list(all_images))}
81
-
82
- print('Generating style-mixed images...')
83
- for row_seed in row_seeds:
84
- for col_seed in col_seeds:
85
- w = w_dict[row_seed].clone()
86
- w[col_styles] = w_dict[col_seed][col_styles]
87
- image = G.synthesis(w[np.newaxis], noise_mode=noise_mode)
88
- image = (image.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
89
- image_dict[(row_seed, col_seed)] = image[0].cpu().numpy()
90
-
91
- print('Saving images...')
92
- os.makedirs(outdir, exist_ok=True)
93
- for (row_seed, col_seed), image in image_dict.items():
94
- PIL.Image.fromarray(image, 'RGB').save(f'{outdir}/{row_seed}-{col_seed}.png')
95
-
96
- print('Saving image grid...')
97
- W = G.img_resolution
98
- H = G.img_resolution
99
- canvas = PIL.Image.new('RGB', (W * (len(col_seeds) + 1), H * (len(row_seeds) + 1)), 'black')
100
- for row_idx, row_seed in enumerate([0] + row_seeds):
101
- for col_idx, col_seed in enumerate([0] + col_seeds):
102
- if row_idx == 0 and col_idx == 0:
103
- continue
104
- key = (row_seed, col_seed)
105
- if row_idx == 0:
106
- key = (col_seed, col_seed)
107
- if col_idx == 0:
108
- key = (row_seed, row_seed)
109
- canvas.paste(PIL.Image.fromarray(image_dict[key], 'RGB'), (W * col_idx, H * row_idx))
110
- canvas.save(f'{outdir}/grid.png')
111
-
112
-
113
- #----------------------------------------------------------------------------
114
-
115
- if __name__ == "__main__":
116
- generate_style_mix() # pylint: disable=no-value-for-parameter
117
-
118
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/__init__.py DELETED
@@ -1,9 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- # empty
 
 
 
 
 
 
 
 
 
 
torch_utils/custom_ops.py DELETED
@@ -1,126 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- import os
10
- import glob
11
- import torch
12
- import torch.utils.cpp_extension
13
- import importlib
14
- import hashlib
15
- import shutil
16
- from pathlib import Path
17
-
18
- from torch.utils.file_baton import FileBaton
19
-
20
- #----------------------------------------------------------------------------
21
- # Global options.
22
-
23
- verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
24
-
25
- #----------------------------------------------------------------------------
26
- # Internal helper funcs.
27
-
28
- def _find_compiler_bindir():
29
- patterns = [
30
- 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
31
- 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
32
- 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
33
- 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
34
- ]
35
- for pattern in patterns:
36
- matches = sorted(glob.glob(pattern))
37
- if len(matches):
38
- return matches[-1]
39
- return None
40
-
41
- #----------------------------------------------------------------------------
42
- # Main entry point for compiling and loading C++/CUDA plugins.
43
-
44
- _cached_plugins = dict()
45
-
46
- def get_plugin(module_name, sources, **build_kwargs):
47
- assert verbosity in ['none', 'brief', 'full']
48
-
49
- # Already cached?
50
- if module_name in _cached_plugins:
51
- return _cached_plugins[module_name]
52
-
53
- # Print status.
54
- if verbosity == 'full':
55
- print(f'Setting up PyTorch plugin "{module_name}"...')
56
- elif verbosity == 'brief':
57
- print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
58
-
59
- try: # pylint: disable=too-many-nested-blocks
60
- # Make sure we can find the necessary compiler binaries.
61
- if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
62
- compiler_bindir = _find_compiler_bindir()
63
- if compiler_bindir is None:
64
- raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
65
- os.environ['PATH'] += ';' + compiler_bindir
66
-
67
- # Compile and load.
68
- verbose_build = (verbosity == 'full')
69
-
70
- # Incremental build md5sum trickery. Copies all the input source files
71
- # into a cached build directory under a combined md5 digest of the input
72
- # source files. Copying is done only if the combined digest has changed.
73
- # This keeps input file timestamps and filenames the same as in previous
74
- # extension builds, allowing for fast incremental rebuilds.
75
- #
76
- # This optimization is done only in case all the source files reside in
77
- # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
78
- # environment variable is set (we take this as a signal that the user
79
- # actually cares about this.)
80
- source_dirs_set = set(os.path.dirname(source) for source in sources)
81
- if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ):
82
- all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file()))
83
-
84
- # Compute a combined hash digest for all source files in the same
85
- # custom op directory (usually .cu, .cpp, .py and .h files).
86
- hash_md5 = hashlib.md5()
87
- for src in all_source_files:
88
- with open(src, 'rb') as f:
89
- hash_md5.update(f.read())
90
- build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
91
- digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest())
92
-
93
- if not os.path.isdir(digest_build_dir):
94
- os.makedirs(digest_build_dir, exist_ok=True)
95
- baton = FileBaton(os.path.join(digest_build_dir, 'lock'))
96
- if baton.try_acquire():
97
- try:
98
- for src in all_source_files:
99
- shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src)))
100
- finally:
101
- baton.release()
102
- else:
103
- # Someone else is copying source files under the digest dir,
104
- # wait until done and continue.
105
- baton.wait()
106
- digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources]
107
- torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir,
108
- verbose=verbose_build, sources=digest_sources, **build_kwargs)
109
- else:
110
- torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
111
- module = importlib.import_module(module_name)
112
-
113
- except:
114
- if verbosity == 'brief':
115
- print('Failed!')
116
- raise
117
-
118
- # Print status and add to cache.
119
- if verbosity == 'full':
120
- print(f'Done setting up PyTorch plugin "{module_name}".')
121
- elif verbosity == 'brief':
122
- print('Done.')
123
- _cached_plugins[module_name] = module
124
- return module
125
-
126
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/misc.py DELETED
@@ -1,262 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- import re
10
- import contextlib
11
- import numpy as np
12
- import torch
13
- import warnings
14
- import dnnlib
15
-
16
- #----------------------------------------------------------------------------
17
- # Cached construction of constant tensors. Avoids CPU=>GPU copy when the
18
- # same constant is used multiple times.
19
-
20
- _constant_cache = dict()
21
-
22
- def constant(value, shape=None, dtype=None, device=None, memory_format=None):
23
- value = np.asarray(value)
24
- if shape is not None:
25
- shape = tuple(shape)
26
- if dtype is None:
27
- dtype = torch.get_default_dtype()
28
- if device is None:
29
- device = torch.device('cpu')
30
- if memory_format is None:
31
- memory_format = torch.contiguous_format
32
-
33
- key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
34
- tensor = _constant_cache.get(key, None)
35
- if tensor is None:
36
- tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
37
- if shape is not None:
38
- tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
39
- tensor = tensor.contiguous(memory_format=memory_format)
40
- _constant_cache[key] = tensor
41
- return tensor
42
-
43
- #----------------------------------------------------------------------------
44
- # Replace NaN/Inf with specified numerical values.
45
-
46
- try:
47
- nan_to_num = torch.nan_to_num # 1.8.0a0
48
- except AttributeError:
49
- def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
50
- assert isinstance(input, torch.Tensor)
51
- if posinf is None:
52
- posinf = torch.finfo(input.dtype).max
53
- if neginf is None:
54
- neginf = torch.finfo(input.dtype).min
55
- assert nan == 0
56
- return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
57
-
58
- #----------------------------------------------------------------------------
59
- # Symbolic assert.
60
-
61
- try:
62
- symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
63
- except AttributeError:
64
- symbolic_assert = torch.Assert # 1.7.0
65
-
66
- #----------------------------------------------------------------------------
67
- # Context manager to suppress known warnings in torch.jit.trace().
68
-
69
- class suppress_tracer_warnings(warnings.catch_warnings):
70
- def __enter__(self):
71
- super().__enter__()
72
- warnings.simplefilter('ignore', category=torch.jit.TracerWarning)
73
- return self
74
-
75
- #----------------------------------------------------------------------------
76
- # Assert that the shape of a tensor matches the given list of integers.
77
- # None indicates that the size of a dimension is allowed to vary.
78
- # Performs symbolic assertion when used in torch.jit.trace().
79
-
80
- def assert_shape(tensor, ref_shape):
81
- if tensor.ndim != len(ref_shape):
82
- raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
83
- for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
84
- if ref_size is None:
85
- pass
86
- elif isinstance(ref_size, torch.Tensor):
87
- with suppress_tracer_warnings(): # as_tensor results are registered as constants
88
- symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
89
- elif isinstance(size, torch.Tensor):
90
- with suppress_tracer_warnings(): # as_tensor results are registered as constants
91
- symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
92
- elif size != ref_size:
93
- raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
94
-
95
- #----------------------------------------------------------------------------
96
- # Function decorator that calls torch.autograd.profiler.record_function().
97
-
98
- def profiled_function(fn):
99
- def decorator(*args, **kwargs):
100
- with torch.autograd.profiler.record_function(fn.__name__):
101
- return fn(*args, **kwargs)
102
- decorator.__name__ = fn.__name__
103
- return decorator
104
-
105
- #----------------------------------------------------------------------------
106
- # Sampler for torch.utils.data.DataLoader that loops over the dataset
107
- # indefinitely, shuffling items as it goes.
108
-
109
- class InfiniteSampler(torch.utils.data.Sampler):
110
- def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
111
- assert len(dataset) > 0
112
- assert num_replicas > 0
113
- assert 0 <= rank < num_replicas
114
- assert 0 <= window_size <= 1
115
- super().__init__(dataset)
116
- self.dataset = dataset
117
- self.rank = rank
118
- self.num_replicas = num_replicas
119
- self.shuffle = shuffle
120
- self.seed = seed
121
- self.window_size = window_size
122
-
123
- def __iter__(self):
124
- order = np.arange(len(self.dataset))
125
- rnd = None
126
- window = 0
127
- if self.shuffle:
128
- rnd = np.random.RandomState(self.seed)
129
- rnd.shuffle(order)
130
- window = int(np.rint(order.size * self.window_size))
131
-
132
- idx = 0
133
- while True:
134
- i = idx % order.size
135
- if idx % self.num_replicas == self.rank:
136
- yield order[i]
137
- if window >= 2:
138
- j = (i - rnd.randint(window)) % order.size
139
- order[i], order[j] = order[j], order[i]
140
- idx += 1
141
-
142
- #----------------------------------------------------------------------------
143
- # Utilities for operating with torch.nn.Module parameters and buffers.
144
-
145
- def params_and_buffers(module):
146
- assert isinstance(module, torch.nn.Module)
147
- return list(module.parameters()) + list(module.buffers())
148
-
149
- def named_params_and_buffers(module):
150
- assert isinstance(module, torch.nn.Module)
151
- return list(module.named_parameters()) + list(module.named_buffers())
152
-
153
- def copy_params_and_buffers(src_module, dst_module, require_all=False):
154
- assert isinstance(src_module, torch.nn.Module)
155
- assert isinstance(dst_module, torch.nn.Module)
156
- src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)}
157
- for name, tensor in named_params_and_buffers(dst_module):
158
- assert (name in src_tensors) or (not require_all)
159
- if name in src_tensors:
160
- tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
161
-
162
- #----------------------------------------------------------------------------
163
- # Context manager for easily enabling/disabling DistributedDataParallel
164
- # synchronization.
165
-
166
- @contextlib.contextmanager
167
- def ddp_sync(module, sync):
168
- assert isinstance(module, torch.nn.Module)
169
- if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
170
- yield
171
- else:
172
- with module.no_sync():
173
- yield
174
-
175
- #----------------------------------------------------------------------------
176
- # Check DistributedDataParallel consistency across processes.
177
-
178
- def check_ddp_consistency(module, ignore_regex=None):
179
- assert isinstance(module, torch.nn.Module)
180
- for name, tensor in named_params_and_buffers(module):
181
- fullname = type(module).__name__ + '.' + name
182
- if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
183
- continue
184
- tensor = tensor.detach()
185
- other = tensor.clone()
186
- torch.distributed.broadcast(tensor=other, src=0)
187
- assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname
188
-
189
- #----------------------------------------------------------------------------
190
- # Print summary table of module hierarchy.
191
-
192
- def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
193
- assert isinstance(module, torch.nn.Module)
194
- assert not isinstance(module, torch.jit.ScriptModule)
195
- assert isinstance(inputs, (tuple, list))
196
-
197
- # Register hooks.
198
- entries = []
199
- nesting = [0]
200
- def pre_hook(_mod, _inputs):
201
- nesting[0] += 1
202
- def post_hook(mod, _inputs, outputs):
203
- nesting[0] -= 1
204
- if nesting[0] <= max_nesting:
205
- outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
206
- outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
207
- entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
208
- hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
209
- hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
210
-
211
- # Run module.
212
- outputs = module(*inputs)
213
- for hook in hooks:
214
- hook.remove()
215
-
216
- # Identify unique outputs, parameters, and buffers.
217
- tensors_seen = set()
218
- for e in entries:
219
- e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
220
- e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
221
- e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
222
- tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
223
-
224
- # Filter out redundant entries.
225
- if skip_redundant:
226
- entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
227
-
228
- # Construct table.
229
- rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
230
- rows += [['---'] * len(rows[0])]
231
- param_total = 0
232
- buffer_total = 0
233
- submodule_names = {mod: name for name, mod in module.named_modules()}
234
- for e in entries:
235
- name = '<top-level>' if e.mod is module else submodule_names[e.mod]
236
- param_size = sum(t.numel() for t in e.unique_params)
237
- buffer_size = sum(t.numel() for t in e.unique_buffers)
238
- output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs]
239
- output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
240
- rows += [[
241
- name + (':0' if len(e.outputs) >= 2 else ''),
242
- str(param_size) if param_size else '-',
243
- str(buffer_size) if buffer_size else '-',
244
- (output_shapes + ['-'])[0],
245
- (output_dtypes + ['-'])[0],
246
- ]]
247
- for idx in range(1, len(e.outputs)):
248
- rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
249
- param_total += param_size
250
- buffer_total += buffer_size
251
- rows += [['---'] * len(rows[0])]
252
- rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
253
-
254
- # Print table.
255
- widths = [max(len(cell) for cell in column) for column in zip(*rows)]
256
- print()
257
- for row in rows:
258
- print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
259
- print()
260
- return outputs
261
-
262
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/ops/__init__.py DELETED
@@ -1,9 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- # empty
 
 
 
 
 
 
 
 
 
 
torch_utils/ops/bias_act.cpp DELETED
@@ -1,99 +0,0 @@
1
- // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- //
3
- // NVIDIA CORPORATION and its licensors retain all intellectual property
4
- // and proprietary rights in and to this software, related documentation
5
- // and any modifications thereto. Any use, reproduction, disclosure or
6
- // distribution of this software and related documentation without an express
7
- // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- #include <torch/extension.h>
10
- #include <ATen/cuda/CUDAContext.h>
11
- #include <c10/cuda/CUDAGuard.h>
12
- #include "bias_act.h"
13
-
14
- //------------------------------------------------------------------------
15
-
16
- static bool has_same_layout(torch::Tensor x, torch::Tensor y)
17
- {
18
- if (x.dim() != y.dim())
19
- return false;
20
- for (int64_t i = 0; i < x.dim(); i++)
21
- {
22
- if (x.size(i) != y.size(i))
23
- return false;
24
- if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
25
- return false;
26
- }
27
- return true;
28
- }
29
-
30
- //------------------------------------------------------------------------
31
-
32
- static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
33
- {
34
- // Validate arguments.
35
- TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
36
- TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
37
- TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
38
- TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
39
- TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
40
- TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
41
- TORCH_CHECK(b.dim() == 1, "b must have rank 1");
42
- TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
43
- TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
44
- TORCH_CHECK(grad >= 0, "grad must be non-negative");
45
-
46
- // Validate layout.
47
- TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
48
- TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
49
- TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
50
- TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
51
- TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
52
-
53
- // Create output tensor.
54
- const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
55
- torch::Tensor y = torch::empty_like(x);
56
- TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
57
-
58
- // Initialize CUDA kernel parameters.
59
- bias_act_kernel_params p;
60
- p.x = x.data_ptr();
61
- p.b = (b.numel()) ? b.data_ptr() : NULL;
62
- p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
63
- p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
64
- p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
65
- p.y = y.data_ptr();
66
- p.grad = grad;
67
- p.act = act;
68
- p.alpha = alpha;
69
- p.gain = gain;
70
- p.clamp = clamp;
71
- p.sizeX = (int)x.numel();
72
- p.sizeB = (int)b.numel();
73
- p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
74
-
75
- // Choose CUDA kernel.
76
- void* kernel;
77
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
78
- {
79
- kernel = choose_bias_act_kernel<scalar_t>(p);
80
- });
81
- TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
82
-
83
- // Launch CUDA kernel.
84
- p.loopX = 4;
85
- int blockSize = 4 * 32;
86
- int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
87
- void* args[] = {&p};
88
- AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
89
- return y;
90
- }
91
-
92
- //------------------------------------------------------------------------
93
-
94
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
95
- {
96
- m.def("bias_act", &bias_act);
97
- }
98
-
99
- //------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/ops/bias_act.cu DELETED
@@ -1,173 +0,0 @@
1
- // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- //
3
- // NVIDIA CORPORATION and its licensors retain all intellectual property
4
- // and proprietary rights in and to this software, related documentation
5
- // and any modifications thereto. Any use, reproduction, disclosure or
6
- // distribution of this software and related documentation without an express
7
- // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- #include <c10/util/Half.h>
10
- #include "bias_act.h"
11
-
12
- //------------------------------------------------------------------------
13
- // Helpers.
14
-
15
- template <class T> struct InternalType;
16
- template <> struct InternalType<double> { typedef double scalar_t; };
17
- template <> struct InternalType<float> { typedef float scalar_t; };
18
- template <> struct InternalType<c10::Half> { typedef float scalar_t; };
19
-
20
- //------------------------------------------------------------------------
21
- // CUDA kernel.
22
-
23
- template <class T, int A>
24
- __global__ void bias_act_kernel(bias_act_kernel_params p)
25
- {
26
- typedef typename InternalType<T>::scalar_t scalar_t;
27
- int G = p.grad;
28
- scalar_t alpha = (scalar_t)p.alpha;
29
- scalar_t gain = (scalar_t)p.gain;
30
- scalar_t clamp = (scalar_t)p.clamp;
31
- scalar_t one = (scalar_t)1;
32
- scalar_t two = (scalar_t)2;
33
- scalar_t expRange = (scalar_t)80;
34
- scalar_t halfExpRange = (scalar_t)40;
35
- scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
36
- scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
37
-
38
- // Loop over elements.
39
- int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
40
- for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
41
- {
42
- // Load.
43
- scalar_t x = (scalar_t)((const T*)p.x)[xi];
44
- scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
45
- scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
46
- scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
47
- scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
48
- scalar_t yy = (gain != 0) ? yref / gain : 0;
49
- scalar_t y = 0;
50
-
51
- // Apply bias.
52
- ((G == 0) ? x : xref) += b;
53
-
54
- // linear
55
- if (A == 1)
56
- {
57
- if (G == 0) y = x;
58
- if (G == 1) y = x;
59
- }
60
-
61
- // relu
62
- if (A == 2)
63
- {
64
- if (G == 0) y = (x > 0) ? x : 0;
65
- if (G == 1) y = (yy > 0) ? x : 0;
66
- }
67
-
68
- // lrelu
69
- if (A == 3)
70
- {
71
- if (G == 0) y = (x > 0) ? x : x * alpha;
72
- if (G == 1) y = (yy > 0) ? x : x * alpha;
73
- }
74
-
75
- // tanh
76
- if (A == 4)
77
- {
78
- if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
79
- if (G == 1) y = x * (one - yy * yy);
80
- if (G == 2) y = x * (one - yy * yy) * (-two * yy);
81
- }
82
-
83
- // sigmoid
84
- if (A == 5)
85
- {
86
- if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
87
- if (G == 1) y = x * yy * (one - yy);
88
- if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
89
- }
90
-
91
- // elu
92
- if (A == 6)
93
- {
94
- if (G == 0) y = (x >= 0) ? x : exp(x) - one;
95
- if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
96
- if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
97
- }
98
-
99
- // selu
100
- if (A == 7)
101
- {
102
- if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
103
- if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
104
- if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
105
- }
106
-
107
- // softplus
108
- if (A == 8)
109
- {
110
- if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
111
- if (G == 1) y = x * (one - exp(-yy));
112
- if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
113
- }
114
-
115
- // swish
116
- if (A == 9)
117
- {
118
- if (G == 0)
119
- y = (x < -expRange) ? 0 : x / (exp(-x) + one);
120
- else
121
- {
122
- scalar_t c = exp(xref);
123
- scalar_t d = c + one;
124
- if (G == 1)
125
- y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
126
- else
127
- y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
128
- yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
129
- }
130
- }
131
-
132
- // Apply gain.
133
- y *= gain * dy;
134
-
135
- // Clamp.
136
- if (clamp >= 0)
137
- {
138
- if (G == 0)
139
- y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
140
- else
141
- y = (yref > -clamp & yref < clamp) ? y : 0;
142
- }
143
-
144
- // Store.
145
- ((T*)p.y)[xi] = (T)y;
146
- }
147
- }
148
-
149
- //------------------------------------------------------------------------
150
- // CUDA kernel selection.
151
-
152
- template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p)
153
- {
154
- if (p.act == 1) return (void*)bias_act_kernel<T, 1>;
155
- if (p.act == 2) return (void*)bias_act_kernel<T, 2>;
156
- if (p.act == 3) return (void*)bias_act_kernel<T, 3>;
157
- if (p.act == 4) return (void*)bias_act_kernel<T, 4>;
158
- if (p.act == 5) return (void*)bias_act_kernel<T, 5>;
159
- if (p.act == 6) return (void*)bias_act_kernel<T, 6>;
160
- if (p.act == 7) return (void*)bias_act_kernel<T, 7>;
161
- if (p.act == 8) return (void*)bias_act_kernel<T, 8>;
162
- if (p.act == 9) return (void*)bias_act_kernel<T, 9>;
163
- return NULL;
164
- }
165
-
166
- //------------------------------------------------------------------------
167
- // Template specializations.
168
-
169
- template void* choose_bias_act_kernel<double> (const bias_act_kernel_params& p);
170
- template void* choose_bias_act_kernel<float> (const bias_act_kernel_params& p);
171
- template void* choose_bias_act_kernel<c10::Half> (const bias_act_kernel_params& p);
172
-
173
- //------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/ops/bias_act.h DELETED
@@ -1,38 +0,0 @@
1
- // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- //
3
- // NVIDIA CORPORATION and its licensors retain all intellectual property
4
- // and proprietary rights in and to this software, related documentation
5
- // and any modifications thereto. Any use, reproduction, disclosure or
6
- // distribution of this software and related documentation without an express
7
- // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- //------------------------------------------------------------------------
10
- // CUDA kernel parameters.
11
-
12
- struct bias_act_kernel_params
13
- {
14
- const void* x; // [sizeX]
15
- const void* b; // [sizeB] or NULL
16
- const void* xref; // [sizeX] or NULL
17
- const void* yref; // [sizeX] or NULL
18
- const void* dy; // [sizeX] or NULL
19
- void* y; // [sizeX]
20
-
21
- int grad;
22
- int act;
23
- float alpha;
24
- float gain;
25
- float clamp;
26
-
27
- int sizeX;
28
- int sizeB;
29
- int stepB;
30
- int loopX;
31
- };
32
-
33
- //------------------------------------------------------------------------
34
- // CUDA kernel selection.
35
-
36
- template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p);
37
-
38
- //------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/ops/bias_act.py DELETED
@@ -1,212 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- """Custom PyTorch ops for efficient bias and activation."""
10
-
11
- import os
12
- import warnings
13
- import numpy as np
14
- import torch
15
- import dnnlib
16
- import traceback
17
-
18
- from .. import custom_ops
19
- from .. import misc
20
-
21
- #----------------------------------------------------------------------------
22
-
23
- activation_funcs = {
24
- 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
25
- 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
26
- 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
27
- 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
28
- 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
29
- 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
30
- 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
31
- 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
32
- 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
33
- }
34
-
35
- #----------------------------------------------------------------------------
36
-
37
- _inited = False
38
- _plugin = None
39
- _null_tensor = torch.empty([0])
40
-
41
- def _init():
42
- global _inited, _plugin
43
- if not _inited:
44
- _inited = True
45
- sources = ['bias_act.cpp', 'bias_act.cu']
46
- sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
47
- try:
48
- _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
49
- except:
50
- warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
51
- return _plugin is not None
52
-
53
- #----------------------------------------------------------------------------
54
-
55
- def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
56
- r"""Fused bias and activation function.
57
-
58
- Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
59
- and scales the result by `gain`. Each of the steps is optional. In most cases,
60
- the fused op is considerably more efficient than performing the same calculation
61
- using standard PyTorch ops. It supports first and second order gradients,
62
- but not third order gradients.
63
-
64
- Args:
65
- x: Input activation tensor. Can be of any shape.
66
- b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
67
- as `x`. The shape must be known, and it must match the dimension of `x`
68
- corresponding to `dim`.
69
- dim: The dimension in `x` corresponding to the elements of `b`.
70
- The value of `dim` is ignored if `b` is not specified.
71
- act: Name of the activation function to evaluate, or `"linear"` to disable.
72
- Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
73
- See `activation_funcs` for a full list. `None` is not allowed.
74
- alpha: Shape parameter for the activation function, or `None` to use the default.
75
- gain: Scaling factor for the output tensor, or `None` to use default.
76
- See `activation_funcs` for the default scaling of each activation function.
77
- If unsure, consider specifying 1.
78
- clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
79
- the clamping (default).
80
- impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
81
-
82
- Returns:
83
- Tensor of the same shape and datatype as `x`.
84
- """
85
- assert isinstance(x, torch.Tensor)
86
- assert impl in ['ref', 'cuda']
87
- if impl == 'cuda' and x.device.type == 'cuda' and _init():
88
- return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
89
- return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
90
-
91
- #----------------------------------------------------------------------------
92
-
93
- @misc.profiled_function
94
- def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
95
- """Slow reference implementation of `bias_act()` using standard TensorFlow ops.
96
- """
97
- assert isinstance(x, torch.Tensor)
98
- assert clamp is None or clamp >= 0
99
- spec = activation_funcs[act]
100
- alpha = float(alpha if alpha is not None else spec.def_alpha)
101
- gain = float(gain if gain is not None else spec.def_gain)
102
- clamp = float(clamp if clamp is not None else -1)
103
-
104
- # Add bias.
105
- if b is not None:
106
- assert isinstance(b, torch.Tensor) and b.ndim == 1
107
- assert 0 <= dim < x.ndim
108
- assert b.shape[0] == x.shape[dim]
109
- x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
110
-
111
- # Evaluate activation function.
112
- alpha = float(alpha)
113
- x = spec.func(x, alpha=alpha)
114
-
115
- # Scale by gain.
116
- gain = float(gain)
117
- if gain != 1:
118
- x = x * gain
119
-
120
- # Clamp.
121
- if clamp >= 0:
122
- x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
123
- return x
124
-
125
- #----------------------------------------------------------------------------
126
-
127
- _bias_act_cuda_cache = dict()
128
-
129
- def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
130
- """Fast CUDA implementation of `bias_act()` using custom ops.
131
- """
132
- # Parse arguments.
133
- assert clamp is None or clamp >= 0
134
- spec = activation_funcs[act]
135
- alpha = float(alpha if alpha is not None else spec.def_alpha)
136
- gain = float(gain if gain is not None else spec.def_gain)
137
- clamp = float(clamp if clamp is not None else -1)
138
-
139
- # Lookup from cache.
140
- key = (dim, act, alpha, gain, clamp)
141
- if key in _bias_act_cuda_cache:
142
- return _bias_act_cuda_cache[key]
143
-
144
- # Forward op.
145
- class BiasActCuda(torch.autograd.Function):
146
- @staticmethod
147
- def forward(ctx, x, b): # pylint: disable=arguments-differ
148
- ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format
149
- x = x.contiguous(memory_format=ctx.memory_format)
150
- b = b.contiguous() if b is not None else _null_tensor
151
- y = x
152
- if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
153
- y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
154
- ctx.save_for_backward(
155
- x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
156
- b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
157
- y if 'y' in spec.ref else _null_tensor)
158
- return y
159
-
160
- @staticmethod
161
- def backward(ctx, dy): # pylint: disable=arguments-differ
162
- dy = dy.contiguous(memory_format=ctx.memory_format)
163
- x, b, y = ctx.saved_tensors
164
- dx = None
165
- db = None
166
-
167
- if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
168
- dx = dy
169
- if act != 'linear' or gain != 1 or clamp >= 0:
170
- dx = BiasActCudaGrad.apply(dy, x, b, y)
171
-
172
- if ctx.needs_input_grad[1]:
173
- db = dx.sum([i for i in range(dx.ndim) if i != dim])
174
-
175
- return dx, db
176
-
177
- # Backward op.
178
- class BiasActCudaGrad(torch.autograd.Function):
179
- @staticmethod
180
- def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
181
- ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format
182
- dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
183
- ctx.save_for_backward(
184
- dy if spec.has_2nd_grad else _null_tensor,
185
- x, b, y)
186
- return dx
187
-
188
- @staticmethod
189
- def backward(ctx, d_dx): # pylint: disable=arguments-differ
190
- d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
191
- dy, x, b, y = ctx.saved_tensors
192
- d_dy = None
193
- d_x = None
194
- d_b = None
195
- d_y = None
196
-
197
- if ctx.needs_input_grad[0]:
198
- d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
199
-
200
- if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
201
- d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
202
-
203
- if spec.has_2nd_grad and ctx.needs_input_grad[2]:
204
- d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
205
-
206
- return d_dy, d_x, d_b, d_y
207
-
208
- # Add to cache.
209
- _bias_act_cuda_cache[key] = BiasActCuda
210
- return BiasActCuda
211
-
212
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/ops/conv2d_gradfix.py DELETED
@@ -1,170 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- """Custom replacement for `torch.nn.functional.conv2d` that supports
10
- arbitrarily high order gradients with zero performance penalty."""
11
-
12
- import warnings
13
- import contextlib
14
- import torch
15
-
16
- # pylint: disable=redefined-builtin
17
- # pylint: disable=arguments-differ
18
- # pylint: disable=protected-access
19
-
20
- #----------------------------------------------------------------------------
21
-
22
- enabled = False # Enable the custom op by setting this to true.
23
- weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
24
-
25
- @contextlib.contextmanager
26
- def no_weight_gradients():
27
- global weight_gradients_disabled
28
- old = weight_gradients_disabled
29
- weight_gradients_disabled = True
30
- yield
31
- weight_gradients_disabled = old
32
-
33
- #----------------------------------------------------------------------------
34
-
35
- def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
36
- if _should_use_custom_op(input):
37
- return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
38
- return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
39
-
40
- def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
41
- if _should_use_custom_op(input):
42
- return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
43
- return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
44
-
45
- #----------------------------------------------------------------------------
46
-
47
- def _should_use_custom_op(input):
48
- assert isinstance(input, torch.Tensor)
49
- if (not enabled) or (not torch.backends.cudnn.enabled):
50
- return False
51
- if input.device.type != 'cuda':
52
- return False
53
- if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
54
- return True
55
- warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().')
56
- return False
57
-
58
- def _tuple_of_ints(xs, ndim):
59
- xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
60
- assert len(xs) == ndim
61
- assert all(isinstance(x, int) for x in xs)
62
- return xs
63
-
64
- #----------------------------------------------------------------------------
65
-
66
- _conv2d_gradfix_cache = dict()
67
-
68
- def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
69
- # Parse arguments.
70
- ndim = 2
71
- weight_shape = tuple(weight_shape)
72
- stride = _tuple_of_ints(stride, ndim)
73
- padding = _tuple_of_ints(padding, ndim)
74
- output_padding = _tuple_of_ints(output_padding, ndim)
75
- dilation = _tuple_of_ints(dilation, ndim)
76
-
77
- # Lookup from cache.
78
- key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
79
- if key in _conv2d_gradfix_cache:
80
- return _conv2d_gradfix_cache[key]
81
-
82
- # Validate arguments.
83
- assert groups >= 1
84
- assert len(weight_shape) == ndim + 2
85
- assert all(stride[i] >= 1 for i in range(ndim))
86
- assert all(padding[i] >= 0 for i in range(ndim))
87
- assert all(dilation[i] >= 0 for i in range(ndim))
88
- if not transpose:
89
- assert all(output_padding[i] == 0 for i in range(ndim))
90
- else: # transpose
91
- assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
92
-
93
- # Helpers.
94
- common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
95
- def calc_output_padding(input_shape, output_shape):
96
- if transpose:
97
- return [0, 0]
98
- return [
99
- input_shape[i + 2]
100
- - (output_shape[i + 2] - 1) * stride[i]
101
- - (1 - 2 * padding[i])
102
- - dilation[i] * (weight_shape[i + 2] - 1)
103
- for i in range(ndim)
104
- ]
105
-
106
- # Forward & backward.
107
- class Conv2d(torch.autograd.Function):
108
- @staticmethod
109
- def forward(ctx, input, weight, bias):
110
- assert weight.shape == weight_shape
111
- if not transpose:
112
- output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
113
- else: # transpose
114
- output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
115
- ctx.save_for_backward(input, weight)
116
- return output
117
-
118
- @staticmethod
119
- def backward(ctx, grad_output):
120
- input, weight = ctx.saved_tensors
121
- grad_input = None
122
- grad_weight = None
123
- grad_bias = None
124
-
125
- if ctx.needs_input_grad[0]:
126
- p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
127
- grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None)
128
- assert grad_input.shape == input.shape
129
-
130
- if ctx.needs_input_grad[1] and not weight_gradients_disabled:
131
- grad_weight = Conv2dGradWeight.apply(grad_output, input)
132
- assert grad_weight.shape == weight_shape
133
-
134
- if ctx.needs_input_grad[2]:
135
- grad_bias = grad_output.sum([0, 2, 3])
136
-
137
- return grad_input, grad_weight, grad_bias
138
-
139
- # Gradient with respect to the weights.
140
- class Conv2dGradWeight(torch.autograd.Function):
141
- @staticmethod
142
- def forward(ctx, grad_output, input):
143
- op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight')
144
- flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
145
- grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
146
- assert grad_weight.shape == weight_shape
147
- ctx.save_for_backward(grad_output, input)
148
- return grad_weight
149
-
150
- @staticmethod
151
- def backward(ctx, grad2_grad_weight):
152
- grad_output, input = ctx.saved_tensors
153
- grad2_grad_output = None
154
- grad2_input = None
155
-
156
- if ctx.needs_input_grad[0]:
157
- grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
158
- assert grad2_grad_output.shape == grad_output.shape
159
-
160
- if ctx.needs_input_grad[1]:
161
- p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
162
- grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None)
163
- assert grad2_input.shape == input.shape
164
-
165
- return grad2_grad_output, grad2_input
166
-
167
- _conv2d_gradfix_cache[key] = Conv2d
168
- return Conv2d
169
-
170
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/ops/conv2d_resample.py DELETED
@@ -1,156 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- """2D convolution with optional up/downsampling."""
10
-
11
- import torch
12
-
13
- from .. import misc
14
- from . import conv2d_gradfix
15
- from . import upfirdn2d
16
- from .upfirdn2d import _parse_padding
17
- from .upfirdn2d import _get_filter_size
18
-
19
- #----------------------------------------------------------------------------
20
-
21
- def _get_weight_shape(w):
22
- with misc.suppress_tracer_warnings(): # this value will be treated as a constant
23
- shape = [int(sz) for sz in w.shape]
24
- misc.assert_shape(w, shape)
25
- return shape
26
-
27
- #----------------------------------------------------------------------------
28
-
29
- def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
30
- """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
31
- """
32
- out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
33
-
34
- # Flip weight if requested.
35
- if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
36
- w = w.flip([2, 3])
37
-
38
- # Workaround performance pitfall in cuDNN 8.0.5, triggered when using
39
- # 1x1 kernel + memory_format=channels_last + less than 64 channels.
40
- if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose:
41
- if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64:
42
- if out_channels <= 4 and groups == 1:
43
- in_shape = x.shape
44
- x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1])
45
- x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]])
46
- else:
47
- x = x.to(memory_format=torch.contiguous_format)
48
- w = w.to(memory_format=torch.contiguous_format)
49
- x = conv2d_gradfix.conv2d(x, w, groups=groups)
50
- return x.to(memory_format=torch.channels_last)
51
-
52
- # Otherwise => execute using conv2d_gradfix.
53
- op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
54
- return op(x, w, stride=stride, padding=padding, groups=groups)
55
-
56
- #----------------------------------------------------------------------------
57
-
58
- @misc.profiled_function
59
- def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
60
- r"""2D convolution with optional up/downsampling.
61
-
62
- Padding is performed only once at the beginning, not between the operations.
63
-
64
- Args:
65
- x: Input tensor of shape
66
- `[batch_size, in_channels, in_height, in_width]`.
67
- w: Weight tensor of shape
68
- `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
69
- f: Low-pass filter for up/downsampling. Must be prepared beforehand by
70
- calling upfirdn2d.setup_filter(). None = identity (default).
71
- up: Integer upsampling factor (default: 1).
72
- down: Integer downsampling factor (default: 1).
73
- padding: Padding with respect to the upsampled image. Can be a single number
74
- or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
75
- (default: 0).
76
- groups: Split input channels into N groups (default: 1).
77
- flip_weight: False = convolution, True = correlation (default: True).
78
- flip_filter: False = convolution, True = correlation (default: False).
79
-
80
- Returns:
81
- Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
82
- """
83
- # Validate arguments.
84
- assert isinstance(x, torch.Tensor) and (x.ndim == 4)
85
- assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
86
- assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
87
- assert isinstance(up, int) and (up >= 1)
88
- assert isinstance(down, int) and (down >= 1)
89
- assert isinstance(groups, int) and (groups >= 1)
90
- out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
91
- fw, fh = _get_filter_size(f)
92
- px0, px1, py0, py1 = _parse_padding(padding)
93
-
94
- # Adjust padding to account for up/downsampling.
95
- if up > 1:
96
- px0 += (fw + up - 1) // 2
97
- px1 += (fw - up) // 2
98
- py0 += (fh + up - 1) // 2
99
- py1 += (fh - up) // 2
100
- if down > 1:
101
- px0 += (fw - down + 1) // 2
102
- px1 += (fw - down) // 2
103
- py0 += (fh - down + 1) // 2
104
- py1 += (fh - down) // 2
105
-
106
- # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
107
- if kw == 1 and kh == 1 and (down > 1 and up == 1):
108
- x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
109
- x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
110
- return x
111
-
112
- # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
113
- if kw == 1 and kh == 1 and (up > 1 and down == 1):
114
- x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
115
- x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
116
- return x
117
-
118
- # Fast path: downsampling only => use strided convolution.
119
- if down > 1 and up == 1:
120
- x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
121
- x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
122
- return x
123
-
124
- # Fast path: upsampling with optional downsampling => use transpose strided convolution.
125
- if up > 1:
126
- if groups == 1:
127
- w = w.transpose(0, 1)
128
- else:
129
- w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
130
- w = w.transpose(1, 2)
131
- w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
132
- px0 -= kw - 1
133
- px1 -= kw - up
134
- py0 -= kh - 1
135
- py1 -= kh - up
136
- pxt = max(min(-px0, -px1), 0)
137
- pyt = max(min(-py0, -py1), 0)
138
- x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
139
- x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
140
- if down > 1:
141
- x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
142
- return x
143
-
144
- # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
145
- if up == 1 and down == 1:
146
- if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
147
- return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
148
-
149
- # Fallback: Generic reference implementation.
150
- x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
151
- x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
152
- if down > 1:
153
- x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
154
- return x
155
-
156
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/ops/fma.py DELETED
@@ -1,60 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
10
-
11
- import torch
12
-
13
- #----------------------------------------------------------------------------
14
-
15
- def fma(a, b, c): # => a * b + c
16
- return _FusedMultiplyAdd.apply(a, b, c)
17
-
18
- #----------------------------------------------------------------------------
19
-
20
- class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
21
- @staticmethod
22
- def forward(ctx, a, b, c): # pylint: disable=arguments-differ
23
- out = torch.addcmul(c, a, b)
24
- ctx.save_for_backward(a, b)
25
- ctx.c_shape = c.shape
26
- return out
27
-
28
- @staticmethod
29
- def backward(ctx, dout): # pylint: disable=arguments-differ
30
- a, b = ctx.saved_tensors
31
- c_shape = ctx.c_shape
32
- da = None
33
- db = None
34
- dc = None
35
-
36
- if ctx.needs_input_grad[0]:
37
- da = _unbroadcast(dout * b, a.shape)
38
-
39
- if ctx.needs_input_grad[1]:
40
- db = _unbroadcast(dout * a, b.shape)
41
-
42
- if ctx.needs_input_grad[2]:
43
- dc = _unbroadcast(dout, c_shape)
44
-
45
- return da, db, dc
46
-
47
- #----------------------------------------------------------------------------
48
-
49
- def _unbroadcast(x, shape):
50
- extra_dims = x.ndim - len(shape)
51
- assert extra_dims >= 0
52
- dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
53
- if len(dim):
54
- x = x.sum(dim=dim, keepdim=True)
55
- if extra_dims:
56
- x = x.reshape(-1, *x.shape[extra_dims+1:])
57
- assert x.shape == shape
58
- return x
59
-
60
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/ops/grid_sample_gradfix.py DELETED
@@ -1,83 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- """Custom replacement for `torch.nn.functional.grid_sample` that
10
- supports arbitrarily high order gradients between the input and output.
11
- Only works on 2D images and assumes
12
- `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
13
-
14
- import warnings
15
- import torch
16
-
17
- # pylint: disable=redefined-builtin
18
- # pylint: disable=arguments-differ
19
- # pylint: disable=protected-access
20
-
21
- #----------------------------------------------------------------------------
22
-
23
- enabled = False # Enable the custom op by setting this to true.
24
-
25
- #----------------------------------------------------------------------------
26
-
27
- def grid_sample(input, grid):
28
- if _should_use_custom_op():
29
- return _GridSample2dForward.apply(input, grid)
30
- return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
31
-
32
- #----------------------------------------------------------------------------
33
-
34
- def _should_use_custom_op():
35
- if not enabled:
36
- return False
37
- if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
38
- return True
39
- warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().')
40
- return False
41
-
42
- #----------------------------------------------------------------------------
43
-
44
- class _GridSample2dForward(torch.autograd.Function):
45
- @staticmethod
46
- def forward(ctx, input, grid):
47
- assert input.ndim == 4
48
- assert grid.ndim == 4
49
- output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
50
- ctx.save_for_backward(input, grid)
51
- return output
52
-
53
- @staticmethod
54
- def backward(ctx, grad_output):
55
- input, grid = ctx.saved_tensors
56
- grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
57
- return grad_input, grad_grid
58
-
59
- #----------------------------------------------------------------------------
60
-
61
- class _GridSample2dBackward(torch.autograd.Function):
62
- @staticmethod
63
- def forward(ctx, grad_output, input, grid):
64
- op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
65
- grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
66
- ctx.save_for_backward(grid)
67
- return grad_input, grad_grid
68
-
69
- @staticmethod
70
- def backward(ctx, grad2_grad_input, grad2_grad_grid):
71
- _ = grad2_grad_grid # unused
72
- grid, = ctx.saved_tensors
73
- grad2_grad_output = None
74
- grad2_input = None
75
- grad2_grid = None
76
-
77
- if ctx.needs_input_grad[0]:
78
- grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
79
-
80
- assert not ctx.needs_input_grad[2]
81
- return grad2_grad_output, grad2_input, grad2_grid
82
-
83
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/ops/upfirdn2d.cpp DELETED
@@ -1,103 +0,0 @@
1
- // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- //
3
- // NVIDIA CORPORATION and its licensors retain all intellectual property
4
- // and proprietary rights in and to this software, related documentation
5
- // and any modifications thereto. Any use, reproduction, disclosure or
6
- // distribution of this software and related documentation without an express
7
- // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- #include <torch/extension.h>
10
- #include <ATen/cuda/CUDAContext.h>
11
- #include <c10/cuda/CUDAGuard.h>
12
- #include "upfirdn2d.h"
13
-
14
- //------------------------------------------------------------------------
15
-
16
- static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
17
- {
18
- // Validate arguments.
19
- TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
20
- TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
21
- TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
22
- TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
23
- TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
24
- TORCH_CHECK(x.dim() == 4, "x must be rank 4");
25
- TORCH_CHECK(f.dim() == 2, "f must be rank 2");
26
- TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
27
- TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
28
- TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
29
-
30
- // Create output tensor.
31
- const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
32
- int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
33
- int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
34
- TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
35
- torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
36
- TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
37
-
38
- // Initialize CUDA kernel parameters.
39
- upfirdn2d_kernel_params p;
40
- p.x = x.data_ptr();
41
- p.f = f.data_ptr<float>();
42
- p.y = y.data_ptr();
43
- p.up = make_int2(upx, upy);
44
- p.down = make_int2(downx, downy);
45
- p.pad0 = make_int2(padx0, pady0);
46
- p.flip = (flip) ? 1 : 0;
47
- p.gain = gain;
48
- p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
49
- p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
50
- p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
51
- p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
52
- p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
53
- p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
54
- p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
55
- p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
56
-
57
- // Choose CUDA kernel.
58
- upfirdn2d_kernel_spec spec;
59
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
60
- {
61
- spec = choose_upfirdn2d_kernel<scalar_t>(p);
62
- });
63
-
64
- // Set looping options.
65
- p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
66
- p.loopMinor = spec.loopMinor;
67
- p.loopX = spec.loopX;
68
- p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
69
- p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
70
-
71
- // Compute grid size.
72
- dim3 blockSize, gridSize;
73
- if (spec.tileOutW < 0) // large
74
- {
75
- blockSize = dim3(4, 32, 1);
76
- gridSize = dim3(
77
- ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
78
- (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
79
- p.launchMajor);
80
- }
81
- else // small
82
- {
83
- blockSize = dim3(256, 1, 1);
84
- gridSize = dim3(
85
- ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
86
- (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
87
- p.launchMajor);
88
- }
89
-
90
- // Launch CUDA kernel.
91
- void* args[] = {&p};
92
- AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
93
- return y;
94
- }
95
-
96
- //------------------------------------------------------------------------
97
-
98
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
99
- {
100
- m.def("upfirdn2d", &upfirdn2d);
101
- }
102
-
103
- //------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/ops/upfirdn2d.cu DELETED
@@ -1,350 +0,0 @@
1
- // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- //
3
- // NVIDIA CORPORATION and its licensors retain all intellectual property
4
- // and proprietary rights in and to this software, related documentation
5
- // and any modifications thereto. Any use, reproduction, disclosure or
6
- // distribution of this software and related documentation without an express
7
- // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- #include <c10/util/Half.h>
10
- #include "upfirdn2d.h"
11
-
12
- //------------------------------------------------------------------------
13
- // Helpers.
14
-
15
- template <class T> struct InternalType;
16
- template <> struct InternalType<double> { typedef double scalar_t; };
17
- template <> struct InternalType<float> { typedef float scalar_t; };
18
- template <> struct InternalType<c10::Half> { typedef float scalar_t; };
19
-
20
- static __device__ __forceinline__ int floor_div(int a, int b)
21
- {
22
- int t = 1 - a / b;
23
- return (a + t * b) / b - t;
24
- }
25
-
26
- //------------------------------------------------------------------------
27
- // Generic CUDA implementation for large filters.
28
-
29
- template <class T> static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p)
30
- {
31
- typedef typename InternalType<T>::scalar_t scalar_t;
32
-
33
- // Calculate thread index.
34
- int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
35
- int outY = minorBase / p.launchMinor;
36
- minorBase -= outY * p.launchMinor;
37
- int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
38
- int majorBase = blockIdx.z * p.loopMajor;
39
- if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
40
- return;
41
-
42
- // Setup Y receptive field.
43
- int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
44
- int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
45
- int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
46
- int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
47
- if (p.flip)
48
- filterY = p.filterSize.y - 1 - filterY;
49
-
50
- // Loop over major, minor, and X.
51
- for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
52
- for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor)
53
- {
54
- int nc = major * p.sizeMinor + minor;
55
- int n = nc / p.inSize.z;
56
- int c = nc - n * p.inSize.z;
57
- for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y)
58
- {
59
- // Setup X receptive field.
60
- int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
61
- int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
62
- int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX;
63
- int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
64
- if (p.flip)
65
- filterX = p.filterSize.x - 1 - filterX;
66
-
67
- // Initialize pointers.
68
- const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
69
- const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
70
- int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
71
- int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
72
-
73
- // Inner loop.
74
- scalar_t v = 0;
75
- for (int y = 0; y < h; y++)
76
- {
77
- for (int x = 0; x < w; x++)
78
- {
79
- v += (scalar_t)(*xp) * (scalar_t)(*fp);
80
- xp += p.inStride.x;
81
- fp += filterStepX;
82
- }
83
- xp += p.inStride.y - w * p.inStride.x;
84
- fp += filterStepY - w * filterStepX;
85
- }
86
-
87
- // Store result.
88
- v *= p.gain;
89
- ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
90
- }
91
- }
92
- }
93
-
94
- //------------------------------------------------------------------------
95
- // Specialized CUDA implementation for small filters.
96
-
97
- template <class T, int upx, int upy, int downx, int downy, int filterW, int filterH, int tileOutW, int tileOutH, int loopMinor>
98
- static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p)
99
- {
100
- typedef typename InternalType<T>::scalar_t scalar_t;
101
- const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
102
- const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
103
- __shared__ volatile scalar_t sf[filterH][filterW];
104
- __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
105
-
106
- // Calculate tile index.
107
- int minorBase = blockIdx.x;
108
- int tileOutY = minorBase / p.launchMinor;
109
- minorBase -= tileOutY * p.launchMinor;
110
- minorBase *= loopMinor;
111
- tileOutY *= tileOutH;
112
- int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
113
- int majorBase = blockIdx.z * p.loopMajor;
114
- if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor)
115
- return;
116
-
117
- // Load filter (flipped).
118
- for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x)
119
- {
120
- int fy = tapIdx / filterW;
121
- int fx = tapIdx - fy * filterW;
122
- scalar_t v = 0;
123
- if (fx < p.filterSize.x & fy < p.filterSize.y)
124
- {
125
- int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
126
- int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
127
- v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
128
- }
129
- sf[fy][fx] = v;
130
- }
131
-
132
- // Loop over major and X.
133
- for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
134
- {
135
- int baseNC = major * p.sizeMinor + minorBase;
136
- int n = baseNC / p.inSize.z;
137
- int baseC = baseNC - n * p.inSize.z;
138
- for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW)
139
- {
140
- // Load input pixels.
141
- int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
142
- int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
143
- int tileInX = floor_div(tileMidX, upx);
144
- int tileInY = floor_div(tileMidY, upy);
145
- __syncthreads();
146
- for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x)
147
- {
148
- int relC = inIdx;
149
- int relInX = relC / loopMinor;
150
- int relInY = relInX / tileInW;
151
- relC -= relInX * loopMinor;
152
- relInX -= relInY * tileInW;
153
- int c = baseC + relC;
154
- int inX = tileInX + relInX;
155
- int inY = tileInY + relInY;
156
- scalar_t v = 0;
157
- if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z)
158
- v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
159
- sx[relInY][relInX][relC] = v;
160
- }
161
-
162
- // Loop over output pixels.
163
- __syncthreads();
164
- for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x)
165
- {
166
- int relC = outIdx;
167
- int relOutX = relC / loopMinor;
168
- int relOutY = relOutX / tileOutW;
169
- relC -= relOutX * loopMinor;
170
- relOutX -= relOutY * tileOutW;
171
- int c = baseC + relC;
172
- int outX = tileOutX + relOutX;
173
- int outY = tileOutY + relOutY;
174
-
175
- // Setup receptive field.
176
- int midX = tileMidX + relOutX * downx;
177
- int midY = tileMidY + relOutY * downy;
178
- int inX = floor_div(midX, upx);
179
- int inY = floor_div(midY, upy);
180
- int relInX = inX - tileInX;
181
- int relInY = inY - tileInY;
182
- int filterX = (inX + 1) * upx - midX - 1; // flipped
183
- int filterY = (inY + 1) * upy - midY - 1; // flipped
184
-
185
- // Inner loop.
186
- if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z)
187
- {
188
- scalar_t v = 0;
189
- #pragma unroll
190
- for (int y = 0; y < filterH / upy; y++)
191
- #pragma unroll
192
- for (int x = 0; x < filterW / upx; x++)
193
- v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx];
194
- v *= p.gain;
195
- ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
196
- }
197
- }
198
- }
199
- }
200
- }
201
-
202
- //------------------------------------------------------------------------
203
- // CUDA kernel selection.
204
-
205
- template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p)
206
- {
207
- int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
208
-
209
- upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,1, 4}; // contiguous
210
- if (s == 1) spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,4, 1}; // channels_last
211
-
212
- if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
213
- {
214
- if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 64,16,1>, 64,16,1, 1};
215
- if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
216
- if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5, 64,16,1>, 64,16,1, 1};
217
- if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
218
- if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3, 64,16,1>, 64,16,1, 1};
219
- if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
220
- if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 20,1, 128,8,1>, 128,8,1, 1};
221
- if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
222
- if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 12,1, 128,8,1>, 128,8,1, 1};
223
- if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
224
- if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
225
- if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,20, 32,32,1>, 32,32,1, 1};
226
- if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
227
- if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,12, 32,32,1>, 32,32,1, 1};
228
- if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
229
- }
230
- if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
231
- {
232
- if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 16,16,8>, 16,16,8, 1};
233
- if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
234
- if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
235
- if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
236
- if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
237
- if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
238
- if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 20,1, 128,1,16>, 128,1,16, 1};
239
- if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
240
- if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 12,1, 128,1,16>, 128,1,16, 1};
241
- if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
242
- if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
243
- if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,20, 1,128,16>, 1,128,16, 1};
244
- if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
245
- if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,12, 1,128,16>, 1,128,16, 1};
246
- if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
247
- }
248
- if (s != 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
249
- {
250
- if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 64,16,1>, 64,16,1, 1};
251
- if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
252
- if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
253
- if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 64,16,1>, 64,16,1, 1};
254
- }
255
- if (s == 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
256
- {
257
- if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 16,16,8>, 16,16,8, 1};
258
- if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 16,16,8>, 16,16,8, 1};
259
- if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
260
- if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 16,16,8>, 16,16,8, 1};
261
- }
262
- if (s != 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
263
- {
264
- if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
265
- if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 20,1, 128,8,1>, 128,8,1, 1};
266
- if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
267
- if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 12,1, 128,8,1>, 128,8,1, 1};
268
- if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
269
- }
270
- if (s == 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
271
- {
272
- if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
273
- if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 20,1, 128,1,16>, 128,1,16, 1};
274
- if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
275
- if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 12,1, 128,1,16>, 128,1,16, 1};
276
- if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
277
- }
278
- if (s != 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
279
- {
280
- if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
281
- if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,20, 32,32,1>, 32,32,1, 1};
282
- if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
283
- if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,12, 32,32,1>, 32,32,1, 1};
284
- if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
285
- }
286
- if (s == 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
287
- {
288
- if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
289
- if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,20, 1,128,16>, 1,128,16, 1};
290
- if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
291
- if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,12, 1,128,16>, 1,128,16, 1};
292
- if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
293
- }
294
- if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // contiguous
295
- {
296
- if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 32,8,1>, 32,8,1, 1};
297
- if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 32,8,1>, 32,8,1, 1};
298
- if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 32,8,1>, 32,8,1, 1};
299
- if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 32,8,1>, 32,8,1, 1};
300
- }
301
- if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // channels_last
302
- {
303
- if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 8,8,8>, 8,8,8, 1};
304
- if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 8,8,8>, 8,8,8, 1};
305
- if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 8,8,8>, 8,8,8, 1};
306
- if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 8,8,8>, 8,8,8, 1};
307
- }
308
- if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // contiguous
309
- {
310
- if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,8,1>, 64,8,1, 1};
311
- if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 20,1, 64,8,1>, 64,8,1, 1};
312
- if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,8,1>, 64,8,1, 1};
313
- if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 12,1, 64,8,1>, 64,8,1, 1};
314
- if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,8,1>, 64,8,1, 1};
315
- }
316
- if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // channels_last
317
- {
318
- if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,1,8>, 64,1,8, 1};
319
- if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 20,1, 64,1,8>, 64,1,8, 1};
320
- if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,1,8>, 64,1,8, 1};
321
- if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 12,1, 64,1,8>, 64,1,8, 1};
322
- if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,1,8>, 64,1,8, 1};
323
- }
324
- if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // contiguous
325
- {
326
- if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 32,16,1>, 32,16,1, 1};
327
- if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,20, 32,16,1>, 32,16,1, 1};
328
- if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 32,16,1>, 32,16,1, 1};
329
- if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,12, 32,16,1>, 32,16,1, 1};
330
- if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 32,16,1>, 32,16,1, 1};
331
- }
332
- if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // channels_last
333
- {
334
- if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 1,64,8>, 1,64,8, 1};
335
- if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,20, 1,64,8>, 1,64,8, 1};
336
- if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 1,64,8>, 1,64,8, 1};
337
- if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,12, 1,64,8>, 1,64,8, 1};
338
- if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 1,64,8>, 1,64,8, 1};
339
- }
340
- return spec;
341
- }
342
-
343
- //------------------------------------------------------------------------
344
- // Template specializations.
345
-
346
- template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<double> (const upfirdn2d_kernel_params& p);
347
- template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<float> (const upfirdn2d_kernel_params& p);
348
- template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<c10::Half>(const upfirdn2d_kernel_params& p);
349
-
350
- //------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/ops/upfirdn2d.h DELETED
@@ -1,59 +0,0 @@
1
- // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- //
3
- // NVIDIA CORPORATION and its licensors retain all intellectual property
4
- // and proprietary rights in and to this software, related documentation
5
- // and any modifications thereto. Any use, reproduction, disclosure or
6
- // distribution of this software and related documentation without an express
7
- // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- #include <cuda_runtime.h>
10
-
11
- //------------------------------------------------------------------------
12
- // CUDA kernel parameters.
13
-
14
- struct upfirdn2d_kernel_params
15
- {
16
- const void* x;
17
- const float* f;
18
- void* y;
19
-
20
- int2 up;
21
- int2 down;
22
- int2 pad0;
23
- int flip;
24
- float gain;
25
-
26
- int4 inSize; // [width, height, channel, batch]
27
- int4 inStride;
28
- int2 filterSize; // [width, height]
29
- int2 filterStride;
30
- int4 outSize; // [width, height, channel, batch]
31
- int4 outStride;
32
- int sizeMinor;
33
- int sizeMajor;
34
-
35
- int loopMinor;
36
- int loopMajor;
37
- int loopX;
38
- int launchMinor;
39
- int launchMajor;
40
- };
41
-
42
- //------------------------------------------------------------------------
43
- // CUDA kernel specialization.
44
-
45
- struct upfirdn2d_kernel_spec
46
- {
47
- void* kernel;
48
- int tileOutW;
49
- int tileOutH;
50
- int loopMinor;
51
- int loopX;
52
- };
53
-
54
- //------------------------------------------------------------------------
55
- // CUDA kernel selection.
56
-
57
- template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
58
-
59
- //------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/ops/upfirdn2d.py DELETED
@@ -1,384 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- """Custom PyTorch ops for efficient resampling of 2D images."""
10
-
11
- import os
12
- import warnings
13
- import numpy as np
14
- import torch
15
- import traceback
16
-
17
- from .. import custom_ops
18
- from .. import misc
19
- from . import conv2d_gradfix
20
-
21
- #----------------------------------------------------------------------------
22
-
23
- _inited = False
24
- _plugin = None
25
-
26
- def _init():
27
- global _inited, _plugin
28
- if not _inited:
29
- sources = ['upfirdn2d.cpp', 'upfirdn2d.cu']
30
- sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
31
- try:
32
- _plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
33
- except:
34
- warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
35
- return _plugin is not None
36
-
37
- def _parse_scaling(scaling):
38
- if isinstance(scaling, int):
39
- scaling = [scaling, scaling]
40
- assert isinstance(scaling, (list, tuple))
41
- assert all(isinstance(x, int) for x in scaling)
42
- sx, sy = scaling
43
- assert sx >= 1 and sy >= 1
44
- return sx, sy
45
-
46
- def _parse_padding(padding):
47
- if isinstance(padding, int):
48
- padding = [padding, padding]
49
- assert isinstance(padding, (list, tuple))
50
- assert all(isinstance(x, int) for x in padding)
51
- if len(padding) == 2:
52
- padx, pady = padding
53
- padding = [padx, padx, pady, pady]
54
- padx0, padx1, pady0, pady1 = padding
55
- return padx0, padx1, pady0, pady1
56
-
57
- def _get_filter_size(f):
58
- if f is None:
59
- return 1, 1
60
- assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
61
- fw = f.shape[-1]
62
- fh = f.shape[0]
63
- with misc.suppress_tracer_warnings():
64
- fw = int(fw)
65
- fh = int(fh)
66
- misc.assert_shape(f, [fh, fw][:f.ndim])
67
- assert fw >= 1 and fh >= 1
68
- return fw, fh
69
-
70
- #----------------------------------------------------------------------------
71
-
72
- def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None):
73
- r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
74
-
75
- Args:
76
- f: Torch tensor, numpy array, or python list of the shape
77
- `[filter_height, filter_width]` (non-separable),
78
- `[filter_taps]` (separable),
79
- `[]` (impulse), or
80
- `None` (identity).
81
- device: Result device (default: cpu).
82
- normalize: Normalize the filter so that it retains the magnitude
83
- for constant input signal (DC)? (default: True).
84
- flip_filter: Flip the filter? (default: False).
85
- gain: Overall scaling factor for signal magnitude (default: 1).
86
- separable: Return a separable filter? (default: select automatically).
87
-
88
- Returns:
89
- Float32 tensor of the shape
90
- `[filter_height, filter_width]` (non-separable) or
91
- `[filter_taps]` (separable).
92
- """
93
- # Validate.
94
- if f is None:
95
- f = 1
96
- f = torch.as_tensor(f, dtype=torch.float32)
97
- assert f.ndim in [0, 1, 2]
98
- assert f.numel() > 0
99
- if f.ndim == 0:
100
- f = f[np.newaxis]
101
-
102
- # Separable?
103
- if separable is None:
104
- separable = (f.ndim == 1 and f.numel() >= 8)
105
- if f.ndim == 1 and not separable:
106
- f = f.ger(f)
107
- assert f.ndim == (1 if separable else 2)
108
-
109
- # Apply normalize, flip, gain, and device.
110
- if normalize:
111
- f /= f.sum()
112
- if flip_filter:
113
- f = f.flip(list(range(f.ndim)))
114
- f = f * (gain ** (f.ndim / 2))
115
- f = f.to(device=device)
116
- return f
117
-
118
- #----------------------------------------------------------------------------
119
-
120
- def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
121
- r"""Pad, upsample, filter, and downsample a batch of 2D images.
122
-
123
- Performs the following sequence of operations for each channel:
124
-
125
- 1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
126
-
127
- 2. Pad the image with the specified number of zeros on each side (`padding`).
128
- Negative padding corresponds to cropping the image.
129
-
130
- 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
131
- so that the footprint of all output pixels lies within the input image.
132
-
133
- 4. Downsample the image by keeping every Nth pixel (`down`).
134
-
135
- This sequence of operations bears close resemblance to scipy.signal.upfirdn().
136
- The fused op is considerably more efficient than performing the same calculation
137
- using standard PyTorch ops. It supports gradients of arbitrary order.
138
-
139
- Args:
140
- x: Float32/float64/float16 input tensor of the shape
141
- `[batch_size, num_channels, in_height, in_width]`.
142
- f: Float32 FIR filter of the shape
143
- `[filter_height, filter_width]` (non-separable),
144
- `[filter_taps]` (separable), or
145
- `None` (identity).
146
- up: Integer upsampling factor. Can be a single int or a list/tuple
147
- `[x, y]` (default: 1).
148
- down: Integer downsampling factor. Can be a single int or a list/tuple
149
- `[x, y]` (default: 1).
150
- padding: Padding with respect to the upsampled image. Can be a single number
151
- or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
152
- (default: 0).
153
- flip_filter: False = convolution, True = correlation (default: False).
154
- gain: Overall scaling factor for signal magnitude (default: 1).
155
- impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
156
-
157
- Returns:
158
- Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
159
- """
160
- assert isinstance(x, torch.Tensor)
161
- assert impl in ['ref', 'cuda']
162
- if impl == 'cuda' and x.device.type == 'cuda' and _init():
163
- return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f)
164
- return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
165
-
166
- #----------------------------------------------------------------------------
167
-
168
- @misc.profiled_function
169
- def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
170
- """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
171
- """
172
- # Validate arguments.
173
- assert isinstance(x, torch.Tensor) and x.ndim == 4
174
- if f is None:
175
- f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
176
- assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
177
- assert f.dtype == torch.float32 and not f.requires_grad
178
- batch_size, num_channels, in_height, in_width = x.shape
179
- upx, upy = _parse_scaling(up)
180
- downx, downy = _parse_scaling(down)
181
- padx0, padx1, pady0, pady1 = _parse_padding(padding)
182
-
183
- # Upsample by inserting zeros.
184
- x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
185
- x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
186
- x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
187
-
188
- # Pad or crop.
189
- x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
190
- x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)]
191
-
192
- # Setup filter.
193
- f = f * (gain ** (f.ndim / 2))
194
- f = f.to(x.dtype)
195
- if not flip_filter:
196
- f = f.flip(list(range(f.ndim)))
197
-
198
- # Convolve with the filter.
199
- f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
200
- if f.ndim == 4:
201
- x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels)
202
- else:
203
- x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
204
- x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
205
-
206
- # Downsample by throwing away pixels.
207
- x = x[:, :, ::downy, ::downx]
208
- return x
209
-
210
- #----------------------------------------------------------------------------
211
-
212
- _upfirdn2d_cuda_cache = dict()
213
-
214
- def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1):
215
- """Fast CUDA implementation of `upfirdn2d()` using custom ops.
216
- """
217
- # Parse arguments.
218
- upx, upy = _parse_scaling(up)
219
- downx, downy = _parse_scaling(down)
220
- padx0, padx1, pady0, pady1 = _parse_padding(padding)
221
-
222
- # Lookup from cache.
223
- key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
224
- if key in _upfirdn2d_cuda_cache:
225
- return _upfirdn2d_cuda_cache[key]
226
-
227
- # Forward op.
228
- class Upfirdn2dCuda(torch.autograd.Function):
229
- @staticmethod
230
- def forward(ctx, x, f): # pylint: disable=arguments-differ
231
- assert isinstance(x, torch.Tensor) and x.ndim == 4
232
- if f is None:
233
- f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
234
- assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
235
- y = x
236
- if f.ndim == 2:
237
- y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
238
- else:
239
- y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain))
240
- y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain))
241
- ctx.save_for_backward(f)
242
- ctx.x_shape = x.shape
243
- return y
244
-
245
- @staticmethod
246
- def backward(ctx, dy): # pylint: disable=arguments-differ
247
- f, = ctx.saved_tensors
248
- _, _, ih, iw = ctx.x_shape
249
- _, _, oh, ow = dy.shape
250
- fw, fh = _get_filter_size(f)
251
- p = [
252
- fw - padx0 - 1,
253
- iw * upx - ow * downx + padx0 - upx + 1,
254
- fh - pady0 - 1,
255
- ih * upy - oh * downy + pady0 - upy + 1,
256
- ]
257
- dx = None
258
- df = None
259
-
260
- if ctx.needs_input_grad[0]:
261
- dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f)
262
-
263
- assert not ctx.needs_input_grad[1]
264
- return dx, df
265
-
266
- # Add to cache.
267
- _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda
268
- return Upfirdn2dCuda
269
-
270
- #----------------------------------------------------------------------------
271
-
272
- def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'):
273
- r"""Filter a batch of 2D images using the given 2D FIR filter.
274
-
275
- By default, the result is padded so that its shape matches the input.
276
- User-specified padding is applied on top of that, with negative values
277
- indicating cropping. Pixels outside the image are assumed to be zero.
278
-
279
- Args:
280
- x: Float32/float64/float16 input tensor of the shape
281
- `[batch_size, num_channels, in_height, in_width]`.
282
- f: Float32 FIR filter of the shape
283
- `[filter_height, filter_width]` (non-separable),
284
- `[filter_taps]` (separable), or
285
- `None` (identity).
286
- padding: Padding with respect to the output. Can be a single number or a
287
- list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
288
- (default: 0).
289
- flip_filter: False = convolution, True = correlation (default: False).
290
- gain: Overall scaling factor for signal magnitude (default: 1).
291
- impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
292
-
293
- Returns:
294
- Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
295
- """
296
- padx0, padx1, pady0, pady1 = _parse_padding(padding)
297
- fw, fh = _get_filter_size(f)
298
- p = [
299
- padx0 + fw // 2,
300
- padx1 + (fw - 1) // 2,
301
- pady0 + fh // 2,
302
- pady1 + (fh - 1) // 2,
303
- ]
304
- return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
305
-
306
- #----------------------------------------------------------------------------
307
-
308
- def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
309
- r"""Upsample a batch of 2D images using the given 2D FIR filter.
310
-
311
- By default, the result is padded so that its shape is a multiple of the input.
312
- User-specified padding is applied on top of that, with negative values
313
- indicating cropping. Pixels outside the image are assumed to be zero.
314
-
315
- Args:
316
- x: Float32/float64/float16 input tensor of the shape
317
- `[batch_size, num_channels, in_height, in_width]`.
318
- f: Float32 FIR filter of the shape
319
- `[filter_height, filter_width]` (non-separable),
320
- `[filter_taps]` (separable), or
321
- `None` (identity).
322
- up: Integer upsampling factor. Can be a single int or a list/tuple
323
- `[x, y]` (default: 1).
324
- padding: Padding with respect to the output. Can be a single number or a
325
- list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
326
- (default: 0).
327
- flip_filter: False = convolution, True = correlation (default: False).
328
- gain: Overall scaling factor for signal magnitude (default: 1).
329
- impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
330
-
331
- Returns:
332
- Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
333
- """
334
- upx, upy = _parse_scaling(up)
335
- padx0, padx1, pady0, pady1 = _parse_padding(padding)
336
- fw, fh = _get_filter_size(f)
337
- p = [
338
- padx0 + (fw + upx - 1) // 2,
339
- padx1 + (fw - upx) // 2,
340
- pady0 + (fh + upy - 1) // 2,
341
- pady1 + (fh - upy) // 2,
342
- ]
343
- return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl)
344
-
345
- #----------------------------------------------------------------------------
346
-
347
- def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
348
- r"""Downsample a batch of 2D images using the given 2D FIR filter.
349
-
350
- By default, the result is padded so that its shape is a fraction of the input.
351
- User-specified padding is applied on top of that, with negative values
352
- indicating cropping. Pixels outside the image are assumed to be zero.
353
-
354
- Args:
355
- x: Float32/float64/float16 input tensor of the shape
356
- `[batch_size, num_channels, in_height, in_width]`.
357
- f: Float32 FIR filter of the shape
358
- `[filter_height, filter_width]` (non-separable),
359
- `[filter_taps]` (separable), or
360
- `None` (identity).
361
- down: Integer downsampling factor. Can be a single int or a list/tuple
362
- `[x, y]` (default: 1).
363
- padding: Padding with respect to the input. Can be a single number or a
364
- list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
365
- (default: 0).
366
- flip_filter: False = convolution, True = correlation (default: False).
367
- gain: Overall scaling factor for signal magnitude (default: 1).
368
- impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
369
-
370
- Returns:
371
- Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
372
- """
373
- downx, downy = _parse_scaling(down)
374
- padx0, padx1, pady0, pady1 = _parse_padding(padding)
375
- fw, fh = _get_filter_size(f)
376
- p = [
377
- padx0 + (fw - downx + 1) // 2,
378
- padx1 + (fw - downx) // 2,
379
- pady0 + (fh - downy + 1) // 2,
380
- pady1 + (fh - downy) // 2,
381
- ]
382
- return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
383
-
384
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/persistence.py DELETED
@@ -1,251 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- """Facilities for pickling Python code alongside other data.
10
-
11
- The pickled code is automatically imported into a separate Python module
12
- during unpickling. This way, any previously exported pickles will remain
13
- usable even if the original code is no longer available, or if the current
14
- version of the code is not consistent with what was originally pickled."""
15
-
16
- import sys
17
- import pickle
18
- import io
19
- import inspect
20
- import copy
21
- import uuid
22
- import types
23
- import dnnlib
24
-
25
- #----------------------------------------------------------------------------
26
-
27
- _version = 6 # internal version number
28
- _decorators = set() # {decorator_class, ...}
29
- _import_hooks = [] # [hook_function, ...]
30
- _module_to_src_dict = dict() # {module: src, ...}
31
- _src_to_module_dict = dict() # {src: module, ...}
32
-
33
- #----------------------------------------------------------------------------
34
-
35
- def persistent_class(orig_class):
36
- r"""Class decorator that extends a given class to save its source code
37
- when pickled.
38
-
39
- Example:
40
-
41
- from torch_utils import persistence
42
-
43
- @persistence.persistent_class
44
- class MyNetwork(torch.nn.Module):
45
- def __init__(self, num_inputs, num_outputs):
46
- super().__init__()
47
- self.fc = MyLayer(num_inputs, num_outputs)
48
- ...
49
-
50
- @persistence.persistent_class
51
- class MyLayer(torch.nn.Module):
52
- ...
53
-
54
- When pickled, any instance of `MyNetwork` and `MyLayer` will save its
55
- source code alongside other internal state (e.g., parameters, buffers,
56
- and submodules). This way, any previously exported pickle will remain
57
- usable even if the class definitions have been modified or are no
58
- longer available.
59
-
60
- The decorator saves the source code of the entire Python module
61
- containing the decorated class. It does *not* save the source code of
62
- any imported modules. Thus, the imported modules must be available
63
- during unpickling, also including `torch_utils.persistence` itself.
64
-
65
- It is ok to call functions defined in the same module from the
66
- decorated class. However, if the decorated class depends on other
67
- classes defined in the same module, they must be decorated as well.
68
- This is illustrated in the above example in the case of `MyLayer`.
69
-
70
- It is also possible to employ the decorator just-in-time before
71
- calling the constructor. For example:
72
-
73
- cls = MyLayer
74
- if want_to_make_it_persistent:
75
- cls = persistence.persistent_class(cls)
76
- layer = cls(num_inputs, num_outputs)
77
-
78
- As an additional feature, the decorator also keeps track of the
79
- arguments that were used to construct each instance of the decorated
80
- class. The arguments can be queried via `obj.init_args` and
81
- `obj.init_kwargs`, and they are automatically pickled alongside other
82
- object state. A typical use case is to first unpickle a previous
83
- instance of a persistent class, and then upgrade it to use the latest
84
- version of the source code:
85
-
86
- with open('old_pickle.pkl', 'rb') as f:
87
- old_net = pickle.load(f)
88
- new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
89
- misc.copy_params_and_buffers(old_net, new_net, require_all=True)
90
- """
91
- assert isinstance(orig_class, type)
92
- if is_persistent(orig_class):
93
- return orig_class
94
-
95
- assert orig_class.__module__ in sys.modules
96
- orig_module = sys.modules[orig_class.__module__]
97
- orig_module_src = _module_to_src(orig_module)
98
-
99
- class Decorator(orig_class):
100
- _orig_module_src = orig_module_src
101
- _orig_class_name = orig_class.__name__
102
-
103
- def __init__(self, *args, **kwargs):
104
- super().__init__(*args, **kwargs)
105
- self._init_args = copy.deepcopy(args)
106
- self._init_kwargs = copy.deepcopy(kwargs)
107
- assert orig_class.__name__ in orig_module.__dict__
108
- _check_pickleable(self.__reduce__())
109
-
110
- @property
111
- def init_args(self):
112
- return copy.deepcopy(self._init_args)
113
-
114
- @property
115
- def init_kwargs(self):
116
- return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs))
117
-
118
- def __reduce__(self):
119
- fields = list(super().__reduce__())
120
- fields += [None] * max(3 - len(fields), 0)
121
- if fields[0] is not _reconstruct_persistent_obj:
122
- meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
123
- fields[0] = _reconstruct_persistent_obj # reconstruct func
124
- fields[1] = (meta,) # reconstruct args
125
- fields[2] = None # state dict
126
- return tuple(fields)
127
-
128
- Decorator.__name__ = orig_class.__name__
129
- _decorators.add(Decorator)
130
- return Decorator
131
-
132
- #----------------------------------------------------------------------------
133
-
134
- def is_persistent(obj):
135
- r"""Test whether the given object or class is persistent, i.e.,
136
- whether it will save its source code when pickled.
137
- """
138
- try:
139
- if obj in _decorators:
140
- return True
141
- except TypeError:
142
- pass
143
- return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
144
-
145
- #----------------------------------------------------------------------------
146
-
147
- def import_hook(hook):
148
- r"""Register an import hook that is called whenever a persistent object
149
- is being unpickled. A typical use case is to patch the pickled source
150
- code to avoid errors and inconsistencies when the API of some imported
151
- module has changed.
152
-
153
- The hook should have the following signature:
154
-
155
- hook(meta) -> modified meta
156
-
157
- `meta` is an instance of `dnnlib.EasyDict` with the following fields:
158
-
159
- type: Type of the persistent object, e.g. `'class'`.
160
- version: Internal version number of `torch_utils.persistence`.
161
- module_src Original source code of the Python module.
162
- class_name: Class name in the original Python module.
163
- state: Internal state of the object.
164
-
165
- Example:
166
-
167
- @persistence.import_hook
168
- def wreck_my_network(meta):
169
- if meta.class_name == 'MyNetwork':
170
- print('MyNetwork is being imported. I will wreck it!')
171
- meta.module_src = meta.module_src.replace("True", "False")
172
- return meta
173
- """
174
- assert callable(hook)
175
- _import_hooks.append(hook)
176
-
177
- #----------------------------------------------------------------------------
178
-
179
- def _reconstruct_persistent_obj(meta):
180
- r"""Hook that is called internally by the `pickle` module to unpickle
181
- a persistent object.
182
- """
183
- meta = dnnlib.EasyDict(meta)
184
- meta.state = dnnlib.EasyDict(meta.state)
185
- for hook in _import_hooks:
186
- meta = hook(meta)
187
- assert meta is not None
188
-
189
- assert meta.version == _version
190
- module = _src_to_module(meta.module_src)
191
-
192
- assert meta.type == 'class'
193
- orig_class = module.__dict__[meta.class_name]
194
- decorator_class = persistent_class(orig_class)
195
- obj = decorator_class.__new__(decorator_class)
196
-
197
- setstate = getattr(obj, '__setstate__', None)
198
- if callable(setstate):
199
- setstate(meta.state) # pylint: disable=not-callable
200
- else:
201
- obj.__dict__.update(meta.state)
202
- return obj
203
-
204
- #----------------------------------------------------------------------------
205
-
206
- def _module_to_src(module):
207
- r"""Query the source code of a given Python module.
208
- """
209
- src = _module_to_src_dict.get(module, None)
210
- if src is None:
211
- src = inspect.getsource(module)
212
- _module_to_src_dict[module] = src
213
- _src_to_module_dict[src] = module
214
- return src
215
-
216
- def _src_to_module(src):
217
- r"""Get or create a Python module for the given source code.
218
- """
219
- module = _src_to_module_dict.get(src, None)
220
- if module is None:
221
- module_name = "_imported_module_" + uuid.uuid4().hex
222
- module = types.ModuleType(module_name)
223
- sys.modules[module_name] = module
224
- _module_to_src_dict[module] = src
225
- _src_to_module_dict[src] = module
226
- exec(src, module.__dict__) # pylint: disable=exec-used
227
- return module
228
-
229
- #----------------------------------------------------------------------------
230
-
231
- def _check_pickleable(obj):
232
- r"""Check that the given object is pickleable, raising an exception if
233
- it is not. This function is expected to be considerably more efficient
234
- than actually pickling the object.
235
- """
236
- def recurse(obj):
237
- if isinstance(obj, (list, tuple, set)):
238
- return [recurse(x) for x in obj]
239
- if isinstance(obj, dict):
240
- return [[recurse(x), recurse(y)] for x, y in obj.items()]
241
- if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
242
- return None # Python primitive types are pickleable.
243
- if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']:
244
- return None # NumPy arrays and PyTorch tensors are pickleable.
245
- if is_persistent(obj):
246
- return None # Persistent objects are pickleable, by virtue of the constructor check.
247
- return obj
248
- with io.BytesIO() as f:
249
- pickle.dump(recurse(obj), f)
250
-
251
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch_utils/training_stats.py DELETED
@@ -1,268 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- """Facilities for reporting and collecting training statistics across
10
- multiple processes and devices. The interface is designed to minimize
11
- synchronization overhead as well as the amount of boilerplate in user
12
- code."""
13
-
14
- import re
15
- import numpy as np
16
- import torch
17
- import dnnlib
18
-
19
- from . import misc
20
-
21
- #----------------------------------------------------------------------------
22
-
23
- _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
24
- _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
25
- _counter_dtype = torch.float64 # Data type to use for the internal counters.
26
- _rank = 0 # Rank of the current process.
27
- _sync_device = None # Device to use for multiprocess communication. None = single-process.
28
- _sync_called = False # Has _sync() been called yet?
29
- _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor
30
- _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
31
-
32
- #----------------------------------------------------------------------------
33
-
34
- def init_multiprocessing(rank, sync_device):
35
- r"""Initializes `torch_utils.training_stats` for collecting statistics
36
- across multiple processes.
37
-
38
- This function must be called after
39
- `torch.distributed.init_process_group()` and before `Collector.update()`.
40
- The call is not necessary if multi-process collection is not needed.
41
-
42
- Args:
43
- rank: Rank of the current process.
44
- sync_device: PyTorch device to use for inter-process
45
- communication, or None to disable multi-process
46
- collection. Typically `torch.device('cuda', rank)`.
47
- """
48
- global _rank, _sync_device
49
- assert not _sync_called
50
- _rank = rank
51
- _sync_device = sync_device
52
-
53
- #----------------------------------------------------------------------------
54
-
55
- @misc.profiled_function
56
- def report(name, value):
57
- r"""Broadcasts the given set of scalars to all interested instances of
58
- `Collector`, across device and process boundaries.
59
-
60
- This function is expected to be extremely cheap and can be safely
61
- called from anywhere in the training loop, loss function, or inside a
62
- `torch.nn.Module`.
63
-
64
- Warning: The current implementation expects the set of unique names to
65
- be consistent across processes. Please make sure that `report()` is
66
- called at least once for each unique name by each process, and in the
67
- same order. If a given process has no scalars to broadcast, it can do
68
- `report(name, [])` (empty list).
69
-
70
- Args:
71
- name: Arbitrary string specifying the name of the statistic.
72
- Averages are accumulated separately for each unique name.
73
- value: Arbitrary set of scalars. Can be a list, tuple,
74
- NumPy array, PyTorch tensor, or Python scalar.
75
-
76
- Returns:
77
- The same `value` that was passed in.
78
- """
79
- if name not in _counters:
80
- _counters[name] = dict()
81
-
82
- elems = torch.as_tensor(value)
83
- if elems.numel() == 0:
84
- return value
85
-
86
- elems = elems.detach().flatten().to(_reduce_dtype)
87
- moments = torch.stack([
88
- torch.ones_like(elems).sum(),
89
- elems.sum(),
90
- elems.square().sum(),
91
- ])
92
- assert moments.ndim == 1 and moments.shape[0] == _num_moments
93
- moments = moments.to(_counter_dtype)
94
-
95
- device = moments.device
96
- if device not in _counters[name]:
97
- _counters[name][device] = torch.zeros_like(moments)
98
- _counters[name][device].add_(moments)
99
- return value
100
-
101
- #----------------------------------------------------------------------------
102
-
103
- def report0(name, value):
104
- r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
105
- but ignores any scalars provided by the other processes.
106
- See `report()` for further details.
107
- """
108
- report(name, value if _rank == 0 else [])
109
- return value
110
-
111
- #----------------------------------------------------------------------------
112
-
113
- class Collector:
114
- r"""Collects the scalars broadcasted by `report()` and `report0()` and
115
- computes their long-term averages (mean and standard deviation) over
116
- user-defined periods of time.
117
-
118
- The averages are first collected into internal counters that are not
119
- directly visible to the user. They are then copied to the user-visible
120
- state as a result of calling `update()` and can then be queried using
121
- `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the
122
- internal counters for the next round, so that the user-visible state
123
- effectively reflects averages collected between the last two calls to
124
- `update()`.
125
-
126
- Args:
127
- regex: Regular expression defining which statistics to
128
- collect. The default is to collect everything.
129
- keep_previous: Whether to retain the previous averages if no
130
- scalars were collected on a given round
131
- (default: True).
132
- """
133
- def __init__(self, regex='.*', keep_previous=True):
134
- self._regex = re.compile(regex)
135
- self._keep_previous = keep_previous
136
- self._cumulative = dict()
137
- self._moments = dict()
138
- self.update()
139
- self._moments.clear()
140
-
141
- def names(self):
142
- r"""Returns the names of all statistics broadcasted so far that
143
- match the regular expression specified at construction time.
144
- """
145
- return [name for name in _counters if self._regex.fullmatch(name)]
146
-
147
- def update(self):
148
- r"""Copies current values of the internal counters to the
149
- user-visible state and resets them for the next round.
150
-
151
- If `keep_previous=True` was specified at construction time, the
152
- operation is skipped for statistics that have received no scalars
153
- since the last update, retaining their previous averages.
154
-
155
- This method performs a number of GPU-to-CPU transfers and one
156
- `torch.distributed.all_reduce()`. It is intended to be called
157
- periodically in the main training loop, typically once every
158
- N training steps.
159
- """
160
- if not self._keep_previous:
161
- self._moments.clear()
162
- for name, cumulative in _sync(self.names()):
163
- if name not in self._cumulative:
164
- self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
165
- delta = cumulative - self._cumulative[name]
166
- self._cumulative[name].copy_(cumulative)
167
- if float(delta[0]) != 0:
168
- self._moments[name] = delta
169
-
170
- def _get_delta(self, name):
171
- r"""Returns the raw moments that were accumulated for the given
172
- statistic between the last two calls to `update()`, or zero if
173
- no scalars were collected.
174
- """
175
- assert self._regex.fullmatch(name)
176
- if name not in self._moments:
177
- self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
178
- return self._moments[name]
179
-
180
- def num(self, name):
181
- r"""Returns the number of scalars that were accumulated for the given
182
- statistic between the last two calls to `update()`, or zero if
183
- no scalars were collected.
184
- """
185
- delta = self._get_delta(name)
186
- return int(delta[0])
187
-
188
- def mean(self, name):
189
- r"""Returns the mean of the scalars that were accumulated for the
190
- given statistic between the last two calls to `update()`, or NaN if
191
- no scalars were collected.
192
- """
193
- delta = self._get_delta(name)
194
- if int(delta[0]) == 0:
195
- return float('nan')
196
- return float(delta[1] / delta[0])
197
-
198
- def std(self, name):
199
- r"""Returns the standard deviation of the scalars that were
200
- accumulated for the given statistic between the last two calls to
201
- `update()`, or NaN if no scalars were collected.
202
- """
203
- delta = self._get_delta(name)
204
- if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
205
- return float('nan')
206
- if int(delta[0]) == 1:
207
- return float(0)
208
- mean = float(delta[1] / delta[0])
209
- raw_var = float(delta[2] / delta[0])
210
- return np.sqrt(max(raw_var - np.square(mean), 0))
211
-
212
- def as_dict(self):
213
- r"""Returns the averages accumulated between the last two calls to
214
- `update()` as an `dnnlib.EasyDict`. The contents are as follows:
215
-
216
- dnnlib.EasyDict(
217
- NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT),
218
- ...
219
- )
220
- """
221
- stats = dnnlib.EasyDict()
222
- for name in self.names():
223
- stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
224
- return stats
225
-
226
- def __getitem__(self, name):
227
- r"""Convenience getter.
228
- `collector[name]` is a synonym for `collector.mean(name)`.
229
- """
230
- return self.mean(name)
231
-
232
- #----------------------------------------------------------------------------
233
-
234
- def _sync(names):
235
- r"""Synchronize the global cumulative counters across devices and
236
- processes. Called internally by `Collector.update()`.
237
- """
238
- if len(names) == 0:
239
- return []
240
- global _sync_called
241
- _sync_called = True
242
-
243
- # Collect deltas within current rank.
244
- deltas = []
245
- device = _sync_device if _sync_device is not None else torch.device('cpu')
246
- for name in names:
247
- delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
248
- for counter in _counters[name].values():
249
- delta.add_(counter.to(device))
250
- counter.copy_(torch.zeros_like(counter))
251
- deltas.append(delta)
252
- deltas = torch.stack(deltas)
253
-
254
- # Sum deltas across ranks.
255
- if _sync_device is not None:
256
- torch.distributed.all_reduce(deltas)
257
-
258
- # Update cumulative values.
259
- deltas = deltas.cpu()
260
- for idx, name in enumerate(names):
261
- if name not in _cumulative:
262
- _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
263
- _cumulative[name].add_(deltas[idx])
264
-
265
- # Return name-value pairs.
266
- return [(name, _cumulative[name]) for name in names]
267
-
268
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train.py DELETED
@@ -1,540 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- """Train a GAN using the techniques described in the paper
10
- "Training Generative Adversarial Networks with Limited Data"."""
11
-
12
- import os
13
- import click
14
- import re
15
- import json
16
- import tempfile
17
- import torch
18
- import dnnlib
19
-
20
- from training import training_loop
21
- from metrics import metric_main
22
- from torch_utils import training_stats
23
- from torch_utils import custom_ops
24
-
25
- #----------------------------------------------------------------------------
26
-
27
- class UserError(Exception):
28
- pass
29
-
30
- #----------------------------------------------------------------------------
31
-
32
- def setup_training_loop_kwargs(
33
- # General options (not included in desc).
34
- gpus = None, # Number of GPUs: <int>, default = 1 gpu
35
- snap = None, # Snapshot interval: <int>, default = 50 ticks
36
- metrics = None, # List of metric names: [], ['fid50k_full'] (default), ...
37
- seed = None, # Random seed: <int>, default = 0
38
-
39
- # Dataset.
40
- data = None, # Training dataset (required): <path>
41
- cond = None, # Train conditional model based on dataset labels: <bool>, default = False
42
- subset = None, # Train with only N images: <int>, default = all
43
- mirror = None, # Augment dataset with x-flips: <bool>, default = False
44
-
45
- # Base config.
46
- cfg = None, # Base config: 'auto' (default), 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar'
47
- gamma = None, # Override R1 gamma: <float>
48
- kimg = None, # Override training duration: <int>
49
- batch = None, # Override batch size: <int>
50
-
51
- # Discriminator augmentation.
52
- aug = None, # Augmentation mode: 'ada' (default), 'noaug', 'fixed'
53
- p = None, # Specify p for 'fixed' (required): <float>
54
- target = None, # Override ADA target for 'ada': <float>, default = depends on aug
55
- augpipe = None, # Augmentation pipeline: 'blit', 'geom', 'color', 'filter', 'noise', 'cutout', 'bg', 'bgc' (default), ..., 'bgcfnc'
56
-
57
- # Transfer learning.
58
- resume = None, # Load previous network: 'noresume' (default), 'ffhq256', 'ffhq512', 'ffhq1024', 'celebahq256', 'lsundog256', <file>, <url>
59
- freezed = None, # Freeze-D: <int>, default = 0 discriminator layers
60
-
61
- # Performance options (not included in desc).
62
- fp32 = None, # Disable mixed-precision training: <bool>, default = False
63
- nhwc = None, # Use NHWC memory format with FP16: <bool>, default = False
64
- allow_tf32 = None, # Allow PyTorch to use TF32 for matmul and convolutions: <bool>, default = False
65
- nobench = None, # Disable cuDNN benchmarking: <bool>, default = False
66
- workers = None, # Override number of DataLoader workers: <int>, default = 3
67
- ):
68
- args = dnnlib.EasyDict()
69
-
70
- # ------------------------------------------
71
- # General options: gpus, snap, metrics, seed
72
- # ------------------------------------------
73
-
74
- if gpus is None:
75
- gpus = 1
76
- assert isinstance(gpus, int)
77
- if not (gpus >= 1 and gpus & (gpus - 1) == 0):
78
- raise UserError('--gpus must be a power of two')
79
- args.num_gpus = gpus
80
-
81
- if snap is None:
82
- snap = 50
83
- assert isinstance(snap, int)
84
- if snap < 1:
85
- raise UserError('--snap must be at least 1')
86
- args.image_snapshot_ticks = snap
87
- args.network_snapshot_ticks = snap
88
-
89
- if metrics is None:
90
- metrics = ['fid50k_full']
91
- assert isinstance(metrics, list)
92
- if not all(metric_main.is_valid_metric(metric) for metric in metrics):
93
- raise UserError('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics()))
94
- args.metrics = metrics
95
-
96
- if seed is None:
97
- seed = 0
98
- assert isinstance(seed, int)
99
- args.random_seed = seed
100
-
101
- # -----------------------------------
102
- # Dataset: data, cond, subset, mirror
103
- # -----------------------------------
104
-
105
- assert data is not None
106
- assert isinstance(data, str)
107
- args.training_set_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=data, use_labels=True, max_size=None, xflip=False)
108
- args.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, num_workers=3, prefetch_factor=2)
109
- try:
110
- training_set = dnnlib.util.construct_class_by_name(**args.training_set_kwargs) # subclass of training.dataset.Dataset
111
- args.training_set_kwargs.resolution = training_set.resolution # be explicit about resolution
112
- args.training_set_kwargs.use_labels = training_set.has_labels # be explicit about labels
113
- args.training_set_kwargs.max_size = len(training_set) # be explicit about dataset size
114
- desc = training_set.name
115
- del training_set # conserve memory
116
- except IOError as err:
117
- raise UserError(f'--data: {err}')
118
-
119
- if cond is None:
120
- cond = False
121
- assert isinstance(cond, bool)
122
- if cond:
123
- if not args.training_set_kwargs.use_labels:
124
- raise UserError('--cond=True requires labels specified in dataset.json')
125
- desc += '-cond'
126
- else:
127
- args.training_set_kwargs.use_labels = False
128
-
129
- if subset is not None:
130
- assert isinstance(subset, int)
131
- if not 1 <= subset <= args.training_set_kwargs.max_size:
132
- raise UserError(f'--subset must be between 1 and {args.training_set_kwargs.max_size}')
133
- desc += f'-subset{subset}'
134
- if subset < args.training_set_kwargs.max_size:
135
- args.training_set_kwargs.max_size = subset
136
- args.training_set_kwargs.random_seed = args.random_seed
137
-
138
- if mirror is None:
139
- mirror = False
140
- assert isinstance(mirror, bool)
141
- if mirror:
142
- desc += '-mirror'
143
- args.training_set_kwargs.xflip = True
144
-
145
- # ------------------------------------
146
- # Base config: cfg, gamma, kimg, batch
147
- # ------------------------------------
148
-
149
- if cfg is None:
150
- cfg = 'auto'
151
- assert isinstance(cfg, str)
152
- desc += f'-{cfg}'
153
-
154
- cfg_specs = {
155
- 'auto': dict(ref_gpus=-1, kimg=25000, mb=-1, mbstd=-1, fmaps=-1, lrate=-1, gamma=-1, ema=-1, ramp=0.05, map=2), # Populated dynamically based on resolution and GPU count.
156
- 'stylegan2': dict(ref_gpus=8, kimg=25000, mb=32, mbstd=4, fmaps=1, lrate=0.002, gamma=10, ema=10, ramp=None, map=8), # Uses mixed-precision, unlike the original StyleGAN2.
157
- 'paper256': dict(ref_gpus=8, kimg=25000, mb=64, mbstd=8, fmaps=0.5, lrate=0.0025, gamma=1, ema=20, ramp=None, map=8),
158
- 'paper512': dict(ref_gpus=8, kimg=25000, mb=64, mbstd=8, fmaps=1, lrate=0.0025, gamma=0.5, ema=20, ramp=None, map=8),
159
- 'paper1024': dict(ref_gpus=8, kimg=25000, mb=32, mbstd=4, fmaps=1, lrate=0.002, gamma=2, ema=10, ramp=None, map=8),
160
- 'cifar': dict(ref_gpus=2, kimg=100000, mb=64, mbstd=32, fmaps=1, lrate=0.0025, gamma=0.01, ema=500, ramp=0.05, map=2),
161
- }
162
-
163
- assert cfg in cfg_specs
164
- spec = dnnlib.EasyDict(cfg_specs[cfg])
165
- if cfg == 'auto':
166
- desc += f'{gpus:d}'
167
- spec.ref_gpus = gpus
168
- res = args.training_set_kwargs.resolution
169
- spec.mb = max(min(gpus * min(4096 // res, 32), 64), gpus) # keep gpu memory consumption at bay
170
- spec.mbstd = min(spec.mb // gpus, 4) # other hyperparams behave more predictably if mbstd group size remains fixed
171
- spec.fmaps = 1 if res >= 512 else 0.5
172
- spec.lrate = 0.002 if res >= 1024 else 0.0025
173
- spec.gamma = 0.0002 * (res ** 2) / spec.mb # heuristic formula
174
- spec.ema = spec.mb * 10 / 32
175
-
176
- args.G_kwargs = dnnlib.EasyDict(class_name='training.networks.Generator', z_dim=512, w_dim=512, mapping_kwargs=dnnlib.EasyDict(), synthesis_kwargs=dnnlib.EasyDict())
177
- args.D_kwargs = dnnlib.EasyDict(class_name='training.networks.Discriminator', block_kwargs=dnnlib.EasyDict(), mapping_kwargs=dnnlib.EasyDict(), epilogue_kwargs=dnnlib.EasyDict())
178
- args.G_kwargs.synthesis_kwargs.channel_base = args.D_kwargs.channel_base = int(spec.fmaps * 32768)
179
- args.G_kwargs.synthesis_kwargs.channel_max = args.D_kwargs.channel_max = 512
180
- args.G_kwargs.mapping_kwargs.num_layers = spec.map
181
- args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 4 # enable mixed-precision training
182
- args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = 256 # clamp activations to avoid float16 overflow
183
- args.D_kwargs.epilogue_kwargs.mbstd_group_size = spec.mbstd
184
-
185
- args.G_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=spec.lrate, betas=[0,0.99], eps=1e-8)
186
- args.D_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=spec.lrate, betas=[0,0.99], eps=1e-8)
187
- args.loss_kwargs = dnnlib.EasyDict(class_name='training.loss.StyleGAN2Loss', r1_gamma=spec.gamma)
188
-
189
- args.total_kimg = spec.kimg
190
- args.batch_size = spec.mb
191
- args.batch_gpu = spec.mb // spec.ref_gpus
192
- args.ema_kimg = spec.ema
193
- args.ema_rampup = spec.ramp
194
-
195
- if cfg == 'cifar':
196
- args.loss_kwargs.pl_weight = 0 # disable path length regularization
197
- args.loss_kwargs.style_mixing_prob = 0 # disable style mixing
198
- args.D_kwargs.architecture = 'orig' # disable residual skip connections
199
-
200
- if gamma is not None:
201
- assert isinstance(gamma, float)
202
- if not gamma >= 0:
203
- raise UserError('--gamma must be non-negative')
204
- desc += f'-gamma{gamma:g}'
205
- args.loss_kwargs.r1_gamma = gamma
206
-
207
- if kimg is not None:
208
- assert isinstance(kimg, int)
209
- if not kimg >= 1:
210
- raise UserError('--kimg must be at least 1')
211
- desc += f'-kimg{kimg:d}'
212
- args.total_kimg = kimg
213
-
214
- if batch is not None:
215
- assert isinstance(batch, int)
216
- if not (batch >= 1 and batch % gpus == 0):
217
- raise UserError('--batch must be at least 1 and divisible by --gpus')
218
- desc += f'-batch{batch}'
219
- args.batch_size = batch
220
- args.batch_gpu = batch // gpus
221
-
222
- # ---------------------------------------------------
223
- # Discriminator augmentation: aug, p, target, augpipe
224
- # ---------------------------------------------------
225
-
226
- if aug is None:
227
- aug = 'ada'
228
- else:
229
- assert isinstance(aug, str)
230
- desc += f'-{aug}'
231
-
232
- if aug == 'ada':
233
- args.ada_target = 0.6
234
-
235
- elif aug == 'noaug':
236
- pass
237
-
238
- elif aug == 'fixed':
239
- if p is None:
240
- raise UserError(f'--aug={aug} requires specifying --p')
241
-
242
- else:
243
- raise UserError(f'--aug={aug} not supported')
244
-
245
- if p is not None:
246
- assert isinstance(p, float)
247
- if aug != 'fixed':
248
- raise UserError('--p can only be specified with --aug=fixed')
249
- if not 0 <= p <= 1:
250
- raise UserError('--p must be between 0 and 1')
251
- desc += f'-p{p:g}'
252
- args.augment_p = p
253
-
254
- if target is not None:
255
- assert isinstance(target, float)
256
- if aug != 'ada':
257
- raise UserError('--target can only be specified with --aug=ada')
258
- if not 0 <= target <= 1:
259
- raise UserError('--target must be between 0 and 1')
260
- desc += f'-target{target:g}'
261
- args.ada_target = target
262
-
263
- assert augpipe is None or isinstance(augpipe, str)
264
- if augpipe is None:
265
- augpipe = 'bgc'
266
- else:
267
- if aug == 'noaug':
268
- raise UserError('--augpipe cannot be specified with --aug=noaug')
269
- desc += f'-{augpipe}'
270
-
271
- augpipe_specs = {
272
- 'blit': dict(xflip=1, rotate90=1, xint=1),
273
- 'geom': dict(scale=1, rotate=1, aniso=1, xfrac=1),
274
- 'color': dict(brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1),
275
- 'filter': dict(imgfilter=1),
276
- 'noise': dict(noise=1),
277
- 'cutout': dict(cutout=1),
278
- 'bg': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1),
279
- 'bgc': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1),
280
- 'bgcf': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1),
281
- 'bgcfn': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1, noise=1),
282
- 'bgcfnc': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1, noise=1, cutout=1),
283
- }
284
-
285
- assert augpipe in augpipe_specs
286
- if aug != 'noaug':
287
- args.augment_kwargs = dnnlib.EasyDict(class_name='training.augment.AugmentPipe', **augpipe_specs[augpipe])
288
-
289
- # ----------------------------------
290
- # Transfer learning: resume, freezed
291
- # ----------------------------------
292
-
293
- resume_specs = {
294
- 'ffhq256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res256-mirror-paper256-noaug.pkl',
295
- 'ffhq512': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res512-mirror-stylegan2-noaug.pkl',
296
- 'ffhq1024': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res1024-mirror-stylegan2-noaug.pkl',
297
- 'celebahq256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/celebahq-res256-mirror-paper256-kimg100000-ada-target0.5.pkl',
298
- 'lsundog256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/lsundog-res256-paper256-kimg100000-noaug.pkl',
299
- }
300
-
301
- assert resume is None or isinstance(resume, str)
302
- if resume is None:
303
- resume = 'noresume'
304
- elif resume == 'noresume':
305
- desc += '-noresume'
306
- elif resume in resume_specs:
307
- desc += f'-resume{resume}'
308
- args.resume_pkl = resume_specs[resume] # predefined url
309
- else:
310
- desc += '-resumecustom'
311
- args.resume_pkl = resume # custom path or url
312
-
313
- if resume != 'noresume':
314
- args.ada_kimg = 100 # make ADA react faster at the beginning
315
- args.ema_rampup = None # disable EMA rampup
316
-
317
- if freezed is not None:
318
- assert isinstance(freezed, int)
319
- if not freezed >= 0:
320
- raise UserError('--freezed must be non-negative')
321
- desc += f'-freezed{freezed:d}'
322
- args.D_kwargs.block_kwargs.freeze_layers = freezed
323
-
324
- # -------------------------------------------------
325
- # Performance options: fp32, nhwc, nobench, workers
326
- # -------------------------------------------------
327
-
328
- if fp32 is None:
329
- fp32 = False
330
- assert isinstance(fp32, bool)
331
- if fp32:
332
- args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 0
333
- args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = None
334
-
335
- if nhwc is None:
336
- nhwc = False
337
- assert isinstance(nhwc, bool)
338
- if nhwc:
339
- args.G_kwargs.synthesis_kwargs.fp16_channels_last = args.D_kwargs.block_kwargs.fp16_channels_last = True
340
-
341
- if nobench is None:
342
- nobench = False
343
- assert isinstance(nobench, bool)
344
- if nobench:
345
- args.cudnn_benchmark = False
346
-
347
- if allow_tf32 is None:
348
- allow_tf32 = False
349
- assert isinstance(allow_tf32, bool)
350
- if allow_tf32:
351
- args.allow_tf32 = True
352
-
353
- if workers is not None:
354
- assert isinstance(workers, int)
355
- if not workers >= 1:
356
- raise UserError('--workers must be at least 1')
357
- args.data_loader_kwargs.num_workers = workers
358
-
359
- return desc, args
360
-
361
- #----------------------------------------------------------------------------
362
-
363
- def subprocess_fn(rank, args, temp_dir):
364
- dnnlib.util.Logger(file_name=os.path.join(args.run_dir, 'log.txt'), file_mode='a', should_flush=True)
365
-
366
- # Init torch.distributed.
367
- if args.num_gpus > 1:
368
- init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init'))
369
- if os.name == 'nt':
370
- init_method = 'file:///' + init_file.replace('\\', '/')
371
- torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus)
372
- else:
373
- init_method = f'file://{init_file}'
374
- torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus)
375
-
376
- # Init torch_utils.
377
- sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None
378
- training_stats.init_multiprocessing(rank=rank, sync_device=sync_device)
379
- if rank != 0:
380
- custom_ops.verbosity = 'none'
381
-
382
- # Execute training loop.
383
- training_loop.training_loop(rank=rank, **args)
384
-
385
- #----------------------------------------------------------------------------
386
-
387
- class CommaSeparatedList(click.ParamType):
388
- name = 'list'
389
-
390
- def convert(self, value, param, ctx):
391
- _ = param, ctx
392
- if value is None or value.lower() == 'none' or value == '':
393
- return []
394
- return value.split(',')
395
-
396
- #----------------------------------------------------------------------------
397
-
398
- @click.command()
399
- @click.pass_context
400
-
401
- # General options.
402
- @click.option('--outdir', help='Where to save the results', required=True, metavar='DIR')
403
- @click.option('--gpus', help='Number of GPUs to use [default: 1]', type=int, metavar='INT')
404
- @click.option('--snap', help='Snapshot interval [default: 50 ticks]', type=int, metavar='INT')
405
- @click.option('--metrics', help='Comma-separated list or "none" [default: fid50k_full]', type=CommaSeparatedList())
406
- @click.option('--seed', help='Random seed [default: 0]', type=int, metavar='INT')
407
- @click.option('-n', '--dry-run', help='Print training options and exit', is_flag=True)
408
-
409
- # Dataset.
410
- @click.option('--data', help='Training data (directory or zip)', metavar='PATH', required=True)
411
- @click.option('--cond', help='Train conditional model based on dataset labels [default: false]', type=bool, metavar='BOOL')
412
- @click.option('--subset', help='Train with only N images [default: all]', type=int, metavar='INT')
413
- @click.option('--mirror', help='Enable dataset x-flips [default: false]', type=bool, metavar='BOOL')
414
-
415
- # Base config.
416
- @click.option('--cfg', help='Base config [default: auto]', type=click.Choice(['auto', 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar']))
417
- @click.option('--gamma', help='Override R1 gamma', type=float)
418
- @click.option('--kimg', help='Override training duration', type=int, metavar='INT')
419
- @click.option('--batch', help='Override batch size', type=int, metavar='INT')
420
-
421
- # Discriminator augmentation.
422
- @click.option('--aug', help='Augmentation mode [default: ada]', type=click.Choice(['noaug', 'ada', 'fixed']))
423
- @click.option('--p', help='Augmentation probability for --aug=fixed', type=float)
424
- @click.option('--target', help='ADA target value for --aug=ada', type=float)
425
- @click.option('--augpipe', help='Augmentation pipeline [default: bgc]', type=click.Choice(['blit', 'geom', 'color', 'filter', 'noise', 'cutout', 'bg', 'bgc', 'bgcf', 'bgcfn', 'bgcfnc']))
426
-
427
- # Transfer learning.
428
- @click.option('--resume', help='Resume training [default: noresume]', metavar='PKL')
429
- @click.option('--freezed', help='Freeze-D [default: 0 layers]', type=int, metavar='INT')
430
-
431
- # Performance options.
432
- @click.option('--fp32', help='Disable mixed-precision training', type=bool, metavar='BOOL')
433
- @click.option('--nhwc', help='Use NHWC memory format with FP16', type=bool, metavar='BOOL')
434
- @click.option('--nobench', help='Disable cuDNN benchmarking', type=bool, metavar='BOOL')
435
- @click.option('--allow-tf32', help='Allow PyTorch to use TF32 internally', type=bool, metavar='BOOL')
436
- @click.option('--workers', help='Override number of DataLoader workers', type=int, metavar='INT')
437
-
438
- def main(ctx, outdir, dry_run, **config_kwargs):
439
- """Train a GAN using the techniques described in the paper
440
- "Training Generative Adversarial Networks with Limited Data".
441
-
442
- Examples:
443
-
444
- \b
445
- # Train with custom dataset using 1 GPU.
446
- python train.py --outdir=~/training-runs --data=~/mydataset.zip --gpus=1
447
-
448
- \b
449
- # Train class-conditional CIFAR-10 using 2 GPUs.
450
- python train.py --outdir=~/training-runs --data=~/datasets/cifar10.zip \\
451
- --gpus=2 --cfg=cifar --cond=1
452
-
453
- \b
454
- # Transfer learn MetFaces from FFHQ using 4 GPUs.
455
- python train.py --outdir=~/training-runs --data=~/datasets/metfaces.zip \\
456
- --gpus=4 --cfg=paper1024 --mirror=1 --resume=ffhq1024 --snap=10
457
-
458
- \b
459
- # Reproduce original StyleGAN2 config F.
460
- python train.py --outdir=~/training-runs --data=~/datasets/ffhq.zip \\
461
- --gpus=8 --cfg=stylegan2 --mirror=1 --aug=noaug
462
-
463
- \b
464
- Base configs (--cfg):
465
- auto Automatically select reasonable defaults based on resolution
466
- and GPU count. Good starting point for new datasets.
467
- stylegan2 Reproduce results for StyleGAN2 config F at 1024x1024.
468
- paper256 Reproduce results for FFHQ and LSUN Cat at 256x256.
469
- paper512 Reproduce results for BreCaHAD and AFHQ at 512x512.
470
- paper1024 Reproduce results for MetFaces at 1024x1024.
471
- cifar Reproduce results for CIFAR-10 at 32x32.
472
-
473
- \b
474
- Transfer learning source networks (--resume):
475
- ffhq256 FFHQ trained at 256x256 resolution.
476
- ffhq512 FFHQ trained at 512x512 resolution.
477
- ffhq1024 FFHQ trained at 1024x1024 resolution.
478
- celebahq256 CelebA-HQ trained at 256x256 resolution.
479
- lsundog256 LSUN Dog trained at 256x256 resolution.
480
- <PATH or URL> Custom network pickle.
481
- """
482
- dnnlib.util.Logger(should_flush=True)
483
-
484
- # Setup training options.
485
- try:
486
- run_desc, args = setup_training_loop_kwargs(**config_kwargs)
487
- except UserError as err:
488
- ctx.fail(err)
489
-
490
- # Pick output directory.
491
- prev_run_dirs = []
492
- if os.path.isdir(outdir):
493
- prev_run_dirs = [x for x in os.listdir(outdir) if os.path.isdir(os.path.join(outdir, x))]
494
- prev_run_ids = [re.match(r'^\d+', x) for x in prev_run_dirs]
495
- prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None]
496
- cur_run_id = max(prev_run_ids, default=-1) + 1
497
- args.run_dir = os.path.join(outdir, f'{cur_run_id:05d}-{run_desc}')
498
- assert not os.path.exists(args.run_dir)
499
-
500
- # Print options.
501
- print()
502
- print('Training options:')
503
- print(json.dumps(args, indent=2))
504
- print()
505
- print(f'Output directory: {args.run_dir}')
506
- print(f'Training data: {args.training_set_kwargs.path}')
507
- print(f'Training duration: {args.total_kimg} kimg')
508
- print(f'Number of GPUs: {args.num_gpus}')
509
- print(f'Number of images: {args.training_set_kwargs.max_size}')
510
- print(f'Image resolution: {args.training_set_kwargs.resolution}')
511
- print(f'Conditional model: {args.training_set_kwargs.use_labels}')
512
- print(f'Dataset x-flips: {args.training_set_kwargs.xflip}')
513
- print()
514
-
515
- # Dry run?
516
- if dry_run:
517
- print('Dry run; exiting.')
518
- return
519
-
520
- # Create output directory.
521
- print('Creating output directory...')
522
- os.makedirs(args.run_dir)
523
- with open(os.path.join(args.run_dir, 'training_options.json'), 'wt') as f:
524
- json.dump(args, f, indent=2)
525
-
526
- # Launch processes.
527
- print('Launching processes...')
528
- torch.multiprocessing.set_start_method('spawn')
529
- with tempfile.TemporaryDirectory() as temp_dir:
530
- if args.num_gpus == 1:
531
- subprocess_fn(rank=0, args=args, temp_dir=temp_dir)
532
- else:
533
- torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus)
534
-
535
- #----------------------------------------------------------------------------
536
-
537
- if __name__ == "__main__":
538
- main() # pylint: disable=no-value-for-parameter
539
-
540
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training/__init__.py DELETED
@@ -1,9 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- # empty
 
 
 
 
 
 
 
 
 
 
training/augment.py DELETED
@@ -1,431 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- import numpy as np
10
- import scipy.signal
11
- import torch
12
- from torch_utils import persistence
13
- from torch_utils import misc
14
- from torch_utils.ops import upfirdn2d
15
- from torch_utils.ops import grid_sample_gradfix
16
- from torch_utils.ops import conv2d_gradfix
17
-
18
- #----------------------------------------------------------------------------
19
- # Coefficients of various wavelet decomposition low-pass filters.
20
-
21
- wavelets = {
22
- 'haar': [0.7071067811865476, 0.7071067811865476],
23
- 'db1': [0.7071067811865476, 0.7071067811865476],
24
- 'db2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
25
- 'db3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
26
- 'db4': [-0.010597401784997278, 0.032883011666982945, 0.030841381835986965, -0.18703481171888114, -0.02798376941698385, 0.6308807679295904, 0.7148465705525415, 0.23037781330885523],
27
- 'db5': [0.003335725285001549, -0.012580751999015526, -0.006241490213011705, 0.07757149384006515, -0.03224486958502952, -0.24229488706619015, 0.13842814590110342, 0.7243085284385744, 0.6038292697974729, 0.160102397974125],
28
- 'db6': [-0.00107730108499558, 0.004777257511010651, 0.0005538422009938016, -0.031582039318031156, 0.02752286553001629, 0.09750160558707936, -0.12976686756709563, -0.22626469396516913, 0.3152503517092432, 0.7511339080215775, 0.4946238903983854, 0.11154074335008017],
29
- 'db7': [0.0003537138000010399, -0.0018016407039998328, 0.00042957797300470274, 0.012550998556013784, -0.01657454163101562, -0.03802993693503463, 0.0806126091510659, 0.07130921926705004, -0.22403618499416572, -0.14390600392910627, 0.4697822874053586, 0.7291320908465551, 0.39653931948230575, 0.07785205408506236],
30
- 'db8': [-0.00011747678400228192, 0.0006754494059985568, -0.0003917403729959771, -0.00487035299301066, 0.008746094047015655, 0.013981027917015516, -0.04408825393106472, -0.01736930100202211, 0.128747426620186, 0.00047248457399797254, -0.2840155429624281, -0.015829105256023893, 0.5853546836548691, 0.6756307362980128, 0.3128715909144659, 0.05441584224308161],
31
- 'sym2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
32
- 'sym3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
33
- 'sym4': [-0.07576571478927333, -0.02963552764599851, 0.49761866763201545, 0.8037387518059161, 0.29785779560527736, -0.09921954357684722, -0.012603967262037833, 0.0322231006040427],
34
- 'sym5': [0.027333068345077982, 0.029519490925774643, -0.039134249302383094, 0.1993975339773936, 0.7234076904024206, 0.6339789634582119, 0.01660210576452232, -0.17532808990845047, -0.021101834024758855, 0.019538882735286728],
35
- 'sym6': [0.015404109327027373, 0.0034907120842174702, -0.11799011114819057, -0.048311742585633, 0.4910559419267466, 0.787641141030194, 0.3379294217276218, -0.07263752278646252, -0.021060292512300564, 0.04472490177066578, 0.0017677118642428036, -0.007800708325034148],
36
- 'sym7': [0.002681814568257878, -0.0010473848886829163, -0.01263630340325193, 0.03051551316596357, 0.0678926935013727, -0.049552834937127255, 0.017441255086855827, 0.5361019170917628, 0.767764317003164, 0.2886296317515146, -0.14004724044296152, -0.10780823770381774, 0.004010244871533663, 0.010268176708511255],
37
- 'sym8': [-0.0033824159510061256, -0.0005421323317911481, 0.03169508781149298, 0.007607487324917605, -0.1432942383508097, -0.061273359067658524, 0.4813596512583722, 0.7771857517005235, 0.3644418948353314, -0.05194583810770904, -0.027219029917056003, 0.049137179673607506, 0.003808752013890615, -0.01495225833704823, -0.0003029205147213668, 0.0018899503327594609],
38
- }
39
-
40
- #----------------------------------------------------------------------------
41
- # Helpers for constructing transformation matrices.
42
-
43
- def matrix(*rows, device=None):
44
- assert all(len(row) == len(rows[0]) for row in rows)
45
- elems = [x for row in rows for x in row]
46
- ref = [x for x in elems if isinstance(x, torch.Tensor)]
47
- if len(ref) == 0:
48
- return misc.constant(np.asarray(rows), device=device)
49
- assert device is None or device == ref[0].device
50
- elems = [x if isinstance(x, torch.Tensor) else misc.constant(x, shape=ref[0].shape, device=ref[0].device) for x in elems]
51
- return torch.stack(elems, dim=-1).reshape(ref[0].shape + (len(rows), -1))
52
-
53
- def translate2d(tx, ty, **kwargs):
54
- return matrix(
55
- [1, 0, tx],
56
- [0, 1, ty],
57
- [0, 0, 1],
58
- **kwargs)
59
-
60
- def translate3d(tx, ty, tz, **kwargs):
61
- return matrix(
62
- [1, 0, 0, tx],
63
- [0, 1, 0, ty],
64
- [0, 0, 1, tz],
65
- [0, 0, 0, 1],
66
- **kwargs)
67
-
68
- def scale2d(sx, sy, **kwargs):
69
- return matrix(
70
- [sx, 0, 0],
71
- [0, sy, 0],
72
- [0, 0, 1],
73
- **kwargs)
74
-
75
- def scale3d(sx, sy, sz, **kwargs):
76
- return matrix(
77
- [sx, 0, 0, 0],
78
- [0, sy, 0, 0],
79
- [0, 0, sz, 0],
80
- [0, 0, 0, 1],
81
- **kwargs)
82
-
83
- def rotate2d(theta, **kwargs):
84
- return matrix(
85
- [torch.cos(theta), torch.sin(-theta), 0],
86
- [torch.sin(theta), torch.cos(theta), 0],
87
- [0, 0, 1],
88
- **kwargs)
89
-
90
- def rotate3d(v, theta, **kwargs):
91
- vx = v[..., 0]; vy = v[..., 1]; vz = v[..., 2]
92
- s = torch.sin(theta); c = torch.cos(theta); cc = 1 - c
93
- return matrix(
94
- [vx*vx*cc+c, vx*vy*cc-vz*s, vx*vz*cc+vy*s, 0],
95
- [vy*vx*cc+vz*s, vy*vy*cc+c, vy*vz*cc-vx*s, 0],
96
- [vz*vx*cc-vy*s, vz*vy*cc+vx*s, vz*vz*cc+c, 0],
97
- [0, 0, 0, 1],
98
- **kwargs)
99
-
100
- def translate2d_inv(tx, ty, **kwargs):
101
- return translate2d(-tx, -ty, **kwargs)
102
-
103
- def scale2d_inv(sx, sy, **kwargs):
104
- return scale2d(1 / sx, 1 / sy, **kwargs)
105
-
106
- def rotate2d_inv(theta, **kwargs):
107
- return rotate2d(-theta, **kwargs)
108
-
109
- #----------------------------------------------------------------------------
110
- # Versatile image augmentation pipeline from the paper
111
- # "Training Generative Adversarial Networks with Limited Data".
112
- #
113
- # All augmentations are disabled by default; individual augmentations can
114
- # be enabled by setting their probability multipliers to 1.
115
-
116
- @persistence.persistent_class
117
- class AugmentPipe(torch.nn.Module):
118
- def __init__(self,
119
- xflip=0, rotate90=0, xint=0, xint_max=0.125,
120
- scale=0, rotate=0, aniso=0, xfrac=0, scale_std=0.2, rotate_max=1, aniso_std=0.2, xfrac_std=0.125,
121
- brightness=0, contrast=0, lumaflip=0, hue=0, saturation=0, brightness_std=0.2, contrast_std=0.5, hue_max=1, saturation_std=1,
122
- imgfilter=0, imgfilter_bands=[1,1,1,1], imgfilter_std=1,
123
- noise=0, cutout=0, noise_std=0.1, cutout_size=0.5,
124
- ):
125
- super().__init__()
126
- self.register_buffer('p', torch.ones([])) # Overall multiplier for augmentation probability.
127
-
128
- # Pixel blitting.
129
- self.xflip = float(xflip) # Probability multiplier for x-flip.
130
- self.rotate90 = float(rotate90) # Probability multiplier for 90 degree rotations.
131
- self.xint = float(xint) # Probability multiplier for integer translation.
132
- self.xint_max = float(xint_max) # Range of integer translation, relative to image dimensions.
133
-
134
- # General geometric transformations.
135
- self.scale = float(scale) # Probability multiplier for isotropic scaling.
136
- self.rotate = float(rotate) # Probability multiplier for arbitrary rotation.
137
- self.aniso = float(aniso) # Probability multiplier for anisotropic scaling.
138
- self.xfrac = float(xfrac) # Probability multiplier for fractional translation.
139
- self.scale_std = float(scale_std) # Log2 standard deviation of isotropic scaling.
140
- self.rotate_max = float(rotate_max) # Range of arbitrary rotation, 1 = full circle.
141
- self.aniso_std = float(aniso_std) # Log2 standard deviation of anisotropic scaling.
142
- self.xfrac_std = float(xfrac_std) # Standard deviation of frational translation, relative to image dimensions.
143
-
144
- # Color transformations.
145
- self.brightness = float(brightness) # Probability multiplier for brightness.
146
- self.contrast = float(contrast) # Probability multiplier for contrast.
147
- self.lumaflip = float(lumaflip) # Probability multiplier for luma flip.
148
- self.hue = float(hue) # Probability multiplier for hue rotation.
149
- self.saturation = float(saturation) # Probability multiplier for saturation.
150
- self.brightness_std = float(brightness_std) # Standard deviation of brightness.
151
- self.contrast_std = float(contrast_std) # Log2 standard deviation of contrast.
152
- self.hue_max = float(hue_max) # Range of hue rotation, 1 = full circle.
153
- self.saturation_std = float(saturation_std) # Log2 standard deviation of saturation.
154
-
155
- # Image-space filtering.
156
- self.imgfilter = float(imgfilter) # Probability multiplier for image-space filtering.
157
- self.imgfilter_bands = list(imgfilter_bands) # Probability multipliers for individual frequency bands.
158
- self.imgfilter_std = float(imgfilter_std) # Log2 standard deviation of image-space filter amplification.
159
-
160
- # Image-space corruptions.
161
- self.noise = float(noise) # Probability multiplier for additive RGB noise.
162
- self.cutout = float(cutout) # Probability multiplier for cutout.
163
- self.noise_std = float(noise_std) # Standard deviation of additive RGB noise.
164
- self.cutout_size = float(cutout_size) # Size of the cutout rectangle, relative to image dimensions.
165
-
166
- # Setup orthogonal lowpass filter for geometric augmentations.
167
- self.register_buffer('Hz_geom', upfirdn2d.setup_filter(wavelets['sym6']))
168
-
169
- # Construct filter bank for image-space filtering.
170
- Hz_lo = np.asarray(wavelets['sym2']) # H(z)
171
- Hz_hi = Hz_lo * ((-1) ** np.arange(Hz_lo.size)) # H(-z)
172
- Hz_lo2 = np.convolve(Hz_lo, Hz_lo[::-1]) / 2 # H(z) * H(z^-1) / 2
173
- Hz_hi2 = np.convolve(Hz_hi, Hz_hi[::-1]) / 2 # H(-z) * H(-z^-1) / 2
174
- Hz_fbank = np.eye(4, 1) # Bandpass(H(z), b_i)
175
- for i in range(1, Hz_fbank.shape[0]):
176
- Hz_fbank = np.dstack([Hz_fbank, np.zeros_like(Hz_fbank)]).reshape(Hz_fbank.shape[0], -1)[:, :-1]
177
- Hz_fbank = scipy.signal.convolve(Hz_fbank, [Hz_lo2])
178
- Hz_fbank[i, (Hz_fbank.shape[1] - Hz_hi2.size) // 2 : (Hz_fbank.shape[1] + Hz_hi2.size) // 2] += Hz_hi2
179
- self.register_buffer('Hz_fbank', torch.as_tensor(Hz_fbank, dtype=torch.float32))
180
-
181
- def forward(self, images, debug_percentile=None):
182
- assert isinstance(images, torch.Tensor) and images.ndim == 4
183
- batch_size, num_channels, height, width = images.shape
184
- device = images.device
185
- if debug_percentile is not None:
186
- debug_percentile = torch.as_tensor(debug_percentile, dtype=torch.float32, device=device)
187
-
188
- # -------------------------------------
189
- # Select parameters for pixel blitting.
190
- # -------------------------------------
191
-
192
- # Initialize inverse homogeneous 2D transform: G_inv @ pixel_out ==> pixel_in
193
- I_3 = torch.eye(3, device=device)
194
- G_inv = I_3
195
-
196
- # Apply x-flip with probability (xflip * strength).
197
- if self.xflip > 0:
198
- i = torch.floor(torch.rand([batch_size], device=device) * 2)
199
- i = torch.where(torch.rand([batch_size], device=device) < self.xflip * self.p, i, torch.zeros_like(i))
200
- if debug_percentile is not None:
201
- i = torch.full_like(i, torch.floor(debug_percentile * 2))
202
- G_inv = G_inv @ scale2d_inv(1 - 2 * i, 1)
203
-
204
- # Apply 90 degree rotations with probability (rotate90 * strength).
205
- if self.rotate90 > 0:
206
- i = torch.floor(torch.rand([batch_size], device=device) * 4)
207
- i = torch.where(torch.rand([batch_size], device=device) < self.rotate90 * self.p, i, torch.zeros_like(i))
208
- if debug_percentile is not None:
209
- i = torch.full_like(i, torch.floor(debug_percentile * 4))
210
- G_inv = G_inv @ rotate2d_inv(-np.pi / 2 * i)
211
-
212
- # Apply integer translation with probability (xint * strength).
213
- if self.xint > 0:
214
- t = (torch.rand([batch_size, 2], device=device) * 2 - 1) * self.xint_max
215
- t = torch.where(torch.rand([batch_size, 1], device=device) < self.xint * self.p, t, torch.zeros_like(t))
216
- if debug_percentile is not None:
217
- t = torch.full_like(t, (debug_percentile * 2 - 1) * self.xint_max)
218
- G_inv = G_inv @ translate2d_inv(torch.round(t[:,0] * width), torch.round(t[:,1] * height))
219
-
220
- # --------------------------------------------------------
221
- # Select parameters for general geometric transformations.
222
- # --------------------------------------------------------
223
-
224
- # Apply isotropic scaling with probability (scale * strength).
225
- if self.scale > 0:
226
- s = torch.exp2(torch.randn([batch_size], device=device) * self.scale_std)
227
- s = torch.where(torch.rand([batch_size], device=device) < self.scale * self.p, s, torch.ones_like(s))
228
- if debug_percentile is not None:
229
- s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.scale_std))
230
- G_inv = G_inv @ scale2d_inv(s, s)
231
-
232
- # Apply pre-rotation with probability p_rot.
233
- p_rot = 1 - torch.sqrt((1 - self.rotate * self.p).clamp(0, 1)) # P(pre OR post) = p
234
- if self.rotate > 0:
235
- theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max
236
- theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta))
237
- if debug_percentile is not None:
238
- theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.rotate_max)
239
- G_inv = G_inv @ rotate2d_inv(-theta) # Before anisotropic scaling.
240
-
241
- # Apply anisotropic scaling with probability (aniso * strength).
242
- if self.aniso > 0:
243
- s = torch.exp2(torch.randn([batch_size], device=device) * self.aniso_std)
244
- s = torch.where(torch.rand([batch_size], device=device) < self.aniso * self.p, s, torch.ones_like(s))
245
- if debug_percentile is not None:
246
- s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.aniso_std))
247
- G_inv = G_inv @ scale2d_inv(s, 1 / s)
248
-
249
- # Apply post-rotation with probability p_rot.
250
- if self.rotate > 0:
251
- theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.rotate_max
252
- theta = torch.where(torch.rand([batch_size], device=device) < p_rot, theta, torch.zeros_like(theta))
253
- if debug_percentile is not None:
254
- theta = torch.zeros_like(theta)
255
- G_inv = G_inv @ rotate2d_inv(-theta) # After anisotropic scaling.
256
-
257
- # Apply fractional translation with probability (xfrac * strength).
258
- if self.xfrac > 0:
259
- t = torch.randn([batch_size, 2], device=device) * self.xfrac_std
260
- t = torch.where(torch.rand([batch_size, 1], device=device) < self.xfrac * self.p, t, torch.zeros_like(t))
261
- if debug_percentile is not None:
262
- t = torch.full_like(t, torch.erfinv(debug_percentile * 2 - 1) * self.xfrac_std)
263
- G_inv = G_inv @ translate2d_inv(t[:,0] * width, t[:,1] * height)
264
-
265
- # ----------------------------------
266
- # Execute geometric transformations.
267
- # ----------------------------------
268
-
269
- # Execute if the transform is not identity.
270
- if G_inv is not I_3:
271
-
272
- # Calculate padding.
273
- cx = (width - 1) / 2
274
- cy = (height - 1) / 2
275
- cp = matrix([-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1], device=device) # [idx, xyz]
276
- cp = G_inv @ cp.t() # [batch, xyz, idx]
277
- Hz_pad = self.Hz_geom.shape[0] // 4
278
- margin = cp[:, :2, :].permute(1, 0, 2).flatten(1) # [xy, batch * idx]
279
- margin = torch.cat([-margin, margin]).max(dim=1).values # [x0, y0, x1, y1]
280
- margin = margin + misc.constant([Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device)
281
- margin = margin.max(misc.constant([0, 0] * 2, device=device))
282
- margin = margin.min(misc.constant([width-1, height-1] * 2, device=device))
283
- mx0, my0, mx1, my1 = margin.ceil().to(torch.int32)
284
-
285
- # Pad image and adjust origin.
286
- images = torch.nn.functional.pad(input=images, pad=[mx0,mx1,my0,my1], mode='reflect')
287
- G_inv = translate2d((mx0 - mx1) / 2, (my0 - my1) / 2) @ G_inv
288
-
289
- # Upsample.
290
- images = upfirdn2d.upsample2d(x=images, f=self.Hz_geom, up=2)
291
- G_inv = scale2d(2, 2, device=device) @ G_inv @ scale2d_inv(2, 2, device=device)
292
- G_inv = translate2d(-0.5, -0.5, device=device) @ G_inv @ translate2d_inv(-0.5, -0.5, device=device)
293
-
294
- # Execute transformation.
295
- shape = [batch_size, num_channels, (height + Hz_pad * 2) * 2, (width + Hz_pad * 2) * 2]
296
- G_inv = scale2d(2 / images.shape[3], 2 / images.shape[2], device=device) @ G_inv @ scale2d_inv(2 / shape[3], 2 / shape[2], device=device)
297
- grid = torch.nn.functional.affine_grid(theta=G_inv[:,:2,:], size=shape, align_corners=False)
298
- images = grid_sample_gradfix.grid_sample(images, grid)
299
-
300
- # Downsample and crop.
301
- images = upfirdn2d.downsample2d(x=images, f=self.Hz_geom, down=2, padding=-Hz_pad*2, flip_filter=True)
302
-
303
- # --------------------------------------------
304
- # Select parameters for color transformations.
305
- # --------------------------------------------
306
-
307
- # Initialize homogeneous 3D transformation matrix: C @ color_in ==> color_out
308
- I_4 = torch.eye(4, device=device)
309
- C = I_4
310
-
311
- # Apply brightness with probability (brightness * strength).
312
- if self.brightness > 0:
313
- b = torch.randn([batch_size], device=device) * self.brightness_std
314
- b = torch.where(torch.rand([batch_size], device=device) < self.brightness * self.p, b, torch.zeros_like(b))
315
- if debug_percentile is not None:
316
- b = torch.full_like(b, torch.erfinv(debug_percentile * 2 - 1) * self.brightness_std)
317
- C = translate3d(b, b, b) @ C
318
-
319
- # Apply contrast with probability (contrast * strength).
320
- if self.contrast > 0:
321
- c = torch.exp2(torch.randn([batch_size], device=device) * self.contrast_std)
322
- c = torch.where(torch.rand([batch_size], device=device) < self.contrast * self.p, c, torch.ones_like(c))
323
- if debug_percentile is not None:
324
- c = torch.full_like(c, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.contrast_std))
325
- C = scale3d(c, c, c) @ C
326
-
327
- # Apply luma flip with probability (lumaflip * strength).
328
- v = misc.constant(np.asarray([1, 1, 1, 0]) / np.sqrt(3), device=device) # Luma axis.
329
- if self.lumaflip > 0:
330
- i = torch.floor(torch.rand([batch_size, 1, 1], device=device) * 2)
331
- i = torch.where(torch.rand([batch_size, 1, 1], device=device) < self.lumaflip * self.p, i, torch.zeros_like(i))
332
- if debug_percentile is not None:
333
- i = torch.full_like(i, torch.floor(debug_percentile * 2))
334
- C = (I_4 - 2 * v.ger(v) * i) @ C # Householder reflection.
335
-
336
- # Apply hue rotation with probability (hue * strength).
337
- if self.hue > 0 and num_channels > 1:
338
- theta = (torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.hue_max
339
- theta = torch.where(torch.rand([batch_size], device=device) < self.hue * self.p, theta, torch.zeros_like(theta))
340
- if debug_percentile is not None:
341
- theta = torch.full_like(theta, (debug_percentile * 2 - 1) * np.pi * self.hue_max)
342
- C = rotate3d(v, theta) @ C # Rotate around v.
343
-
344
- # Apply saturation with probability (saturation * strength).
345
- if self.saturation > 0 and num_channels > 1:
346
- s = torch.exp2(torch.randn([batch_size, 1, 1], device=device) * self.saturation_std)
347
- s = torch.where(torch.rand([batch_size, 1, 1], device=device) < self.saturation * self.p, s, torch.ones_like(s))
348
- if debug_percentile is not None:
349
- s = torch.full_like(s, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.saturation_std))
350
- C = (v.ger(v) + (I_4 - v.ger(v)) * s) @ C
351
-
352
- # ------------------------------
353
- # Execute color transformations.
354
- # ------------------------------
355
-
356
- # Execute if the transform is not identity.
357
- if C is not I_4:
358
- images = images.reshape([batch_size, num_channels, height * width])
359
- if num_channels == 3:
360
- images = C[:, :3, :3] @ images + C[:, :3, 3:]
361
- elif num_channels == 1:
362
- C = C[:, :3, :].mean(dim=1, keepdims=True)
363
- images = images * C[:, :, :3].sum(dim=2, keepdims=True) + C[:, :, 3:]
364
- else:
365
- raise ValueError('Image must be RGB (3 channels) or L (1 channel)')
366
- images = images.reshape([batch_size, num_channels, height, width])
367
-
368
- # ----------------------
369
- # Image-space filtering.
370
- # ----------------------
371
-
372
- if self.imgfilter > 0:
373
- num_bands = self.Hz_fbank.shape[0]
374
- assert len(self.imgfilter_bands) == num_bands
375
- expected_power = misc.constant(np.array([10, 1, 1, 1]) / 13, device=device) # Expected power spectrum (1/f).
376
-
377
- # Apply amplification for each band with probability (imgfilter * strength * band_strength).
378
- g = torch.ones([batch_size, num_bands], device=device) # Global gain vector (identity).
379
- for i, band_strength in enumerate(self.imgfilter_bands):
380
- t_i = torch.exp2(torch.randn([batch_size], device=device) * self.imgfilter_std)
381
- t_i = torch.where(torch.rand([batch_size], device=device) < self.imgfilter * self.p * band_strength, t_i, torch.ones_like(t_i))
382
- if debug_percentile is not None:
383
- t_i = torch.full_like(t_i, torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.imgfilter_std)) if band_strength > 0 else torch.ones_like(t_i)
384
- t = torch.ones([batch_size, num_bands], device=device) # Temporary gain vector.
385
- t[:, i] = t_i # Replace i'th element.
386
- t = t / (expected_power * t.square()).sum(dim=-1, keepdims=True).sqrt() # Normalize power.
387
- g = g * t # Accumulate into global gain.
388
-
389
- # Construct combined amplification filter.
390
- Hz_prime = g @ self.Hz_fbank # [batch, tap]
391
- Hz_prime = Hz_prime.unsqueeze(1).repeat([1, num_channels, 1]) # [batch, channels, tap]
392
- Hz_prime = Hz_prime.reshape([batch_size * num_channels, 1, -1]) # [batch * channels, 1, tap]
393
-
394
- # Apply filter.
395
- p = self.Hz_fbank.shape[1] // 2
396
- images = images.reshape([1, batch_size * num_channels, height, width])
397
- images = torch.nn.functional.pad(input=images, pad=[p,p,p,p], mode='reflect')
398
- images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(2), groups=batch_size*num_channels)
399
- images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(3), groups=batch_size*num_channels)
400
- images = images.reshape([batch_size, num_channels, height, width])
401
-
402
- # ------------------------
403
- # Image-space corruptions.
404
- # ------------------------
405
-
406
- # Apply additive RGB noise with probability (noise * strength).
407
- if self.noise > 0:
408
- sigma = torch.randn([batch_size, 1, 1, 1], device=device).abs() * self.noise_std
409
- sigma = torch.where(torch.rand([batch_size, 1, 1, 1], device=device) < self.noise * self.p, sigma, torch.zeros_like(sigma))
410
- if debug_percentile is not None:
411
- sigma = torch.full_like(sigma, torch.erfinv(debug_percentile) * self.noise_std)
412
- images = images + torch.randn([batch_size, num_channels, height, width], device=device) * sigma
413
-
414
- # Apply cutout with probability (cutout * strength).
415
- if self.cutout > 0:
416
- size = torch.full([batch_size, 2, 1, 1, 1], self.cutout_size, device=device)
417
- size = torch.where(torch.rand([batch_size, 1, 1, 1, 1], device=device) < self.cutout * self.p, size, torch.zeros_like(size))
418
- center = torch.rand([batch_size, 2, 1, 1, 1], device=device)
419
- if debug_percentile is not None:
420
- size = torch.full_like(size, self.cutout_size)
421
- center = torch.full_like(center, debug_percentile)
422
- coord_x = torch.arange(width, device=device).reshape([1, 1, 1, -1])
423
- coord_y = torch.arange(height, device=device).reshape([1, 1, -1, 1])
424
- mask_x = (((coord_x + 0.5) / width - center[:, 0]).abs() >= size[:, 0] / 2)
425
- mask_y = (((coord_y + 0.5) / height - center[:, 1]).abs() >= size[:, 1] / 2)
426
- mask = torch.logical_or(mask_x, mask_y).to(torch.float32)
427
- images = images * mask
428
-
429
- return images
430
-
431
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training/dataset.py DELETED
@@ -1,236 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- import os
10
- import numpy as np
11
- import zipfile
12
- import PIL.Image
13
- import json
14
- import torch
15
- import dnnlib
16
-
17
- try:
18
- import pyspng
19
- except ImportError:
20
- pyspng = None
21
-
22
- #----------------------------------------------------------------------------
23
-
24
- class Dataset(torch.utils.data.Dataset):
25
- def __init__(self,
26
- name, # Name of the dataset.
27
- raw_shape, # Shape of the raw image data (NCHW).
28
- max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip.
29
- use_labels = False, # Enable conditioning labels? False = label dimension is zero.
30
- xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size.
31
- random_seed = 0, # Random seed to use when applying max_size.
32
- ):
33
- self._name = name
34
- self._raw_shape = list(raw_shape)
35
- self._use_labels = use_labels
36
- self._raw_labels = None
37
- self._label_shape = None
38
-
39
- # Apply max_size.
40
- self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64)
41
- if (max_size is not None) and (self._raw_idx.size > max_size):
42
- np.random.RandomState(random_seed).shuffle(self._raw_idx)
43
- self._raw_idx = np.sort(self._raw_idx[:max_size])
44
-
45
- # Apply xflip.
46
- self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8)
47
- if xflip:
48
- self._raw_idx = np.tile(self._raw_idx, 2)
49
- self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)])
50
-
51
- def _get_raw_labels(self):
52
- if self._raw_labels is None:
53
- self._raw_labels = self._load_raw_labels() if self._use_labels else None
54
- if self._raw_labels is None:
55
- self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32)
56
- assert isinstance(self._raw_labels, np.ndarray)
57
- assert self._raw_labels.shape[0] == self._raw_shape[0]
58
- assert self._raw_labels.dtype in [np.float32, np.int64]
59
- if self._raw_labels.dtype == np.int64:
60
- assert self._raw_labels.ndim == 1
61
- assert np.all(self._raw_labels >= 0)
62
- return self._raw_labels
63
-
64
- def close(self): # to be overridden by subclass
65
- pass
66
-
67
- def _load_raw_image(self, raw_idx): # to be overridden by subclass
68
- raise NotImplementedError
69
-
70
- def _load_raw_labels(self): # to be overridden by subclass
71
- raise NotImplementedError
72
-
73
- def __getstate__(self):
74
- return dict(self.__dict__, _raw_labels=None)
75
-
76
- def __del__(self):
77
- try:
78
- self.close()
79
- except:
80
- pass
81
-
82
- def __len__(self):
83
- return self._raw_idx.size
84
-
85
- def __getitem__(self, idx):
86
- image = self._load_raw_image(self._raw_idx[idx])
87
- assert isinstance(image, np.ndarray)
88
- assert list(image.shape) == self.image_shape
89
- assert image.dtype == np.uint8
90
- if self._xflip[idx]:
91
- assert image.ndim == 3 # CHW
92
- image = image[:, :, ::-1]
93
- return image.copy(), self.get_label(idx)
94
-
95
- def get_label(self, idx):
96
- label = self._get_raw_labels()[self._raw_idx[idx]]
97
- if label.dtype == np.int64:
98
- onehot = np.zeros(self.label_shape, dtype=np.float32)
99
- onehot[label] = 1
100
- label = onehot
101
- return label.copy()
102
-
103
- def get_details(self, idx):
104
- d = dnnlib.EasyDict()
105
- d.raw_idx = int(self._raw_idx[idx])
106
- d.xflip = (int(self._xflip[idx]) != 0)
107
- d.raw_label = self._get_raw_labels()[d.raw_idx].copy()
108
- return d
109
-
110
- @property
111
- def name(self):
112
- return self._name
113
-
114
- @property
115
- def image_shape(self):
116
- return list(self._raw_shape[1:])
117
-
118
- @property
119
- def num_channels(self):
120
- assert len(self.image_shape) == 3 # CHW
121
- return self.image_shape[0]
122
-
123
- @property
124
- def resolution(self):
125
- assert len(self.image_shape) == 3 # CHW
126
- assert self.image_shape[1] == self.image_shape[2]
127
- return self.image_shape[1]
128
-
129
- @property
130
- def label_shape(self):
131
- if self._label_shape is None:
132
- raw_labels = self._get_raw_labels()
133
- if raw_labels.dtype == np.int64:
134
- self._label_shape = [int(np.max(raw_labels)) + 1]
135
- else:
136
- self._label_shape = raw_labels.shape[1:]
137
- return list(self._label_shape)
138
-
139
- @property
140
- def label_dim(self):
141
- assert len(self.label_shape) == 1
142
- return self.label_shape[0]
143
-
144
- @property
145
- def has_labels(self):
146
- return any(x != 0 for x in self.label_shape)
147
-
148
- @property
149
- def has_onehot_labels(self):
150
- return self._get_raw_labels().dtype == np.int64
151
-
152
- #----------------------------------------------------------------------------
153
-
154
- class ImageFolderDataset(Dataset):
155
- def __init__(self,
156
- path, # Path to directory or zip.
157
- resolution = None, # Ensure specific resolution, None = highest available.
158
- **super_kwargs, # Additional arguments for the Dataset base class.
159
- ):
160
- self._path = path
161
- self._zipfile = None
162
-
163
- if os.path.isdir(self._path):
164
- self._type = 'dir'
165
- self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files}
166
- elif self._file_ext(self._path) == '.zip':
167
- self._type = 'zip'
168
- self._all_fnames = set(self._get_zipfile().namelist())
169
- else:
170
- raise IOError('Path must point to a directory or zip')
171
-
172
- PIL.Image.init()
173
- self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION)
174
- if len(self._image_fnames) == 0:
175
- raise IOError('No image files found in the specified path')
176
-
177
- name = os.path.splitext(os.path.basename(self._path))[0]
178
- raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape)
179
- if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution):
180
- raise IOError('Image files do not match the specified resolution')
181
- super().__init__(name=name, raw_shape=raw_shape, **super_kwargs)
182
-
183
- @staticmethod
184
- def _file_ext(fname):
185
- return os.path.splitext(fname)[1].lower()
186
-
187
- def _get_zipfile(self):
188
- assert self._type == 'zip'
189
- if self._zipfile is None:
190
- self._zipfile = zipfile.ZipFile(self._path)
191
- return self._zipfile
192
-
193
- def _open_file(self, fname):
194
- if self._type == 'dir':
195
- return open(os.path.join(self._path, fname), 'rb')
196
- if self._type == 'zip':
197
- return self._get_zipfile().open(fname, 'r')
198
- return None
199
-
200
- def close(self):
201
- try:
202
- if self._zipfile is not None:
203
- self._zipfile.close()
204
- finally:
205
- self._zipfile = None
206
-
207
- def __getstate__(self):
208
- return dict(super().__getstate__(), _zipfile=None)
209
-
210
- def _load_raw_image(self, raw_idx):
211
- fname = self._image_fnames[raw_idx]
212
- with self._open_file(fname) as f:
213
- if pyspng is not None and self._file_ext(fname) == '.png':
214
- image = pyspng.load(f.read())
215
- else:
216
- image = np.array(PIL.Image.open(f))
217
- if image.ndim == 2:
218
- image = image[:, :, np.newaxis] # HW => HWC
219
- image = image.transpose(2, 0, 1) # HWC => CHW
220
- return image
221
-
222
- def _load_raw_labels(self):
223
- fname = 'dataset.json'
224
- if fname not in self._all_fnames:
225
- return None
226
- with self._open_file(fname) as f:
227
- labels = json.load(f)['labels']
228
- if labels is None:
229
- return None
230
- labels = dict(labels)
231
- labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames]
232
- labels = np.array(labels)
233
- labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])
234
- return labels
235
-
236
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training/loss.py DELETED
@@ -1,133 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- import numpy as np
10
- import torch
11
- from torch_utils import training_stats
12
- from torch_utils import misc
13
- from torch_utils.ops import conv2d_gradfix
14
-
15
- #----------------------------------------------------------------------------
16
-
17
- class Loss:
18
- def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync, gain): # to be overridden by subclass
19
- raise NotImplementedError()
20
-
21
- #----------------------------------------------------------------------------
22
-
23
- class StyleGAN2Loss(Loss):
24
- def __init__(self, device, G_mapping, G_synthesis, D, augment_pipe=None, style_mixing_prob=0.9, r1_gamma=10, pl_batch_shrink=2, pl_decay=0.01, pl_weight=2):
25
- super().__init__()
26
- self.device = device
27
- self.G_mapping = G_mapping
28
- self.G_synthesis = G_synthesis
29
- self.D = D
30
- self.augment_pipe = augment_pipe
31
- self.style_mixing_prob = style_mixing_prob
32
- self.r1_gamma = r1_gamma
33
- self.pl_batch_shrink = pl_batch_shrink
34
- self.pl_decay = pl_decay
35
- self.pl_weight = pl_weight
36
- self.pl_mean = torch.zeros([], device=device)
37
-
38
- def run_G(self, z, c, sync):
39
- with misc.ddp_sync(self.G_mapping, sync):
40
- ws = self.G_mapping(z, c)
41
- if self.style_mixing_prob > 0:
42
- with torch.autograd.profiler.record_function('style_mixing'):
43
- cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1])
44
- cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1]))
45
- ws[:, cutoff:] = self.G_mapping(torch.randn_like(z), c, skip_w_avg_update=True)[:, cutoff:]
46
- with misc.ddp_sync(self.G_synthesis, sync):
47
- img = self.G_synthesis(ws)
48
- return img, ws
49
-
50
- def run_D(self, img, c, sync):
51
- if self.augment_pipe is not None:
52
- img = self.augment_pipe(img)
53
- with misc.ddp_sync(self.D, sync):
54
- logits = self.D(img, c)
55
- return logits
56
-
57
- def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync, gain):
58
- assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth']
59
- do_Gmain = (phase in ['Gmain', 'Gboth'])
60
- do_Dmain = (phase in ['Dmain', 'Dboth'])
61
- do_Gpl = (phase in ['Greg', 'Gboth']) and (self.pl_weight != 0)
62
- do_Dr1 = (phase in ['Dreg', 'Dboth']) and (self.r1_gamma != 0)
63
-
64
- # Gmain: Maximize logits for generated images.
65
- if do_Gmain:
66
- with torch.autograd.profiler.record_function('Gmain_forward'):
67
- gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync=(sync and not do_Gpl)) # May get synced by Gpl.
68
- gen_logits = self.run_D(gen_img, gen_c, sync=False)
69
- training_stats.report('Loss/scores/fake', gen_logits)
70
- training_stats.report('Loss/signs/fake', gen_logits.sign())
71
- loss_Gmain = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits))
72
- training_stats.report('Loss/G/loss', loss_Gmain)
73
- with torch.autograd.profiler.record_function('Gmain_backward'):
74
- loss_Gmain.mean().mul(gain).backward()
75
-
76
- # Gpl: Apply path length regularization.
77
- if do_Gpl:
78
- with torch.autograd.profiler.record_function('Gpl_forward'):
79
- batch_size = gen_z.shape[0] // self.pl_batch_shrink
80
- gen_img, gen_ws = self.run_G(gen_z[:batch_size], gen_c[:batch_size], sync=sync)
81
- pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3])
82
- with torch.autograd.profiler.record_function('pl_grads'), conv2d_gradfix.no_weight_gradients():
83
- pl_grads = torch.autograd.grad(outputs=[(gen_img * pl_noise).sum()], inputs=[gen_ws], create_graph=True, only_inputs=True)[0]
84
- pl_lengths = pl_grads.square().sum(2).mean(1).sqrt()
85
- pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay)
86
- self.pl_mean.copy_(pl_mean.detach())
87
- pl_penalty = (pl_lengths - pl_mean).square()
88
- training_stats.report('Loss/pl_penalty', pl_penalty)
89
- loss_Gpl = pl_penalty * self.pl_weight
90
- training_stats.report('Loss/G/reg', loss_Gpl)
91
- with torch.autograd.profiler.record_function('Gpl_backward'):
92
- (gen_img[:, 0, 0, 0] * 0 + loss_Gpl).mean().mul(gain).backward()
93
-
94
- # Dmain: Minimize logits for generated images.
95
- loss_Dgen = 0
96
- if do_Dmain:
97
- with torch.autograd.profiler.record_function('Dgen_forward'):
98
- gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync=False)
99
- gen_logits = self.run_D(gen_img, gen_c, sync=False) # Gets synced by loss_Dreal.
100
- training_stats.report('Loss/scores/fake', gen_logits)
101
- training_stats.report('Loss/signs/fake', gen_logits.sign())
102
- loss_Dgen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits))
103
- with torch.autograd.profiler.record_function('Dgen_backward'):
104
- loss_Dgen.mean().mul(gain).backward()
105
-
106
- # Dmain: Maximize logits for real images.
107
- # Dr1: Apply R1 regularization.
108
- if do_Dmain or do_Dr1:
109
- name = 'Dreal_Dr1' if do_Dmain and do_Dr1 else 'Dreal' if do_Dmain else 'Dr1'
110
- with torch.autograd.profiler.record_function(name + '_forward'):
111
- real_img_tmp = real_img.detach().requires_grad_(do_Dr1)
112
- real_logits = self.run_D(real_img_tmp, real_c, sync=sync)
113
- training_stats.report('Loss/scores/real', real_logits)
114
- training_stats.report('Loss/signs/real', real_logits.sign())
115
-
116
- loss_Dreal = 0
117
- if do_Dmain:
118
- loss_Dreal = torch.nn.functional.softplus(-real_logits) # -log(sigmoid(real_logits))
119
- training_stats.report('Loss/D/loss', loss_Dgen + loss_Dreal)
120
-
121
- loss_Dr1 = 0
122
- if do_Dr1:
123
- with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients():
124
- r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp], create_graph=True, only_inputs=True)[0]
125
- r1_penalty = r1_grads.square().sum([1,2,3])
126
- loss_Dr1 = r1_penalty * (self.r1_gamma / 2)
127
- training_stats.report('Loss/r1_penalty', r1_penalty)
128
- training_stats.report('Loss/D/reg', loss_Dr1)
129
-
130
- with torch.autograd.profiler.record_function(name + '_backward'):
131
- (real_logits * 0 + loss_Dreal + loss_Dr1).mean().mul(gain).backward()
132
-
133
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training/networks.py DELETED
@@ -1,729 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- import numpy as np
10
- import torch
11
- from torch_utils import misc
12
- from torch_utils import persistence
13
- from torch_utils.ops import conv2d_resample
14
- from torch_utils.ops import upfirdn2d
15
- from torch_utils.ops import bias_act
16
- from torch_utils.ops import fma
17
-
18
- #----------------------------------------------------------------------------
19
-
20
- @misc.profiled_function
21
- def normalize_2nd_moment(x, dim=1, eps=1e-8):
22
- return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt()
23
-
24
- #----------------------------------------------------------------------------
25
-
26
- @misc.profiled_function
27
- def modulated_conv2d(
28
- x, # Input tensor of shape [batch_size, in_channels, in_height, in_width].
29
- weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width].
30
- styles, # Modulation coefficients of shape [batch_size, in_channels].
31
- noise = None, # Optional noise tensor to add to the output activations.
32
- up = 1, # Integer upsampling factor.
33
- down = 1, # Integer downsampling factor.
34
- padding = 0, # Padding with respect to the upsampled image.
35
- resample_filter = None, # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter().
36
- demodulate = True, # Apply weight demodulation?
37
- flip_weight = True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d).
38
- fused_modconv = True, # Perform modulation, convolution, and demodulation as a single fused operation?
39
- ):
40
- batch_size = x.shape[0]
41
- out_channels, in_channels, kh, kw = weight.shape
42
- misc.assert_shape(weight, [out_channels, in_channels, kh, kw]) # [OIkk]
43
- misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW]
44
- misc.assert_shape(styles, [batch_size, in_channels]) # [NI]
45
-
46
- # Pre-normalize inputs to avoid FP16 overflow.
47
- if x.dtype == torch.float16 and demodulate:
48
- weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1,2,3], keepdim=True)) # max_Ikk
49
- styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I
50
-
51
- # Calculate per-sample weights and demodulation coefficients.
52
- w = None
53
- dcoefs = None
54
- if demodulate or fused_modconv:
55
- w = weight.unsqueeze(0) # [NOIkk]
56
- w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk]
57
- if demodulate:
58
- dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO]
59
- if demodulate and fused_modconv:
60
- w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk]
61
-
62
- # Execute by scaling the activations before and after the convolution.
63
- if not fused_modconv:
64
- x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1)
65
- x = conv2d_resample.conv2d_resample(x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, flip_weight=flip_weight)
66
- if demodulate and noise is not None:
67
- x = fma.fma(x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype))
68
- elif demodulate:
69
- x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1)
70
- elif noise is not None:
71
- x = x.add_(noise.to(x.dtype))
72
- return x
73
-
74
- # Execute as one fused op using grouped convolution.
75
- with misc.suppress_tracer_warnings(): # this value will be treated as a constant
76
- batch_size = int(batch_size)
77
- misc.assert_shape(x, [batch_size, in_channels, None, None])
78
- x = x.reshape(1, -1, *x.shape[2:])
79
- w = w.reshape(-1, in_channels, kh, kw)
80
- x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, groups=batch_size, flip_weight=flip_weight)
81
- x = x.reshape(batch_size, -1, *x.shape[2:])
82
- if noise is not None:
83
- x = x.add_(noise)
84
- return x
85
-
86
- #----------------------------------------------------------------------------
87
-
88
- @persistence.persistent_class
89
- class FullyConnectedLayer(torch.nn.Module):
90
- def __init__(self,
91
- in_features, # Number of input features.
92
- out_features, # Number of output features.
93
- bias = True, # Apply additive bias before the activation function?
94
- activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
95
- lr_multiplier = 1, # Learning rate multiplier.
96
- bias_init = 0, # Initial value for the additive bias.
97
- ):
98
- super().__init__()
99
- self.activation = activation
100
- self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier)
101
- self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None
102
- self.weight_gain = lr_multiplier / np.sqrt(in_features)
103
- self.bias_gain = lr_multiplier
104
-
105
- def forward(self, x):
106
- w = self.weight.to(x.dtype) * self.weight_gain
107
- b = self.bias
108
- if b is not None:
109
- b = b.to(x.dtype)
110
- if self.bias_gain != 1:
111
- b = b * self.bias_gain
112
-
113
- if self.activation == 'linear' and b is not None:
114
- x = torch.addmm(b.unsqueeze(0), x, w.t())
115
- else:
116
- x = x.matmul(w.t())
117
- x = bias_act.bias_act(x, b, act=self.activation)
118
- return x
119
-
120
- #----------------------------------------------------------------------------
121
-
122
- @persistence.persistent_class
123
- class Conv2dLayer(torch.nn.Module):
124
- def __init__(self,
125
- in_channels, # Number of input channels.
126
- out_channels, # Number of output channels.
127
- kernel_size, # Width and height of the convolution kernel.
128
- bias = True, # Apply additive bias before the activation function?
129
- activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
130
- up = 1, # Integer upsampling factor.
131
- down = 1, # Integer downsampling factor.
132
- resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
133
- conv_clamp = None, # Clamp the output to +-X, None = disable clamping.
134
- channels_last = False, # Expect the input to have memory_format=channels_last?
135
- trainable = True, # Update the weights of this layer during training?
136
- ):
137
- super().__init__()
138
- self.activation = activation
139
- self.up = up
140
- self.down = down
141
- self.conv_clamp = conv_clamp
142
- self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
143
- self.padding = kernel_size // 2
144
- self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
145
- self.act_gain = bias_act.activation_funcs[activation].def_gain
146
-
147
- memory_format = torch.channels_last if channels_last else torch.contiguous_format
148
- weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)
149
- bias = torch.zeros([out_channels]) if bias else None
150
- if trainable:
151
- self.weight = torch.nn.Parameter(weight)
152
- self.bias = torch.nn.Parameter(bias) if bias is not None else None
153
- else:
154
- self.register_buffer('weight', weight)
155
- if bias is not None:
156
- self.register_buffer('bias', bias)
157
- else:
158
- self.bias = None
159
-
160
- def forward(self, x, gain=1):
161
- w = self.weight * self.weight_gain
162
- b = self.bias.to(x.dtype) if self.bias is not None else None
163
- flip_weight = (self.up == 1) # slightly faster
164
- x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=self.resample_filter, up=self.up, down=self.down, padding=self.padding, flip_weight=flip_weight)
165
-
166
- act_gain = self.act_gain * gain
167
- act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
168
- x = bias_act.bias_act(x, b, act=self.activation, gain=act_gain, clamp=act_clamp)
169
- return x
170
-
171
- #----------------------------------------------------------------------------
172
-
173
- @persistence.persistent_class
174
- class MappingNetwork(torch.nn.Module):
175
- def __init__(self,
176
- z_dim, # Input latent (Z) dimensionality, 0 = no latent.
177
- c_dim, # Conditioning label (C) dimensionality, 0 = no label.
178
- w_dim, # Intermediate latent (W) dimensionality.
179
- num_ws, # Number of intermediate latents to output, None = do not broadcast.
180
- num_layers = 8, # Number of mapping layers.
181
- embed_features = None, # Label embedding dimensionality, None = same as w_dim.
182
- layer_features = None, # Number of intermediate features in the mapping layers, None = same as w_dim.
183
- activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
184
- lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers.
185
- w_avg_beta = 0.995, # Decay for tracking the moving average of W during training, None = do not track.
186
- ):
187
- super().__init__()
188
- self.z_dim = z_dim
189
- self.c_dim = c_dim
190
- self.w_dim = w_dim
191
- self.num_ws = num_ws
192
- self.num_layers = num_layers
193
- self.w_avg_beta = w_avg_beta
194
-
195
- if embed_features is None:
196
- embed_features = w_dim
197
- if c_dim == 0:
198
- embed_features = 0
199
- if layer_features is None:
200
- layer_features = w_dim
201
- features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]
202
-
203
- if c_dim > 0:
204
- self.embed = FullyConnectedLayer(c_dim, embed_features)
205
- for idx in range(num_layers):
206
- in_features = features_list[idx]
207
- out_features = features_list[idx + 1]
208
- layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier)
209
- setattr(self, f'fc{idx}', layer)
210
-
211
- if num_ws is not None and w_avg_beta is not None:
212
- self.register_buffer('w_avg', torch.zeros([w_dim]))
213
-
214
- def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False):
215
- # Embed, normalize, and concat inputs.
216
- x = None
217
- with torch.autograd.profiler.record_function('input'):
218
- if self.z_dim > 0:
219
- misc.assert_shape(z, [None, self.z_dim])
220
- x = normalize_2nd_moment(z.to(torch.float32))
221
- if self.c_dim > 0:
222
- misc.assert_shape(c, [None, self.c_dim])
223
- y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
224
- x = torch.cat([x, y], dim=1) if x is not None else y
225
-
226
- # Main layers.
227
- for idx in range(self.num_layers):
228
- layer = getattr(self, f'fc{idx}')
229
- x = layer(x)
230
-
231
- # Update moving average of W.
232
- if self.w_avg_beta is not None and self.training and not skip_w_avg_update:
233
- with torch.autograd.profiler.record_function('update_w_avg'):
234
- self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
235
-
236
- # Broadcast.
237
- if self.num_ws is not None:
238
- with torch.autograd.profiler.record_function('broadcast'):
239
- x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
240
-
241
- # Apply truncation.
242
- if truncation_psi != 1:
243
- with torch.autograd.profiler.record_function('truncate'):
244
- assert self.w_avg_beta is not None
245
- if self.num_ws is None or truncation_cutoff is None:
246
- x = self.w_avg.lerp(x, truncation_psi)
247
- else:
248
- x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)
249
- return x
250
-
251
- #----------------------------------------------------------------------------
252
-
253
- @persistence.persistent_class
254
- class SynthesisLayer(torch.nn.Module):
255
- def __init__(self,
256
- in_channels, # Number of input channels.
257
- out_channels, # Number of output channels.
258
- w_dim, # Intermediate latent (W) dimensionality.
259
- resolution, # Resolution of this layer.
260
- kernel_size = 3, # Convolution kernel size.
261
- up = 1, # Integer upsampling factor.
262
- use_noise = True, # Enable noise input?
263
- activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
264
- resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
265
- conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
266
- channels_last = False, # Use channels_last format for the weights?
267
- ):
268
- super().__init__()
269
- self.resolution = resolution
270
- self.up = up
271
- self.use_noise = use_noise
272
- self.activation = activation
273
- self.conv_clamp = conv_clamp
274
- self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
275
- self.padding = kernel_size // 2
276
- self.act_gain = bias_act.activation_funcs[activation].def_gain
277
-
278
- self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
279
- memory_format = torch.channels_last if channels_last else torch.contiguous_format
280
- self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))
281
- if use_noise:
282
- self.register_buffer('noise_const', torch.randn([resolution, resolution]))
283
- self.noise_strength = torch.nn.Parameter(torch.zeros([]))
284
- self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
285
-
286
- def forward(self, x, w, noise_mode='random', fused_modconv=True, gain=1):
287
- assert noise_mode in ['random', 'const', 'none']
288
- in_resolution = self.resolution // self.up
289
- misc.assert_shape(x, [None, self.weight.shape[1], in_resolution, in_resolution])
290
- styles = self.affine(w)
291
-
292
- noise = None
293
- if self.use_noise and noise_mode == 'random':
294
- noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution], device=x.device) * self.noise_strength
295
- if self.use_noise and noise_mode == 'const':
296
- noise = self.noise_const * self.noise_strength
297
-
298
- flip_weight = (self.up == 1) # slightly faster
299
- x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=noise, up=self.up,
300
- padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight, fused_modconv=fused_modconv)
301
-
302
- act_gain = self.act_gain * gain
303
- act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
304
- x = bias_act.bias_act(x, self.bias.to(x.dtype), act=self.activation, gain=act_gain, clamp=act_clamp)
305
- return x
306
-
307
- #----------------------------------------------------------------------------
308
-
309
- @persistence.persistent_class
310
- class ToRGBLayer(torch.nn.Module):
311
- def __init__(self, in_channels, out_channels, w_dim, kernel_size=1, conv_clamp=None, channels_last=False):
312
- super().__init__()
313
- self.conv_clamp = conv_clamp
314
- self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
315
- memory_format = torch.channels_last if channels_last else torch.contiguous_format
316
- self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))
317
- self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
318
- self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
319
-
320
- def forward(self, x, w, fused_modconv=True):
321
- styles = self.affine(w) * self.weight_gain
322
- x = modulated_conv2d(x=x, weight=self.weight, styles=styles, demodulate=False, fused_modconv=fused_modconv)
323
- x = bias_act.bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp)
324
- return x
325
-
326
- #----------------------------------------------------------------------------
327
-
328
- @persistence.persistent_class
329
- class SynthesisBlock(torch.nn.Module):
330
- def __init__(self,
331
- in_channels, # Number of input channels, 0 = first block.
332
- out_channels, # Number of output channels.
333
- w_dim, # Intermediate latent (W) dimensionality.
334
- resolution, # Resolution of this block.
335
- img_channels, # Number of output color channels.
336
- is_last, # Is this the last block?
337
- architecture = 'skip', # Architecture: 'orig', 'skip', 'resnet'.
338
- resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
339
- conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
340
- use_fp16 = False, # Use FP16 for this block?
341
- fp16_channels_last = False, # Use channels-last memory format with FP16?
342
- **layer_kwargs, # Arguments for SynthesisLayer.
343
- ):
344
- assert architecture in ['orig', 'skip', 'resnet']
345
- super().__init__()
346
- self.in_channels = in_channels
347
- self.w_dim = w_dim
348
- self.resolution = resolution
349
- self.img_channels = img_channels
350
- self.is_last = is_last
351
- self.architecture = architecture
352
- self.use_fp16 = use_fp16
353
- self.channels_last = (use_fp16 and fp16_channels_last)
354
- self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
355
- self.num_conv = 0
356
- self.num_torgb = 0
357
-
358
- if in_channels == 0:
359
- self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution]))
360
-
361
- if in_channels != 0:
362
- self.conv0 = SynthesisLayer(in_channels, out_channels, w_dim=w_dim, resolution=resolution, up=2,
363
- resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs)
364
- self.num_conv += 1
365
-
366
- self.conv1 = SynthesisLayer(out_channels, out_channels, w_dim=w_dim, resolution=resolution,
367
- conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs)
368
- self.num_conv += 1
369
-
370
- if is_last or architecture == 'skip':
371
- self.torgb = ToRGBLayer(out_channels, img_channels, w_dim=w_dim,
372
- conv_clamp=conv_clamp, channels_last=self.channels_last)
373
- self.num_torgb += 1
374
-
375
- if in_channels != 0 and architecture == 'resnet':
376
- self.skip = Conv2dLayer(in_channels, out_channels, kernel_size=1, bias=False, up=2,
377
- resample_filter=resample_filter, channels_last=self.channels_last)
378
-
379
- def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, **layer_kwargs):
380
- misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim])
381
- w_iter = iter(ws.unbind(dim=1))
382
- dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
383
- memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
384
- if fused_modconv is None:
385
- with misc.suppress_tracer_warnings(): # this value will be treated as a constant
386
- fused_modconv = (not self.training) and (dtype == torch.float32 or int(x.shape[0]) == 1)
387
-
388
- # Input.
389
- if self.in_channels == 0:
390
- x = self.const.to(dtype=dtype, memory_format=memory_format)
391
- x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1])
392
- else:
393
- misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2])
394
- x = x.to(dtype=dtype, memory_format=memory_format)
395
-
396
- # Main layers.
397
- if self.in_channels == 0:
398
- x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
399
- elif self.architecture == 'resnet':
400
- y = self.skip(x, gain=np.sqrt(0.5))
401
- x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
402
- x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs)
403
- x = y.add_(x)
404
- else:
405
- x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
406
- x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
407
-
408
- # ToRGB.
409
- if img is not None:
410
- misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2])
411
- img = upfirdn2d.upsample2d(img, self.resample_filter)
412
- if self.is_last or self.architecture == 'skip':
413
- y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv)
414
- y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format)
415
- img = img.add_(y) if img is not None else y
416
-
417
- assert x.dtype == dtype
418
- assert img is None or img.dtype == torch.float32
419
- return x, img
420
-
421
- #----------------------------------------------------------------------------
422
-
423
- @persistence.persistent_class
424
- class SynthesisNetwork(torch.nn.Module):
425
- def __init__(self,
426
- w_dim, # Intermediate latent (W) dimensionality.
427
- img_resolution, # Output image resolution.
428
- img_channels, # Number of color channels.
429
- channel_base = 32768, # Overall multiplier for the number of channels.
430
- channel_max = 512, # Maximum number of channels in any layer.
431
- num_fp16_res = 0, # Use FP16 for the N highest resolutions.
432
- **block_kwargs, # Arguments for SynthesisBlock.
433
- ):
434
- assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0
435
- super().__init__()
436
- self.w_dim = w_dim
437
- self.img_resolution = img_resolution
438
- self.img_resolution_log2 = int(np.log2(img_resolution))
439
- self.img_channels = img_channels
440
- self.block_resolutions = [2 ** i for i in range(2, self.img_resolution_log2 + 1)]
441
- channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions}
442
- fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
443
-
444
- self.num_ws = 0
445
- for res in self.block_resolutions:
446
- in_channels = channels_dict[res // 2] if res > 4 else 0
447
- out_channels = channels_dict[res]
448
- use_fp16 = (res >= fp16_resolution)
449
- is_last = (res == self.img_resolution)
450
- block = SynthesisBlock(in_channels, out_channels, w_dim=w_dim, resolution=res,
451
- img_channels=img_channels, is_last=is_last, use_fp16=use_fp16, **block_kwargs)
452
- self.num_ws += block.num_conv
453
- if is_last:
454
- self.num_ws += block.num_torgb
455
- setattr(self, f'b{res}', block)
456
-
457
- def forward(self, ws, **block_kwargs):
458
- block_ws = []
459
- with torch.autograd.profiler.record_function('split_ws'):
460
- misc.assert_shape(ws, [None, self.num_ws, self.w_dim])
461
- ws = ws.to(torch.float32)
462
- w_idx = 0
463
- for res in self.block_resolutions:
464
- block = getattr(self, f'b{res}')
465
- block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb))
466
- w_idx += block.num_conv
467
-
468
- x = img = None
469
- for res, cur_ws in zip(self.block_resolutions, block_ws):
470
- block = getattr(self, f'b{res}')
471
- x, img = block(x, img, cur_ws, **block_kwargs)
472
- return img
473
-
474
- #----------------------------------------------------------------------------
475
-
476
- @persistence.persistent_class
477
- class Generator(torch.nn.Module):
478
- def __init__(self,
479
- z_dim, # Input latent (Z) dimensionality.
480
- c_dim, # Conditioning label (C) dimensionality.
481
- w_dim, # Intermediate latent (W) dimensionality.
482
- img_resolution, # Output resolution.
483
- img_channels, # Number of output color channels.
484
- mapping_kwargs = {}, # Arguments for MappingNetwork.
485
- synthesis_kwargs = {}, # Arguments for SynthesisNetwork.
486
- ):
487
- super().__init__()
488
- self.z_dim = z_dim
489
- self.c_dim = c_dim
490
- self.w_dim = w_dim
491
- self.img_resolution = img_resolution
492
- self.img_channels = img_channels
493
- self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs)
494
- self.num_ws = self.synthesis.num_ws
495
- self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs)
496
-
497
- def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, **synthesis_kwargs):
498
- ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff)
499
- img = self.synthesis(ws, **synthesis_kwargs)
500
- return img
501
-
502
- #----------------------------------------------------------------------------
503
-
504
- @persistence.persistent_class
505
- class DiscriminatorBlock(torch.nn.Module):
506
- def __init__(self,
507
- in_channels, # Number of input channels, 0 = first block.
508
- tmp_channels, # Number of intermediate channels.
509
- out_channels, # Number of output channels.
510
- resolution, # Resolution of this block.
511
- img_channels, # Number of input color channels.
512
- first_layer_idx, # Index of the first layer.
513
- architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
514
- activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
515
- resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
516
- conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
517
- use_fp16 = False, # Use FP16 for this block?
518
- fp16_channels_last = False, # Use channels-last memory format with FP16?
519
- freeze_layers = 0, # Freeze-D: Number of layers to freeze.
520
- ):
521
- assert in_channels in [0, tmp_channels]
522
- assert architecture in ['orig', 'skip', 'resnet']
523
- super().__init__()
524
- self.in_channels = in_channels
525
- self.resolution = resolution
526
- self.img_channels = img_channels
527
- self.first_layer_idx = first_layer_idx
528
- self.architecture = architecture
529
- self.use_fp16 = use_fp16
530
- self.channels_last = (use_fp16 and fp16_channels_last)
531
- self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
532
-
533
- self.num_layers = 0
534
- def trainable_gen():
535
- while True:
536
- layer_idx = self.first_layer_idx + self.num_layers
537
- trainable = (layer_idx >= freeze_layers)
538
- self.num_layers += 1
539
- yield trainable
540
- trainable_iter = trainable_gen()
541
-
542
- if in_channels == 0 or architecture == 'skip':
543
- self.fromrgb = Conv2dLayer(img_channels, tmp_channels, kernel_size=1, activation=activation,
544
- trainable=next(trainable_iter), conv_clamp=conv_clamp, channels_last=self.channels_last)
545
-
546
- self.conv0 = Conv2dLayer(tmp_channels, tmp_channels, kernel_size=3, activation=activation,
547
- trainable=next(trainable_iter), conv_clamp=conv_clamp, channels_last=self.channels_last)
548
-
549
- self.conv1 = Conv2dLayer(tmp_channels, out_channels, kernel_size=3, activation=activation, down=2,
550
- trainable=next(trainable_iter), resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last)
551
-
552
- if architecture == 'resnet':
553
- self.skip = Conv2dLayer(tmp_channels, out_channels, kernel_size=1, bias=False, down=2,
554
- trainable=next(trainable_iter), resample_filter=resample_filter, channels_last=self.channels_last)
555
-
556
- def forward(self, x, img, force_fp32=False):
557
- dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
558
- memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
559
-
560
- # Input.
561
- if x is not None:
562
- misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution])
563
- x = x.to(dtype=dtype, memory_format=memory_format)
564
-
565
- # FromRGB.
566
- if self.in_channels == 0 or self.architecture == 'skip':
567
- misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution])
568
- img = img.to(dtype=dtype, memory_format=memory_format)
569
- y = self.fromrgb(img)
570
- x = x + y if x is not None else y
571
- img = upfirdn2d.downsample2d(img, self.resample_filter) if self.architecture == 'skip' else None
572
-
573
- # Main layers.
574
- if self.architecture == 'resnet':
575
- y = self.skip(x, gain=np.sqrt(0.5))
576
- x = self.conv0(x)
577
- x = self.conv1(x, gain=np.sqrt(0.5))
578
- x = y.add_(x)
579
- else:
580
- x = self.conv0(x)
581
- x = self.conv1(x)
582
-
583
- assert x.dtype == dtype
584
- return x, img
585
-
586
- #----------------------------------------------------------------------------
587
-
588
- @persistence.persistent_class
589
- class MinibatchStdLayer(torch.nn.Module):
590
- def __init__(self, group_size, num_channels=1):
591
- super().__init__()
592
- self.group_size = group_size
593
- self.num_channels = num_channels
594
-
595
- def forward(self, x):
596
- N, C, H, W = x.shape
597
- with misc.suppress_tracer_warnings(): # as_tensor results are registered as constants
598
- G = torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N)) if self.group_size is not None else N
599
- F = self.num_channels
600
- c = C // F
601
-
602
- y = x.reshape(G, -1, F, c, H, W) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c.
603
- y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group.
604
- y = y.square().mean(dim=0) # [nFcHW] Calc variance over group.
605
- y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group.
606
- y = y.mean(dim=[2,3,4]) # [nF] Take average over channels and pixels.
607
- y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions.
608
- y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels.
609
- x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels.
610
- return x
611
-
612
- #----------------------------------------------------------------------------
613
-
614
- @persistence.persistent_class
615
- class DiscriminatorEpilogue(torch.nn.Module):
616
- def __init__(self,
617
- in_channels, # Number of input channels.
618
- cmap_dim, # Dimensionality of mapped conditioning label, 0 = no label.
619
- resolution, # Resolution of this block.
620
- img_channels, # Number of input color channels.
621
- architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
622
- mbstd_group_size = 4, # Group size for the minibatch standard deviation layer, None = entire minibatch.
623
- mbstd_num_channels = 1, # Number of features for the minibatch standard deviation layer, 0 = disable.
624
- activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
625
- conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
626
- ):
627
- assert architecture in ['orig', 'skip', 'resnet']
628
- super().__init__()
629
- self.in_channels = in_channels
630
- self.cmap_dim = cmap_dim
631
- self.resolution = resolution
632
- self.img_channels = img_channels
633
- self.architecture = architecture
634
-
635
- if architecture == 'skip':
636
- self.fromrgb = Conv2dLayer(img_channels, in_channels, kernel_size=1, activation=activation)
637
- self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels) if mbstd_num_channels > 0 else None
638
- self.conv = Conv2dLayer(in_channels + mbstd_num_channels, in_channels, kernel_size=3, activation=activation, conv_clamp=conv_clamp)
639
- self.fc = FullyConnectedLayer(in_channels * (resolution ** 2), in_channels, activation=activation)
640
- self.out = FullyConnectedLayer(in_channels, 1 if cmap_dim == 0 else cmap_dim)
641
-
642
- def forward(self, x, img, cmap, force_fp32=False):
643
- misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution]) # [NCHW]
644
- _ = force_fp32 # unused
645
- dtype = torch.float32
646
- memory_format = torch.contiguous_format
647
-
648
- # FromRGB.
649
- x = x.to(dtype=dtype, memory_format=memory_format)
650
- if self.architecture == 'skip':
651
- misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution])
652
- img = img.to(dtype=dtype, memory_format=memory_format)
653
- x = x + self.fromrgb(img)
654
-
655
- # Main layers.
656
- if self.mbstd is not None:
657
- x = self.mbstd(x)
658
- x = self.conv(x)
659
- x = self.fc(x.flatten(1))
660
- x = self.out(x)
661
-
662
- # Conditioning.
663
- if self.cmap_dim > 0:
664
- misc.assert_shape(cmap, [None, self.cmap_dim])
665
- x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
666
-
667
- assert x.dtype == dtype
668
- return x
669
-
670
- #----------------------------------------------------------------------------
671
-
672
- @persistence.persistent_class
673
- class Discriminator(torch.nn.Module):
674
- def __init__(self,
675
- c_dim, # Conditioning label (C) dimensionality.
676
- img_resolution, # Input resolution.
677
- img_channels, # Number of input color channels.
678
- architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
679
- channel_base = 32768, # Overall multiplier for the number of channels.
680
- channel_max = 512, # Maximum number of channels in any layer.
681
- num_fp16_res = 0, # Use FP16 for the N highest resolutions.
682
- conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
683
- cmap_dim = None, # Dimensionality of mapped conditioning label, None = default.
684
- block_kwargs = {}, # Arguments for DiscriminatorBlock.
685
- mapping_kwargs = {}, # Arguments for MappingNetwork.
686
- epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue.
687
- ):
688
- super().__init__()
689
- self.c_dim = c_dim
690
- self.img_resolution = img_resolution
691
- self.img_resolution_log2 = int(np.log2(img_resolution))
692
- self.img_channels = img_channels
693
- self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
694
- channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
695
- fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
696
-
697
- if cmap_dim is None:
698
- cmap_dim = channels_dict[4]
699
- if c_dim == 0:
700
- cmap_dim = 0
701
-
702
- common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
703
- cur_layer_idx = 0
704
- for res in self.block_resolutions:
705
- in_channels = channels_dict[res] if res < img_resolution else 0
706
- tmp_channels = channels_dict[res]
707
- out_channels = channels_dict[res // 2]
708
- use_fp16 = (res >= fp16_resolution)
709
- block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
710
- first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
711
- setattr(self, f'b{res}', block)
712
- cur_layer_idx += block.num_layers
713
- if c_dim > 0:
714
- self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
715
- self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs)
716
-
717
- def forward(self, img, c, **block_kwargs):
718
- x = None
719
- for res in self.block_resolutions:
720
- block = getattr(self, f'b{res}')
721
- x, img = block(x, img, **block_kwargs)
722
-
723
- cmap = None
724
- if self.c_dim > 0:
725
- cmap = self.mapping(None, c)
726
- x = self.b4(x, img, cmap)
727
- return x
728
-
729
- #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training/training_loop.py DELETED
@@ -1,421 +0,0 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # NVIDIA CORPORATION and its licensors retain all intellectual property
4
- # and proprietary rights in and to this software, related documentation
5
- # and any modifications thereto. Any use, reproduction, disclosure or
6
- # distribution of this software and related documentation without an express
7
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
-
9
- import os
10
- import time
11
- import copy
12
- import json
13
- import pickle
14
- import psutil
15
- import PIL.Image
16
- import numpy as np
17
- import torch
18
- import dnnlib
19
- from torch_utils import misc
20
- from torch_utils import training_stats
21
- from torch_utils.ops import conv2d_gradfix
22
- from torch_utils.ops import grid_sample_gradfix
23
-
24
- import legacy
25
- from metrics import metric_main
26
-
27
- #----------------------------------------------------------------------------
28
-
29
- def setup_snapshot_image_grid(training_set, random_seed=0):
30
- rnd = np.random.RandomState(random_seed)
31
- gw = np.clip(7680 // training_set.image_shape[2], 7, 32)
32
- gh = np.clip(4320 // training_set.image_shape[1], 4, 32)
33
-
34
- # No labels => show random subset of training samples.
35
- if not training_set.has_labels:
36
- all_indices = list(range(len(training_set)))
37
- rnd.shuffle(all_indices)
38
- grid_indices = [all_indices[i % len(all_indices)] for i in range(gw * gh)]
39
-
40
- else:
41
- # Group training samples by label.
42
- label_groups = dict() # label => [idx, ...]
43
- for idx in range(len(training_set)):
44
- label = tuple(training_set.get_details(idx).raw_label.flat[::-1])
45
- if label not in label_groups:
46
- label_groups[label] = []
47
- label_groups[label].append(idx)
48
-
49
- # Reorder.
50
- label_order = sorted(label_groups.keys())
51
- for label in label_order:
52
- rnd.shuffle(label_groups[label])
53
-
54
- # Organize into grid.
55
- grid_indices = []
56
- for y in range(gh):
57
- label = label_order[y % len(label_order)]
58
- indices = label_groups[label]
59
- grid_indices += [indices[x % len(indices)] for x in range(gw)]
60
- label_groups[label] = [indices[(i + gw) % len(indices)] for i in range(len(indices))]
61
-
62
- # Load data.
63
- images, labels = zip(*[training_set[i] for i in grid_indices])
64
- return (gw, gh), np.stack(images), np.stack(labels)
65
-
66
- #----------------------------------------------------------------------------
67
-
68
- def save_image_grid(img, fname, drange, grid_size):
69
- lo, hi = drange
70
- img = np.asarray(img, dtype=np.float32)
71
- img = (img - lo) * (255 / (hi - lo))
72
- img = np.rint(img).clip(0, 255).astype(np.uint8)
73
-
74
- gw, gh = grid_size
75
- _N, C, H, W = img.shape
76
- img = img.reshape(gh, gw, C, H, W)
77
- img = img.transpose(0, 3, 1, 4, 2)
78
- img = img.reshape(gh * H, gw * W, C)
79
-
80
- assert C in [1, 3]
81
- if C == 1:
82
- PIL.Image.fromarray(img[:, :, 0], 'L').save(fname)
83
- if C == 3:
84
- PIL.Image.fromarray(img, 'RGB').save(fname)
85
-
86
- #----------------------------------------------------------------------------
87
-
88
- def training_loop(
89
- run_dir = '.', # Output directory.
90
- training_set_kwargs = {}, # Options for training set.
91
- data_loader_kwargs = {}, # Options for torch.utils.data.DataLoader.
92
- G_kwargs = {}, # Options for generator network.
93
- D_kwargs = {}, # Options for discriminator network.
94
- G_opt_kwargs = {}, # Options for generator optimizer.
95
- D_opt_kwargs = {}, # Options for discriminator optimizer.
96
- augment_kwargs = None, # Options for augmentation pipeline. None = disable.
97
- loss_kwargs = {}, # Options for loss function.
98
- metrics = [], # Metrics to evaluate during training.
99
- random_seed = 0, # Global random seed.
100
- num_gpus = 1, # Number of GPUs participating in the training.
101
- rank = 0, # Rank of the current process in [0, num_gpus[.
102
- batch_size = 4, # Total batch size for one training iteration. Can be larger than batch_gpu * num_gpus.
103
- batch_gpu = 4, # Number of samples processed at a time by one GPU.
104
- ema_kimg = 10, # Half-life of the exponential moving average (EMA) of generator weights.
105
- ema_rampup = None, # EMA ramp-up coefficient.
106
- G_reg_interval = 4, # How often to perform regularization for G? None = disable lazy regularization.
107
- D_reg_interval = 16, # How often to perform regularization for D? None = disable lazy regularization.
108
- augment_p = 0, # Initial value of augmentation probability.
109
- ada_target = None, # ADA target value. None = fixed p.
110
- ada_interval = 4, # How often to perform ADA adjustment?
111
- ada_kimg = 500, # ADA adjustment speed, measured in how many kimg it takes for p to increase/decrease by one unit.
112
- total_kimg = 25000, # Total length of the training, measured in thousands of real images.
113
- kimg_per_tick = 4, # Progress snapshot interval.
114
- image_snapshot_ticks = 50, # How often to save image snapshots? None = disable.
115
- network_snapshot_ticks = 50, # How often to save network snapshots? None = disable.
116
- resume_pkl = None, # Network pickle to resume training from.
117
- cudnn_benchmark = True, # Enable torch.backends.cudnn.benchmark?
118
- allow_tf32 = False, # Enable torch.backends.cuda.matmul.allow_tf32 and torch.backends.cudnn.allow_tf32?
119
- abort_fn = None, # Callback function for determining whether to abort training. Must return consistent results across ranks.
120
- progress_fn = None, # Callback function for updating training progress. Called for all ranks.
121
- ):
122
- # Initialize.
123
- start_time = time.time()
124
- device = torch.device('cuda', rank)
125
- np.random.seed(random_seed * num_gpus + rank)
126
- torch.manual_seed(random_seed * num_gpus + rank)
127
- torch.backends.cudnn.benchmark = cudnn_benchmark # Improves training speed.
128
- torch.backends.cuda.matmul.allow_tf32 = allow_tf32 # Allow PyTorch to internally use tf32 for matmul
129
- torch.backends.cudnn.allow_tf32 = allow_tf32 # Allow PyTorch to internally use tf32 for convolutions
130
- conv2d_gradfix.enabled = True # Improves training speed.
131
- grid_sample_gradfix.enabled = True # Avoids errors with the augmentation pipe.
132
-
133
- # Load training set.
134
- if rank == 0:
135
- print('Loading training set...')
136
- training_set = dnnlib.util.construct_class_by_name(**training_set_kwargs) # subclass of training.dataset.Dataset
137
- training_set_sampler = misc.InfiniteSampler(dataset=training_set, rank=rank, num_replicas=num_gpus, seed=random_seed)
138
- training_set_iterator = iter(torch.utils.data.DataLoader(dataset=training_set, sampler=training_set_sampler, batch_size=batch_size//num_gpus, **data_loader_kwargs))
139
- if rank == 0:
140
- print()
141
- print('Num images: ', len(training_set))
142
- print('Image shape:', training_set.image_shape)
143
- print('Label shape:', training_set.label_shape)
144
- print()
145
-
146
- # Construct networks.
147
- if rank == 0:
148
- print('Constructing networks...')
149
- common_kwargs = dict(c_dim=training_set.label_dim, img_resolution=training_set.resolution, img_channels=training_set.num_channels)
150
- G = dnnlib.util.construct_class_by_name(**G_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
151
- D = dnnlib.util.construct_class_by_name(**D_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
152
- G_ema = copy.deepcopy(G).eval()
153
-
154
- # Resume from existing pickle.
155
- if (resume_pkl is not None) and (rank == 0):
156
- print(f'Resuming from "{resume_pkl}"')
157
- with dnnlib.util.open_url(resume_pkl) as f:
158
- resume_data = legacy.load_network_pkl(f)
159
- for name, module in [('G', G), ('D', D), ('G_ema', G_ema)]:
160
- misc.copy_params_and_buffers(resume_data[name], module, require_all=False)
161
-
162
- # Print network summary tables.
163
- if rank == 0:
164
- z = torch.empty([batch_gpu, G.z_dim], device=device)
165
- c = torch.empty([batch_gpu, G.c_dim], device=device)
166
- img = misc.print_module_summary(G, [z, c])
167
- misc.print_module_summary(D, [img, c])
168
-
169
- # Setup augmentation.
170
- if rank == 0:
171
- print('Setting up augmentation...')
172
- augment_pipe = None
173
- ada_stats = None
174
- if (augment_kwargs is not None) and (augment_p > 0 or ada_target is not None):
175
- augment_pipe = dnnlib.util.construct_class_by_name(**augment_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
176
- augment_pipe.p.copy_(torch.as_tensor(augment_p))
177
- if ada_target is not None:
178
- ada_stats = training_stats.Collector(regex='Loss/signs/real')
179
-
180
- # Distribute across GPUs.
181
- if rank == 0:
182
- print(f'Distributing across {num_gpus} GPUs...')
183
- ddp_modules = dict()
184
- for name, module in [('G_mapping', G.mapping), ('G_synthesis', G.synthesis), ('D', D), (None, G_ema), ('augment_pipe', augment_pipe)]:
185
- if (num_gpus > 1) and (module is not None) and len(list(module.parameters())) != 0:
186
- module.requires_grad_(True)
187
- module = torch.nn.parallel.DistributedDataParallel(module, device_ids=[device], broadcast_buffers=False)
188
- module.requires_grad_(False)
189
- if name is not None:
190
- ddp_modules[name] = module
191
-
192
- # Setup training phases.
193
- if rank == 0:
194
- print('Setting up training phases...')
195
- loss = dnnlib.util.construct_class_by_name(device=device, **ddp_modules, **loss_kwargs) # subclass of training.loss.Loss
196
- phases = []
197
- for name, module, opt_kwargs, reg_interval in [('G', G, G_opt_kwargs, G_reg_interval), ('D', D, D_opt_kwargs, D_reg_interval)]:
198
- if reg_interval is None:
199
- opt = dnnlib.util.construct_class_by_name(params=module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer
200
- phases += [dnnlib.EasyDict(name=name+'both', module=module, opt=opt, interval=1)]
201
- else: # Lazy regularization.
202
- mb_ratio = reg_interval / (reg_interval + 1)
203
- opt_kwargs = dnnlib.EasyDict(opt_kwargs)
204
- opt_kwargs.lr = opt_kwargs.lr * mb_ratio
205
- opt_kwargs.betas = [beta ** mb_ratio for beta in opt_kwargs.betas]
206
- opt = dnnlib.util.construct_class_by_name(module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer
207
- phases += [dnnlib.EasyDict(name=name+'main', module=module, opt=opt, interval=1)]
208
- phases += [dnnlib.EasyDict(name=name+'reg', module=module, opt=opt, interval=reg_interval)]
209
- for phase in phases:
210
- phase.start_event = None
211
- phase.end_event = None
212
- if rank == 0:
213
- phase.start_event = torch.cuda.Event(enable_timing=True)
214
- phase.end_event = torch.cuda.Event(enable_timing=True)
215
-
216
- # Export sample images.
217
- grid_size = None
218
- grid_z = None
219
- grid_c = None
220
- if rank == 0:
221
- print('Exporting sample images...')
222
- grid_size, images, labels = setup_snapshot_image_grid(training_set=training_set)
223
- save_image_grid(images, os.path.join(run_dir, 'reals.png'), drange=[0,255], grid_size=grid_size)
224
- grid_z = torch.randn([labels.shape[0], G.z_dim], device=device).split(batch_gpu)
225
- grid_c = torch.from_numpy(labels).to(device).split(batch_gpu)
226
- images = torch.cat([G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c)]).numpy()
227
- save_image_grid(images, os.path.join(run_dir, 'fakes_init.png'), drange=[-1,1], grid_size=grid_size)
228
-
229
- # Initialize logs.
230
- if rank == 0:
231
- print('Initializing logs...')
232
- stats_collector = training_stats.Collector(regex='.*')
233
- stats_metrics = dict()
234
- stats_jsonl = None
235
- stats_tfevents = None
236
- if rank == 0:
237
- stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'wt')
238
- try:
239
- import torch.utils.tensorboard as tensorboard
240
- stats_tfevents = tensorboard.SummaryWriter(run_dir)
241
- except ImportError as err:
242
- print('Skipping tfevents export:', err)
243
-
244
- # Train.
245
- if rank == 0:
246
- print(f'Training for {total_kimg} kimg...')
247
- print()
248
- cur_nimg = 0
249
- cur_tick = 0
250
- tick_start_nimg = cur_nimg
251
- tick_start_time = time.time()
252
- maintenance_time = tick_start_time - start_time
253
- batch_idx = 0
254
- if progress_fn is not None:
255
- progress_fn(0, total_kimg)
256
- while True:
257
-
258
- # Fetch training data.
259
- with torch.autograd.profiler.record_function('data_fetch'):
260
- phase_real_img, phase_real_c = next(training_set_iterator)
261
- phase_real_img = (phase_real_img.to(device).to(torch.float32) / 127.5 - 1).split(batch_gpu)
262
- phase_real_c = phase_real_c.to(device).split(batch_gpu)
263
- all_gen_z = torch.randn([len(phases) * batch_size, G.z_dim], device=device)
264
- all_gen_z = [phase_gen_z.split(batch_gpu) for phase_gen_z in all_gen_z.split(batch_size)]
265
- all_gen_c = [training_set.get_label(np.random.randint(len(training_set))) for _ in range(len(phases) * batch_size)]
266
- all_gen_c = torch.from_numpy(np.stack(all_gen_c)).pin_memory().to(device)
267
- all_gen_c = [phase_gen_c.split(batch_gpu) for phase_gen_c in all_gen_c.split(batch_size)]
268
-
269
- # Execute training phases.
270
- for phase, phase_gen_z, phase_gen_c in zip(phases, all_gen_z, all_gen_c):
271
- if batch_idx % phase.interval != 0:
272
- continue
273
-
274
- # Initialize gradient accumulation.
275
- if phase.start_event is not None:
276
- phase.start_event.record(torch.cuda.current_stream(device))
277
- phase.opt.zero_grad(set_to_none=True)
278
- phase.module.requires_grad_(True)
279
-
280
- # Accumulate gradients over multiple rounds.
281
- for round_idx, (real_img, real_c, gen_z, gen_c) in enumerate(zip(phase_real_img, phase_real_c, phase_gen_z, phase_gen_c)):
282
- sync = (round_idx == batch_size // (batch_gpu * num_gpus) - 1)
283
- gain = phase.interval
284
- loss.accumulate_gradients(phase=phase.name, real_img=real_img, real_c=real_c, gen_z=gen_z, gen_c=gen_c, sync=sync, gain=gain)
285
-
286
- # Update weights.
287
- phase.module.requires_grad_(False)
288
- with torch.autograd.profiler.record_function(phase.name + '_opt'):
289
- for param in phase.module.parameters():
290
- if param.grad is not None:
291
- misc.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad)
292
- phase.opt.step()
293
- if phase.end_event is not None:
294
- phase.end_event.record(torch.cuda.current_stream(device))
295
-
296
- # Update G_ema.
297
- with torch.autograd.profiler.record_function('Gema'):
298
- ema_nimg = ema_kimg * 1000
299
- if ema_rampup is not None:
300
- ema_nimg = min(ema_nimg, cur_nimg * ema_rampup)
301
- ema_beta = 0.5 ** (batch_size / max(ema_nimg, 1e-8))
302
- for p_ema, p in zip(G_ema.parameters(), G.parameters()):
303
- p_ema.copy_(p.lerp(p_ema, ema_beta))
304
- for b_ema, b in zip(G_ema.buffers(), G.buffers()):
305
- b_ema.copy_(b)
306
-
307
- # Update state.
308
- cur_nimg += batch_size
309
- batch_idx += 1
310
-
311
- # Execute ADA heuristic.
312
- if (ada_stats is not None) and (batch_idx % ada_interval == 0):
313
- ada_stats.update()
314
- adjust = np.sign(ada_stats['Loss/signs/real'] - ada_target) * (batch_size * ada_interval) / (ada_kimg * 1000)
315
- augment_pipe.p.copy_((augment_pipe.p + adjust).max(misc.constant(0, device=device)))
316
-
317
- # Perform maintenance tasks once per tick.
318
- done = (cur_nimg >= total_kimg * 1000)
319
- if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000):
320
- continue
321
-
322
- # Print status line, accumulating the same information in stats_collector.
323
- tick_end_time = time.time()
324
- fields = []
325
- fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"]
326
- fields += [f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<8.1f}"]
327
- fields += [f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"]
328
- fields += [f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}"]
329
- fields += [f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}"]
330
- fields += [f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}"]
331
- fields += [f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"]
332
- fields += [f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}"]
333
- torch.cuda.reset_peak_memory_stats()
334
- fields += [f"augment {training_stats.report0('Progress/augment', float(augment_pipe.p.cpu()) if augment_pipe is not None else 0):.3f}"]
335
- training_stats.report0('Timing/total_hours', (tick_end_time - start_time) / (60 * 60))
336
- training_stats.report0('Timing/total_days', (tick_end_time - start_time) / (24 * 60 * 60))
337
- if rank == 0:
338
- print(' '.join(fields))
339
-
340
- # Check for abort.
341
- if (not done) and (abort_fn is not None) and abort_fn():
342
- done = True
343
- if rank == 0:
344
- print()
345
- print('Aborting...')
346
-
347
- # Save image snapshot.
348
- if (rank == 0) and (image_snapshot_ticks is not None) and (done or cur_tick % image_snapshot_ticks == 0):
349
- images = torch.cat([G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c)]).numpy()
350
- save_image_grid(images, os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}.png'), drange=[-1,1], grid_size=grid_size)
351
-
352
- # Save network snapshot.
353
- snapshot_pkl = None
354
- snapshot_data = None
355
- if (network_snapshot_ticks is not None) and (done or cur_tick % network_snapshot_ticks == 0):
356
- snapshot_data = dict(training_set_kwargs=dict(training_set_kwargs))
357
- for name, module in [('G', G), ('D', D), ('G_ema', G_ema), ('augment_pipe', augment_pipe)]:
358
- if module is not None:
359
- if num_gpus > 1:
360
- misc.check_ddp_consistency(module, ignore_regex=r'.*\.w_avg')
361
- module = copy.deepcopy(module).eval().requires_grad_(False).cpu()
362
- snapshot_data[name] = module
363
- del module # conserve memory
364
- snapshot_pkl = os.path.join(run_dir, f'network-snapshot-{cur_nimg//1000:06d}.pkl')
365
- if rank == 0:
366
- with open(snapshot_pkl, 'wb') as f:
367
- pickle.dump(snapshot_data, f)
368
-
369
- # Evaluate metrics.
370
- if (snapshot_data is not None) and (len(metrics) > 0):
371
- if rank == 0:
372
- print('Evaluating metrics...')
373
- for metric in metrics:
374
- result_dict = metric_main.calc_metric(metric=metric, G=snapshot_data['G_ema'],
375
- dataset_kwargs=training_set_kwargs, num_gpus=num_gpus, rank=rank, device=device)
376
- if rank == 0:
377
- metric_main.report_metric(result_dict, run_dir=run_dir, snapshot_pkl=snapshot_pkl)
378
- stats_metrics.update(result_dict.results)
379
- del snapshot_data # conserve memory
380
-
381
- # Collect statistics.
382
- for phase in phases:
383
- value = []
384
- if (phase.start_event is not None) and (phase.end_event is not None):
385
- phase.end_event.synchronize()
386
- value = phase.start_event.elapsed_time(phase.end_event)
387
- training_stats.report0('Timing/' + phase.name, value)
388
- stats_collector.update()
389
- stats_dict = stats_collector.as_dict()
390
-
391
- # Update logs.
392
- timestamp = time.time()
393
- if stats_jsonl is not None:
394
- fields = dict(stats_dict, timestamp=timestamp)
395
- stats_jsonl.write(json.dumps(fields) + '\n')
396
- stats_jsonl.flush()
397
- if stats_tfevents is not None:
398
- global_step = int(cur_nimg / 1e3)
399
- walltime = timestamp - start_time
400
- for name, value in stats_dict.items():
401
- stats_tfevents.add_scalar(name, value.mean, global_step=global_step, walltime=walltime)
402
- for name, value in stats_metrics.items():
403
- stats_tfevents.add_scalar(f'Metrics/{name}', value, global_step=global_step, walltime=walltime)
404
- stats_tfevents.flush()
405
- if progress_fn is not None:
406
- progress_fn(cur_nimg // 1000, total_kimg)
407
-
408
- # Update state.
409
- cur_tick += 1
410
- tick_start_nimg = cur_nimg
411
- tick_start_time = time.time()
412
- maintenance_time = tick_start_time - tick_end_time
413
- if done:
414
- break
415
-
416
- # Done.
417
- if rank == 0:
418
- print()
419
- print('Exiting...')
420
-
421
- #----------------------------------------------------------------------------