unknown commited on
Commit
a6f45eb
1 Parent(s): 375a436

Documentation

Browse files
README.md CHANGED
@@ -203,7 +203,7 @@ python dataset_tool.py --source=~/downloads/afhq/train/wild --dest=~/datasets/af
203
  python dataset_tool.py --source=~/downloads/cifar-10-python.tar.gz --dest=~/datasets/cifar10.zip
204
  ```
205
 
206
- **LSUN**: Download the desired LSUN categories in LMDB format from the [LSUN project page](https://www.yf.io/p/lsun/) and convert to ZIP archive:
207
 
208
  ```.bash
209
  python dataset_tool.py --source=~/downloads/lsun/raw/cat_lmdb --dest=~/datasets/lsuncat200k.zip \
@@ -262,7 +262,7 @@ The training configuration can be further customized with additional command lin
262
  * `--cond=1` enables class-conditional training (requires a dataset with labels).
263
  * `--mirror=1` amplifies the dataset with x-flips. Often beneficial, even with ADA.
264
  * `--resume=ffhq1024 --snap=10` performs transfer learning from FFHQ trained at 1024x1024.
265
- * `--resume=~/training-runs/<NAME>/network-snapshot-<KIMG>.pkl` resumes a previous training run where it left off.
266
  * `--gamma=10` overrides R1 gamma. We recommend trying a couple of different values for each new dataset.
267
  * `--aug=ada --target=0.7` adjusts ADA target value (default: 0.6).
268
  * `--augpipe=blit` enables pixel blitting but disables all other augmentations.
@@ -293,7 +293,7 @@ The total training time depends heavily on resolution, number of GPUs, dataset,
293
  | 1024x1024 | 4 | 11h 36m | 12d 02h | 40.1&ndash;40.8 | 8.4 GB | 21.9 GB
294
  | 1024x1024 | 8 | 5h 54m | 6d 03h | 20.2&ndash;20.6 | 8.3 GB | 44.7 GB
295
 
296
- The above measurements were done using NVIDIA Tesla V100 GPUs with default settings (`--cfg=auto --aug=ada --metrics=fid50k_full`). "sec/kimg" shows the expected range of variation in raw training performance, as reported in `log.txt`, and "GPU mem" and "CPU mem" show the peak memory consumption observed over the course of training.
297
 
298
  In typical cases, 25000 kimg or more is needed to reach convergence, but the results are already quite reasonable around 5000 kimg. 1000 kimg is often enough for transfer learning, which tends to converge significantly faster. The following figure shows example convergence curves for different datasets as a function of wallclock time, using the same settings as above:
299
 
@@ -325,23 +325,23 @@ We employ the following metrics in the ADA paper. Execution time and GPU memory
325
 
326
  | Metric | Time | GPU mem | Description |
327
  | :----- | :----: | :-----: | :---------- |
328
- | `fid50k_full` | 13 min | 1.8 GB | Fr&eacute;chet inception distance<sup>[1]</sup> against the full dataset.
329
- | `kid50k_full` | 13 min | 1.8 GB | Kernel inception distance<sup>[2]</sup> against the full dataset.
330
- | `pr50k3_full` | 13 min | 4.1 GB | Precision and recall<sup>[3]</sup> againt the full dataset.
331
- | `is50k` | 13 min | 1.8 GB | Inception score<sup>[4]</sup> for CIFAR-10.
332
 
