Zhendong commited on
Commit
2e04998
·
1 Parent(s): 0e1004f

Initial Commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +172 -3
  2. checkpoints/diffusion-insgen-afhqcat.pkl +3 -0
  3. checkpoints/diffusion-insgen-afhqdog.pkl +3 -0
  4. checkpoints/diffusion-insgen-afhqwild.pkl +3 -0
  5. checkpoints/diffusion-projectedgan-cifar10.pkl +3 -0
  6. checkpoints/diffusion-projectedgan-lsun-bedroom.pkl +3 -0
  7. checkpoints/diffusion-projectedgan-lsun-church.pkl +3 -0
  8. checkpoints/diffusion-projectedgan-stl10.pkl +3 -0
  9. checkpoints/diffusion-stylegan2-celeba64.pkl +3 -0
  10. checkpoints/diffusion-stylegan2-cifar10.pkl +3 -0
  11. checkpoints/diffusion-stylegan2-ffhq.pkl +3 -0
  12. checkpoints/diffusion-stylegan2-lsun-bedroom.pkl +3 -0
  13. checkpoints/diffusion-stylegan2-lsun-church.pkl +3 -0
  14. checkpoints/diffusion-stylegan2-stl10.pkl +3 -0
  15. diffusion-insgen/calc_metrics.py +190 -0
  16. diffusion-insgen/dataset_tool.py +444 -0
  17. diffusion-insgen/dnnlib/__init__.py +9 -0
  18. diffusion-insgen/dnnlib/util.py +477 -0
  19. diffusion-insgen/generate.py +129 -0
  20. diffusion-insgen/legacy.py +332 -0
  21. diffusion-insgen/metrics/__init__.py +9 -0
  22. diffusion-insgen/metrics/frechet_inception_distance.py +41 -0
  23. diffusion-insgen/metrics/inception_score.py +38 -0
  24. diffusion-insgen/metrics/kernel_inception_distance.py +46 -0
  25. diffusion-insgen/metrics/metric_main.py +152 -0
  26. diffusion-insgen/metrics/metric_utils.py +275 -0
  27. diffusion-insgen/metrics/perceptual_path_length.py +131 -0
  28. diffusion-insgen/metrics/precision_recall.py +62 -0
  29. diffusion-insgen/projector.py +212 -0
  30. diffusion-insgen/style_mixing.py +118 -0
  31. diffusion-insgen/torch_utils/__init__.py +2 -0
  32. diffusion-insgen/torch_utils/custom_ops.py +119 -0
  33. diffusion-insgen/torch_utils/misc.py +260 -0
  34. diffusion-insgen/torch_utils/ops/__init__.py +2 -0
  35. diffusion-insgen/torch_utils/ops/bias_act.cpp +99 -0
  36. diffusion-insgen/torch_utils/ops/bias_act.cu +173 -0
  37. diffusion-insgen/torch_utils/ops/bias_act.h +38 -0
  38. diffusion-insgen/torch_utils/ops/bias_act.py +205 -0
  39. diffusion-insgen/torch_utils/ops/conv2d_gradfix.py +172 -0
  40. diffusion-insgen/torch_utils/ops/conv2d_resample.py +149 -0
  41. diffusion-insgen/torch_utils/ops/fma.py +53 -0
  42. diffusion-insgen/torch_utils/ops/grid_sample_gradfix.py +77 -0
  43. diffusion-insgen/torch_utils/ops/upfirdn2d.cpp +103 -0
  44. diffusion-insgen/torch_utils/ops/upfirdn2d.cu +350 -0
  45. diffusion-insgen/torch_utils/ops/upfirdn2d.h +59 -0
  46. diffusion-insgen/torch_utils/ops/upfirdn2d.py +377 -0
  47. diffusion-insgen/torch_utils/persistence.py +244 -0
  48. diffusion-insgen/torch_utils/training_stats.py +261 -0
  49. diffusion-insgen/train.py +605 -0
  50. diffusion-insgen/training/__init__.py +9 -0
README.md CHANGED
@@ -1,3 +1,172 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Diffusion-GAN — Official PyTorch implementation
2
+
3
+ **Diffusion-GAN: Training GANs with Diffusion**<br>
4
+ Zhendong Wang, Huangjie Zheng, Pengcheng He, Weizhu Chen and Mingyuan Zhou <br>
5
+ https://arxiv.org/abs/2206.02262 <br>
6
+
7
+ Abstract: *For stable training of generative adversarial networks (GANs), injecting instance
8
+ noise into the input of the discriminator is considered as a theoretically sound
9
+ solution, which, however, has not yet delivered on its promise in practice. This
10
+ paper introduces Diffusion-GAN that employs a Gaussian mixture distribution,
11
+ defined over all the diffusion steps of a forward diffusion chain, to inject instance
12
+ noise. A random sample from the mixture, which is diffused from an observed
13
+ or generated data, is fed as the input to the discriminator. The generator is
14
+ updated by backpropagating its gradient through the forward diffusion chain,
15
+ whose length is adaptively adjusted to control the maximum noise-to-data ratio
16
+ allowed at each training step. Theoretical analysis verifies the soundness of the
17
+ proposed Diffusion-GAN, which provides model- and domain-agnostic differentiable
18
+ augmentation. A rich set of experiments on diverse datasets show that DiffusionGAN can
19
+ provide stable and data-efficient GAN training, bringing consistent
20
+ performance improvement over strong GAN baselines for synthesizing photorealistic images.*
21
+
22
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/diffusion-gan-training-gans-with-diffusion/image-generation-on-celeba-64x64)](https://paperswithcode.com/sota/image-generation-on-celeba-64x64?p=diffusion-gan-training-gans-with-diffusion)
23
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/diffusion-gan-training-gans-with-diffusion/image-generation-on-stl-10)](https://paperswithcode.com/sota/image-generation-on-stl-10?p=diffusion-gan-training-gans-with-diffusion)
24
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/diffusion-gan-training-gans-with-diffusion/image-generation-on-lsun-bedroom-256-x-256)](https://paperswithcode.com/sota/image-generation-on-lsun-bedroom-256-x-256?p=diffusion-gan-training-gans-with-diffusion)
25
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/diffusion-gan-training-gans-with-diffusion/image-generation-on-afhq-wild)](https://paperswithcode.com/sota/image-generation-on-afhq-wild?p=diffusion-gan-training-gans-with-diffusion)
26
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/diffusion-gan-training-gans-with-diffusion/image-generation-on-afhq-cat)](https://paperswithcode.com/sota/image-generation-on-afhq-cat?p=diffusion-gan-training-gans-with-diffusion)
27
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/diffusion-gan-training-gans-with-diffusion/image-generation-on-afhq-dog)](https://paperswithcode.com/sota/image-generation-on-afhq-dog?p=diffusion-gan-training-gans-with-diffusion)
28
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/diffusion-gan-training-gans-with-diffusion/image-generation-on-lsun-churches-256-x-256)](https://paperswithcode.com/sota/image-generation-on-lsun-churches-256-x-256?p=diffusion-gan-training-gans-with-diffusion)
29
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/diffusion-gan-training-gans-with-diffusion/image-generation-on-ffhq-1024-x-1024)](https://paperswithcode.com/sota/image-generation-on-ffhq-1024-x-1024?p=diffusion-gan-training-gans-with-diffusion)
30
+
31
+ ## ToDos
32
+ - [x] Initial code release
33
+ - [x] Providing pretrained models
34
+
35
+ ## Build your Diffusion-GAN
36
+ Here, we explain how to train general GANs with diffusion. We provide two ways:
37
+ a. plug-in as simple as a data augmentation method;
38
+ b. training GANs on diffusion chains with a timestep-dependent discriminator.
39
+ Currently, we didn't find significant empirical differences of the two approaches,
40
+ while the second approach has stronger theoretical guarantees. We suspect when advanced timestep-dependent structure is applied in the discriminator,
41
+ the second approach could become better, and we left that for future study.
42
+
43
+ ### Simple Plug-in
44
+ * Design a proper diffusion process based on the ```diffusion.py``` file
45
+ * Apply diffusion on the inputs of discriminators,
46
+ ```logits = Discriminator(Diffusion(gen/real_images))```
47
+ * Add adaptiveness of diffusion into your training iterations
48
+ ```
49
+ if update_diffusion: # batch_idx % ada_interval == 0
50
+ adjust = np.sign(sign(Discriminator(real_images)) - ada_target) * C # C = (batch_size * ada_interval) / (ada_kimg * 1000)
51
+ diffusion.p = (diffusion.p + adjust).clip(min=0., max=1.)
52
+ diffusion.update_T()
53
+ ```
54
+
55
+ ### Full Version
56
+ * Add diffusion timestep `t` as an input for discriminators `logits = Discriminator(images, t)`.
57
+ You may need some modifications in your discriminator architecture.
58
+ * The other steps are the same as Simple Plug-in. Note that since discriminator depends on timesteps,
59
+ you need to collect `t`.
60
+ ```
61
+ diffused_images, t = Diffusion(images)
62
+ logits = Discrimnator(diffused_images, t)
63
+ ```
64
+
65
+ ## Train our Diffusion-GAN
66
+
67
+ ### Requirements
68
+ * 64-bit Python 3.7 and PyTorch 1.7.1/1.8.1. See [https://pytorch.org/](https://pytorch.org/) for PyTorch install instructions.
69
+ * CUDA toolkit 11.0 or later.
70
+ * Python libraries: `pip install click requests tqdm pyspng ninja imageio-ffmpeg==0.4.3`.
71
+
72
+ ### Data Preparation
73
+
74
+ In our paper, we trained our model on [CIFAR-10 (32 x 32)](https://www.cs.toronto.edu/~kriz/cifar.html), [STL-10 (64 x 64)](https://cs.stanford.edu/~acoates/stl10/),
75
+ [LSUN (256 x 256)](https://github.com/fyu/lsun), [AFHQ (512 x 512)](https://github.com/clovaai/stargan-v2) and [FFHQ (1024 x 1024)](https://github.com/NVlabs/ffhq-dataset).
76
+ You can download the datasets we used in our paper at their respective websites.
77
+ To prepare the dataset at the respective resolution, run for example
78
+ ```.bash
79
+ python dataset_tool.py --source=~/downloads/lsun/raw/bedroom_lmdb --dest=~/datasets/lsun_bedroom200k.zip \
80
+ --transform=center-crop --width=256 --height=256 --max_images=200000
81
+
82
+ python dataset_tool.py --source=~/downloads/lsun/raw/church_lmdb --dest=~/datasets/lsun_church200k.zip \
83
+ --transform=center-crop-wide --width=256 --height=256 --max_images=200000
84
+ ```
85
+
86
+ ### Training
87
+
88
+ We show the training commands that we used below. In most cases, the training commands are similar, so below we use CIFAR-10 dataset
89
+ as an example:
90
+
91
+ For Diffusion-GAN,
92
+ ```.bash
93
+ python train.py --outdir=training-runs --data="~/cifar10.zip" --gpus=4 --cfg cifar --kimg 50000 --aug no --target 0.6 --noise_sd 0.05 --ts_dist priority
94
+ ```
95
+ For Diffusion-ProjectedGAN
96
+ ```.bash
97
+ python train.py --outdir=training-runs --data="~/cifar10.zip" --gpus=4 --batch 64 --batch-gpu=16 --cfg fastgan --kimg 50000 --target 0.45 --d_pos first --noise_sd 0.5
98
+ ```
99
+ For Diffusion-InsGen
100
+ ```.bash
101
+ python train.py --outdir=training-runs --data="~/afhq-wild.zip" --gpus=8 --cfg paper512 --kimg 25000
102
+ ```
103
+
104
+ We follows the `config` setting from [StyleGAN2-ADA](https://github.com/NVlabs/stylegan2-ada-pytorchhttps://github.com/NVlabs/stylegan2-ada-pytorch)
105
+ and refer to them for more details. The other major hyperparameters are listed and discussed below:
106
+ * `--target` the discriminator target, which balances the level of diffusion intensity.
107
+ * `--aug` domain-specific image augmentation, such as ADA and Differentiable Augmentation, which is used for evaluate complementariness with diffusion.
108
+ * `--noise_sd` diffusion noise standard deviation, which is set as 0.05 in our case.
109
+ * ` --ts_dist` t sampling distribution, $\pi(t)$ in paper.
110
+
111
+ We evaluated two `t` sampling distribution `['priority', 'uniform']`,
112
+ where `'priority'` denotes the Equation (11) in paper and `'uniform'` denotes random sampling. In most cases, `priority` works slightly better, while in some cases, such as FFHQ,
113
+ `'uniform'` is better.
114
+
115
+ ## Sampling and Evaluation with our checkpoints
116
+ We report the FIDs of our Diffusion-GAN below and provide the trained checkpoints in the ``./checkpoints`` folder:
117
+
118
+ | Model | Dataset | Resolution | FID |
119
+ |:---------------------------:|:------------:|:----------:|:-----:|
120
+ | Diffusion-StyleGAN2 | CIFAR-10 | 32x32 | 3.19 |
121
+ | Diffusion-StyleGAN2 | CelebA | 64x64 | 1.69 |
122
+ | Diffusion-StyleGAN2 | STL-10 | 64x64 | 11.53 |
123
+ | Diffusion-StyleGAN2 | LSUN-Bedroom | 256x256 | 3.65 |
124
+ | Diffusion-StyleGAN2 | LSUN-Church | 256x256 | 3.17 |
125
+ | Diffusion-StyleGAN2 | FFHQ | 1024x1024 | 2.83 |
126
+ | Diffusion-ProjectedGAN | CIFAR-10 | 32x32 | 2.54 |
127
+ | Diffusion-ProjectedGAN | STL-10 | 64x64 | 6.91 |
128
+ | Diffusion-ProjectedGAN | LSUN-Bedroom | 256x256 | 1.43 |
129
+ | Diffusion-ProjectedGAN | LSUN-Church | 256x256 | 1.85 |
130
+ | Diffusion-InsGen | AFHQ-Cat | 512x512 | 2.40 |
131
+ | Diffusion-InsGen | AFHQ-Dog | 512x512 | 4.83 |
132
+ | Diffusion-InsGen | AFHQ-Wild | 512x512 | 1.51 |
133
+
134
+
135
+ To generate samples, run the following commands:
136
+
137
+ ```.bash
138
+ # Generate FFHQ with pretrained Diffusion-StyleGAN2
139
+ python generate.py --outdir=out --seeds=1-100 \
140
+ --network=https://tsciencescu.blob.core.windows.net/projectshzheng/DiffusionGAN/diffusion-stylegan2-ffhq.pkl
141
+
142
+ # Generate LSUN-Church with pretrained Diffusion-ProjectedGAN
143
+ python gen_images.py --outdir=out --seeds=1-100 \
144
+ --network=https://tsciencescu.blob.core.windows.net/projectshzheng/DiffusionGAN/diffusion-projectedgan-lsun-church.pkl
145
+ ```
146
+
147
+ The checkpoints can be replaced with any pre-trained Diffusion-GAN checkpoint path downloaded from the table above.
148
+
149
+
150
+ Similarly, the metrics can be calculated with the following commands:
151
+
152
+ ```.bash
153
+ # Pre-trained network pickle: specify dataset explicitly, print result to stdout.
154
+ python calc_metrics.py --metrics=fid50k_full --data=~/datasets/ffhq.zip --mirror=1 \
155
+ --network=https://tsciencescu.blob.core.windows.net/projectshzheng/DiffusionGAN/diffusion-stylegan2-ffhq.pkl
156
+ ```
157
+
158
+ ## Citation
159
+
160
+ ```
161
+ @article{wang2022diffusiongan,
162
+ title = {Diffusion-GAN: Training GANs with Diffusion},
163
+ author = {Wang, Zhendong and Zheng, Huangjie and He, Pengcheng and Chen, Weizhu and Zhou, Mingyuan},
164
+ journal = {arXiv preprint arXiv:2206.02262},
165
+ year = {2022},
166
+ url = {https://arxiv.org/abs/2206.02262}
167
+ }
168
+ ```
169
+
170
+ ## Acknowledgements
171
+
172
+ Our code builds upon the awesome [StyleGAN2-ADA repo](https://github.com/NVlabs/stylegan2-ada-pytorch), [InsGen repo](https://github.com/genforce/insgen) and [ProjectedGAN repo](https://github.com/autonomousvision/projected_gan), respectively by Karras et al, Ceyuan Yang et al and Axel Sauer et al.
checkpoints/diffusion-insgen-afhqcat.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1c92c46b87bbaafc8fb914fb781b1c315d70a1bba99b99b54af7541e3669ca2f
3
+ size 365039489
checkpoints/diffusion-insgen-afhqdog.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f0b1617c0af01a89795654337ca0b2510598c9f0c507760a9ddca63599f42039
3
+ size 365039489
checkpoints/diffusion-insgen-afhqwild.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7efc94c615be9cf76f3cde438bc8e832e397d421a2bbeb40e71918efd60a8e65
3
+ size 365039490
checkpoints/diffusion-projectedgan-cifar10.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3406ca783404806d7a8ee1b1daf9cf7936f143e94a2fa4a54057ed8c662679e0
3
+ size 1788846251
checkpoints/diffusion-projectedgan-lsun-bedroom.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:98aaa6cbfc5cd115fc0afab4bfb8507f0bc0289b4422024f60ab321ae94f5938
3
+ size 1788705080
checkpoints/diffusion-projectedgan-lsun-church.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:573cfb3eefdd78ea1e41dbc2c03effca8f43cb8a794af3b29227894d8a9a0c83
3
+ size 1788704999
checkpoints/diffusion-projectedgan-stl10.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:902cbf4b2282dfa8a2ea5a326a08a6001a93eb766e3f7dbfe1ce55e9109b6d7e
3
+ size 1788846259
checkpoints/diffusion-stylegan2-celeba64.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:921e72e290870affb879bc90fe5334e8cb6d5f90ff486e4d6b540036a5606745
3
+ size 319333518
checkpoints/diffusion-stylegan2-cifar10.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7b828b7cb95c13688256497f7789ecb8f9dc556df32aa9be8815f9ab0e0ffe6a
3
+ size 252092418
checkpoints/diffusion-stylegan2-ffhq.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d425b3b85dbd7b79bdde5e8366a8da4dd6bd8bac0e5a1bb7dd543d86ced685a
3
+ size 391116089
checkpoints/diffusion-stylegan2-lsun-bedroom.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:86805c3573922b686718957aac81729fc8e181d52796f6727b48719e87bbd7e0
3
+ size 305245901
checkpoints/diffusion-stylegan2-lsun-church.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:90dbb7f40beeeac921764a283473e97f3ce06d984b4e394bbdc898f30e3ddf9c
3
+ size 305242727
checkpoints/diffusion-stylegan2-stl10.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:490017d05b39f3249c5a771bc1035bb2a697b0dba9adc820a081a51cf5fad0e1
3
+ size 319325822
diffusion-insgen/calc_metrics.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Calculate quality metrics for previous training run or pretrained network pickle."""
10
+
11
+ import os
12
+ import click
13
+ import json
14
+ import tempfile
15
+ import copy
16
+ import torch
17
+ import dnnlib
18
+
19
+ import legacy
20
+ from metrics import metric_main
21
+ from metrics import metric_utils
22
+ from torch_utils import training_stats
23
+ from torch_utils import custom_ops
24
+ from torch_utils import misc
25
+
26
+ #----------------------------------------------------------------------------
27
+
28
+ def subprocess_fn(rank, args, temp_dir):
29
+ dnnlib.util.Logger(should_flush=True)
30
+
31
+ # Init torch.distributed.
32
+ if args.num_gpus > 1:
33
+ init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init'))
34
+ if os.name == 'nt':
35
+ init_method = 'file:///' + init_file.replace('\\', '/')
36
+ torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus)
37
+ else:
38
+ init_method = f'file://{init_file}'
39
+ torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus)
40
+
41
+ # Init torch_utils.
42
+ sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None
43
+ training_stats.init_multiprocessing(rank=rank, sync_device=sync_device)
44
+ if rank != 0 or not args.verbose:
45
+ custom_ops.verbosity = 'none'
46
+
47
+ # Print network summary.
48
+ device = torch.device('cuda', rank)
49
+ torch.backends.cudnn.benchmark = True
50
+ torch.backends.cuda.matmul.allow_tf32 = False
51
+ torch.backends.cudnn.allow_tf32 = False
52
+ G = copy.deepcopy(args.G).eval().requires_grad_(False).to(device)
53
+ if rank == 0 and args.verbose:
54
+ z = torch.empty([1, G.z_dim], device=device)
55
+ c = torch.empty([1, G.c_dim], device=device)
56
+ misc.print_module_summary(G, [z, c])
57
+
58
+ # Calculate each metric.
59
+ for metric in args.metrics:
60
+ if rank == 0 and args.verbose:
61
+ print(f'Calculating {metric}...')
62
+ progress = metric_utils.ProgressMonitor(verbose=args.verbose)
63
+ result_dict = metric_main.calc_metric(metric=metric, G=G, dataset_kwargs=args.dataset_kwargs,
64
+ num_gpus=args.num_gpus, rank=rank, device=device, progress=progress)
65
+ if rank == 0:
66
+ metric_main.report_metric(result_dict, run_dir=args.run_dir, snapshot_pkl=args.network_pkl)
67
+ if rank == 0 and args.verbose:
68
+ print()
69
+
70
+ # Done.
71
+ if rank == 0 and args.verbose:
72
+ print('Exiting...')
73
+
74
+ #----------------------------------------------------------------------------
75
+
76
+ class CommaSeparatedList(click.ParamType):
77
+ name = 'list'
78
+
79
+ def convert(self, value, param, ctx):
80
+ _ = param, ctx
81
+ if value is None or value.lower() == 'none' or value == '':
82
+ return []
83
+ return value.split(',')
84
+
85
+ #----------------------------------------------------------------------------
86
+
87
+ @click.command()
88
+ @click.pass_context
89
+ @click.option('network_pkl', '--network', help='Network pickle filename or URL', metavar='PATH', required=True)
90
+ @click.option('--metrics', help='Comma-separated list or "none"', type=CommaSeparatedList(), default='fid50k_full', show_default=True)
91
+ @click.option('--data', help='Dataset to evaluate metrics against (directory or zip) [default: same as training data]', metavar='PATH')
92
+ @click.option('--mirror', help='Whether the dataset was augmented with x-flips during training [default: look up]', type=bool, metavar='BOOL')
93
+ @click.option('--gpus', help='Number of GPUs to use', type=int, default=1, metavar='INT', show_default=True)
94
+ @click.option('--verbose', help='Print optional information', type=bool, default=True, metavar='BOOL', show_default=True)
95
+
96
+ def calc_metrics(ctx, network_pkl, metrics, data, mirror, gpus, verbose):
97
+ """Calculate quality metrics for previous training run or pretrained network pickle.
98
+
99
+ Examples:
100
+
101
+ \b
102
+ # Previous training run: look up options automatically, save result to JSONL file.
103
+ python calc_metrics.py --metrics=pr50k3_full \\
104
+ --network=~/training-runs/00000-ffhq10k-res64-auto1/network-snapshot-000000.pkl
105
+
106
+ \b
107
+ # Pre-trained network pickle: specify dataset explicitly, print result to stdout.
108
+ python calc_metrics.py --metrics=fid50k_full --data=~/datasets/ffhq.zip --mirror=1 \\
109
+ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl
110
+
111
+ Available metrics:
112
+
113
+ \b
114
+ ADA paper:
115
+ fid50k_full Frechet inception distance against the full dataset.
116
+ kid50k_full Kernel inception distance against the full dataset.
117
+ pr50k3_full Precision and recall againt the full dataset.
118
+ is50k Inception score for CIFAR-10.
119
+
120
+ \b
121
+ StyleGAN and StyleGAN2 papers:
122
+ fid50k Frechet inception distance against 50k real images.
123
+ kid50k Kernel inception distance against 50k real images.
124
+ pr50k3 Precision and recall against 50k real images.
125
+ ppl2_wend Perceptual path length in W at path endpoints against full image.
126
+ ppl_zfull Perceptual path length in Z for full paths against cropped image.
127
+ ppl_wfull Perceptual path length in W for full paths against cropped image.
128
+ ppl_zend Perceptual path length in Z at path endpoints against cropped image.
129
+ ppl_wend Perceptual path length in W at path endpoints against cropped image.
130
+ """
131
+ dnnlib.util.Logger(should_flush=True)
132
+
133
+ # Validate arguments.
134
+ args = dnnlib.EasyDict(metrics=metrics, num_gpus=gpus, network_pkl=network_pkl, verbose=verbose)
135
+ if not all(metric_main.is_valid_metric(metric) for metric in args.metrics):
136
+ ctx.fail('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics()))
137
+ if not args.num_gpus >= 1:
138
+ ctx.fail('--gpus must be at least 1')
139
+
140
+ # Load network.
141
+ if not dnnlib.util.is_url(network_pkl, allow_file_urls=True) and not os.path.isfile(network_pkl):
142
+ ctx.fail('--network must point to a file or URL')
143
+ if args.verbose:
144
+ print(f'Loading network from "{network_pkl}"...')
145
+ with dnnlib.util.open_url(network_pkl, verbose=args.verbose) as f:
146
+ network_dict = legacy.load_network_pkl(f)
147
+ args.G = network_dict['G_ema'] # subclass of torch.nn.Module
148
+
149
+ # Initialize dataset options.
150
+ if data is not None:
151
+ args.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=data)
152
+ elif network_dict['training_set_kwargs'] is not None:
153
+ args.dataset_kwargs = dnnlib.EasyDict(network_dict['training_set_kwargs'])
154
+ else:
155
+ ctx.fail('Could not look up dataset options; please specify --data')
156
+
157
+ # Finalize dataset options.
158
+ args.dataset_kwargs.resolution = args.G.img_resolution
159
+ args.dataset_kwargs.use_labels = (args.G.c_dim != 0)
160
+ if mirror is not None:
161
+ args.dataset_kwargs.xflip = mirror
162
+
163
+ # Print dataset options.
164
+ if args.verbose:
165
+ print('Dataset options:')
166
+ print(json.dumps(args.dataset_kwargs, indent=2))
167
+
168
+ # Locate run dir.
169
+ args.run_dir = None
170
+ if os.path.isfile(network_pkl):
171
+ pkl_dir = os.path.dirname(network_pkl)
172
+ if os.path.isfile(os.path.join(pkl_dir, 'training_options.json')):
173
+ args.run_dir = pkl_dir
174
+
175
+ # Launch processes.
176
+ if args.verbose:
177
+ print('Launching processes...')
178
+ torch.multiprocessing.set_start_method('spawn')
179
+ with tempfile.TemporaryDirectory() as temp_dir:
180
+ if args.num_gpus == 1:
181
+ subprocess_fn(rank=0, args=args, temp_dir=temp_dir)
182
+ else:
183
+ torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus)
184
+
185
+ #----------------------------------------------------------------------------
186
+
187
+ if __name__ == "__main__":
188
+ calc_metrics() # pylint: disable=no-value-for-parameter
189
+
190
+ #----------------------------------------------------------------------------
diffusion-insgen/dataset_tool.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import functools
10
+ import io
11
+ import json
12
+ import os
13
+ import pickle
14
+ import sys
15
+ import tarfile
16
+ import gzip
17
+ import zipfile
18
+ from pathlib import Path
19
+ from typing import Callable, Optional, Tuple, Union
20
+
21
+ import click
22
+ import numpy as np
23
+ import PIL.Image
24
+ from tqdm import tqdm
25
+
26
+ #----------------------------------------------------------------------------
27
+
28
+ def error(msg):
29
+ print('Error: ' + msg)
30
+ sys.exit(1)
31
+
32
+ #----------------------------------------------------------------------------
33
+
34
+ def maybe_min(a: int, b: Optional[int]) -> int:
35
+ if b is not None:
36
+ return min(a, b)
37
+ return a
38
+
39
+ #----------------------------------------------------------------------------
40
+
41
+ def file_ext(name: Union[str, Path]) -> str:
42
+ return str(name).split('.')[-1]
43
+
44
+ #----------------------------------------------------------------------------
45
+
46
+ def is_image_ext(fname: Union[str, Path]) -> bool:
47
+ ext = file_ext(fname).lower()
48
+ return f'.{ext}' in PIL.Image.EXTENSION # type: ignore
49
+
50
+ #----------------------------------------------------------------------------
51
+
52
+ def open_image_folder(source_dir, *, max_images: Optional[int]):
53
+ input_images = [str(f) for f in sorted(Path(source_dir).rglob('*')) if is_image_ext(f) and os.path.isfile(f)]
54
+
55
+ # Load labels.
56
+ labels = {}
57
+ meta_fname = os.path.join(source_dir, 'dataset.json')
58
+ if os.path.isfile(meta_fname):
59
+ with open(meta_fname, 'r') as file:
60
+ labels = json.load(file)['labels']
61
+ if labels is not None:
62
+ labels = { x[0]: x[1] for x in labels }
63
+ else:
64
+ labels = {}
65
+
66
+ max_idx = maybe_min(len(input_images), max_images)
67
+
68
+ def iterate_images():
69
+ for idx, fname in enumerate(input_images):
70
+ arch_fname = os.path.relpath(fname, source_dir)
71
+ arch_fname = arch_fname.replace('\\', '/')
72
+ img = np.array(PIL.Image.open(fname))
73
+ yield dict(img=img, label=labels.get(arch_fname))
74
+ if idx >= max_idx-1:
75
+ break
76
+ return max_idx, iterate_images()
77
+
78
+ #----------------------------------------------------------------------------
79
+
80
+ def open_image_zip(source, *, max_images: Optional[int]):
81
+ with zipfile.ZipFile(source, mode='r') as z:
82
+ input_images = [str(f) for f in sorted(z.namelist()) if is_image_ext(f)]
83
+
84
+ # Load labels.
85
+ labels = {}
86
+ if 'dataset.json' in z.namelist():
87
+ with z.open('dataset.json', 'r') as file:
88
+ labels = json.load(file)['labels']
89
+ if labels is not None:
90
+ labels = { x[0]: x[1] for x in labels }
91
+ else:
92
+ labels = {}
93
+
94
+ max_idx = maybe_min(len(input_images), max_images)
95
+
96
+ def iterate_images():
97
+ with zipfile.ZipFile(source, mode='r') as z:
98
+ for idx, fname in enumerate(input_images):
99
+ with z.open(fname, 'r') as file:
100
+ img = PIL.Image.open(file) # type: ignore
101
+ img = np.array(img)
102
+ yield dict(img=img, label=labels.get(fname))
103
+ if idx >= max_idx-1:
104
+ break
105
+ return max_idx, iterate_images()
106
+
107
+ #----------------------------------------------------------------------------
108
+
109
+ def open_lmdb(lmdb_dir: str, *, max_images: Optional[int]):
110
+ import cv2 # pip install opencv-python
111
+ import lmdb # pip install lmdb # pylint: disable=import-error
112
+
113
+ with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn:
114
+ max_idx = maybe_min(txn.stat()['entries'], max_images)
115
+
116
+ def iterate_images():
117
+ with lmdb.open(lmdb_dir, readonly=True, lock=False).begin(write=False) as txn:
118
+ for idx, (_key, value) in enumerate(txn.cursor()):
119
+ try:
120
+ try:
121
+ img = cv2.imdecode(np.frombuffer(value, dtype=np.uint8), 1)
122
+ if img is None:
123
+ raise IOError('cv2.imdecode failed')
124
+ img = img[:, :, ::-1] # BGR => RGB
125
+ except IOError:
126
+ img = np.array(PIL.Image.open(io.BytesIO(value)))
127
+ yield dict(img=img, label=None)
128
+ if idx >= max_idx-1:
129
+ break
130
+ except:
131
+ print(sys.exc_info()[1])
132
+
133
+ return max_idx, iterate_images()
134
+
135
+ #----------------------------------------------------------------------------
136
+
137
+ def open_cifar10(tarball: str, *, max_images: Optional[int]):
138
+ images = []
139
+ labels = []
140
+
141
+ with tarfile.open(tarball, 'r:gz') as tar:
142
+ for batch in range(1, 6):
143
+ member = tar.getmember(f'cifar-10-batches-py/data_batch_{batch}')
144
+ with tar.extractfile(member) as file:
145
+ data = pickle.load(file, encoding='latin1')
146
+ images.append(data['data'].reshape(-1, 3, 32, 32))
147
+ labels.append(data['labels'])
148
+
149
+ images = np.concatenate(images)
150
+ labels = np.concatenate(labels)
151
+ images = images.transpose([0, 2, 3, 1]) # NCHW -> NHWC
152
+ assert images.shape == (50000, 32, 32, 3) and images.dtype == np.uint8
153
+ assert labels.shape == (50000,) and labels.dtype in [np.int32, np.int64]
154
+ assert np.min(images) == 0 and np.max(images) == 255
155
+ assert np.min(labels) == 0 and np.max(labels) == 9
156
+
157
+ max_idx = maybe_min(len(images), max_images)
158
+
159
+ def iterate_images():
160
+ for idx, img in enumerate(images):
161
+ yield dict(img=img, label=int(labels[idx]))
162
+ if idx >= max_idx-1:
163
+ break
164
+
165
+ return max_idx, iterate_images()
166
+
167
+ #----------------------------------------------------------------------------
168
+
169
+ def open_mnist(images_gz: str, *, max_images: Optional[int]):
170
+ labels_gz = images_gz.replace('-images-idx3-ubyte.gz', '-labels-idx1-ubyte.gz')
171
+ assert labels_gz != images_gz
172
+ images = []
173
+ labels = []
174
+
175
+ with gzip.open(images_gz, 'rb') as f:
176
+ images = np.frombuffer(f.read(), np.uint8, offset=16)
177
+ with gzip.open(labels_gz, 'rb') as f:
178
+ labels = np.frombuffer(f.read(), np.uint8, offset=8)
179
+
180
+ images = images.reshape(-1, 28, 28)
181
+ images = np.pad(images, [(0,0), (2,2), (2,2)], 'constant', constant_values=0)
182
+ assert images.shape == (60000, 32, 32) and images.dtype == np.uint8
183
+ assert labels.shape == (60000,) and labels.dtype == np.uint8
184
+ assert np.min(images) == 0 and np.max(images) == 255
185
+ assert np.min(labels) == 0 and np.max(labels) == 9
186
+
187
+ max_idx = maybe_min(len(images), max_images)
188
+
189
+ def iterate_images():
190
+ for idx, img in enumerate(images):
191
+ yield dict(img=img, label=int(labels[idx]))
192
+ if idx >= max_idx-1:
193
+ break
194
+
195
+ return max_idx, iterate_images()
196
+
197
+ #----------------------------------------------------------------------------
198
+
199
+ def make_transform(
200
+ transform: Optional[str],
201
+ output_width: Optional[int],
202
+ output_height: Optional[int],
203
+ resize_filter: str
204
+ ) -> Callable[[np.ndarray], Optional[np.ndarray]]:
205
+ resample = { 'box': PIL.Image.BOX, 'lanczos': PIL.Image.LANCZOS }[resize_filter]
206
+ def scale(width, height, img):
207
+ w = img.shape[1]
208
+ h = img.shape[0]
209
+ if width == w and height == h:
210
+ return img
211
+ img = PIL.Image.fromarray(img)
212
+ ww = width if width is not None else w
213
+ hh = height if height is not None else h
214
+ img = img.resize((ww, hh), resample)
215
+ return np.array(img)
216
+
217
+ def center_crop(width, height, img):
218
+ crop = np.min(img.shape[:2])
219
+ img = img[(img.shape[0] - crop) // 2 : (img.shape[0] + crop) // 2, (img.shape[1] - crop) // 2 : (img.shape[1] + crop) // 2]
220
+ img = PIL.Image.fromarray(img, 'RGB')
221
+ img = img.resize((width, height), resample)
222
+ return np.array(img)
223
+
224
+ def center_crop_wide(width, height, img):
225
+ ch = int(np.round(width * img.shape[0] / img.shape[1]))
226
+ if img.shape[1] < width or ch < height:
227
+ return None
228
+
229
+ img = img[(img.shape[0] - ch) // 2 : (img.shape[0] + ch) // 2]
230
+ img = PIL.Image.fromarray(img, 'RGB')
231
+ img = img.resize((width, height), resample)
232
+ img = np.array(img)
233
+
234
+ canvas = np.zeros([width, width, 3], dtype=np.uint8)
235
+ canvas[(width - height) // 2 : (width + height) // 2, :] = img
236
+ return canvas
237
+
238
+ if transform is None:
239
+ return functools.partial(scale, output_width, output_height)
240
+ if transform == 'center-crop':
241
+ if (output_width is None) or (output_height is None):
242
+ error ('must specify --width and --height when using ' + transform + 'transform')
243
+ return functools.partial(center_crop, output_width, output_height)
244
+ if transform == 'center-crop-wide':
245
+ if (output_width is None) or (output_height is None):
246
+ error ('must specify --width and --height when using ' + transform + ' transform')
247
+ return functools.partial(center_crop_wide, output_width, output_height)
248
+ assert False, 'unknown transform'
249
+
250
+ #----------------------------------------------------------------------------
251
+
252
+ def open_dataset(source, *, max_images: Optional[int]):
253
+ if os.path.isdir(source):
254
+ if source.rstrip('/').endswith('_lmdb'):
255
+ return open_lmdb(source, max_images=max_images)
256
+ else:
257
+ return open_image_folder(source, max_images=max_images)
258
+ elif os.path.isfile(source):
259
+ if os.path.basename(source) == 'cifar-10-python.tar.gz':
260
+ return open_cifar10(source, max_images=max_images)
261
+ elif os.path.basename(source) == 'train-images-idx3-ubyte.gz':
262
+ return open_mnist(source, max_images=max_images)
263
+ elif file_ext(source) == 'zip':
264
+ return open_image_zip(source, max_images=max_images)
265
+ else:
266
+ assert False, 'unknown archive type'
267
+ else:
268
+ error(f'Missing input file or directory: {source}')
269
+
270
+ #----------------------------------------------------------------------------
271
+
272
+ def open_dest(dest: str) -> Tuple[str, Callable[[str, Union[bytes, str]], None], Callable[[], None]]:
273
+ dest_ext = file_ext(dest)
274
+
275
+ if dest_ext == 'zip':
276
+ if os.path.dirname(dest) != '':
277
+ os.makedirs(os.path.dirname(dest), exist_ok=True)
278
+ zf = zipfile.ZipFile(file=dest, mode='w', compression=zipfile.ZIP_STORED)
279
+ def zip_write_bytes(fname: str, data: Union[bytes, str]):
280
+ zf.writestr(fname, data)
281
+ return '', zip_write_bytes, zf.close
282
+ else:
283
+ # If the output folder already exists, check that is is
284
+ # empty.
285
+ #
286
+ # Note: creating the output directory is not strictly
287
+ # necessary as folder_write_bytes() also mkdirs, but it's better
288
+ # to give an error message earlier in case the dest folder
289
+ # somehow cannot be created.
290
+ if os.path.isdir(dest) and len(os.listdir(dest)) != 0:
291
+ error('--dest folder must be empty')
292
+ os.makedirs(dest, exist_ok=True)
293
+
294
+ def folder_write_bytes(fname: str, data: Union[bytes, str]):
295
+ os.makedirs(os.path.dirname(fname), exist_ok=True)
296
+ with open(fname, 'wb') as fout:
297
+ if isinstance(data, str):
298
+ data = data.encode('utf8')
299
+ fout.write(data)
300
+ return dest, folder_write_bytes, lambda: None
301
+
302
+ #----------------------------------------------------------------------------
303
+
304
+ @click.command()
305
+ @click.pass_context
306
+ @click.option('--source', help='Directory or archive name for input dataset', required=True, metavar='PATH')
307
+ @click.option('--dest', help='Output directory or archive name for output dataset', required=True, metavar='PATH')
308
+ @click.option('--max-images', help='Output only up to `max-images` images', type=int, default=None)
309
+ @click.option('--resize-filter', help='Filter to use when resizing images for output resolution', type=click.Choice(['box', 'lanczos']), default='lanczos', show_default=True)
310
+ @click.option('--transform', help='Input crop/resize mode', type=click.Choice(['center-crop', 'center-crop-wide']))
311
+ @click.option('--width', help='Output width', type=int)
312
+ @click.option('--height', help='Output height', type=int)
313
+ def convert_dataset(
314
+ ctx: click.Context,
315
+ source: str,
316
+ dest: str,
317
+ max_images: Optional[int],
318
+ transform: Optional[str],
319
+ resize_filter: str,
320
+ width: Optional[int],
321
+ height: Optional[int]
322
+ ):
323
+ """Convert an image dataset into a dataset archive usable with StyleGAN2 ADA PyTorch.
324
+
325
+ The input dataset format is guessed from the --source argument:
326
+
327
+ \b
328
+ --source *_lmdb/ Load LSUN dataset
329
+ --source cifar-10-python.tar.gz Load CIFAR-10 dataset
330
+ --source train-images-idx3-ubyte.gz Load MNIST dataset
331
+ --source path/ Recursively load all images from path/
332
+ --source dataset.zip Recursively load all images from dataset.zip
333
+
334
+ Specifying the output format and path:
335
+
336
+ \b
337
+ --dest /path/to/dir Save output files under /path/to/dir
338
+ --dest /path/to/dataset.zip Save output files into /path/to/dataset.zip
339
+
340
+ The output dataset format can be either an image folder or an uncompressed zip archive.
341
+ Zip archives makes it easier to move datasets around file servers and clusters, and may
342
+ offer better training performance on network file systems.
343
+
344
+ Images within the dataset archive will be stored as uncompressed PNG.
345
+ Uncompresed PNGs can be efficiently decoded in the training loop.
346
+
347
+ Class labels are stored in a file called 'dataset.json' that is stored at the
348
+ dataset root folder. This file has the following structure:
349
+
350
+ \b
351
+ {
352
+ "labels": [
353
+ ["00000/img00000000.png",6],
354
+ ["00000/img00000001.png",9],
355
+ ... repeated for every image in the datase
356
+ ["00049/img00049999.png",1]
357
+ ]
358
+ }
359
+
360
+ If the 'dataset.json' file cannot be found, the dataset is interpreted as
361
+ not containing class labels.
362
+
363
+ Image scale/crop and resolution requirements:
364
+
365
+ Output images must be square-shaped and they must all have the same power-of-two
366
+ dimensions.
367
+
368
+ To scale arbitrary input image size to a specific width and height, use the
369
+ --width and --height options. Output resolution will be either the original
370
+ input resolution (if --width/--height was not specified) or the one specified with
371
+ --width/height.
372
+
373
+ Use the --transform=center-crop or --transform=center-crop-wide options to apply a
374
+ center crop transform on the input image. These options should be used with the
375
+ --width and --height options. For example:
376
+
377
+ \b
378
+ python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \\
379
+ --transform=center-crop-wide --width 512 --height=384
380
+ """
381
+
382
+ PIL.Image.init() # type: ignore
383
+
384
+ if dest == '':
385
+ ctx.fail('--dest output filename or directory must not be an empty string')
386
+
387
+ num_files, input_iter = open_dataset(source, max_images=max_images)
388
+ archive_root_dir, save_bytes, close_dest = open_dest(dest)
389
+
390
+ transform_image = make_transform(transform, width, height, resize_filter)
391
+
392
+ dataset_attrs = None
393
+
394
+ labels = []
395
+ for idx, image in tqdm(enumerate(input_iter), total=num_files):
396
+ idx_str = f'{idx:08d}'
397
+ archive_fname = f'{idx_str[:5]}/img{idx_str}.png'
398
+
399
+ # Apply crop and resize.
400
+ img = transform_image(image['img'])
401
+
402
+ # Transform may drop images.
403
+ if img is None:
404
+ continue
405
+
406
+ # Error check to require uniform image attributes across
407
+ # the whole dataset.
408
+ channels = img.shape[2] if img.ndim == 3 else 1
409
+ cur_image_attrs = {
410
+ 'width': img.shape[1],
411
+ 'height': img.shape[0],
412
+ 'channels': channels
413
+ }
414
+ if dataset_attrs is None:
415
+ dataset_attrs = cur_image_attrs
416
+ width = dataset_attrs['width']
417
+ height = dataset_attrs['height']
418
+ if width != height:
419
+ error(f'Image dimensions after scale and crop are required to be square. Got {width}x{height}')
420
+ if dataset_attrs['channels'] not in [1, 3]:
421
+ error('Input images must be stored as RGB or grayscale')
422
+ if width != 2 ** int(np.floor(np.log2(width))):
423
+ error('Image width/height after scale and crop are required to be power-of-two')
424
+ elif dataset_attrs != cur_image_attrs:
425
+ err = [f' dataset {k}/cur image {k}: {dataset_attrs[k]}/{cur_image_attrs[k]}' for k in dataset_attrs.keys()]
426
+ error(f'Image {archive_fname} attributes must be equal across all images of the dataset. Got:\n' + '\n'.join(err))
427
+
428
+ # Save the image as an uncompressed PNG.
429
+ img = PIL.Image.fromarray(img, { 1: 'L', 3: 'RGB' }[channels])
430
+ image_bits = io.BytesIO()
431
+ img.save(image_bits, format='png', compress_level=0, optimize=False)
432
+ save_bytes(os.path.join(archive_root_dir, archive_fname), image_bits.getbuffer())
433
+ labels.append([archive_fname, image['label']] if image['label'] is not None else None)
434
+
435
+ metadata = {
436
+ 'labels': labels if all(x is not None for x in labels) else None
437
+ }
438
+ save_bytes(os.path.join(archive_root_dir, 'dataset.json'), json.dumps(metadata))
439
+ close_dest()
440
+
441
+ #----------------------------------------------------------------------------
442
+
443
+ if __name__ == "__main__":
444
+ convert_dataset() # pylint: disable=no-value-for-parameter
diffusion-insgen/dnnlib/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ from .util import EasyDict, make_cache_dir_path
diffusion-insgen/dnnlib/util.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Miscellaneous utility classes and functions."""
10
+
11
+ import ctypes
12
+ import fnmatch
13
+ import importlib
14
+ import inspect
15
+ import numpy as np
16
+ import os
17
+ import shutil
18
+ import sys
19
+ import types
20
+ import io
21
+ import pickle
22
+ import re
23
+ import requests
24
+ import html
25
+ import hashlib
26
+ import glob
27
+ import tempfile
28
+ import urllib
29
+ import urllib.request
30
+ import uuid
31
+
32
+ from distutils.util import strtobool
33
+ from typing import Any, List, Tuple, Union
34
+
35
+
36
+ # Util classes
37
+ # ------------------------------------------------------------------------------------------
38
+
39
+
40
+ class EasyDict(dict):
41
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
42
+
43
+ def __getattr__(self, name: str) -> Any:
44
+ try:
45
+ return self[name]
46
+ except KeyError:
47
+ raise AttributeError(name)
48
+
49
+ def __setattr__(self, name: str, value: Any) -> None:
50
+ self[name] = value
51
+
52
+ def __delattr__(self, name: str) -> None:
53
+ del self[name]
54
+
55
+
56
+ class Logger(object):
57
+ """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
58
+
59
+ def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
60
+ self.file = None
61
+
62
+ if file_name is not None:
63
+ self.file = open(file_name, file_mode)
64
+
65
+ self.should_flush = should_flush
66
+ self.stdout = sys.stdout
67
+ self.stderr = sys.stderr
68
+
69
+ sys.stdout = self
70
+ sys.stderr = self
71
+
72
+ def __enter__(self) -> "Logger":
73
+ return self
74
+
75
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
76
+ self.close()
77
+
78
+ def write(self, text: Union[str, bytes]) -> None:
79
+ """Write text to stdout (and a file) and optionally flush."""
80
+ if isinstance(text, bytes):
81
+ text = text.decode()
82
+ if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
83
+ return
84
+
85
+ if self.file is not None:
86
+ self.file.write(text)
87
+
88
+ self.stdout.write(text)
89
+
90
+ if self.should_flush:
91
+ self.flush()
92
+
93
+ def flush(self) -> None:
94
+ """Flush written text to both stdout and a file, if open."""
95
+ if self.file is not None:
96
+ self.file.flush()
97
+
98
+ self.stdout.flush()
99
+
100
+ def close(self) -> None:
101
+ """Flush, close possible files, and remove stdout/stderr mirroring."""
102
+ self.flush()
103
+
104
+ # if using multiple loggers, prevent closing in wrong order
105
+ if sys.stdout is self:
106
+ sys.stdout = self.stdout
107
+ if sys.stderr is self:
108
+ sys.stderr = self.stderr
109
+
110
+ if self.file is not None:
111
+ self.file.close()
112
+ self.file = None
113
+
114
+
115
+ # Cache directories
116
+ # ------------------------------------------------------------------------------------------
117
+
118
+ _dnnlib_cache_dir = None
119
+
120
+ def set_cache_dir(path: str) -> None:
121
+ global _dnnlib_cache_dir
122
+ _dnnlib_cache_dir = path
123
+
124
+ def make_cache_dir_path(*paths: str) -> str:
125
+ if _dnnlib_cache_dir is not None:
126
+ return os.path.join(_dnnlib_cache_dir, *paths)
127
+ if 'DNNLIB_CACHE_DIR' in os.environ:
128
+ return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
129
+ if 'HOME' in os.environ:
130
+ return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
131
+ if 'USERPROFILE' in os.environ:
132
+ return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
133
+ return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
134
+
135
+ # Small util functions
136
+ # ------------------------------------------------------------------------------------------
137
+
138
+
139
+ def format_time(seconds: Union[int, float]) -> str:
140
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
141
+ s = int(np.rint(seconds))
142
+
143
+ if s < 60:
144
+ return "{0}s".format(s)
145
+ elif s < 60 * 60:
146
+ return "{0}m {1:02}s".format(s // 60, s % 60)
147
+ elif s < 24 * 60 * 60:
148
+ return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
149
+ else:
150
+ return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
151
+
152
+
153
+ def ask_yes_no(question: str) -> bool:
154
+ """Ask the user the question until the user inputs a valid answer."""
155
+ while True:
156
+ try:
157
+ print("{0} [y/n]".format(question))
158
+ return strtobool(input().lower())
159
+ except ValueError:
160
+ pass
161
+
162
+
163
+ def tuple_product(t: Tuple) -> Any:
164
+ """Calculate the product of the tuple elements."""
165
+ result = 1
166
+
167
+ for v in t:
168
+ result *= v
169
+
170
+ return result
171
+
172
+
173
+ _str_to_ctype = {
174
+ "uint8": ctypes.c_ubyte,
175
+ "uint16": ctypes.c_uint16,
176
+ "uint32": ctypes.c_uint32,
177
+ "uint64": ctypes.c_uint64,
178
+ "int8": ctypes.c_byte,
179
+ "int16": ctypes.c_int16,
180
+ "int32": ctypes.c_int32,
181
+ "int64": ctypes.c_int64,
182
+ "float32": ctypes.c_float,
183
+ "float64": ctypes.c_double
184
+ }
185
+
186
+
187
+ def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
188
+ """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
189
+ type_str = None
190
+
191
+ if isinstance(type_obj, str):
192
+ type_str = type_obj
193
+ elif hasattr(type_obj, "__name__"):
194
+ type_str = type_obj.__name__
195
+ elif hasattr(type_obj, "name"):
196
+ type_str = type_obj.name
197
+ else:
198
+ raise RuntimeError("Cannot infer type name from input")
199
+
200
+ assert type_str in _str_to_ctype.keys()
201
+
202
+ my_dtype = np.dtype(type_str)
203
+ my_ctype = _str_to_ctype[type_str]
204
+
205
+ assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
206
+
207
+ return my_dtype, my_ctype
208
+
209
+
210
+ def is_pickleable(obj: Any) -> bool:
211
+ try:
212
+ with io.BytesIO() as stream:
213
+ pickle.dump(obj, stream)
214
+ return True
215
+ except:
216
+ return False
217
+
218
+
219
+ # Functionality to import modules/objects by name, and call functions by name
220
+ # ------------------------------------------------------------------------------------------
221
+
222
+ def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
223
+ """Searches for the underlying module behind the name to some python object.
224
+ Returns the module and the object name (original name with module part removed)."""
225
+
226
+ # allow convenience shorthands, substitute them by full names
227
+ obj_name = re.sub("^np.", "numpy.", obj_name)
228
+ obj_name = re.sub("^tf.", "tensorflow.", obj_name)
229
+
230
+ # list alternatives for (module_name, local_obj_name)
231
+ parts = obj_name.split(".")
232
+ name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
233
+
234
+ # try each alternative in turn
235
+ for module_name, local_obj_name in name_pairs:
236
+ try:
237
+ module = importlib.import_module(module_name) # may raise ImportError
238
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
239
+ return module, local_obj_name
240
+ except:
241
+ pass
242
+
243
+ # maybe some of the modules themselves contain errors?
244
+ for module_name, _local_obj_name in name_pairs:
245
+ try:
246
+ importlib.import_module(module_name) # may raise ImportError
247
+ except ImportError:
248
+ if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
249
+ raise
250
+
251
+ # maybe the requested attribute is missing?
252
+ for module_name, local_obj_name in name_pairs:
253
+ try:
254
+ module = importlib.import_module(module_name) # may raise ImportError
255
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
256
+ except ImportError:
257
+ pass
258
+
259
+ # we are out of luck, but we have no idea why
260
+ raise ImportError(obj_name)
261
+
262
+
263
+ def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
264
+ """Traverses the object name and returns the last (rightmost) python object."""
265
+ if obj_name == '':
266
+ return module
267
+ obj = module
268
+ for part in obj_name.split("."):
269
+ obj = getattr(obj, part)
270
+ return obj
271
+
272
+
273
+ def get_obj_by_name(name: str) -> Any:
274
+ """Finds the python object with the given name."""
275
+ module, obj_name = get_module_from_obj_name(name)
276
+ return get_obj_from_module(module, obj_name)
277
+
278
+
279
+ def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
280
+ """Finds the python object with the given name and calls it as a function."""
281
+ assert func_name is not None
282
+ func_obj = get_obj_by_name(func_name)
283
+ assert callable(func_obj)
284
+ return func_obj(*args, **kwargs)
285
+
286
+
287
+ def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
288
+ """Finds the python class with the given name and constructs it with the given arguments."""
289
+ return call_func_by_name(*args, func_name=class_name, **kwargs)
290
+
291
+
292
+ def get_module_dir_by_obj_name(obj_name: str) -> str:
293
+ """Get the directory path of the module containing the given object name."""
294
+ module, _ = get_module_from_obj_name(obj_name)
295
+ return os.path.dirname(inspect.getfile(module))
296
+
297
+
298
+ def is_top_level_function(obj: Any) -> bool:
299
+ """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
300
+ return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
301
+
302
+
303
+ def get_top_level_function_name(obj: Any) -> str:
304
+ """Return the fully-qualified name of a top-level function."""
305
+ assert is_top_level_function(obj)
306
+ module = obj.__module__
307
+ if module == '__main__':
308
+ module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
309
+ return module + "." + obj.__name__
310
+
311
+
312
+ # File system helpers
313
+ # ------------------------------------------------------------------------------------------
314
+
315
+ def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
316
+ """List all files recursively in a given directory while ignoring given file and directory names.
317
+ Returns list of tuples containing both absolute and relative paths."""
318
+ assert os.path.isdir(dir_path)
319
+ base_name = os.path.basename(os.path.normpath(dir_path))
320
+
321
+ if ignores is None:
322
+ ignores = []
323
+
324
+ result = []
325
+
326
+ for root, dirs, files in os.walk(dir_path, topdown=True):
327
+ for ignore_ in ignores:
328
+ dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
329
+
330
+ # dirs need to be edited in-place
331
+ for d in dirs_to_remove:
332
+ dirs.remove(d)
333
+
334
+ files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
335
+
336
+ absolute_paths = [os.path.join(root, f) for f in files]
337
+ relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
338
+
339
+ if add_base_to_relative:
340
+ relative_paths = [os.path.join(base_name, p) for p in relative_paths]
341
+
342
+ assert len(absolute_paths) == len(relative_paths)
343
+ result += zip(absolute_paths, relative_paths)
344
+
345
+ return result
346
+
347
+
348
+ def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
349
+ """Takes in a list of tuples of (src, dst) paths and copies files.
350
+ Will create all necessary directories."""
351
+ for file in files:
352
+ target_dir_name = os.path.dirname(file[1])
353
+
354
+ # will create all intermediate-level directories
355
+ if not os.path.exists(target_dir_name):
356
+ os.makedirs(target_dir_name)
357
+
358
+ shutil.copyfile(file[0], file[1])
359
+
360
+
361
+ # URL helpers
362
+ # ------------------------------------------------------------------------------------------
363
+
364
+ def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
365
+ """Determine whether the given object is a valid URL string."""
366
+ if not isinstance(obj, str) or not "://" in obj:
367
+ return False
368
+ if allow_file_urls and obj.startswith('file://'):
369
+ return True
370
+ try:
371
+ res = requests.compat.urlparse(obj)
372
+ if not res.scheme or not res.netloc or not "." in res.netloc:
373
+ return False
374
+ res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
375
+ if not res.scheme or not res.netloc or not "." in res.netloc:
376
+ return False
377
+ except:
378
+ return False
379
+ return True
380
+
381
+
382
+ def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
383
+ """Download the given URL and return a binary-mode file object to access the data."""
384
+ assert num_attempts >= 1
385
+ assert not (return_filename and (not cache))
386
+
387
+ # Doesn't look like an URL scheme so interpret it as a local filename.
388
+ if not re.match('^[a-z]+://', url):
389
+ return url if return_filename else open(url, "rb")
390
+
391
+ # Handle file URLs. This code handles unusual file:// patterns that
392
+ # arise on Windows:
393
+ #
394
+ # file:///c:/foo.txt
395
+ #
396
+ # which would translate to a local '/c:/foo.txt' filename that's
397
+ # invalid. Drop the forward slash for such pathnames.
398
+ #
399
+ # If you touch this code path, you should test it on both Linux and
400
+ # Windows.
401
+ #
402
+ # Some internet resources suggest using urllib.request.url2pathname() but
403
+ # but that converts forward slashes to backslashes and this causes
404
+ # its own set of problems.
405
+ if url.startswith('file://'):
406
+ filename = urllib.parse.urlparse(url).path
407
+ if re.match(r'^/[a-zA-Z]:', filename):
408
+ filename = filename[1:]
409
+ return filename if return_filename else open(filename, "rb")
410
+
411
+ assert is_url(url)
412
+
413
+ # Lookup from cache.
414
+ if cache_dir is None:
415
+ cache_dir = make_cache_dir_path('downloads')
416
+
417
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
418
+ if cache:
419
+ cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
420
+ if len(cache_files) == 1:
421
+ filename = cache_files[0]
422
+ return filename if return_filename else open(filename, "rb")
423
+
424
+ # Download.
425
+ url_name = None
426
+ url_data = None
427
+ with requests.Session() as session:
428
+ if verbose:
429
+ print("Downloading %s ..." % url, end="", flush=True)
430
+ for attempts_left in reversed(range(num_attempts)):
431
+ try:
432
+ with session.get(url) as res:
433
+ res.raise_for_status()
434
+ if len(res.content) == 0:
435
+ raise IOError("No data received")
436
+
437
+ if len(res.content) < 8192:
438
+ content_str = res.content.decode("utf-8")
439
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
440
+ links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
441
+ if len(links) == 1:
442
+ url = requests.compat.urljoin(url, links[0])
443
+ raise IOError("Google Drive virus checker nag")
444
+ if "Google Drive - Quota exceeded" in content_str:
445
+ raise IOError("Google Drive download quota exceeded -- please try again later")
446
+
447
+ match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
448
+ url_name = match[1] if match else url
449
+ url_data = res.content
450
+ if verbose:
451
+ print(" done")
452
+ break
453
+ except KeyboardInterrupt:
454
+ raise
455
+ except:
456
+ if not attempts_left:
457
+ if verbose:
458
+ print(" failed")
459
+ raise
460
+ if verbose:
461
+ print(".", end="", flush=True)
462
+
463
+ # Save to cache.
464
+ if cache:
465
+ safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
466
+ cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
467
+ temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
468
+ os.makedirs(cache_dir, exist_ok=True)
469
+ with open(temp_file, "wb") as f:
470
+ f.write(url_data)
471
+ os.replace(temp_file, cache_file) # atomic
472
+ if return_filename:
473
+ return cache_file
474
+
475
+ # Return data as file object.
476
+ assert not return_filename
477
+ return io.BytesIO(url_data)
diffusion-insgen/generate.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Generate images using pretrained network pickle."""
10
+
11
+ import os
12
+ import re
13
+ from typing import List, Optional
14
+
15
+ import click
16
+ import dnnlib
17
+ import numpy as np
18
+ import PIL.Image
19
+ import torch
20
+
21
+ import legacy
22
+
23
+ #----------------------------------------------------------------------------
24
+
25
+ def num_range(s: str) -> List[int]:
26
+ '''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.'''
27
+
28
+ range_re = re.compile(r'^(\d+)-(\d+)$')
29
+ m = range_re.match(s)
30
+ if m:
31
+ return list(range(int(m.group(1)), int(m.group(2))+1))
32
+ vals = s.split(',')
33
+ return [int(x) for x in vals]
34
+
35
+ #----------------------------------------------------------------------------
36
+
37
+ @click.command()
38
+ @click.pass_context
39
+ @click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
40
+ @click.option('--seeds', type=num_range, help='List of random seeds')
41
+ @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
42
+ @click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)')
43
+ @click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
44
+ @click.option('--projected-w', help='Projection result file', type=str, metavar='FILE')
45
+ @click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR')
46
+ def generate_images(
47
+ ctx: click.Context,
48
+ network_pkl: str,
49
+ seeds: Optional[List[int]],
50
+ truncation_psi: float,
51
+ noise_mode: str,
52
+ outdir: str,
53
+ class_idx: Optional[int],
54
+ projected_w: Optional[str]
55
+ ):
56
+ """Generate images using pretrained network pickle.
57
+
58
+ Examples:
59
+
60
+ \b
61
+ # Generate curated MetFaces images without truncation (Fig.10 left)
62
+ python generate.py --outdir=out --trunc=1 --seeds=85,265,297,849 \\
63
+ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
64
+
65
+ \b
66
+ # Generate uncurated MetFaces images with truncation (Fig.12 upper left)
67
+ python generate.py --outdir=out --trunc=0.7 --seeds=600-605 \\
68
+ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
69
+
70
+ \b
71
+ # Generate class conditional CIFAR-10 images (Fig.17 left, Car)
72
+ python generate.py --outdir=out --seeds=0-35 --class=1 \\
73
+ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/cifar10.pkl
74
+
75
+ \b
76
+ # Render an image from projected W
77
+ python generate.py --outdir=out --projected_w=projected_w.npz \\
78
+ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
79
+ """
80
+
81
+ print('Loading networks from "%s"...' % network_pkl)
82
+ device = torch.device('cuda')
83
+ with dnnlib.util.open_url(network_pkl) as f:
84
+ G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
85
+
86
+ os.makedirs(outdir, exist_ok=True)
87
+
88
+ # Synthesize the result of a W projection.
89
+ if projected_w is not None:
90
+ if seeds is not None:
91
+ print ('warn: --seeds is ignored when using --projected-w')
92
+ print(f'Generating images from projected W "{projected_w}"')
93
+ ws = np.load(projected_w)['w']
94
+ ws = torch.tensor(ws, device=device) # pylint: disable=not-callable
95
+ assert ws.shape[1:] == (G.num_ws, G.w_dim)
96
+ for idx, w in enumerate(ws):
97
+ img = G.synthesis(w.unsqueeze(0), noise_mode=noise_mode)
98
+ img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
99
+ img = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/proj{idx:02d}.png')
100
+ return
101
+
102
+ if seeds is None:
103
+ ctx.fail('--seeds option is required when not using --projected-w')
104
+
105
+ # Labels.
106
+ label = torch.zeros([1, G.c_dim], device=device)
107
+ if G.c_dim != 0:
108
+ if class_idx is None:
109
+ ctx.fail('Must specify class label with --class when using a conditional network')
110
+ label[:, class_idx] = 1
111
+ else:
112
+ if class_idx is not None:
113
+ print ('warn: --class=lbl ignored when running on an unconditional network')
114
+
115
+ # Generate images.
116
+ for seed_idx, seed in enumerate(seeds):
117
+ print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
118
+ z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)
119
+ img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
120
+ img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
121
+ PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.png')
122
+
123
+
124
+ #----------------------------------------------------------------------------
125
+
126
+ if __name__ == "__main__":
127
+ generate_images() # pylint: disable=no-value-for-parameter
128
+
129
+ #----------------------------------------------------------------------------
diffusion-insgen/legacy.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import click
10
+ import pickle
11
+ import re
12
+ import copy
13
+ import numpy as np
14
+ import torch
15
+ import dnnlib
16
+ from torch_utils import misc
17
+
18
+ #----------------------------------------------------------------------------
19
+
20
+ def load_network_pkl(f, force_fp16=False):
21
+ data = _LegacyUnpickler(f).load()
22
+
23
+ # Legacy TensorFlow pickle => convert.
24
+ if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data):
25
+ tf_G, tf_D, tf_Gs = data
26
+ G = convert_tf_generator(tf_G)
27
+ D = convert_tf_discriminator(tf_D)
28
+ G_ema = convert_tf_generator(tf_Gs)
29
+ data = dict(G=G, D=D, G_ema=G_ema)
30
+
31
+ # extract nn.module from ddp
32
+ for k, v in data.items():
33
+ if isinstance(v, _DDPNetworkStub):
34
+ data[k] = v._modules['module']
35
+
36
+ # Add missing fields.
37
+ if 'training_set_kwargs' not in data:
38
+ data['training_set_kwargs'] = None
39
+ if 'augment_pipe' not in data:
40
+ data['augment_pipe'] = None
41
+
42
+ # Validate contents.
43
+ assert isinstance(data['G'], torch.nn.Module)
44
+ assert isinstance(data['D'], torch.nn.Module)
45
+ assert isinstance(data['G_ema'], torch.nn.Module)
46
+ assert isinstance(data['training_set_kwargs'], (dict, type(None)))
47
+ assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None)))
48
+
49
+ # Force FP16.
50
+ if force_fp16:
51
+ for key in ['G', 'D', 'G_ema']:
52
+ old = data[key]
53
+ kwargs = copy.deepcopy(old.init_kwargs)
54
+ if key.startswith('G'):
55
+ kwargs.synthesis_kwargs = dnnlib.EasyDict(kwargs.get('synthesis_kwargs', {}))
56
+ kwargs.synthesis_kwargs.num_fp16_res = 4
57
+ kwargs.synthesis_kwargs.conv_clamp = 256
58
+ if key.startswith('D'):
59
+ kwargs.num_fp16_res = 4
60
+ kwargs.conv_clamp = 256
61
+ if kwargs != old.init_kwargs:
62
+ new = type(old)(**kwargs).eval().requires_grad_(False)
63
+ misc.copy_params_and_buffers(old, new, require_all=True)
64
+ data[key] = new
65
+ return data
66
+
67
+ #----------------------------------------------------------------------------
68
+
69
+ class _DDPNetworkStub(dnnlib.EasyDict):
70
+ pass
71
+
72
+ class _TFNetworkStub(dnnlib.EasyDict):
73
+ pass
74
+
75
+ class _LegacyUnpickler(pickle.Unpickler):
76
+ def find_class(self, module, name):
77
+ if module == 'torch.nn.parallel.distributed' and name == 'DistributedDataParallel':
78
+ return _DDPNetworkStub
79
+ if module == 'dnnlib.tflib.network' and name == 'Network':
80
+ return _TFNetworkStub
81
+ if module == 'training.augment':
82
+ return _TFNetworkStub
83
+ return super().find_class(module, name)
84
+
85
+ #----------------------------------------------------------------------------
86
+
87
+ def _collect_tf_params(tf_net):
88
+ # pylint: disable=protected-access
89
+ tf_params = dict()
90
+ def recurse(prefix, tf_net):
91
+ for name, value in tf_net.variables:
92
+ tf_params[prefix + name] = value
93
+ for name, comp in tf_net.components.items():
94
+ recurse(prefix + name + '/', comp)
95
+ recurse('', tf_net)
96
+ return tf_params
97
+
98
+ #----------------------------------------------------------------------------
99
+
100
+ def _populate_module_params(module, *patterns):
101
+ for name, tensor in misc.named_params_and_buffers(module):
102
+ found = False
103
+ value = None
104
+ for pattern, value_fn in zip(patterns[0::2], patterns[1::2]):
105
+ match = re.fullmatch(pattern, name)
106
+ if match:
107
+ found = True
108
+ if value_fn is not None:
109
+ value = value_fn(*match.groups())
110
+ break
111
+ try:
112
+ assert found
113
+ if value is not None:
114
+ tensor.copy_(torch.from_numpy(np.array(value)))
115
+ except:
116
+ print(name, list(tensor.shape))
117
+ raise
118
+
119
+ #----------------------------------------------------------------------------
120
+
121
+ def convert_tf_generator(tf_G):
122
+ if tf_G.version < 4:
123
+ raise ValueError('TensorFlow pickle version too low')
124
+
125
+ # Collect kwargs.
126
+ tf_kwargs = tf_G.static_kwargs
127
+ known_kwargs = set()
128
+ def kwarg(tf_name, default=None, none=None):
129
+ known_kwargs.add(tf_name)
130
+ val = tf_kwargs.get(tf_name, default)
131
+ return val if val is not None else none
132
+
133
+ # Convert kwargs.
134
+ kwargs = dnnlib.EasyDict(
135
+ z_dim = kwarg('latent_size', 512),
136
+ c_dim = kwarg('label_size', 0),
137
+ w_dim = kwarg('dlatent_size', 512),
138
+ img_resolution = kwarg('resolution', 1024),
139
+ img_channels = kwarg('num_channels', 3),
140
+ mapping_kwargs = dnnlib.EasyDict(
141
+ num_layers = kwarg('mapping_layers', 8),
142
+ embed_features = kwarg('label_fmaps', None),
143
+ layer_features = kwarg('mapping_fmaps', None),
144
+ activation = kwarg('mapping_nonlinearity', 'lrelu'),
145
+ lr_multiplier = kwarg('mapping_lrmul', 0.01),
146
+ w_avg_beta = kwarg('w_avg_beta', 0.995, none=1),
147
+ ),
148
+ synthesis_kwargs = dnnlib.EasyDict(
149
+ channel_base = kwarg('fmap_base', 16384) * 2,
150
+ channel_max = kwarg('fmap_max', 512),
151
+ num_fp16_res = kwarg('num_fp16_res', 0),
152
+ conv_clamp = kwarg('conv_clamp', None),
153
+ architecture = kwarg('architecture', 'skip'),
154
+ resample_filter = kwarg('resample_kernel', [1,3,3,1]),
155
+ use_noise = kwarg('use_noise', True),
156
+ activation = kwarg('nonlinearity', 'lrelu'),
157
+ ),
158
+ )
159
+
160
+ # Check for unknown kwargs.
161
+ kwarg('truncation_psi')
162
+ kwarg('truncation_cutoff')
163
+ kwarg('style_mixing_prob')
164
+ kwarg('structure')
165
+ unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
166
+ if len(unknown_kwargs) > 0:
167
+ raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
168
+
169
+ # Collect params.
170
+ tf_params = _collect_tf_params(tf_G)
171
+ for name, value in list(tf_params.items()):
172
+ match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name)
173
+ if match:
174
+ r = kwargs.img_resolution // (2 ** int(match.group(1)))
175
+ tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value
176
+ kwargs.synthesis.kwargs.architecture = 'orig'
177
+ #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
178
+
179
+ # Convert params.
180
+ from training import networks
181
+ G = networks.Generator(**kwargs).eval().requires_grad_(False)
182
+ # pylint: disable=unnecessary-lambda
183
+ _populate_module_params(G,
184
+ r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'],
185
+ r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(),
186
+ r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'],
187
+ r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(),
188
+ r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'],
189
+ r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0],
190
+ r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1),
191
+ r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'],
192
+ r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0],
193
+ r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'],
194
+ r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(),
195
+ r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1,
196
+ r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
197
+ r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'],
198
+ r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0],
199
+ r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'],
200
+ r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(),
201
+ r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1,
202
+ r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1),
203
+ r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'],
204
+ r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0],
205
+ r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'],
206
+ r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(),
207
+ r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1,
208
+ r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1),
209
+ r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'],
210
+ r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(),
211
+ r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1,
212
+ r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
213
+ r'.*\.resample_filter', None,
214
+ )
215
+ return G
216
+
217
+ #----------------------------------------------------------------------------
218
+
219
+ def convert_tf_discriminator(tf_D):
220
+ if tf_D.version < 4:
221
+ raise ValueError('TensorFlow pickle version too low')
222
+
223
+ # Collect kwargs.
224
+ tf_kwargs = tf_D.static_kwargs
225
+ known_kwargs = set()
226
+ def kwarg(tf_name, default=None):
227
+ known_kwargs.add(tf_name)
228
+ return tf_kwargs.get(tf_name, default)
229
+
230
+ # Convert kwargs.
231
+ kwargs = dnnlib.EasyDict(
232
+ c_dim = kwarg('label_size', 0),
233
+ img_resolution = kwarg('resolution', 1024),
234
+ img_channels = kwarg('num_channels', 3),
235
+ architecture = kwarg('architecture', 'resnet'),
236
+ channel_base = kwarg('fmap_base', 16384) * 2,
237
+ channel_max = kwarg('fmap_max', 512),
238
+ num_fp16_res = kwarg('num_fp16_res', 0),
239
+ conv_clamp = kwarg('conv_clamp', None),
240
+ cmap_dim = kwarg('mapping_fmaps', None),
241
+ block_kwargs = dnnlib.EasyDict(
242
+ activation = kwarg('nonlinearity', 'lrelu'),
243
+ resample_filter = kwarg('resample_kernel', [1,3,3,1]),
244
+ freeze_layers = kwarg('freeze_layers', 0),
245
+ ),
246
+ mapping_kwargs = dnnlib.EasyDict(
247
+ num_layers = kwarg('mapping_layers', 0),
248
+ embed_features = kwarg('mapping_fmaps', None),
249
+ layer_features = kwarg('mapping_fmaps', None),
250
+ activation = kwarg('nonlinearity', 'lrelu'),
251
+ lr_multiplier = kwarg('mapping_lrmul', 0.1),
252
+ ),
253
+ epilogue_kwargs = dnnlib.EasyDict(
254
+ mbstd_group_size = kwarg('mbstd_group_size', None),
255
+ mbstd_num_channels = kwarg('mbstd_num_features', 1),
256
+ activation = kwarg('nonlinearity', 'lrelu'),
257
+ ),
258
+ )
259
+
260
+ # Check for unknown kwargs.
261
+ kwarg('structure')
262
+ unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
263
+ if len(unknown_kwargs) > 0:
264
+ raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
265
+
266
+ # Collect params.
267
+ tf_params = _collect_tf_params(tf_D)
268
+ for name, value in list(tf_params.items()):
269
+ match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name)
270
+ if match:
271
+ r = kwargs.img_resolution // (2 ** int(match.group(1)))
272
+ tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value
273
+ kwargs.architecture = 'orig'
274
+ #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
275
+
276
+ # Convert params.
277
+ from training import networks
278
+ D = networks.Discriminator(**kwargs).eval().requires_grad_(False)
279
+ # pylint: disable=unnecessary-lambda
280
+ _populate_module_params(D,
281
+ r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1),
282
+ r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'],
283
+ r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1),
284
+ r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'],
285
+ r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1),
286
+ r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(),
287
+ r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'],
288
+ r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(),
289
+ r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'],
290
+ r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1),
291
+ r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'],
292
+ r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(),
293
+ r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'],
294
+ r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(),
295
+ r'b4\.out\.bias', lambda: tf_params[f'Output/bias'],
296
+ r'.*\.resample_filter', None,
297
+ )
298
+ return D
299
+
300
+ #----------------------------------------------------------------------------
301
+
302
+ @click.command()
303
+ @click.option('--source', help='Input pickle', required=True, metavar='PATH')
304
+ @click.option('--dest', help='Output pickle', required=True, metavar='PATH')
305
+ @click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True)
306
+ def convert_network_pickle(source, dest, force_fp16):
307
+ """Convert legacy network pickle into the native PyTorch format.
308
+
309
+ The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA.
310
+ It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks.
311
+
312
+ Example:
313
+
314
+ \b
315
+ python legacy.py \\
316
+ --source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\
317
+ --dest=stylegan2-cat-config-f.pkl
318
+ """
319
+ print(f'Loading "{source}"...')
320
+ with dnnlib.util.open_url(source) as f:
321
+ data = load_network_pkl(f, force_fp16=force_fp16)
322
+ print(f'Saving "{dest}"...')
323
+ with open(dest, 'wb') as f:
324
+ pickle.dump(data, f)
325
+ print('Done.')
326
+
327
+ #----------------------------------------------------------------------------
328
+
329
+ if __name__ == "__main__":
330
+ convert_network_pickle() # pylint: disable=no-value-for-parameter
331
+
332
+ #----------------------------------------------------------------------------
diffusion-insgen/metrics/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ # empty
diffusion-insgen/metrics/frechet_inception_distance.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Frechet Inception Distance (FID) from the paper
10
+ "GANs trained by a two time-scale update rule converge to a local Nash
11
+ equilibrium". Matches the original implementation by Heusel et al. at
12
+ https://github.com/bioinf-jku/TTUR/blob/master/fid.py"""
13
+
14
+ import numpy as np
15
+ import scipy.linalg
16
+ from . import metric_utils
17
+
18
+ #----------------------------------------------------------------------------
19
+
20
+ def compute_fid(opts, max_real, num_gen):
21
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
22
+ detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
23
+ detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
24
+
25
+ mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset(
26
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
27
+ rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real).get_mean_cov()
28
+
29
+ mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator(
30
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
31
+ rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen).get_mean_cov()
32
+
33
+ if opts.rank != 0:
34
+ return float('nan')
35
+
36
+ m = np.square(mu_gen - mu_real).sum()
37
+ s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
38
+ fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
39
+ return float(fid)
40
+
41
+ #----------------------------------------------------------------------------
diffusion-insgen/metrics/inception_score.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Inception Score (IS) from the paper "Improved techniques for training
10
+ GANs". Matches the original implementation by Salimans et al. at
11
+ https://github.com/openai/improved-gan/blob/master/inception_score/model.py"""
12
+
13
+ import numpy as np
14
+ from . import metric_utils
15
+
16
+ #----------------------------------------------------------------------------
17
+
18
+ def compute_is(opts, num_gen, num_splits):
19
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
20
+ detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
21
+ detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer.
22
+
23
+ gen_probs = metric_utils.compute_feature_stats_for_generator(
24
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
25
+ capture_all=True, max_items=num_gen).get_all()
26
+
27
+ if opts.rank != 0:
28
+ return float('nan'), float('nan')
29
+
30
+ scores = []
31
+ for i in range(num_splits):
32
+ part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits]
33
+ kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True)))
34
+ kl = np.mean(np.sum(kl, axis=1))
35
+ scores.append(np.exp(kl))
36
+ return float(np.mean(scores)), float(np.std(scores))
37
+
38
+ #----------------------------------------------------------------------------
diffusion-insgen/metrics/kernel_inception_distance.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Kernel Inception Distance (KID) from the paper "Demystifying MMD
10
+ GANs". Matches the original implementation by Binkowski et al. at
11
+ https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py"""
12
+
13
+ import numpy as np
14
+ from . import metric_utils
15
+
16
+ #----------------------------------------------------------------------------
17
+
18
+ def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size):
19
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
20
+ detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
21
+ detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
22
+
23
+ real_features = metric_utils.compute_feature_stats_for_dataset(
24
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
25
+ rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all()
26
+
27
+ gen_features = metric_utils.compute_feature_stats_for_generator(
28
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
29
+ rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all()
30
+
31
+ if opts.rank != 0:
32
+ return float('nan')
33
+
34
+ n = real_features.shape[1]
35
+ m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size)
36
+ t = 0
37
+ for _subset_idx in range(num_subsets):
38
+ x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)]
39
+ y = real_features[np.random.choice(real_features.shape[0], m, replace=False)]
40
+ a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3
41
+ b = (x @ y.T / n + 1) ** 3
42
+ t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m
43
+ kid = t / num_subsets / m
44
+ return float(kid)
45
+
46
+ #----------------------------------------------------------------------------
diffusion-insgen/metrics/metric_main.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import os
10
+ import time
11
+ import json
12
+ import torch
13
+ import dnnlib
14
+
15
+ from . import metric_utils
16
+ from . import frechet_inception_distance
17
+ from . import kernel_inception_distance
18
+ from . import precision_recall
19
+ from . import perceptual_path_length
20
+ from . import inception_score
21
+
22
+ #----------------------------------------------------------------------------
23
+
24
+ _metric_dict = dict() # name => fn
25
+
26
+ def register_metric(fn):
27
+ assert callable(fn)
28
+ _metric_dict[fn.__name__] = fn
29
+ return fn
30
+
31
+ def is_valid_metric(metric):
32
+ return metric in _metric_dict
33
+
34
+ def list_valid_metrics():
35
+ return list(_metric_dict.keys())
36
+
37
+ #----------------------------------------------------------------------------
38
+
39
+ def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments.
40
+ assert is_valid_metric(metric)
41
+ opts = metric_utils.MetricOptions(**kwargs)
42
+
43
+ # Calculate.
44
+ start_time = time.time()
45
+ results = _metric_dict[metric](opts)
46
+ total_time = time.time() - start_time
47
+
48
+ # Broadcast results.
49
+ for key, value in list(results.items()):
50
+ if opts.num_gpus > 1:
51
+ value = torch.as_tensor(value, dtype=torch.float64, device=opts.device)
52
+ torch.distributed.broadcast(tensor=value, src=0)
53
+ value = float(value.cpu())
54
+ results[key] = value
55
+
56
+ # Decorate with metadata.
57
+ return dnnlib.EasyDict(
58
+ results = dnnlib.EasyDict(results),
59
+ metric = metric,
60
+ total_time = total_time,
61
+ total_time_str = dnnlib.util.format_time(total_time),
62
+ num_gpus = opts.num_gpus,
63
+ )
64
+
65
+ #----------------------------------------------------------------------------
66
+
67
+ def report_metric(result_dict, run_dir=None, snapshot_pkl=None):
68
+ metric = result_dict['metric']
69
+ assert is_valid_metric(metric)
70
+ if run_dir is not None and snapshot_pkl is not None:
71
+ snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir)
72
+
73
+ jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time()))
74
+ print(jsonl_line)
75
+ if run_dir is not None and os.path.isdir(run_dir):
76
+ with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f:
77
+ f.write(jsonl_line + '\n')
78
+
79
+ #----------------------------------------------------------------------------
80
+ # Primary metrics.
81
+
82
+ @register_metric
83
+ def fid50k_full(opts):
84
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
85
+ fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000)
86
+ return dict(fid50k_full=fid)
87
+
88
+ @register_metric
89
+ def kid50k_full(opts):
90
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
91
+ kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000)
92
+ return dict(kid50k_full=kid)
93
+
94
+ @register_metric
95
+ def pr50k3_full(opts):
96
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
97
+ precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
98
+ return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall)
99
+
100
+ @register_metric
101
+ def ppl2_wend(opts):
102
+ ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2)
103
+ return dict(ppl2_wend=ppl)
104
+
105
+ @register_metric
106
+ def is50k(opts):
107
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
108
+ mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10)
109
+ return dict(is50k_mean=mean, is50k_std=std)
110
+
111
+ #----------------------------------------------------------------------------
112
+ # Legacy metrics.
113
+
114
+ @register_metric
115
+ def fid50k(opts):
116
+ opts.dataset_kwargs.update(max_size=None)
117
+ fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000)
118
+ return dict(fid50k=fid)
119
+
120
+ @register_metric
121
+ def kid50k(opts):
122
+ opts.dataset_kwargs.update(max_size=None)
123
+ kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000)
124
+ return dict(kid50k=kid)
125
+
126
+ @register_metric
127
+ def pr50k3(opts):
128
+ opts.dataset_kwargs.update(max_size=None)
129
+ precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
130
+ return dict(pr50k3_precision=precision, pr50k3_recall=recall)
131
+
132
+ @register_metric
133
+ def ppl_zfull(opts):
134
+ ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='z', sampling='full', crop=True, batch_size=2)
135
+ return dict(ppl_zfull=ppl)
136
+
137
+ @register_metric
138
+ def ppl_wfull(opts):
139
+ ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='full', crop=True, batch_size=2)
140
+ return dict(ppl_wfull=ppl)
141
+
142
+ @register_metric
143
+ def ppl_zend(opts):
144
+ ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='z', sampling='end', crop=True, batch_size=2)
145
+ return dict(ppl_zend=ppl)
146
+
147
+ @register_metric
148
+ def ppl_wend(opts):
149
+ ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=True, batch_size=2)
150
+ return dict(ppl_wend=ppl)
151
+
152
+ #----------------------------------------------------------------------------
diffusion-insgen/metrics/metric_utils.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import os
10
+ import time
11
+ import hashlib
12
+ import pickle
13
+ import copy
14
+ import uuid
15
+ import numpy as np
16
+ import torch
17
+ import dnnlib
18
+
19
+ #----------------------------------------------------------------------------
20
+
21
+ class MetricOptions:
22
+ def __init__(self, G=None, G_kwargs={}, dataset_kwargs={}, num_gpus=1, rank=0, device=None, progress=None, cache=True):
23
+ assert 0 <= rank < num_gpus
24
+ self.G = G
25
+ self.G_kwargs = dnnlib.EasyDict(G_kwargs)
26
+ self.dataset_kwargs = dnnlib.EasyDict(dataset_kwargs)
27
+ self.num_gpus = num_gpus
28
+ self.rank = rank
29
+ self.device = device if device is not None else torch.device('cuda', rank)
30
+ self.progress = progress.sub() if progress is not None and rank == 0 else ProgressMonitor()
31
+ self.cache = cache
32
+
33
+ #----------------------------------------------------------------------------
34
+
35
+ _feature_detector_cache = dict()
36
+
37
+ def get_feature_detector_name(url):
38
+ return os.path.splitext(url.split('/')[-1])[0]
39
+
40
+ def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False):
41
+ assert 0 <= rank < num_gpus
42
+ key = (url, device)
43
+ if key not in _feature_detector_cache:
44
+ is_leader = (rank == 0)
45
+ if not is_leader and num_gpus > 1:
46
+ torch.distributed.barrier() # leader goes first
47
+ with dnnlib.util.open_url(url, verbose=(verbose and is_leader)) as f:
48
+ _feature_detector_cache[key] = torch.jit.load(f).eval().to(device)
49
+ if is_leader and num_gpus > 1:
50
+ torch.distributed.barrier() # others follow
51
+ return _feature_detector_cache[key]
52
+
53
+ #----------------------------------------------------------------------------
54
+
55
+ class FeatureStats:
56
+ def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None):
57
+ self.capture_all = capture_all
58
+ self.capture_mean_cov = capture_mean_cov
59
+ self.max_items = max_items
60
+ self.num_items = 0
61
+ self.num_features = None
62
+ self.all_features = None
63
+ self.raw_mean = None
64
+ self.raw_cov = None
65
+
66
+ def set_num_features(self, num_features):
67
+ if self.num_features is not None:
68
+ assert num_features == self.num_features
69
+ else:
70
+ self.num_features = num_features
71
+ self.all_features = []
72
+ self.raw_mean = np.zeros([num_features], dtype=np.float64)
73
+ self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64)
74
+
75
+ def is_full(self):
76
+ return (self.max_items is not None) and (self.num_items >= self.max_items)
77
+
78
+ def append(self, x):
79
+ x = np.asarray(x, dtype=np.float32)
80
+ assert x.ndim == 2
81
+ if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items):
82
+ if self.num_items >= self.max_items:
83
+ return
84
+ x = x[:self.max_items - self.num_items]
85
+
86
+ self.set_num_features(x.shape[1])
87
+ self.num_items += x.shape[0]
88
+ if self.capture_all:
89
+ self.all_features.append(x)
90
+ if self.capture_mean_cov:
91
+ x64 = x.astype(np.float64)
92
+ self.raw_mean += x64.sum(axis=0)
93
+ self.raw_cov += x64.T @ x64
94
+
95
+ def append_torch(self, x, num_gpus=1, rank=0):
96
+ assert isinstance(x, torch.Tensor) and x.ndim == 2
97
+ assert 0 <= rank < num_gpus
98
+ if num_gpus > 1:
99
+ ys = []
100
+ for src in range(num_gpus):
101
+ y = x.clone()
102
+ torch.distributed.broadcast(y, src=src)
103
+ ys.append(y)
104
+ x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples
105
+ self.append(x.cpu().numpy())
106
+
107
+ def get_all(self):
108
+ assert self.capture_all
109
+ return np.concatenate(self.all_features, axis=0)
110
+
111
+ def get_all_torch(self):
112
+ return torch.from_numpy(self.get_all())
113
+
114
+ def get_mean_cov(self):
115
+ assert self.capture_mean_cov
116
+ mean = self.raw_mean / self.num_items
117
+ cov = self.raw_cov / self.num_items
118
+ cov = cov - np.outer(mean, mean)
119
+ return mean, cov
120
+
121
+ def save(self, pkl_file):
122
+ with open(pkl_file, 'wb') as f:
123
+ pickle.dump(self.__dict__, f)
124
+
125
+ @staticmethod
126
+ def load(pkl_file):
127
+ with open(pkl_file, 'rb') as f:
128
+ s = dnnlib.EasyDict(pickle.load(f))
129
+ obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items)
130
+ obj.__dict__.update(s)
131
+ return obj
132
+
133
+ #----------------------------------------------------------------------------
134
+
135
+ class ProgressMonitor:
136
+ def __init__(self, tag=None, num_items=None, flush_interval=1000, verbose=False, progress_fn=None, pfn_lo=0, pfn_hi=1000, pfn_total=1000):
137
+ self.tag = tag
138
+ self.num_items = num_items
139
+ self.verbose = verbose
140
+ self.flush_interval = flush_interval
141
+ self.progress_fn = progress_fn
142
+ self.pfn_lo = pfn_lo
143
+ self.pfn_hi = pfn_hi
144
+ self.pfn_total = pfn_total
145
+ self.start_time = time.time()
146
+ self.batch_time = self.start_time
147
+ self.batch_items = 0
148
+ if self.progress_fn is not None:
149
+ self.progress_fn(self.pfn_lo, self.pfn_total)
150
+
151
+ def update(self, cur_items):
152
+ assert (self.num_items is None) or (cur_items <= self.num_items)
153
+ if (cur_items < self.batch_items + self.flush_interval) and (self.num_items is None or cur_items < self.num_items):
154
+ return
155
+ cur_time = time.time()
156
+ total_time = cur_time - self.start_time
157
+ time_per_item = (cur_time - self.batch_time) / max(cur_items - self.batch_items, 1)
158
+ if (self.verbose) and (self.tag is not None):
159
+ print(f'{self.tag:<19s} items {cur_items:<7d} time {dnnlib.util.format_time(total_time):<12s} ms/item {time_per_item*1e3:.2f}')
160
+ self.batch_time = cur_time
161
+ self.batch_items = cur_items
162
+
163
+ if (self.progress_fn is not None) and (self.num_items is not None):
164
+ self.progress_fn(self.pfn_lo + (self.pfn_hi - self.pfn_lo) * (cur_items / self.num_items), self.pfn_total)
165
+
166
+ def sub(self, tag=None, num_items=None, flush_interval=1000, rel_lo=0, rel_hi=1):
167
+ return ProgressMonitor(
168
+ tag = tag,
169
+ num_items = num_items,
170
+ flush_interval = flush_interval,
171
+ verbose = self.verbose,
172
+ progress_fn = self.progress_fn,
173
+ pfn_lo = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_lo,
174
+ pfn_hi = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_hi,
175
+ pfn_total = self.pfn_total,
176
+ )
177
+
178
+ #----------------------------------------------------------------------------
179
+
180
+ def compute_feature_stats_for_dataset(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, data_loader_kwargs=None, max_items=None, **stats_kwargs):
181
+ dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
182
+ if data_loader_kwargs is None:
183
+ data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2)
184
+
185
+ # Try to lookup from cache.
186
+ cache_file = None
187
+ if opts.cache:
188
+ # Choose cache file name.
189
+ args = dict(dataset_kwargs=opts.dataset_kwargs, detector_url=detector_url, detector_kwargs=detector_kwargs, stats_kwargs=stats_kwargs)
190
+ md5 = hashlib.md5(repr(sorted(args.items())).encode('utf-8'))
191
+ cache_tag = f'{dataset.name}-{get_feature_detector_name(detector_url)}-{md5.hexdigest()}'
192
+ cache_file = dnnlib.make_cache_dir_path('gan-metrics', cache_tag + '.pkl')
193
+
194
+ # Check if the file exists (all processes must agree).
195
+ flag = os.path.isfile(cache_file) if opts.rank == 0 else False
196
+ if opts.num_gpus > 1:
197
+ flag = torch.as_tensor(flag, dtype=torch.float32, device=opts.device)
198
+ torch.distributed.broadcast(tensor=flag, src=0)
199
+ flag = (float(flag.cpu()) != 0)
200
+
201
+ # Load.
202
+ if flag:
203
+ return FeatureStats.load(cache_file)
204
+
205
+ # Initialize.
206
+ num_items = len(dataset)
207
+ if max_items is not None:
208
+ num_items = min(num_items, max_items)
209
+ stats = FeatureStats(max_items=num_items, **stats_kwargs)
210
+ progress = opts.progress.sub(tag='dataset features', num_items=num_items, rel_lo=rel_lo, rel_hi=rel_hi)
211
+ detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
212
+
213
+ # Main loop.
214
+ item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)]
215
+ for images, _labels in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs):
216
+ if images.shape[1] == 1:
217
+ images = images.repeat([1, 3, 1, 1])
218
+ features = detector(images.to(opts.device), **detector_kwargs)
219
+ stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
220
+ progress.update(stats.num_items)
221
+
222
+ # Save to cache.
223
+ if cache_file is not None and opts.rank == 0:
224
+ os.makedirs(os.path.dirname(cache_file), exist_ok=True)
225
+ temp_file = cache_file + '.' + uuid.uuid4().hex
226
+ stats.save(temp_file)
227
+ os.replace(temp_file, cache_file) # atomic
228
+ return stats
229
+
230
+ #----------------------------------------------------------------------------
231
+
232
+ def compute_feature_stats_for_generator(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, batch_gen=None, jit=False, **stats_kwargs):
233
+ if batch_gen is None:
234
+ batch_gen = min(batch_size, 4)
235
+ assert batch_size % batch_gen == 0
236
+
237
+ # Setup generator and load labels.
238
+ G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
239
+ dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
240
+
241
+ # Image generation func.
242
+ def run_generator(z, c):
243
+ img = G(z=z, c=c, **opts.G_kwargs)
244
+ img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
245
+ return img
246
+
247
+ # JIT.
248
+ if jit:
249
+ z = torch.zeros([batch_gen, G.z_dim], device=opts.device)
250
+ c = torch.zeros([batch_gen, G.c_dim], device=opts.device)
251
+ run_generator = torch.jit.trace(run_generator, [z, c], check_trace=False)
252
+
253
+ # Initialize.
254
+ stats = FeatureStats(**stats_kwargs)
255
+ assert stats.max_items is not None
256
+ progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi)
257
+ detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
258
+
259
+ # Main loop.
260
+ while not stats.is_full():
261
+ images = []
262
+ for _i in range(batch_size // batch_gen):
263
+ z = torch.randn([batch_gen, G.z_dim], device=opts.device)
264
+ c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_gen)]
265
+ c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
266
+ images.append(run_generator(z, c))
267
+ images = torch.cat(images)
268
+ if images.shape[1] == 1:
269
+ images = images.repeat([1, 3, 1, 1])
270
+ features = detector(images, **detector_kwargs)
271
+ stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
272
+ progress.update(stats.num_items)
273
+ return stats
274
+
275
+ #----------------------------------------------------------------------------
diffusion-insgen/metrics/perceptual_path_length.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Perceptual Path Length (PPL) from the paper "A Style-Based Generator
10
+ Architecture for Generative Adversarial Networks". Matches the original
11
+ implementation by Karras et al. at
12
+ https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py"""
13
+
14
+ import copy
15
+ import numpy as np
16
+ import torch
17
+ import dnnlib
18
+ from . import metric_utils
19
+
20
+ #----------------------------------------------------------------------------
21
+
22
+ # Spherical interpolation of a batch of vectors.
23
+ def slerp(a, b, t):
24
+ a = a / a.norm(dim=-1, keepdim=True)
25
+ b = b / b.norm(dim=-1, keepdim=True)
26
+ d = (a * b).sum(dim=-1, keepdim=True)
27
+ p = t * torch.acos(d)
28
+ c = b - d * a
29
+ c = c / c.norm(dim=-1, keepdim=True)
30
+ d = a * torch.cos(p) + c * torch.sin(p)
31
+ d = d / d.norm(dim=-1, keepdim=True)
32
+ return d
33
+
34
+ #----------------------------------------------------------------------------
35
+
36
+ class PPLSampler(torch.nn.Module):
37
+ def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16):
38
+ assert space in ['z', 'w']
39
+ assert sampling in ['full', 'end']
40
+ super().__init__()
41
+ self.G = copy.deepcopy(G)
42
+ self.G_kwargs = G_kwargs
43
+ self.epsilon = epsilon
44
+ self.space = space
45
+ self.sampling = sampling
46
+ self.crop = crop
47
+ self.vgg16 = copy.deepcopy(vgg16)
48
+
49
+ def forward(self, c):
50
+ # Generate random latents and interpolation t-values.
51
+ t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0)
52
+ z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2)
53
+
54
+ # Interpolate in W or Z.
55
+ if self.space == 'w':
56
+ w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2)
57
+ wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2))
58
+ wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon)
59
+ else: # space == 'z'
60
+ zt0 = slerp(z0, z1, t.unsqueeze(1))
61
+ zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon)
62
+ wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2)
63
+
64
+ # Randomize noise buffers.
65
+ for name, buf in self.G.named_buffers():
66
+ if name.endswith('.noise_const'):
67
+ buf.copy_(torch.randn_like(buf))
68
+
69
+ # Generate images.
70
+ img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs)
71
+
72
+ # Center crop.
73
+ if self.crop:
74
+ assert img.shape[2] == img.shape[3]
75
+ c = img.shape[2] // 8
76
+ img = img[:, :, c*3 : c*7, c*2 : c*6]
77
+
78
+ # Downsample to 256x256.
79
+ factor = self.G.img_resolution // 256
80
+ if factor > 1:
81
+ img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5])
82
+
83
+ # Scale dynamic range from [-1,1] to [0,255].
84
+ img = (img + 1) * (255 / 2)
85
+ if self.G.img_channels == 1:
86
+ img = img.repeat([1, 3, 1, 1])
87
+
88
+ # Evaluate differential LPIPS.
89
+ lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2)
90
+ dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2
91
+ return dist
92
+
93
+ #----------------------------------------------------------------------------
94
+
95
+ def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size, jit=False):
96
+ dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
97
+ vgg16_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
98
+ vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose)
99
+
100
+ # Setup sampler.
101
+ sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16)
102
+ sampler.eval().requires_grad_(False).to(opts.device)
103
+ if jit:
104
+ c = torch.zeros([batch_size, opts.G.c_dim], device=opts.device)
105
+ sampler = torch.jit.trace(sampler, [c], check_trace=False)
106
+
107
+ # Sampling loop.
108
+ dist = []
109
+ progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples)
110
+ for batch_start in range(0, num_samples, batch_size * opts.num_gpus):
111
+ progress.update(batch_start)
112
+ c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_size)]
113
+ c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
114
+ x = sampler(c)
115
+ for src in range(opts.num_gpus):
116
+ y = x.clone()
117
+ if opts.num_gpus > 1:
118
+ torch.distributed.broadcast(y, src=src)
119
+ dist.append(y)
120
+ progress.update(num_samples)
121
+
122
+ # Compute PPL.
123
+ if opts.rank != 0:
124
+ return float('nan')
125
+ dist = torch.cat(dist)[:num_samples].cpu().numpy()
126
+ lo = np.percentile(dist, 1, interpolation='lower')
127
+ hi = np.percentile(dist, 99, interpolation='higher')
128
+ ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean()
129
+ return float(ppl)
130
+
131
+ #----------------------------------------------------------------------------
diffusion-insgen/metrics/precision_recall.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Precision/Recall (PR) from the paper "Improved Precision and Recall
10
+ Metric for Assessing Generative Models". Matches the original implementation
11
+ by Kynkaanniemi et al. at
12
+ https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py"""
13
+
14
+ import torch
15
+ from . import metric_utils
16
+
17
+ #----------------------------------------------------------------------------
18
+
19
+ def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size):
20
+ assert 0 <= rank < num_gpus
21
+ num_cols = col_features.shape[0]
22
+ num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus
23
+ col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches)
24
+ dist_batches = []
25
+ for col_batch in col_batches[rank :: num_gpus]:
26
+ dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0]
27
+ for src in range(num_gpus):
28
+ dist_broadcast = dist_batch.clone()
29
+ if num_gpus > 1:
30
+ torch.distributed.broadcast(dist_broadcast, src=src)
31
+ dist_batches.append(dist_broadcast.cpu() if rank == 0 else None)
32
+ return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None
33
+
34
+ #----------------------------------------------------------------------------
35
+
36
+ def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size):
37
+ detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
38
+ detector_kwargs = dict(return_features=True)
39
+
40
+ real_features = metric_utils.compute_feature_stats_for_dataset(
41
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
42
+ rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device)
43
+
44
+ gen_features = metric_utils.compute_feature_stats_for_generator(
45
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
46
+ rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device)
47
+
48
+ results = dict()
49
+ for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]:
50
+ kth = []
51
+ for manifold_batch in manifold.split(row_batch_size):
52
+ dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
53
+ kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None)
54
+ kth = torch.cat(kth) if opts.rank == 0 else None
55
+ pred = []
56
+ for probes_batch in probes.split(row_batch_size):
57
+ dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
58
+ pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None)
59
+ results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan')
60
+ return results['precision'], results['recall']
61
+
62
+ #----------------------------------------------------------------------------
diffusion-insgen/projector.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Project given image to the latent space of pretrained network pickle."""
10
+
11
+ import copy
12
+ import os
13
+ from time import perf_counter
14
+
15
+ import click
16
+ import imageio
17
+ import numpy as np
18
+ import PIL.Image
19
+ import torch
20
+ import torch.nn.functional as F
21
+
22
+ import dnnlib
23
+ import legacy
24
+
25
+ def project(
26
+ G,
27
+ target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution
28
+ *,
29
+ num_steps = 1000,
30
+ w_avg_samples = 10000,
31
+ initial_learning_rate = 0.1,
32
+ initial_noise_factor = 0.05,
33
+ lr_rampdown_length = 0.25,
34
+ lr_rampup_length = 0.05,
35
+ noise_ramp_length = 0.75,
36
+ regularize_noise_weight = 1e5,
37
+ verbose = False,
38
+ device: torch.device
39
+ ):
40
+ assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution)
41
+
42
+ def logprint(*args):
43
+ if verbose:
44
+ print(*args)
45
+
46
+ G = copy.deepcopy(G).eval().requires_grad_(False).to(device) # type: ignore
47
+
48
+ # Compute w stats.
49
+ logprint(f'Computing W midpoint and stddev using {w_avg_samples} samples...')
50
+ z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim)
51
+ w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None) # [N, L, C]
52
+ w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C]
53
+ w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C]
54
+ w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5
55
+
56
+ # Setup noise inputs.
57
+ noise_bufs = { name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name }
58
+
59
+ # Load VGG16 feature detector.
60
+ url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
61
+ with dnnlib.util.open_url(url) as f:
62
+ vgg16 = torch.jit.load(f).eval().to(device)
63
+
64
+ # Features for target image.
65
+ target_images = target.unsqueeze(0).to(device).to(torch.float32)
66
+ if target_images.shape[2] > 256:
67
+ target_images = F.interpolate(target_images, size=(256, 256), mode='area')
68
+ target_features = vgg16(target_images, resize_images=False, return_lpips=True)
69
+
70
+ w_opt = torch.tensor(w_avg, dtype=torch.float32, device=device, requires_grad=True) # pylint: disable=not-callable
71
+ w_out = torch.zeros([num_steps] + list(w_opt.shape[1:]), dtype=torch.float32, device=device)
72
+ optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999), lr=initial_learning_rate)
73
+
74
+ # Init noise.
75
+ for buf in noise_bufs.values():
76
+ buf[:] = torch.randn_like(buf)
77
+ buf.requires_grad = True
78
+
79
+ for step in range(num_steps):
80
+ # Learning rate schedule.
81
+ t = step / num_steps
82
+ w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2
83
+ lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length)
84
+ lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
85
+ lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length)
86
+ lr = initial_learning_rate * lr_ramp
87
+ for param_group in optimizer.param_groups:
88
+ param_group['lr'] = lr
89
+
90
+ # Synth images from opt_w.
91
+ w_noise = torch.randn_like(w_opt) * w_noise_scale
92
+ ws = (w_opt + w_noise).repeat([1, G.mapping.num_ws, 1])
93
+ synth_images = G.synthesis(ws, noise_mode='const')
94
+
95
+ # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
96
+ synth_images = (synth_images + 1) * (255/2)
97
+ if synth_images.shape[2] > 256:
98
+ synth_images = F.interpolate(synth_images, size=(256, 256), mode='area')
99
+
100
+ # Features for synth images.
101
+ synth_features = vgg16(synth_images, resize_images=False, return_lpips=True)
102
+ dist = (target_features - synth_features).square().sum()
103
+
104
+ # Noise regularization.
105
+ reg_loss = 0.0
106
+ for v in noise_bufs.values():
107
+ noise = v[None,None,:,:] # must be [1,1,H,W] for F.avg_pool2d()
108
+ while True:
109
+ reg_loss += (noise*torch.roll(noise, shifts=1, dims=3)).mean()**2
110
+ reg_loss += (noise*torch.roll(noise, shifts=1, dims=2)).mean()**2
111
+ if noise.shape[2] <= 8:
112
+ break
113
+ noise = F.avg_pool2d(noise, kernel_size=2)
114
+ loss = dist + reg_loss * regularize_noise_weight
115
+
116
+ # Step
117
+ optimizer.zero_grad(set_to_none=True)
118
+ loss.backward()
119
+ optimizer.step()
120
+ logprint(f'step {step+1:>4d}/{num_steps}: dist {dist:<4.2f} loss {float(loss):<5.2f}')
121
+
122
+ # Save projected W for each optimization step.
123
+ w_out[step] = w_opt.detach()[0]
124
+
125
+ # Normalize noise.
126
+ with torch.no_grad():
127
+ for buf in noise_bufs.values():
128
+ buf -= buf.mean()
129
+ buf *= buf.square().mean().rsqrt()
130
+
131
+ return w_out.repeat([1, G.mapping.num_ws, 1])
132
+
133
+ #----------------------------------------------------------------------------
134
+
135
+ @click.command()
136
+ @click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
137
+ @click.option('--target', 'target_fname', help='Target image file to project to', required=True, metavar='FILE')
138
+ @click.option('--num-steps', help='Number of optimization steps', type=int, default=1000, show_default=True)
139
+ @click.option('--seed', help='Random seed', type=int, default=303, show_default=True)
140
+ @click.option('--save-video', help='Save an mp4 video of optimization progress', type=bool, default=True, show_default=True)
141
+ @click.option('--outdir', help='Where to save the output images', required=True, metavar='DIR')
142
+ def run_projection(
143
+ network_pkl: str,
144
+ target_fname: str,
145
+ outdir: str,
146
+ save_video: bool,
147
+ seed: int,
148
+ num_steps: int
149
+ ):
150
+ """Project given image to the latent space of pretrained network pickle.
151
+
152
+ Examples:
153
+
154
+ \b
155
+ python projector.py --outdir=out --target=~/mytargetimg.png \\
156
+ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl
157
+ """
158
+ np.random.seed(seed)
159
+ torch.manual_seed(seed)
160
+
161
+ # Load networks.
162
+ print('Loading networks from "%s"...' % network_pkl)
163
+ device = torch.device('cuda')
164
+ with dnnlib.util.open_url(network_pkl) as fp:
165
+ G = legacy.load_network_pkl(fp)['G_ema'].requires_grad_(False).to(device) # type: ignore
166
+
167
+ # Load target image.
168
+ target_pil = PIL.Image.open(target_fname).convert('RGB')
169
+ w, h = target_pil.size
170
+ s = min(w, h)
171
+ target_pil = target_pil.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2))
172
+ target_pil = target_pil.resize((G.img_resolution, G.img_resolution), PIL.Image.LANCZOS)
173
+ target_uint8 = np.array(target_pil, dtype=np.uint8)
174
+
175
+ # Optimize projection.
176
+ start_time = perf_counter()
177
+ projected_w_steps = project(
178
+ G,
179
+ target=torch.tensor(target_uint8.transpose([2, 0, 1]), device=device), # pylint: disable=not-callable
180
+ num_steps=num_steps,
181
+ device=device,
182
+ verbose=True
183
+ )
184
+ print (f'Elapsed: {(perf_counter()-start_time):.1f} s')
185
+
186
+ # Render debug output: optional video and projected image and W vector.
187
+ os.makedirs(outdir, exist_ok=True)
188
+ if save_video:
189
+ video = imageio.get_writer(f'{outdir}/proj.mp4', mode='I', fps=10, codec='libx264', bitrate='16M')
190
+ print (f'Saving optimization progress video "{outdir}/proj.mp4"')
191
+ for projected_w in projected_w_steps:
192
+ synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const')
193
+ synth_image = (synth_image + 1) * (255/2)
194
+ synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
195
+ video.append_data(np.concatenate([target_uint8, synth_image], axis=1))
196
+ video.close()
197
+
198
+ # Save final projected frame and W vector.
199
+ target_pil.save(f'{outdir}/target.png')
200
+ projected_w = projected_w_steps[-1]
201
+ synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const')
202
+ synth_image = (synth_image + 1) * (255/2)
203
+ synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
204
+ PIL.Image.fromarray(synth_image, 'RGB').save(f'{outdir}/proj.png')
205
+ np.savez(f'{outdir}/projected_w.npz', w=projected_w.unsqueeze(0).cpu().numpy())
206
+
207
+ #----------------------------------------------------------------------------
208
+
209
+ if __name__ == "__main__":
210
+ run_projection() # pylint: disable=no-value-for-parameter
211
+
212
+ #----------------------------------------------------------------------------
diffusion-insgen/style_mixing.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Generate style mixing image matrix using pretrained network pickle."""
10
+
11
+ import os
12
+ import re
13
+ from typing import List
14
+
15
+ import click
16
+ import dnnlib
17
+ import numpy as np
18
+ import PIL.Image
19
+ import torch
20
+
21
+ import legacy
22
+
23
+ #----------------------------------------------------------------------------
24
+
25
+ def num_range(s: str) -> List[int]:
26
+ '''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.'''
27
+
28
+ range_re = re.compile(r'^(\d+)-(\d+)$')
29
+ m = range_re.match(s)
30
+ if m:
31
+ return list(range(int(m.group(1)), int(m.group(2))+1))
32
+ vals = s.split(',')
33
+ return [int(x) for x in vals]
34
+
35
+ #----------------------------------------------------------------------------
36
+
37
+ @click.command()
38
+ @click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
39
+ @click.option('--rows', 'row_seeds', type=num_range, help='Random seeds to use for image rows', required=True)
40
+ @click.option('--cols', 'col_seeds', type=num_range, help='Random seeds to use for image columns', required=True)
41
+ @click.option('--styles', 'col_styles', type=num_range, help='Style layer range', default='0-6', show_default=True)
42
+ @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True)
43
+ @click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
44
+ @click.option('--outdir', type=str, required=True)
45
+ def generate_style_mix(
46
+ network_pkl: str,
47
+ row_seeds: List[int],
48
+ col_seeds: List[int],
49
+ col_styles: List[int],
50
+ truncation_psi: float,
51
+ noise_mode: str,
52
+ outdir: str
53
+ ):
54
+ """Generate images using pretrained network pickle.
55
+
56
+ Examples:
57
+
58
+ \b
59
+ python style_mixing.py --outdir=out --rows=85,100,75,458,1500 --cols=55,821,1789,293 \\
60
+ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl
61
+ """
62
+ print('Loading networks from "%s"...' % network_pkl)
63
+ device = torch.device('cuda')
64
+ with dnnlib.util.open_url(network_pkl) as f:
65
+ G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore
66
+
67
+ os.makedirs(outdir, exist_ok=True)
68
+
69
+ print('Generating W vectors...')
70
+ all_seeds = list(set(row_seeds + col_seeds))
71
+ all_z = np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in all_seeds])
72
+ all_w = G.mapping(torch.from_numpy(all_z).to(device), None)
73
+ w_avg = G.mapping.w_avg
74
+ all_w = w_avg + (all_w - w_avg) * truncation_psi
75
+ w_dict = {seed: w for seed, w in zip(all_seeds, list(all_w))}
76
+
77
+ print('Generating images...')
78
+ all_images = G.synthesis(all_w, noise_mode=noise_mode)
79
+ all_images = (all_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy()
80
+ image_dict = {(seed, seed): image for seed, image in zip(all_seeds, list(all_images))}
81
+
82
+ print('Generating style-mixed images...')
83
+ for row_seed in row_seeds:
84
+ for col_seed in col_seeds:
85
+ w = w_dict[row_seed].clone()
86
+ w[col_styles] = w_dict[col_seed][col_styles]
87
+ image = G.synthesis(w[np.newaxis], noise_mode=noise_mode)
88
+ image = (image.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
89
+ image_dict[(row_seed, col_seed)] = image[0].cpu().numpy()
90
+
91
+ print('Saving images...')
92
+ os.makedirs(outdir, exist_ok=True)
93
+ for (row_seed, col_seed), image in image_dict.items():
94
+ PIL.Image.fromarray(image, 'RGB').save(f'{outdir}/{row_seed}-{col_seed}.png')
95
+
96
+ print('Saving image grid...')
97
+ W = G.img_resolution
98
+ H = G.img_resolution
99
+ canvas = PIL.Image.new('RGB', (W * (len(col_seeds) + 1), H * (len(row_seeds) + 1)), 'black')
100
+ for row_idx, row_seed in enumerate([0] + row_seeds):
101
+ for col_idx, col_seed in enumerate([0] + col_seeds):
102
+ if row_idx == 0 and col_idx == 0:
103
+ continue
104
+ key = (row_seed, col_seed)
105
+ if row_idx == 0:
106
+ key = (col_seed, col_seed)
107
+ if col_idx == 0:
108
+ key = (row_seed, row_seed)
109
+ canvas.paste(PIL.Image.fromarray(image_dict[key], 'RGB'), (W * col_idx, H * row_idx))
110
+ canvas.save(f'{outdir}/grid.png')
111
+
112
+
113
+ #----------------------------------------------------------------------------
114
+
115
+ if __name__ == "__main__":
116
+ generate_style_mix() # pylint: disable=no-value-for-parameter
117
+
118
+ #----------------------------------------------------------------------------
diffusion-insgen/torch_utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ 
2
+ # empty
diffusion-insgen/torch_utils/custom_ops.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import glob
4
+ import torch
5
+ import torch.utils.cpp_extension
6
+ import importlib
7
+ import hashlib
8
+ import shutil
9
+ from pathlib import Path
10
+
11
+ from torch.utils.file_baton import FileBaton
12
+
13
+ #----------------------------------------------------------------------------
14
+ # Global options.
15
+
16
+ verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
17
+
18
+ #----------------------------------------------------------------------------
19
+ # Internal helper funcs.
20
+
21
+ def _find_compiler_bindir():
22
+ patterns = [
23
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
24
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
25
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
26
+ 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
27
+ ]
28
+ for pattern in patterns:
29
+ matches = sorted(glob.glob(pattern))
30
+ if len(matches):
31
+ return matches[-1]
32
+ return None
33
+
34
+ #----------------------------------------------------------------------------
35
+ # Main entry point for compiling and loading C++/CUDA plugins.
36
+
37
+ _cached_plugins = dict()
38
+
39
+ def get_plugin(module_name, sources, **build_kwargs):
40
+ assert verbosity in ['none', 'brief', 'full']
41
+
42
+ # Already cached?
43
+ if module_name in _cached_plugins:
44
+ return _cached_plugins[module_name]
45
+
46
+ # Print status.
47
+ if verbosity == 'full':
48
+ print(f'Setting up PyTorch plugin "{module_name}"...')
49
+ elif verbosity == 'brief':
50
+ print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
51
+
52
+ try: # pylint: disable=too-many-nested-blocks
53
+ # Make sure we can find the necessary compiler binaries.
54
+ if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
55
+ compiler_bindir = _find_compiler_bindir()
56
+ if compiler_bindir is None:
57
+ raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
58
+ os.environ['PATH'] += ';' + compiler_bindir
59
+
60
+ # Compile and load.
61
+ verbose_build = (verbosity == 'full')
62
+
63
+ # Incremental build md5sum trickery. Copies all the input source files
64
+ # into a cached build directory under a combined md5 digest of the input
65
+ # source files. Copying is done only if the combined digest has changed.
66
+ # This keeps input file timestamps and filenames the same as in previous
67
+ # extension builds, allowing for fast incremental rebuilds.
68
+ #
69
+ # This optimization is done only in case all the source files reside in
70
+ # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
71
+ # environment variable is set (we take this as a signal that the user
72
+ # actually cares about this.)
73
+ source_dirs_set = set(os.path.dirname(source) for source in sources)
74
+ if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ):
75
+ all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file()))
76
+
77
+ # Compute a combined hash digest for all source files in the same
78
+ # custom op directory (usually .cu, .cpp, .py and .h files).
79
+ hash_md5 = hashlib.md5()
80
+ for src in all_source_files:
81
+ with open(src, 'rb') as f:
82
+ hash_md5.update(f.read())
83
+ build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
84
+ digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest())
85
+
86
+ if not os.path.isdir(digest_build_dir):
87
+ os.makedirs(digest_build_dir, exist_ok=True)
88
+ baton = FileBaton(os.path.join(digest_build_dir, 'lock'))
89
+ if baton.try_acquire():
90
+ try:
91
+ for src in all_source_files:
92
+ shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src)))
93
+ finally:
94
+ baton.release()
95
+ else:
96
+ # Someone else is copying source files under the digest dir,
97
+ # wait until done and continue.
98
+ baton.wait()
99
+ digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources]
100
+ torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir,
101
+ verbose=verbose_build, sources=digest_sources, **build_kwargs)
102
+ else:
103
+ torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
104
+ module = importlib.import_module(module_name)
105
+
106
+ except:
107
+ if verbosity == 'brief':
108
+ print('Failed!')
109
+ raise
110
+
111
+ # Print status and add to cache.
112
+ if verbosity == 'full':
113
+ print(f'Done setting up PyTorch plugin "{module_name}".')
114
+ elif verbosity == 'brief':
115
+ print('Done.')
116
+ _cached_plugins[module_name] = module
117
+ return module
118
+
119
+ #----------------------------------------------------------------------------
diffusion-insgen/torch_utils/misc.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 
2
+ import re
3
+ import contextlib
4
+ import numpy as np
5
+ import torch
6
+ import warnings
7
+ import dnnlib
8
+
9
+ #----------------------------------------------------------------------------
10
+ # Cached construction of constant tensors. Avoids CPU=>GPU copy when the
11
+ # same constant is used multiple times.
12
+
13
+ _constant_cache = dict()
14
+
15
+ def constant(value, shape=None, dtype=None, device=None, memory_format=None):
16
+ value = np.asarray(value)
17
+ if shape is not None:
18
+ shape = tuple(shape)
19
+ if dtype is None:
20
+ dtype = torch.get_default_dtype()
21
+ if device is None:
22
+ device = torch.device('cpu')
23
+ if memory_format is None:
24
+ memory_format = torch.contiguous_format
25
+
26
+ key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
27
+ tensor = _constant_cache.get(key, None)
28
+ if tensor is None:
29
+ tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
30
+ if shape is not None:
31
+ tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
32
+ tensor = tensor.contiguous(memory_format=memory_format)
33
+ _constant_cache[key] = tensor
34
+ return tensor
35
+
36
+ #----------------------------------------------------------------------------
37
+ # Replace NaN/Inf with specified numerical values.
38
+
39
+ try:
40
+ nan_to_num = torch.nan_to_num # 1.8.0a0
41
+ except AttributeError:
42
+ def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
43
+ assert isinstance(input, torch.Tensor)
44
+ if posinf is None:
45
+ posinf = torch.finfo(input.dtype).max
46
+ if neginf is None:
47
+ neginf = torch.finfo(input.dtype).min
48
+ assert nan == 0
49
+ return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
50
+
51
+ #----------------------------------------------------------------------------
52
+ # Symbolic assert.
53
+
54
+ try:
55
+ symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
56
+ except AttributeError:
57
+ symbolic_assert = torch.Assert # 1.7.0
58
+
59
+ #----------------------------------------------------------------------------
60
+ # Context manager to suppress known warnings in torch.jit.trace().
61
+
62
+ class suppress_tracer_warnings(warnings.catch_warnings):
63
+ def __enter__(self):
64
+ super().__enter__()
65
+ warnings.simplefilter('ignore', category=torch.jit.TracerWarning)
66
+ return self
67
+
68
+ #----------------------------------------------------------------------------
69
+ # Assert that the shape of a tensor matches the given list of integers.
70
+ # None indicates that the size of a dimension is allowed to vary.
71
+ # Performs symbolic assertion when used in torch.jit.trace().
72
+
73
+ def assert_shape(tensor, ref_shape):
74
+ if tensor.ndim != len(ref_shape):
75
+ raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
76
+ for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
77
+ if ref_size is None:
78
+ pass
79
+ elif isinstance(ref_size, torch.Tensor):
80
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
81
+ symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
82
+ elif isinstance(size, torch.Tensor):
83
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
84
+ symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
85
+ elif size != ref_size:
86
+ raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
87
+
88
+ #----------------------------------------------------------------------------
89
+ # Function decorator that calls torch.autograd.profiler.record_function().
90
+
91
+ def profiled_function(fn):
92
+ def decorator(*args, **kwargs):
93
+ with torch.autograd.profiler.record_function(fn.__name__):
94
+ return fn(*args, **kwargs)
95
+ decorator.__name__ = fn.__name__
96
+ return decorator
97
+
98
+ #----------------------------------------------------------------------------
99
+ # Sampler for torch.utils.data.DataLoader that loops over the dataset
100
+ # indefinitely, shuffling items as it goes.
101
+
102
+ class InfiniteSampler(torch.utils.data.Sampler):
103
+ def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
104
+ assert len(dataset) > 0
105
+ assert num_replicas > 0
106
+ assert 0 <= rank < num_replicas
107
+ assert 0 <= window_size <= 1
108
+ super().__init__(dataset)
109
+ self.dataset = dataset
110
+ self.rank = rank
111
+ self.num_replicas = num_replicas
112
+ self.shuffle = shuffle
113
+ self.seed = seed
114
+ self.window_size = window_size
115
+
116
+ def __iter__(self):
117
+ order = np.arange(len(self.dataset))
118
+ rnd = None
119
+ window = 0
120
+ if self.shuffle:
121
+ rnd = np.random.RandomState(self.seed)
122
+ rnd.shuffle(order)
123
+ window = int(np.rint(order.size * self.window_size))
124
+
125
+ idx = 0
126
+ while True:
127
+ i = idx % order.size
128
+ if idx % self.num_replicas == self.rank:
129
+ yield order[i]
130
+ if window >= 2:
131
+ j = (i - rnd.randint(window)) % order.size
132
+ order[i], order[j] = order[j], order[i]
133
+ idx += 1
134
+
135
+ #----------------------------------------------------------------------------
136
+ # Utilities for operating with torch.nn.Module parameters and buffers.
137
+
138
+ def params_and_buffers(module):
139
+ assert isinstance(module, torch.nn.Module)
140
+ return list(module.parameters()) + list(module.buffers())
141
+
142
+ def named_params_and_buffers(module):
143
+ assert isinstance(module, torch.nn.Module)
144
+ return list(module.named_parameters()) + list(module.named_buffers())
145
+
146
+ def copy_params_and_buffers(src_module, dst_module, require_all=False):
147
+ assert isinstance(src_module, torch.nn.Module)
148
+ assert isinstance(dst_module, torch.nn.Module)
149
+ src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)}
150
+ for name, tensor in named_params_and_buffers(dst_module):
151
+ assert (name in src_tensors) or (not require_all)
152
+ if name in src_tensors:
153
+ tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
154
+
155
+ #----------------------------------------------------------------------------
156
+ # Context manager for easily enabling/disabling DistributedDataParallel
157
+ # synchronization.
158
+
159
+ @contextlib.contextmanager
160
+ def ddp_sync(module, sync):
161
+ assert isinstance(module, torch.nn.Module)
162
+ if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
163
+ yield
164
+ else:
165
+ with module.no_sync():
166
+ yield
167
+
168
+ #----------------------------------------------------------------------------
169
+ # Check DistributedDataParallel consistency across processes.
170
+
171
+ def check_ddp_consistency(module, ignore_regex=None):
172
+ assert isinstance(module, torch.nn.Module)
173
+ for name, tensor in named_params_and_buffers(module):
174
+ fullname = type(module).__name__ + '.' + name
175
+ if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
176
+ continue
177
+ tensor = tensor.detach()
178
+ other = tensor.clone()
179
+ torch.distributed.broadcast(tensor=other, src=0)
180
+ assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname
181
+
182
+ #----------------------------------------------------------------------------
183
+ # Print summary table of module hierarchy.
184
+
185
+ def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
186
+ assert isinstance(module, torch.nn.Module)
187
+ assert not isinstance(module, torch.jit.ScriptModule)
188
+ assert isinstance(inputs, (tuple, list))
189
+
190
+ # Register hooks.
191
+ entries = []
192
+ nesting = [0]
193
+ def pre_hook(_mod, _inputs):
194
+ nesting[0] += 1
195
+ def post_hook(mod, _inputs, outputs):
196
+ nesting[0] -= 1
197
+ if nesting[0] <= max_nesting:
198
+ outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
199
+ outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
200
+ entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
201
+ hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
202
+ hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
203
+
204
+ # Run module.
205
+ outputs = module(*inputs)
206
+ for hook in hooks:
207
+ hook.remove()
208
+
209
+ # Identify unique outputs, parameters, and buffers.
210
+ tensors_seen = set()
211
+ for e in entries:
212
+ e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
213
+ e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
214
+ e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
215
+ tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
216
+
217
+ # Filter out redundant entries.
218
+ if skip_redundant:
219
+ entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
220
+
221
+ # Construct table.
222
+ rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
223
+ rows += [['---'] * len(rows[0])]
224
+ param_total = 0
225
+ buffer_total = 0
226
+ submodule_names = {mod: name for name, mod in module.named_modules()}
227
+ for e in entries:
228
+ name = '<top-level>' if e.mod is module else submodule_names[e.mod]
229
+ param_size = sum(t.numel() for t in e.unique_params)
230
+ buffer_size = sum(t.numel() for t in e.unique_buffers)
231
+ output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs]
232
+ output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
233
+ rows += [[
234
+ name + (':0' if len(e.outputs) >= 2 else ''),
235
+ str(param_size) if param_size else '-',
236
+ str(buffer_size) if buffer_size else '-',
237
+ (output_shapes + ['-'])[0],
238
+ (output_dtypes + ['-'])[0],
239
+ ]]
240
+ for idx in range(1, len(e.outputs)):
241
+ rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
242
+ param_total += param_size
243
+ buffer_total += buffer_size
244
+ rows += [['---'] * len(rows[0])]
245
+ rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
246
+
247
+ # Print table.
248
+ widths = [max(len(cell) for cell in column) for column in zip(*rows)]
249
+ print()
250
+ for row in rows:
251
+ print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
252
+ print()
253
+ return outputs
254
+
255
+ #----------------------------------------------------------------------------
256
+
257
+ import os
258
+
259
+ def get_ckpt_path(run_dir):
260
+ return os.path.join(run_dir, f'network-snapshot.pkl')
diffusion-insgen/torch_utils/ops/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ 
2
+ # empty
diffusion-insgen/torch_utils/ops/bias_act.cpp ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <torch/extension.h>
10
+ #include <ATen/cuda/CUDAContext.h>
11
+ #include <c10/cuda/CUDAGuard.h>
12
+ #include "bias_act.h"
13
+
14
+ //------------------------------------------------------------------------
15
+
16
+ static bool has_same_layout(torch::Tensor x, torch::Tensor y)
17
+ {
18
+ if (x.dim() != y.dim())
19
+ return false;
20
+ for (int64_t i = 0; i < x.dim(); i++)
21
+ {
22
+ if (x.size(i) != y.size(i))
23
+ return false;
24
+ if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
25
+ return false;
26
+ }
27
+ return true;
28
+ }
29
+
30
+ //------------------------------------------------------------------------
31
+
32
+ static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
33
+ {
34
+ // Validate arguments.
35
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
36
+ TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
37
+ TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
38
+ TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
39
+ TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
40
+ TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
41
+ TORCH_CHECK(b.dim() == 1, "b must have rank 1");
42
+ TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
43
+ TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
44
+ TORCH_CHECK(grad >= 0, "grad must be non-negative");
45
+
46
+ // Validate layout.
47
+ TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
48
+ TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
49
+ TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
50
+ TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
51
+ TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
52
+
53
+ // Create output tensor.
54
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
55
+ torch::Tensor y = torch::empty_like(x);
56
+ TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
57
+
58
+ // Initialize CUDA kernel parameters.
59
+ bias_act_kernel_params p;
60
+ p.x = x.data_ptr();
61
+ p.b = (b.numel()) ? b.data_ptr() : NULL;
62
+ p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
63
+ p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
64
+ p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
65
+ p.y = y.data_ptr();
66
+ p.grad = grad;
67
+ p.act = act;
68
+ p.alpha = alpha;
69
+ p.gain = gain;
70
+ p.clamp = clamp;
71
+ p.sizeX = (int)x.numel();
72
+ p.sizeB = (int)b.numel();
73
+ p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
74
+
75
+ // Choose CUDA kernel.
76
+ void* kernel;
77
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
78
+ {
79
+ kernel = choose_bias_act_kernel<scalar_t>(p);
80
+ });
81
+ TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
82
+
83
+ // Launch CUDA kernel.
84
+ p.loopX = 4;
85
+ int blockSize = 4 * 32;
86
+ int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
87
+ void* args[] = {&p};
88
+ AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
89
+ return y;
90
+ }
91
+
92
+ //------------------------------------------------------------------------
93
+
94
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
95
+ {
96
+ m.def("bias_act", &bias_act);
97
+ }
98
+
99
+ //------------------------------------------------------------------------
diffusion-insgen/torch_utils/ops/bias_act.cu ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <c10/util/Half.h>
10
+ #include "bias_act.h"
11
+
12
+ //------------------------------------------------------------------------
13
+ // Helpers.
14
+
15
+ template <class T> struct InternalType;
16
+ template <> struct InternalType<double> { typedef double scalar_t; };
17
+ template <> struct InternalType<float> { typedef float scalar_t; };
18
+ template <> struct InternalType<c10::Half> { typedef float scalar_t; };
19
+
20
+ //------------------------------------------------------------------------
21
+ // CUDA kernel.
22
+
23
+ template <class T, int A>
24
+ __global__ void bias_act_kernel(bias_act_kernel_params p)
25
+ {
26
+ typedef typename InternalType<T>::scalar_t scalar_t;
27
+ int G = p.grad;
28
+ scalar_t alpha = (scalar_t)p.alpha;
29
+ scalar_t gain = (scalar_t)p.gain;
30
+ scalar_t clamp = (scalar_t)p.clamp;
31
+ scalar_t one = (scalar_t)1;
32
+ scalar_t two = (scalar_t)2;
33
+ scalar_t expRange = (scalar_t)80;
34
+ scalar_t halfExpRange = (scalar_t)40;
35
+ scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
36
+ scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
37
+
38
+ // Loop over elements.
39
+ int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
40
+ for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
41
+ {
42
+ // Load.
43
+ scalar_t x = (scalar_t)((const T*)p.x)[xi];
44
+ scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
45
+ scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
46
+ scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
47
+ scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
48
+ scalar_t yy = (gain != 0) ? yref / gain : 0;
49
+ scalar_t y = 0;
50
+
51
+ // Apply bias.
52
+ ((G == 0) ? x : xref) += b;
53
+
54
+ // linear
55
+ if (A == 1)
56
+ {
57
+ if (G == 0) y = x;
58
+ if (G == 1) y = x;
59
+ }
60
+
61
+ // relu
62
+ if (A == 2)
63
+ {
64
+ if (G == 0) y = (x > 0) ? x : 0;
65
+ if (G == 1) y = (yy > 0) ? x : 0;
66
+ }
67
+
68
+ // lrelu
69
+ if (A == 3)
70
+ {
71
+ if (G == 0) y = (x > 0) ? x : x * alpha;
72
+ if (G == 1) y = (yy > 0) ? x : x * alpha;
73
+ }
74
+
75
+ // tanh
76
+ if (A == 4)
77
+ {
78
+ if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
79
+ if (G == 1) y = x * (one - yy * yy);
80
+ if (G == 2) y = x * (one - yy * yy) * (-two * yy);
81
+ }
82
+
83
+ // sigmoid
84
+ if (A == 5)
85
+ {
86
+ if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
87
+ if (G == 1) y = x * yy * (one - yy);
88
+ if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
89
+ }
90
+
91
+ // elu
92
+ if (A == 6)
93
+ {
94
+ if (G == 0) y = (x >= 0) ? x : exp(x) - one;
95
+ if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
96
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
97
+ }
98
+
99
+ // selu
100
+ if (A == 7)
101
+ {
102
+ if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
103
+ if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
104
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
105
+ }
106
+
107
+ // softplus
108
+ if (A == 8)
109
+ {
110
+ if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
111
+ if (G == 1) y = x * (one - exp(-yy));
112
+ if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
113
+ }
114
+
115
+ // swish
116
+ if (A == 9)
117
+ {
118
+ if (G == 0)
119
+ y = (x < -expRange) ? 0 : x / (exp(-x) + one);
120
+ else
121
+ {
122
+ scalar_t c = exp(xref);
123
+ scalar_t d = c + one;
124
+ if (G == 1)
125
+ y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
126
+ else
127
+ y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
128
+ yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
129
+ }
130
+ }
131
+
132
+ // Apply gain.
133
+ y *= gain * dy;
134
+
135
+ // Clamp.
136
+ if (clamp >= 0)
137
+ {
138
+ if (G == 0)
139
+ y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
140
+ else
141
+ y = (yref > -clamp & yref < clamp) ? y : 0;
142
+ }
143
+
144
+ // Store.
145
+ ((T*)p.y)[xi] = (T)y;
146
+ }
147
+ }
148
+
149
+ //------------------------------------------------------------------------
150
+ // CUDA kernel selection.
151
+
152
+ template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p)
153
+ {
154
+ if (p.act == 1) return (void*)bias_act_kernel<T, 1>;
155
+ if (p.act == 2) return (void*)bias_act_kernel<T, 2>;
156
+ if (p.act == 3) return (void*)bias_act_kernel<T, 3>;
157
+ if (p.act == 4) return (void*)bias_act_kernel<T, 4>;
158
+ if (p.act == 5) return (void*)bias_act_kernel<T, 5>;
159
+ if (p.act == 6) return (void*)bias_act_kernel<T, 6>;
160
+ if (p.act == 7) return (void*)bias_act_kernel<T, 7>;
161
+ if (p.act == 8) return (void*)bias_act_kernel<T, 8>;
162
+ if (p.act == 9) return (void*)bias_act_kernel<T, 9>;
163
+ return NULL;
164
+ }
165
+
166
+ //------------------------------------------------------------------------
167
+ // Template specializations.
168
+
169
+ template void* choose_bias_act_kernel<double> (const bias_act_kernel_params& p);
170
+ template void* choose_bias_act_kernel<float> (const bias_act_kernel_params& p);
171
+ template void* choose_bias_act_kernel<c10::Half> (const bias_act_kernel_params& p);
172
+
173
+ //------------------------------------------------------------------------
diffusion-insgen/torch_utils/ops/bias_act.h ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ //------------------------------------------------------------------------
10
+ // CUDA kernel parameters.
11
+
12
+ struct bias_act_kernel_params
13
+ {
14
+ const void* x; // [sizeX]
15
+ const void* b; // [sizeB] or NULL
16
+ const void* xref; // [sizeX] or NULL
17
+ const void* yref; // [sizeX] or NULL
18
+ const void* dy; // [sizeX] or NULL
19
+ void* y; // [sizeX]
20
+
21
+ int grad;
22
+ int act;
23
+ float alpha;
24
+ float gain;
25
+ float clamp;
26
+
27
+ int sizeX;
28
+ int sizeB;
29
+ int stepB;
30
+ int loopX;
31
+ };
32
+
33
+ //------------------------------------------------------------------------
34
+ // CUDA kernel selection.
35
+
36
+ template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p);
37
+
38
+ //------------------------------------------------------------------------
diffusion-insgen/torch_utils/ops/bias_act.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """Custom PyTorch ops for efficient bias and activation."""
3
+
4
+ import os
5
+ import warnings
6
+ import numpy as np
7
+ import torch
8
+ import dnnlib
9
+ import traceback
10
+
11
+ from .. import custom_ops
12
+ from .. import misc
13
+
14
+ #----------------------------------------------------------------------------
15
+
16
+ activation_funcs = {
17
+ 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
18
+ 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
19
+ 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
20
+ 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
21
+ 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
22
+ 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
23
+ 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
24
+ 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
25
+ 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
26
+ }
27
+
28
+ #----------------------------------------------------------------------------
29
+
30
+ _inited = False
31
+ _plugin = None
32
+ _null_tensor = torch.empty([0])
33
+
34
+ def _init():
35
+ global _inited, _plugin
36
+ if not _inited:
37
+ _inited = True
38
+ sources = ['bias_act.cpp', 'bias_act.cu']
39
+ sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
40
+ try:
41
+ _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
42
+ except:
43
+ warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
44
+ return _plugin is not None
45
+
46
+ #----------------------------------------------------------------------------
47
+
48
+ def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
49
+ r"""Fused bias and activation function.
50
+
51
+ Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
52
+ and scales the result by `gain`. Each of the steps is optional. In most cases,
53
+ the fused op is considerably more efficient than performing the same calculation
54
+ using standard PyTorch ops. It supports first and second order gradients,
55
+ but not third order gradients.
56
+
57
+ Args:
58
+ x: Input activation tensor. Can be of any shape.
59
+ b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
60
+ as `x`. The shape must be known, and it must match the dimension of `x`
61
+ corresponding to `dim`.
62
+ dim: The dimension in `x` corresponding to the elements of `b`.
63
+ The value of `dim` is ignored if `b` is not specified.
64
+ act: Name of the activation function to evaluate, or `"linear"` to disable.
65
+ Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
66
+ See `activation_funcs` for a full list. `None` is not allowed.
67
+ alpha: Shape parameter for the activation function, or `None` to use the default.
68
+ gain: Scaling factor for the output tensor, or `None` to use default.
69
+ See `activation_funcs` for the default scaling of each activation function.
70
+ If unsure, consider specifying 1.
71
+ clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
72
+ the clamping (default).
73
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
74
+
75
+ Returns:
76
+ Tensor of the same shape and datatype as `x`.
77
+ """
78
+ assert isinstance(x, torch.Tensor)
79
+ assert impl in ['ref', 'cuda']
80
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
81
+ return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
82
+ return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
83
+
84
+ #----------------------------------------------------------------------------
85
+
86
+ @misc.profiled_function
87
+ def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
88
+ """Slow reference implementation of `bias_act()` using standard TensorFlow ops.
89
+ """
90
+ assert isinstance(x, torch.Tensor)
91
+ assert clamp is None or clamp >= 0
92
+ spec = activation_funcs[act]
93
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
94
+ gain = float(gain if gain is not None else spec.def_gain)
95
+ clamp = float(clamp if clamp is not None else -1)
96
+
97
+ # Add bias.
98
+ if b is not None:
99
+ assert isinstance(b, torch.Tensor) and b.ndim == 1
100
+ assert 0 <= dim < x.ndim
101
+ assert b.shape[0] == x.shape[dim]
102
+ x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
103
+
104
+ # Evaluate activation function.
105
+ alpha = float(alpha)
106
+ x = spec.func(x, alpha=alpha)
107
+
108
+ # Scale by gain.
109
+ gain = float(gain)
110
+ if gain != 1:
111
+ x = x * gain
112
+
113
+ # Clamp.
114
+ if clamp >= 0:
115
+ x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
116
+ return x
117
+
118
+ #----------------------------------------------------------------------------
119
+
120
+ _bias_act_cuda_cache = dict()
121
+
122
+ def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
123
+ """Fast CUDA implementation of `bias_act()` using custom ops.
124
+ """
125
+ # Parse arguments.
126
+ assert clamp is None or clamp >= 0
127
+ spec = activation_funcs[act]
128
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
129
+ gain = float(gain if gain is not None else spec.def_gain)
130
+ clamp = float(clamp if clamp is not None else -1)
131
+
132
+ # Lookup from cache.
133
+ key = (dim, act, alpha, gain, clamp)
134
+ if key in _bias_act_cuda_cache:
135
+ return _bias_act_cuda_cache[key]
136
+
137
+ # Forward op.
138
+ class BiasActCuda(torch.autograd.Function):
139
+ @staticmethod
140
+ def forward(ctx, x, b): # pylint: disable=arguments-differ
141
+ ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format
142
+ x = x.contiguous(memory_format=ctx.memory_format)
143
+ b = b.contiguous() if b is not None else _null_tensor
144
+ y = x
145
+ if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
146
+ y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
147
+ ctx.save_for_backward(
148
+ x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
149
+ b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
150
+ y if 'y' in spec.ref else _null_tensor)
151
+ return y
152
+
153
+ @staticmethod
154
+ def backward(ctx, dy): # pylint: disable=arguments-differ
155
+ dy = dy.contiguous(memory_format=ctx.memory_format)
156
+ x, b, y = ctx.saved_tensors
157
+ dx = None
158
+ db = None
159
+
160
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
161
+ dx = dy
162
+ if act != 'linear' or gain != 1 or clamp >= 0:
163
+ dx = BiasActCudaGrad.apply(dy, x, b, y)
164
+
165
+ if ctx.needs_input_grad[1]:
166
+ db = dx.sum([i for i in range(dx.ndim) if i != dim])
167
+
168
+ return dx, db
169
+
170
+ # Backward op.
171
+ class BiasActCudaGrad(torch.autograd.Function):
172
+ @staticmethod
173
+ def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
174
+ ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format
175
+ dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
176
+ ctx.save_for_backward(
177
+ dy if spec.has_2nd_grad else _null_tensor,
178
+ x, b, y)
179
+ return dx
180
+
181
+ @staticmethod
182
+ def backward(ctx, d_dx): # pylint: disable=arguments-differ
183
+ d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
184
+ dy, x, b, y = ctx.saved_tensors
185
+ d_dy = None
186
+ d_x = None
187
+ d_b = None
188
+ d_y = None
189
+
190
+ if ctx.needs_input_grad[0]:
191
+ d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
192
+
193
+ if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
194
+ d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
195
+
196
+ if spec.has_2nd_grad and ctx.needs_input_grad[2]:
197
+ d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
198
+
199
+ return d_dy, d_x, d_b, d_y
200
+
201
+ # Add to cache.
202
+ _bias_act_cuda_cache[key] = BiasActCuda
203
+ return BiasActCuda
204
+
205
+ #----------------------------------------------------------------------------
diffusion-insgen/torch_utils/ops/conv2d_gradfix.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """Custom replacement for `torch.nn.functional.conv2d` that supports
3
+ arbitrarily high order gradients with zero performance penalty."""
4
+
5
+ import warnings
6
+ import contextlib
7
+ import torch
8
+ from distutils.version import LooseVersion
9
+
10
+ # pylint: disable=redefined-builtin
11
+ # pylint: disable=arguments-differ
12
+ # pylint: disable=protected-access
13
+
14
+ #----------------------------------------------------------------------------
15
+
16
+ enabled = False # Enable the custom op by setting this to true.
17
+ weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
18
+ old_version = LooseVersion(torch.__version__) < LooseVersion('1.11.0')
19
+
20
+ @contextlib.contextmanager
21
+ def no_weight_gradients():
22
+ global weight_gradients_disabled
23
+ old = weight_gradients_disabled
24
+ weight_gradients_disabled = True
25
+ yield
26
+ weight_gradients_disabled = old
27
+
28
+ #----------------------------------------------------------------------------
29
+
30
+ def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
31
+ if _should_use_custom_op(input):
32
+ return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
33
+ return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
34
+
35
+ def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
36
+ if _should_use_custom_op(input):
37
+ return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
38
+ return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
39
+
40
+ #----------------------------------------------------------------------------
41
+
42
+ def _should_use_custom_op(input):
43
+ assert isinstance(input, torch.Tensor)
44
+ if (not enabled) or (not torch.backends.cudnn.enabled):
45
+ return False
46
+ if input.device.type != 'cuda':
47
+ return False
48
+ if LooseVersion(torch.__version__) >= LooseVersion('1.7.0'):
49
+ return True
50
+ warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().')
51
+ return False
52
+
53
+ def _tuple_of_ints(xs, ndim):
54
+ xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
55
+ assert len(xs) == ndim
56
+ assert all(isinstance(x, int) for x in xs)
57
+ return xs
58
+
59
+ #----------------------------------------------------------------------------
60
+
61
+ _conv2d_gradfix_cache = dict()
62
+
63
+ def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
64
+ # Parse arguments.
65
+ ndim = 2
66
+ weight_shape = tuple(weight_shape)
67
+ stride = _tuple_of_ints(stride, ndim)
68
+ padding = _tuple_of_ints(padding, ndim)
69
+ output_padding = _tuple_of_ints(output_padding, ndim)
70
+ dilation = _tuple_of_ints(dilation, ndim)
71
+
72
+ # Lookup from cache.
73
+ key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
74
+ if key in _conv2d_gradfix_cache:
75
+ return _conv2d_gradfix_cache[key]
76
+
77
+ # Validate arguments.
78
+ assert groups >= 1
79
+ assert len(weight_shape) == ndim + 2
80
+ assert all(stride[i] >= 1 for i in range(ndim))
81
+ assert all(padding[i] >= 0 for i in range(ndim))
82
+ assert all(dilation[i] >= 0 for i in range(ndim))
83
+ if not transpose:
84
+ assert all(output_padding[i] == 0 for i in range(ndim))
85
+ else: # transpose
86
+ assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
87
+
88
+ # Helpers.
89
+ common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
90
+ def calc_output_padding(input_shape, output_shape):
91
+ if transpose:
92
+ return [0, 0]
93
+ return [
94
+ input_shape[i + 2]
95
+ - (output_shape[i + 2] - 1) * stride[i]
96
+ - (1 - 2 * padding[i])
97
+ - dilation[i] * (weight_shape[i + 2] - 1)
98
+ for i in range(ndim)
99
+ ]
100
+
101
+ # Forward & backward.
102
+ class Conv2d(torch.autograd.Function):
103
+ @staticmethod
104
+ def forward(ctx, input, weight, bias):
105
+ assert weight.shape == weight_shape
106
+ if not transpose:
107
+ output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
108
+ else: # transpose
109
+ output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
110
+ ctx.save_for_backward(input, weight, bias)
111
+ return output
112
+
113
+ @staticmethod
114
+ def backward(ctx, grad_output):
115
+ input, weight, bias = ctx.saved_tensors
116
+ grad_input = None
117
+ grad_weight = None
118
+ grad_bias = None
119
+
120
+ if ctx.needs_input_grad[0]:
121
+ p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
122
+ grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None)
123
+ assert grad_input.shape == input.shape
124
+
125
+ if ctx.needs_input_grad[1] and not weight_gradients_disabled:
126
+ grad_weight = Conv2dGradWeight.apply(grad_output, input, bias)
127
+ assert grad_weight.shape == weight_shape
128
+
129
+ if ctx.needs_input_grad[2]:
130
+ grad_bias = grad_output.sum([0, 2, 3])
131
+
132
+ return grad_input, grad_weight, grad_bias
133
+
134
+ # Gradient with respect to the weights.
135
+ class Conv2dGradWeight(torch.autograd.Function):
136
+ @staticmethod
137
+ def forward(ctx, grad_output, input, bias):
138
+ if old_version:
139
+ op = torch._C._jit_get_operation(
140
+ 'aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight')
141
+ flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic,
142
+ torch.backends.cudnn.allow_tf32]
143
+ grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
144
+ else:
145
+ bias_shape = bias.shape if (bias is not None) else None
146
+ empty_weight = torch.empty(weight_shape, dtype=input.dtype, layout=input.layout, device=input.device)
147
+ grad_weight = torch.ops.aten.convolution_backward(grad_output, input, empty_weight, bias_sizes=bias_shape, stride=stride, padding=padding, dilation=dilation, transposed=transpose, output_padding=output_padding, groups=groups, output_mask=[0,1,0])[1]
148
+ assert grad_weight.shape == weight_shape
149
+ ctx.save_for_backward(grad_output, input)
150
+ return grad_weight
151
+
152
+ @staticmethod
153
+ def backward(ctx, grad2_grad_weight):
154
+ grad_output, input = ctx.saved_tensors
155
+ grad2_grad_output = None
156
+ grad2_input = None
157
+
158
+ if ctx.needs_input_grad[0]:
159
+ grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
160
+ assert grad2_grad_output.shape == grad_output.shape
161
+
162
+ if ctx.needs_input_grad[1]:
163
+ p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
164
+ grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None)
165
+ assert grad2_input.shape == input.shape
166
+
167
+ return grad2_grad_output, grad2_input
168
+
169
+ _conv2d_gradfix_cache[key] = Conv2d
170
+ return Conv2d
171
+
172
+ #----------------------------------------------------------------------------
diffusion-insgen/torch_utils/ops/conv2d_resample.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """2D convolution with optional up/downsampling."""
3
+
4
+ import torch
5
+
6
+ from .. import misc
7
+ from . import conv2d_gradfix
8
+ from . import upfirdn2d
9
+ from .upfirdn2d import _parse_padding
10
+ from .upfirdn2d import _get_filter_size
11
+
12
+ #----------------------------------------------------------------------------
13
+
14
+ def _get_weight_shape(w):
15
+ with misc.suppress_tracer_warnings(): # this value will be treated as a constant
16
+ shape = [int(sz) for sz in w.shape]
17
+ misc.assert_shape(w, shape)
18
+ return shape
19
+
20
+ #----------------------------------------------------------------------------
21
+
22
+ def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
23
+ """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
24
+ """
25
+ out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
26
+
27
+ # Flip weight if requested.
28
+ if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
29
+ w = w.flip([2, 3])
30
+
31
+ # Workaround performance pitfall in cuDNN 8.0.5, triggered when using
32
+ # 1x1 kernel + memory_format=channels_last + less than 64 channels.
33
+ if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose:
34
+ if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64:
35
+ if out_channels <= 4 and groups == 1:
36
+ in_shape = x.shape
37
+ x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1])
38
+ x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]])
39
+ else:
40
+ x = x.to(memory_format=torch.contiguous_format)
41
+ w = w.to(memory_format=torch.contiguous_format)
42
+ x = conv2d_gradfix.conv2d(x, w, groups=groups)
43
+ return x.to(memory_format=torch.channels_last)
44
+
45
+ # Otherwise => execute using conv2d_gradfix.
46
+ op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
47
+ return op(x, w, stride=stride, padding=padding, groups=groups)
48
+
49
+ #----------------------------------------------------------------------------
50
+
51
+ @misc.profiled_function
52
+ def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
53
+ r"""2D convolution with optional up/downsampling.
54
+
55
+ Padding is performed only once at the beginning, not between the operations.
56
+
57
+ Args:
58
+ x: Input tensor of shape
59
+ `[batch_size, in_channels, in_height, in_width]`.
60
+ w: Weight tensor of shape
61
+ `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
62
+ f: Low-pass filter for up/downsampling. Must be prepared beforehand by
63
+ calling upfirdn2d.setup_filter(). None = identity (default).
64
+ up: Integer upsampling factor (default: 1).
65
+ down: Integer downsampling factor (default: 1).
66
+ padding: Padding with respect to the upsampled image. Can be a single number
67
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
68
+ (default: 0).
69
+ groups: Split input channels into N groups (default: 1).
70
+ flip_weight: False = convolution, True = correlation (default: True).
71
+ flip_filter: False = convolution, True = correlation (default: False).
72
+
73
+ Returns:
74
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
75
+ """
76
+ # Validate arguments.
77
+ assert isinstance(x, torch.Tensor) and (x.ndim == 4)
78
+ assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
79
+ assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
80
+ assert isinstance(up, int) and (up >= 1)
81
+ assert isinstance(down, int) and (down >= 1)
82
+ assert isinstance(groups, int) and (groups >= 1)
83
+ out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
84
+ fw, fh = _get_filter_size(f)
85
+ px0, px1, py0, py1 = _parse_padding(padding)
86
+
87
+ # Adjust padding to account for up/downsampling.
88
+ if up > 1:
89
+ px0 += (fw + up - 1) // 2
90
+ px1 += (fw - up) // 2
91
+ py0 += (fh + up - 1) // 2
92
+ py1 += (fh - up) // 2
93
+ if down > 1:
94
+ px0 += (fw - down + 1) // 2
95
+ px1 += (fw - down) // 2
96
+ py0 += (fh - down + 1) // 2
97
+ py1 += (fh - down) // 2
98
+
99
+ # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
100
+ if kw == 1 and kh == 1 and (down > 1 and up == 1):
101
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
102
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
103
+ return x
104
+
105
+ # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
106
+ if kw == 1 and kh == 1 and (up > 1 and down == 1):
107
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
108
+ x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
109
+ return x
110
+
111
+ # Fast path: downsampling only => use strided convolution.
112
+ if down > 1 and up == 1:
113
+ x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
114
+ x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
115
+ return x
116
+
117
+ # Fast path: upsampling with optional downsampling => use transpose strided convolution.
118
+ if up > 1:
119
+ if groups == 1:
120
+ w = w.transpose(0, 1)
121
+ else:
122
+ w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
123
+ w = w.transpose(1, 2)
124
+ w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
125
+ px0 -= kw - 1
126
+ px1 -= kw - up
127
+ py0 -= kh - 1
128
+ py1 -= kh - up
129
+ pxt = max(min(-px0, -px1), 0)
130
+ pyt = max(min(-py0, -py1), 0)
131
+ x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
132
+ x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
133
+ if down > 1:
134
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
135
+ return x
136
+
137
+ # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
138
+ if up == 1 and down == 1:
139
+ if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
140
+ return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
141
+
142
+ # Fallback: Generic reference implementation.
143
+ x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
144
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
145
+ if down > 1:
146
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
147
+ return x
148
+
149
+ #----------------------------------------------------------------------------
diffusion-insgen/torch_utils/ops/fma.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
3
+
4
+ import torch
5
+
6
+ #----------------------------------------------------------------------------
7
+
8
+ def fma(a, b, c): # => a * b + c
9
+ return _FusedMultiplyAdd.apply(a, b, c)
10
+
11
+ #----------------------------------------------------------------------------
12
+
13
+ class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
14
+ @staticmethod
15
+ def forward(ctx, a, b, c): # pylint: disable=arguments-differ
16
+ out = torch.addcmul(c, a, b)
17
+ ctx.save_for_backward(a, b)
18
+ ctx.c_shape = c.shape
19
+ return out
20
+
21
+ @staticmethod
22
+ def backward(ctx, dout): # pylint: disable=arguments-differ
23
+ a, b = ctx.saved_tensors
24
+ c_shape = ctx.c_shape
25
+ da = None
26
+ db = None
27
+ dc = None
28
+
29
+ if ctx.needs_input_grad[0]:
30
+ da = _unbroadcast(dout * b, a.shape)
31
+
32
+ if ctx.needs_input_grad[1]:
33
+ db = _unbroadcast(dout * a, b.shape)
34
+
35
+ if ctx.needs_input_grad[2]:
36
+ dc = _unbroadcast(dout, c_shape)
37
+
38
+ return da, db, dc
39
+
40
+ #----------------------------------------------------------------------------
41
+
42
+ def _unbroadcast(x, shape):
43
+ extra_dims = x.ndim - len(shape)
44
+ assert extra_dims >= 0
45
+ dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
46
+ if len(dim):
47
+ x = x.sum(dim=dim, keepdim=True)
48
+ if extra_dims:
49
+ x = x.reshape(-1, *x.shape[extra_dims+1:])
50
+ assert x.shape == shape
51
+ return x
52
+
53
+ #----------------------------------------------------------------------------
diffusion-insgen/torch_utils/ops/grid_sample_gradfix.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """Custom replacement for `torch.nn.functional.grid_sample` that
3
+ supports arbitrarily high order gradients between the input and output.
4
+ Only works on 2D images and assumes
5
+ `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
6
+
7
+ import warnings
8
+ import torch
9
+ from distutils.version import LooseVersion
10
+
11
+ # pylint: disable=redefined-builtin
12
+ # pylint: disable=arguments-differ
13
+ # pylint: disable=protected-access
14
+
15
+ #----------------------------------------------------------------------------
16
+
17
+ enabled = False # Enable the custom op by setting this to true.
18
+
19
+ #----------------------------------------------------------------------------
20
+
21
+ def grid_sample(input, grid):
22
+ if _should_use_custom_op():
23
+ return _GridSample2dForward.apply(input, grid)
24
+ return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
25
+
26
+ #----------------------------------------------------------------------------
27
+
28
+ def _should_use_custom_op():
29
+ if not enabled:
30
+ return False
31
+ if LooseVersion(torch.__version__) >= LooseVersion('1.7.0'):
32
+ return True
33
+ warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().')
34
+ return False
35
+
36
+ #----------------------------------------------------------------------------
37
+
38
+ class _GridSample2dForward(torch.autograd.Function):
39
+ @staticmethod
40
+ def forward(ctx, input, grid):
41
+ assert input.ndim == 4
42
+ assert grid.ndim == 4
43
+ output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
44
+ ctx.save_for_backward(input, grid)
45
+ return output
46
+
47
+ @staticmethod
48
+ def backward(ctx, grad_output):
49
+ input, grid = ctx.saved_tensors
50
+ grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
51
+ return grad_input, grad_grid
52
+
53
+ #----------------------------------------------------------------------------
54
+
55
+ class _GridSample2dBackward(torch.autograd.Function):
56
+ @staticmethod
57
+ def forward(ctx, grad_output, input, grid):
58
+ op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
59
+ grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
60
+ ctx.save_for_backward(grid)
61
+ return grad_input, grad_grid
62
+
63
+ @staticmethod
64
+ def backward(ctx, grad2_grad_input, grad2_grad_grid):
65
+ _ = grad2_grad_grid # unused
66
+ grid, = ctx.saved_tensors
67
+ grad2_grad_output = None
68
+ grad2_input = None
69
+ grad2_grid = None
70
+
71
+ if ctx.needs_input_grad[0]:
72
+ grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
73
+
74
+ assert not ctx.needs_input_grad[2]
75
+ return grad2_grad_output, grad2_input, grad2_grid
76
+
77
+ #----------------------------------------------------------------------------
diffusion-insgen/torch_utils/ops/upfirdn2d.cpp ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <torch/extension.h>
10
+ #include <ATen/cuda/CUDAContext.h>
11
+ #include <c10/cuda/CUDAGuard.h>
12
+ #include "upfirdn2d.h"
13
+
14
+ //------------------------------------------------------------------------
15
+
16
+ static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
17
+ {
18
+ // Validate arguments.
19
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
20
+ TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
21
+ TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
22
+ TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
23
+ TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
24
+ TORCH_CHECK(x.dim() == 4, "x must be rank 4");
25
+ TORCH_CHECK(f.dim() == 2, "f must be rank 2");
26
+ TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
27
+ TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
28
+ TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
29
+
30
+ // Create output tensor.
31
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
32
+ int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
33
+ int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
34
+ TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
35
+ torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
36
+ TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
37
+
38
+ // Initialize CUDA kernel parameters.
39
+ upfirdn2d_kernel_params p;
40
+ p.x = x.data_ptr();
41
+ p.f = f.data_ptr<float>();
42
+ p.y = y.data_ptr();
43
+ p.up = make_int2(upx, upy);
44
+ p.down = make_int2(downx, downy);
45
+ p.pad0 = make_int2(padx0, pady0);
46
+ p.flip = (flip) ? 1 : 0;
47
+ p.gain = gain;
48
+ p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
49
+ p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
50
+ p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
51
+ p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
52
+ p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
53
+ p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
54
+ p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
55
+ p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
56
+
57
+ // Choose CUDA kernel.
58
+ upfirdn2d_kernel_spec spec;
59
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
60
+ {
61
+ spec = choose_upfirdn2d_kernel<scalar_t>(p);
62
+ });
63
+
64
+ // Set looping options.
65
+ p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
66
+ p.loopMinor = spec.loopMinor;
67
+ p.loopX = spec.loopX;
68
+ p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
69
+ p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
70
+
71
+ // Compute grid size.
72
+ dim3 blockSize, gridSize;
73
+ if (spec.tileOutW < 0) // large
74
+ {
75
+ blockSize = dim3(4, 32, 1);
76
+ gridSize = dim3(
77
+ ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
78
+ (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
79
+ p.launchMajor);
80
+ }
81
+ else // small
82
+ {
83
+ blockSize = dim3(256, 1, 1);
84
+ gridSize = dim3(
85
+ ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
86
+ (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
87
+ p.launchMajor);
88
+ }
89
+
90
+ // Launch CUDA kernel.
91
+ void* args[] = {&p};
92
+ AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
93
+ return y;
94
+ }
95
+
96
+ //------------------------------------------------------------------------
97
+
98
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
99
+ {
100
+ m.def("upfirdn2d", &upfirdn2d);
101
+ }
102
+
103
+ //------------------------------------------------------------------------
diffusion-insgen/torch_utils/ops/upfirdn2d.cu ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <c10/util/Half.h>
10
+ #include "upfirdn2d.h"
11
+
12
+ //------------------------------------------------------------------------
13
+ // Helpers.
14
+
15
+ template <class T> struct InternalType;
16
+ template <> struct InternalType<double> { typedef double scalar_t; };
17
+ template <> struct InternalType<float> { typedef float scalar_t; };
18
+ template <> struct InternalType<c10::Half> { typedef float scalar_t; };
19
+
20
+ static __device__ __forceinline__ int floor_div(int a, int b)
21
+ {
22
+ int t = 1 - a / b;
23
+ return (a + t * b) / b - t;
24
+ }
25
+
26
+ //------------------------------------------------------------------------
27
+ // Generic CUDA implementation for large filters.
28
+
29
+ template <class T> static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p)
30
+ {
31
+ typedef typename InternalType<T>::scalar_t scalar_t;
32
+
33
+ // Calculate thread index.
34
+ int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
35
+ int outY = minorBase / p.launchMinor;
36
+ minorBase -= outY * p.launchMinor;
37
+ int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
38
+ int majorBase = blockIdx.z * p.loopMajor;
39
+ if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
40
+ return;
41
+
42
+ // Setup Y receptive field.
43
+ int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
44
+ int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
45
+ int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
46
+ int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
47
+ if (p.flip)
48
+ filterY = p.filterSize.y - 1 - filterY;
49
+
50
+ // Loop over major, minor, and X.
51
+ for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
52
+ for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor)
53
+ {
54
+ int nc = major * p.sizeMinor + minor;
55
+ int n = nc / p.inSize.z;
56
+ int c = nc - n * p.inSize.z;
57
+ for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y)
58
+ {
59
+ // Setup X receptive field.
60
+ int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
61
+ int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
62
+ int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX;
63
+ int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
64
+ if (p.flip)
65
+ filterX = p.filterSize.x - 1 - filterX;
66
+
67
+ // Initialize pointers.
68
+ const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
69
+ const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
70
+ int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
71
+ int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
72
+
73
+ // Inner loop.
74
+ scalar_t v = 0;
75
+ for (int y = 0; y < h; y++)
76
+ {
77
+ for (int x = 0; x < w; x++)
78
+ {
79
+ v += (scalar_t)(*xp) * (scalar_t)(*fp);
80
+ xp += p.inStride.x;
81
+ fp += filterStepX;
82
+ }
83
+ xp += p.inStride.y - w * p.inStride.x;
84
+ fp += filterStepY - w * filterStepX;
85
+ }
86
+
87
+ // Store result.
88
+ v *= p.gain;
89
+ ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
90
+ }
91
+ }
92
+ }
93
+
94
+ //------------------------------------------------------------------------
95
+ // Specialized CUDA implementation for small filters.
96
+
97
+ template <class T, int upx, int upy, int downx, int downy, int filterW, int filterH, int tileOutW, int tileOutH, int loopMinor>
98
+ static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p)
99
+ {
100
+ typedef typename InternalType<T>::scalar_t scalar_t;
101
+ const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
102
+ const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
103
+ __shared__ volatile scalar_t sf[filterH][filterW];
104
+ __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
105
+
106
+ // Calculate tile index.
107
+ int minorBase = blockIdx.x;
108
+ int tileOutY = minorBase / p.launchMinor;
109
+ minorBase -= tileOutY * p.launchMinor;
110
+ minorBase *= loopMinor;
111
+ tileOutY *= tileOutH;
112
+ int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
113
+ int majorBase = blockIdx.z * p.loopMajor;
114
+ if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor)
115
+ return;
116
+
117
+ // Load filter (flipped).
118
+ for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x)
119
+ {
120
+ int fy = tapIdx / filterW;
121
+ int fx = tapIdx - fy * filterW;
122
+ scalar_t v = 0;
123
+ if (fx < p.filterSize.x & fy < p.filterSize.y)
124
+ {
125
+ int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
126
+ int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
127
+ v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
128
+ }
129
+ sf[fy][fx] = v;
130
+ }
131
+
132
+ // Loop over major and X.
133
+ for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
134
+ {
135
+ int baseNC = major * p.sizeMinor + minorBase;
136
+ int n = baseNC / p.inSize.z;
137
+ int baseC = baseNC - n * p.inSize.z;
138
+ for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW)
139
+ {
140
+ // Load input pixels.
141
+ int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
142
+ int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
143
+ int tileInX = floor_div(tileMidX, upx);
144
+ int tileInY = floor_div(tileMidY, upy);
145
+ __syncthreads();
146
+ for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x)
147
+ {
148
+ int relC = inIdx;
149
+ int relInX = relC / loopMinor;
150
+ int relInY = relInX / tileInW;
151
+ relC -= relInX * loopMinor;
152
+ relInX -= relInY * tileInW;
153
+ int c = baseC + relC;
154
+ int inX = tileInX + relInX;
155
+ int inY = tileInY + relInY;
156
+ scalar_t v = 0;
157
+ if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z)
158
+ v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
159
+ sx[relInY][relInX][relC] = v;
160
+ }
161
+
162
+ // Loop over output pixels.
163
+ __syncthreads();
164
+ for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x)
165
+ {
166
+ int relC = outIdx;
167
+ int relOutX = relC / loopMinor;
168
+ int relOutY = relOutX / tileOutW;
169
+ relC -= relOutX * loopMinor;
170
+ relOutX -= relOutY * tileOutW;
171
+ int c = baseC + relC;
172
+ int outX = tileOutX + relOutX;
173
+ int outY = tileOutY + relOutY;
174
+
175
+ // Setup receptive field.
176
+ int midX = tileMidX + relOutX * downx;
177
+ int midY = tileMidY + relOutY * downy;
178
+ int inX = floor_div(midX, upx);
179
+ int inY = floor_div(midY, upy);
180
+ int relInX = inX - tileInX;
181
+ int relInY = inY - tileInY;
182
+ int filterX = (inX + 1) * upx - midX - 1; // flipped
183
+ int filterY = (inY + 1) * upy - midY - 1; // flipped
184
+
185
+ // Inner loop.
186
+ if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z)
187
+ {
188
+ scalar_t v = 0;
189
+ #pragma unroll
190
+ for (int y = 0; y < filterH / upy; y++)
191
+ #pragma unroll
192
+ for (int x = 0; x < filterW / upx; x++)
193
+ v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx];
194
+ v *= p.gain;
195
+ ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
196
+ }
197
+ }
198
+ }
199
+ }
200
+ }
201
+
202
+ //------------------------------------------------------------------------
203
+ // CUDA kernel selection.
204
+
205
+ template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p)
206
+ {
207
+ int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
208
+
209
+ upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,1, 4}; // contiguous
210
+ if (s == 1) spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,4, 1}; // channels_last
211
+
212
+ if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
213
+ {
214
+ if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 64,16,1>, 64,16,1, 1};
215
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
216
+ if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5, 64,16,1>, 64,16,1, 1};
217
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
218
+ if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3, 64,16,1>, 64,16,1, 1};
219
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
220
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 20,1, 128,8,1>, 128,8,1, 1};
221
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
222
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 12,1, 128,8,1>, 128,8,1, 1};
223
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
224
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
225
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,20, 32,32,1>, 32,32,1, 1};
226
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
227
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,12, 32,32,1>, 32,32,1, 1};
228
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
229
+ }
230
+ if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
231
+ {
232
+ if (fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 16,16,8>, 16,16,8, 1};
233
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
234
+ if (fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
235
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
236
+ if (fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
237
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
238
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 20,1, 128,1,16>, 128,1,16, 1};
239
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
240
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 12,1, 128,1,16>, 128,1,16, 1};
241
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
242
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
243
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,20, 1,128,16>, 1,128,16, 1};
244
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
245
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,12, 1,128,16>, 1,128,16, 1};
246
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
247
+ }
248
+ if (s != 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
249
+ {
250
+ if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 64,16,1>, 64,16,1, 1};
251
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
252
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
253
+ if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 64,16,1>, 64,16,1, 1};
254
+ }
255
+ if (s == 1 && p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
256
+ {
257
+ if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 16,16,8>, 16,16,8, 1};
258
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 16,16,8>, 16,16,8, 1};
259
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
260
+ if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 16,16,8>, 16,16,8, 1};
261
+ }
262
+ if (s != 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // contiguous
263
+ {
264
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
265
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 20,1, 128,8,1>, 128,8,1, 1};
266
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
267
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 12,1, 128,8,1>, 128,8,1, 1};
268
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
269
+ }
270
+ if (s == 1 && p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) // channels_last
271
+ {
272
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
273
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 20,1, 128,1,16>, 128,1,16, 1};
274
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
275
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 12,1, 128,1,16>, 128,1,16, 1};
276
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
277
+ }
278
+ if (s != 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // contiguous
279
+ {
280
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
281
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,20, 32,32,1>, 32,32,1, 1};
282
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
283
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,12, 32,32,1>, 32,32,1, 1};
284
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
285
+ }
286
+ if (s == 1 && p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) // channels_last
287
+ {
288
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
289
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,20, 1,128,16>, 1,128,16, 1};
290
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
291
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,12, 1,128,16>, 1,128,16, 1};
292
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
293
+ }
294
+ if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // contiguous
295
+ {
296
+ if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 32,8,1>, 32,8,1, 1};
297
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 32,8,1>, 32,8,1, 1};
298
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 32,8,1>, 32,8,1, 1};
299
+ if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 32,8,1>, 32,8,1, 1};
300
+ }
301
+ if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) // channels_last
302
+ {
303
+ if (fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 8,8,8>, 8,8,8, 1};
304
+ if (fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 8,8,8>, 8,8,8, 1};
305
+ if (fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 8,8,8>, 8,8,8, 1};
306
+ if (fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 8,8,8>, 8,8,8, 1};
307
+ }
308
+ if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // contiguous
309
+ {
310
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,8,1>, 64,8,1, 1};
311
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 20,1, 64,8,1>, 64,8,1, 1};
312
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,8,1>, 64,8,1, 1};
313
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 12,1, 64,8,1>, 64,8,1, 1};
314
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,8,1>, 64,8,1, 1};
315
+ }
316
+ if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) // channels_last
317
+ {
318
+ if (fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,1,8>, 64,1,8, 1};
319
+ if (fx <= 20 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 20,1, 64,1,8>, 64,1,8, 1};
320
+ if (fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,1,8>, 64,1,8, 1};
321
+ if (fx <= 12 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 12,1, 64,1,8>, 64,1,8, 1};
322
+ if (fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,1,8>, 64,1,8, 1};
323
+ }
324
+ if (s != 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // contiguous
325
+ {
326
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 32,16,1>, 32,16,1, 1};
327
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,20, 32,16,1>, 32,16,1, 1};
328
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 32,16,1>, 32,16,1, 1};
329
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,12, 32,16,1>, 32,16,1, 1};
330
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 32,16,1>, 32,16,1, 1};
331
+ }
332
+ if (s == 1 && p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) // channels_last
333
+ {
334
+ if (fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 1,64,8>, 1,64,8, 1};
335
+ if (fx <= 1 && fy <= 20) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,20, 1,64,8>, 1,64,8, 1};
336
+ if (fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 1,64,8>, 1,64,8, 1};
337
+ if (fx <= 1 && fy <= 12) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,12, 1,64,8>, 1,64,8, 1};
338
+ if (fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 1,64,8>, 1,64,8, 1};
339
+ }
340
+ return spec;
341
+ }
342
+
343
+ //------------------------------------------------------------------------
344
+ // Template specializations.
345
+
346
+ template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<double> (const upfirdn2d_kernel_params& p);
347
+ template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<float> (const upfirdn2d_kernel_params& p);
348
+ template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<c10::Half>(const upfirdn2d_kernel_params& p);
349
+
350
+ //------------------------------------------------------------------------
diffusion-insgen/torch_utils/ops/upfirdn2d.h ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <cuda_runtime.h>
10
+
11
+ //------------------------------------------------------------------------
12
+ // CUDA kernel parameters.
13
+
14
+ struct upfirdn2d_kernel_params
15
+ {
16
+ const void* x;
17
+ const float* f;
18
+ void* y;
19
+
20
+ int2 up;
21
+ int2 down;
22
+ int2 pad0;
23
+ int flip;
24
+ float gain;
25
+
26
+ int4 inSize; // [width, height, channel, batch]
27
+ int4 inStride;
28
+ int2 filterSize; // [width, height]
29
+ int2 filterStride;
30
+ int4 outSize; // [width, height, channel, batch]
31
+ int4 outStride;
32
+ int sizeMinor;
33
+ int sizeMajor;
34
+
35
+ int loopMinor;
36
+ int loopMajor;
37
+ int loopX;
38
+ int launchMinor;
39
+ int launchMajor;
40
+ };
41
+
42
+ //------------------------------------------------------------------------
43
+ // CUDA kernel specialization.
44
+
45
+ struct upfirdn2d_kernel_spec
46
+ {
47
+ void* kernel;
48
+ int tileOutW;
49
+ int tileOutH;
50
+ int loopMinor;
51
+ int loopX;
52
+ };
53
+
54
+ //------------------------------------------------------------------------
55
+ // CUDA kernel selection.
56
+
57
+ template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
58
+
59
+ //------------------------------------------------------------------------
diffusion-insgen/torch_utils/ops/upfirdn2d.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """Custom PyTorch ops for efficient resampling of 2D images."""
3
+
4
+ import os
5
+ import warnings
6
+ import numpy as np
7
+ import torch
8
+ import traceback
9
+
10
+ from .. import custom_ops
11
+ from .. import misc
12
+ from . import conv2d_gradfix
13
+
14
+ #----------------------------------------------------------------------------
15
+
16
+ _inited = False
17
+ _plugin = None
18
+
19
+ def _init():
20
+ global _inited, _plugin
21
+ if not _inited:
22
+ sources = ['upfirdn2d.cpp', 'upfirdn2d.cu']
23
+ sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
24
+ try:
25
+ _plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
26
+ except:
27
+ warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
28
+ return _plugin is not None
29
+
30
+ def _parse_scaling(scaling):
31
+ if isinstance(scaling, int):
32
+ scaling = [scaling, scaling]
33
+ assert isinstance(scaling, (list, tuple))
34
+ assert all(isinstance(x, int) for x in scaling)
35
+ sx, sy = scaling
36
+ assert sx >= 1 and sy >= 1
37
+ return sx, sy
38
+
39
+ def _parse_padding(padding):
40
+ if isinstance(padding, int):
41
+ padding = [padding, padding]
42
+ assert isinstance(padding, (list, tuple))
43
+ assert all(isinstance(x, int) for x in padding)
44
+ if len(padding) == 2:
45
+ padx, pady = padding
46
+ padding = [padx, padx, pady, pady]
47
+ padx0, padx1, pady0, pady1 = padding
48
+ return padx0, padx1, pady0, pady1
49
+
50
+ def _get_filter_size(f):
51
+ if f is None:
52
+ return 1, 1
53
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
54
+ fw = f.shape[-1]
55
+ fh = f.shape[0]
56
+ with misc.suppress_tracer_warnings():
57
+ fw = int(fw)
58
+ fh = int(fh)
59
+ misc.assert_shape(f, [fh, fw][:f.ndim])
60
+ assert fw >= 1 and fh >= 1
61
+ return fw, fh
62
+
63
+ #----------------------------------------------------------------------------
64
+
65
+ def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None):
66
+ r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
67
+
68
+ Args:
69
+ f: Torch tensor, numpy array, or python list of the shape
70
+ `[filter_height, filter_width]` (non-separable),
71
+ `[filter_taps]` (separable),
72
+ `[]` (impulse), or
73
+ `None` (identity).
74
+ device: Result device (default: cpu).
75
+ normalize: Normalize the filter so that it retains the magnitude
76
+ for constant input signal (DC)? (default: True).
77
+ flip_filter: Flip the filter? (default: False).
78
+ gain: Overall scaling factor for signal magnitude (default: 1).
79
+ separable: Return a separable filter? (default: select automatically).
80
+
81
+ Returns:
82
+ Float32 tensor of the shape
83
+ `[filter_height, filter_width]` (non-separable) or
84
+ `[filter_taps]` (separable).
85
+ """
86
+ # Validate.
87
+ if f is None:
88
+ f = 1
89
+ f = torch.as_tensor(f, dtype=torch.float32)
90
+ assert f.ndim in [0, 1, 2]
91
+ assert f.numel() > 0
92
+ if f.ndim == 0:
93
+ f = f[np.newaxis]
94
+
95
+ # Separable?
96
+ if separable is None:
97
+ separable = (f.ndim == 1 and f.numel() >= 8)
98
+ if f.ndim == 1 and not separable:
99
+ f = f.ger(f)
100
+ assert f.ndim == (1 if separable else 2)
101
+
102
+ # Apply normalize, flip, gain, and device.
103
+ if normalize:
104
+ f /= f.sum()
105
+ if flip_filter:
106
+ f = f.flip(list(range(f.ndim)))
107
+ f = f * (gain ** (f.ndim / 2))
108
+ f = f.to(device=device)
109
+ return f
110
+
111
+ #----------------------------------------------------------------------------
112
+
113
+ def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
114
+ r"""Pad, upsample, filter, and downsample a batch of 2D images.
115
+
116
+ Performs the following sequence of operations for each channel:
117
+
118
+ 1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
119
+
120
+ 2. Pad the image with the specified number of zeros on each side (`padding`).
121
+ Negative padding corresponds to cropping the image.
122
+
123
+ 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
124
+ so that the footprint of all output pixels lies within the input image.
125
+
126
+ 4. Downsample the image by keeping every Nth pixel (`down`).
127
+
128
+ This sequence of operations bears close resemblance to scipy.signal.upfirdn().
129
+ The fused op is considerably more efficient than performing the same calculation
130
+ using standard PyTorch ops. It supports gradients of arbitrary order.
131
+
132
+ Args:
133
+ x: Float32/float64/float16 input tensor of the shape
134
+ `[batch_size, num_channels, in_height, in_width]`.
135
+ f: Float32 FIR filter of the shape
136
+ `[filter_height, filter_width]` (non-separable),
137
+ `[filter_taps]` (separable), or
138
+ `None` (identity).
139
+ up: Integer upsampling factor. Can be a single int or a list/tuple
140
+ `[x, y]` (default: 1).
141
+ down: Integer downsampling factor. Can be a single int or a list/tuple
142
+ `[x, y]` (default: 1).
143
+ padding: Padding with respect to the upsampled image. Can be a single number
144
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
145
+ (default: 0).
146
+ flip_filter: False = convolution, True = correlation (default: False).
147
+ gain: Overall scaling factor for signal magnitude (default: 1).
148
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
149
+
150
+ Returns:
151
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
152
+ """
153
+ assert isinstance(x, torch.Tensor)
154
+ assert impl in ['ref', 'cuda']
155
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
156
+ return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f)
157
+ return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
158
+
159
+ #----------------------------------------------------------------------------
160
+
161
+ @misc.profiled_function
162
+ def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
163
+ """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
164
+ """
165
+ # Validate arguments.
166
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
167
+ if f is None:
168
+ f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
169
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
170
+ assert f.dtype == torch.float32 and not f.requires_grad
171
+ batch_size, num_channels, in_height, in_width = x.shape
172
+ upx, upy = _parse_scaling(up)
173
+ downx, downy = _parse_scaling(down)
174
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
175
+
176
+ # Upsample by inserting zeros.
177
+ x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
178
+ x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
179
+ x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
180
+
181
+ # Pad or crop.
182
+ x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
183
+ x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)]
184
+
185
+ # Setup filter.
186
+ f = f * (gain ** (f.ndim / 2))
187
+ f = f.to(x.dtype)
188
+ if not flip_filter:
189
+ f = f.flip(list(range(f.ndim)))
190
+
191
+ # Convolve with the filter.
192
+ f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
193
+ if f.ndim == 4:
194
+ x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels)
195
+ else:
196
+ x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
197
+ x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
198
+
199
+ # Downsample by throwing away pixels.
200
+ x = x[:, :, ::downy, ::downx]
201
+ return x
202
+
203
+ #----------------------------------------------------------------------------
204
+
205
+ _upfirdn2d_cuda_cache = dict()
206
+
207
+ def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1):
208
+ """Fast CUDA implementation of `upfirdn2d()` using custom ops.
209
+ """
210
+ # Parse arguments.
211
+ upx, upy = _parse_scaling(up)
212
+ downx, downy = _parse_scaling(down)
213
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
214
+
215
+ # Lookup from cache.
216
+ key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
217
+ if key in _upfirdn2d_cuda_cache:
218
+ return _upfirdn2d_cuda_cache[key]
219
+
220
+ # Forward op.
221
+ class Upfirdn2dCuda(torch.autograd.Function):
222
+ @staticmethod
223
+ def forward(ctx, x, f): # pylint: disable=arguments-differ
224
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
225
+ if f is None:
226
+ f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
227
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
228
+ y = x
229
+ if f.ndim == 2:
230
+ y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
231
+ else:
232
+ y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain))
233
+ y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain))
234
+ ctx.save_for_backward(f)
235
+ ctx.x_shape = x.shape
236
+ return y
237
+
238
+ @staticmethod
239
+ def backward(ctx, dy): # pylint: disable=arguments-differ
240
+ f, = ctx.saved_tensors
241
+ _, _, ih, iw = ctx.x_shape
242
+ _, _, oh, ow = dy.shape
243
+ fw, fh = _get_filter_size(f)
244
+ p = [
245
+ fw - padx0 - 1,
246
+ iw * upx - ow * downx + padx0 - upx + 1,
247
+ fh - pady0 - 1,
248
+ ih * upy - oh * downy + pady0 - upy + 1,
249
+ ]
250
+ dx = None
251
+ df = None
252
+
253
+ if ctx.needs_input_grad[0]:
254
+ dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f)
255
+
256
+ assert not ctx.needs_input_grad[1]
257
+ return dx, df
258
+
259
+ # Add to cache.
260
+ _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda
261
+ return Upfirdn2dCuda
262
+
263
+ #----------------------------------------------------------------------------
264
+
265
+ def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'):
266
+ r"""Filter a batch of 2D images using the given 2D FIR filter.
267
+
268
+ By default, the result is padded so that its shape matches the input.
269
+ User-specified padding is applied on top of that, with negative values
270
+ indicating cropping. Pixels outside the image are assumed to be zero.
271
+
272
+ Args:
273
+ x: Float32/float64/float16 input tensor of the shape
274
+ `[batch_size, num_channels, in_height, in_width]`.
275
+ f: Float32 FIR filter of the shape
276
+ `[filter_height, filter_width]` (non-separable),
277
+ `[filter_taps]` (separable), or
278
+ `None` (identity).
279
+ padding: Padding with respect to the output. Can be a single number or a
280
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
281
+ (default: 0).
282
+ flip_filter: False = convolution, True = correlation (default: False).
283
+ gain: Overall scaling factor for signal magnitude (default: 1).
284
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
285
+
286
+ Returns:
287
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
288
+ """
289
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
290
+ fw, fh = _get_filter_size(f)
291
+ p = [
292
+ padx0 + fw // 2,
293
+ padx1 + (fw - 1) // 2,
294
+ pady0 + fh // 2,
295
+ pady1 + (fh - 1) // 2,
296
+ ]
297
+ return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
298
+
299
+ #----------------------------------------------------------------------------
300
+
301
+ def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
302
+ r"""Upsample a batch of 2D images using the given 2D FIR filter.
303
+
304
+ By default, the result is padded so that its shape is a multiple of the input.
305
+ User-specified padding is applied on top of that, with negative values
306
+ indicating cropping. Pixels outside the image are assumed to be zero.
307
+
308
+ Args:
309
+ x: Float32/float64/float16 input tensor of the shape
310
+ `[batch_size, num_channels, in_height, in_width]`.
311
+ f: Float32 FIR filter of the shape
312
+ `[filter_height, filter_width]` (non-separable),
313
+ `[filter_taps]` (separable), or
314
+ `None` (identity).
315
+ up: Integer upsampling factor. Can be a single int or a list/tuple
316
+ `[x, y]` (default: 1).
317
+ padding: Padding with respect to the output. Can be a single number or a
318
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
319
+ (default: 0).
320
+ flip_filter: False = convolution, True = correlation (default: False).
321
+ gain: Overall scaling factor for signal magnitude (default: 1).
322
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
323
+
324
+ Returns:
325
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
326
+ """
327
+ upx, upy = _parse_scaling(up)
328
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
329
+ fw, fh = _get_filter_size(f)
330
+ p = [
331
+ padx0 + (fw + upx - 1) // 2,
332
+ padx1 + (fw - upx) // 2,
333
+ pady0 + (fh + upy - 1) // 2,
334
+ pady1 + (fh - upy) // 2,
335
+ ]
336
+ return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl)
337
+
338
+ #----------------------------------------------------------------------------
339
+
340
+ def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
341
+ r"""Downsample a batch of 2D images using the given 2D FIR filter.
342
+
343
+ By default, the result is padded so that its shape is a fraction of the input.
344
+ User-specified padding is applied on top of that, with negative values
345
+ indicating cropping. Pixels outside the image are assumed to be zero.
346
+
347
+ Args:
348
+ x: Float32/float64/float16 input tensor of the shape
349
+ `[batch_size, num_channels, in_height, in_width]`.
350
+ f: Float32 FIR filter of the shape
351
+ `[filter_height, filter_width]` (non-separable),
352
+ `[filter_taps]` (separable), or
353
+ `None` (identity).
354
+ down: Integer downsampling factor. Can be a single int or a list/tuple
355
+ `[x, y]` (default: 1).
356
+ padding: Padding with respect to the input. Can be a single number or a
357
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
358
+ (default: 0).
359
+ flip_filter: False = convolution, True = correlation (default: False).
360
+ gain: Overall scaling factor for signal magnitude (default: 1).
361
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
362
+
363
+ Returns:
364
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
365
+ """
366
+ downx, downy = _parse_scaling(down)
367
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
368
+ fw, fh = _get_filter_size(f)
369
+ p = [
370
+ padx0 + (fw - downx + 1) // 2,
371
+ padx1 + (fw - downx) // 2,
372
+ pady0 + (fh - downy + 1) // 2,
373
+ pady1 + (fh - downy) // 2,
374
+ ]
375
+ return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
376
+
377
+ #----------------------------------------------------------------------------
diffusion-insgen/torch_utils/persistence.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 
2
+ """Facilities for pickling Python code alongside other data.
3
+
4
+ The pickled code is automatically imported into a separate Python module
5
+ during unpickling. This way, any previously exported pickles will remain
6
+ usable even if the original code is no longer available, or if the current
7
+ version of the code is not consistent with what was originally pickled."""
8
+
9
+ import sys
10
+ import pickle
11
+ import io
12
+ import inspect
13
+ import copy
14
+ import uuid
15
+ import types
16
+ import dnnlib
17
+
18
+ #----------------------------------------------------------------------------
19
+
20
+ _version = 6 # internal version number
21
+ _decorators = set() # {decorator_class, ...}
22
+ _import_hooks = [] # [hook_function, ...]
23
+ _module_to_src_dict = dict() # {module: src, ...}
24
+ _src_to_module_dict = dict() # {src: module, ...}
25
+
26
+ #----------------------------------------------------------------------------
27
+
28
+ def persistent_class(orig_class):
29
+ r"""Class decorator that extends a given class to save its source code
30
+ when pickled.
31
+
32
+ Example:
33
+
34
+ from torch_utils import persistence
35
+
36
+ @persistence.persistent_class
37
+ class MyNetwork(torch.nn.Module):
38
+ def __init__(self, num_inputs, num_outputs):
39
+ super().__init__()
40
+ self.fc = MyLayer(num_inputs, num_outputs)
41
+ ...
42
+
43
+ @persistence.persistent_class
44
+ class MyLayer(torch.nn.Module):
45
+ ...
46
+
47
+ When pickled, any instance of `MyNetwork` and `MyLayer` will save its
48
+ source code alongside other internal state (e.g., parameters, buffers,
49
+ and submodules). This way, any previously exported pickle will remain
50
+ usable even if the class definitions have been modified or are no
51
+ longer available.
52
+
53
+ The decorator saves the source code of the entire Python module
54
+ containing the decorated class. It does *not* save the source code of
55
+ any imported modules. Thus, the imported modules must be available
56
+ during unpickling, also including `torch_utils.persistence` itself.
57
+
58
+ It is ok to call functions defined in the same module from the
59
+ decorated class. However, if the decorated class depends on other
60
+ classes defined in the same module, they must be decorated as well.
61
+ This is illustrated in the above example in the case of `MyLayer`.
62
+
63
+ It is also possible to employ the decorator just-in-time before
64
+ calling the constructor. For example:
65
+
66
+ cls = MyLayer
67
+ if want_to_make_it_persistent:
68
+ cls = persistence.persistent_class(cls)
69
+ layer = cls(num_inputs, num_outputs)
70
+
71
+ As an additional feature, the decorator also keeps track of the
72
+ arguments that were used to construct each instance of the decorated
73
+ class. The arguments can be queried via `obj.init_args` and
74
+ `obj.init_kwargs`, and they are automatically pickled alongside other
75
+ object state. A typical use case is to first unpickle a previous
76
+ instance of a persistent class, and then upgrade it to use the latest
77
+ version of the source code:
78
+
79
+ with open('old_pickle.pkl', 'rb') as f:
80
+ old_net = pickle.load(f)
81
+ new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
82
+ misc.copy_params_and_buffers(old_net, new_net, require_all=True)
83
+ """
84
+ assert isinstance(orig_class, type)
85
+ if is_persistent(orig_class):
86
+ return orig_class
87
+
88
+ assert orig_class.__module__ in sys.modules
89
+ orig_module = sys.modules[orig_class.__module__]
90
+ orig_module_src = _module_to_src(orig_module)
91
+
92
+ class Decorator(orig_class):
93
+ _orig_module_src = orig_module_src
94
+ _orig_class_name = orig_class.__name__
95
+
96
+ def __init__(self, *args, **kwargs):
97
+ super().__init__(*args, **kwargs)
98
+ self._init_args = copy.deepcopy(args)
99
+ self._init_kwargs = copy.deepcopy(kwargs)
100
+ assert orig_class.__name__ in orig_module.__dict__
101
+ _check_pickleable(self.__reduce__())
102
+
103
+ @property
104
+ def init_args(self):
105
+ return copy.deepcopy(self._init_args)
106
+
107
+ @property
108
+ def init_kwargs(self):
109
+ return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs))
110
+
111
+ def __reduce__(self):
112
+ fields = list(super().__reduce__())
113
+ fields += [None] * max(3 - len(fields), 0)
114
+ if fields[0] is not _reconstruct_persistent_obj:
115
+ meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
116
+ fields[0] = _reconstruct_persistent_obj # reconstruct func
117
+ fields[1] = (meta,) # reconstruct args
118
+ fields[2] = None # state dict
119
+ return tuple(fields)
120
+
121
+ Decorator.__name__ = orig_class.__name__
122
+ _decorators.add(Decorator)
123
+ return Decorator
124
+
125
+ #----------------------------------------------------------------------------
126
+
127
+ def is_persistent(obj):
128
+ r"""Test whether the given object or class is persistent, i.e.,
129
+ whether it will save its source code when pickled.
130
+ """
131
+ try:
132
+ if obj in _decorators:
133
+ return True
134
+ except TypeError:
135
+ pass
136
+ return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
137
+
138
+ #----------------------------------------------------------------------------
139
+
140
+ def import_hook(hook):
141
+ r"""Register an import hook that is called whenever a persistent object
142
+ is being unpickled. A typical use case is to patch the pickled source
143
+ code to avoid errors and inconsistencies when the API of some imported
144
+ module has changed.
145
+
146
+ The hook should have the following signature:
147
+
148
+ hook(meta) -> modified meta
149
+
150
+ `meta` is an instance of `dnnlib.EasyDict` with the following fields:
151
+
152
+ type: Type of the persistent object, e.g. `'class'`.
153
+ version: Internal version number of `torch_utils.persistence`.
154
+ module_src Original source code of the Python module.
155
+ class_name: Class name in the original Python module.
156
+ state: Internal state of the object.
157
+
158
+ Example:
159
+
160
+ @persistence.import_hook
161
+ def wreck_my_network(meta):
162
+ if meta.class_name == 'MyNetwork':
163
+ print('MyNetwork is being imported. I will wreck it!')
164
+ meta.module_src = meta.module_src.replace("True", "False")
165
+ return meta
166
+ """
167
+ assert callable(hook)
168
+ _import_hooks.append(hook)
169
+
170
+ #----------------------------------------------------------------------------
171
+
172
+ def _reconstruct_persistent_obj(meta):
173
+ r"""Hook that is called internally by the `pickle` module to unpickle
174
+ a persistent object.
175
+ """
176
+ meta = dnnlib.EasyDict(meta)
177
+ meta.state = dnnlib.EasyDict(meta.state)
178
+ for hook in _import_hooks:
179
+ meta = hook(meta)
180
+ assert meta is not None
181
+
182
+ assert meta.version == _version
183
+ module = _src_to_module(meta.module_src)
184
+
185
+ assert meta.type == 'class'
186
+ orig_class = module.__dict__[meta.class_name]
187
+ decorator_class = persistent_class(orig_class)
188
+ obj = decorator_class.__new__(decorator_class)
189
+
190
+ setstate = getattr(obj, '__setstate__', None)
191
+ if callable(setstate):
192
+ setstate(meta.state) # pylint: disable=not-callable
193
+ else:
194
+ obj.__dict__.update(meta.state)
195
+ return obj
196
+
197
+ #----------------------------------------------------------------------------
198
+
199
+ def _module_to_src(module):
200
+ r"""Query the source code of a given Python module.
201
+ """
202
+ src = _module_to_src_dict.get(module, None)
203
+ if src is None:
204
+ src = inspect.getsource(module)
205
+ _module_to_src_dict[module] = src
206
+ _src_to_module_dict[src] = module
207
+ return src
208
+
209
+ def _src_to_module(src):
210
+ r"""Get or create a Python module for the given source code.
211
+ """
212
+ module = _src_to_module_dict.get(src, None)
213
+ if module is None:
214
+ module_name = "_imported_module_" + uuid.uuid4().hex
215
+ module = types.ModuleType(module_name)
216
+ sys.modules[module_name] = module
217
+ _module_to_src_dict[module] = src
218
+ _src_to_module_dict[src] = module
219
+ exec(src, module.__dict__) # pylint: disable=exec-used
220
+ return module
221
+
222
+ #----------------------------------------------------------------------------
223
+
224
+ def _check_pickleable(obj):
225
+ r"""Check that the given object is pickleable, raising an exception if
226
+ it is not. This function is expected to be considerably more efficient
227
+ than actually pickling the object.
228
+ """
229
+ def recurse(obj):
230
+ if isinstance(obj, (list, tuple, set)):
231
+ return [recurse(x) for x in obj]
232
+ if isinstance(obj, dict):
233
+ return [[recurse(x), recurse(y)] for x, y in obj.items()]
234
+ if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
235
+ return None # Python primitive types are pickleable.
236
+ if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']:
237
+ return None # NumPy arrays and PyTorch tensors are pickleable.
238
+ if is_persistent(obj):
239
+ return None # Persistent objects are pickleable, by virtue of the constructor check.
240
+ return obj
241
+ with io.BytesIO() as f:
242
+ pickle.dump(recurse(obj), f)
243
+
244
+ #----------------------------------------------------------------------------
diffusion-insgen/torch_utils/training_stats.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """Facilities for reporting and collecting training statistics across
3
+ multiple processes and devices. The interface is designed to minimize
4
+ synchronization overhead as well as the amount of boilerplate in user
5
+ code."""
6
+
7
+ import re
8
+ import numpy as np
9
+ import torch
10
+ import dnnlib
11
+
12
+ from . import misc
13
+
14
+ #----------------------------------------------------------------------------
15
+
16
+ _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
17
+ _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
18
+ _counter_dtype = torch.float64 # Data type to use for the internal counters.
19
+ _rank = 0 # Rank of the current process.
20
+ _sync_device = None # Device to use for multiprocess communication. None = single-process.
21
+ _sync_called = False # Has _sync() been called yet?
22
+ _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor
23
+ _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
24
+
25
+ #----------------------------------------------------------------------------
26
+
27
+ def init_multiprocessing(rank, sync_device):
28
+ r"""Initializes `torch_utils.training_stats` for collecting statistics
29
+ across multiple processes.
30
+
31
+ This function must be called after
32
+ `torch.distributed.init_process_group()` and before `Collector.update()`.
33
+ The call is not necessary if multi-process collection is not needed.
34
+
35
+ Args:
36
+ rank: Rank of the current process.
37
+ sync_device: PyTorch device to use for inter-process
38
+ communication, or None to disable multi-process
39
+ collection. Typically `torch.device('cuda', rank)`.
40
+ """
41
+ global _rank, _sync_device
42
+ assert not _sync_called
43
+ _rank = rank
44
+ _sync_device = sync_device
45
+
46
+ #----------------------------------------------------------------------------
47
+
48
+ @misc.profiled_function
49
+ def report(name, value):
50
+ r"""Broadcasts the given set of scalars to all interested instances of
51
+ `Collector`, across device and process boundaries.
52
+
53
+ This function is expected to be extremely cheap and can be safely
54
+ called from anywhere in the training loop, loss function, or inside a
55
+ `torch.nn.Module`.
56
+
57
+ Warning: The current implementation expects the set of unique names to
58
+ be consistent across processes. Please make sure that `report()` is
59
+ called at least once for each unique name by each process, and in the
60
+ same order. If a given process has no scalars to broadcast, it can do
61
+ `report(name, [])` (empty list).
62
+
63
+ Args:
64
+ name: Arbitrary string specifying the name of the statistic.
65
+ Averages are accumulated separately for each unique name.
66
+ value: Arbitrary set of scalars. Can be a list, tuple,
67
+ NumPy array, PyTorch tensor, or Python scalar.
68
+
69
+ Returns:
70
+ The same `value` that was passed in.
71
+ """
72
+ if name not in _counters:
73
+ _counters[name] = dict()
74
+
75
+ elems = torch.as_tensor(value)
76
+ if elems.numel() == 0:
77
+ return value
78
+
79
+ elems = elems.detach().flatten().to(_reduce_dtype)
80
+ moments = torch.stack([
81
+ torch.ones_like(elems).sum(),
82
+ elems.sum(),
83
+ elems.square().sum(),
84
+ ])
85
+ assert moments.ndim == 1 and moments.shape[0] == _num_moments
86
+ moments = moments.to(_counter_dtype)
87
+
88
+ device = moments.device
89
+ if device not in _counters[name]:
90
+ _counters[name][device] = torch.zeros_like(moments)
91
+ _counters[name][device].add_(moments)
92
+ return value
93
+
94
+ #----------------------------------------------------------------------------
95
+
96
+ def report0(name, value):
97
+ r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
98
+ but ignores any scalars provided by the other processes.
99
+ See `report()` for further details.
100
+ """
101
+ report(name, value if _rank == 0 else [])
102
+ return value
103
+
104
+ #----------------------------------------------------------------------------
105
+
106
+ class Collector:
107
+ r"""Collects the scalars broadcasted by `report()` and `report0()` and
108
+ computes their long-term averages (mean and standard deviation) over
109
+ user-defined periods of time.
110
+
111
+ The averages are first collected into internal counters that are not
112
+ directly visible to the user. They are then copied to the user-visible
113
+ state as a result of calling `update()` and can then be queried using
114
+ `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the
115
+ internal counters for the next round, so that the user-visible state
116
+ effectively reflects averages collected between the last two calls to
117
+ `update()`.
118
+
119
+ Args:
120
+ regex: Regular expression defining which statistics to
121
+ collect. The default is to collect everything.
122
+ keep_previous: Whether to retain the previous averages if no
123
+ scalars were collected on a given round
124
+ (default: True).
125
+ """
126
+ def __init__(self, regex='.*', keep_previous=True):
127
+ self._regex = re.compile(regex)
128
+ self._keep_previous = keep_previous
129
+ self._cumulative = dict()
130
+ self._moments = dict()
131
+ self.update()
132
+ self._moments.clear()
133
+
134
+ def names(self):
135
+ r"""Returns the names of all statistics broadcasted so far that
136
+ match the regular expression specified at construction time.
137
+ """
138
+ return [name for name in _counters if self._regex.fullmatch(name)]
139
+
140
+ def update(self):
141
+ r"""Copies current values of the internal counters to the
142
+ user-visible state and resets them for the next round.
143
+
144
+ If `keep_previous=True` was specified at construction time, the
145
+ operation is skipped for statistics that have received no scalars
146
+ since the last update, retaining their previous averages.
147
+
148
+ This method performs a number of GPU-to-CPU transfers and one
149
+ `torch.distributed.all_reduce()`. It is intended to be called
150
+ periodically in the main training loop, typically once every
151
+ N training steps.
152
+ """
153
+ if not self._keep_previous:
154
+ self._moments.clear()
155
+ for name, cumulative in _sync(self.names()):
156
+ if name not in self._cumulative:
157
+ self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
158
+ delta = cumulative - self._cumulative[name]
159
+ self._cumulative[name].copy_(cumulative)
160
+ if float(delta[0]) != 0:
161
+ self._moments[name] = delta
162
+
163
+ def _get_delta(self, name):
164
+ r"""Returns the raw moments that were accumulated for the given
165
+ statistic between the last two calls to `update()`, or zero if
166
+ no scalars were collected.
167
+ """
168
+ assert self._regex.fullmatch(name)
169
+ if name not in self._moments:
170
+ self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
171
+ return self._moments[name]
172
+
173
+ def num(self, name):
174
+ r"""Returns the number of scalars that were accumulated for the given
175
+ statistic between the last two calls to `update()`, or zero if
176
+ no scalars were collected.
177
+ """
178
+ delta = self._get_delta(name)
179
+ return int(delta[0])
180
+
181
+ def mean(self, name):
182
+ r"""Returns the mean of the scalars that were accumulated for the
183
+ given statistic between the last two calls to `update()`, or NaN if
184
+ no scalars were collected.
185
+ """
186
+ delta = self._get_delta(name)
187
+ if int(delta[0]) == 0:
188
+ return float('nan')
189
+ return float(delta[1] / delta[0])
190
+
191
+ def std(self, name):
192
+ r"""Returns the standard deviation of the scalars that were
193
+ accumulated for the given statistic between the last two calls to
194
+ `update()`, or NaN if no scalars were collected.
195
+ """
196
+ delta = self._get_delta(name)
197
+ if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
198
+ return float('nan')
199
+ if int(delta[0]) == 1:
200
+ return float(0)
201
+ mean = float(delta[1] / delta[0])
202
+ raw_var = float(delta[2] / delta[0])
203
+ return np.sqrt(max(raw_var - np.square(mean), 0))
204
+
205
+ def as_dict(self):
206
+ r"""Returns the averages accumulated between the last two calls to
207
+ `update()` as an `dnnlib.EasyDict`. The contents are as follows:
208
+
209
+ dnnlib.EasyDict(
210
+ NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT),
211
+ ...
212
+ )
213
+ """
214
+ stats = dnnlib.EasyDict()
215
+ for name in self.names():
216
+ stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
217
+ return stats
218
+
219
+ def __getitem__(self, name):
220
+ r"""Convenience getter.
221
+ `collector[name]` is a synonym for `collector.mean(name)`.
222
+ """
223
+ return self.mean(name)
224
+
225
+ #----------------------------------------------------------------------------
226
+
227
+ def _sync(names):
228
+ r"""Synchronize the global cumulative counters across devices and
229
+ processes. Called internally by `Collector.update()`.
230
+ """
231
+ if len(names) == 0:
232
+ return []
233
+ global _sync_called
234
+ _sync_called = True
235
+
236
+ # Collect deltas within current rank.
237
+ deltas = []
238
+ device = _sync_device if _sync_device is not None else torch.device('cpu')
239
+ for name in names:
240
+ delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
241
+ for counter in _counters[name].values():
242
+ delta.add_(counter.to(device))
243
+ counter.copy_(torch.zeros_like(counter))
244
+ deltas.append(delta)
245
+ deltas = torch.stack(deltas)
246
+
247
+ # Sum deltas across ranks.
248
+ if _sync_device is not None:
249
+ torch.distributed.all_reduce(deltas)
250
+
251
+ # Update cumulative values.
252
+ deltas = deltas.cpu()
253
+ for idx, name in enumerate(names):
254
+ if name not in _cumulative:
255
+ _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
256
+ _cumulative[name].add_(deltas[idx])
257
+
258
+ # Return name-value pairs.
259
+ return [(name, _cumulative[name]) for name in names]
260
+
261
+ #----------------------------------------------------------------------------
diffusion-insgen/train.py ADDED
@@ -0,0 +1,605 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Train a GAN using the techniques described in the paper
10
+ "Training Generative Adversarial Networks with Limited Data"."""
11
+
12
+ import os
13
+ import click
14
+ import re
15
+ import json
16
+ import tempfile
17
+ import torch
18
+ import dnnlib
19
+
20
+ from training import training_loop
21
+ from metrics import metric_main
22
+ from torch_utils import training_stats
23
+ from torch_utils import custom_ops
24
+
25
+ #----------------------------------------------------------------------------
26
+
27
+ class UserError(Exception):
28
+ pass
29
+
30
+ #----------------------------------------------------------------------------
31
+
32
+ def setup_training_loop_kwargs(
33
+ # General options (not included in desc).
34
+ gpus = None, # Number of GPUs: <int>, default = 1 gpu
35
+ snap = None, # Snapshot interval: <int>, default = 50 ticks
36
+ metrics = None, # List of metric names: [], ['fid50k_full'] (default), ...
37
+ seed = None, # Random seed: <int>, default = 0
38
+
39
+ # Dataset.
40
+ data = None, # Training dataset (required): <path>
41
+ cond = None, # Train conditional model based on dataset labels: <bool>, default = False
42
+ subset = None, # Train with only N images: <int>, default = all
43
+ mirror = None, # Augment dataset with x-flips: <bool>, default = False
44
+
45
+ # Base config.
46
+ cfg = None, # Base config: 'auto' (default), 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar'
47
+ gamma = None, # Override R1 gamma: <float>
48
+ kimg = None, # Override training duration: <int>
49
+ batch = None, # Override batch size: <int>
50
+
51
+ # Discriminator augmentation.
52
+ aug = None, # Augmentation mode: 'ada' (default), 'noaug', 'fixed'
53
+ p = None, # Specify p for 'fixed' (required): <float>
54
+ target = None, # Override ADA target for 'ada': <float>, default = depends on aug
55
+
56
+ # Transfer learning.
57
+ resume = None, # Load previous network: 'noresume' (default), 'ffhq256', 'ffhq512', 'ffhq1024', 'celebahq256', 'lsundog256', <file>, <url>
58
+ freezed = None, # Freeze-D: <int>, default = 0 discriminator layers
59
+
60
+ # Performance options (not included in desc).
61
+ fp32 = None, # Disable mixed-precision training: <bool>, default = False
62
+ nhwc = None, # Use NHWC memory format with FP16: <bool>, default = False
63
+ allow_tf32 = None, # Allow PyTorch to use TF32 for matmul and convolutions: <bool>, default = False
64
+ nobench = None, # Disable cuDNN benchmarking: <bool>, default = False
65
+ workers = None, # Override number of DataLoader workers: <int>, default = 3
66
+ # InsGen related options
67
+ no_insgen = False, # Disable insgen for training: <bool>, default = False
68
+ rqs = None, # Size of real image queue: <int>, default = 5% * len(dataset)
69
+ fqs = None, # Size of fake image queue: <int>, default = 5% * len(dataset)
70
+ no_cl_on_g = False, # Disable fake instance discrimination for generator: <bool>, default = False
71
+ ada_linear = False, # Whether to linearly increase the strength of ADA: <bool>, default = False
72
+
73
+ # Added
74
+ exp = None,
75
+ daug = 'ADA',
76
+
77
+ # Adaptive Diffusion config.
78
+ beta_schedule = None,
79
+ beta_start = None,
80
+ beta_end = None,
81
+ t_min = None,
82
+ t_max = None,
83
+ noise_sd = None,
84
+ ts_dist = None,
85
+ ada_maxp = None,
86
+ ):
87
+ args = dnnlib.EasyDict()
88
+
89
+ # ------------------------------------------
90
+ # General options: gpus, snap, metrics, seed
91
+ # ------------------------------------------
92
+
93
+ if gpus is None:
94
+ gpus = 1
95
+ assert isinstance(gpus, int)
96
+ if not (gpus >= 1 and gpus & (gpus - 1) == 0):
97
+ raise UserError('--gpus must be a power of two')
98
+ args.num_gpus = gpus
99
+
100
+ if snap is None:
101
+ snap = 50
102
+ assert isinstance(snap, int)
103
+ if snap < 1:
104
+ raise UserError('--snap must be at least 1')
105
+ args.image_snapshot_ticks = snap
106
+ args.network_snapshot_ticks = snap
107
+
108
+ if metrics is None:
109
+ metrics = ['fid50k_full']
110
+ assert isinstance(metrics, list)
111
+ if not all(metric_main.is_valid_metric(metric) for metric in metrics):
112
+ raise UserError('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics()))
113
+ args.metrics = metrics
114
+
115
+ if seed is None:
116
+ seed = 0
117
+ assert isinstance(seed, int)
118
+ args.random_seed = seed
119
+
120
+ # -----------------------------------
121
+ # Dataset: data, cond, subset, mirror
122
+ # -----------------------------------
123
+
124
+ assert data is not None
125
+ assert isinstance(data, str)
126
+ args.training_set_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=data, use_labels=True, max_size=None, xflip=False)
127
+ args.data_loader_kwargs = dnnlib.EasyDict(pin_memory=True, num_workers=3, prefetch_factor=2)
128
+ try:
129
+ training_set = dnnlib.util.construct_class_by_name(**args.training_set_kwargs) # subclass of training.dataset.Dataset
130
+ args.training_set_kwargs.resolution = training_set.resolution # be explicit about resolution
131
+ args.training_set_kwargs.use_labels = training_set.has_labels # be explicit about labels
132
+ args.training_set_kwargs.max_size = len(training_set) # be explicit about dataset size
133
+ desc = training_set.name
134
+ del training_set # conserve memory
135
+ except IOError as err:
136
+ raise UserError(f'--data: {err}')
137
+
138
+ if exp is not None:
139
+ desc += f'-{exp}'
140
+
141
+ if cond is None:
142
+ cond = False
143
+ assert isinstance(cond, bool)
144
+ if cond:
145
+ if not args.training_set_kwargs.use_labels:
146
+ raise UserError('--cond=True requires labels specified in dataset.json')
147
+ desc += '-cond'
148
+ else:
149
+ args.training_set_kwargs.use_labels = False
150
+
151
+ if subset is not None:
152
+ assert isinstance(subset, int)
153
+ if not 1 <= subset <= args.training_set_kwargs.max_size:
154
+ raise UserError(f'--subset must be between 1 and {args.training_set_kwargs.max_size}')
155
+ desc += f'-subset{subset}'
156
+ if subset < args.training_set_kwargs.max_size:
157
+ args.training_set_kwargs.max_size = subset
158
+ args.training_set_kwargs.random_seed = args.random_seed
159
+
160
+ if mirror is None:
161
+ mirror = False
162
+ assert isinstance(mirror, bool)
163
+ if mirror:
164
+ desc += '-mirror'
165
+ args.training_set_kwargs.xflip = True
166
+
167
+ # ------------------------------------
168
+ # Base config: cfg, gamma, kimg, batch
169
+ # ------------------------------------
170
+
171
+ if cfg is None:
172
+ cfg = 'auto'
173
+ assert isinstance(cfg, str)
174
+ desc += f'-{cfg}'
175
+
176
+ cfg_specs = {
177
+ 'auto': dict(ref_gpus=-1, kimg=25000, mb=-1, mbstd=-1, fmaps=-1, lrate=-1, gamma=-1, ema=-1, ramp=0.05, map=2), # Populated dynamically based on resolution and GPU count.
178
+ 'stylegan2': dict(ref_gpus=8, kimg=25000, mb=32, mbstd=4, fmaps=1, lrate=0.002, gamma=10, ema=10, ramp=None, map=8), # Uses mixed-precision, unlike the original StyleGAN2.
179
+ 'paper256': dict(ref_gpus=8, kimg=25000, mb=64, mbstd=8, fmaps=0.5, lrate=0.0025, gamma=1, ema=20, ramp=None, map=8),
180
+ 'paper512': dict(ref_gpus=8, kimg=25000, mb=64, mbstd=8, fmaps=1, lrate=0.0025, gamma=0.5, ema=20, ramp=None, map=8),
181
+ 'paper1024': dict(ref_gpus=8, kimg=25000, mb=32, mbstd=4, fmaps=1, lrate=0.002, gamma=2, ema=10, ramp=None, map=8),
182
+ 'cifar': dict(ref_gpus=4, kimg=100000, mb=64, mbstd=32, fmaps=1, lrate=0.0025, gamma=0.01, ema=500, ramp=0.05, map=2),
183
+ }
184
+
185
+ assert cfg in cfg_specs
186
+ spec = dnnlib.EasyDict(cfg_specs[cfg])
187
+ if cfg == 'auto':
188
+ desc += f'{gpus:d}'
189
+ spec.ref_gpus = gpus
190
+ res = args.training_set_kwargs.resolution
191
+ spec.mb = max(min(gpus * min(4096 // res, 32), 64), gpus) # keep gpu memory consumption at bay
192
+ spec.mbstd = min(spec.mb // gpus, 4) # other hyperparams behave more predictably if mbstd group size remains fixed
193
+ spec.fmaps = 1 if res >= 512 else 0.5
194
+ spec.lrate = 0.002 if res >= 1024 else 0.0025
195
+ spec.gamma = 0.0002 * (res ** 2) / spec.mb # heuristic formula
196
+ spec.ema = spec.mb * 10 / 32
197
+
198
+ args.G_kwargs = dnnlib.EasyDict(class_name='training.networks.Generator', z_dim=512, w_dim=512, mapping_kwargs=dnnlib.EasyDict(), synthesis_kwargs=dnnlib.EasyDict())
199
+ args.D_kwargs = dnnlib.EasyDict(class_name='training.networks.Discriminator', block_kwargs=dnnlib.EasyDict(), mapping_kwargs=dnnlib.EasyDict(), epilogue_kwargs=dnnlib.EasyDict())
200
+ args.G_kwargs.synthesis_kwargs.channel_base = args.D_kwargs.channel_base = int(spec.fmaps * 32768)
201
+ args.G_kwargs.synthesis_kwargs.channel_max = args.D_kwargs.channel_max = 512
202
+ args.G_kwargs.mapping_kwargs.num_layers = spec.map
203
+ args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 4 # enable mixed-precision training
204
+ args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = 256 # clamp activations to avoid float16 overflow
205
+ args.D_kwargs.epilogue_kwargs.mbstd_group_size = spec.mbstd
206
+ args.D_kwargs.mapping_kwargs.num_layers = 0 # align with tensorflow implementation of ADA
207
+
208
+ args.G_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=spec.lrate, betas=[0,0.99], eps=1e-8)
209
+ args.D_opt_kwargs = dnnlib.EasyDict(class_name='torch.optim.Adam', lr=spec.lrate, betas=[0,0.99], eps=1e-8)
210
+ args.loss_kwargs = dnnlib.EasyDict(class_name='training.loss.StyleGAN2Loss', r1_gamma=spec.gamma)
211
+
212
+ args.total_kimg = spec.kimg
213
+ args.batch_size = spec.mb
214
+ args.batch_gpu = spec.mb // spec.ref_gpus
215
+ args.ema_kimg = spec.ema
216
+ args.ema_rampup = spec.ramp
217
+
218
+ if cfg == 'cifar':
219
+ args.loss_kwargs.pl_weight = 0 # disable path length regularization
220
+ args.loss_kwargs.style_mixing_prob = 0 # disable style mixing
221
+ args.D_kwargs.architecture = 'orig' # disable residual skip connections
222
+
223
+ if gamma is not None:
224
+ assert isinstance(gamma, float)
225
+ if not gamma >= 0:
226
+ raise UserError('--gamma must be non-negative')
227
+ desc += f'-gamma{gamma:g}'
228
+ args.loss_kwargs.r1_gamma = gamma
229
+
230
+ if kimg is not None:
231
+ assert isinstance(kimg, int)
232
+ if not kimg >= 1:
233
+ raise UserError('--kimg must be at least 1')
234
+ desc += f'-kimg{kimg:d}'
235
+ args.total_kimg = kimg
236
+
237
+ if batch is not None:
238
+ assert isinstance(batch, int)
239
+ if not (batch >= 1 and batch % gpus == 0):
240
+ raise UserError('--batch must be at least 1 and divisible by --gpus')
241
+ desc += f'-batch{batch}'
242
+ args.batch_size = batch
243
+ args.batch_gpu = batch // gpus
244
+
245
+ # ---------------------------------------------------
246
+ # Discriminator augmentation: aug, p, target, augpipe
247
+ # ---------------------------------------------------
248
+
249
+ if aug is None:
250
+ aug = 'ada'
251
+ else:
252
+ assert isinstance(aug, str)
253
+ desc += f'-{aug}'
254
+
255
+ if aug == 'ada':
256
+ args.ada_target = 0.6
257
+
258
+ elif aug == 'noaug':
259
+ pass
260
+
261
+ elif aug == 'fixed':
262
+ if p is None:
263
+ raise UserError(f'--aug={aug} requires specifying --p')
264
+
265
+ else:
266
+ raise UserError(f'--aug={aug} not supported')
267
+
268
+ if p is not None:
269
+ assert isinstance(p, float)
270
+ if aug != 'fixed':
271
+ raise UserError('--p can only be specified with --aug=fixed')
272
+ if not 0 <= p <= 1:
273
+ raise UserError('--p must be between 0 and 1')
274
+ desc += f'-p{p:g}'
275
+ args.augment_p = p
276
+
277
+ if target is not None:
278
+ assert isinstance(target, float)
279
+ if aug != 'ada':
280
+ raise UserError('--target can only be specified with --aug=ada')
281
+ if not 0 <= target <= 1:
282
+ raise UserError('--target must be between 0 and 1')
283
+ desc += f'-target{target:g}'
284
+ args.ada_target = target
285
+
286
+ diffusion_specs = dict(beta_schedule=beta_schedule, beta_start=beta_start, beta_end=beta_end,
287
+ t_min=t_min, t_max=t_max, noise_std=noise_sd,
288
+ aug=daug, ada_maxp=ada_maxp, ts_dist=ts_dist)
289
+
290
+ desc += f"-ts_dist-{ts_dist}"
291
+ if aug != 'noaug':
292
+ args.augment_kwargs = dnnlib.EasyDict(class_name='training.augment.AugmentPipe', **diffusion_specs)
293
+
294
+ # ----------------------------------
295
+ # Transfer learning: resume, freezed
296
+ # ----------------------------------
297
+
298
+ resume_specs = {
299
+ 'ffhq256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res256-mirror-paper256-noaug.pkl',
300
+ 'ffhq512': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res512-mirror-stylegan2-noaug.pkl',
301
+ 'ffhq1024': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res1024-mirror-stylegan2-noaug.pkl',
302
+ 'celebahq256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/celebahq-res256-mirror-paper256-kimg100000-ada-target0.5.pkl',
303
+ 'lsundog256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/lsundog-res256-paper256-kimg100000-noaug.pkl',
304
+ }
305
+
306
+ assert resume is None or isinstance(resume, str)
307
+ if resume is None:
308
+ resume = 'noresume'
309
+ elif resume == 'noresume':
310
+ desc += '-noresume'
311
+ elif resume in resume_specs:
312
+ desc += f'-resume{resume}'
313
+ args.resume_pkl = resume_specs[resume] # predefined url
314
+ else:
315
+ desc += '-resumecustom'
316
+ args.resume_pkl = resume # custom path or url
317
+
318
+ if resume != 'noresume':
319
+ args.ada_kimg = 100 # make ADA react faster at the beginning
320
+ args.ema_rampup = None # disable EMA rampup
321
+ args.ada_kimg = 100
322
+
323
+ if freezed is not None:
324
+ assert isinstance(freezed, int)
325
+ if not freezed >= 0:
326
+ raise UserError('--freezed must be non-negative')
327
+ desc += f'-freezed{freezed:d}'
328
+ args.D_kwargs.block_kwargs.freeze_layers = freezed
329
+
330
+ # -------------------------------------------------
331
+ # Performance options: fp32, nhwc, nobench, workers
332
+ # -------------------------------------------------
333
+
334
+ if fp32 is None:
335
+ fp32 = False
336
+ assert isinstance(fp32, bool)
337
+ if fp32:
338
+ args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 0
339
+ args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = None
340
+
341
+ if nhwc is None:
342
+ nhwc = False
343
+ assert isinstance(nhwc, bool)
344
+ if nhwc:
345
+ args.G_kwargs.synthesis_kwargs.fp16_channels_last = args.D_kwargs.block_kwargs.fp16_channels_last = True
346
+
347
+ if nobench is None:
348
+ nobench = False
349
+ assert isinstance(nobench, bool)
350
+ if nobench:
351
+ args.cudnn_benchmark = False
352
+
353
+ if allow_tf32 is None:
354
+ allow_tf32 = False
355
+ assert isinstance(allow_tf32, bool)
356
+ if allow_tf32:
357
+ args.allow_tf32 = True
358
+
359
+ if workers is not None:
360
+ assert isinstance(workers, int)
361
+ if not workers >= 1:
362
+ raise UserError('--workers must be at least 1')
363
+ args.data_loader_kwargs.num_workers = workers
364
+
365
+ # ----------------------------------------------------
366
+ # InsGen: contrastive_head, no_cl_on_g, cl_loss_weight
367
+ # ----------------------------------------------------
368
+ use_insgen = True
369
+ if no_insgen is not None:
370
+ assert isinstance(no_insgen, bool)
371
+ use_insgen = not no_insgen
372
+
373
+ if use_insgen:
374
+ # Overwrite class name of loss function
375
+ args.loss_kwargs.class_name = 'training.contrastive_loss.StyleGAN2LossCL'
376
+
377
+ args.DHead_kwargs = dnnlib.EasyDict(class_name='training.contrastive_head.CLHead', inplanes=512, temperature=0.2, momentum=0.999, queue_size=-1)
378
+ args.GHead_kwargs = dnnlib.EasyDict(class_name='training.contrastive_head.CLHead', inplanes=512, temperature=0.2, momentum=0.999, queue_size=-1)
379
+ # Default queue size is 0.05 * len(dataset)
380
+ default_queue_size = int(0.05 * args.training_set_kwargs.max_size)
381
+ if args.training_set_kwargs.xflip:
382
+ default_queue_size *= 2
383
+ args.DHead_kwargs.queue_size = default_queue_size if rqs is None else rqs
384
+ args.GHead_kwargs.queue_size = default_queue_size if fqs is None else fqs
385
+
386
+ if no_cl_on_g is not None:
387
+ assert isinstance(no_cl_on_g, bool)
388
+ args.no_cl_on_g = no_cl_on_g
389
+ if ada_linear is not None:
390
+ assert isinstance(ada_linear, bool)
391
+ args.ada_linear = ada_linear
392
+ # Default loss weight for real instance discrimination, fake instance discrimination and fake instance discrimination on g
393
+ args.cl_loss_weight = dnnlib.EasyDict(lw_real_cl=1.0, lw_fake_cl=1.0, lw_fake_cl_on_g=0.1)
394
+ else:
395
+ args.DHead_kwargs = None
396
+ args.GHead_kwargs = None
397
+
398
+ return desc, args
399
+
400
+ #----------------------------------------------------------------------------
401
+
402
+ def subprocess_fn(rank, args, temp_dir):
403
+ dnnlib.util.Logger(file_name=os.path.join(args.run_dir, 'log.txt'), file_mode='a', should_flush=True)
404
+
405
+ # Init torch.distributed.
406
+ if args.num_gpus > 1:
407
+ init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init'))
408
+ if os.name == 'nt':
409
+ init_method = 'file:///' + init_file.replace('\\', '/')
410
+ torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus)
411
+ else:
412
+ init_method = f'file://{init_file}'
413
+ torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus)
414
+
415
+ # Init torch_utils.
416
+ sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None
417
+ training_stats.init_multiprocessing(rank=rank, sync_device=sync_device)
418
+ if rank != 0:
419
+ custom_ops.verbosity = 'none'
420
+
421
+ # Execute training loop.
422
+ training_loop.training_loop(rank=rank, **args)
423
+
424
+ #----------------------------------------------------------------------------
425
+
426
+ class CommaSeparatedList(click.ParamType):
427
+ name = 'list'
428
+
429
+ def convert(self, value, param, ctx):
430
+ _ = param, ctx
431
+ if value is None or value.lower() == 'none' or value == '':
432
+ return []
433
+ return value.split(',')
434
+
435
+ #----------------------------------------------------------------------------
436
+
437
+ @click.command()
438
+ @click.pass_context
439
+
440
+ # General options.
441
+ @click.option('--outdir', help='Where to save the results', required=True, metavar='DIR')
442
+ @click.option('--gpus', help='Number of GPUs to use [default: 1]', type=int, metavar='INT')
443
+ @click.option('--snap', help='Snapshot interval [default: 50 ticks]', type=int, metavar='INT')
444
+ @click.option('--metrics', help='Comma-separated list or "none" [default: fid50k_full]', type=CommaSeparatedList())
445
+ @click.option('--seed', help='Random seed [default: 0]', type=int, metavar='INT')
446
+ @click.option('-n', '--dry-run', help='Print training options and exit', is_flag=True)
447
+ @click.option('--exp', help='exp id', type=str)
448
+
449
+ # Dataset.
450
+ @click.option('--data', help='Training data (directory or zip)', metavar='PATH', required=True)
451
+ @click.option('--cond', help='Train conditional model based on dataset labels [default: false]', type=bool, metavar='BOOL')
452
+ @click.option('--subset', help='Train with only N images [default: all]', type=int, metavar='INT')
453
+ @click.option('--mirror', help='Enable dataset x-flips [default: false]', type=bool, metavar='BOOL', default=1)
454
+
455
+ # Base config.
456
+ @click.option('--cfg', help='Base config [default: auto]', type=click.Choice(['auto', 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar']))
457
+ @click.option('--gamma', help='Override R1 gamma', type=float)
458
+ @click.option('--kimg', help='Override training duration', type=int, metavar='INT')
459
+ @click.option('--batch', help='Override batch size', type=int, metavar='INT')
460
+
461
+ # Discriminator augmentation.
462
+ @click.option('--aug', help='Augmentation mode [default: ada]', type=click.Choice(['noaug', 'ada', 'fixed']))
463
+ @click.option('--daug', help='Augmentation mode [default: ada]', type=click.Choice(['NO', 'ADA', 'DIFF']), default='ADA')
464
+ @click.option('--p', help='Augmentation probability for --aug=fixed', type=float)
465
+
466
+ # Adaptive diffusion config.
467
+ @click.option('--beta_schedule', help='Forward diffusion beta schedule (we use linear always)', type=str, default='linear')
468
+ @click.option('--beta_start', help='Forward diffusion process beta_start', type=float, default=1e-4)
469
+ @click.option('--beta_end', help='Forward diffusion process beta_end', type=float, default=2e-2)
470
+ @click.option('--t_min', help='Minimum # of timesteps for adaptively modification', type=int, default=10)
471
+ @click.option('--t_max', help='Maximum # of timesteps for adaptively modification', type=int, default=500)
472
+ @click.option('--noise_sd', help='Diffusion noise standard deviation', type=float, default=0.05)
473
+ @click.option('--ts_dist', help='Diffusion t sampling way', type=click.Choice(['priority', 'uniform']), default='uniform')
474
+ @click.option('--target', help='Discriminator target value', type=float, default=0.6)
475
+
476
+ # Transfer learning.
477
+ @click.option('--resume', help='Resume training [default: noresume]', metavar='PKL')
478
+ @click.option('--freezed', help='Freeze-D [default: 0 layers]', type=int, metavar='INT')
479
+
480
+ # Performance options.
481
+ @click.option('--fp32', help='Disable mixed-precision training', type=bool, metavar='BOOL')
482
+ @click.option('--nhwc', help='Use NHWC memory format with FP16', type=bool, metavar='BOOL')
483
+ @click.option('--nobench', help='Disable cuDNN benchmarking', type=bool, metavar='BOOL')
484
+ @click.option('--allow-tf32', help='Allow PyTorch to use TF32 internally', type=bool, metavar='BOOL')
485
+ @click.option('--workers', help='Override number of DataLoader workers', type=int, metavar='INT')
486
+
487
+ # InsGen related options.
488
+ @click.option('--no_insgen', help='Disable InsGen back to ADA [default: False]', type=bool, metavar='BOOL')
489
+ @click.option('--rqs', help='Size of real image queue [default: 5% * len(dataset)]', type=int, metavar='INT')
490
+ @click.option('--fqs', help='Size of fake image queue [default: 5% * len(dataset)]', type=int, metavar='INT')
491
+ @click.option('--no_cl_on_g', help='Disable fake instance discrimination for generator [default: False]', type=bool, metavar='BOOL')
492
+ @click.option('--ada_linear', help='Whether to linearly increase the strength of ADA [default: False]', type=bool, metavar='BOOL')
493
+
494
+
495
+ def main(ctx, outdir, dry_run, **config_kwargs):
496
+ """Train a GAN using the techniques described in the paper
497
+ "Training Generative Adversarial Networks with Limited Data".
498
+
499
+ Examples:
500
+
501
+ \b
502
+ # Train with custom dataset using 1 GPU.
503
+ python train.py --outdir=~/training-runs --data=~/mydataset.zip --gpus=1
504
+
505
+ \b
506
+ # Train class-conditional CIFAR-10 using 2 GPUs.
507
+ python train.py --outdir=~/training-runs --data=~/datasets/cifar10.zip \\
508
+ --gpus=2 --cfg=cifar --cond=1
509
+
510
+ \b
511
+ # Transfer learn MetFaces from FFHQ using 4 GPUs.
512
+ python train.py --outdir=~/training-runs --data=~/datasets/metfaces.zip \\
513
+ --gpus=4 --cfg=paper1024 --mirror=1 --resume=ffhq1024 --snap=10
514
+
515
+ \b
516
+ # Reproduce original StyleGAN2 config F.
517
+ python train.py --outdir=~/training-runs --data=~/datasets/ffhq.zip \\
518
+ --gpus=8 --cfg=stylegan2 --mirror=1 --aug=noaug
519
+
520
+ \b
521
+ Base configs (--cfg):
522
+ auto Automatically select reasonable defaults based on resolution
523
+ and GPU count. Good starting point for new datasets.
524
+ stylegan2 Reproduce results for StyleGAN2 config F at 1024x1024.
525
+ paper256 Reproduce results for FFHQ and LSUN Cat at 256x256.
526
+ paper512 Reproduce results for BreCaHAD and AFHQ at 512x512.
527
+ paper1024 Reproduce results for MetFaces at 1024x1024.
528
+ cifar Reproduce results for CIFAR-10 at 32x32.
529
+
530
+ \b
531
+ Transfer learning source networks (--resume):
532
+ ffhq256 FFHQ trained at 256x256 resolution.
533
+ ffhq512 FFHQ trained at 512x512 resolution.
534
+ ffhq1024 FFHQ trained at 1024x1024 resolution.
535
+ celebahq256 CelebA-HQ trained at 256x256 resolution.
536
+ lsundog256 LSUN Dog trained at 256x256 resolution.
537
+ <PATH or URL> Custom network pickle.
538
+ """
539
+ dnnlib.util.Logger(should_flush=True)
540
+
541
+ # Setup training options.
542
+ try:
543
+ run_desc, args = setup_training_loop_kwargs(**config_kwargs)
544
+ except UserError as err:
545
+ ctx.fail(err)
546
+
547
+ # Pick output directory.
548
+ prev_run_dirs = []
549
+ if os.path.isdir(outdir):
550
+ prev_run_dirs = [x for x in os.listdir(outdir) if os.path.isdir(os.path.join(outdir, x))]
551
+
552
+ matching_dirs = [re.fullmatch(r'\d{5}' + f'-{run_desc}', x) for x in prev_run_dirs if
553
+ re.fullmatch(r'\d{5}' + f'-{run_desc}', x) is not None]
554
+ if len(matching_dirs) > 0: # expect unique desc, continue in this directory
555
+ assert len(matching_dirs) == 1, f'Multiple directories found for resuming: {matching_dirs}'
556
+ run_dir = os.path.join(outdir, matching_dirs[0].group())
557
+ else: # fallback to standard
558
+ prev_run_ids = [re.match(r'^\d+', x) for x in prev_run_dirs]
559
+ prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None]
560
+ cur_run_id = max(prev_run_ids, default=-1) + 1
561
+ run_dir = os.path.join(outdir, f'{cur_run_id:05d}-{run_desc}')
562
+ assert not os.path.exists(run_dir)
563
+ args.run_dir = run_dir
564
+
565
+ # Print options.
566
+ print()
567
+ print('Training options:')
568
+ print(json.dumps(args, indent=2))
569
+ print()
570
+ print(f'Output directory: {args.run_dir}')
571
+ print(f'Training data: {args.training_set_kwargs.path}')
572
+ print(f'Training duration: {args.total_kimg} kimg')
573
+ print(f'Number of GPUs: {args.num_gpus}')
574
+ print(f'Number of images: {args.training_set_kwargs.max_size}')
575
+ print(f'Image resolution: {args.training_set_kwargs.resolution}')
576
+ print(f'Conditional model: {args.training_set_kwargs.use_labels}')
577
+ print(f'Dataset x-flips: {args.training_set_kwargs.xflip}')
578
+ print()
579
+
580
+ # Dry run?
581
+ if dry_run:
582
+ print('Dry run; exiting.')
583
+ return
584
+
585
+ # Create output directory.
586
+ print('Creating output directory...')
587
+ os.makedirs(args.run_dir, exist_ok=True)
588
+ with open(os.path.join(args.run_dir, 'training_options.json'), 'wt') as f:
589
+ json.dump(args, f, indent=2)
590
+
591
+ # Launch processes.
592
+ print('Launching processes...')
593
+ torch.multiprocessing.set_start_method('spawn')
594
+ with tempfile.TemporaryDirectory() as temp_dir:
595
+ if args.num_gpus == 1:
596
+ subprocess_fn(rank=0, args=args, temp_dir=temp_dir)
597
+ else:
598
+ torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus)
599
+
600
+ #----------------------------------------------------------------------------
601
+
602
+ if __name__ == "__main__":
603
+ main() # pylint: disable=no-value-for-parameter
604
+
605
+ #----------------------------------------------------------------------------
diffusion-insgen/training/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ # empty