333
  In addition, the following metrics from the [StyleGAN](https://github.com/NVlabs/stylegan) and [StyleGAN2](https://github.com/NVlabs/stylegan2) papers are also supported:
334
 
335
  | Metric | Time | GPU mem | Description |
336
  | :------------ | :----: | :-----: | :---------- |
337
- | `fid50k` | 13 min | 1.8 GB | Fr&eacute;chet inception distance against 50k real images.
338
- | `kid50k` | 13 min | 1.8 GB | Kernel inception distance against 50k real images.
339
- | `pr50k3` | 13 min | 4.1 GB | Precision and recall against 50k real images.
340
- | `ppl2_wend` | 36 min | 2.4 GB | Perceptual path length<sup>[5]</sup> in W at path endpoints against full image.
341
- | `ppl_zfull` | 36 min | 2.4 GB | Perceptual path length in Z for full paths against cropped image.
342
- | `ppl_wfull` | 36 min | 2.4 GB | Perceptual path length in W for full paths against cropped image.
343
- | `ppl_zend` | 36 min | 2.4 GB | Perceptual path length in Z at path endpoints against cropped image.
344
- | `ppl_wend` | 36 min | 2.4 GB | Perceptual path length in W at path endpoints against cropped image.
345
 
346
  References:
347
  1. [GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium](https://arxiv.org/abs/1706.08500), Heusel et al. 2017
203
  python dataset_tool.py --source=~/downloads/cifar-10-python.tar.gz --dest=~/datasets/cifar10.zip
204
  ```
205
 
206
+ **LSUN**: Download the desired categories from the [LSUN project page](https://www.yf.io/p/lsun/) and convert to ZIP archive:
207
 
208
  ```.bash
209
  python dataset_tool.py --source=~/downloads/lsun/raw/cat_lmdb --dest=~/datasets/lsuncat200k.zip \
262
  * `--cond=1` enables class-conditional training (requires a dataset with labels).
263
  * `--mirror=1` amplifies the dataset with x-flips. Often beneficial, even with ADA.
264
  * `--resume=ffhq1024 --snap=10` performs transfer learning from FFHQ trained at 1024x1024.
265
+ * `--resume=~/training-runs/<NAME>/network-snapshot-<INT>.pkl` resumes a previous training run.
266
  * `--gamma=10` overrides R1 gamma. We recommend trying a couple of different values for each new dataset.
267
  * `--aug=ada --target=0.7` adjusts ADA target value (default: 0.6).
268
  * `--augpipe=blit` enables pixel blitting but disables all other augmentations.
293
  | 1024x1024 | 4 | 11h 36m | 12d 02h | 40.1&ndash;40.8 | 8.4 GB | 21.9 GB
294
  | 1024x1024 | 8 | 5h 54m | 6d 03h | 20.2&ndash;20.6 | 8.3 GB | 44.7 GB
295
 
296
+ The above measurements were done using NVIDIA Tesla V100 GPUs with default settings (`--cfg=auto --aug=ada --metrics=fid50k_full`). "sec/kimg" shows the expected range of variation in raw training performance, as reported in `log.txt`. "GPU mem" and "CPU mem" show the highest observed memory consumption, excluding the peak at the beginning caused by `torch.backends.cudnn.benchmark`.
297
 
298
  In typical cases, 25000 kimg or more is needed to reach convergence, but the results are already quite reasonable around 5000 kimg. 1000 kimg is often enough for transfer learning, which tends to converge significantly faster. The following figure shows example convergence curves for different datasets as a function of wallclock time, using the same settings as above:
299
 
325
 
326
  | Metric | Time | GPU mem | Description |
327
  | :----- | :----: | :-----: | :---------- |
328
+ | `fid50k_full` | 13 min | 1.8 GB | Fr&eacute;chet inception distance<sup>[1]</sup> against the full dataset
329
+ | `kid50k_full` | 13 min | 1.8 GB | Kernel inception distance<sup>[2]</sup> against the full dataset
330
+ | `pr50k3_full` | 13 min | 4.1 GB | Precision and recall<sup>[3]</sup> againt the full dataset
331
+ | `is50k` | 13 min | 1.8 GB | Inception score<sup>[4]</sup> for CIFAR-10
332
 
333
  In addition, the following metrics from the [StyleGAN](https://github.com/NVlabs/stylegan) and [StyleGAN2](https://github.com/NVlabs/stylegan2) papers are also supported:
334
 
335
  | Metric | Time | GPU mem | Description |
336
  | :------------ | :----: | :-----: | :---------- |
337
+ | `fid50k` | 13 min | 1.8 GB | Fr&eacute;chet inception distance against 50k real images
338
+ | `kid50k` | 13 min | 1.8 GB | Kernel inception distance against 50k real images
339
+ | `pr50k3` | 13 min | 4.1 GB | Precision and recall against 50k real images
340
+ | `ppl2_wend` | 36 min | 2.4 GB | Perceptual path length<sup>[5]</sup> in W, endpoints, full image
341
+ | `ppl_zfull` | 36 min | 2.4 GB | Perceptual path length in Z, full paths, cropped image
342
+ | `ppl_wfull` | 36 min | 2.4 GB | Perceptual path length in W, full paths, cropped image
343
+ | `ppl_zend` | 36 min | 2.4 GB | Perceptual path length in Z, endpoints, cropped image
344
+ | `ppl_wend` | 36 min | 2.4 GB | Perceptual path length in W, endpoints, cropped image
345
 
346
  References:
347
  1. [GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium](https://arxiv.org/abs/1706.08500), Heusel et al. 2017
dnnlib/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
  #
3
  # NVIDIA CORPORATION and its licensors retain all intellectual property
4
  # and proprietary rights in and to this software, related documentation
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
  #
3
  # NVIDIA CORPORATION and its licensors retain all intellectual property
4
  # and proprietary rights in and to this software, related documentation
docs/dataset-tool-help.txt CHANGED
@@ -1,50 +1,50 @@
1
- Usage: dataset_tool.py [OPTIONS]
2
-
3
- Convert an image dataset into a dataset archive usable with StyleGAN2 ADA
4
- PyTorch.
5
-
6
- The input dataset format is guessed from the --source argument:
7
-
8
- --source *_lmdb/ - Load LSUN dataset
9
- --source cifar-10-python.tar.gz - Load CIFAR-10 dataset
10
- --source path/ - Recursively load all images from path/
11
- --source dataset.zip - Recursively load all images from dataset.zip
12
-
13
- The output dataset format can be either an image folder or a zip archive.
14
- Specifying the output format and path:
15
-
16
- --dest /path/to/dir - Save output files under /path/to/dir
17
- --dest /path/to/dataset.zip - Save output files into /path/to/dataset.zip archive
18
-
19
- Images within the dataset archive will be stored as uncompressed PNG.
20
-
21
- Image scale/crop and resolution requirements:
22
-
23
- Output images must be square-shaped and they must all have the same power-
24
- of-two dimensions.
25
-
26
- To scale arbitrary input image size to a specific width and height, use
27
- the --width and --height options. Output resolution will be either the
28
- original input resolution (if --width/--height was not specified) or the
29
- one specified with --width/height.
30
-
31
- Use the --transform=center-crop or --transform=center-crop-wide options to
32
- apply a center crop transform on the input image. These options should be
33
- used with the --width and --height options. For example:
34
-
35
- python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \
36
- --transform=center-crop-wide --width 512 --height=384
37
-
38
- Options:
39
- --source PATH Directory or archive name for input dataset
40
- [required]
41
- --dest PATH Output directory or archive name for output
42
- dataset [required]
43
- --max-images INTEGER Output only up to `max-images` images
44
- --resize-filter [box|lanczos] Filter to use when resizing images for
45
- output resolution [default: lanczos]
46
- --transform [center-crop|center-crop-wide]
47
- Input crop/resize mode
48
- --width INTEGER Output width
49
- --height INTEGER Output height
50
- --help Show this message and exit.
1
+ Usage: dataset_tool.py [OPTIONS]
2
+
3
+ Convert an image dataset into a dataset archive usable with StyleGAN2 ADA
4
+ PyTorch.
5
+
6
+ The input dataset format is guessed from the --source argument:
7
+
8
+ --source *_lmdb/ - Load LSUN dataset
9
+ --source cifar-10-python.tar.gz - Load CIFAR-10 dataset
10
+ --source path/ - Recursively load all images from path/
11
+ --source dataset.zip - Recursively load all images from dataset.zip
12
+
13
+ The output dataset format can be either an image folder or a zip archive.
14
+ Specifying the output format and path:
15
+
16
+ --dest /path/to/dir - Save output files under /path/to/dir
17
+ --dest /path/to/dataset.zip - Save output files into /path/to/dataset.zip archive
18
+
19
+ Images within the dataset archive will be stored as uncompressed PNG.
20
+
21
+ Image scale/crop and resolution requirements:
22
+
23
+ Output images must be square-shaped and they must all have the same power-
24
+ of-two dimensions.
25
+
26
+ To scale arbitrary input image size to a specific width and height, use
27
+ the --width and --height options. Output resolution will be either the
28
+ original input resolution (if --width/--height was not specified) or the
29
+ one specified with --width/height.
30
+
31
+ Use the --transform=center-crop or --transform=center-crop-wide options to
32
+ apply a center crop transform on the input image. These options should be
33
+ used with the --width and --height options. For example:
34
+
35
+ python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \
36
+ --transform=center-crop-wide --width 512 --height=384
37
+
38
+ Options:
39
+ --source PATH Directory or archive name for input dataset
40
+ [required]
41
+ --dest PATH Output directory or archive name for output
42
+ dataset [required]
43
+ --max-images INTEGER Output only up to `max-images` images
44
+ --resize-filter [box|lanczos] Filter to use when resizing images for
45
+ output resolution [default: lanczos]
46
+ --transform [center-crop|center-crop-wide]
47
+ Input crop/resize mode
48
+ --width INTEGER Output width
49
+ --height INTEGER Output height
50
+ --help Show this message and exit.
docs/train-help.txt CHANGED
@@ -1,69 +1,69 @@
1
- Usage: train.py [OPTIONS]
2
-
3
- Train a GAN using the techniques described in the paper "Training
4
- Generative Adversarial Networks with Limited Data".
5
-
6
- Examples:
7
-
8
- # Train with custom images using 1 GPU.
9
- python train.py --outdir=~/training-runs --data=~/my-image-folder
10
-
11
- # Train class-conditional CIFAR-10 using 2 GPUs.
12
- python train.py --outdir=~/training-runs --data=~/datasets/cifar10.zip \
13
- --gpus=2 --cfg=cifar --cond=1
14
-
15
- # Transfer learn MetFaces from FFHQ using 4 GPUs.
16
- python train.py --outdir=~/training-runs --data=~/datasets/metfaces.zip \
17
- --gpus=4 --cfg=paper1024 --mirror=1 --resume=ffhq1024 --snap=10
18
-
19
- # Reproduce original StyleGAN2 config F.
20
- python train.py --outdir=~/training-runs --data=~/datasets/ffhq.zip \
21
- --gpus=8 --cfg=stylegan2 --mirror=1 --aug=noaug
22
-
23
- Base configs (--cfg):
24
- auto Automatically select reasonable defaults based on resolution
25
- and GPU count. Good starting point for new datasets.
26
- stylegan2 Reproduce results for StyleGAN2 config F at 1024x1024.
27
- paper256 Reproduce results for FFHQ and LSUN Cat at 256x256.
28
- paper512 Reproduce results for BreCaHAD and AFHQ at 512x512.
29
- paper1024 Reproduce results for MetFaces at 1024x1024.
30
- cifar Reproduce results for CIFAR-10 at 32x32.
31
-
32
- Transfer learning source networks (--resume):
33
- ffhq256 FFHQ trained at 256x256 resolution.
34
- ffhq512 FFHQ trained at 512x512 resolution.
35
- ffhq1024 FFHQ trained at 1024x1024 resolution.
36
- celebahq256 CelebA-HQ trained at 256x256 resolution.
37
- lsundog256 LSUN Dog trained at 256x256 resolution.
38
- <PATH or URL> Custom network pickle.
39
-
40
- Options:
41
- --outdir DIR Where to save the results [required]
42
- --gpus INT Number of GPUs to use [default: 1]
43
- --snap INT Snapshot interval [default: 50 ticks]
44
- --metrics LIST Comma-separated list or "none" [default:
45
- fid50k_full]
46
- --seed INT Random seed [default: 0]
47
- -n, --dry-run Print training options and exit
48
- --data PATH Training data (directory or zip) [required]
49
- --cond BOOL Train conditional model based on dataset
50
- labels [default: false]
51
- --subset INT Train with only N images [default: all]
52
- --mirror BOOL Enable dataset x-flips [default: false]
53
- --cfg [auto|stylegan2|paper256|paper512|paper1024|cifar]
54
- Base config [default: auto]
55
- --gamma FLOAT Override R1 gamma
56
- --kimg INT Override training duration
57
- --batch INT Override batch size
58
- --aug [noaug|ada|fixed] Augmentation mode [default: ada]
59
- --p FLOAT Augmentation probability for --aug=fixed
60
- --target FLOAT ADA target value for --aug=ada
61
- --augpipe [blit|geom|color|filter|noise|cutout|bg|bgc|bgcf|bgcfn|bgcfnc]
62
- Augmentation pipeline [default: bgc]
63
- --resume PKL Resume training [default: noresume]
64
- --freezed INT Freeze-D [default: 0 layers]
65
- --fp32 BOOL Disable mixed-precision training
66
- --nhwc BOOL Use NHWC memory format with FP16
67
- --nobench BOOL Disable cuDNN benchmarking
68
- --workers INT Override number of DataLoader workers
69
- --help Show this message and exit.
1
+ Usage: train.py [OPTIONS]
2
+
3
+ Train a GAN using the techniques described in the paper "Training
4
+ Generative Adversarial Networks with Limited Data".
5
+
6
+ Examples:
7
+
8
+ # Train with custom images using 1 GPU.
9
+ python train.py --outdir=~/training-runs --data=~/my-image-folder
10
+
11
+ # Train class-conditional CIFAR-10 using 2 GPUs.
12
+ python train.py --outdir=~/training-runs --data=~/datasets/cifar10.zip \
13
+ --gpus=2 --cfg=cifar --cond=1
14
+
15
+ # Transfer learn MetFaces from FFHQ using 4 GPUs.
16
+ python train.py --outdir=~/training-runs --data=~/datasets/metfaces.zip \
17
+ --gpus=4 --cfg=paper1024 --mirror=1 --resume=ffhq1024 --snap=10
18
+
19
+ # Reproduce original StyleGAN2 config F.
20
+ python train.py --outdir=~/training-runs --data=~/datasets/ffhq.zip \
21
+ --gpus=8 --cfg=stylegan2 --mirror=1 --aug=noaug
22
+
23
+ Base configs (--cfg):
24
+ auto Automatically select reasonable defaults based on resolution
25
+ and GPU count. Good starting point for new datasets.
26
+ stylegan2 Reproduce results for StyleGAN2 config F at 1024x1024.
27
+ paper256 Reproduce results for FFHQ and LSUN Cat at 256x256.
28
+ paper512 Reproduce results for BreCaHAD and AFHQ at 512x512.
29
+ paper1024 Reproduce results for MetFaces at 1024x1024.
30
+ cifar Reproduce results for CIFAR-10 at 32x32.
31
+
32
+ Transfer learning source networks (--resume):
33
+ ffhq256 FFHQ trained at 256x256 resolution.
34
+ ffhq512 FFHQ trained at 512x512 resolution.
35
+ ffhq1024 FFHQ trained at 1024x1024 resolution.
36
+ celebahq256 CelebA-HQ trained at 256x256 resolution.
37
+ lsundog256 LSUN Dog trained at 256x256 resolution.
38
+ <PATH or URL> Custom network pickle.
39
+
40
+ Options:
41
+ --outdir DIR Where to save the results [required]
42
+ --gpus INT Number of GPUs to use [default: 1]
43
+ --snap INT Snapshot interval [default: 50 ticks]
44
+ --metrics LIST Comma-separated list or "none" [default:
45
+ fid50k_full]
46
+ --seed INT Random seed [default: 0]
47
+ -n, --dry-run Print training options and exit
48
+ --data PATH Training data (directory or zip) [required]
49
+ --cond BOOL Train conditional model based on dataset
50
+ labels [default: false]
51
+ --subset INT Train with only N images [default: all]
52
+ --mirror BOOL Enable dataset x-flips [default: false]
53
+ --cfg [auto|stylegan2|paper256|paper512|paper1024|cifar]
54
+ Base config [default: auto]
55
+ --gamma FLOAT Override R1 gamma
56
+ --kimg INT Override training duration
57
+ --batch INT Override batch size
58
+ --aug [noaug|ada|fixed] Augmentation mode [default: ada]
59
+ --p FLOAT Augmentation probability for --aug=fixed
60
+ --target FLOAT ADA target value for --aug=ada
61
+ --augpipe [blit|geom|color|filter|noise|cutout|bg|bgc|bgcf|bgcfn|bgcfnc]
62
+ Augmentation pipeline [default: bgc]
63
+ --resume PKL Resume training [default: noresume]
64
+ --freezed INT Freeze-D [default: 0 layers]
65
+ --fp32 BOOL Disable mixed-precision training
66
+ --nhwc BOOL Use NHWC memory format with FP16
67
+ --nobench BOOL Disable cuDNN benchmarking
68
+ --workers INT Override number of DataLoader workers
69
+ --help Show this message and exit.
metrics/frechet_inception_distance.py CHANGED
@@ -6,16 +6,21 @@
6
  # distribution of this software and related documentation without an express
7
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
 
 
 
 
 
 
9
  import numpy as np
10
  import scipy.linalg
11
-
12
  from . import metric_utils
13
 
14
  #----------------------------------------------------------------------------
15
 
16
  def compute_fid(opts, max_real, num_gen):
 
17
  detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
18
- detector_kwargs = dict(return_features=True)
19
 
20
  mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset(
21
  opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
6
  # distribution of this software and related documentation without an express
7
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
 
9
+ """Frechet Inception Distance (FID) from the paper
10
+ "GANs trained by a two time-scale update rule converge to a local Nash
11
+ equilibrium". Matches the original implementation by Heusel et al. at
12
+ https://github.com/bioinf-jku/TTUR/blob/master/fid.py"""
13
+
14
  import numpy as np
15
  import scipy.linalg
 
16
  from . import metric_utils
17
 
18
  #----------------------------------------------------------------------------
19
 
20
  def compute_fid(opts, max_real, num_gen):
21
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
22
  detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
23
+ detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
24
 
25
  mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset(
26
  opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
metrics/inception_score.py CHANGED
@@ -6,15 +6,19 @@
6
  # distribution of this software and related documentation without an express
7
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
 
9
- import numpy as np
 
 
10
 
 
11
  from . import metric_utils
12
 
13
  #----------------------------------------------------------------------------
14
 
15
  def compute_is(opts, num_gen, num_splits):
 
16
  detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
17
- detector_kwargs = dict(no_output_bias=True)
18
 
19
  gen_probs = metric_utils.compute_feature_stats_for_generator(
20
  opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
6
  # distribution of this software and related documentation without an express
7
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
 
9
+ """Inception Score (IS) from the paper "Improved techniques for training
10
+ GANs". Matches the original implementation by Salimans et al. at
11
+ https://github.com/openai/improved-gan/blob/master/inception_score/model.py"""
12
 
13
+ import numpy as np
14
  from . import metric_utils
15
 
16
  #----------------------------------------------------------------------------
17
 
18
  def compute_is(opts, num_gen, num_splits):
19
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
20
  detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
21
+ detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer.
22
 
23
  gen_probs = metric_utils.compute_feature_stats_for_generator(
24
  opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
metrics/kernel_inception_distance.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
  #
3
  # NVIDIA CORPORATION and its licensors retain all intellectual property
4
  # and proprietary rights in and to this software, related documentation
@@ -6,15 +6,19 @@
6
  # distribution of this software and related documentation without an express
7
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
 
9
- import numpy as np
 
 
10
 
 
11
  from . import metric_utils
12
 
13
  #----------------------------------------------------------------------------
14
 
15
  def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size):
 
16
  detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
17
- detector_kwargs = dict(return_features=True)
18
 
19
  real_features = metric_utils.compute_feature_stats_for_dataset(
20
  opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
  #
3
  # NVIDIA CORPORATION and its licensors retain all intellectual property
4
  # and proprietary rights in and to this software, related documentation
6
  # distribution of this software and related documentation without an express
7
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
 
9
+ """Kernel Inception Distance (KID) from the paper "Demystifying MMD
10
+ GANs". Matches the original implementation by Binkowski et al. at
11
+ https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py"""
12
 
13
+ import numpy as np
14
  from . import metric_utils
15
 
16
  #----------------------------------------------------------------------------
17
 
18
  def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size):
19
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
20
  detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
21
+ detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
22
 
23
  real_features = metric_utils.compute_feature_stats_for_dataset(
24
  opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
metrics/perceptual_path_length.py CHANGED
@@ -6,11 +6,15 @@
6
  # distribution of this software and related documentation without an express
7
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
 
 
 
 
 
 
9
  import copy
10
  import numpy as np
11
  import torch
12
  import dnnlib
13
-
14
  from . import metric_utils
15
 
16
  #----------------------------------------------------------------------------
6
  # distribution of this software and related documentation without an express
7
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
 
9
+ """Perceptual Path Length (PPL) from the paper "A Style-Based Generator
10
+ Architecture for Generative Adversarial Networks". Matches the original
11
+ implementation by Karras et al. at
12
+ https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py"""
13
+
14
  import copy
15
  import numpy as np
16
  import torch
17
  import dnnlib
 
18
  from . import metric_utils
19
 
20
  #----------------------------------------------------------------------------
metrics/precision_recall.py CHANGED
@@ -6,8 +6,12 @@
6
  # distribution of this software and related documentation without an express
7
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
 
9
- import torch
 
 
 
10
 
 
11
  from . import metric_utils
12
 
13
  #----------------------------------------------------------------------------
6
  # distribution of this software and related documentation without an express
7
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
 
9
+ """Precision/Recall (PR) from the paper "Improved Precision and Recall
10
+ Metric for Assessing Generative Models". Matches the original implementation
11
+ by Kynkaanniemi et al. at
12
+ https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py"""
13
 
14
+ import torch
15
  from . import metric_utils
16
 
17
  #----------------------------------------------------------------------------
torch_utils/persistence.py CHANGED
@@ -6,6 +6,13 @@
6
  # distribution of this software and related documentation without an express
7
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
 
 
 
 
 
 
 
 
9
  import sys
10
  import pickle
11
  import io
@@ -17,29 +24,70 @@ import dnnlib
17
 
18
  #----------------------------------------------------------------------------
19
 
20
- _version = 6
21
- _decorators = set() # {decorator_class}
22
- _import_hooks = [] # [function]
23
- _module_to_src_dict = dict() # {module: src}
24
- _src_to_module_dict = dict() # {src: module}
25
-
26
- #----------------------------------------------------------------------------
27
-
28
- def is_persistent(obj):
29
- try:
30
- if obj in _decorators:
31
- return True
32
- except TypeError:
33
- pass
34
- return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
35
-
36
- def import_hook(func):
37
- assert callable(func)
38
- _import_hooks.append(func)
39
 
40
  #----------------------------------------------------------------------------
41
 
42
  def persistent_class(orig_class):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  assert isinstance(orig_class, type)
44
  if is_persistent(orig_class):
45
  return orig_class
@@ -83,7 +131,55 @@ def persistent_class(orig_class):
83
 
84
  #----------------------------------------------------------------------------
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  def _reconstruct_persistent_obj(meta):
 
 
 
87
  meta = dnnlib.EasyDict(meta)
88
  meta.state = dnnlib.EasyDict(meta.state)
89
  for hook in _import_hooks:
@@ -108,6 +204,8 @@ def _reconstruct_persistent_obj(meta):
108
  #----------------------------------------------------------------------------
109
 
110
  def _module_to_src(module):
 
 
111
  src = _module_to_src_dict.get(module, None)
112
  if src is None:
113
  src = inspect.getsource(module)
@@ -116,6 +214,8 @@ def _module_to_src(module):
116
  return src
117
 
118
  def _src_to_module(src):
 
 
119
  module = _src_to_module_dict.get(src, None)
120
  if module is None:
121
  module_name = "_imported_module_" + uuid.uuid4().hex
@@ -129,15 +229,19 @@ def _src_to_module(src):
129
  #----------------------------------------------------------------------------
130
 
131
  def _check_pickleable(obj):
 
 
 
 
132
  def recurse(obj):
133
  if isinstance(obj, (list, tuple, set)):
134
  return [recurse(x) for x in obj]
135
  if isinstance(obj, dict):
136
  return [[recurse(x), recurse(y)] for x, y in obj.items()]
137
  if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
138
- return None # Primitive types are pickleable.
139
  if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']:
140
- return None # Tensors are pickleable.
141
  if is_persistent(obj):
142
  return None # Persistent objects are pickleable, by virtue of the constructor check.
143
  return obj
6
  # distribution of this software and related documentation without an express
7
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
 
9
+ """Facilities for pickling Python code alongside other data.
10
+
11
+ The pickled code is automatically imported into a separate Python module
12
+ during unpickling. This way, any previously exported pickles will remain
13
+ usable even if the original code is no longer available, or if the current
14
+ version of the code is not consistent with what was originally pickled."""
15
+
16
  import sys
17
  import pickle
18
  import io
24
 
25
  #----------------------------------------------------------------------------
26
 
27
+ _version = 6 # internal version number
28
+ _decorators = set() # {decorator_class, ...}
29
+ _import_hooks = [] # [hook_function, ...]
30
+ _module_to_src_dict = dict() # {module: src, ...}
31
+ _src_to_module_dict = dict() # {src: module, ...}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  #----------------------------------------------------------------------------
34
 
35
  def persistent_class(orig_class):
36
+ r"""Class decorator that extends a given class to save its source code
37
+ when pickled.
38
+
39
+ Example:
40
+
41
+ from torch_utils import persistence
42
+
43
+ @persistence.persistent_class
44
+ class MyNetwork(torch.nn.Module):
45
+ def __init__(self, num_inputs, num_outputs):
46
+ super().__init__()
47
+ self.fc = MyLayer(num_inputs, num_outputs)
48
+ ...
49
+
50
+ @persistence.persistent_class
51
+ class MyLayer(torch.nn.Module):
52
+ ...
53
+
54
+ When pickled, any instance of `MyNetwork` and `MyLayer` will save its
55
+ source code alongside other internal state (e.g., parameters, buffers,
56
+ and submodules). This way, any previously exported pickle will remain
57
+ usable even if the class definitions have been modified or are no
58
+ longer available.
59
+
60
+ The decorator saves the source code of the entire Python module
61
+ containing the decorated class. It does *not* save the source code of
62
+ any imported modules. Thus, the imported modules must be available
63
+ during unpickling, also including `torch_utils.persistence` itself.
64
+
65
+ It is ok to call functions defined in the same module from the
66
+ decorated class. However, if the decorated class depends on other
67
+ classes defined in the same module, they must be decorated as well.
68
+ This is illustrated in the above example in the case of `MyLayer`.
69
+
70
+ It is also possible to employ the decorator just-in-time before
71
+ calling the constructor. For example:
72
+
73
+ cls = MyLayer
74
+ if want_to_make_it_persistent:
75
+ cls = persistence.persistent_class(cls)
76
+ layer = cls(num_inputs, num_outputs)
77
+
78
+ As an additional feature, the decorator also keeps track of the
79
+ arguments that were used to construct each instance of the decorated
80
+ class. The arguments can be queried via `obj.init_args` and
81
+ `obj.init_kwargs`, and they are automatically pickled alongside other
82
+ object state. A typical use case is to first unpickle a previous
83
+ instance of a persistent class, and then upgrade it to use the latest
84
+ version of the source code:
85
+
86
+ with open('old_pickle.pkl', 'rb') as f:
87
+ old_net = pickle.load(f)
88
+ new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
89
+ misc.copy_params_and_buffers(old_net, new_net, require_all=True)
90
+ """
91
  assert isinstance(orig_class, type)
92
  if is_persistent(orig_class):
93
  return orig_class
131
 
132
  #----------------------------------------------------------------------------
133
 
134
+ def is_persistent(obj):
135
+ r"""Test whether the given object or class is persistent, i.e.,
136
+ whether it will save its source code when pickled.
137
+ """
138
+ try:
139
+ if obj in _decorators:
140
+ return True
141
+ except TypeError:
142
+ pass
143
+ return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
144
+
145
+ #----------------------------------------------------------------------------
146
+
147
+ def import_hook(hook):
148
+ r"""Register an import hook that is called whenever a persistent object
149
+ is being unpickled. A typical use case is to patch the pickled source
150
+ code to avoid errors and inconsistencies when the API of some imported
151
+ module has changed.
152
+
153
+ The hook should have the following signature:
154
+
155
+ hook(meta) -> modified meta
156
+
157
+ `meta` is an instance of `dnnlib.EasyDict` with the following fields:
158
+
159
+ type: Type of the persistent object, e.g. `'class'`.
160
+ version: Internal version number of `torch_utils.persistence`.
161
+ module_src Original source code of the Python module.
162
+ class_name: Class name in the original Python module.
163
+ state: Internal state of the object.
164
+
165
+ Example:
166
+
167
+ @persistence.import_hook
168
+ def wreck_my_network(meta):
169
+ if meta.class_name == 'MyNetwork':
170
+ print('MyNetwork is being imported. I will wreck it!')
171
+ meta.module_src = meta.module_src.replace("True", "False")
172
+ return meta
173
+ """
174
+ assert callable(hook)
175
+ _import_hooks.append(hook)
176
+
177
+ #----------------------------------------------------------------------------
178
+
179
  def _reconstruct_persistent_obj(meta):
180
+ r"""Hook that is called internally by the `pickle` module to unpickle
181
+ a persistent object.
182
+ """
183
  meta = dnnlib.EasyDict(meta)
184
  meta.state = dnnlib.EasyDict(meta.state)
185
  for hook in _import_hooks:
204
  #----------------------------------------------------------------------------
205
 
206
  def _module_to_src(module):
207
+ r"""Query the source code of a given Python module.
208
+ """
209
  src = _module_to_src_dict.get(module, None)
210
  if src is None:
211
  src = inspect.getsource(module)
214
  return src
215
 
216
  def _src_to_module(src):
217
+ r"""Get or create a Python module for the given source code.
218
+ """
219
  module = _src_to_module_dict.get(src, None)
220
  if module is None:
221
  module_name = "_imported_module_" + uuid.uuid4().hex
229
  #----------------------------------------------------------------------------
230
 
231
  def _check_pickleable(obj):
232
+ r"""Check that the given object is pickleable, raising an exception if
233
+ it is not. This function is expected to be considerably more efficient
234
+ than actually pickling the object.
235
+ """
236
  def recurse(obj):
237
  if isinstance(obj, (list, tuple, set)):
238
  return [recurse(x) for x in obj]
239
  if isinstance(obj, dict):
240
  return [[recurse(x), recurse(y)] for x, y in obj.items()]
241
  if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
242
+ return None # Python primitive types are pickleable.
243
  if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']:
244
+ return None # NumPy arrays and PyTorch tensors are pickleable.
245
  if is_persistent(obj):
246
  return None # Persistent objects are pickleable, by virtue of the constructor check.
247
  return obj
torch_utils/training_stats.py CHANGED
@@ -6,6 +6,11 @@
6
  # distribution of this software and related documentation without an express
7
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
 
 
 
 
 
 
9
  import re
10
  import numpy as np
11
  import torch
@@ -15,19 +20,31 @@ from . import misc
15
 
16
  #----------------------------------------------------------------------------
17
 
18
- _num_moments = 3 # [num_scalars, sum_scalars, sum_squares]
19
  _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
20
- _counter_dtype = torch.float64 # Data type to use for the counters.
21
-
22
  _rank = 0 # Rank of the current process.
23
  _sync_device = None # Device to use for multiprocess communication. None = single-process.
24
  _sync_called = False # Has _sync() been called yet?
25
- _counters = dict() # Running counter on each device, updated by report(): name => device => torch.Tensor
26
- _cumulative = dict() # Cumulative counter on the CPU, updated by _sync(): name => torch.Tensor
27
 
28
  #----------------------------------------------------------------------------
29
 
30
  def init_multiprocessing(rank, sync_device):
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  global _rank, _sync_device
32
  assert not _sync_called
33
  _rank = rank
@@ -37,6 +54,28 @@ def init_multiprocessing(rank, sync_device):
37
 
38
  @misc.profiled_function
39
  def report(name, value):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  if name not in _counters:
41
  _counters[name] = dict()
42
 
@@ -45,7 +84,11 @@ def report(name, value):
45
  return value
46
 
47
  elems = elems.detach().flatten().to(_reduce_dtype)
48
- moments = torch.stack([torch.ones_like(elems).sum(), elems.sum(), elems.square().sum()])
 
 
 
 
49
  assert moments.ndim == 1 and moments.shape[0] == _num_moments
50
  moments = moments.to(_counter_dtype)
51
 
@@ -58,45 +101,35 @@ def report(name, value):
58
  #----------------------------------------------------------------------------
59
 
60
  def report0(name, value):
 
 
 
 
61
  report(name, value if _rank == 0 else [])
62
  return value
63
 
64
  #----------------------------------------------------------------------------
65
 
66
- def _sync(names):
67
- if len(names) == 0:
68
- return []
69
- global _sync_called
70
- _sync_called = True
71
-
72
- # Collect deltas within current rank.
73
- deltas = []
74
- device = _sync_device if _sync_device is not None else torch.device('cpu')
75
- for name in names:
76
- delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
77
- for counter in _counters[name].values():
78
- delta.add_(counter.to(device))
79
- counter.copy_(torch.zeros_like(counter))
80
- deltas.append(delta)
81
- deltas = torch.stack(deltas)
82
-
83
- # Sum deltas across ranks.
84
- if _sync_device is not None:
85
- torch.distributed.all_reduce(deltas)
86
-
87
- # Update cumulative values.
88
- deltas = deltas.cpu()
89
- for idx, name in enumerate(names):
90
- if name not in _cumulative:
91
- _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
92
- _cumulative[name].add_(deltas[idx])
93
 
94
- # Return name-value pairs.
95
- return [(name, _cumulative[name]) for name in names]
 
 
 
 
 
96
 
97
- #----------------------------------------------------------------------------
98
-
99
- class Collector:
 
 
 
 
100
  def __init__(self, regex='.*', keep_previous=True):
101
  self._regex = re.compile(regex)
102
  self._keep_previous = keep_previous
@@ -106,9 +139,24 @@ class Collector:
106
  self._moments.clear()
107
 
108
  def names(self):
 
 
 
109
  return [name for name in _counters if self._regex.fullmatch(name)]
110
 
111
  def update(self):
 
 
 
 
 
 
 
 
 
 
 
 
112
  if not self._keep_previous:
113
  self._moments.clear()
114
  for name, cumulative in _sync(self.names()):
@@ -120,22 +168,38 @@ class Collector:
120
  self._moments[name] = delta
121
 
122
  def _get_delta(self, name):
 
 
 
 
123
  assert self._regex.fullmatch(name)
124
  if name not in self._moments:
125
  self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
126
  return self._moments[name]
127
 
128
  def num(self, name):
 
 
 
 
129
  delta = self._get_delta(name)
130
  return int(delta[0])
131
 
132
  def mean(self, name):
 
 
 
 
133
  delta = self._get_delta(name)
134
  if int(delta[0]) == 0:
135
  return float('nan')
136
  return float(delta[1] / delta[0])
137
 
138
  def std(self, name):
 
 
 
 
139
  delta = self._get_delta(name)
140
  if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
141
  return float('nan')
@@ -146,12 +210,59 @@ class Collector:
146
  return np.sqrt(max(raw_var - np.square(mean), 0))
147
 
148
  def as_dict(self):
 
 
 
 
 
 
 
 
149
  stats = dnnlib.EasyDict()
150
  for name in self.names():
151
  stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
152
  return stats
153
 
154
  def __getitem__(self, name):
 
 
 
155
  return self.mean(name)
156
 
157
  #----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  # distribution of this software and related documentation without an express
7
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
 
9
+ """Facilities for reporting and collecting training statistics across
10
+ multiple processes and devices. The interface is designed to minimize
11
+ synchronization overhead as well as the amount of boilerplate in user
12
+ code."""
13
+
14
  import re
15
  import numpy as np
16
  import torch
20
 
21
  #----------------------------------------------------------------------------
22
 
23
+ _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
24
  _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
25
+ _counter_dtype = torch.float64 # Data type to use for the internal counters.
 
26
  _rank = 0 # Rank of the current process.
27
  _sync_device = None # Device to use for multiprocess communication. None = single-process.
28
  _sync_called = False # Has _sync() been called yet?
29
+ _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor
30
+ _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
31
 
32
  #----------------------------------------------------------------------------
33
 
34
  def init_multiprocessing(rank, sync_device):
35
+ r"""Initializes `torch_utils.training_stats` for collecting statistics
36
+ across multiple processes.
37
+
38
+ This function must be called after
39
+ `torch.distributed.init_process_group()` and before `Collector.update()`.
40
+ The call is not necessary if multi-process collection is not needed.
41
+
42
+ Args:
43
+ rank: Rank of the current process.
44
+ sync_device: PyTorch device to use for inter-process
45
+ communication, or None to disable multi-process
46
+ collection. Typically `torch.device('cuda', rank)`.
47
+ """
48
  global _rank, _sync_device
49
  assert not _sync_called
50
  _rank = rank
54
 
55
  @misc.profiled_function
56
  def report(name, value):
57
+ r"""Broadcasts the given set of scalars to all interested instances of
58
+ `Collector`, across device and process boundaries.
59
+
60
+ This function is expected to be extremely cheap and can be safely
61
+ called from anywhere in the training loop, loss function, or inside a
62
+ `torch.nn.Module`.
63
+
64
+ Warning: The current implementation expects the set of unique names to
65
+ be consistent across processes. Please make sure that `report()` is
66
+ called at least once for each unique name by each process, and in the
67
+ same order. If a given process has no scalars to broadcast, it can do
68
+ `report(name, [])` (empty list).
69
+
70
+ Args:
71
+ name: Arbitrary string specifying the name of the statistic.
72
+ Averages are accumulated separately for each unique name.
73
+ value: Arbitrary set of scalars. Can be a list, tuple,
74
+ NumPy array, PyTorch tensor, or Python scalar.
75
+
76
+ Returns:
77
+ The same `value` that was passed in.
78
+ """
79
  if name not in _counters:
80
  _counters[name] = dict()
81
 
84
  return value
85
 
86
  elems = elems.detach().flatten().to(_reduce_dtype)
87
+ moments = torch.stack([
88
+ torch.ones_like(elems).sum(),
89
+ elems.sum(),
90
+ elems.square().sum(),
91
+ ])
92
  assert moments.ndim == 1 and moments.shape[0] == _num_moments
93
  moments = moments.to(_counter_dtype)
94
 
101
  #----------------------------------------------------------------------------
102
 
103
  def report0(name, value):
104
+ r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
105
+ but ignores any scalars provided by the other processes.
106
+ See `report()` for further details.
107
+ """
108
  report(name, value if _rank == 0 else [])
109
  return value
110
 
111
  #----------------------------------------------------------------------------
112
 
113
+ class Collector:
114
+ r"""Collects the scalars broadcasted by `report()` and `report0()` and
115
+ computes their long-term averages (mean and standard deviation) over
116
+ user-defined periods of time.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
+ The averages are first collected into internal counters that are not
119
+ directly visible to the user. They are then copied to the user-visible
120
+ state as a result of calling `update()` and can then be queried using
121
+ `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the
122
+ internal counters for the next round, so that the user-visible state
123
+ effectively reflects averages collected between the last two calls to
124
+ `update()`.
125
 
126
+ Args:
127
+ regex: Regular expression defining which statistics to
128
+ collect. The default is to collect everything.
129
+ keep_previous: Whether to retain the previous averages if no
130
+ scalars were collected on a given round
131
+ (default: True).
132
+ """
133
  def __init__(self, regex='.*', keep_previous=True):
134
  self._regex = re.compile(regex)
135
  self._keep_previous = keep_previous
139
  self._moments.clear()
140
 
141
  def names(self):
142
+ r"""Returns the names of all statistics broadcasted so far that
143
+ match the regular expression specified at construction time.
144
+ """
145
  return [name for name in _counters if self._regex.fullmatch(name)]
146
 
147
  def update(self):
148
+ r"""Copies current values of the internal counters to the
149
+ user-visible state and resets them for the next round.
150
+
151
+ If `keep_previous=True` was specified at construction time, the
152
+ operation is skipped for statistics that have received no scalars
153
+ since the last update, retaining their previous averages.
154
+
155
+ This method performs a number of GPU-to-CPU transfers and one
156
+ `torch.distributed.all_reduce()`. It is intended to be called
157
+ periodically in the main training loop, typically once every
158
+ N training steps.
159
+ """
160
  if not self._keep_previous:
161
  self._moments.clear()
162
  for name, cumulative in _sync(self.names()):
168
  self._moments[name] = delta
169
 
170
  def _get_delta(self, name):
171
+ r"""Returns the raw moments that were accumulated for the given
172
+ statistic between the last two calls to `update()`, or zero if
173
+ no scalars were collected.
174
+ """
175
  assert self._regex.fullmatch(name)
176
  if name not in self._moments:
177
  self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
178
  return self._moments[name]
179
 
180
  def num(self, name):
181
+ r"""Returns the number of scalars that were accumulated for the given
182
+ statistic between the last two calls to `update()`, or zero if
183
+ no scalars were collected.
184
+ """
185
  delta = self._get_delta(name)
186
  return int(delta[0])
187
 
188
  def mean(self, name):
189
+ r"""Returns the mean of the scalars that were accumulated for the
190
+ given statistic between the last two calls to `update()`, or NaN if
191
+ no scalars were collected.
192
+ """
193
  delta = self._get_delta(name)
194
  if int(delta[0]) == 0:
195
  return float('nan')
196
  return float(delta[1] / delta[0])
197
 
198
  def std(self, name):
199
+ r"""Returns the standard deviation of the scalars that were
200
+ accumulated for the given statistic between the last two calls to
201
+ `update()`, or NaN if no scalars were collected.
202
+ """
203
  delta = self._get_delta(name)
204
  if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
205
  return float('nan')
210
  return np.sqrt(max(raw_var - np.square(mean), 0))
211
 
212
  def as_dict(self):
213
+ r"""Returns the averages accumulated between the last two calls to
214
+ `update()` as an `dnnlib.EasyDict`. The contents are as follows:
215
+
216
+ dnnlib.EasyDict(
217
+ NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT),
218
+ ...
219
+ )
220
+ """
221
  stats = dnnlib.EasyDict()
222
  for name in self.names():
223
  stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
224
  return stats
225
 
226
  def __getitem__(self, name):
227
+ r"""Convenience getter.
228
+ `collector[name]` is a synonym for `collector.mean(name)`.
229
+ """
230
  return self.mean(name)
231
 
232
  #----------------------------------------------------------------------------
233
+
234
+ def _sync(names):
235
+ r"""Synchronize the global cumulative counters across devices and
236
+ processes. Called internally by `Collector.update()`.
237
+ """
238
+ if len(names) == 0:
239
+ return []
240
+ global _sync_called
241
+ _sync_called = True
242
+
243
+ # Collect deltas within current rank.
244
+ deltas = []
245
+ device = _sync_device if _sync_device is not None else torch.device('cpu')
246
+ for name in names:
247
+ delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
248
+ for counter in _counters[name].values():
249
+ delta.add_(counter.to(device))
250
+ counter.copy_(torch.zeros_like(counter))
251
+ deltas.append(delta)
252
+ deltas = torch.stack(deltas)
253
+
254
+ # Sum deltas across ranks.
255
+ if _sync_device is not None:
256
+ torch.distributed.all_reduce(deltas)
257
+
258
+ # Update cumulative values.
259
+ deltas = deltas.cpu()
260
+ for idx, name in enumerate(names):
261
+ if name not in _cumulative:
262
+ _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
263
+ _cumulative[name].add_(deltas[idx])
264
+
265
+ # Return name-value pairs.
266
+ return [(name, _cumulative[name]) for name in names]
267
+
268
+ #----------------------------------------------------------------------------