devzhk commited on
Commit
972a35a
·
1 Parent(s): 6e0fb69

Add model files

Browse files
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. Dockerfile +4 -0
  3. LICENSE +21 -0
  4. README.md +163 -3
  5. assets/figs/12samples_compressed.png +3 -0
  6. assets/figs/arch.png +3 -0
  7. assets/figs/bar_mem_256.png +3 -0
  8. assets/figs/bar_mem_512.png +3 -0
  9. assets/figs/bar_speed_256.png +3 -0
  10. assets/figs/bar_speed_512.png +3 -0
  11. assets/figs/bubble_gflops_wg.png +3 -0
  12. assets/figs/bubble_gflops_wog.png +3 -0
  13. assets/figs/maskdit_arch.png +3 -0
  14. assets/figs/repo_head.png +3 -0
  15. assets/figs/sample512-set1.png +3 -0
  16. assets/imagenet_label.json +1 -0
  17. autoencoder.py +522 -0
  18. checkpoints/.DS_Store +0 -0
  19. configs/finetune/imagenet256-latent-const.yaml +49 -0
  20. configs/finetune/imagenet256-latent-cos.yaml +49 -0
  21. configs/finetune/imagenet512-latent.yaml +47 -0
  22. configs/test/maskdit-256.yaml +45 -0
  23. configs/test/maskdit-512.yaml +46 -0
  24. configs/train/imagenet256-latent.yaml +48 -0
  25. configs/train/imagenet512-latent.yaml +47 -0
  26. eval_latent.py +132 -0
  27. evaluator.py +695 -0
  28. extract_latent.py +114 -0
  29. fid.py +177 -0
  30. generate.py +91 -0
  31. licenses/LICENSE_ADM.txt +21 -0
  32. licenses/LICENSE_DIT.txt +400 -0
  33. licenses/LICENSE_EDM.txt +439 -0
  34. licenses/LICENSE_UVIT.txt +21 -0
  35. lmdb2wds.py +39 -0
  36. models/maskdit.py +781 -0
  37. sample.py +397 -0
  38. scripts/download_assets.sh +8 -0
  39. scripts/finetune_latent512.sh +14 -0
  40. scripts/prepare_latent256.sh +3 -0
  41. scripts/prepare_latent512.sh +6 -0
  42. scripts/train_latent512.sh +11 -0
  43. torch_utils/__init__.py +0 -0
  44. torch_utils/persistence.py +276 -0
  45. train.py +336 -0
  46. train_utils/datasets.py +412 -0
  47. train_utils/helper.py +69 -0
  48. train_utils/loss.py +101 -0
  49. train_wds.py +400 -0
  50. utils.py +225 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
Dockerfile ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ FROM nvcr.io/nvidia/pytorch:23.03-py3
2
+ RUN pip install einops lmdb omegaconf wandb tqdm pyyaml accelerate
3
+ RUN pip install timm webdataset
4
+ RUN pip install diffusers["torch"] transformers
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) [2023] [Anima-Lab]
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,3 +1,163 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Fast Training of Diffusion Models with Masked Transformers
2
+
3
+ Official PyTorch implementation of the TMLR 2024 paper:<br>
4
+ **[Fast Training of Diffusion Models with Masked Transformers](https://openreview.net/pdf?id=vTBjBtGioE)**
5
+ <br>
6
+ Hongkai Zheng*, Weili Nie*, Arash Vahdat, Anima Anandkumar <br>
7
+ (*Equal contribution)<br>
8
+
9
+ Abstract: *While masked transformers have been extensively explored for representation learning, their application to
10
+ generative learning is less explored in the vision domain. Our work is the first to exploit masked training to reduce
11
+ the training cost of diffusion models significantly. Specifically, we randomly mask out a high proportion (e.g., 50%) of
12
+ patches in diffused input images during training. For masked training, we introduce an asymmetric encoder-decoder
13
+ architecture consisting of a transformer encoder that operates only on unmasked patches and a lightweight transformer
14
+ decoder on full patches. To promote a long-range understanding of full patches, we add an auxiliary task of
15
+ reconstructing masked patches to the denoising score matching objective that learns the score of unmasked patches.
16
+ Experiments on ImageNet-256x256 and ImageNet-512x512 show that our approach achieves competitive and even better
17
+ generative performance than the state-of-the-art Diffusion Transformer (DiT) model, using only around 30% of its
18
+ original training time. Thus, our method shows a promising way of efficiently training large transformer-based diffusion
19
+ models without sacrificing the generative performance.*
20
+
21
+ <div align='center'>
22
+ <img src="assets/figs/repo_head.png" alt="Architecture" width="900" height="500" style="display: block;"/>
23
+ </div>
24
+
25
+ ## Requirements
26
+
27
+ - Training MaskDiT on ImageNet256x256 takes around 260 hours with 8 A100 GPUs to perform 2M updates with a batch size of
28
+ 1024.
29
+ - Training MaskDiT on ImageNet512x512 takes around 210 A100 GPU days to perform 1M updates with a batch size of 1024.
30
+ - At least one high-end GPU for sampling.
31
+ - [Dockerfile](Dockerfile) is provided for exact software environment.
32
+
33
+ ## Efficiency
34
+
35
+ Our MaskDiT applies Automatic Mixed Precision (AMP) by default. We also add the MaskDiT without AMP (Ours_ft32) for
36
+ reference.
37
+
38
+ ### Training speed
39
+
40
+ <img src="assets/figs/bar_speed_256.png" width=45% style="display: inline-block;"><img src="assets/figs/bar_speed_512.png" width=46% style="display: inline-block;">
41
+
42
+ ### GPU memory
43
+
44
+ <img src="assets/figs/bar_mem_256.png" width=45% style="display: inline-block;"><img src="assets/figs/bar_mem_512.png" width=44.3% style="display: inline-block;">
45
+
46
+ ## Pretrained Models
47
+ We provide pretrained models of MaskDiT for ImageNet256 and ImageNet512 in the following table. For FID with guidance, the guidance scale is set to 1.5 by default.
48
+ | Guidance | Resolution | FID | Model |
49
+ | :------- | :--------- | :---- | :------------------------------------------------------------------------------------------------------------------------ |
50
+ | Yes | 256x256 | 2.28 | [imagenet256-guidance.pt](checkpoints/imagenet256_with_guidance.pt) |
51
+ | No | 256x256 | 5.69 | [imagenet256-conditional.pt](checkpoints/imagenet256_without_guidance.pt) |
52
+ | Yes | 512x512 | 2.50 | [imagenet512-guidance.pt](checkpoints/imagenet512_with_guidance.pt) |
53
+ | No | 512x512 | 10.79 | [imagenet512-conditional.pt](checkpoints/imagenet512_without_guidance.pt) | |
54
+
55
+ ## Generate from pretrained models
56
+
57
+ To generate samples from provided checkpoints, run
58
+
59
+ ```bash
60
+ python3 generate.py --config configs/test/maskdit-512.yaml --ckpt_path [path to checkpoints] --class_idx [class index from 0-999] --cfg_scale [guidance scale]
61
+ ```
62
+
63
+ <img src="assets/figs/12samples_compressed.png" title="Generated samples from MaskDiT 256x256." width="850" style="display: block; margin: 0 auto;"/>
64
+ <p align='center'> Generated samples from MaskDiT 256x256. Upper panel: without CFG. Lower panel: with CFG (scale=1.5).
65
+ <p\>
66
+
67
+ <img src="assets/figs/imagenet512.png" title="Generated samples from MaskDiT 512x512." width="850" style="display: block; margin: 0 auto;"/>
68
+ <p align='center'> Generated samples from MaskDiT 512x512 with CFG (scale=1.5).
69
+ <p\>
70
+
71
+ ## Prepare dataset
72
+
73
+ We use the pre-trained VAE to first encode the ImageNet dataset into latent space. You can download the ImageNet-256x256
74
+ and ImageNet-512x512 that have been encoded into latent space by running
75
+
76
+ ```bash
77
+ bash scripts/download_assets.sh
78
+ ```
79
+
80
+ `extract_latent.py` was used to encode the ImageNet.
81
+
82
+ ### LMDB to Webdataset
83
+
84
+ When training on ImageNet-256x256, we store our data in LMDB format. When training on ImageNet-512x512, we
85
+ use [webdataset](https://github.com/webdataset/webdataset) for faster IO performance. To convert a LMDB dataset into a
86
+ webdataset, run
87
+
88
+ ```
89
+ python3 lmdb2wds.py --datadir [path to lmdb] --outdir [path to save webdataset] --resolution [latent resolution] --num_channels [number of latent channels]
90
+ ```
91
+
92
+ ## Train
93
+
94
+ ### ImageNet-256x256
95
+
96
+ First train MaskDiT with 50% mask ratio with AMP enabled.
97
+
98
+ ```bash
99
+ accelerate launch --multi_gpu train.py --config configs/train/imagenet256-latent.yaml
100
+ ```
101
+
102
+ Then finetune with unmasking.
103
+
104
+ ```bash
105
+ accelerate launch --multi_gpu train.py --config configs/finetune/imagenet256-latent-const.yaml --ckpt_path [path to checkpoint] --use_ckpt_path False --use_strict_load False --no_amp
106
+ ```
107
+
108
+ ### ImageNet-512x512
109
+
110
+ Train MaskDiT with 50% mask ratio with AMP enabled. Here is an example of 4-node training script.
111
+
112
+ ```bash
113
+ bash scripts/train_latent512.sh
114
+ ```
115
+
116
+ Then finetune with unmasking.
117
+
118
+ ```bash
119
+ bash scripts/finetune_latent512.sh
120
+ ```
121
+
122
+ ## Evaluation
123
+
124
+ ### FID evaluation
125
+
126
+ To compute a FID of a pretrained model, run
127
+
128
+ ```bash
129
+ accelerate launch --multi_gpu eval_latent.py --config configs/test/maskdit-256.yaml --ckpt [path to the pretrained model] --cfg_scale [guidance scale]
130
+ ```
131
+
132
+ ### Full evaluation
133
+
134
+ First, download the reference from [ADM repo](https://github.com/openai/guided-diffusion/tree/main/evaluations)
135
+ directly. You can also use `download_assets.py` by running
136
+
137
+ ```bash
138
+ python3 download_assets.py --name imagenet256 --dest [destination directory]
139
+ ```
140
+
141
+ Then we use the evaluator `evaluator.py`
142
+ from [ADM repo](https://github.com/openai/guided-diffusion/tree/main/evaluations), or `fid.py`
143
+ from [EDM repo](https://github.com/NVlabs/edm), to evaluate the generated samples.
144
+
145
+
146
+ ### Citation
147
+
148
+ ```
149
+ @inproceedings{Zheng2024MaskDiT,
150
+ title={Fast Training of Diffusion Models with Masked Transformers},
151
+ author={Zheng, Hongkai and Nie, Weili and Vahdat, Arash and Anandkumar, Anima},
152
+ booktitle = {Transactions on Machine Learning Research (TMLR)},
153
+ year={2024}
154
+ }
155
+ ```
156
+
157
+
158
+ ## Acknowledgements
159
+
160
+ Thanks to the open source codebases such as [DiT](https://github.com/facebookresearch/DiT)
161
+ , [MAE](https://github.com/facebookresearch/mae), [U-ViT](https://github.com/baofff/U-ViT)
162
+ , [ADM](https://github.com/openai/guided-diffusion), and [EDM](https://github.com/NVlabs/edm). Our codebase is built on
163
+ them.
assets/figs/12samples_compressed.png ADDED

Git LFS Details

  • SHA256: 827525f00c9ec3de222f20082efc9e78f058cb80aae48712113a22df9819f24c
  • Pointer size: 132 Bytes
  • Size of remote file: 3.59 MB
assets/figs/arch.png ADDED

Git LFS Details

  • SHA256: 958ae1cf735092a826f0c0b227b2b3628066375b8b136af78485071565da59a5
  • Pointer size: 132 Bytes
  • Size of remote file: 3.13 MB
assets/figs/bar_mem_256.png ADDED

Git LFS Details

  • SHA256: a06b76bcaf818307c4c287569b354352f1d2d7e671f417df4d198b13108685c1
  • Pointer size: 130 Bytes
  • Size of remote file: 27.9 kB
assets/figs/bar_mem_512.png ADDED

Git LFS Details

  • SHA256: de707af00e56ef9bc9c42bf7066f79723a7bcc03c7b4a930d7d14971e95c00ac
  • Pointer size: 130 Bytes
  • Size of remote file: 74.2 kB
assets/figs/bar_speed_256.png ADDED

Git LFS Details

  • SHA256: 0e6a49fd1d7b9bf50b2e092f3ceceab5b87fa7487537887a1fdcb92a242e43f3
  • Pointer size: 130 Bytes
  • Size of remote file: 25.3 kB
assets/figs/bar_speed_512.png ADDED

Git LFS Details

  • SHA256: d68bc409e3232cf2075df6bdf5183959ad8269da2d6576b995b419967a3dee23
  • Pointer size: 130 Bytes
  • Size of remote file: 72.4 kB
assets/figs/bubble_gflops_wg.png ADDED

Git LFS Details

  • SHA256: 10f764e67f101e23a35d8ea989e6c97080a817102d3913a3b886baf638b8cd20
  • Pointer size: 131 Bytes
  • Size of remote file: 318 kB
assets/figs/bubble_gflops_wog.png ADDED

Git LFS Details

  • SHA256: 209d48bcf148572491e3004f488a6709acf3def862d8b071421e5609d5009c99
  • Pointer size: 131 Bytes
  • Size of remote file: 323 kB
assets/figs/maskdit_arch.png ADDED

Git LFS Details

  • SHA256: f0d4cda130e6c3ebfa39abd33685824e369149e2ed1fbc6d22cddb06f7192d41
  • Pointer size: 132 Bytes
  • Size of remote file: 2.09 MB
assets/figs/repo_head.png ADDED

Git LFS Details

  • SHA256: fd3d4db8cd7d8147ecaa07bf11b69ddae97ec935d876e07e078e2166fba404f6
  • Pointer size: 132 Bytes
  • Size of remote file: 3.21 MB
assets/figs/sample512-set1.png ADDED

Git LFS Details

  • SHA256: 21abf3b5ba37f436dc2baf5ce7003533e2e6fa62a7361e2478e68b3d7f3772be
  • Pointer size: 132 Bytes
  • Size of remote file: 4.4 MB
assets/imagenet_label.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"0": ["n01440764", "tench"], "1": ["n01443537", "goldfish"], "2": ["n01484850", "great_white_shark"], "3": ["n01491361", "tiger_shark"], "4": ["n01494475", "hammerhead"], "5": ["n01496331", "electric_ray"], "6": ["n01498041", "stingray"], "7": ["n01514668", "cock"], "8": ["n01514859", "hen"], "9": ["n01518878", "ostrich"], "10": ["n01530575", "brambling"], "11": ["n01531178", "goldfinch"], "12": ["n01532829", "house_finch"], "13": ["n01534433", "junco"], "14": ["n01537544", "indigo_bunting"], "15": ["n01558993", "robin"], "16": ["n01560419", "bulbul"], "17": ["n01580077", "jay"], "18": ["n01582220", "magpie"], "19": ["n01592084", "chickadee"], "20": ["n01601694", "water_ouzel"], "21": ["n01608432", "kite"], "22": ["n01614925", "bald_eagle"], "23": ["n01616318", "vulture"], "24": ["n01622779", "great_grey_owl"], "25": ["n01629819", "European_fire_salamander"], "26": ["n01630670", "common_newt"], "27": ["n01631663", "eft"], "28": ["n01632458", "spotted_salamander"], "29": ["n01632777", "axolotl"], "30": ["n01641577", "bullfrog"], "31": ["n01644373", "tree_frog"], "32": ["n01644900", "tailed_frog"], "33": ["n01664065", "loggerhead"], "34": ["n01665541", "leatherback_turtle"], "35": ["n01667114", "mud_turtle"], "36": ["n01667778", "terrapin"], "37": ["n01669191", "box_turtle"], "38": ["n01675722", "banded_gecko"], "39": ["n01677366", "common_iguana"], "40": ["n01682714", "American_chameleon"], "41": ["n01685808", "whiptail"], "42": ["n01687978", "agama"], "43": ["n01688243", "frilled_lizard"], "44": ["n01689811", "alligator_lizard"], "45": ["n01692333", "Gila_monster"], "46": ["n01693334", "green_lizard"], "47": ["n01694178", "African_chameleon"], "48": ["n01695060", "Komodo_dragon"], "49": ["n01697457", "African_crocodile"], "50": ["n01698640", "American_alligator"], "51": ["n01704323", "triceratops"], "52": ["n01728572", "thunder_snake"], "53": ["n01728920", "ringneck_snake"], "54": ["n01729322", "hognose_snake"], "55": ["n01729977", "green_snake"], "56": ["n01734418", "king_snake"], "57": ["n01735189", "garter_snake"], "58": ["n01737021", "water_snake"], "59": ["n01739381", "vine_snake"], "60": ["n01740131", "night_snake"], "61": ["n01742172", "boa_constrictor"], "62": ["n01744401", "rock_python"], "63": ["n01748264", "Indian_cobra"], "64": ["n01749939", "green_mamba"], "65": ["n01751748", "sea_snake"], "66": ["n01753488", "horned_viper"], "67": ["n01755581", "diamondback"], "68": ["n01756291", "sidewinder"], "69": ["n01768244", "trilobite"], "70": ["n01770081", "harvestman"], "71": ["n01770393", "scorpion"], "72": ["n01773157", "black_and_gold_garden_spider"], "73": ["n01773549", "barn_spider"], "74": ["n01773797", "garden_spider"], "75": ["n01774384", "black_widow"], "76": ["n01774750", "tarantula"], "77": ["n01775062", "wolf_spider"], "78": ["n01776313", "tick"], "79": ["n01784675", "centipede"], "80": ["n01795545", "black_grouse"], "81": ["n01796340", "ptarmigan"], "82": ["n01797886", "ruffed_grouse"], "83": ["n01798484", "prairie_chicken"], "84": ["n01806143", "peacock"], "85": ["n01806567", "quail"], "86": ["n01807496", "partridge"], "87": ["n01817953", "African_grey"], "88": ["n01818515", "macaw"], "89": ["n01819313", "sulphur-crested_cockatoo"], "90": ["n01820546", "lorikeet"], "91": ["n01824575", "coucal"], "92": ["n01828970", "bee_eater"], "93": ["n01829413", "hornbill"], "94": ["n01833805", "hummingbird"], "95": ["n01843065", "jacamar"], "96": ["n01843383", "toucan"], "97": ["n01847000", "drake"], "98": ["n01855032", "red-breasted_merganser"], "99": ["n01855672", "goose"], "100": ["n01860187", "black_swan"], "101": ["n01871265", "tusker"], "102": ["n01872401", "echidna"], "103": ["n01873310", "platypus"], "104": ["n01877812", "wallaby"], "105": ["n01882714", "koala"], "106": ["n01883070", "wombat"], "107": ["n01910747", "jellyfish"], "108": ["n01914609", "sea_anemone"], "109": ["n01917289", "brain_coral"], "110": ["n01924916", "flatworm"], "111": ["n01930112", "nematode"], "112": ["n01943899", "conch"], "113": ["n01944390", "snail"], "114": ["n01945685", "slug"], "115": ["n01950731", "sea_slug"], "116": ["n01955084", "chiton"], "117": ["n01968897", "chambered_nautilus"], "118": ["n01978287", "Dungeness_crab"], "119": ["n01978455", "rock_crab"], "120": ["n01980166", "fiddler_crab"], "121": ["n01981276", "king_crab"], "122": ["n01983481", "American_lobster"], "123": ["n01984695", "spiny_lobster"], "124": ["n01985128", "crayfish"], "125": ["n01986214", "hermit_crab"], "126": ["n01990800", "isopod"], "127": ["n02002556", "white_stork"], "128": ["n02002724", "black_stork"], "129": ["n02006656", "spoonbill"], "130": ["n02007558", "flamingo"], "131": ["n02009229", "little_blue_heron"], "132": ["n02009912", "American_egret"], "133": ["n02011460", "bittern"], "134": ["n02012849", "crane"], "135": ["n02013706", "limpkin"], "136": ["n02017213", "European_gallinule"], "137": ["n02018207", "American_coot"], "138": ["n02018795", "bustard"], "139": ["n02025239", "ruddy_turnstone"], "140": ["n02027492", "red-backed_sandpiper"], "141": ["n02028035", "redshank"], "142": ["n02033041", "dowitcher"], "143": ["n02037110", "oystercatcher"], "144": ["n02051845", "pelican"], "145": ["n02056570", "king_penguin"], "146": ["n02058221", "albatross"], "147": ["n02066245", "grey_whale"], "148": ["n02071294", "killer_whale"], "149": ["n02074367", "dugong"], "150": ["n02077923", "sea_lion"], "151": ["n02085620", "Chihuahua"], "152": ["n02085782", "Japanese_spaniel"], "153": ["n02085936", "Maltese_dog"], "154": ["n02086079", "Pekinese"], "155": ["n02086240", "Shih-Tzu"], "156": ["n02086646", "Blenheim_spaniel"], "157": ["n02086910", "papillon"], "158": ["n02087046", "toy_terrier"], "159": ["n02087394", "Rhodesian_ridgeback"], "160": ["n02088094", "Afghan_hound"], "161": ["n02088238", "basset"], "162": ["n02088364", "beagle"], "163": ["n02088466", "bloodhound"], "164": ["n02088632", "bluetick"], "165": ["n02089078", "black-and-tan_coonhound"], "166": ["n02089867", "Walker_hound"], "167": ["n02089973", "English_foxhound"], "168": ["n02090379", "redbone"], "169": ["n02090622", "borzoi"], "170": ["n02090721", "Irish_wolfhound"], "171": ["n02091032", "Italian_greyhound"], "172": ["n02091134", "whippet"], "173": ["n02091244", "Ibizan_hound"], "174": ["n02091467", "Norwegian_elkhound"], "175": ["n02091635", "otterhound"], "176": ["n02091831", "Saluki"], "177": ["n02092002", "Scottish_deerhound"], "178": ["n02092339", "Weimaraner"], "179": ["n02093256", "Staffordshire_bullterrier"], "180": ["n02093428", "American_Staffordshire_terrier"], "181": ["n02093647", "Bedlington_terrier"], "182": ["n02093754", "Border_terrier"], "183": ["n02093859", "Kerry_blue_terrier"], "184": ["n02093991", "Irish_terrier"], "185": ["n02094114", "Norfolk_terrier"], "186": ["n02094258", "Norwich_terrier"], "187": ["n02094433", "Yorkshire_terrier"], "188": ["n02095314", "wire-haired_fox_terrier"], "189": ["n02095570", "Lakeland_terrier"], "190": ["n02095889", "Sealyham_terrier"], "191": ["n02096051", "Airedale"], "192": ["n02096177", "cairn"], "193": ["n02096294", "Australian_terrier"], "194": ["n02096437", "Dandie_Dinmont"], "195": ["n02096585", "Boston_bull"], "196": ["n02097047", "miniature_schnauzer"], "197": ["n02097130", "giant_schnauzer"], "198": ["n02097209", "standard_schnauzer"], "199": ["n02097298", "Scotch_terrier"], "200": ["n02097474", "Tibetan_terrier"], "201": ["n02097658", "silky_terrier"], "202": ["n02098105", "soft-coated_wheaten_terrier"], "203": ["n02098286", "West_Highland_white_terrier"], "204": ["n02098413", "Lhasa"], "205": ["n02099267", "flat-coated_retriever"], "206": ["n02099429", "curly-coated_retriever"], "207": ["n02099601", "golden_retriever"], "208": ["n02099712", "Labrador_retriever"], "209": ["n02099849", "Chesapeake_Bay_retriever"], "210": ["n02100236", "German_short-haired_pointer"], "211": ["n02100583", "vizsla"], "212": ["n02100735", "English_setter"], "213": ["n02100877", "Irish_setter"], "214": ["n02101006", "Gordon_setter"], "215": ["n02101388", "Brittany_spaniel"], "216": ["n02101556", "clumber"], "217": ["n02102040", "English_springer"], "218": ["n02102177", "Welsh_springer_spaniel"], "219": ["n02102318", "cocker_spaniel"], "220": ["n02102480", "Sussex_spaniel"], "221": ["n02102973", "Irish_water_spaniel"], "222": ["n02104029", "kuvasz"], "223": ["n02104365", "schipperke"], "224": ["n02105056", "groenendael"], "225": ["n02105162", "malinois"], "226": ["n02105251", "briard"], "227": ["n02105412", "kelpie"], "228": ["n02105505", "komondor"], "229": ["n02105641", "Old_English_sheepdog"], "230": ["n02105855", "Shetland_sheepdog"], "231": ["n02106030", "collie"], "232": ["n02106166", "Border_collie"], "233": ["n02106382", "Bouvier_des_Flandres"], "234": ["n02106550", "Rottweiler"], "235": ["n02106662", "German_shepherd"], "236": ["n02107142", "Doberman"], "237": ["n02107312", "miniature_pinscher"], "238": ["n02107574", "Greater_Swiss_Mountain_dog"], "239": ["n02107683", "Bernese_mountain_dog"], "240": ["n02107908", "Appenzeller"], "241": ["n02108000", "EntleBucher"], "242": ["n02108089", "boxer"], "243": ["n02108422", "bull_mastiff"], "244": ["n02108551", "Tibetan_mastiff"], "245": ["n02108915", "French_bulldog"], "246": ["n02109047", "Great_Dane"], "247": ["n02109525", "Saint_Bernard"], "248": ["n02109961", "Eskimo_dog"], "249": ["n02110063", "malamute"], "250": ["n02110185", "Siberian_husky"], "251": ["n02110341", "dalmatian"], "252": ["n02110627", "affenpinscher"], "253": ["n02110806", "basenji"], "254": ["n02110958", "pug"], "255": ["n02111129", "Leonberg"], "256": ["n02111277", "Newfoundland"], "257": ["n02111500", "Great_Pyrenees"], "258": ["n02111889", "Samoyed"], "259": ["n02112018", "Pomeranian"], "260": ["n02112137", "chow"], "261": ["n02112350", "keeshond"], "262": ["n02112706", "Brabancon_griffon"], "263": ["n02113023", "Pembroke"], "264": ["n02113186", "Cardigan"], "265": ["n02113624", "toy_poodle"], "266": ["n02113712", "miniature_poodle"], "267": ["n02113799", "standard_poodle"], "268": ["n02113978", "Mexican_hairless"], "269": ["n02114367", "timber_wolf"], "270": ["n02114548", "white_wolf"], "271": ["n02114712", "red_wolf"], "272": ["n02114855", "coyote"], "273": ["n02115641", "dingo"], "274": ["n02115913", "dhole"], "275": ["n02116738", "African_hunting_dog"], "276": ["n02117135", "hyena"], "277": ["n02119022", "red_fox"], "278": ["n02119789", "kit_fox"], "279": ["n02120079", "Arctic_fox"], "280": ["n02120505", "grey_fox"], "281": ["n02123045", "tabby"], "282": ["n02123159", "tiger_cat"], "283": ["n02123394", "Persian_cat"], "284": ["n02123597", "Siamese_cat"], "285": ["n02124075", "Egyptian_cat"], "286": ["n02125311", "cougar"], "287": ["n02127052", "lynx"], "288": ["n02128385", "leopard"], "289": ["n02128757", "snow_leopard"], "290": ["n02128925", "jaguar"], "291": ["n02129165", "lion"], "292": ["n02129604", "tiger"], "293": ["n02130308", "cheetah"], "294": ["n02132136", "brown_bear"], "295": ["n02133161", "American_black_bear"], "296": ["n02134084", "ice_bear"], "297": ["n02134418", "sloth_bear"], "298": ["n02137549", "mongoose"], "299": ["n02138441", "meerkat"], "300": ["n02165105", "tiger_beetle"], "301": ["n02165456", "ladybug"], "302": ["n02167151", "ground_beetle"], "303": ["n02168699", "long-horned_beetle"], "304": ["n02169497", "leaf_beetle"], "305": ["n02172182", "dung_beetle"], "306": ["n02174001", "rhinoceros_beetle"], "307": ["n02177972", "weevil"], "308": ["n02190166", "fly"], "309": ["n02206856", "bee"], "310": ["n02219486", "ant"], "311": ["n02226429", "grasshopper"], "312": ["n02229544", "cricket"], "313": ["n02231487", "walking_stick"], "314": ["n02233338", "cockroach"], "315": ["n02236044", "mantis"], "316": ["n02256656", "cicada"], "317": ["n02259212", "leafhopper"], "318": ["n02264363", "lacewing"], "319": ["n02268443", "dragonfly"], "320": ["n02268853", "damselfly"], "321": ["n02276258", "admiral"], "322": ["n02277742", "ringlet"], "323": ["n02279972", "monarch"], "324": ["n02280649", "cabbage_butterfly"], "325": ["n02281406", "sulphur_butterfly"], "326": ["n02281787", "lycaenid"], "327": ["n02317335", "starfish"], "328": ["n02319095", "sea_urchin"], "329": ["n02321529", "sea_cucumber"], "330": ["n02325366", "wood_rabbit"], "331": ["n02326432", "hare"], "332": ["n02328150", "Angora"], "333": ["n02342885", "hamster"], "334": ["n02346627", "porcupine"], "335": ["n02356798", "fox_squirrel"], "336": ["n02361337", "marmot"], "337": ["n02363005", "beaver"], "338": ["n02364673", "guinea_pig"], "339": ["n02389026", "sorrel"], "340": ["n02391049", "zebra"], "341": ["n02395406", "hog"], "342": ["n02396427", "wild_boar"], "343": ["n02397096", "warthog"], "344": ["n02398521", "hippopotamus"], "345": ["n02403003", "ox"], "346": ["n02408429", "water_buffalo"], "347": ["n02410509", "bison"], "348": ["n02412080", "ram"], "349": ["n02415577", "bighorn"], "350": ["n02417914", "ibex"], "351": ["n02422106", "hartebeest"], "352": ["n02422699", "impala"], "353": ["n02423022", "gazelle"], "354": ["n02437312", "Arabian_camel"], "355": ["n02437616", "llama"], "356": ["n02441942", "weasel"], "357": ["n02442845", "mink"], "358": ["n02443114", "polecat"], "359": ["n02443484", "black-footed_ferret"], "360": ["n02444819", "otter"], "361": ["n02445715", "skunk"], "362": ["n02447366", "badger"], "363": ["n02454379", "armadillo"], "364": ["n02457408", "three-toed_sloth"], "365": ["n02480495", "orangutan"], "366": ["n02480855", "gorilla"], "367": ["n02481823", "chimpanzee"], "368": ["n02483362", "gibbon"], "369": ["n02483708", "siamang"], "370": ["n02484975", "guenon"], "371": ["n02486261", "patas"], "372": ["n02486410", "baboon"], "373": ["n02487347", "macaque"], "374": ["n02488291", "langur"], "375": ["n02488702", "colobus"], "376": ["n02489166", "proboscis_monkey"], "377": ["n02490219", "marmoset"], "378": ["n02492035", "capuchin"], "379": ["n02492660", "howler_monkey"], "380": ["n02493509", "titi"], "381": ["n02493793", "spider_monkey"], "382": ["n02494079", "squirrel_monkey"], "383": ["n02497673", "Madagascar_cat"], "384": ["n02500267", "indri"], "385": ["n02504013", "Indian_elephant"], "386": ["n02504458", "African_elephant"], "387": ["n02509815", "lesser_panda"], "388": ["n02510455", "giant_panda"], "389": ["n02514041", "barracouta"], "390": ["n02526121", "eel"], "391": ["n02536864", "coho"], "392": ["n02606052", "rock_beauty"], "393": ["n02607072", "anemone_fish"], "394": ["n02640242", "sturgeon"], "395": ["n02641379", "gar"], "396": ["n02643566", "lionfish"], "397": ["n02655020", "puffer"], "398": ["n02666196", "abacus"], "399": ["n02667093", "abaya"], "400": ["n02669723", "academic_gown"], "401": ["n02672831", "accordion"], "402": ["n02676566", "acoustic_guitar"], "403": ["n02687172", "aircraft_carrier"], "404": ["n02690373", "airliner"], "405": ["n02692877", "airship"], "406": ["n02699494", "altar"], "407": ["n02701002", "ambulance"], "408": ["n02704792", "amphibian"], "409": ["n02708093", "analog_clock"], "410": ["n02727426", "apiary"], "411": ["n02730930", "apron"], "412": ["n02747177", "ashcan"], "413": ["n02749479", "assault_rifle"], "414": ["n02769748", "backpack"], "415": ["n02776631", "bakery"], "416": ["n02777292", "balance_beam"], "417": ["n02782093", "balloon"], "418": ["n02783161", "ballpoint"], "419": ["n02786058", "Band_Aid"], "420": ["n02787622", "banjo"], "421": ["n02788148", "bannister"], "422": ["n02790996", "barbell"], "423": ["n02791124", "barber_chair"], "424": ["n02791270", "barbershop"], "425": ["n02793495", "barn"], "426": ["n02794156", "barometer"], "427": ["n02795169", "barrel"], "428": ["n02797295", "barrow"], "429": ["n02799071", "baseball"], "430": ["n02802426", "basketball"], "431": ["n02804414", "bassinet"], "432": ["n02804610", "bassoon"], "433": ["n02807133", "bathing_cap"], "434": ["n02808304", "bath_towel"], "435": ["n02808440", "bathtub"], "436": ["n02814533", "beach_wagon"], "437": ["n02814860", "beacon"], "438": ["n02815834", "beaker"], "439": ["n02817516", "bearskin"], "440": ["n02823428", "beer_bottle"], "441": ["n02823750", "beer_glass"], "442": ["n02825657", "bell_cote"], "443": ["n02834397", "bib"], "444": ["n02835271", "bicycle-built-for-two"], "445": ["n02837789", "bikini"], "446": ["n02840245", "binder"], "447": ["n02841315", "binoculars"], "448": ["n02843684", "birdhouse"], "449": ["n02859443", "boathouse"], "450": ["n02860847", "bobsled"], "451": ["n02865351", "bolo_tie"], "452": ["n02869837", "bonnet"], "453": ["n02870880", "bookcase"], "454": ["n02871525", "bookshop"], "455": ["n02877765", "bottlecap"], "456": ["n02879718", "bow"], "457": ["n02883205", "bow_tie"], "458": ["n02892201", "brass"], "459": ["n02892767", "brassiere"], "460": ["n02894605", "breakwater"], "461": ["n02895154", "breastplate"], "462": ["n02906734", "broom"], "463": ["n02909870", "bucket"], "464": ["n02910353", "buckle"], "465": ["n02916936", "bulletproof_vest"], "466": ["n02917067", "bullet_train"], "467": ["n02927161", "butcher_shop"], "468": ["n02930766", "cab"], "469": ["n02939185", "caldron"], "470": ["n02948072", "candle"], "471": ["n02950826", "cannon"], "472": ["n02951358", "canoe"], "473": ["n02951585", "can_opener"], "474": ["n02963159", "cardigan"], "475": ["n02965783", "car_mirror"], "476": ["n02966193", "carousel"], "477": ["n02966687", "carpenter's_kit"], "478": ["n02971356", "carton"], "479": ["n02974003", "car_wheel"], "480": ["n02977058", "cash_machine"], "481": ["n02978881", "cassette"], "482": ["n02979186", "cassette_player"], "483": ["n02980441", "castle"], "484": ["n02981792", "catamaran"], "485": ["n02988304", "CD_player"], "486": ["n02992211", "cello"], "487": ["n02992529", "cellular_telephone"], "488": ["n02999410", "chain"], "489": ["n03000134", "chainlink_fence"], "490": ["n03000247", "chain_mail"], "491": ["n03000684", "chain_saw"], "492": ["n03014705", "chest"], "493": ["n03016953", "chiffonier"], "494": ["n03017168", "chime"], "495": ["n03018349", "china_cabinet"], "496": ["n03026506", "Christmas_stocking"], "497": ["n03028079", "church"], "498": ["n03032252", "cinema"], "499": ["n03041632", "cleaver"], "500": ["n03042490", "cliff_dwelling"], "501": ["n03045698", "cloak"], "502": ["n03047690", "clog"], "503": ["n03062245", "cocktail_shaker"], "504": ["n03063599", "coffee_mug"], "505": ["n03063689", "coffeepot"], "506": ["n03065424", "coil"], "507": ["n03075370", "combination_lock"], "508": ["n03085013", "computer_keyboard"], "509": ["n03089624", "confectionery"], "510": ["n03095699", "container_ship"], "511": ["n03100240", "convertible"], "512": ["n03109150", "corkscrew"], "513": ["n03110669", "cornet"], "514": ["n03124043", "cowboy_boot"], "515": ["n03124170", "cowboy_hat"], "516": ["n03125729", "cradle"], "517": ["n03126707", "crane"], "518": ["n03127747", "crash_helmet"], "519": ["n03127925", "crate"], "520": ["n03131574", "crib"], "521": ["n03133878", "Crock_Pot"], "522": ["n03134739", "croquet_ball"], "523": ["n03141823", "crutch"], "524": ["n03146219", "cuirass"], "525": ["n03160309", "dam"], "526": ["n03179701", "desk"], "527": ["n03180011", "desktop_computer"], "528": ["n03187595", "dial_telephone"], "529": ["n03188531", "diaper"], "530": ["n03196217", "digital_clock"], "531": ["n03197337", "digital_watch"], "532": ["n03201208", "dining_table"], "533": ["n03207743", "dishrag"], "534": ["n03207941", "dishwasher"], "535": ["n03208938", "disk_brake"], "536": ["n03216828", "dock"], "537": ["n03218198", "dogsled"], "538": ["n03220513", "dome"], "539": ["n03223299", "doormat"], "540": ["n03240683", "drilling_platform"], "541": ["n03249569", "drum"], "542": ["n03250847", "drumstick"], "543": ["n03255030", "dumbbell"], "544": ["n03259280", "Dutch_oven"], "545": ["n03271574", "electric_fan"], "546": ["n03272010", "electric_guitar"], "547": ["n03272562", "electric_locomotive"], "548": ["n03290653", "entertainment_center"], "549": ["n03291819", "envelope"], "550": ["n03297495", "espresso_maker"], "551": ["n03314780", "face_powder"], "552": ["n03325584", "feather_boa"], "553": ["n03337140", "file"], "554": ["n03344393", "fireboat"], "555": ["n03345487", "fire_engine"], "556": ["n03347037", "fire_screen"], "557": ["n03355925", "flagpole"], "558": ["n03372029", "flute"], "559": ["n03376595", "folding_chair"], "560": ["n03379051", "football_helmet"], "561": ["n03384352", "forklift"], "562": ["n03388043", "fountain"], "563": ["n03388183", "fountain_pen"], "564": ["n03388549", "four-poster"], "565": ["n03393912", "freight_car"], "566": ["n03394916", "French_horn"], "567": ["n03400231", "frying_pan"], "568": ["n03404251", "fur_coat"], "569": ["n03417042", "garbage_truck"], "570": ["n03424325", "gasmask"], "571": ["n03425413", "gas_pump"], "572": ["n03443371", "goblet"], "573": ["n03444034", "go-kart"], "574": ["n03445777", "golf_ball"], "575": ["n03445924", "golfcart"], "576": ["n03447447", "gondola"], "577": ["n03447721", "gong"], "578": ["n03450230", "gown"], "579": ["n03452741", "grand_piano"], "580": ["n03457902", "greenhouse"], "581": ["n03459775", "grille"], "582": ["n03461385", "grocery_store"], "583": ["n03467068", "guillotine"], "584": ["n03476684", "hair_slide"], "585": ["n03476991", "hair_spray"], "586": ["n03478589", "half_track"], "587": ["n03481172", "hammer"], "588": ["n03482405", "hamper"], "589": ["n03483316", "hand_blower"], "590": ["n03485407", "hand-held_computer"], "591": ["n03485794", "handkerchief"], "592": ["n03492542", "hard_disc"], "593": ["n03494278", "harmonica"], "594": ["n03495258", "harp"], "595": ["n03496892", "harvester"], "596": ["n03498962", "hatchet"], "597": ["n03527444", "holster"], "598": ["n03529860", "home_theater"], "599": ["n03530642", "honeycomb"], "600": ["n03532672", "hook"], "601": ["n03534580", "hoopskirt"], "602": ["n03535780", "horizontal_bar"], "603": ["n03538406", "horse_cart"], "604": ["n03544143", "hourglass"], "605": ["n03584254", "iPod"], "606": ["n03584829", "iron"], "607": ["n03590841", "jack-o'-lantern"], "608": ["n03594734", "jean"], "609": ["n03594945", "jeep"], "610": ["n03595614", "jersey"], "611": ["n03598930", "jigsaw_puzzle"], "612": ["n03599486", "jinrikisha"], "613": ["n03602883", "joystick"], "614": ["n03617480", "kimono"], "615": ["n03623198", "knee_pad"], "616": ["n03627232", "knot"], "617": ["n03630383", "lab_coat"], "618": ["n03633091", "ladle"], "619": ["n03637318", "lampshade"], "620": ["n03642806", "laptop"], "621": ["n03649909", "lawn_mower"], "622": ["n03657121", "lens_cap"], "623": ["n03658185", "letter_opener"], "624": ["n03661043", "library"], "625": ["n03662601", "lifeboat"], "626": ["n03666591", "lighter"], "627": ["n03670208", "limousine"], "628": ["n03673027", "liner"], "629": ["n03676483", "lipstick"], "630": ["n03680355", "Loafer"], "631": ["n03690938", "lotion"], "632": ["n03691459", "loudspeaker"], "633": ["n03692522", "loupe"], "634": ["n03697007", "lumbermill"], "635": ["n03706229", "magnetic_compass"], "636": ["n03709823", "mailbag"], "637": ["n03710193", "mailbox"], "638": ["n03710637", "maillot"], "639": ["n03710721", "maillot"], "640": ["n03717622", "manhole_cover"], "641": ["n03720891", "maraca"], "642": ["n03721384", "marimba"], "643": ["n03724870", "mask"], "644": ["n03729826", "matchstick"], "645": ["n03733131", "maypole"], "646": ["n03733281", "maze"], "647": ["n03733805", "measuring_cup"], "648": ["n03742115", "medicine_chest"], "649": ["n03743016", "megalith"], "650": ["n03759954", "microphone"], "651": ["n03761084", "microwave"], "652": ["n03763968", "military_uniform"], "653": ["n03764736", "milk_can"], "654": ["n03769881", "minibus"], "655": ["n03770439", "miniskirt"], "656": ["n03770679", "minivan"], "657": ["n03773504", "missile"], "658": ["n03775071", "mitten"], "659": ["n03775546", "mixing_bowl"], "660": ["n03776460", "mobile_home"], "661": ["n03777568", "Model_T"], "662": ["n03777754", "modem"], "663": ["n03781244", "monastery"], "664": ["n03782006", "monitor"], "665": ["n03785016", "moped"], "666": ["n03786901", "mortar"], "667": ["n03787032", "mortarboard"], "668": ["n03788195", "mosque"], "669": ["n03788365", "mosquito_net"], "670": ["n03791053", "motor_scooter"], "671": ["n03792782", "mountain_bike"], "672": ["n03792972", "mountain_tent"], "673": ["n03793489", "mouse"], "674": ["n03794056", "mousetrap"], "675": ["n03796401", "moving_van"], "676": ["n03803284", "muzzle"], "677": ["n03804744", "nail"], "678": ["n03814639", "neck_brace"], "679": ["n03814906", "necklace"], "680": ["n03825788", "nipple"], "681": ["n03832673", "notebook"], "682": ["n03837869", "obelisk"], "683": ["n03838899", "oboe"], "684": ["n03840681", "ocarina"], "685": ["n03841143", "odometer"], "686": ["n03843555", "oil_filter"], "687": ["n03854065", "organ"], "688": ["n03857828", "oscilloscope"], "689": ["n03866082", "overskirt"], "690": ["n03868242", "oxcart"], "691": ["n03868863", "oxygen_mask"], "692": ["n03871628", "packet"], "693": ["n03873416", "paddle"], "694": ["n03874293", "paddlewheel"], "695": ["n03874599", "padlock"], "696": ["n03876231", "paintbrush"], "697": ["n03877472", "pajama"], "698": ["n03877845", "palace"], "699": ["n03884397", "panpipe"], "700": ["n03887697", "paper_towel"], "701": ["n03888257", "parachute"], "702": ["n03888605", "parallel_bars"], "703": ["n03891251", "park_bench"], "704": ["n03891332", "parking_meter"], "705": ["n03895866", "passenger_car"], "706": ["n03899768", "patio"], "707": ["n03902125", "pay-phone"], "708": ["n03903868", "pedestal"], "709": ["n03908618", "pencil_box"], "710": ["n03908714", "pencil_sharpener"], "711": ["n03916031", "perfume"], "712": ["n03920288", "Petri_dish"], "713": ["n03924679", "photocopier"], "714": ["n03929660", "pick"], "715": ["n03929855", "pickelhaube"], "716": ["n03930313", "picket_fence"], "717": ["n03930630", "pickup"], "718": ["n03933933", "pier"], "719": ["n03935335", "piggy_bank"], "720": ["n03937543", "pill_bottle"], "721": ["n03938244", "pillow"], "722": ["n03942813", "ping-pong_ball"], "723": ["n03944341", "pinwheel"], "724": ["n03947888", "pirate"], "725": ["n03950228", "pitcher"], "726": ["n03954731", "plane"], "727": ["n03956157", "planetarium"], "728": ["n03958227", "plastic_bag"], "729": ["n03961711", "plate_rack"], "730": ["n03967562", "plow"], "731": ["n03970156", "plunger"], "732": ["n03976467", "Polaroid_camera"], "733": ["n03976657", "pole"], "734": ["n03977966", "police_van"], "735": ["n03980874", "poncho"], "736": ["n03982430", "pool_table"], "737": ["n03983396", "pop_bottle"], "738": ["n03991062", "pot"], "739": ["n03992509", "potter's_wheel"], "740": ["n03995372", "power_drill"], "741": ["n03998194", "prayer_rug"], "742": ["n04004767", "printer"], "743": ["n04005630", "prison"], "744": ["n04008634", "projectile"], "745": ["n04009552", "projector"], "746": ["n04019541", "puck"], "747": ["n04023962", "punching_bag"], "748": ["n04026417", "purse"], "749": ["n04033901", "quill"], "750": ["n04033995", "quilt"], "751": ["n04037443", "racer"], "752": ["n04039381", "racket"], "753": ["n04040759", "radiator"], "754": ["n04041544", "radio"], "755": ["n04044716", "radio_telescope"], "756": ["n04049303", "rain_barrel"], "757": ["n04065272", "recreational_vehicle"], "758": ["n04067472", "reel"], "759": ["n04069434", "reflex_camera"], "760": ["n04070727", "refrigerator"], "761": ["n04074963", "remote_control"], "762": ["n04081281", "restaurant"], "763": ["n04086273", "revolver"], "764": ["n04090263", "rifle"], "765": ["n04099969", "rocking_chair"], "766": ["n04111531", "rotisserie"], "767": ["n04116512", "rubber_eraser"], "768": ["n04118538", "rugby_ball"], "769": ["n04118776", "rule"], "770": ["n04120489", "running_shoe"], "771": ["n04125021", "safe"], "772": ["n04127249", "safety_pin"], "773": ["n04131690", "saltshaker"], "774": ["n04133789", "sandal"], "775": ["n04136333", "sarong"], "776": ["n04141076", "sax"], "777": ["n04141327", "scabbard"], "778": ["n04141975", "scale"], "779": ["n04146614", "school_bus"], "780": ["n04147183", "schooner"], "781": ["n04149813", "scoreboard"], "782": ["n04152593", "screen"], "783": ["n04153751", "screw"], "784": ["n04154565", "screwdriver"], "785": ["n04162706", "seat_belt"], "786": ["n04179913", "sewing_machine"], "787": ["n04192698", "shield"], "788": ["n04200800", "shoe_shop"], "789": ["n04201297", "shoji"], "790": ["n04204238", "shopping_basket"], "791": ["n04204347", "shopping_cart"], "792": ["n04208210", "shovel"], "793": ["n04209133", "shower_cap"], "794": ["n04209239", "shower_curtain"], "795": ["n04228054", "ski"], "796": ["n04229816", "ski_mask"], "797": ["n04235860", "sleeping_bag"], "798": ["n04238763", "slide_rule"], "799": ["n04239074", "sliding_door"], "800": ["n04243546", "slot"], "801": ["n04251144", "snorkel"], "802": ["n04252077", "snowmobile"], "803": ["n04252225", "snowplow"], "804": ["n04254120", "soap_dispenser"], "805": ["n04254680", "soccer_ball"], "806": ["n04254777", "sock"], "807": ["n04258138", "solar_dish"], "808": ["n04259630", "sombrero"], "809": ["n04263257", "soup_bowl"], "810": ["n04264628", "space_bar"], "811": ["n04265275", "space_heater"], "812": ["n04266014", "space_shuttle"], "813": ["n04270147", "spatula"], "814": ["n04273569", "speedboat"], "815": ["n04275548", "spider_web"], "816": ["n04277352", "spindle"], "817": ["n04285008", "sports_car"], "818": ["n04286575", "spotlight"], "819": ["n04296562", "stage"], "820": ["n04310018", "steam_locomotive"], "821": ["n04311004", "steel_arch_bridge"], "822": ["n04311174", "steel_drum"], "823": ["n04317175", "stethoscope"], "824": ["n04325704", "stole"], "825": ["n04326547", "stone_wall"], "826": ["n04328186", "stopwatch"], "827": ["n04330267", "stove"], "828": ["n04332243", "strainer"], "829": ["n04335435", "streetcar"], "830": ["n04336792", "stretcher"], "831": ["n04344873", "studio_couch"], "832": ["n04346328", "stupa"], "833": ["n04347754", "submarine"], "834": ["n04350905", "suit"], "835": ["n04355338", "sundial"], "836": ["n04355933", "sunglass"], "837": ["n04356056", "sunglasses"], "838": ["n04357314", "sunscreen"], "839": ["n04366367", "suspension_bridge"], "840": ["n04367480", "swab"], "841": ["n04370456", "sweatshirt"], "842": ["n04371430", "swimming_trunks"], "843": ["n04371774", "swing"], "844": ["n04372370", "switch"], "845": ["n04376876", "syringe"], "846": ["n04380533", "table_lamp"], "847": ["n04389033", "tank"], "848": ["n04392985", "tape_player"], "849": ["n04398044", "teapot"], "850": ["n04399382", "teddy"], "851": ["n04404412", "television"], "852": ["n04409515", "tennis_ball"], "853": ["n04417672", "thatch"], "854": ["n04418357", "theater_curtain"], "855": ["n04423845", "thimble"], "856": ["n04428191", "thresher"], "857": ["n04429376", "throne"], "858": ["n04435653", "tile_roof"], "859": ["n04442312", "toaster"], "860": ["n04443257", "tobacco_shop"], "861": ["n04447861", "toilet_seat"], "862": ["n04456115", "torch"], "863": ["n04458633", "totem_pole"], "864": ["n04461696", "tow_truck"], "865": ["n04462240", "toyshop"], "866": ["n04465501", "tractor"], "867": ["n04467665", "trailer_truck"], "868": ["n04476259", "tray"], "869": ["n04479046", "trench_coat"], "870": ["n04482393", "tricycle"], "871": ["n04483307", "trimaran"], "872": ["n04485082", "tripod"], "873": ["n04486054", "triumphal_arch"], "874": ["n04487081", "trolleybus"], "875": ["n04487394", "trombone"], "876": ["n04493381", "tub"], "877": ["n04501370", "turnstile"], "878": ["n04505470", "typewriter_keyboard"], "879": ["n04507155", "umbrella"], "880": ["n04509417", "unicycle"], "881": ["n04515003", "upright"], "882": ["n04517823", "vacuum"], "883": ["n04522168", "vase"], "884": ["n04523525", "vault"], "885": ["n04525038", "velvet"], "886": ["n04525305", "vending_machine"], "887": ["n04532106", "vestment"], "888": ["n04532670", "viaduct"], "889": ["n04536866", "violin"], "890": ["n04540053", "volleyball"], "891": ["n04542943", "waffle_iron"], "892": ["n04548280", "wall_clock"], "893": ["n04548362", "wallet"], "894": ["n04550184", "wardrobe"], "895": ["n04552348", "warplane"], "896": ["n04553703", "washbasin"], "897": ["n04554684", "washer"], "898": ["n04557648", "water_bottle"], "899": ["n04560804", "water_jug"], "900": ["n04562935", "water_tower"], "901": ["n04579145", "whiskey_jug"], "902": ["n04579432", "whistle"], "903": ["n04584207", "wig"], "904": ["n04589890", "window_screen"], "905": ["n04590129", "window_shade"], "906": ["n04591157", "Windsor_tie"], "907": ["n04591713", "wine_bottle"], "908": ["n04592741", "wing"], "909": ["n04596742", "wok"], "910": ["n04597913", "wooden_spoon"], "911": ["n04599235", "wool"], "912": ["n04604644", "worm_fence"], "913": ["n04606251", "wreck"], "914": ["n04612504", "yawl"], "915": ["n04613696", "yurt"], "916": ["n06359193", "web_site"], "917": ["n06596364", "comic_book"], "918": ["n06785654", "crossword_puzzle"], "919": ["n06794110", "street_sign"], "920": ["n06874185", "traffic_light"], "921": ["n07248320", "book_jacket"], "922": ["n07565083", "menu"], "923": ["n07579787", "plate"], "924": ["n07583066", "guacamole"], "925": ["n07584110", "consomme"], "926": ["n07590611", "hot_pot"], "927": ["n07613480", "trifle"], "928": ["n07614500", "ice_cream"], "929": ["n07615774", "ice_lolly"], "930": ["n07684084", "French_loaf"], "931": ["n07693725", "bagel"], "932": ["n07695742", "pretzel"], "933": ["n07697313", "cheeseburger"], "934": ["n07697537", "hotdog"], "935": ["n07711569", "mashed_potato"], "936": ["n07714571", "head_cabbage"], "937": ["n07714990", "broccoli"], "938": ["n07715103", "cauliflower"], "939": ["n07716358", "zucchini"], "940": ["n07716906", "spaghetti_squash"], "941": ["n07717410", "acorn_squash"], "942": ["n07717556", "butternut_squash"], "943": ["n07718472", "cucumber"], "944": ["n07718747", "artichoke"], "945": ["n07720875", "bell_pepper"], "946": ["n07730033", "cardoon"], "947": ["n07734744", "mushroom"], "948": ["n07742313", "Granny_Smith"], "949": ["n07745940", "strawberry"], "950": ["n07747607", "orange"], "951": ["n07749582", "lemon"], "952": ["n07753113", "fig"], "953": ["n07753275", "pineapple"], "954": ["n07753592", "banana"], "955": ["n07754684", "jackfruit"], "956": ["n07760859", "custard_apple"], "957": ["n07768694", "pomegranate"], "958": ["n07802026", "hay"], "959": ["n07831146", "carbonara"], "960": ["n07836838", "chocolate_sauce"], "961": ["n07860988", "dough"], "962": ["n07871810", "meat_loaf"], "963": ["n07873807", "pizza"], "964": ["n07875152", "potpie"], "965": ["n07880968", "burrito"], "966": ["n07892512", "red_wine"], "967": ["n07920052", "espresso"], "968": ["n07930864", "cup"], "969": ["n07932039", "eggnog"], "970": ["n09193705", "alp"], "971": ["n09229709", "bubble"], "972": ["n09246464", "cliff"], "973": ["n09256479", "coral_reef"], "974": ["n09288635", "geyser"], "975": ["n09332890", "lakeside"], "976": ["n09399592", "promontory"], "977": ["n09421951", "sandbar"], "978": ["n09428293", "seashore"], "979": ["n09468604", "valley"], "980": ["n09472597", "volcano"], "981": ["n09835506", "ballplayer"], "982": ["n10148035", "groom"], "983": ["n10565667", "scuba_diver"], "984": ["n11879895", "rapeseed"], "985": ["n11939491", "daisy"], "986": ["n12057211", "yellow_lady's_slipper"], "987": ["n12144580", "corn"], "988": ["n12267677", "acorn"], "989": ["n12620546", "hip"], "990": ["n12768682", "buckeye"], "991": ["n12985857", "coral_fungus"], "992": ["n12998815", "agaric"], "993": ["n13037406", "gyromitra"], "994": ["n13040303", "stinkhorn"], "995": ["n13044778", "earthstar"], "996": ["n13052670", "hen-of-the-woods"], "997": ["n13054560", "bolete"], "998": ["n13133613", "ear"], "999": ["n15075141", "toilet_tissue"]}
autoencoder.py ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Borrowed from U-ViT: https://github.com/baofff/U-ViT/blob/main/libs/autoencoder.py.
2
+ # The original code is licensed under MIT License, which is can be found at licenses/LICENSE_UVIT.txt.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import numpy as np
7
+ from einops import rearrange
8
+
9
+
10
+ class LinearAttention(nn.Module):
11
+ def __init__(self, dim, heads=4, dim_head=32):
12
+ super().__init__()
13
+ self.heads = heads
14
+ hidden_dim = dim_head * heads
15
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
16
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
17
+
18
+ def forward(self, x):
19
+ b, c, h, w = x.shape
20
+ qkv = self.to_qkv(x)
21
+ q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
22
+ k = k.softmax(dim=-1)
23
+ context = torch.einsum('bhdn,bhen->bhde', k, v)
24
+ out = torch.einsum('bhde,bhdn->bhen', context, q)
25
+ out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
26
+ return self.to_out(out)
27
+
28
+
29
+ def nonlinearity(x):
30
+ # swish
31
+ return x*torch.sigmoid(x)
32
+
33
+
34
+ def Normalize(in_channels, num_groups=32):
35
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
36
+
37
+
38
+ class Upsample(nn.Module):
39
+ def __init__(self, in_channels, with_conv):
40
+ super().__init__()
41
+ self.with_conv = with_conv
42
+ if self.with_conv:
43
+ self.conv = torch.nn.Conv2d(in_channels,
44
+ in_channels,
45
+ kernel_size=3,
46
+ stride=1,
47
+ padding=1)
48
+
49
+ def forward(self, x):
50
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
51
+ if self.with_conv:
52
+ x = self.conv(x)
53
+ return x
54
+
55
+
56
+ class Downsample(nn.Module):
57
+ def __init__(self, in_channels, with_conv):
58
+ super().__init__()
59
+ self.with_conv = with_conv
60
+ if self.with_conv:
61
+ # no asymmetric padding in torch conv, must do it ourselves
62
+ self.conv = torch.nn.Conv2d(in_channels,
63
+ in_channels,
64
+ kernel_size=3,
65
+ stride=2,
66
+ padding=0)
67
+
68
+ def forward(self, x):
69
+ if self.with_conv:
70
+ pad = (0,1,0,1)
71
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
72
+ x = self.conv(x)
73
+ else:
74
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
75
+ return x
76
+
77
+
78
+ class ResnetBlock(nn.Module):
79
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
80
+ dropout, temb_channels=512):
81
+ super().__init__()
82
+ self.in_channels = in_channels
83
+ out_channels = in_channels if out_channels is None else out_channels
84
+ self.out_channels = out_channels
85
+ self.use_conv_shortcut = conv_shortcut
86
+
87
+ self.norm1 = Normalize(in_channels)
88
+ self.conv1 = torch.nn.Conv2d(in_channels,
89
+ out_channels,
90
+ kernel_size=3,
91
+ stride=1,
92
+ padding=1)
93
+ if temb_channels > 0:
94
+ self.temb_proj = torch.nn.Linear(temb_channels,
95
+ out_channels)
96
+ self.norm2 = Normalize(out_channels)
97
+ self.dropout = torch.nn.Dropout(dropout)
98
+ self.conv2 = torch.nn.Conv2d(out_channels,
99
+ out_channels,
100
+ kernel_size=3,
101
+ stride=1,
102
+ padding=1)
103
+ if self.in_channels != self.out_channels:
104
+ if self.use_conv_shortcut:
105
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
106
+ out_channels,
107
+ kernel_size=3,
108
+ stride=1,
109
+ padding=1)
110
+ else:
111
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
112
+ out_channels,
113
+ kernel_size=1,
114
+ stride=1,
115
+ padding=0)
116
+
117
+ def forward(self, x, temb):
118
+ h = x
119
+ h = self.norm1(h)
120
+ h = nonlinearity(h)
121
+ h = self.conv1(h)
122
+
123
+ if temb is not None:
124
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
125
+
126
+ h = self.norm2(h)
127
+ h = nonlinearity(h)
128
+ h = self.dropout(h)
129
+ h = self.conv2(h)
130
+
131
+ if self.in_channels != self.out_channels:
132
+ if self.use_conv_shortcut:
133
+ x = self.conv_shortcut(x)
134
+ else:
135
+ x = self.nin_shortcut(x)
136
+
137
+ return x+h
138
+
139
+
140
+ class LinAttnBlock(LinearAttention):
141
+ """to match AttnBlock usage"""
142
+ def __init__(self, in_channels):
143
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
144
+
145
+
146
+ class AttnBlock(nn.Module):
147
+ def __init__(self, in_channels):
148
+ super().__init__()
149
+ self.in_channels = in_channels
150
+
151
+ self.norm = Normalize(in_channels)
152
+ self.q = torch.nn.Conv2d(in_channels,
153
+ in_channels,
154
+ kernel_size=1,
155
+ stride=1,
156
+ padding=0)
157
+ self.k = torch.nn.Conv2d(in_channels,
158
+ in_channels,
159
+ kernel_size=1,
160
+ stride=1,
161
+ padding=0)
162
+ self.v = torch.nn.Conv2d(in_channels,
163
+ in_channels,
164
+ kernel_size=1,
165
+ stride=1,
166
+ padding=0)
167
+ self.proj_out = torch.nn.Conv2d(in_channels,
168
+ in_channels,
169
+ kernel_size=1,
170
+ stride=1,
171
+ padding=0)
172
+
173
+
174
+ def forward(self, x):
175
+ h_ = x
176
+ h_ = self.norm(h_)
177
+ q = self.q(h_)
178
+ k = self.k(h_)
179
+ v = self.v(h_)
180
+
181
+ # compute attention
182
+ b,c,h,w = q.shape
183
+ q = q.reshape(b,c,h*w)
184
+ q = q.permute(0,2,1) # b,hw,c
185
+ k = k.reshape(b,c,h*w) # b,c,hw
186
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
187
+ w_ = w_ * (int(c)**(-0.5))
188
+ w_ = torch.nn.functional.softmax(w_, dim=2)
189
+
190
+ # attend to values
191
+ v = v.reshape(b,c,h*w)
192
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
193
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
194
+ h_ = h_.reshape(b,c,h,w)
195
+
196
+ h_ = self.proj_out(h_)
197
+
198
+ return x+h_
199
+
200
+
201
+ def make_attn(in_channels, attn_type="vanilla"):
202
+ assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
203
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
204
+ if attn_type == "vanilla":
205
+ return AttnBlock(in_channels)
206
+ elif attn_type == "none":
207
+ return nn.Identity(in_channels)
208
+ else:
209
+ return LinAttnBlock(in_channels)
210
+
211
+
212
+ class Encoder(nn.Module):
213
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
214
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
215
+ resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
216
+ **ignore_kwargs):
217
+ super().__init__()
218
+ if use_linear_attn: attn_type = "linear"
219
+ self.ch = ch
220
+ self.temb_ch = 0
221
+ self.num_resolutions = len(ch_mult)
222
+ self.num_res_blocks = num_res_blocks
223
+ self.resolution = resolution
224
+ self.in_channels = in_channels
225
+
226
+ # downsampling
227
+ self.conv_in = torch.nn.Conv2d(in_channels,
228
+ self.ch,
229
+ kernel_size=3,
230
+ stride=1,
231
+ padding=1)
232
+
233
+ curr_res = resolution
234
+ in_ch_mult = (1,)+tuple(ch_mult)
235
+ self.in_ch_mult = in_ch_mult
236
+ self.down = nn.ModuleList()
237
+ for i_level in range(self.num_resolutions):
238
+ block = nn.ModuleList()
239
+ attn = nn.ModuleList()
240
+ block_in = ch*in_ch_mult[i_level]
241
+ block_out = ch*ch_mult[i_level]
242
+ for i_block in range(self.num_res_blocks):
243
+ block.append(ResnetBlock(in_channels=block_in,
244
+ out_channels=block_out,
245
+ temb_channels=self.temb_ch,
246
+ dropout=dropout))
247
+ block_in = block_out
248
+ if curr_res in attn_resolutions:
249
+ attn.append(make_attn(block_in, attn_type=attn_type))
250
+ down = nn.Module()
251
+ down.block = block
252
+ down.attn = attn
253
+ if i_level != self.num_resolutions-1:
254
+ down.downsample = Downsample(block_in, resamp_with_conv)
255
+ curr_res = curr_res // 2
256
+ self.down.append(down)
257
+
258
+ # middle
259
+ self.mid = nn.Module()
260
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
261
+ out_channels=block_in,
262
+ temb_channels=self.temb_ch,
263
+ dropout=dropout)
264
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
265
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
266
+ out_channels=block_in,
267
+ temb_channels=self.temb_ch,
268
+ dropout=dropout)
269
+
270
+ # end
271
+ self.norm_out = Normalize(block_in)
272
+ self.conv_out = torch.nn.Conv2d(block_in,
273
+ 2*z_channels if double_z else z_channels,
274
+ kernel_size=3,
275
+ stride=1,
276
+ padding=1)
277
+
278
+ def forward(self, x):
279
+ # timestep embedding
280
+ temb = None
281
+
282
+ # downsampling
283
+ hs = [self.conv_in(x)]
284
+ for i_level in range(self.num_resolutions):
285
+ for i_block in range(self.num_res_blocks):
286
+ h = self.down[i_level].block[i_block](hs[-1], temb)
287
+ if len(self.down[i_level].attn) > 0:
288
+ h = self.down[i_level].attn[i_block](h)
289
+ hs.append(h)
290
+ if i_level != self.num_resolutions-1:
291
+ hs.append(self.down[i_level].downsample(hs[-1]))
292
+
293
+ # middle
294
+ h = hs[-1]
295
+ h = self.mid.block_1(h, temb)
296
+ h = self.mid.attn_1(h)
297
+ h = self.mid.block_2(h, temb)
298
+
299
+ # end
300
+ h = self.norm_out(h)
301
+ h = nonlinearity(h)
302
+ h = self.conv_out(h)
303
+ return h
304
+
305
+
306
+ class Decoder(nn.Module):
307
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
308
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
309
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
310
+ attn_type="vanilla", **ignorekwargs):
311
+ super().__init__()
312
+ if use_linear_attn: attn_type = "linear"
313
+ self.ch = ch
314
+ self.temb_ch = 0
315
+ self.num_resolutions = len(ch_mult)
316
+ self.num_res_blocks = num_res_blocks
317
+ self.resolution = resolution
318
+ self.in_channels = in_channels
319
+ self.give_pre_end = give_pre_end
320
+ self.tanh_out = tanh_out
321
+
322
+ # compute in_ch_mult, block_in and curr_res at lowest res
323
+ in_ch_mult = (1,)+tuple(ch_mult)
324
+ block_in = ch*ch_mult[self.num_resolutions-1]
325
+ curr_res = resolution // 2**(self.num_resolutions-1)
326
+ self.z_shape = (1,z_channels,curr_res,curr_res)
327
+ print("Working with z of shape {} = {} dimensions.".format(
328
+ self.z_shape, np.prod(self.z_shape)))
329
+
330
+ # z to block_in
331
+ self.conv_in = torch.nn.Conv2d(z_channels,
332
+ block_in,
333
+ kernel_size=3,
334
+ stride=1,
335
+ padding=1)
336
+
337
+ # middle
338
+ self.mid = nn.Module()
339
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
340
+ out_channels=block_in,
341
+ temb_channels=self.temb_ch,
342
+ dropout=dropout)
343
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
344
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
345
+ out_channels=block_in,
346
+ temb_channels=self.temb_ch,
347
+ dropout=dropout)
348
+
349
+ # upsampling
350
+ self.up = nn.ModuleList()
351
+ for i_level in reversed(range(self.num_resolutions)):
352
+ block = nn.ModuleList()
353
+ attn = nn.ModuleList()
354
+ block_out = ch*ch_mult[i_level]
355
+ for i_block in range(self.num_res_blocks+1):
356
+ block.append(ResnetBlock(in_channels=block_in,
357
+ out_channels=block_out,
358
+ temb_channels=self.temb_ch,
359
+ dropout=dropout))
360
+ block_in = block_out
361
+ if curr_res in attn_resolutions:
362
+ attn.append(make_attn(block_in, attn_type=attn_type))
363
+ up = nn.Module()
364
+ up.block = block
365
+ up.attn = attn
366
+ if i_level != 0:
367
+ up.upsample = Upsample(block_in, resamp_with_conv)
368
+ curr_res = curr_res * 2
369
+ self.up.insert(0, up) # prepend to get consistent order
370
+
371
+ # end
372
+ self.norm_out = Normalize(block_in)
373
+ self.conv_out = torch.nn.Conv2d(block_in,
374
+ out_ch,
375
+ kernel_size=3,
376
+ stride=1,
377
+ padding=1)
378
+
379
+ def forward(self, z):
380
+ #assert z.shape[1:] == self.z_shape[1:]
381
+ self.last_z_shape = z.shape
382
+
383
+ # timestep embedding
384
+ temb = None
385
+
386
+ # z to block_in
387
+ h = self.conv_in(z)
388
+
389
+ # middle
390
+ h = self.mid.block_1(h, temb)
391
+ h = self.mid.attn_1(h)
392
+ h = self.mid.block_2(h, temb)
393
+
394
+ # upsampling
395
+ for i_level in reversed(range(self.num_resolutions)):
396
+ for i_block in range(self.num_res_blocks+1):
397
+ h = self.up[i_level].block[i_block](h, temb)
398
+ if len(self.up[i_level].attn) > 0:
399
+ h = self.up[i_level].attn[i_block](h)
400
+ if i_level != 0:
401
+ h = self.up[i_level].upsample(h)
402
+
403
+ # end
404
+ if self.give_pre_end:
405
+ return h
406
+
407
+ h = self.norm_out(h)
408
+ h = nonlinearity(h)
409
+ h = self.conv_out(h)
410
+ if self.tanh_out:
411
+ h = torch.tanh(h)
412
+ return h
413
+
414
+
415
+ class FrozenAutoencoderKL(nn.Module):
416
+ def __init__(self, ddconfig, embed_dim, pretrained_path, scale_factor=0.18215):
417
+ super().__init__()
418
+ print(f'Create autoencoder with scale_factor={scale_factor}')
419
+ self.encoder = Encoder(**ddconfig)
420
+ self.decoder = Decoder(**ddconfig)
421
+ assert ddconfig["double_z"]
422
+ self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
423
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
424
+ self.embed_dim = embed_dim
425
+ self.scale_factor = scale_factor
426
+ m, u = self.load_state_dict(torch.load(pretrained_path, map_location='cpu'))
427
+ assert len(m) == 0 and len(u) == 0
428
+ self.eval()
429
+ self.requires_grad_(False)
430
+
431
+ def encode_moments(self, x):
432
+ h = self.encoder(x)
433
+ moments = self.quant_conv(h)
434
+ return moments
435
+
436
+ def sample(self, moments):
437
+ mean, logvar = torch.chunk(moments, 2, dim=1)
438
+ logvar = torch.clamp(logvar, -30.0, 20.0)
439
+ std = torch.exp(0.5 * logvar)
440
+ z = mean + std * torch.randn_like(mean)
441
+ z = self.scale_factor * z
442
+ return z
443
+
444
+ def encode(self, x):
445
+ moments = self.encode_moments(x)
446
+ z = self.sample(moments)
447
+ return z
448
+
449
+ def decode(self, z):
450
+ z = (1. / self.scale_factor) * z
451
+ z = self.post_quant_conv(z)
452
+ dec = self.decoder(z)
453
+ return dec
454
+
455
+ def forward(self, inputs, fn):
456
+ if fn == 'encode_moments':
457
+ return self.encode_moments(inputs)
458
+ elif fn == 'encode':
459
+ return self.encode(inputs)
460
+ elif fn == 'decode':
461
+ return self.decode(inputs)
462
+ else:
463
+ raise NotImplementedError
464
+
465
+
466
+ def get_model(pretrained_path, scale_factor=0.18215):
467
+ ddconfig = dict(
468
+ double_z=True,
469
+ z_channels=4,
470
+ resolution=256,
471
+ in_channels=3,
472
+ out_ch=3,
473
+ ch=128,
474
+ ch_mult=[1, 2, 4, 4],
475
+ num_res_blocks=2,
476
+ attn_resolutions=[],
477
+ dropout=0.0
478
+ )
479
+ return FrozenAutoencoderKL(ddconfig, 4, pretrained_path, scale_factor)
480
+
481
+
482
+ def main():
483
+ import torchvision.transforms as transforms
484
+ from torchvision.utils import save_image
485
+ import os
486
+ from PIL import Image
487
+
488
+ model = get_model('assets/stable-diffusion/autoencoder_kl.pth')
489
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
490
+ model = model.to(device)
491
+
492
+ scale_factor = 0.18215
493
+ T = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(256), transforms.ToTensor()])
494
+ path = 'imgs'
495
+ fnames = os.listdir(path)
496
+ for fname in fnames:
497
+ p = os.path.join(path, fname)
498
+ img = Image.open(p)
499
+ img = T(img)
500
+ img = img * 2. - 1
501
+ img = img[None, ...]
502
+ img = img.to(device)
503
+
504
+ # with torch.cuda.amp.autocast():
505
+ # moments = model.encode_moments(img)
506
+ # mean, logvar = torch.chunk(moments, 2, dim=1)
507
+ # logvar = torch.clamp(logvar, -30.0, 20.0)
508
+ # std = torch.exp(0.5 * logvar)
509
+ # zs = [(mean + std * torch.randn_like(mean)) * scale_factor for _ in range(4)]
510
+ # recons = [model.decode(z) for z in zs]
511
+
512
+ with torch.cuda.amp.autocast():
513
+ print('test encode & decode')
514
+ recons = [model.decode(model.encode(img)) for _ in range(4)]
515
+
516
+ out = torch.cat([img, *recons], dim=0)
517
+ out = (out + 1) * 0.5
518
+ save_image(out, f'recons_{fname}')
519
+
520
+
521
+ if __name__ == "__main__":
522
+ main()
checkpoints/.DS_Store ADDED
Binary file (6.15 kB). View file
 
configs/finetune/imagenet256-latent-const.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ dataset: imagenet256-latent
3
+ category: lmdb
4
+ resolution: 32
5
+ num_channels: 4
6
+ random_flip: True
7
+ root: ../data/imagenet256
8
+ feat_path: None
9
+
10
+ model:
11
+ precond: edm
12
+ model_type: DiT-XL/2
13
+ in_size: 32
14
+ in_channels: 4
15
+ num_classes: 1000
16
+ use_decoder: True
17
+ ext_feature_dim: 0
18
+ pad_cls_token: False
19
+ mask_ratio: 0.0
20
+ mask_ratio_fn: constant
21
+ mask_ratio_min: 0
22
+ mae_loss_coef: 0.1
23
+ class_dropout_prob: 0.1
24
+
25
+ train:
26
+ tf32: True
27
+ amp: False
28
+ batchsize: 64 # batchsize per GPU
29
+ grad_accum: 1
30
+ epochs: 1000
31
+ lr: 0.00005
32
+ lr_rampup_kimg: 0
33
+ xflip: False
34
+ max_num_steps: 100_000
35
+
36
+ eval: # FID evaluation
37
+ cfg_scales: [1.5]
38
+ batchsize: 50
39
+ ref_path: assets/fid_stats/fid_stats_imagenet256_guided_diffusion.npz
40
+
41
+ log:
42
+ log_every: 500
43
+ ckpt_every: 12_500
44
+ tag: finetune-const
45
+
46
+ wandb:
47
+ entity: MaskDiT
48
+ project: MaskDiT-ImageNet256-latent-finetune
49
+ group: finetune-const
configs/finetune/imagenet256-latent-cos.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ dataset: imagenet256-latent
3
+ category: lmdb
4
+ resolution: 32
5
+ num_channels: 4
6
+ random_flip: True
7
+ root: ../data/imagenet256
8
+ feat_path: None
9
+
10
+ model:
11
+ precond: edm
12
+ model_type: DiT-XL/2
13
+ in_size: 32
14
+ in_channels: 4
15
+ num_classes: 1000
16
+ use_decoder: True
17
+ ext_feature_dim: 0
18
+ pad_cls_token: False
19
+ mask_ratio: 0.5
20
+ mask_ratio_fn: cos4
21
+ mask_ratio_min: 0
22
+ mae_loss_coef: 0.1
23
+ class_dropout_prob: 0.1
24
+
25
+ train:
26
+ tf32: True
27
+ amp: False
28
+ batchsize: 64 # batchsize per GPU
29
+ grad_accum: 1
30
+ epochs: 1000
31
+ lr: 0.00005
32
+ lr_rampup_kimg: 0
33
+ xflip: False
34
+ max_num_steps: 100_000
35
+
36
+ eval: # FID evaluation
37
+ cfg_scales: [1.5]
38
+ batchsize: 50
39
+ ref_path: assets/fid_stats/fid_stats_imagenet256_guided_diffusion.npz
40
+
41
+ log:
42
+ log_every: 500
43
+ ckpt_every: 12_500
44
+ tag: finetune-cos
45
+
46
+ wandb:
47
+ entity: MaskDiT
48
+ project: MaskDiT-ImageNet256-latent-finetune
49
+ group: finetune-cos
configs/finetune/imagenet512-latent.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ dataset: imagenet512-latent
3
+ category: lmdb
4
+ resolution: 64
5
+ num_channels: 4
6
+ root: ../data/imagenet512-wds
7
+ total_num: 1281167
8
+
9
+ model:
10
+ precond: edm
11
+ model_type: DiT-XL/2
12
+ in_size: 64
13
+ in_channels: 4
14
+ num_classes: 1000
15
+ use_decoder: True
16
+ ext_feature_dim: 0
17
+ pad_cls_token: False
18
+ mask_ratio: 0.0
19
+ mask_ratio_fn: constant
20
+ mask_ratio_min: 0
21
+ mae_loss_coef: 0.1
22
+ class_dropout_prob: 0.1
23
+
24
+ train:
25
+ tf32: True
26
+ amp: False
27
+ batchsize: 16 # batchsize per GPU
28
+ grad_accum: 1
29
+ epochs: 2000
30
+ lr: 0.00005
31
+ lr_rampup_kimg: 0
32
+ xflip: False
33
+ max_num_steps: 50_000
34
+
35
+ eval: # FID evaluation
36
+ batchsize: 50
37
+ ref_path: assets/fid_stats/VIRTUAL_imagenet512.npz
38
+
39
+ log:
40
+ log_every: 100
41
+ ckpt_every: 10_000
42
+ tag: finetune-4n-wds
43
+
44
+ wandb:
45
+ entity: MaskDiT
46
+ project: MaskDiT-ImageNet512-latent-finetune
47
+ group: finetune-wds-4nodes
configs/test/maskdit-256.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ dataset: imagenet256-latent
3
+ category: lmdb
4
+ resolution: 32
5
+ num_channels: 4
6
+ root: /imagenet_256_latent_lmdb
7
+ total_num: 1281167
8
+
9
+ model:
10
+ precond: edm
11
+ model_type: DiT-XL/2
12
+ in_size: 32
13
+ in_channels: 4
14
+ num_classes: 1000
15
+ use_decoder: True
16
+ ext_feature_dim: 0
17
+ pad_cls_token: False
18
+ mask_ratio: 0.5
19
+ cond_mask_ratio: 0
20
+ mae_loss_coef: 0.1
21
+ class_dropout_prob: 0.1
22
+
23
+ train:
24
+ tf32: False
25
+ amp: True
26
+ batchsize: 32 # batchsize per GPU
27
+ grad_accum: 1
28
+ epochs: 2800
29
+ lr: 0.0001
30
+ lr_rampup_kimg: 0
31
+ xflip: False
32
+
33
+ eval: # FID evaluation
34
+ batchsize: 50
35
+ ref_path: assets/fid_stats/fid_stats_imagenet256_guided_diffusion.npz
36
+
37
+ log:
38
+ log_every: 500
39
+ ckpt_every: 50_000
40
+ tag: baseline
41
+
42
+ wandb:
43
+ entity: MaskDiT
44
+ project: MaskDiT-ImageNet256-latent
45
+ group: baseline
configs/test/maskdit-512.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ dataset: imagenet512-latent
3
+ category: webdataset
4
+ resolution: 64
5
+ num_channels: 4
6
+ root: ../data/imagenet-wds
7
+ total_num: 1281167
8
+
9
+ model:
10
+ precond: edm
11
+ model_type: DiT-XL/2
12
+ in_size: 64
13
+ in_channels: 4
14
+ num_classes: 1000
15
+ use_decoder: True
16
+ ext_feature_dim: 0
17
+ pad_cls_token: False
18
+ mask_ratio: 0.5
19
+ cond_mask_ratio: 0
20
+ mae_loss_coef: 0.1
21
+ class_dropout_prob: 0.1
22
+
23
+ train:
24
+ tf32: False
25
+ amp: True
26
+ batchsize: 32 # batchsize per GPU
27
+ grad_accum: 1
28
+ epochs: 2800
29
+ lr: 0.0001
30
+ lr_rampup_kimg: 0
31
+ xflip: False
32
+ max_num_steps: 2000000
33
+
34
+ eval: # FID evaluation
35
+ batchsize: 50
36
+ ref_path: assets/fid_stats/VIRTUAL_imagenet512.npz
37
+
38
+ log:
39
+ log_every: 500
40
+ ckpt_every: 50_000
41
+ tag: pretrain
42
+
43
+ wandb:
44
+ entity: MaskDiT
45
+ project: MaskDiT-ImageNet256-latent
46
+ group: pretrain
configs/train/imagenet256-latent.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ dataset: imagenet256-latent
3
+ category: lmdb
4
+ resolution: 32
5
+ num_channels: 4
6
+ random_flip: True
7
+ root: ../data/imagenet256
8
+ feat_path: None
9
+
10
+ model:
11
+ precond: edm
12
+ model_type: DiT-XL/2
13
+ in_size: 32
14
+ in_channels: 4
15
+ num_classes: 1000
16
+ use_decoder: True
17
+ ext_feature_dim: 0
18
+ pad_cls_token: False
19
+ mask_ratio: 0.5
20
+ mask_ratio_fn: constant
21
+ mask_ratio_min: 0
22
+ mae_loss_coef: 0.1
23
+ class_dropout_prob: 0.1
24
+
25
+ train:
26
+ tf32: False
27
+ amp: True
28
+ batchsize: 128 # batchsize per GPU
29
+ grad_accum: 1
30
+ epochs: 2800
31
+ lr: 0.0001
32
+ lr_rampup_kimg: 0
33
+ xflip: False
34
+ max_num_steps: 2000000
35
+
36
+ eval: # FID evaluation
37
+ batchsize: 50
38
+ ref_path: assets/fid_stats/fid_stats_imagenet256_guided_diffusion.npz
39
+
40
+ log:
41
+ log_every: 500
42
+ ckpt_every: 50_000
43
+ tag: pretrain
44
+
45
+ wandb:
46
+ entity: MaskDiT
47
+ project: MaskDiT-ImageNet256-latent-train
48
+ group: pretrain
configs/train/imagenet512-latent.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ dataset: imagenet512-latent
3
+ category: webdataset
4
+ resolution: 64
5
+ num_channels: 4
6
+ root: ../data/imagenet512-wds
7
+ total_num: 1281167
8
+
9
+ model:
10
+ precond: edm
11
+ model_type: DiT-XL/2
12
+ in_size: 64
13
+ in_channels: 4
14
+ num_classes: 1000
15
+ use_decoder: True
16
+ ext_feature_dim: 0
17
+ pad_cls_token: False
18
+ mask_ratio: 0.5
19
+ mask_ratio_fn: constant
20
+ mask_ratio_min: 0
21
+ mae_loss_coef: 0.1
22
+ class_dropout_prob: 0.1
23
+
24
+ train:
25
+ tf32: False
26
+ amp: True
27
+ batchsize: 32 # batchsize per GPU
28
+ grad_accum: 1
29
+ epochs: 2000
30
+ lr: 0.0001
31
+ lr_rampup_kimg: 0
32
+ xflip: False
33
+ max_num_steps: 2000000
34
+
35
+ eval: # FID evaluation
36
+ batchsize: 50
37
+ ref_path: assets/fid_stats/VIRTUAL_imagenet512.npz
38
+
39
+ log:
40
+ log_every: 100
41
+ ckpt_every: 25_000
42
+ tag: pretrain-4nodes
43
+
44
+ wandb:
45
+ entity: MaskDiT
46
+ project: MaskDiT-ImageNet512-latent-train
47
+ group: pretrain-4nodes
eval_latent.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) [2023] [Anima-Lab]
4
+
5
+ from argparse import ArgumentParser
6
+ import os
7
+ from collections import OrderedDict
8
+ from omegaconf import OmegaConf
9
+
10
+ import torch
11
+
12
+ import accelerate
13
+
14
+ from fid import calc
15
+ from models.maskdit import Precond_models
16
+ from sample import generate_with_net
17
+ from utils import dist, mprint, get_ckpt_paths, Logger, parse_int_list, parse_float_none
18
+
19
+
20
+ # ------------------------------------------------------------
21
+ # Training Helper Function
22
+
23
+ @torch.no_grad()
24
+ def update_ema(ema_model, model, decay=0.9999):
25
+ """
26
+ Step the EMA model towards the current model.
27
+ """
28
+ ema_params = OrderedDict(ema_model.named_parameters())
29
+ model_params = OrderedDict(model.named_parameters())
30
+
31
+ for name, param in model_params.items():
32
+ # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
33
+ ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
34
+
35
+
36
+ def requires_grad(model, flag=True):
37
+ """
38
+ Set requires_grad flag for all parameters in a model.
39
+ """
40
+ for p in model.parameters():
41
+ p.requires_grad = flag
42
+
43
+ # ------------------------------------------------------------
44
+
45
+
46
+ def eval_fn(model, args, device, rank, size):
47
+ generate_with_net(args, model, device, rank, size)
48
+ dist.barrier()
49
+ fid = calc(args.outdir, args.ref_path, args.num_expected, args.global_seed, args.fid_batch_size)
50
+ mprint(f'{args.num_expected} samples generated and saved in {args.outdir}')
51
+ mprint(f'guidance: {args.cfg_scale} FID: {fid}')
52
+ dist.barrier()
53
+ return fid
54
+
55
+
56
+ def eval_loop(args):
57
+ config = OmegaConf.load(args.config)
58
+ accelerator = accelerate.Accelerator()
59
+
60
+ device = accelerator.device
61
+ size = accelerator.num_processes
62
+ rank = accelerator.process_index
63
+ print(f'world_size: {size}, rank: {rank}')
64
+ experiment_dir = args.exp_dir
65
+
66
+ if accelerator.is_main_process:
67
+ logger = Logger(file_name=f'{experiment_dir}/log_eval.txt', file_mode="a+", should_flush=True)
68
+ # setup wandb
69
+
70
+ model = Precond_models[config.model.precond](
71
+ img_resolution=config.model.in_size,
72
+ img_channels=config.model.in_channels,
73
+ num_classes=config.model.num_classes,
74
+ model_type=config.model.model_type,
75
+ use_decoder=config.model.use_decoder,
76
+ mae_loss_coef=config.model.mae_loss_coef,
77
+ pad_cls_token=config.model.pad_cls_token,
78
+ ).to(device)
79
+ # Note that parameter initialization is done within the model constructor
80
+ model.eval()
81
+ mprint(f"{config.model.model_type} ((use_decoder: {config.model.use_decoder})) Model Parameters: {sum(p.numel() for p in model.parameters()):,}")
82
+ mprint(f'extras: {model.model.extras}, cls_token: {model.model.cls_token}')
83
+
84
+ # model = torch.compile(model)
85
+ # Load checkpoints
86
+ mprint('start evaluating...')
87
+
88
+ args.outdir = os.path.join(experiment_dir, 'fid', f'edm-steps{args.num_steps}_cfg{args.cfg_scale}')
89
+ os.makedirs(args.outdir, exist_ok=True)
90
+ ckpt = torch.load(args.ckpt, map_location=device)
91
+ model.load_state_dict(ckpt['ema'])
92
+ fid = eval_fn(model, args, device, rank, size)
93
+ mprint(f'FID: {fid}')
94
+
95
+ if accelerator.is_main_process:
96
+ logger.close()
97
+ accelerator.end_training()
98
+
99
+
100
+ if __name__ == '__main__':
101
+ parser = ArgumentParser('training parameters')
102
+ # basic config
103
+ parser.add_argument('--config', type=str, required=True, help='path to config file')
104
+
105
+ # training
106
+ parser.add_argument("--exp_dir", type=str, required=True, help='The exp directory to evaluate, it must contain a checkpoints folder')
107
+ parser.add_argument('--ckpt', type=str, required=True, help='path to the checkpoint')
108
+
109
+ # sampling
110
+ parser.add_argument('--seeds', type=parse_int_list, default='100000-149999', help='Random seeds (e.g. 1,2,5-10)')
111
+ parser.add_argument('--subdirs', action='store_true', help='Create subdirectory for every 1000 seeds')
112
+ parser.add_argument('--class_idx', type=int, default=None, help='Class label [default: random]')
113
+ parser.add_argument('--max_batch_size', type=int, default=50, help='Maximum batch size per GPU during sampling, must be a factor of 50k if torch.compile is used')
114
+ parser.add_argument("--cfg_scale", type=parse_float_none, default=None, help='None = no guidance, by default = 4.0')
115
+
116
+ parser.add_argument('--num_steps', type=int, default=40, help='Number of sampling steps')
117
+ parser.add_argument('--S_churn', type=int, default=0, help='Stochasticity strength')
118
+ parser.add_argument('--solver', type=str, default=None, choices=['euler', 'heun'], help='Ablate ODE solver')
119
+ parser.add_argument('--discretization', type=str, default=None, choices=['vp', 've', 'iddpm', 'edm'], help='Ablate ODE solver')
120
+ parser.add_argument('--schedule', type=str, default=None, choices=['vp', 've', 'linear'], help='Ablate noise schedule sigma(t)')
121
+ parser.add_argument('--scaling', type=str, default=None, choices=['vp', 'none'], help='Ablate signal scaling s(t)')
122
+ parser.add_argument('--pretrained_path', type=str, default='assets/stable_diffusion/autoencoder_kl.pth', help='Autoencoder ckpt')
123
+
124
+ parser.add_argument('--ref_path', type=str, default='assets/fid_stats/VIRTUAL_imagenet512.npz', help='Dataset reference statistics')
125
+ parser.add_argument('--num_expected', type=int, default=50000, help='Number of images to use')
126
+ parser.add_argument("--global_seed", type=int, default=0)
127
+ parser.add_argument('--fid_batch_size', type=int, default=128, help='Maximum batch size per GPU')
128
+
129
+ args = parser.parse_args()
130
+
131
+ torch.backends.cudnn.benchmark = True
132
+ eval_loop(args)
evaluator.py ADDED
@@ -0,0 +1,695 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) [2023] [Anima-Lab]
4
+
5
+ # This code is adapted from https://github.com/openai/guided-diffusion/blob/main/evaluations/evaluator.py.
6
+ # The original code is licensed under a MIT License, which is can be found at licenses/LICENSE_ADM.txt.
7
+
8
+
9
+ import argparse
10
+ import io
11
+ import os
12
+ import random
13
+ import warnings
14
+ import zipfile
15
+ from abc import ABC, abstractmethod
16
+ from contextlib import contextmanager
17
+ from functools import partial
18
+ from multiprocessing import cpu_count
19
+ from multiprocessing.pool import ThreadPool
20
+ from typing import Iterable, Optional, Tuple
21
+
22
+ import numpy as np
23
+ import requests
24
+ import tensorflow.compat.v1 as tf
25
+ from scipy import linalg
26
+ from tqdm.auto import tqdm
27
+ from PIL import Image
28
+
29
+
30
+
31
+
32
+ INCEPTION_V3_URL = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb"
33
+ INCEPTION_V3_PATH = "classify_image_graph_def.pb"
34
+
35
+ FID_POOL_NAME = "pool_3:0"
36
+ FID_SPATIAL_NAME = "mixed_6/conv:0"
37
+
38
+
39
+ def get_all_files(path):
40
+ path_list = []
41
+ for root, dirs, files in os.walk(path):
42
+ for file in files:
43
+ path_list.append(os.path.join(root, file))
44
+ return path_list
45
+
46
+
47
+ def isimg(filename):
48
+ if filename.endswith(".png") or filename.endswith(".jpg"):
49
+ return True
50
+ else:
51
+ return False
52
+
53
+
54
+ def png2npz(img_dir):
55
+ img_list = []
56
+ file_list = get_all_files(img_dir)
57
+ for filename in file_list:
58
+ if isimg(filename):
59
+ filepath = filename
60
+ img = np.asarray(Image.open(filepath).convert('RGB'))
61
+ img_list.append(img)
62
+ imgs = np.stack(img_list, axis=0)
63
+ npz_dir = os.path.join('tmp', 'fid')
64
+ os.makedirs(npz_dir, exist_ok=True)
65
+ npz_path = os.path.join(npz_dir, 'imgs.npz')
66
+ np.savez(npz_path, imgs)
67
+ return npz_path
68
+
69
+
70
+ def main():
71
+ parser = argparse.ArgumentParser()
72
+ parser.add_argument("ref_batch", help="path to reference batch npz file")
73
+ parser.add_argument("sample_batch", help="path to sample batch npz file")
74
+ args = parser.parse_args()
75
+
76
+ config = tf.ConfigProto(
77
+ allow_soft_placement=True # allows DecodeJpeg to run on CPU in Inception graph
78
+ )
79
+ config.gpu_options.allow_growth = True
80
+ evaluator = Evaluator(tf.Session(config=config))
81
+
82
+ print("warming up TensorFlow...")
83
+ # This will cause TF to print a bunch of verbose stuff now rather
84
+ # than after the next print(), to help prevent confusion.
85
+ evaluator.warmup()
86
+
87
+ print("computing reference batch activations...")
88
+ ref_acts = evaluator.read_activations(args.ref_batch)
89
+ print("computing/reading reference batch statistics...")
90
+ ref_stats, ref_stats_spatial = evaluator.read_statistics(args.ref_batch, ref_acts)
91
+
92
+ if os.path.isdir(args.sample_batch):
93
+ sample_batch = png2npz(args.sample_batch)
94
+ else:
95
+ sample_batch = args.sample_batch
96
+
97
+ print("computing sample batch activations...")
98
+ sample_acts = evaluator.read_activations(sample_batch)
99
+ print("computing/reading sample batch statistics...")
100
+ sample_stats, sample_stats_spatial = evaluator.read_statistics(sample_batch, sample_acts)
101
+
102
+ print("Computing evaluations...")
103
+ print("Inception Score:", evaluator.compute_inception_score(sample_acts[0]))
104
+ print("FID:", sample_stats.frechet_distance(ref_stats))
105
+ print("sFID:", sample_stats_spatial.frechet_distance(ref_stats_spatial))
106
+ prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0])
107
+ print("Precision:", prec)
108
+ print("Recall:", recall)
109
+
110
+
111
+ class InvalidFIDException(Exception):
112
+ pass
113
+
114
+
115
+ class FIDStatistics:
116
+ def __init__(self, mu: np.ndarray, sigma: np.ndarray):
117
+ self.mu = mu
118
+ self.sigma = sigma
119
+
120
+ def frechet_distance(self, other, eps=1e-6):
121
+ """
122
+ Compute the Frechet distance between two sets of statistics.
123
+ """
124
+ # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132
125
+ mu1, sigma1 = self.mu, self.sigma
126
+ mu2, sigma2 = other.mu, other.sigma
127
+
128
+ mu1 = np.atleast_1d(mu1)
129
+ mu2 = np.atleast_1d(mu2)
130
+
131
+ sigma1 = np.atleast_2d(sigma1)
132
+ sigma2 = np.atleast_2d(sigma2)
133
+
134
+ assert (
135
+ mu1.shape == mu2.shape
136
+ ), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}"
137
+ assert (
138
+ sigma1.shape == sigma2.shape
139
+ ), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}"
140
+
141
+ diff = mu1 - mu2
142
+
143
+ # product might be almost singular
144
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
145
+ if not np.isfinite(covmean).all():
146
+ msg = (
147
+ "fid calculation produces singular product; adding %s to diagonal of cov estimates"
148
+ % eps
149
+ )
150
+ warnings.warn(msg)
151
+ offset = np.eye(sigma1.shape[0]) * eps
152
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
153
+
154
+ # numerical error might give slight imaginary component
155
+ if np.iscomplexobj(covmean):
156
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
157
+ m = np.max(np.abs(covmean.imag))
158
+ raise ValueError("Imaginary component {}".format(m))
159
+ covmean = covmean.real
160
+
161
+ tr_covmean = np.trace(covmean)
162
+
163
+ return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
164
+
165
+
166
+ class Evaluator:
167
+ def __init__(
168
+ self,
169
+ session,
170
+ batch_size=64,
171
+ softmax_batch_size=512,
172
+ ):
173
+ self.sess = session
174
+ self.batch_size = batch_size
175
+ self.softmax_batch_size = softmax_batch_size
176
+ self.manifold_estimator = ManifoldEstimator(session)
177
+ with self.sess.graph.as_default():
178
+ self.image_input = tf.placeholder(tf.float32, shape=[None, None, None, 3])
179
+ self.softmax_input = tf.placeholder(tf.float32, shape=[None, 2048])
180
+ self.pool_features, self.spatial_features = _create_feature_graph(self.image_input)
181
+ self.softmax = _create_softmax_graph(self.softmax_input)
182
+
183
+ def warmup(self):
184
+ self.compute_activations(np.zeros([1, 8, 64, 64, 3]))
185
+
186
+ def read_activations(self, npz_path: str) -> Tuple[np.ndarray, np.ndarray]:
187
+ with open_npz_array(npz_path, "arr_0") as reader:
188
+ return self.compute_activations(reader.read_batches(self.batch_size))
189
+
190
+ def compute_activations(self, batches: Iterable[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
191
+ """
192
+ Compute image features for downstream evals.
193
+ :param batches: a iterator over NHWC numpy arrays in [0, 255].
194
+ :return: a tuple of numpy arrays of shape [N x X], where X is a feature
195
+ dimension. The tuple is (pool_3, spatial).
196
+ """
197
+ preds = []
198
+ spatial_preds = []
199
+ for batch in tqdm(batches):
200
+ batch = batch.astype(np.float32)
201
+ pred, spatial_pred = self.sess.run(
202
+ [self.pool_features, self.spatial_features], {self.image_input: batch}
203
+ )
204
+ preds.append(pred.reshape([pred.shape[0], -1]))
205
+ spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1]))
206
+ return (
207
+ np.concatenate(preds, axis=0),
208
+ np.concatenate(spatial_preds, axis=0),
209
+ )
210
+
211
+ def read_statistics(
212
+ self, npz_path: str, activations: Tuple[np.ndarray, np.ndarray]
213
+ ) -> Tuple[FIDStatistics, FIDStatistics]:
214
+ obj = np.load(npz_path)
215
+ if "mu" in list(obj.keys()):
216
+ return FIDStatistics(obj["mu"], obj["sigma"]), FIDStatistics(
217
+ obj["mu_s"], obj["sigma_s"]
218
+ )
219
+ return tuple(self.compute_statistics(x) for x in activations)
220
+
221
+ def compute_statistics(self, activations: np.ndarray) -> FIDStatistics:
222
+ mu = np.mean(activations, axis=0)
223
+ sigma = np.cov(activations, rowvar=False)
224
+ return FIDStatistics(mu, sigma)
225
+
226
+ def compute_inception_score(self, activations: np.ndarray, split_size: int = 5000) -> float:
227
+ softmax_out = []
228
+ for i in range(0, len(activations), self.softmax_batch_size):
229
+ acts = activations[i : i + self.softmax_batch_size]
230
+ softmax_out.append(self.sess.run(self.softmax, feed_dict={self.softmax_input: acts}))
231
+ preds = np.concatenate(softmax_out, axis=0)
232
+ # https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46
233
+ scores = []
234
+ for i in range(0, len(preds), split_size):
235
+ part = preds[i : i + split_size]
236
+ kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
237
+ kl = np.mean(np.sum(kl, 1))
238
+ scores.append(np.exp(kl))
239
+ return float(np.mean(scores))
240
+
241
+ def compute_prec_recall(
242
+ self, activations_ref: np.ndarray, activations_sample: np.ndarray
243
+ ) -> Tuple[float, float]:
244
+ radii_1 = self.manifold_estimator.manifold_radii(activations_ref)
245
+ radii_2 = self.manifold_estimator.manifold_radii(activations_sample)
246
+ pr = self.manifold_estimator.evaluate_pr(
247
+ activations_ref, radii_1, activations_sample, radii_2
248
+ )
249
+ return (float(pr[0][0]), float(pr[1][0]))
250
+
251
+
252
+ class ManifoldEstimator:
253
+ """
254
+ A helper for comparing manifolds of feature vectors.
255
+ Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L57
256
+ """
257
+
258
+ def __init__(
259
+ self,
260
+ session,
261
+ row_batch_size=10000,
262
+ col_batch_size=10000,
263
+ nhood_sizes=(3,),
264
+ clamp_to_percentile=None,
265
+ eps=1e-5,
266
+ ):
267
+ """
268
+ Estimate the manifold of given feature vectors.
269
+ :param session: the TensorFlow session.
270
+ :param row_batch_size: row batch size to compute pairwise distances
271
+ (parameter to trade-off between memory usage and performance).
272
+ :param col_batch_size: column batch size to compute pairwise distances.
273
+ :param nhood_sizes: number of neighbors used to estimate the manifold.
274
+ :param clamp_to_percentile: prune hyperspheres that have radius larger than
275
+ the given percentile.
276
+ :param eps: small number for numerical stability.
277
+ """
278
+ self.distance_block = DistanceBlock(session)
279
+ self.row_batch_size = row_batch_size
280
+ self.col_batch_size = col_batch_size
281
+ self.nhood_sizes = nhood_sizes
282
+ self.num_nhoods = len(nhood_sizes)
283
+ self.clamp_to_percentile = clamp_to_percentile
284
+ self.eps = eps
285
+
286
+ def warmup(self):
287
+ feats, radii = (
288
+ np.zeros([1, 2048], dtype=np.float32),
289
+ np.zeros([1, 1], dtype=np.float32),
290
+ )
291
+ self.evaluate_pr(feats, radii, feats, radii)
292
+
293
+ def manifold_radii(self, features: np.ndarray) -> np.ndarray:
294
+ num_images = len(features)
295
+
296
+ # Estimate manifold of features by calculating distances to k-NN of each sample.
297
+ radii = np.zeros([num_images, self.num_nhoods], dtype=np.float32)
298
+ distance_batch = np.zeros([self.row_batch_size, num_images], dtype=np.float32)
299
+ seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32)
300
+
301
+ for begin1 in range(0, num_images, self.row_batch_size):
302
+ end1 = min(begin1 + self.row_batch_size, num_images)
303
+ row_batch = features[begin1:end1]
304
+
305
+ for begin2 in range(0, num_images, self.col_batch_size):
306
+ end2 = min(begin2 + self.col_batch_size, num_images)
307
+ col_batch = features[begin2:end2]
308
+
309
+ # Compute distances between batches.
310
+ distance_batch[
311
+ 0 : end1 - begin1, begin2:end2
312
+ ] = self.distance_block.pairwise_distances(row_batch, col_batch)
313
+
314
+ # Find the k-nearest neighbor from the current batch.
315
+ radii[begin1:end1, :] = np.concatenate(
316
+ [
317
+ x[:, self.nhood_sizes]
318
+ for x in _numpy_partition(distance_batch[0 : end1 - begin1, :], seq, axis=1)
319
+ ],
320
+ axis=0,
321
+ )
322
+
323
+ if self.clamp_to_percentile is not None:
324
+ max_distances = np.percentile(radii, self.clamp_to_percentile, axis=0)
325
+ radii[radii > max_distances] = 0
326
+ return radii
327
+
328
+ def evaluate(self, features: np.ndarray, radii: np.ndarray, eval_features: np.ndarray):
329
+ """
330
+ Evaluate if new feature vectors are at the manifold.
331
+ """
332
+ num_eval_images = eval_features.shape[0]
333
+ num_ref_images = radii.shape[0]
334
+ distance_batch = np.zeros([self.row_batch_size, num_ref_images], dtype=np.float32)
335
+ batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32)
336
+ max_realism_score = np.zeros([num_eval_images], dtype=np.float32)
337
+ nearest_indices = np.zeros([num_eval_images], dtype=np.int32)
338
+
339
+ for begin1 in range(0, num_eval_images, self.row_batch_size):
340
+ end1 = min(begin1 + self.row_batch_size, num_eval_images)
341
+ feature_batch = eval_features[begin1:end1]
342
+
343
+ for begin2 in range(0, num_ref_images, self.col_batch_size):
344
+ end2 = min(begin2 + self.col_batch_size, num_ref_images)
345
+ ref_batch = features[begin2:end2]
346
+
347
+ distance_batch[
348
+ 0 : end1 - begin1, begin2:end2
349
+ ] = self.distance_block.pairwise_distances(feature_batch, ref_batch)
350
+
351
+ # From the minibatch of new feature vectors, determine if they are in the estimated manifold.
352
+ # If a feature vector is inside a hypersphere of some reference sample, then
353
+ # the new sample lies at the estimated manifold.
354
+ # The radii of the hyperspheres are determined from distances of neighborhood size k.
355
+ samples_in_manifold = distance_batch[0 : end1 - begin1, :, None] <= radii
356
+ batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(np.int32)
357
+
358
+ max_realism_score[begin1:end1] = np.max(
359
+ radii[:, 0] / (distance_batch[0 : end1 - begin1, :] + self.eps), axis=1
360
+ )
361
+ nearest_indices[begin1:end1] = np.argmin(distance_batch[0 : end1 - begin1, :], axis=1)
362
+
363
+ return {
364
+ "fraction": float(np.mean(batch_predictions)),
365
+ "batch_predictions": batch_predictions,
366
+ "max_realisim_score": max_realism_score,
367
+ "nearest_indices": nearest_indices,
368
+ }
369
+
370
+ def evaluate_pr(
371
+ self,
372
+ features_1: np.ndarray,
373
+ radii_1: np.ndarray,
374
+ features_2: np.ndarray,
375
+ radii_2: np.ndarray,
376
+ ) -> Tuple[np.ndarray, np.ndarray]:
377
+ """
378
+ Evaluate precision and recall efficiently.
379
+ :param features_1: [N1 x D] feature vectors for reference batch.
380
+ :param radii_1: [N1 x K1] radii for reference vectors.
381
+ :param features_2: [N2 x D] feature vectors for the other batch.
382
+ :param radii_2: [N x K2] radii for other vectors.
383
+ :return: a tuple of arrays for (precision, recall):
384
+ - precision: an np.ndarray of length K1
385
+ - recall: an np.ndarray of length K2
386
+ """
387
+ features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=np.bool)
388
+ features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=np.bool)
389
+ for begin_1 in range(0, len(features_1), self.row_batch_size):
390
+ end_1 = begin_1 + self.row_batch_size
391
+ batch_1 = features_1[begin_1:end_1]
392
+ for begin_2 in range(0, len(features_2), self.col_batch_size):
393
+ end_2 = begin_2 + self.col_batch_size
394
+ batch_2 = features_2[begin_2:end_2]
395
+ batch_1_in, batch_2_in = self.distance_block.less_thans(
396
+ batch_1, radii_1[begin_1:end_1], batch_2, radii_2[begin_2:end_2]
397
+ )
398
+ features_1_status[begin_1:end_1] |= batch_1_in
399
+ features_2_status[begin_2:end_2] |= batch_2_in
400
+ return (
401
+ np.mean(features_2_status.astype(np.float64), axis=0),
402
+ np.mean(features_1_status.astype(np.float64), axis=0),
403
+ )
404
+
405
+
406
+ class DistanceBlock:
407
+ """
408
+ Calculate pairwise distances between vectors.
409
+ Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L34
410
+ """
411
+
412
+ def __init__(self, session):
413
+ self.session = session
414
+
415
+ # Initialize TF graph to calculate pairwise distances.
416
+ with session.graph.as_default():
417
+ self._features_batch1 = tf.placeholder(tf.float32, shape=[None, None])
418
+ self._features_batch2 = tf.placeholder(tf.float32, shape=[None, None])
419
+ distance_block_16 = _batch_pairwise_distances(
420
+ tf.cast(self._features_batch1, tf.float16),
421
+ tf.cast(self._features_batch2, tf.float16),
422
+ )
423
+ self.distance_block = tf.cond(
424
+ tf.reduce_all(tf.math.is_finite(distance_block_16)),
425
+ lambda: tf.cast(distance_block_16, tf.float32),
426
+ lambda: _batch_pairwise_distances(self._features_batch1, self._features_batch2),
427
+ )
428
+
429
+ # Extra logic for less thans.
430
+ self._radii1 = tf.placeholder(tf.float32, shape=[None, None])
431
+ self._radii2 = tf.placeholder(tf.float32, shape=[None, None])
432
+ dist32 = tf.cast(self.distance_block, tf.float32)[..., None]
433
+ self._batch_1_in = tf.math.reduce_any(dist32 <= self._radii2, axis=1)
434
+ self._batch_2_in = tf.math.reduce_any(dist32 <= self._radii1[:, None], axis=0)
435
+
436
+ def pairwise_distances(self, U, V):
437
+ """
438
+ Evaluate pairwise distances between two batches of feature vectors.
439
+ """
440
+ return self.session.run(
441
+ self.distance_block,
442
+ feed_dict={self._features_batch1: U, self._features_batch2: V},
443
+ )
444
+
445
+ def less_thans(self, batch_1, radii_1, batch_2, radii_2):
446
+ return self.session.run(
447
+ [self._batch_1_in, self._batch_2_in],
448
+ feed_dict={
449
+ self._features_batch1: batch_1,
450
+ self._features_batch2: batch_2,
451
+ self._radii1: radii_1,
452
+ self._radii2: radii_2,
453
+ },
454
+ )
455
+
456
+
457
+ def _batch_pairwise_distances(U, V):
458
+ """
459
+ Compute pairwise distances between two batches of feature vectors.
460
+ """
461
+ with tf.variable_scope("pairwise_dist_block"):
462
+ # Squared norms of each row in U and V.
463
+ norm_u = tf.reduce_sum(tf.square(U), 1)
464
+ norm_v = tf.reduce_sum(tf.square(V), 1)
465
+
466
+ # norm_u as a column and norm_v as a row vectors.
467
+ norm_u = tf.reshape(norm_u, [-1, 1])
468
+ norm_v = tf.reshape(norm_v, [1, -1])
469
+
470
+ # Pairwise squared Euclidean distances.
471
+ D = tf.maximum(norm_u - 2 * tf.matmul(U, V, False, True) + norm_v, 0.0)
472
+
473
+ return D
474
+
475
+
476
+ class NpzArrayReader(ABC):
477
+ @abstractmethod
478
+ def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
479
+ pass
480
+
481
+ @abstractmethod
482
+ def remaining(self) -> int:
483
+ pass
484
+
485
+ def read_batches(self, batch_size: int) -> Iterable[np.ndarray]:
486
+ def gen_fn():
487
+ while True:
488
+ batch = self.read_batch(batch_size)
489
+ if batch is None:
490
+ break
491
+ yield batch
492
+
493
+ rem = self.remaining()
494
+ num_batches = rem // batch_size + int(rem % batch_size != 0)
495
+ return BatchIterator(gen_fn, num_batches)
496
+
497
+
498
+ class BatchIterator:
499
+ def __init__(self, gen_fn, length):
500
+ self.gen_fn = gen_fn
501
+ self.length = length
502
+
503
+ def __len__(self):
504
+ return self.length
505
+
506
+ def __iter__(self):
507
+ return self.gen_fn()
508
+
509
+
510
+ class StreamingNpzArrayReader(NpzArrayReader):
511
+ def __init__(self, arr_f, shape, dtype):
512
+ self.arr_f = arr_f
513
+ self.shape = shape
514
+ self.dtype = dtype
515
+ self.idx = 0
516
+
517
+ def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
518
+ if self.idx >= self.shape[0]:
519
+ return None
520
+
521
+ bs = min(batch_size, self.shape[0] - self.idx)
522
+ self.idx += bs
523
+
524
+ if self.dtype.itemsize == 0:
525
+ return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype)
526
+
527
+ read_count = bs * np.prod(self.shape[1:])
528
+ read_size = int(read_count * self.dtype.itemsize)
529
+ data = _read_bytes(self.arr_f, read_size, "array data")
530
+ return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]])
531
+
532
+ def remaining(self) -> int:
533
+ return max(0, self.shape[0] - self.idx)
534
+
535
+
536
+ class MemoryNpzArrayReader(NpzArrayReader):
537
+ def __init__(self, arr):
538
+ self.arr = arr
539
+ self.idx = 0
540
+
541
+ @classmethod
542
+ def load(cls, path: str, arr_name: str):
543
+ with open(path, "rb") as f:
544
+ arr = np.load(f)[arr_name]
545
+ return cls(arr)
546
+
547
+ def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
548
+ if self.idx >= self.arr.shape[0]:
549
+ return None
550
+
551
+ res = self.arr[self.idx : self.idx + batch_size]
552
+ self.idx += batch_size
553
+ return res
554
+
555
+ def remaining(self) -> int:
556
+ return max(0, self.arr.shape[0] - self.idx)
557
+
558
+
559
+ @contextmanager
560
+ def open_npz_array(path: str, arr_name: str) -> NpzArrayReader:
561
+ with _open_npy_file(path, arr_name) as arr_f:
562
+ version = np.lib.format.read_magic(arr_f)
563
+ if version == (1, 0):
564
+ header = np.lib.format.read_array_header_1_0(arr_f)
565
+ elif version == (2, 0):
566
+ header = np.lib.format.read_array_header_2_0(arr_f)
567
+ else:
568
+ yield MemoryNpzArrayReader.load(path, arr_name)
569
+ return
570
+ shape, fortran, dtype = header
571
+ if fortran or dtype.hasobject:
572
+ yield MemoryNpzArrayReader.load(path, arr_name)
573
+ else:
574
+ yield StreamingNpzArrayReader(arr_f, shape, dtype)
575
+
576
+
577
+ def _read_bytes(fp, size, error_template="ran out of data"):
578
+ """
579
+ Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886
580
+ Read from file-like object until size bytes are read.
581
+ Raises ValueError if not EOF is encountered before size bytes are read.
582
+ Non-blocking objects only supported if they derive from io objects.
583
+ Required as e.g. ZipExtFile in python 2.6 can return less data than
584
+ requested.
585
+ """
586
+ data = bytes()
587
+ while True:
588
+ # io files (default in python3) return None or raise on
589
+ # would-block, python2 file will truncate, probably nothing can be
590
+ # done about that. note that regular files can't be non-blocking
591
+ try:
592
+ r = fp.read(size - len(data))
593
+ data += r
594
+ if len(r) == 0 or len(data) == size:
595
+ break
596
+ except io.BlockingIOError:
597
+ pass
598
+ if len(data) != size:
599
+ msg = "EOF: reading %s, expected %d bytes got %d"
600
+ raise ValueError(msg % (error_template, size, len(data)))
601
+ else:
602
+ return data
603
+
604
+
605
+ @contextmanager
606
+ def _open_npy_file(path: str, arr_name: str):
607
+ with open(path, "rb") as f:
608
+ with zipfile.ZipFile(f, "r") as zip_f:
609
+ if f"{arr_name}.npy" not in zip_f.namelist():
610
+ raise ValueError(f"missing {arr_name} in npz file")
611
+ with zip_f.open(f"{arr_name}.npy", "r") as arr_f:
612
+ yield arr_f
613
+
614
+
615
+ def _download_inception_model():
616
+ if os.path.exists(INCEPTION_V3_PATH):
617
+ return
618
+ print("downloading InceptionV3 model...")
619
+ with requests.get(INCEPTION_V3_URL, stream=True) as r:
620
+ r.raise_for_status()
621
+ tmp_path = INCEPTION_V3_PATH + ".tmp"
622
+ with open(tmp_path, "wb") as f:
623
+ for chunk in tqdm(r.iter_content(chunk_size=8192)):
624
+ f.write(chunk)
625
+ os.rename(tmp_path, INCEPTION_V3_PATH)
626
+
627
+
628
+ def _create_feature_graph(input_batch):
629
+ _download_inception_model()
630
+ prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
631
+ with open(INCEPTION_V3_PATH, "rb") as f:
632
+ graph_def = tf.GraphDef()
633
+ graph_def.ParseFromString(f.read())
634
+ pool3, spatial = tf.import_graph_def(
635
+ graph_def,
636
+ input_map={f"ExpandDims:0": input_batch},
637
+ return_elements=[FID_POOL_NAME, FID_SPATIAL_NAME],
638
+ name=prefix,
639
+ )
640
+ _update_shapes(pool3)
641
+ spatial = spatial[..., :7]
642
+ return pool3, spatial
643
+
644
+
645
+ def _create_softmax_graph(input_batch):
646
+ _download_inception_model()
647
+ prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
648
+ with open(INCEPTION_V3_PATH, "rb") as f:
649
+ graph_def = tf.GraphDef()
650
+ graph_def.ParseFromString(f.read())
651
+ (matmul,) = tf.import_graph_def(
652
+ graph_def, return_elements=[f"softmax/logits/MatMul"], name=prefix
653
+ )
654
+ w = matmul.inputs[1]
655
+ logits = tf.matmul(input_batch, w)
656
+ return tf.nn.softmax(logits)
657
+
658
+
659
+ def _update_shapes(pool3):
660
+ # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L50-L63
661
+ ops = pool3.graph.get_operations()
662
+ for op in ops:
663
+ for o in op.outputs:
664
+ shape = o.get_shape()
665
+ if shape._dims is not None: # pylint: disable=protected-access
666
+ # shape = [s.value for s in shape] TF 1.x
667
+ shape = [s for s in shape] # TF 2.x
668
+ new_shape = []
669
+ for j, s in enumerate(shape):
670
+ if s == 1 and j == 0:
671
+ new_shape.append(None)
672
+ else:
673
+ new_shape.append(s)
674
+ o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
675
+ return pool3
676
+
677
+
678
+ def _numpy_partition(arr, kth, **kwargs):
679
+ num_workers = min(cpu_count(), len(arr))
680
+ chunk_size = len(arr) // num_workers
681
+ extra = len(arr) % num_workers
682
+
683
+ start_idx = 0
684
+ batches = []
685
+ for i in range(num_workers):
686
+ size = chunk_size + (1 if i < extra else 0)
687
+ batches.append(arr[start_idx : start_idx + size])
688
+ start_idx += size
689
+
690
+ with ThreadPool(num_workers) as pool:
691
+ return list(pool.map(partial(np.partition, kth=kth, **kwargs), batches))
692
+
693
+
694
+ if __name__ == "__main__":
695
+ main()
extract_latent.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import time
4
+
5
+ import lmdb
6
+ import torch
7
+ from torch import nn
8
+
9
+ import torchvision.transforms as transforms
10
+ from torch.utils.data import DataLoader
11
+
12
+ from autoencoder import get_model
13
+ from train_utils.datasets import imagenet_lmdb_dataset
14
+
15
+
16
+ def main():
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument('--data_name', default='imagenet', type=str)
19
+ parser.add_argument('--data_dir', default='../datasets', type=str)
20
+ parser.add_argument('--ckpt', default='assets/vae/autoencoder_kl.pth', type=str, help='checkpoint path')
21
+ parser.add_argument('--resolution', default=512, type=int)
22
+ parser.add_argument('--batch_size', default=128, type=int)
23
+ parser.add_argument('--split', default='train', type=str)
24
+ parser.add_argument('--xflip', action='store_true')
25
+ parser.add_argument('--outdir', type=str, default='../data/imagenet512-latent', help='output directory')
26
+ args = parser.parse_args()
27
+
28
+ assert args.split in ['train', 'val']
29
+
30
+ transform = transforms.Compose([
31
+ transforms.ToTensor(),
32
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
33
+ ])
34
+
35
+ dataset = imagenet_lmdb_dataset(root=f'{args.data_dir}/{args.split}',
36
+ transform=transform, resolution=args.resolution)
37
+
38
+ print(f'data size: {len(dataset)}')
39
+
40
+ model = get_model(args.ckpt)
41
+ print(f'load vae weights from autoencoder_kl.pth')
42
+ model = nn.DataParallel(model)
43
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
44
+ model.to(device)
45
+
46
+ def extract_feature():
47
+ outdir = f'{args.data_name}_{args.resolution}_latent_lmdb'
48
+ target_db_dir = os.path.join(args.outdir, outdir, args.split)
49
+ os.makedirs(target_db_dir, exist_ok=True)
50
+ target_env = lmdb.open(target_db_dir, map_size=pow(2,40), readahead=False)
51
+
52
+ dataset_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, drop_last=False,
53
+ num_workers=8, pin_memory=True, persistent_workers=True)
54
+
55
+ idx = 0
56
+ begin = time.time()
57
+ print('start...')
58
+ for batch in dataset_loader:
59
+ img, label = batch
60
+ assert img.min() >= -1 and img.max() <= 1
61
+
62
+ img = img.to(device)
63
+ moments = model(img, fn='encode_moments')
64
+ assert moments.shape[-1] == (args.resolution // 8)
65
+
66
+ moments = moments.detach().cpu().numpy()
67
+ label = label.detach().cpu().numpy()
68
+
69
+ with target_env.begin(write=True) as target_txn:
70
+ for moment, lb in zip(moments, label):
71
+ target_txn.put(f'z-{str(idx)}'.encode('utf-8'), moment)
72
+ target_txn.put(f'y-{str(idx)}'.encode('utf-8'), str(lb).encode('utf-8'))
73
+ idx += 1
74
+
75
+ if idx % 5120 == 0:
76
+ cur_time = time.time()
77
+ print(f'saved {idx} files with {cur_time - begin}s elapsed')
78
+ begin = time.time()
79
+
80
+ # idx = 1_281_167
81
+ if args.xflip:
82
+ print('starting to store the xflip latents')
83
+ begin = time.time()
84
+ for batch in dataset_loader:
85
+ img, label = batch
86
+ assert img.min() >= -1 and img.max() <= 1
87
+
88
+ img = img.to(device)
89
+ moments = model(img.flip(dims=[-1]), fn='encode_moments')
90
+
91
+ moments = moments.detach().cpu().numpy()
92
+ label = label.detach().cpu().numpy()
93
+
94
+ with target_env.begin(write=True) as target_txn:
95
+ for moment, lb in zip(moments, label):
96
+ target_txn.put(f'z-{str(idx)}'.encode('utf-8'), moment)
97
+ target_txn.put(f'y-{str(idx)}'.encode('utf-8'), str(lb).encode('utf-8'))
98
+ idx += 1
99
+
100
+ if idx % 10000 == 0:
101
+ cur_time = time.time()
102
+ print(f'saved {idx} files with {cur_time - begin}s elapsed')
103
+ begin = time.time()
104
+
105
+ with target_env.begin(write=True) as target_txn:
106
+ target_txn.put('length'.encode('utf-8'), str(idx).encode('utf-8'))
107
+
108
+ print(f'[finished] saved {idx} files')
109
+
110
+ extract_feature()
111
+
112
+
113
+ if __name__ == "__main__":
114
+ main()
fid.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) [2023] [Anima-Lab]
4
+
5
+ # This code is adapted from https://github.com/NVlabs/edm/blob/main/fid.py.
6
+ # The original code is licensed under a Creative Commons
7
+ # Attribution-NonCommercial-ShareAlike 4.0 International License, which is can be found at licenses/LICENSE_EDM.txt.
8
+
9
+ """Script for calculating Frechet Inception Distance (FID)."""
10
+ import argparse
11
+ from multiprocessing import Process
12
+
13
+ import click
14
+ import tqdm
15
+ import pickle
16
+ import numpy as np
17
+ import scipy.linalg
18
+ import torch
19
+ import torch.distributed as dist
20
+ from torch.utils.data import DataLoader
21
+
22
+ from utils import *
23
+ from train_utils.datasets import ImageFolderDataset
24
+
25
+
26
+ #----------------------------------------------------------------------------
27
+
28
+ def calculate_inception_stats(
29
+ image_path, num_expected=None, seed=0, max_batch_size=64,
30
+ num_workers=3, prefetch_factor=2, device=torch.device('cuda'),
31
+ ):
32
+ # Rank 0 goes first.
33
+ if dist.get_rank() != 0:
34
+ dist.barrier()
35
+
36
+ # Load Inception-v3 model.
37
+ # This is a direct PyTorch translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
38
+ detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
39
+ mprint('Loading Inception-v3 model...')
40
+ detector_kwargs = dict(return_features=True)
41
+ feature_dim = 2048
42
+ with open(detector_url, 'rb') as f:
43
+ detector_net = pickle.load(f).to(device)
44
+
45
+ # List images.
46
+ mprint(f'Loading images from "{image_path}"...')
47
+ dataset_obj = ImageFolderDataset(path=image_path, max_size=num_expected, random_seed=seed)
48
+ if num_expected is not None and len(dataset_obj) < num_expected:
49
+ raise click.ClickException(f'Found {len(dataset_obj)} images, but expected at least {num_expected}')
50
+ if len(dataset_obj) < 2:
51
+ raise click.ClickException(f'Found {len(dataset_obj)} images, but need at least 2 to compute statistics')
52
+
53
+ # Other ranks follow.
54
+ if dist.get_rank() == 0:
55
+ dist.barrier()
56
+
57
+ # Divide images into batches.
58
+ num_batches = ((len(dataset_obj) - 1) // (max_batch_size * dist.get_world_size()) + 1) * dist.get_world_size()
59
+ all_batches = torch.arange(len(dataset_obj)).tensor_split(num_batches)
60
+ rank_batches = all_batches[dist.get_rank() :: dist.get_world_size()]
61
+ data_loader = DataLoader(dataset_obj, batch_sampler=rank_batches, num_workers=num_workers, prefetch_factor=prefetch_factor)
62
+
63
+ # Accumulate statistics.
64
+ mprint(f'Calculating statistics for {len(dataset_obj)} images...')
65
+ mu = torch.zeros([feature_dim], dtype=torch.float64, device=device)
66
+ sigma = torch.zeros([feature_dim, feature_dim], dtype=torch.float64, device=device)
67
+ for images, _labels in tqdm.tqdm(data_loader, unit='batch', disable=(dist.get_rank() != 0)):
68
+ dist.barrier()
69
+ if images.shape[0] == 0:
70
+ continue
71
+ if images.shape[1] == 1:
72
+ images = images.repeat([1, 3, 1, 1])
73
+ features = detector_net(images.to(device), **detector_kwargs).to(torch.float64)
74
+ mu += features.sum(0)
75
+ sigma += features.T @ features
76
+
77
+ # Calculate grand totals.
78
+ dist.all_reduce(mu)
79
+ dist.all_reduce(sigma)
80
+ mu /= len(dataset_obj)
81
+ sigma -= mu.ger(mu) * len(dataset_obj)
82
+ sigma /= len(dataset_obj) - 1
83
+ return mu.cpu().numpy(), sigma.cpu().numpy()
84
+
85
+ #----------------------------------------------------------------------------
86
+
87
+ def calculate_fid_from_inception_stats(mu, sigma, mu_ref, sigma_ref):
88
+ m = np.square(mu - mu_ref).sum()
89
+ s, _ = scipy.linalg.sqrtm(np.dot(sigma, sigma_ref), disp=False)
90
+ fid = m + np.trace(sigma + sigma_ref - s * 2)
91
+ return float(np.real(fid))
92
+
93
+ #----------------------------------------------------------------------------
94
+
95
+
96
+ def calc(image_path, ref_path, num_expected, seed, batch):
97
+ """Calculate FID for a given set of images."""
98
+ if dist.get_rank() == 0:
99
+ logger = Logger(file_name=f'{image_path}/log_fid.txt', file_mode="a+", should_flush=True)
100
+
101
+ mprint(f'Loading dataset reference statistics from "{ref_path}"...')
102
+ ref = None
103
+ if dist.get_rank() == 0:
104
+ assert ref_path.endswith('.npz')
105
+ ref = dict(np.load(ref_path))
106
+
107
+ mu, sigma = calculate_inception_stats(image_path=image_path, num_expected=num_expected, seed=seed, max_batch_size=batch)
108
+ mprint('Calculating FID...')
109
+ fid = None
110
+ if dist.get_rank() == 0:
111
+ fid = calculate_fid_from_inception_stats(mu, sigma, ref['mu'], ref['sigma'])
112
+ print(f'{fid:g}')
113
+
114
+ dist.barrier()
115
+ if dist.get_rank() == 0:
116
+ logger.close()
117
+
118
+ return fid
119
+
120
+ #----------------------------------------------------------------------------
121
+
122
+
123
+ def ref(dataset_path, dest_path, batch):
124
+ """Calculate dataset reference statistics needed by 'calc'."""
125
+
126
+ mu, sigma = calculate_inception_stats(image_path=dataset_path, max_batch_size=batch)
127
+ mprint(f'Saving dataset reference statistics to "{dest_path}"...')
128
+ if dist.get_rank() == 0:
129
+ if os.path.dirname(dest_path):
130
+ os.makedirs(os.path.dirname(dest_path), exist_ok=True)
131
+ np.savez(dest_path, mu=mu, sigma=sigma)
132
+
133
+ dist.barrier()
134
+ mprint('Done.')
135
+
136
+
137
+ if __name__ == '__main__':
138
+ parser = argparse.ArgumentParser('fid parameters')
139
+
140
+ # ddp
141
+ parser.add_argument('--num_proc_node', type=int, default=1, help='The number of nodes in multi node env.')
142
+ parser.add_argument('--num_process_per_node', type=int, default=1, help='number of gpus')
143
+ parser.add_argument('--node_rank', type=int, default=0, help='The index of node.')
144
+ parser.add_argument('--local_rank', type=int, default=0, help='rank of process in the node')
145
+ parser.add_argument('--master_address', type=str, default='localhost', help='address for master')
146
+
147
+ # fid
148
+ parser.add_argument('--mode', type=str, required=True, choices=['calc', 'ref'], help='Calcalute FID or store reference statistics')
149
+ parser.add_argument('--image_path', type=str, required=True, help='Path to the images')
150
+ parser.add_argument('--ref_path', type=str, default='assets/fid_stats/fid_stats_imagenet256_guided_diffusion.npz', help='Dataset reference statistics')
151
+ parser.add_argument('--num_expected', type=int, default=50000, help='Number of images to use')
152
+ parser.add_argument('--seed', type=int, default=0, help='Random seed for selecting the images')
153
+ parser.add_argument('--batch', type=int, default=64, help='Maximum batch size per GPU')
154
+
155
+ args = parser.parse_args()
156
+ args.global_size = args.num_proc_node * args.num_process_per_node
157
+ size = args.num_process_per_node
158
+
159
+ func = lambda args: calc(args.image_path, args.ref_path, args.num_expected, args.seed, args.batch) \
160
+ if args.mode == 'calc' else lambda args: ref(args.image_path, args.ref_path, args.batch)
161
+
162
+ if size > 1:
163
+ processes = []
164
+ for rank in range(size):
165
+ args.local_rank = rank
166
+ args.global_rank = rank + args.node_rank * args.num_process_per_node
167
+ p = Process(target=init_processes, args=(func, args))
168
+ p.start()
169
+ processes.append(p)
170
+
171
+ for p in processes:
172
+ p.join()
173
+ else:
174
+ print('Single GPU run')
175
+ assert args.global_size == 1 and args.local_rank == 0
176
+ args.global_rank = 0
177
+ init_processes(func, args)
generate.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) [2023] [Anima-Lab]
4
+
5
+
6
+ from argparse import ArgumentParser
7
+ import os
8
+ import json
9
+ from omegaconf import OmegaConf
10
+
11
+ import torch
12
+ from models.maskdit import Precond_models
13
+
14
+ from sample import generate_with_net
15
+ from utils import parse_float_none, parse_int_list, init_processes
16
+
17
+
18
+ def generate(args):
19
+ rank = args.global_rank
20
+ size = args.global_size
21
+ config = OmegaConf.load(args.config)
22
+ label_dict = json.load(open(args.label_dict, 'r'))
23
+ class_label = label_dict[str(args.class_idx)][1]
24
+ print(f'start sampling class {class_label}...')
25
+ device = torch.device('cuda')
26
+ # setup directory
27
+ sample_dir = os.path.join(args.results_dir, class_label)
28
+ os.makedirs(sample_dir, exist_ok=True)
29
+ args.outdir = sample_dir
30
+ # setup model
31
+ model = Precond_models[config.model.precond](
32
+ img_resolution=config.model.in_size,
33
+ img_channels=config.model.in_channels,
34
+ num_classes=config.model.num_classes,
35
+ model_type=config.model.model_type,
36
+ use_decoder=config.model.use_decoder,
37
+ mae_loss_coef=config.model.mae_loss_coef,
38
+ pad_cls_token=config.model.pad_cls_token,
39
+ use_encoder_feat=config.model.self_cond,
40
+ ).to(device)
41
+
42
+ model.eval()
43
+ print(f"{config.model.model_type} ((use_decoder: {config.model.use_decoder})) Model Parameters: {sum(p.numel() for p in model.parameters()):,}")
44
+ print(f'extras: {model.model.extras}, cls_token: {model.model.cls_token}')
45
+
46
+ model = torch.compile(model)
47
+ ckpt = torch.load(args.ckpt_path, map_location=device)
48
+ model.load_state_dict(ckpt['ema'])
49
+ generate_with_net(args, model, device, rank, size)
50
+
51
+ print(f'sampling class {class_label} done!')
52
+
53
+
54
+ if __name__ == '__main__':
55
+ parser = ArgumentParser('Sample from a trained model')
56
+ # basic config
57
+ parser.add_argument('--config', type=str, required=True, help='path to config file')
58
+ parser.add_argument('--label_dict', type=str, default='assets/imagenet_label.json', help='path to label dict')
59
+ parser.add_argument("--results_dir", type=str, default="samples", help='path to save samples')
60
+ parser.add_argument('--ckpt_path', type=str, default=None, help='path to ckpt')
61
+
62
+ # sampling
63
+ parser.add_argument('--seeds', type=parse_int_list, default='100-131', help='Random seeds (e.g. 1,2,5-10)')
64
+ parser.add_argument('--subdirs', action='store_true', help='Create subdirectory for every 1000 seeds')
65
+ parser.add_argument('--class_idx', type=int, default=None, help='Class label [default: random]')
66
+ parser.add_argument("--cfg_scale", type=parse_float_none, default=None, help='None = no guidance, by default = 4.0')
67
+
68
+ parser.add_argument('--num_steps', type=int, default=40, help='Number of sampling steps')
69
+ parser.add_argument('--S_churn', type=int, default=0, help='Stochasticity strength')
70
+ parser.add_argument('--solver', type=str, default=None, choices=['euler', 'heun'], help='Ablate ODE solver')
71
+ parser.add_argument('--discretization', type=str, default=None, choices=['vp', 've', 'iddpm', 'edm'], help='Ablate ODE solver')
72
+ parser.add_argument('--schedule', type=str, default=None, choices=['vp', 've', 'linear'], help='Ablate noise schedule sigma(t)')
73
+ parser.add_argument('--scaling', type=str, default=None, choices=['vp', 'none'], help='Ablate signal scaling s(t)')
74
+ parser.add_argument('--pretrained_path', type=str, default='assets/autoencoder_kl.pth', help='Autoencoder ckpt')
75
+
76
+ parser.add_argument('--max_batch_size', type=int, default=32, help='Maximum batch size per GPU during sampling')
77
+ parser.add_argument('--num_expected', type=int, default=32, help='Number of images to use')
78
+ parser.add_argument("--global_seed", type=int, default=0)
79
+ parser.add_argument('--fid_batch_size', type=int, default=32, help='Maximum batch size')
80
+
81
+ # ddp
82
+ parser.add_argument('--num_proc_node', type=int, default=1, help='The number of nodes in multi node env.')
83
+ parser.add_argument('--num_process_per_node', type=int, default=1, help='number of gpus')
84
+ parser.add_argument('--node_rank', type=int, default=0, help='The index of node.')
85
+ parser.add_argument('--local_rank', type=int, default=0, help='rank of process in the node')
86
+ parser.add_argument('--master_address', type=str, default='localhost', help='address for master')
87
+ args = parser.parse_args()
88
+ args.global_rank = 0
89
+ args.local_rank = 0
90
+ args.global_size = 1
91
+ init_processes(generate, args)
licenses/LICENSE_ADM.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2021 OpenAI
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
licenses/LICENSE_DIT.txt ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Attribution-NonCommercial 4.0 International
3
+
4
+ =======================================================================
5
+
6
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
7
+ does not provide legal services or legal advice. Distribution of
8
+ Creative Commons public licenses does not create a lawyer-client or
9
+ other relationship. Creative Commons makes its licenses and related
10
+ information available on an "as-is" basis. Creative Commons gives no
11
+ warranties regarding its licenses, any material licensed under their
12
+ terms and conditions, or any related information. Creative Commons
13
+ disclaims all liability for damages resulting from their use to the
14
+ fullest extent possible.
15
+
16
+ Using Creative Commons Public Licenses
17
+
18
+ Creative Commons public licenses provide a standard set of terms and
19
+ conditions that creators and other rights holders may use to share
20
+ original works of authorship and other material subject to copyright
21
+ and certain other rights specified in the public license below. The
22
+ following considerations are for informational purposes only, are not
23
+ exhaustive, and do not form part of our licenses.
24
+
25
+ Considerations for licensors: Our public licenses are
26
+ intended for use by those authorized to give the public
27
+ permission to use material in ways otherwise restricted by
28
+ copyright and certain other rights. Our licenses are
29
+ irrevocable. Licensors should read and understand the terms
30
+ and conditions of the license they choose before applying it.
31
+ Licensors should also secure all rights necessary before
32
+ applying our licenses so that the public can reuse the
33
+ material as expected. Licensors should clearly mark any
34
+ material not subject to the license. This includes other CC-
35
+ licensed material, or material used under an exception or
36
+ limitation to copyright. More considerations for licensors:
37
+ wiki.creativecommons.org/Considerations_for_licensors
38
+
39
+ Considerations for the public: By using one of our public
40
+ licenses, a licensor grants the public permission to use the
41
+ licensed material under specified terms and conditions. If
42
+ the licensor's permission is not necessary for any reason--for
43
+ example, because of any applicable exception or limitation to
44
+ copyright--then that use is not regulated by the license. Our
45
+ licenses grant only permissions under copyright and certain
46
+ other rights that a licensor has authority to grant. Use of
47
+ the licensed material may still be restricted for other
48
+ reasons, including because others have copyright or other
49
+ rights in the material. A licensor may make special requests,
50
+ such as asking that all changes be marked or described.
51
+ Although not required by our licenses, you are encouraged to
52
+ respect those requests where reasonable. More_considerations
53
+ for the public:
54
+ wiki.creativecommons.org/Considerations_for_licensees
55
+
56
+ =======================================================================
57
+
58
+ Creative Commons Attribution-NonCommercial 4.0 International Public
59
+ License
60
+
61
+ By exercising the Licensed Rights (defined below), You accept and agree
62
+ to be bound by the terms and conditions of this Creative Commons
63
+ Attribution-NonCommercial 4.0 International Public License ("Public
64
+ License"). To the extent this Public License may be interpreted as a
65
+ contract, You are granted the Licensed Rights in consideration of Your
66
+ acceptance of these terms and conditions, and the Licensor grants You
67
+ such rights in consideration of benefits the Licensor receives from
68
+ making the Licensed Material available under these terms and
69
+ conditions.
70
+
71
+ Section 1 -- Definitions.
72
+
73
+ a. Adapted Material means material subject to Copyright and Similar
74
+ Rights that is derived from or based upon the Licensed Material
75
+ and in which the Licensed Material is translated, altered,
76
+ arranged, transformed, or otherwise modified in a manner requiring
77
+ permission under the Copyright and Similar Rights held by the
78
+ Licensor. For purposes of this Public License, where the Licensed
79
+ Material is a musical work, performance, or sound recording,
80
+ Adapted Material is always produced where the Licensed Material is
81
+ synched in timed relation with a moving image.
82
+
83
+ b. Adapter's License means the license You apply to Your Copyright
84
+ and Similar Rights in Your contributions to Adapted Material in
85
+ accordance with the terms and conditions of this Public License.
86
+
87
+ c. Copyright and Similar Rights means copyright and/or similar rights
88
+ closely related to copyright including, without limitation,
89
+ performance, broadcast, sound recording, and Sui Generis Database
90
+ Rights, without regard to how the rights are labeled or
91
+ categorized. For purposes of this Public License, the rights
92
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
93
+ Rights.
94
+ d. Effective Technological Measures means those measures that, in the
95
+ absence of proper authority, may not be circumvented under laws
96
+ fulfilling obligations under Article 11 of the WIPO Copyright
97
+ Treaty adopted on December 20, 1996, and/or similar international
98
+ agreements.
99
+
100
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
101
+ any other exception or limitation to Copyright and Similar Rights
102
+ that applies to Your use of the Licensed Material.
103
+
104
+ f. Licensed Material means the artistic or literary work, database,
105
+ or other material to which the Licensor applied this Public
106
+ License.
107
+
108
+ g. Licensed Rights means the rights granted to You subject to the
109
+ terms and conditions of this Public License, which are limited to
110
+ all Copyright and Similar Rights that apply to Your use of the
111
+ Licensed Material and that the Licensor has authority to license.
112
+
113
+ h. Licensor means the individual(s) or entity(ies) granting rights
114
+ under this Public License.
115
+
116
+ i. NonCommercial means not primarily intended for or directed towards
117
+ commercial advantage or monetary compensation. For purposes of
118
+ this Public License, the exchange of the Licensed Material for
119
+ other material subject to Copyright and Similar Rights by digital
120
+ file-sharing or similar means is NonCommercial provided there is
121
+ no payment of monetary compensation in connection with the
122
+ exchange.
123
+
124
+ j. Share means to provide material to the public by any means or
125
+ process that requires permission under the Licensed Rights, such
126
+ as reproduction, public display, public performance, distribution,
127
+ dissemination, communication, or importation, and to make material
128
+ available to the public including in ways that members of the
129
+ public may access the material from a place and at a time
130
+ individually chosen by them.
131
+
132
+ k. Sui Generis Database Rights means rights other than copyright
133
+ resulting from Directive 96/9/EC of the European Parliament and of
134
+ the Council of 11 March 1996 on the legal protection of databases,
135
+ as amended and/or succeeded, as well as other essentially
136
+ equivalent rights anywhere in the world.
137
+
138
+ l. You means the individual or entity exercising the Licensed Rights
139
+ under this Public License. Your has a corresponding meaning.
140
+
141
+ Section 2 -- Scope.
142
+
143
+ a. License grant.
144
+
145
+ 1. Subject to the terms and conditions of this Public License,
146
+ the Licensor hereby grants You a worldwide, royalty-free,
147
+ non-sublicensable, non-exclusive, irrevocable license to
148
+ exercise the Licensed Rights in the Licensed Material to:
149
+
150
+ a. reproduce and Share the Licensed Material, in whole or
151
+ in part, for NonCommercial purposes only; and
152
+
153
+ b. produce, reproduce, and Share Adapted Material for
154
+ NonCommercial purposes only.
155
+
156
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
157
+ Exceptions and Limitations apply to Your use, this Public
158
+ License does not apply, and You do not need to comply with
159
+ its terms and conditions.
160
+
161
+ 3. Term. The term of this Public License is specified in Section
162
+ 6(a).
163
+
164
+ 4. Media and formats; technical modifications allowed. The
165
+ Licensor authorizes You to exercise the Licensed Rights in
166
+ all media and formats whether now known or hereafter created,
167
+ and to make technical modifications necessary to do so. The
168
+ Licensor waives and/or agrees not to assert any right or
169
+ authority to forbid You from making technical modifications
170
+ necessary to exercise the Licensed Rights, including
171
+ technical modifications necessary to circumvent Effective
172
+ Technological Measures. For purposes of this Public License,
173
+ simply making modifications authorized by this Section 2(a)
174
+ (4) never produces Adapted Material.
175
+
176
+ 5. Downstream recipients.
177
+
178
+ a. Offer from the Licensor -- Licensed Material. Every
179
+ recipient of the Licensed Material automatically
180
+ receives an offer from the Licensor to exercise the
181
+ Licensed Rights under the terms and conditions of this
182
+ Public License.
183
+
184
+ b. No downstream restrictions. You may not offer or impose
185
+ any additional or different terms or conditions on, or
186
+ apply any Effective Technological Measures to, the
187
+ Licensed Material if doing so restricts exercise of the
188
+ Licensed Rights by any recipient of the Licensed
189
+ Material.
190
+
191
+ 6. No endorsement. Nothing in this Public License constitutes or
192
+ may be construed as permission to assert or imply that You
193
+ are, or that Your use of the Licensed Material is, connected
194
+ with, or sponsored, endorsed, or granted official status by,
195
+ the Licensor or others designated to receive attribution as
196
+ provided in Section 3(a)(1)(A)(i).
197
+
198
+ b. Other rights.
199
+
200
+ 1. Moral rights, such as the right of integrity, are not
201
+ licensed under this Public License, nor are publicity,
202
+ privacy, and/or other similar personality rights; however, to
203
+ the extent possible, the Licensor waives and/or agrees not to
204
+ assert any such rights held by the Licensor to the limited
205
+ extent necessary to allow You to exercise the Licensed
206
+ Rights, but not otherwise.
207
+
208
+ 2. Patent and trademark rights are not licensed under this
209
+ Public License.
210
+
211
+ 3. To the extent possible, the Licensor waives any right to
212
+ collect royalties from You for the exercise of the Licensed
213
+ Rights, whether directly or through a collecting society
214
+ under any voluntary or waivable statutory or compulsory
215
+ licensing scheme. In all other cases the Licensor expressly
216
+ reserves any right to collect such royalties, including when
217
+ the Licensed Material is used other than for NonCommercial
218
+ purposes.
219
+
220
+ Section 3 -- License Conditions.
221
+
222
+ Your exercise of the Licensed Rights is expressly made subject to the
223
+ following conditions.
224
+
225
+ a. Attribution.
226
+
227
+ 1. If You Share the Licensed Material (including in modified
228
+ form), You must:
229
+
230
+ a. retain the following if it is supplied by the Licensor
231
+ with the Licensed Material:
232
+
233
+ i. identification of the creator(s) of the Licensed
234
+ Material and any others designated to receive
235
+ attribution, in any reasonable manner requested by
236
+ the Licensor (including by pseudonym if
237
+ designated);
238
+
239
+ ii. a copyright notice;
240
+
241
+ iii. a notice that refers to this Public License;
242
+
243
+ iv. a notice that refers to the disclaimer of
244
+ warranties;
245
+
246
+ v. a URI or hyperlink to the Licensed Material to the
247
+ extent reasonably practicable;
248
+
249
+ b. indicate if You modified the Licensed Material and
250
+ retain an indication of any previous modifications; and
251
+
252
+ c. indicate the Licensed Material is licensed under this
253
+ Public License, and include the text of, or the URI or
254
+ hyperlink to, this Public License.
255
+
256
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
257
+ reasonable manner based on the medium, means, and context in
258
+ which You Share the Licensed Material. For example, it may be
259
+ reasonable to satisfy the conditions by providing a URI or
260
+ hyperlink to a resource that includes the required
261
+ information.
262
+
263
+ 3. If requested by the Licensor, You must remove any of the
264
+ information required by Section 3(a)(1)(A) to the extent
265
+ reasonably practicable.
266
+
267
+ 4. If You Share Adapted Material You produce, the Adapter's
268
+ License You apply must not prevent recipients of the Adapted
269
+ Material from complying with this Public License.
270
+
271
+ Section 4 -- Sui Generis Database Rights.
272
+
273
+ Where the Licensed Rights include Sui Generis Database Rights that
274
+ apply to Your use of the Licensed Material:
275
+
276
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
277
+ to extract, reuse, reproduce, and Share all or a substantial
278
+ portion of the contents of the database for NonCommercial purposes
279
+ only;
280
+
281
+ b. if You include all or a substantial portion of the database
282
+ contents in a database in which You have Sui Generis Database
283
+ Rights, then the database in which You have Sui Generis Database
284
+ Rights (but not its individual contents) is Adapted Material; and
285
+
286
+ c. You must comply with the conditions in Section 3(a) if You Share
287
+ all or a substantial portion of the contents of the database.
288
+
289
+ For the avoidance of doubt, this Section 4 supplements and does not
290
+ replace Your obligations under this Public License where the Licensed
291
+ Rights include other Copyright and Similar Rights.
292
+
293
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
294
+
295
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
296
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
297
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
298
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
299
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
300
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
301
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
302
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
303
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
304
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
305
+
306
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
307
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
308
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
309
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
310
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
311
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
312
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
313
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
314
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
315
+
316
+ c. The disclaimer of warranties and limitation of liability provided
317
+ above shall be interpreted in a manner that, to the extent
318
+ possible, most closely approximates an absolute disclaimer and
319
+ waiver of all liability.
320
+
321
+ Section 6 -- Term and Termination.
322
+
323
+ a. This Public License applies for the term of the Copyright and
324
+ Similar Rights licensed here. However, if You fail to comply with
325
+ this Public License, then Your rights under this Public License
326
+ terminate automatically.
327
+
328
+ b. Where Your right to use the Licensed Material has terminated under
329
+ Section 6(a), it reinstates:
330
+
331
+ 1. automatically as of the date the violation is cured, provided
332
+ it is cured within 30 days of Your discovery of the
333
+ violation; or
334
+
335
+ 2. upon express reinstatement by the Licensor.
336
+
337
+ For the avoidance of doubt, this Section 6(b) does not affect any
338
+ right the Licensor may have to seek remedies for Your violations
339
+ of this Public License.
340
+
341
+ c. For the avoidance of doubt, the Licensor may also offer the
342
+ Licensed Material under separate terms or conditions or stop
343
+ distributing the Licensed Material at any time; however, doing so
344
+ will not terminate this Public License.
345
+
346
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
347
+ License.
348
+
349
+ Section 7 -- Other Terms and Conditions.
350
+
351
+ a. The Licensor shall not be bound by any additional or different
352
+ terms or conditions communicated by You unless expressly agreed.
353
+
354
+ b. Any arrangements, understandings, or agreements regarding the
355
+ Licensed Material not stated herein are separate from and
356
+ independent of the terms and conditions of this Public License.
357
+
358
+ Section 8 -- Interpretation.
359
+
360
+ a. For the avoidance of doubt, this Public License does not, and
361
+ shall not be interpreted to, reduce, limit, restrict, or impose
362
+ conditions on any use of the Licensed Material that could lawfully
363
+ be made without permission under this Public License.
364
+
365
+ b. To the extent possible, if any provision of this Public License is
366
+ deemed unenforceable, it shall be automatically reformed to the
367
+ minimum extent necessary to make it enforceable. If the provision
368
+ cannot be reformed, it shall be severed from this Public License
369
+ without affecting the enforceability of the remaining terms and
370
+ conditions.
371
+
372
+ c. No term or condition of this Public License will be waived and no
373
+ failure to comply consented to unless expressly agreed to by the
374
+ Licensor.
375
+
376
+ d. Nothing in this Public License constitutes or may be interpreted
377
+ as a limitation upon, or waiver of, any privileges and immunities
378
+ that apply to the Licensor or You, including from the legal
379
+ processes of any jurisdiction or authority.
380
+
381
+ =======================================================================
382
+
383
+ Creative Commons is not a party to its public
384
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
385
+ its public licenses to material it publishes and in those instances
386
+ will be considered the “Licensor.” The text of the Creative Commons
387
+ public licenses is dedicated to the public domain under the CC0 Public
388
+ Domain Dedication. Except for the limited purpose of indicating that
389
+ material is shared under a Creative Commons public license or as
390
+ otherwise permitted by the Creative Commons policies published at
391
+ creativecommons.org/policies, Creative Commons does not authorize the
392
+ use of the trademark "Creative Commons" or any other trademark or logo
393
+ of Creative Commons without its prior written consent including,
394
+ without limitation, in connection with any unauthorized modifications
395
+ to any of its public licenses or any other arrangements,
396
+ understandings, or agreements concerning use of licensed material. For
397
+ the avoidance of doubt, this paragraph does not form part of the
398
+ public licenses.
399
+
400
+ Creative Commons may be contacted at creativecommons.org.
licenses/LICENSE_EDM.txt ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+
3
+ Attribution-NonCommercial-ShareAlike 4.0 International
4
+
5
+ =======================================================================
6
+
7
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
8
+ does not provide legal services or legal advice. Distribution of
9
+ Creative Commons public licenses does not create a lawyer-client or
10
+ other relationship. Creative Commons makes its licenses and related
11
+ information available on an "as-is" basis. Creative Commons gives no
12
+ warranties regarding its licenses, any material licensed under their
13
+ terms and conditions, or any related information. Creative Commons
14
+ disclaims all liability for damages resulting from their use to the
15
+ fullest extent possible.
16
+
17
+ Using Creative Commons Public Licenses
18
+
19
+ Creative Commons public licenses provide a standard set of terms and
20
+ conditions that creators and other rights holders may use to share
21
+ original works of authorship and other material subject to copyright
22
+ and certain other rights specified in the public license below. The
23
+ following considerations are for informational purposes only, are not
24
+ exhaustive, and do not form part of our licenses.
25
+
26
+ Considerations for licensors: Our public licenses are
27
+ intended for use by those authorized to give the public
28
+ permission to use material in ways otherwise restricted by
29
+ copyright and certain other rights. Our licenses are
30
+ irrevocable. Licensors should read and understand the terms
31
+ and conditions of the license they choose before applying it.
32
+ Licensors should also secure all rights necessary before
33
+ applying our licenses so that the public can reuse the
34
+ material as expected. Licensors should clearly mark any
35
+ material not subject to the license. This includes other CC-
36
+ licensed material, or material used under an exception or
37
+ limitation to copyright. More considerations for licensors:
38
+ wiki.creativecommons.org/Considerations_for_licensors
39
+
40
+ Considerations for the public: By using one of our public
41
+ licenses, a licensor grants the public permission to use the
42
+ licensed material under specified terms and conditions. If
43
+ the licensor's permission is not necessary for any reason--for
44
+ example, because of any applicable exception or limitation to
45
+ copyright--then that use is not regulated by the license. Our
46
+ licenses grant only permissions under copyright and certain
47
+ other rights that a licensor has authority to grant. Use of
48
+ the licensed material may still be restricted for other
49
+ reasons, including because others have copyright or other
50
+ rights in the material. A licensor may make special requests,
51
+ such as asking that all changes be marked or described.
52
+ Although not required by our licenses, you are encouraged to
53
+ respect those requests where reasonable. More considerations
54
+ for the public:
55
+ wiki.creativecommons.org/Considerations_for_licensees
56
+
57
+ =======================================================================
58
+
59
+ Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
60
+ Public License
61
+
62
+ By exercising the Licensed Rights (defined below), You accept and agree
63
+ to be bound by the terms and conditions of this Creative Commons
64
+ Attribution-NonCommercial-ShareAlike 4.0 International Public License
65
+ ("Public License"). To the extent this Public License may be
66
+ interpreted as a contract, You are granted the Licensed Rights in
67
+ consideration of Your acceptance of these terms and conditions, and the
68
+ Licensor grants You such rights in consideration of benefits the
69
+ Licensor receives from making the Licensed Material available under
70
+ these terms and conditions.
71
+
72
+
73
+ Section 1 -- Definitions.
74
+
75
+ a. Adapted Material means material subject to Copyright and Similar
76
+ Rights that is derived from or based upon the Licensed Material
77
+ and in which the Licensed Material is translated, altered,
78
+ arranged, transformed, or otherwise modified in a manner requiring
79
+ permission under the Copyright and Similar Rights held by the
80
+ Licensor. For purposes of this Public License, where the Licensed
81
+ Material is a musical work, performance, or sound recording,
82
+ Adapted Material is always produced where the Licensed Material is
83
+ synched in timed relation with a moving image.
84
+
85
+ b. Adapter's License means the license You apply to Your Copyright
86
+ and Similar Rights in Your contributions to Adapted Material in
87
+ accordance with the terms and conditions of this Public License.
88
+
89
+ c. BY-NC-SA Compatible License means a license listed at
90
+ creativecommons.org/compatiblelicenses, approved by Creative
91
+ Commons as essentially the equivalent of this Public License.
92
+
93
+ d. Copyright and Similar Rights means copyright and/or similar rights
94
+ closely related to copyright including, without limitation,
95
+ performance, broadcast, sound recording, and Sui Generis Database
96
+ Rights, without regard to how the rights are labeled or
97
+ categorized. For purposes of this Public License, the rights
98
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
99
+ Rights.
100
+
101
+ e. Effective Technological Measures means those measures that, in the
102
+ absence of proper authority, may not be circumvented under laws
103
+ fulfilling obligations under Article 11 of the WIPO Copyright
104
+ Treaty adopted on December 20, 1996, and/or similar international
105
+ agreements.
106
+
107
+ f. Exceptions and Limitations means fair use, fair dealing, and/or
108
+ any other exception or limitation to Copyright and Similar Rights
109
+ that applies to Your use of the Licensed Material.
110
+
111
+ g. License Elements means the license attributes listed in the name
112
+ of a Creative Commons Public License. The License Elements of this
113
+ Public License are Attribution, NonCommercial, and ShareAlike.
114
+
115
+ h. Licensed Material means the artistic or literary work, database,
116
+ or other material to which the Licensor applied this Public
117
+ License.
118
+
119
+ i. Licensed Rights means the rights granted to You subject to the
120
+ terms and conditions of this Public License, which are limited to
121
+ all Copyright and Similar Rights that apply to Your use of the
122
+ Licensed Material and that the Licensor has authority to license.
123
+
124
+ j. Licensor means the individual(s) or entity(ies) granting rights
125
+ under this Public License.
126
+
127
+ k. NonCommercial means not primarily intended for or directed towards
128
+ commercial advantage or monetary compensation. For purposes of
129
+ this Public License, the exchange of the Licensed Material for
130
+ other material subject to Copyright and Similar Rights by digital
131
+ file-sharing or similar means is NonCommercial provided there is
132
+ no payment of monetary compensation in connection with the
133
+ exchange.
134
+
135
+ l. Share means to provide material to the public by any means or
136
+ process that requires permission under the Licensed Rights, such
137
+ as reproduction, public display, public performance, distribution,
138
+ dissemination, communication, or importation, and to make material
139
+ available to the public including in ways that members of the
140
+ public may access the material from a place and at a time
141
+ individually chosen by them.
142
+
143
+ m. Sui Generis Database Rights means rights other than copyright
144
+ resulting from Directive 96/9/EC of the European Parliament and of
145
+ the Council of 11 March 1996 on the legal protection of databases,
146
+ as amended and/or succeeded, as well as other essentially
147
+ equivalent rights anywhere in the world.
148
+
149
+ n. You means the individual or entity exercising the Licensed Rights
150
+ under this Public License. Your has a corresponding meaning.
151
+
152
+
153
+ Section 2 -- Scope.
154
+
155
+ a. License grant.
156
+
157
+ 1. Subject to the terms and conditions of this Public License,
158
+ the Licensor hereby grants You a worldwide, royalty-free,
159
+ non-sublicensable, non-exclusive, irrevocable license to
160
+ exercise the Licensed Rights in the Licensed Material to:
161
+
162
+ a. reproduce and Share the Licensed Material, in whole or
163
+ in part, for NonCommercial purposes only; and
164
+
165
+ b. produce, reproduce, and Share Adapted Material for
166
+ NonCommercial purposes only.
167
+
168
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
169
+ Exceptions and Limitations apply to Your use, this Public
170
+ License does not apply, and You do not need to comply with
171
+ its terms and conditions.
172
+
173
+ 3. Term. The term of this Public License is specified in Section
174
+ 6(a).
175
+
176
+ 4. Media and formats; technical modifications allowed. The
177
+ Licensor authorizes You to exercise the Licensed Rights in
178
+ all media and formats whether now known or hereafter created,
179
+ and to make technical modifications necessary to do so. The
180
+ Licensor waives and/or agrees not to assert any right or
181
+ authority to forbid You from making technical modifications
182
+ necessary to exercise the Licensed Rights, including
183
+ technical modifications necessary to circumvent Effective
184
+ Technological Measures. For purposes of this Public License,
185
+ simply making modifications authorized by this Section 2(a)
186
+ (4) never produces Adapted Material.
187
+
188
+ 5. Downstream recipients.
189
+
190
+ a. Offer from the Licensor -- Licensed Material. Every
191
+ recipient of the Licensed Material automatically
192
+ receives an offer from the Licensor to exercise the
193
+ Licensed Rights under the terms and conditions of this
194
+ Public License.
195
+
196
+ b. Additional offer from the Licensor -- Adapted Material.
197
+ Every recipient of Adapted Material from You
198
+ automatically receives an offer from the Licensor to
199
+ exercise the Licensed Rights in the Adapted Material
200
+ under the conditions of the Adapter's License You apply.
201
+
202
+ c. No downstream restrictions. You may not offer or impose
203
+ any additional or different terms or conditions on, or
204
+ apply any Effective Technological Measures to, the
205
+ Licensed Material if doing so restricts exercise of the
206
+ Licensed Rights by any recipient of the Licensed
207
+ Material.
208
+
209
+ 6. No endorsement. Nothing in this Public License constitutes or
210
+ may be construed as permission to assert or imply that You
211
+ are, or that Your use of the Licensed Material is, connected
212
+ with, or sponsored, endorsed, or granted official status by,
213
+ the Licensor or others designated to receive attribution as
214
+ provided in Section 3(a)(1)(A)(i).
215
+
216
+ b. Other rights.
217
+
218
+ 1. Moral rights, such as the right of integrity, are not
219
+ licensed under this Public License, nor are publicity,
220
+ privacy, and/or other similar personality rights; however, to
221
+ the extent possible, the Licensor waives and/or agrees not to
222
+ assert any such rights held by the Licensor to the limited
223
+ extent necessary to allow You to exercise the Licensed
224
+ Rights, but not otherwise.
225
+
226
+ 2. Patent and trademark rights are not licensed under this
227
+ Public License.
228
+
229
+ 3. To the extent possible, the Licensor waives any right to
230
+ collect royalties from You for the exercise of the Licensed
231
+ Rights, whether directly or through a collecting society
232
+ under any voluntary or waivable statutory or compulsory
233
+ licensing scheme. In all other cases the Licensor expressly
234
+ reserves any right to collect such royalties, including when
235
+ the Licensed Material is used other than for NonCommercial
236
+ purposes.
237
+
238
+
239
+ Section 3 -- License Conditions.
240
+
241
+ Your exercise of the Licensed Rights is expressly made subject to the
242
+ following conditions.
243
+
244
+ a. Attribution.
245
+
246
+ 1. If You Share the Licensed Material (including in modified
247
+ form), You must:
248
+
249
+ a. retain the following if it is supplied by the Licensor
250
+ with the Licensed Material:
251
+
252
+ i. identification of the creator(s) of the Licensed
253
+ Material and any others designated to receive
254
+ attribution, in any reasonable manner requested by
255
+ the Licensor (including by pseudonym if
256
+ designated);
257
+
258
+ ii. a copyright notice;
259
+
260
+ iii. a notice that refers to this Public License;
261
+
262
+ iv. a notice that refers to the disclaimer of
263
+ warranties;
264
+
265
+ v. a URI or hyperlink to the Licensed Material to the
266
+ extent reasonably practicable;
267
+
268
+ b. indicate if You modified the Licensed Material and
269
+ retain an indication of any previous modifications; and
270
+
271
+ c. indicate the Licensed Material is licensed under this
272
+ Public License, and include the text of, or the URI or
273
+ hyperlink to, this Public License.
274
+
275
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
276
+ reasonable manner based on the medium, means, and context in
277
+ which You Share the Licensed Material. For example, it may be
278
+ reasonable to satisfy the conditions by providing a URI or
279
+ hyperlink to a resource that includes the required
280
+ information.
281
+ 3. If requested by the Licensor, You must remove any of the
282
+ information required by Section 3(a)(1)(A) to the extent
283
+ reasonably practicable.
284
+
285
+ b. ShareAlike.
286
+
287
+ In addition to the conditions in Section 3(a), if You Share
288
+ Adapted Material You produce, the following conditions also apply.
289
+
290
+ 1. The Adapter's License You apply must be a Creative Commons
291
+ license with the same License Elements, this version or
292
+ later, or a BY-NC-SA Compatible License.
293
+
294
+ 2. You must include the text of, or the URI or hyperlink to, the
295
+ Adapter's License You apply. You may satisfy this condition
296
+ in any reasonable manner based on the medium, means, and
297
+ context in which You Share Adapted Material.
298
+
299
+ 3. You may not offer or impose any additional or different terms
300
+ or conditions on, or apply any Effective Technological
301
+ Measures to, Adapted Material that restrict exercise of the
302
+ rights granted under the Adapter's License You apply.
303
+
304
+
305
+ Section 4 -- Sui Generis Database Rights.
306
+
307
+ Where the Licensed Rights include Sui Generis Database Rights that
308
+ apply to Your use of the Licensed Material:
309
+
310
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
311
+ to extract, reuse, reproduce, and Share all or a substantial
312
+ portion of the contents of the database for NonCommercial purposes
313
+ only;
314
+
315
+ b. if You include all or a substantial portion of the database
316
+ contents in a database in which You have Sui Generis Database
317
+ Rights, then the database in which You have Sui Generis Database
318
+ Rights (but not its individual contents) is Adapted Material,
319
+ including for purposes of Section 3(b); and
320
+
321
+ c. You must comply with the conditions in Section 3(a) if You Share
322
+ all or a substantial portion of the contents of the database.
323
+
324
+ For the avoidance of doubt, this Section 4 supplements and does not
325
+ replace Your obligations under this Public License where the Licensed
326
+ Rights include other Copyright and Similar Rights.
327
+
328
+
329
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
330
+
331
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
332
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
333
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
334
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
335
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
336
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
337
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
338
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
339
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
340
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
341
+
342
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
343
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
344
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
345
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
346
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
347
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
348
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
349
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
350
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
351
+
352
+ c. The disclaimer of warranties and limitation of liability provided
353
+ above shall be interpreted in a manner that, to the extent
354
+ possible, most closely approximates an absolute disclaimer and
355
+ waiver of all liability.
356
+
357
+
358
+ Section 6 -- Term and Termination.
359
+
360
+ a. This Public License applies for the term of the Copyright and
361
+ Similar Rights licensed here. However, if You fail to comply with
362
+ this Public License, then Your rights under this Public License
363
+ terminate automatically.
364
+
365
+ b. Where Your right to use the Licensed Material has terminated under
366
+ Section 6(a), it reinstates:
367
+
368
+ 1. automatically as of the date the violation is cured, provided
369
+ it is cured within 30 days of Your discovery of the
370
+ violation; or
371
+
372
+ 2. upon express reinstatement by the Licensor.
373
+
374
+ For the avoidance of doubt, this Section 6(b) does not affect any
375
+ right the Licensor may have to seek remedies for Your violations
376
+ of this Public License.
377
+
378
+ c. For the avoidance of doubt, the Licensor may also offer the
379
+ Licensed Material under separate terms or conditions or stop
380
+ distributing the Licensed Material at any time; however, doing so
381
+ will not terminate this Public License.
382
+
383
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
384
+ License.
385
+
386
+
387
+ Section 7 -- Other Terms and Conditions.
388
+
389
+ a. The Licensor shall not be bound by any additional or different
390
+ terms or conditions communicated by You unless expressly agreed.
391
+
392
+ b. Any arrangements, understandings, or agreements regarding the
393
+ Licensed Material not stated herein are separate from and
394
+ independent of the terms and conditions of this Public License.
395
+
396
+
397
+ Section 8 -- Interpretation.
398
+
399
+ a. For the avoidance of doubt, this Public License does not, and
400
+ shall not be interpreted to, reduce, limit, restrict, or impose
401
+ conditions on any use of the Licensed Material that could lawfully
402
+ be made without permission under this Public License.
403
+
404
+ b. To the extent possible, if any provision of this Public License is
405
+ deemed unenforceable, it shall be automatically reformed to the
406
+ minimum extent necessary to make it enforceable. If the provision
407
+ cannot be reformed, it shall be severed from this Public License
408
+ without affecting the enforceability of the remaining terms and
409
+ conditions.
410
+
411
+ c. No term or condition of this Public License will be waived and no
412
+ failure to comply consented to unless expressly agreed to by the
413
+ Licensor.
414
+
415
+ d. Nothing in this Public License constitutes or may be interpreted
416
+ as a limitation upon, or waiver of, any privileges and immunities
417
+ that apply to the Licensor or You, including from the legal
418
+ processes of any jurisdiction or authority.
419
+
420
+ =======================================================================
421
+
422
+ Creative Commons is not a party to its public
423
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
424
+ its public licenses to material it publishes and in those instances
425
+ will be considered the "Licensor." The text of the Creative Commons
426
+ public licenses is dedicated to the public domain under the CC0 Public
427
+ Domain Dedication. Except for the limited purpose of indicating that
428
+ material is shared under a Creative Commons public license or as
429
+ otherwise permitted by the Creative Commons policies published at
430
+ creativecommons.org/policies, Creative Commons does not authorize the
431
+ use of the trademark "Creative Commons" or any other trademark or logo
432
+ of Creative Commons without its prior written consent including,
433
+ without limitation, in connection with any unauthorized modifications
434
+ to any of its public licenses or any other arrangements,
435
+ understandings, or agreements concerning use of licensed material. For
436
+ the avoidance of doubt, this paragraph does not form part of the
437
+ public licenses.
438
+
439
+ Creative Commons may be contacted at creativecommons.org.
licenses/LICENSE_UVIT.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 Fan Bao
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
lmdb2wds.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ This file implements the direction conversion from the latent ImageNet dataset to WebDataset.
3
+ '''
4
+ import os
5
+ from argparse import ArgumentParser
6
+ from tqdm import tqdm
7
+ import numpy as np
8
+ import pickle
9
+
10
+ import webdataset as wds
11
+
12
+ from train_utils.datasets import ImageNetLatentDataset
13
+
14
+
15
+ def convert2wds(args):
16
+ os.makedirs(args.outdir, exist_ok=True)
17
+ wds_path = os.path.join(args.outdir, f'latent_imagenet_512_{args.split}-%04d.tar')
18
+ dataset = ImageNetLatentDataset(args.datadir, resolution=args.resolution, num_channels=args.num_channels, split=args.split)
19
+
20
+ with wds.ShardWriter(wds_path, maxcount=args.maxcount, maxsize=args.maxsize) as sink:
21
+ for i in tqdm(range(len(dataset)), dynamic_ncols=True):
22
+ if i % args.maxcount == 0:
23
+ print(f'writing to the {i // args.maxcount}th shard')
24
+ img, label = dataset[i] # C, H, W
25
+ label = np.argmax(label) # int
26
+ sink.write({'__key__': f'{i:07d}', 'latent': pickle.dumps(img), 'cls': label})
27
+
28
+
29
+ if __name__ == "__main__":
30
+ parser = ArgumentParser('Convert the latent imagenet dataset to WebDataset')
31
+ parser.add_argument('--maxcount', type=int, default=10010, help='max number of entries per shard')
32
+ parser.add_argument('--maxsize', type=int, default=10 ** 10, help='max size per shard')
33
+ parser.add_argument('--outdir', type=str, default='latent_imagenet_wds', help='path to save the converted dataset')
34
+ parser.add_argument('--datadir', type=str, default='latent_imagenet', help='path to the latent imagenet dataset')
35
+ parser.add_argument('--resolution', type=int, default=64, help='image resolution')
36
+ parser.add_argument('--num_channels', type=int, default=8, help='number of image channels')
37
+ parser.add_argument('--split', type=str, default='train', help='split of the dataset')
38
+ args = parser.parse_args()
39
+ convert2wds(args)
models/maskdit.py ADDED
@@ -0,0 +1,781 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) [2023] [Anima-Lab]
4
+
5
+ # This code is adapted from https://github.com/facebookresearch/mae/blob/main/models_mae.py
6
+ # and https://github.com/facebookresearch/DiT/blob/main/models.py.
7
+ # The original code is licensed under a Attribution-NonCommercial 4.0 InternationalCreative Commons License,
8
+ # which is can be found at licenses/LICENSE_MAE.txt and licenses/LICENSE_DIT.txt.
9
+
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import numpy as np
14
+ import math
15
+ from functools import partial
16
+ from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
17
+
18
+
19
+ def modulate(x, shift, scale):
20
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
21
+
22
+
23
+ #################################################################################
24
+ # Embedding Layers for Timesteps and Class Labels #
25
+ #################################################################################
26
+
27
+ class TimestepEmbedder(nn.Module):
28
+ """
29
+ Embeds scalar timesteps into vector representations.
30
+ """
31
+
32
+ def __init__(self, hidden_size, frequency_embedding_size=256):
33
+ super().__init__()
34
+ self.mlp = nn.Sequential(
35
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
36
+ nn.SiLU(),
37
+ nn.Linear(hidden_size, hidden_size, bias=True),
38
+ )
39
+ self.frequency_embedding_size = frequency_embedding_size
40
+
41
+ @staticmethod
42
+ def timestep_embedding(t, dim, max_period=10000):
43
+ """
44
+ Create sinusoidal timestep embeddings.
45
+ :param t: a 1-D Tensor of N indices, one per batch element.
46
+ These may be fractional.
47
+ :param dim: the dimension of the output.
48
+ :param max_period: controls the minimum frequency of the embeddings.
49
+ :return: an (N, D) Tensor of positional embeddings.
50
+ """
51
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
52
+ half = dim // 2
53
+ freqs = torch.exp(
54
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
55
+ ).to(device=t.device)
56
+ args = t[:, None].float() * freqs[None]
57
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
58
+ if dim % 2:
59
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
60
+ return embedding
61
+
62
+ def forward(self, t):
63
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
64
+ t_emb = self.mlp(t_freq)
65
+ return t_emb
66
+
67
+
68
+ class LabelEmbedder(nn.Module):
69
+ """
70
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
71
+ """
72
+
73
+ def __init__(self, num_classes, hidden_size, dropout_prob):
74
+ super().__init__()
75
+ self.embedding_table = nn.Linear(num_classes, hidden_size, bias=False)
76
+ self.num_classes = num_classes
77
+ self.dropout_prob = dropout_prob
78
+
79
+ def forward(self, y):
80
+ embeddings = self.embedding_table(y)
81
+ return embeddings
82
+
83
+
84
+ #################################################################################
85
+ # Token Masking and Unmasking #
86
+ #################################################################################
87
+
88
+ def get_mask(batch, length, mask_ratio, device):
89
+ """
90
+ Get the binary mask for the input sequence.
91
+ Args:
92
+ - batch: batch size
93
+ - length: sequence length
94
+ - mask_ratio: ratio of tokens to mask
95
+ return:
96
+ mask_dict with following keys:
97
+ - mask: binary mask, 0 is keep, 1 is remove
98
+ - ids_keep: indices of tokens to keep
99
+ - ids_restore: indices to restore the original order
100
+ """
101
+ len_keep = int(length * (1 - mask_ratio))
102
+ noise = torch.rand(batch, length, device=device) # noise in [0, 1]
103
+ ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
104
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
105
+ # keep the first subset
106
+ ids_keep = ids_shuffle[:, :len_keep]
107
+
108
+ mask = torch.ones([batch, length], device=device)
109
+ mask[:, :len_keep] = 0
110
+ mask = torch.gather(mask, dim=1, index=ids_restore)
111
+ return {'mask': mask,
112
+ 'ids_keep': ids_keep,
113
+ 'ids_restore': ids_restore}
114
+
115
+
116
+ def mask_out_token(x, ids_keep):
117
+ """
118
+ Mask out the tokens specified by ids_keep.
119
+ Args:
120
+ - x: input sequence, [N, L, D]
121
+ - ids_keep: indices of tokens to keep
122
+ return:
123
+ - x_masked: masked sequence
124
+ """
125
+ N, L, D = x.shape # batch, length, dim
126
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
127
+ return x_masked
128
+
129
+
130
+ def mask_tokens(x, mask_ratio):
131
+ """
132
+ Perform per-sample random masking by per-sample shuffling.
133
+ Per-sample shuffling is done by argsort random noise.
134
+ x: [N, L, D], sequence
135
+ """
136
+ N, L, D = x.shape # batch, length, dim
137
+ len_keep = int(L * (1 - mask_ratio))
138
+
139
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
140
+
141
+ # sort noise for each sample
142
+ ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
143
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
144
+
145
+ # keep the first subset
146
+ ids_keep = ids_shuffle[:, :len_keep]
147
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
148
+
149
+ # generate the binary mask: 0 is keep, 1 is remove
150
+ mask = torch.ones([N, L], device=x.device)
151
+ mask[:, :len_keep] = 0
152
+ mask = torch.gather(mask, dim=1, index=ids_restore)
153
+
154
+ return x_masked, mask, ids_restore
155
+
156
+
157
+ def unmask_tokens(x, ids_restore, mask_token, extras=0):
158
+ # x: [N, T, D] if extras == 0 (i.e., no cls token) else x: [N, T+1, D]
159
+ mask_tokens = mask_token.repeat(x.shape[0], ids_restore.shape[1] + extras - x.shape[1], 1)
160
+ x_ = torch.cat([x[:, extras:, :], mask_tokens], dim=1) # no cls token
161
+ x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
162
+ x = torch.cat([x[:, :extras, :], x_], dim=1) # append cls token
163
+ return x
164
+
165
+
166
+ #################################################################################
167
+ # Core DiT Model #
168
+ #################################################################################
169
+
170
+ class DiTBlock(nn.Module):
171
+ """
172
+ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
173
+ """
174
+
175
+ def __init__(self, hidden_size, c_emb_dize, num_heads, mlp_ratio=4.0, **block_kwargs):
176
+ super().__init__()
177
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
178
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
179
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
180
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
181
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
182
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
183
+ self.adaLN_modulation = nn.Sequential(
184
+ nn.SiLU(),
185
+ nn.Linear(c_emb_dize, 6 * hidden_size, bias=True)
186
+ )
187
+
188
+ def forward(self, x, c):
189
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
190
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
191
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
192
+ return x
193
+
194
+
195
+ class DecoderLayer(nn.Module):
196
+ """
197
+ The final layer of DiT.
198
+ """
199
+
200
+ def __init__(self, hidden_size, decoder_hidden_size):
201
+ super().__init__()
202
+ self.norm_decoder = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
203
+ self.linear = nn.Linear(hidden_size, decoder_hidden_size, bias=True)
204
+ self.adaLN_modulation = nn.Sequential(
205
+ nn.SiLU(),
206
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
207
+ )
208
+
209
+ def forward(self, x, c):
210
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
211
+ x = modulate(self.norm_decoder(x), shift, scale)
212
+ x = self.linear(x)
213
+ return x
214
+
215
+
216
+ class FinalLayer(nn.Module):
217
+ """
218
+ The final layer of DiT.
219
+ """
220
+
221
+ def __init__(self, final_hidden_size, c_emb_size, patch_size, out_channels):
222
+ super().__init__()
223
+ self.norm_final = nn.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6)
224
+ self.linear = nn.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True)
225
+ self.adaLN_modulation = nn.Sequential(
226
+ nn.SiLU(),
227
+ nn.Linear(c_emb_size, 2 * final_hidden_size, bias=True)
228
+ )
229
+
230
+ def forward(self, x, c):
231
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
232
+ x = modulate(self.norm_final(x), shift, scale)
233
+ x = self.linear(x)
234
+ return x
235
+
236
+
237
+ class DiT(nn.Module):
238
+ """
239
+ Diffusion model with a Transformer backbone.
240
+ """
241
+
242
+ def __init__(
243
+ self,
244
+ input_size=32,
245
+ patch_size=2,
246
+ in_channels=4,
247
+ hidden_size=1152,
248
+ depth=28,
249
+ num_heads=16,
250
+ mlp_ratio=4.0,
251
+ class_dropout_prob=0.1,
252
+ num_classes=1000, # 0 = unconditional
253
+ learn_sigma=False,
254
+ use_decoder=False, # decide if add a lightweight DiT decoder
255
+ mae_loss_coef=0, # 0 = no mae loss
256
+ pad_cls_token=False, # decide if use cls_token as mask token for decoder
257
+ direct_cls_token=False, # decide if directly pass cls_toekn to decoder (0 = not pass to decoder)
258
+ ext_feature_dim=0, # decide if condition on external features (0 = no feature)
259
+ use_encoder_feat=False, # decide if condition on encoder output feature
260
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), # normalize the encoder output feature
261
+ ):
262
+ super().__init__()
263
+ self.learn_sigma = learn_sigma
264
+ self.in_channels = in_channels
265
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
266
+ self.patch_size = patch_size
267
+ self.num_heads = num_heads
268
+ self.class_dropout_prob = class_dropout_prob
269
+ self.num_classes = num_classes
270
+ self.use_decoder = use_decoder
271
+ self.mae_loss_coef = mae_loss_coef
272
+ self.pad_cls_token = pad_cls_token
273
+ self.direct_cls_token = direct_cls_token
274
+ self.ext_feature_dim = ext_feature_dim
275
+ self.use_encoder_feat = use_encoder_feat
276
+ self.feat_norm = norm_layer(hidden_size, elementwise_affine=False)
277
+
278
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size)
279
+ self.t_embedder = TimestepEmbedder(hidden_size)
280
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) if num_classes else None
281
+ num_patches = self.x_embedder.num_patches
282
+
283
+ self.cls_token = None
284
+ self.extras = 0
285
+ self.decoder_extras = 0
286
+ if self.pad_cls_token:
287
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))
288
+ self.extras = 1
289
+ self.decoder_extras = 1
290
+
291
+ self.feat_embedder = None
292
+ if self.ext_feature_dim > 0:
293
+ self.feat_embedder = nn.Linear(self.ext_feature_dim, hidden_size, bias=True)
294
+
295
+ # Will use fixed sin-cos embedding:
296
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.extras, hidden_size), requires_grad=False)
297
+
298
+ self.blocks = nn.ModuleList([
299
+ DiTBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
300
+ ])
301
+
302
+ self.decoder_pos_embed = None
303
+ self.decoder_layer = None
304
+ self.decoder_blocks = None
305
+ self.mask_token = None
306
+ self.cls_token_embedder = None
307
+ self.enc_feat_embedder = None
308
+ final_hidden_size = hidden_size
309
+ if self.use_decoder:
310
+ decoder_hidden_size = 512
311
+ decoder_depth = 8
312
+ decoder_num_heads = 16
313
+ if not self.direct_cls_token:
314
+ self.decoder_extras = 0
315
+ self.decoder_pos_embed = nn.Parameter(
316
+ torch.zeros(1, num_patches + self.decoder_extras, decoder_hidden_size),
317
+ requires_grad=False)
318
+ self.decoder_layer = DecoderLayer(hidden_size, decoder_hidden_size)
319
+ self.decoder_blocks = nn.ModuleList([
320
+ DiTBlock(decoder_hidden_size, hidden_size, decoder_num_heads, mlp_ratio=mlp_ratio) for _ in
321
+ range(decoder_depth)
322
+ ])
323
+ if self.mae_loss_coef > 0:
324
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_hidden_size)) # Similar to MAE
325
+ if self.pad_cls_token:
326
+ self.cls_token_embedder = nn.Linear(hidden_size, hidden_size, bias=True)
327
+ if self.use_encoder_feat:
328
+ self.enc_feat_embedder = nn.Linear(hidden_size, hidden_size, bias=True)
329
+ final_hidden_size = decoder_hidden_size
330
+
331
+ self.final_layer = FinalLayer(final_hidden_size, hidden_size, patch_size, self.out_channels)
332
+ self.initialize_weights()
333
+
334
+ def initialize_weights(self):
335
+ # Initialize transformer layers:
336
+ def _basic_init(module):
337
+ if isinstance(module, nn.Linear):
338
+ nn.init.xavier_uniform_(module.weight)
339
+ if module.bias is not None:
340
+ nn.init.constant_(module.bias, 0)
341
+
342
+ self.apply(_basic_init)
343
+
344
+ # Initialize (and freeze) pos_embed by sin-cos embedding:
345
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5),
346
+ cls_token=self.pad_cls_token, extra_tokens=self.extras)
347
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
348
+
349
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
350
+ w = self.x_embedder.proj.weight.data
351
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
352
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
353
+
354
+ # Initialize label embedding table:
355
+ if self.y_embedder is not None:
356
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
357
+
358
+ # Initialize timestep embedding MLP:
359
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
360
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
361
+
362
+ # Initialize cls_token embedding:
363
+ if self.feat_embedder is not None:
364
+ nn.init.normal_(self.feat_embedder.weight, std=0.02)
365
+
366
+ # Initialize cls token
367
+ if self.cls_token is not None:
368
+ nn.init.normal_(self.cls_token, std=.02)
369
+
370
+ # Initialize cls_token embedding:
371
+ if self.cls_token_embedder is not None:
372
+ nn.init.normal_(self.cls_token_embedder.weight, std=0.02)
373
+
374
+ # Zero-out adaLN modulation layers in DiT blocks:
375
+ for block in self.blocks:
376
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
377
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
378
+
379
+ # Zero-out output layers:
380
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
381
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
382
+ nn.init.constant_(self.final_layer.linear.weight, 0)
383
+ nn.init.constant_(self.final_layer.linear.bias, 0)
384
+
385
+ # --------------------------- decoder initialization ---------------------------
386
+ # Initialize (and freeze) decoder_pos_embed by sin-cos embedding:
387
+ if self.decoder_pos_embed is not None:
388
+ pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1],
389
+ int(self.x_embedder.num_patches ** 0.5),
390
+ cls_token=self.pad_cls_token, extra_tokens=self.decoder_extras)
391
+ self.decoder_pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
392
+
393
+ # Initialize mask token
394
+ if self.mae_loss_coef > 0 and self.mask_token is not None:
395
+ nn.init.normal_(self.mask_token, std=.02)
396
+
397
+ # Zero-out adaLN modulation layers in DiT decoder blocks:
398
+ if self.decoder_blocks is not None:
399
+ for block in self.decoder_blocks:
400
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
401
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
402
+
403
+ # Zero-out decoder layers: (TODO: here we keep it the same with final layers but not sure if it makes sense)
404
+ if self.decoder_layer is not None:
405
+ nn.init.constant_(self.decoder_layer.adaLN_modulation[-1].weight, 0)
406
+ nn.init.constant_(self.decoder_layer.adaLN_modulation[-1].bias, 0)
407
+ nn.init.constant_(self.decoder_layer.linear.weight, 0)
408
+ nn.init.constant_(self.decoder_layer.linear.bias, 0)
409
+ # ------------------------------------------------------------------------------
410
+
411
+ def unpatchify(self, x):
412
+ """
413
+ x: (N, L, patch_size**2 * C)
414
+ imgs: (N, H, W, C)
415
+ """
416
+ c = self.out_channels
417
+ p = self.x_embedder.patch_size[0]
418
+ h = w = int(x.shape[1] ** 0.5)
419
+ assert h * w == x.shape[1]
420
+
421
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
422
+ x = torch.einsum('nhwpqc->nchpwq', x)
423
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
424
+ return imgs
425
+
426
+ def encode(self, x, t, y, mask_ratio=0, mask_dict=None, feat=None):
427
+ '''
428
+ Encode x and (t, y, feat) into a latent representation.
429
+ Return:
430
+ x_feat: feature
431
+ mask_dict with keys: 'ids_keep', 'ids_mask', 'mask_ratio'
432
+ '''
433
+ x = self.x_embedder(x) + self.pos_embed[:, self.extras:, :] # (N, T, D), where T = H * W / patch_size ** 2
434
+ if mask_ratio > 0 and mask_dict is None:
435
+ mask_dict = get_mask(x.shape[0], x.shape[1], mask_ratio, device=x.device)
436
+ if mask_ratio > 0:
437
+ ids_keep = mask_dict['ids_keep']
438
+ x = mask_out_token(x, ids_keep)
439
+ # append cls token
440
+ if self.cls_token is not None:
441
+ cls_token = self.cls_token + self.pos_embed[:, :self.extras, :]
442
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
443
+ x = torch.cat((cls_tokens, x), dim=1)
444
+ t = self.t_embedder(t) # (N, D)
445
+ c = t
446
+ if self.y_embedder is not None:
447
+ y = self.y_embedder(y) # (N, D)
448
+ c = c + y # (N, D)
449
+ assert (self.feat_embedder is None) or (self.enc_feat_embedder is None)
450
+ if self.feat_embedder is not None:
451
+ assert feat.shape[-1] == self.ext_feature_dim
452
+ feat_embed = self.feat_embedder(feat) # (N, D)
453
+ c = c + feat_embed # (N, D)
454
+ if self.enc_feat_embedder is not None and feat is not None:
455
+ assert feat.shape[-1] == c.shape[-1]
456
+ feat_embed = self.enc_feat_embedder(feat) # (N, D)
457
+ c = c + feat_embed # (N, D)
458
+
459
+ for block in self.blocks:
460
+ x = block(x, c) # (N, T, D)
461
+
462
+ x_feat = x[:, self.extras:, :].mean(dim=1) # global pool without cls token
463
+ x_feat = self.feat_norm(x_feat)
464
+ return x_feat, mask_dict
465
+
466
+
467
+ def forward_encoder(self, x, t, y, mask_ratio=0, mask_dict=None, feat=None, train=True):
468
+ '''
469
+ Encode x and (t, y, feat) into a latent representation.
470
+ Return:
471
+ - out_enc: dict, containing the following keys: x, x_feat
472
+ - c: the conditional embedding
473
+ '''
474
+ out_enc = dict()
475
+ x = self.x_embedder(x) + self.pos_embed[:, self.extras:, :] # (N, T, D), where T = H * W / patch_size ** 2
476
+ if mask_ratio > 0 and mask_dict is None:
477
+ mask_dict = get_mask(x.shape[0], x.shape[1], mask_ratio=mask_ratio, device=x.device)
478
+
479
+ if mask_ratio > 0:
480
+ ids_keep = mask_dict['ids_keep']
481
+ ids_restore = mask_dict['ids_restore']
482
+ if train:
483
+ x = mask_out_token(x, ids_keep)
484
+
485
+ # append cls token
486
+ if self.cls_token is not None:
487
+ cls_token = self.cls_token + self.pos_embed[:, :self.extras, :]
488
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
489
+ x = torch.cat((cls_tokens, x), dim=1)
490
+
491
+ t = self.t_embedder(t) # (N, D)
492
+ c = t
493
+ if self.y_embedder is not None:
494
+ y = self.y_embedder(y) # (N, D)
495
+ c = c + y # (N, D)
496
+ assert (self.feat_embedder is None) or (self.enc_feat_embedder is None)
497
+ if self.feat_embedder is not None:
498
+ assert feat.shape[-1] == self.ext_feature_dim
499
+ feat_embed = self.feat_embedder(feat) # (N, D)
500
+ c = c + feat_embed # (N, D)
501
+ if self.enc_feat_embedder is not None and feat is not None:
502
+ assert feat.shape[-1] == c.shape[-1]
503
+ feat_embed = self.enc_feat_embedder(feat) # (N, D)
504
+ c = c + feat_embed # (N, D)
505
+ for block in self.blocks:
506
+ x = block(x, c) # (N, T, D)
507
+ out_enc['x'] = x
508
+
509
+ return out_enc, c, mask_dict
510
+
511
+ def forward(self, x, t, y, mask_ratio=0, mask_dict=None, feat=None):
512
+ """
513
+ Forward pass of DiT.
514
+ x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
515
+ t: (N,) tensor of diffusion timesteps
516
+ y: (N,) tensor of class labels
517
+ """
518
+ if not self.training and self.use_encoder_feat:
519
+ feat, _ = self.encode(x, t, y, feat=feat)
520
+ out, c, mask_dict = self.forward_encoder(x, t, y, mask_ratio=mask_ratio, mask_dict=mask_dict, feat=feat, train=self.training)
521
+ if mask_ratio > 0:
522
+ ids_keep = mask_dict['ids_keep']
523
+ ids_restore = mask_dict['ids_restore']
524
+ out['mask'] = mask_dict['mask']
525
+ else:
526
+ ids_keep = ids_restore = None
527
+ x = out['x']
528
+ # Pass to a DiT decoder (if available)
529
+ if self.use_decoder:
530
+ if self.cls_token_embedder is not None:
531
+ # cls_token_output = x[:, :self.extras, :].squeeze(dim=1).detach().clone() # stop gradient
532
+ cls_token_output = x[:, :self.extras, :].squeeze(dim=1)
533
+ cls_token_embed = self.cls_token_embedder(self.feat_norm(cls_token_output)) # normalize cls token
534
+ c = c + cls_token_embed # pad cls_token output's embedding as feature conditioning
535
+
536
+ assert self.decoder_layer is not None
537
+ diff_extras = self.extras - self.decoder_extras
538
+ x = self.decoder_layer(x[:, diff_extras:, :], c) # remove cls token (if necessary)
539
+ if self.training and mask_ratio > 0:
540
+ mask_token = self.mask_token
541
+ if mask_token is None:
542
+ mask_token = torch.zeros(1, 1, x.shape[2]).to(x) # concat zeros to match shape
543
+ x = unmask_tokens(x, ids_restore, mask_token, extras=self.decoder_extras)
544
+ assert self.decoder_pos_embed is not None
545
+ x = x + self.decoder_pos_embed
546
+ assert self.decoder_blocks is not None
547
+ for block in self.decoder_blocks:
548
+ x = block(x, c) # (N, T, D)
549
+
550
+ x = self.final_layer(x, c) # (N, T or T+1, patch_size ** 2 * out_channels)
551
+ if not self.use_decoder and (self.training and mask_ratio > 0):
552
+ mask_token = torch.zeros(1, 1, x.shape[2]).to(x) # concat zeros to match shape
553
+ x = unmask_tokens(x, ids_restore, mask_token, extras=self.extras)
554
+ x = x[:, self.decoder_extras:, :] # remove cls token (if necessary)
555
+ x = self.unpatchify(x) # (N, out_channels, H, W)
556
+ out['x'] = x
557
+ return out
558
+
559
+ def forward_with_cfg(self, x, t, y, cfg_scale, feat=None, **model_kwargs):
560
+ """
561
+ Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
562
+ """
563
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
564
+ out = dict()
565
+
566
+ # Setup classifier-free guidance
567
+ x = torch.cat([x, x], 0)
568
+ y_null = torch.zeros_like(y)
569
+ y = torch.cat([y, y_null], 0)
570
+ if feat is not None:
571
+ feat = torch.cat([feat, feat], 0)
572
+
573
+ half = x[: len(x) // 2]
574
+ combined = torch.cat([half, half], dim=0)
575
+ assert self.num_classes and y is not None
576
+ model_out = self.forward(combined, t, y, feat=feat)['x']
577
+ # For exact reproducibility reasons, we apply classifier-free guidance on only
578
+ # three channels by default. The standard approach to cfg applies it to all channels.
579
+ # This can be done by uncommenting the following line and commenting-out the line following that.
580
+ eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
581
+ # eps, rest = model_out[:, :3], model_out[:, 3:]
582
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
583
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
584
+ half_rest = rest[: len(rest) // 2]
585
+ x = torch.cat([half_eps, half_rest], dim=1)
586
+ out['x'] = x
587
+ return out
588
+
589
+
590
+ #################################################################################
591
+ # Sine/Cosine Positional Embedding Functions #
592
+ #################################################################################
593
+ # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
594
+
595
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=1):
596
+ """
597
+ grid_size: int of the grid height and width
598
+ return:
599
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
600
+ """
601
+ grid_h = np.arange(grid_size, dtype=np.float32)
602
+ grid_w = np.arange(grid_size, dtype=np.float32)
603
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
604
+ grid = np.stack(grid, axis=0)
605
+
606
+ grid = grid.reshape([2, 1, grid_size, grid_size])
607
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
608
+ if cls_token and extra_tokens > 0:
609
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
610
+ return pos_embed
611
+
612
+
613
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
614
+ assert embed_dim % 2 == 0
615
+
616
+ # use half of dimensions to encode grid_h
617
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
618
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
619
+
620
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
621
+ return emb
622
+
623
+
624
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
625
+ """
626
+ embed_dim: output dimension for each position
627
+ pos: a list of positions to be encoded: size (M,)
628
+ out: (M, D)
629
+ """
630
+ assert embed_dim % 2 == 0
631
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
632
+ omega /= embed_dim / 2.
633
+ omega = 1. / 10000 ** omega # (D/2,)
634
+
635
+ pos = pos.reshape(-1) # (M,)
636
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
637
+
638
+ emb_sin = np.sin(out) # (M, D/2)
639
+ emb_cos = np.cos(out) # (M, D/2)
640
+
641
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
642
+ return emb
643
+
644
+
645
+ #################################################################################
646
+ # DiT Configs #
647
+ #################################################################################
648
+
649
+ def DiT_H_2(**kwargs):
650
+ return DiT(depth=32, hidden_size=1280, patch_size=2, num_heads=16, **kwargs)
651
+
652
+
653
+ def DiT_H_4(**kwargs):
654
+ return DiT(depth=32, hidden_size=1280, patch_size=4, num_heads=16, **kwargs)
655
+
656
+
657
+ def DiT_H_8(**kwargs):
658
+ return DiT(depth=32, hidden_size=1280, patch_size=8, num_heads=16, **kwargs)
659
+
660
+
661
+ def DiT_XL_2(**kwargs):
662
+ return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
663
+
664
+
665
+ def DiT_XL_4(**kwargs):
666
+ return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
667
+
668
+
669
+ def DiT_XL_8(**kwargs):
670
+ return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
671
+
672
+
673
+ def DiT_L_2(**kwargs):
674
+ return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
675
+
676
+
677
+ def DiT_L_4(**kwargs):
678
+ return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
679
+
680
+
681
+ def DiT_L_8(**kwargs):
682
+ return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
683
+
684
+
685
+ def DiT_B_2(**kwargs):
686
+ return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
687
+
688
+
689
+ def DiT_B_4(**kwargs):
690
+ return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
691
+
692
+
693
+ def DiT_B_8(**kwargs):
694
+ return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
695
+
696
+
697
+ def DiT_S_2(**kwargs):
698
+ return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
699
+
700
+
701
+ def DiT_S_4(**kwargs):
702
+ return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
703
+
704
+
705
+ def DiT_S_8(**kwargs):
706
+ return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
707
+
708
+
709
+ DiT_models = {
710
+ 'DiT-H/2': DiT_H_2, 'DiT-H/4': DiT_H_4, 'DiT-H/8': DiT_H_8,
711
+ 'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8,
712
+ 'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8,
713
+ 'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8,
714
+ 'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8,
715
+ }
716
+
717
+
718
+ # ----------------------------------------------------------------------------
719
+ # Improved preconditioning proposed in the paper "Elucidating the Design
720
+ # Space of Diffusion-Based Generative Models" (EDM).
721
+
722
+ class EDMPrecond(nn.Module):
723
+ def __init__(self,
724
+ img_resolution, # Image resolution.
725
+ img_channels, # Number of color channels.
726
+ num_classes=0, # Number of class labels, 0 = unconditional.
727
+ sigma_min=0, # Minimum supported noise level.
728
+ sigma_max=float('inf'), # Maximum supported noise level.
729
+ sigma_data=0.5, # Expected standard deviation of the training data.
730
+ model_type='DiT-B/2', # Class name of the underlying model.
731
+ **model_kwargs, # Keyword arguments for the underlying model.
732
+ ):
733
+ super().__init__()
734
+ self.img_resolution = img_resolution
735
+ self.img_channels = img_channels
736
+ self.num_classes = num_classes
737
+ self.sigma_min = sigma_min
738
+ self.sigma_max = sigma_max
739
+ self.sigma_data = sigma_data
740
+ self.model = DiT_models[model_type](input_size=img_resolution, in_channels=img_channels,
741
+ num_classes=num_classes, **model_kwargs)
742
+
743
+ def encode(self, x, sigma, class_labels=None, **model_kwargs):
744
+
745
+ sigma = sigma.to(x.dtype).reshape(-1, 1, 1, 1)
746
+ class_labels = None if self.num_classes == 0 else \
747
+ torch.zeros([x.shape[0], self.num_classes], device=x.device) if class_labels is None else \
748
+ class_labels.to(x.dtype).reshape(-1, self.num_classes)
749
+
750
+ c_in = 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt()
751
+ c_noise = sigma.log() / 4
752
+
753
+ feat, mask_dict = self.model.encode((c_in * x).to(x.dtype), c_noise.flatten(), y=class_labels, **model_kwargs)
754
+ return feat
755
+
756
+ def forward(self, x, sigma, class_labels=None, cfg_scale=None, **model_kwargs):
757
+ model_fn = self.model if cfg_scale is None else partial(self.model.forward_with_cfg, cfg_scale=cfg_scale)
758
+
759
+ sigma = sigma.to(x.dtype).reshape(-1, 1, 1, 1)
760
+ class_labels = None if self.num_classes == 0 else \
761
+ torch.zeros([x.shape[0], self.num_classes], device=x.device) if class_labels is None else \
762
+ class_labels.to(x.dtype).reshape(-1, self.num_classes)
763
+
764
+ c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
765
+ c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt()
766
+ c_in = 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt()
767
+ c_noise = sigma.log() / 4
768
+
769
+ model_out = model_fn((c_in * x).to(x.dtype), c_noise.flatten(), y=class_labels, **model_kwargs)
770
+ F_x = model_out['x']
771
+ D_x = c_skip * x + c_out * F_x
772
+ model_out['x'] = D_x
773
+ return model_out
774
+
775
+ def round_sigma(self, sigma):
776
+ return torch.as_tensor(sigma)
777
+
778
+
779
+ Precond_models = {
780
+ 'edm': EDMPrecond
781
+ }
sample.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) [2023] [Anima-Lab]
4
+
5
+
6
+ # This code is adapted from https://github.com/NVlabs/edm/blob/main/generate.py.
7
+ # The original code is licensed under a Creative Commons
8
+ # Attribution-NonCommercial-ShareAlike 4.0 International License, which is can be found at licenses/LICENSE_EDM.txt.
9
+
10
+ import argparse
11
+ import random
12
+
13
+ import PIL.Image
14
+ import lmdb
15
+ import numpy as np
16
+
17
+ import torch
18
+ import torch.distributed as dist
19
+ from torch.multiprocessing import Process
20
+ from tqdm import tqdm
21
+
22
+ from models.maskdit import Precond_models, DiT_models
23
+ from utils import *
24
+ import autoencoder
25
+
26
+
27
+ # ----------------------------------------------------------------------------
28
+ # Proposed EDM sampler (Algorithm 2).
29
+
30
+ def edm_sampler(
31
+ net, latents, class_labels=None, cfg_scale=None, feat=None, randn_like=torch.randn_like,
32
+ num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
33
+ S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
34
+ ):
35
+ # Adjust noise levels based on what's supported by the network.
36
+ sigma_min = max(sigma_min, net.sigma_min)
37
+ sigma_max = min(sigma_max, net.sigma_max)
38
+
39
+ # Time step discretization.
40
+ step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
41
+ t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (
42
+ sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
43
+ t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
44
+
45
+ # Main sampling loop.
46
+ x_next = latents.to(torch.float64) * t_steps[0]
47
+ for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
48
+ x_cur = x_next
49
+
50
+ # Increase noise temporarily.
51
+ gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
52
+ t_hat = net.round_sigma(t_cur + gamma * t_cur)
53
+ x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)
54
+
55
+ # Euler step.
56
+ denoised = net(x_hat.float(), t_hat, class_labels, cfg_scale, feat=feat)['x'].to(torch.float64)
57
+ d_cur = (x_hat - denoised) / t_hat
58
+ x_next = x_hat + (t_next - t_hat) * d_cur
59
+
60
+ # Apply 2nd order correction.
61
+ if i < num_steps - 1:
62
+ denoised = net(x_next.float(), t_next, class_labels, cfg_scale, feat=feat)['x'].to(torch.float64)
63
+ d_prime = (x_next - denoised) / t_next
64
+ x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
65
+
66
+ return x_next
67
+
68
+
69
+ # ----------------------------------------------------------------------------
70
+ # Generalized ablation sampler, representing the superset of all sampling
71
+ # methods discussed in the paper.
72
+
73
+ def ablation_sampler(
74
+ net, latents, class_labels=None, cfg_scale=None, feat=None, randn_like=torch.randn_like,
75
+ num_steps=18, sigma_min=None, sigma_max=None, rho=7,
76
+ solver='heun', discretization='edm', schedule='linear', scaling='none',
77
+ epsilon_s=1e-3, C_1=0.001, C_2=0.008, M=1000, alpha=1,
78
+ S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
79
+ ):
80
+ assert solver in ['euler', 'heun']
81
+ assert discretization in ['vp', 've', 'iddpm', 'edm']
82
+ assert schedule in ['vp', 've', 'linear']
83
+ assert scaling in ['vp', 'none']
84
+
85
+ # Helper functions for VP & VE noise level schedules.
86
+ vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5
87
+ vp_sigma_deriv = lambda beta_d, beta_min: lambda t: 0.5 * (beta_min + beta_d * t) * (sigma(t) + 1 / sigma(t))
88
+ vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (
89
+ sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d
90
+ ve_sigma = lambda t: t.sqrt()
91
+ ve_sigma_deriv = lambda t: 0.5 / t.sqrt()
92
+ ve_sigma_inv = lambda sigma: sigma ** 2
93
+
94
+ # Select default noise level range based on the specified time step discretization.
95
+ if sigma_min is None:
96
+ vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=epsilon_s)
97
+ sigma_min = {'vp': vp_def, 've': 0.02, 'iddpm': 0.002, 'edm': 0.002}[discretization]
98
+ if sigma_max is None:
99
+ vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=1)
100
+ sigma_max = {'vp': vp_def, 've': 100, 'iddpm': 81, 'edm': 80}[discretization]
101
+
102
+ # Adjust noise levels based on what's supported by the network.
103
+ sigma_min = max(sigma_min, net.sigma_min)
104
+ sigma_max = min(sigma_max, net.sigma_max)
105
+
106
+ # Compute corresponding betas for VP.
107
+ vp_beta_d = 2 * (np.log(sigma_min ** 2 + 1) / epsilon_s - np.log(sigma_max ** 2 + 1)) / (epsilon_s - 1)
108
+ vp_beta_min = np.log(sigma_max ** 2 + 1) - 0.5 * vp_beta_d
109
+
110
+ # Define time steps in terms of noise level.
111
+ step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
112
+ if discretization == 'vp':
113
+ orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1)
114
+ sigma_steps = vp_sigma(vp_beta_d, vp_beta_min)(orig_t_steps)
115
+ elif discretization == 've':
116
+ orig_t_steps = (sigma_max ** 2) * ((sigma_min ** 2 / sigma_max ** 2) ** (step_indices / (num_steps - 1)))
117
+ sigma_steps = ve_sigma(orig_t_steps)
118
+ elif discretization == 'iddpm':
119
+ u = torch.zeros(M + 1, dtype=torch.float64, device=latents.device)
120
+ alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2
121
+ for j in torch.arange(M, 0, -1, device=latents.device): # M, ..., 1
122
+ u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1).sqrt()
123
+ u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)]
124
+ sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)]
125
+ else:
126
+ assert discretization == 'edm'
127
+ sigma_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (
128
+ sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
129
+
130
+ # Define noise level schedule.
131
+ if schedule == 'vp':
132
+ sigma = vp_sigma(vp_beta_d, vp_beta_min)
133
+ sigma_deriv = vp_sigma_deriv(vp_beta_d, vp_beta_min)
134
+ sigma_inv = vp_sigma_inv(vp_beta_d, vp_beta_min)
135
+ elif schedule == 've':
136
+ sigma = ve_sigma
137
+ sigma_deriv = ve_sigma_deriv
138
+ sigma_inv = ve_sigma_inv
139
+ else:
140
+ assert schedule == 'linear'
141
+ sigma = lambda t: t
142
+ sigma_deriv = lambda t: 1
143
+ sigma_inv = lambda sigma: sigma
144
+
145
+ # Define scaling schedule.
146
+ if scaling == 'vp':
147
+ s = lambda t: 1 / (1 + sigma(t) ** 2).sqrt()
148
+ s_deriv = lambda t: -sigma(t) * sigma_deriv(t) * (s(t) ** 3)
149
+ else:
150
+ assert scaling == 'none'
151
+ s = lambda t: 1
152
+ s_deriv = lambda t: 0
153
+
154
+ # Compute final time steps based on the corresponding noise levels.
155
+ t_steps = sigma_inv(net.round_sigma(sigma_steps))
156
+ t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0
157
+
158
+ # Main sampling loop.
159
+ t_next = t_steps[0]
160
+ x_next = latents.to(torch.float64) * (sigma(t_next) * s(t_next))
161
+ for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
162
+ x_cur = x_next
163
+
164
+ # Increase noise temporarily.
165
+ gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= sigma(t_cur) <= S_max else 0
166
+ t_hat = sigma_inv(net.round_sigma(sigma(t_cur) + gamma * sigma(t_cur)))
167
+ x_hat = s(t_hat) / s(t_cur) * x_cur + (sigma(t_hat) ** 2 - sigma(t_cur) ** 2).clip(min=0).sqrt() * s(
168
+ t_hat) * S_noise * randn_like(x_cur)
169
+
170
+ # Euler step.
171
+ h = t_next - t_hat
172
+ denoised = net(x_hat.float() / s(t_hat), sigma(t_hat), class_labels, cfg_scale, feat=feat)['x'].to(torch.float64)
173
+ d_cur = (sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat)) * x_hat - sigma_deriv(t_hat) * s(
174
+ t_hat) / sigma(t_hat) * denoised
175
+ x_prime = x_hat + alpha * h * d_cur
176
+ t_prime = t_hat + alpha * h
177
+
178
+ # Apply 2nd order correction.
179
+ if solver == 'euler' or i == num_steps - 1:
180
+ x_next = x_hat + h * d_cur
181
+ else:
182
+ assert solver == 'heun'
183
+ denoised = net(x_prime.float() / s(t_prime), sigma(t_prime), class_labels, cfg_scale, feat=feat)['x'].to(torch.float64)
184
+ d_prime = (sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime)) * x_prime - sigma_deriv(
185
+ t_prime) * s(t_prime) / sigma(t_prime) * denoised
186
+ x_next = x_hat + h * ((1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime)
187
+
188
+ return x_next
189
+
190
+
191
+ # ----------------------------------------------------------------------------
192
+
193
+ def retrieve_n_features(batch_size, feat_path, feat_dim, num_classes, device, split='train', sample_mode='rand_full'):
194
+ env = lmdb.open(os.path.join(feat_path, split), readonly=True, lock=False, create=False)
195
+
196
+ # Start a new read transaction
197
+ with env.begin() as txn:
198
+ # Read all images in one single transaction, with one lock
199
+ # We could split this up into multiple transactions if needed
200
+ length = int(txn.get('length'.encode('utf-8')).decode('utf-8'))
201
+ if sample_mode == 'rand_full':
202
+ image_ids = random.sample(range(length // 2), batch_size)
203
+ image_ids_y = image_ids
204
+ elif sample_mode == 'rand_repeat':
205
+ image_ids = random.sample(range(length // 2), 1) * batch_size
206
+ image_ids_y = image_ids
207
+ elif sample_mode == 'rand_y':
208
+ image_ids = random.sample(range(length // 2), 1) * batch_size
209
+ image_ids_y = random.sample(range(length // 2), batch_size)
210
+ else:
211
+ raise NotImplementedError
212
+ features, labels = [], []
213
+ for image_id, image_id_y in zip(image_ids, image_ids_y):
214
+ feat_bytes = txn.get(f'feat-{str(image_id)}'.encode('utf-8'))
215
+ y_bytes = txn.get(f'y-{str(image_id_y)}'.encode('utf-8'))
216
+ feat = np.frombuffer(feat_bytes, dtype=np.float32).reshape([feat_dim]).copy()
217
+ y = int(y_bytes.decode('utf-8'))
218
+ features.append(feat)
219
+ labels.append(y)
220
+ features = torch.from_numpy(np.stack(features)).to(device)
221
+ labels = torch.from_numpy(np.array(labels)).to(device)
222
+ class_labels = torch.zeros([batch_size, num_classes], device=device)
223
+ if num_classes > 0:
224
+ class_labels = torch.eye(num_classes, device=device)[labels]
225
+ assert features.shape[0] == class_labels.shape[0] == batch_size
226
+ return features, class_labels
227
+
228
+
229
+
230
+ @torch.no_grad()
231
+ def generate_with_net(args, net, device, rank, size):
232
+ seeds = args.seeds
233
+ num_batches = ((len(seeds) - 1) // (args.max_batch_size * size) + 1) * size
234
+ all_batches = torch.as_tensor(seeds).tensor_split(num_batches)
235
+ rank_batches = all_batches[rank:: size]
236
+
237
+ net.eval()
238
+
239
+ # Setup sampler
240
+ sampler_kwargs = dict(num_steps=args.num_steps, S_churn=args.S_churn,
241
+ solver=args.solver, discretization=args.discretization,
242
+ schedule=args.schedule, scaling=args.scaling)
243
+ sampler_kwargs = {key: value for key, value in sampler_kwargs.items() if value is not None}
244
+ have_ablation_kwargs = any(x in sampler_kwargs for x in ['solver', 'discretization', 'schedule', 'scaling'])
245
+ sampler_fn = ablation_sampler if have_ablation_kwargs else edm_sampler
246
+ mprint(f"sampler_kwargs: {sampler_kwargs}, \nsampler fn: {sampler_fn.__name__}")
247
+ # Setup autoencoder
248
+ vae = autoencoder.get_model(args.pretrained_path).to(device)
249
+
250
+ # generate images
251
+ mprint(f'Generating {len(seeds)} images to "{args.outdir}"...')
252
+ for batch_seeds in tqdm(rank_batches, unit='batch', disable=(rank != 0)):
253
+ dist.barrier()
254
+ batch_size = len(batch_seeds)
255
+ if batch_size == 0:
256
+ continue
257
+
258
+ # Pick latents and labels.
259
+ rnd = StackedRandomGenerator(device, batch_seeds)
260
+ latents = rnd.randn([batch_size, net.img_channels, net.img_resolution, net.img_resolution], device=device)
261
+ class_labels = torch.zeros([batch_size, net.num_classes], device=device)
262
+ if net.num_classes:
263
+ class_labels = torch.eye(net.num_classes, device=device)[
264
+ rnd.randint(net.num_classes, size=[batch_size], device=device)]
265
+ if args.class_idx is not None:
266
+ class_labels[:, :] = 0
267
+ class_labels[:, args.class_idx] = 1
268
+
269
+ # retrieve features from training set [support random only]
270
+ feat = None
271
+
272
+ # Generate images.
273
+ def recur_decode(z):
274
+ try:
275
+ return vae.decode(z)
276
+ except: # reduce the batch for vae decoder but two forward passes when OOM happens occasionally
277
+ assert z.shape[2] % 2 == 0
278
+ z1, z2 = z.tensor_split(2)
279
+ return torch.cat([recur_decode(z1), recur_decode(z2)])
280
+
281
+ with torch.no_grad():
282
+ z = sampler_fn(net, latents.float(), class_labels.float(), randn_like=rnd.randn_like,
283
+ cfg_scale=args.cfg_scale, feat=feat, **sampler_kwargs).float()
284
+ images = recur_decode(z)
285
+
286
+ # Save images.
287
+ images_np = images.add_(1).mul(127.5).clamp_(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()
288
+ # images_np = (images * 127.5 + 128).clip(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()
289
+ for seed, image_np in zip(batch_seeds, images_np):
290
+ image_dir = os.path.join(args.outdir, f'{seed - seed % 1000:06d}') if args.subdirs else args.outdir
291
+ os.makedirs(image_dir, exist_ok=True)
292
+ image_path = os.path.join(image_dir, f'{seed:06d}.png')
293
+ if image_np.shape[2] == 1:
294
+ PIL.Image.fromarray(image_np[:, :, 0], 'L').save(image_path)
295
+ else:
296
+ PIL.Image.fromarray(image_np, 'RGB').save(image_path)
297
+
298
+
299
+ def generate(args):
300
+ device = torch.device("cuda")
301
+
302
+ mprint(f'cf_scale: {args.cfg_scale}')
303
+ if args.global_rank == 0:
304
+ os.makedirs(args.outdir, exist_ok=True)
305
+ logger = Logger(file_name=f'{args.outdir}/log.txt', file_mode="a+", should_flush=True)
306
+
307
+ # Create model:
308
+ net = Precond_models[args.precond](
309
+ img_resolution=args.image_size,
310
+ img_channels=args.image_channels,
311
+ num_classes=args.num_classes,
312
+ model_type=args.model_type,
313
+ use_decoder=args.use_decoder,
314
+ mae_loss_coef=args.mae_loss_coef,
315
+ pad_cls_token=args.pad_cls_token,
316
+ ext_feature_dim=args.ext_feature_dim
317
+ ).to(device)
318
+ mprint(
319
+ f"{args.model_type} (use_decoder: {args.use_decoder}) Model Parameters: {sum(p.numel() for p in net.parameters()):,}")
320
+
321
+ # Load checkpoints
322
+ ckpt = torch.load(args.ckpt_path, map_location=device)
323
+ net.load_state_dict(ckpt['ema'])
324
+ mprint(f'Load weights from {args.ckpt_path}')
325
+
326
+ generate_with_net(args, net, device)
327
+
328
+ # Done.
329
+ cleanup()
330
+ if args.global_rank == 0:
331
+ logger.close()
332
+
333
+
334
+ if __name__ == '__main__':
335
+ parser = argparse.ArgumentParser('sampling parameters')
336
+
337
+ # ddp
338
+ parser.add_argument('--num_proc_node', type=int, default=1, help='The number of nodes in multi node env.')
339
+ parser.add_argument('--num_process_per_node', type=int, default=1, help='number of gpus')
340
+ parser.add_argument('--node_rank', type=int, default=0, help='The index of node.')
341
+ parser.add_argument('--local_rank', type=int, default=0, help='rank of process in the node')
342
+ parser.add_argument('--master_address', type=str, default='localhost', help='address for master')
343
+
344
+ # sampling
345
+ parser.add_argument("--feat_path", type=str, default='')
346
+ parser.add_argument("--ext_feature_dim", type=int, default=0)
347
+ parser.add_argument('--ckpt_path', type=str, required=True, help='Network pickle filename')
348
+ parser.add_argument('--outdir', type=str, required=True, help='sampling results save filename')
349
+ parser.add_argument('--seeds', type=parse_int_list, default='0-63', help='Random seeds (e.g. 1,2,5-10)')
350
+ parser.add_argument('--subdirs', action='store_true', help='Create subdirectory for every 1000 seeds')
351
+ parser.add_argument('--class_idx', type=int, default=None, help='Class label [default: random]')
352
+ parser.add_argument('--max_batch_size', type=int, default=64, help='Maximum batch size per GPU')
353
+
354
+ parser.add_argument("--cfg_scale", type=parse_float_none, default=None, help='None = no guidance, by default = 4.0')
355
+
356
+ parser.add_argument('--num_steps', type=int, default=18, help='Number of sampling steps')
357
+ parser.add_argument('--S_churn', type=int, default=0, help='Stochasticity strength')
358
+ parser.add_argument('--solver', type=str, default=None, choices=['euler', 'heun'], help='Ablate ODE solver')
359
+ parser.add_argument('--discretization', type=str, default=None, choices=['vp', 've', 'iddpm', 'edm'],
360
+ help='Ablate ODE solver')
361
+ parser.add_argument('--schedule', type=str, default=None, choices=['vp', 've', 'linear'],
362
+ help='Ablate noise schedule sigma(t)')
363
+ parser.add_argument('--scaling', type=str, default=None, choices=['vp', 'none'], help='Ablate signal scaling s(t)')
364
+ parser.add_argument('--pretrained_path', type=str, default='assets/stable_diffusion/autoencoder_kl.pth',
365
+ help='Autoencoder ckpt')
366
+
367
+ # model
368
+ parser.add_argument("--image_size", type=int, default=32)
369
+ parser.add_argument("--image_channels", type=int, default=4)
370
+ parser.add_argument("--num_classes", type=int, default=1000, help='0 means unconditional')
371
+ parser.add_argument("--model_type", type=str, choices=list(DiT_models.keys()), default="DiT-XL/2")
372
+ parser.add_argument('--precond', type=str, choices=['vp', 've', 'edm'], default='edm', help='precond train & loss')
373
+ parser.add_argument("--use_decoder", type=str2bool, default=False)
374
+ parser.add_argument("--pad_cls_token", type=str2bool, default=False)
375
+ parser.add_argument('--mae_loss_coef', type=float, default=0, help='0 means no MAE loss')
376
+ parser.add_argument('--sample_mode', type=str, default='rand_full', help='[rand_full, rand_repeat]')
377
+
378
+ args = parser.parse_args()
379
+ args.global_size = args.num_proc_node * args.num_process_per_node
380
+ size = args.num_process_per_node
381
+
382
+ if size > 1:
383
+ processes = []
384
+ for rank in range(size):
385
+ args.local_rank = rank
386
+ args.global_rank = rank + args.node_rank * args.num_process_per_node
387
+ p = Process(target=init_processes, args=(generate, args))
388
+ p.start()
389
+ processes.append(p)
390
+
391
+ for p in processes:
392
+ p.join()
393
+ else:
394
+ print('Single GPU run')
395
+ assert args.global_size == 1 and args.local_rank == 0
396
+ args.global_rank = 0
397
+ init_processes(generate, args)
scripts/download_assets.sh ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # download pretrained VAE
2
+ python3 download_assets.py --name vae --dest assets/stable_diffusion
3
+
4
+ # download ImageNet256 training set
5
+ python3 download_assets.py --name imagenet256-latent-lmdb --dest ../data/imagenet256
6
+
7
+ # download ImageNet512 training set
8
+ python3 download_assets.py --name imagenet512-latent-wds --dest ../data/imagenet512-wds
scripts/finetune_latent512.sh ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate launch \
2
+ --main_process_ip $MASTER_ADDR \
3
+ --main_process_port $MASTER_PORT \
4
+ --num_machines 4 \
5
+ --machine_rank $NODE_RANK \
6
+ --num_processes 32 \
7
+ train_wds.py \
8
+ --config configs/finetune/imagenet512-latent.yaml \
9
+ --resample \
10
+ --ckpt_path checkpoints/1050000.pt \
11
+ --use_ckpt_path False --use_strict_load False \
12
+ --no_amp
13
+
14
+
scripts/prepare_latent256.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Encode ImageNet 256x256 into latent space
2
+
3
+ python3 extract_latent.py --resolution 256 --ckpt assets/vae/autoencoder_kl.pth --batch_size 64 --outdir ../data/imagenet256-latent
scripts/prepare_latent512.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Encode ImageNet 512x512 into latent space
2
+
3
+ python3 extract_latent.py --resolution 512 --ckpt assets/vae/autoencoder_kl.pth --batch_size 64 --outdir ../data/imagenet512-latent
4
+
5
+ # Convert lmdb to webdataset
6
+ python3 lmdb2wds.py --maxcount 10010 --datadir ../data/imagenet512-latent --outdir ../data/imagenet512-latent-wds --resolution 64 --num_channels 8
scripts/train_latent512.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate launch \
2
+ --main_process_ip $MASTER_ADDR \
3
+ --main_process_port $MASTER_PORT \
4
+ --num_machines 4 \
5
+ --machine_rank $NODE_RANK \
6
+ --num_processes 32 \
7
+ train_wds.py \
8
+ --config configs/train/imagenet512-latent.yaml \
9
+ --resample
10
+
11
+
torch_utils/__init__.py ADDED
File without changes
torch_utils/persistence.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) [2023] [Anima-Lab]
4
+
5
+ # This code is adapted from https://github.com/NVlabs/edm/blob/main/fid.py.
6
+ # The original code is licensed under a Creative Commons
7
+ # Attribution-NonCommercial-ShareAlike 4.0 International License, which is can be found at licenses/LICENSE_EDM.txt.
8
+
9
+
10
+ """Facilities for pickling Python code alongside other data.
11
+
12
+ The pickled code is automatically imported into a separate Python module
13
+ during unpickling. This way, any previously exported pickles will remain
14
+ usable even if the original code is no longer available, or if the current
15
+ version of the code is not consistent with what was originally pickled."""
16
+
17
+ import sys
18
+ import pickle
19
+ import io
20
+ import inspect
21
+ import copy
22
+ import uuid
23
+ import types
24
+
25
+
26
+ #----------------------------------------------------------------------------
27
+
28
+ class EasyDict(dict):
29
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
30
+
31
+ def __getattr__(self, name):
32
+ try:
33
+ return self[name]
34
+ except KeyError:
35
+ raise AttributeError(name)
36
+
37
+ def __setattr__(self, name, value):
38
+ self[name] = value
39
+
40
+ def __delattr__(self, name):
41
+ del self[name]
42
+
43
+ #----------------------------------------------------------------------------
44
+
45
+ _version = 6 # internal version number
46
+ _decorators = set() # {decorator_class, ...}
47
+ _import_hooks = [] # [hook_function, ...]
48
+ _module_to_src_dict = dict() # {module: src, ...}
49
+ _src_to_module_dict = dict() # {src: module, ...}
50
+
51
+ #----------------------------------------------------------------------------
52
+
53
+ def persistent_class(orig_class):
54
+ r"""Class decorator that extends a given class to save its source code
55
+ when pickled.
56
+
57
+ Example:
58
+
59
+ from torch_utils import persistence
60
+
61
+ @persistence.persistent_class
62
+ class MyNetwork(torch.nn.Module):
63
+ def __init__(self, num_inputs, num_outputs):
64
+ super().__init__()
65
+ self.fc = MyLayer(num_inputs, num_outputs)
66
+ ...
67
+
68
+ @persistence.persistent_class
69
+ class MyLayer(torch.nn.Module):
70
+ ...
71
+
72
+ When pickled, any instance of `MyNetwork` and `MyLayer` will save its
73
+ source code alongside other internal state (e.g., parameters, buffers,
74
+ and submodules). This way, any previously exported pickle will remain
75
+ usable even if the class definitions have been modified or are no
76
+ longer available.
77
+
78
+ The decorator saves the source code of the entire Python module
79
+ containing the decorated class. It does *not* save the source code of
80
+ any imported modules. Thus, the imported modules must be available
81
+ during unpickling, also including `torch_utils.persistence` itself.
82
+
83
+ It is ok to call functions defined in the same module from the
84
+ decorated class. However, if the decorated class depends on other
85
+ classes defined in the same module, they must be decorated as well.
86
+ This is illustrated in the above example in the case of `MyLayer`.
87
+
88
+ It is also possible to employ the decorator just-in-time before
89
+ calling the constructor. For example:
90
+
91
+ cls = MyLayer
92
+ if want_to_make_it_persistent:
93
+ cls = persistence.persistent_class(cls)
94
+ layer = cls(num_inputs, num_outputs)
95
+
96
+ As an additional feature, the decorator also keeps track of the
97
+ arguments that were used to construct each instance of the decorated
98
+ class. The arguments can be queried via `obj.init_args` and
99
+ `obj.init_kwargs`, and they are automatically pickled alongside other
100
+ object state. This feature can be disabled on a per-instance basis
101
+ by setting `self._record_init_args = False` in the constructor.
102
+
103
+ A typical use case is to first unpickle a previous instance of a
104
+ persistent class, and then upgrade it to use the latest version of
105
+ the source code:
106
+
107
+ with open('old_pickle.pkl', 'rb') as f:
108
+ old_net = pickle.load(f)
109
+ new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
110
+ misc.copy_params_and_buffers(old_net, new_net, require_all=True)
111
+ """
112
+ assert isinstance(orig_class, type)
113
+ if is_persistent(orig_class):
114
+ return orig_class
115
+
116
+ assert orig_class.__module__ in sys.modules
117
+ orig_module = sys.modules[orig_class.__module__]
118
+ orig_module_src = _module_to_src(orig_module)
119
+
120
+ class Decorator(orig_class):
121
+ _orig_module_src = orig_module_src
122
+ _orig_class_name = orig_class.__name__
123
+
124
+ def __init__(self, *args, **kwargs):
125
+ super().__init__(*args, **kwargs)
126
+ record_init_args = getattr(self, '_record_init_args', True)
127
+ self._init_args = copy.deepcopy(args) if record_init_args else None
128
+ self._init_kwargs = copy.deepcopy(kwargs) if record_init_args else None
129
+ assert orig_class.__name__ in orig_module.__dict__
130
+ _check_pickleable(self.__reduce__())
131
+
132
+ @property
133
+ def init_args(self):
134
+ assert self._init_args is not None
135
+ return copy.deepcopy(self._init_args)
136
+
137
+ @property
138
+ def init_kwargs(self):
139
+ assert self._init_kwargs is not None
140
+ return EasyDict(copy.deepcopy(self._init_kwargs))
141
+
142
+ def __reduce__(self):
143
+ fields = list(super().__reduce__())
144
+ fields += [None] * max(3 - len(fields), 0)
145
+ if fields[0] is not _reconstruct_persistent_obj:
146
+ meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
147
+ fields[0] = _reconstruct_persistent_obj # reconstruct func
148
+ fields[1] = (meta,) # reconstruct args
149
+ fields[2] = None # state dict
150
+ return tuple(fields)
151
+
152
+ Decorator.__name__ = orig_class.__name__
153
+ Decorator.__module__ = orig_class.__module__
154
+ _decorators.add(Decorator)
155
+ return Decorator
156
+
157
+ #----------------------------------------------------------------------------
158
+
159
+ def is_persistent(obj):
160
+ r"""Test whether the given object or class is persistent, i.e.,
161
+ whether it will save its source code when pickled.
162
+ """
163
+ try:
164
+ if obj in _decorators:
165
+ return True
166
+ except TypeError:
167
+ pass
168
+ return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
169
+
170
+ #----------------------------------------------------------------------------
171
+
172
+ def import_hook(hook):
173
+ r"""Register an import hook that is called whenever a persistent object
174
+ is being unpickled. A typical use case is to patch the pickled source
175
+ code to avoid errors and inconsistencies when the API of some imported
176
+ module has changed.
177
+
178
+ The hook should have the following signature:
179
+
180
+ hook(meta) -> modified meta
181
+
182
+ `meta` is an instance of `EasyDict` with the following fields:
183
+
184
+ type: Type of the persistent object, e.g. `'class'`.
185
+ version: Internal version number of `torch_utils.persistence`.
186
+ module_src Original source code of the Python module.
187
+ class_name: Class name in the original Python module.
188
+ state: Internal state of the object.
189
+
190
+ Example:
191
+
192
+ @persistence.import_hook
193
+ def wreck_my_network(meta):
194
+ if meta.class_name == 'MyNetwork':
195
+ print('MyNetwork is being imported. I will wreck it!')
196
+ meta.module_src = meta.module_src.replace("True", "False")
197
+ return meta
198
+ """
199
+ assert callable(hook)
200
+ _import_hooks.append(hook)
201
+
202
+ #----------------------------------------------------------------------------
203
+
204
+ def _reconstruct_persistent_obj(meta):
205
+ r"""Hook that is called internally by the `pickle` module to unpickle
206
+ a persistent object.
207
+ """
208
+ meta = EasyDict(meta)
209
+ meta.state = EasyDict(meta.state)
210
+ for hook in _import_hooks:
211
+ meta = hook(meta)
212
+ assert meta is not None
213
+
214
+ assert meta.version == _version
215
+ module = _src_to_module(meta.module_src)
216
+
217
+ assert meta.type == 'class'
218
+ orig_class = module.__dict__[meta.class_name]
219
+ decorator_class = persistent_class(orig_class)
220
+ obj = decorator_class.__new__(decorator_class)
221
+
222
+ setstate = getattr(obj, '__setstate__', None)
223
+ if callable(setstate):
224
+ setstate(meta.state) # pylint: disable=not-callable
225
+ else:
226
+ obj.__dict__.update(meta.state)
227
+ return obj
228
+
229
+ #----------------------------------------------------------------------------
230
+
231
+ def _module_to_src(module):
232
+ r"""Query the source code of a given Python module.
233
+ """
234
+ src = _module_to_src_dict.get(module, None)
235
+ if src is None:
236
+ src = inspect.getsource(module)
237
+ _module_to_src_dict[module] = src
238
+ _src_to_module_dict[src] = module
239
+ return src
240
+
241
+ def _src_to_module(src):
242
+ r"""Get or create a Python module for the given source code.
243
+ """
244
+ module = _src_to_module_dict.get(src, None)
245
+ if module is None:
246
+ module_name = "_imported_module_" + uuid.uuid4().hex
247
+ module = types.ModuleType(module_name)
248
+ sys.modules[module_name] = module
249
+ _module_to_src_dict[module] = src
250
+ _src_to_module_dict[src] = module
251
+ exec(src, module.__dict__) # pylint: disable=exec-used
252
+ return module
253
+
254
+ #----------------------------------------------------------------------------
255
+
256
+ def _check_pickleable(obj):
257
+ r"""Check that the given object is pickleable, raising an exception if
258
+ it is not. This function is expected to be considerably more efficient
259
+ than actually pickling the object.
260
+ """
261
+ def recurse(obj):
262
+ if isinstance(obj, (list, tuple, set)):
263
+ return [recurse(x) for x in obj]
264
+ if isinstance(obj, dict):
265
+ return [[recurse(x), recurse(y)] for x, y in obj.items()]
266
+ if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
267
+ return None # Python primitive types are pickleable.
268
+ if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']:
269
+ return None # NumPy arrays and PyTorch tensors are pickleable.
270
+ if is_persistent(obj):
271
+ return None # Persistent objects are pickleable, by virtue of the constructor check.
272
+ return obj
273
+ with io.BytesIO() as f:
274
+ pickle.dump(recurse(obj), f)
275
+
276
+ #----------------------------------------------------------------------------
train.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) [2023] [Anima-Lab]
4
+ '''
5
+ Training MaskDiT on latent dataset in LMDB format. Used for experiments on Imagenet256x256.
6
+ '''
7
+
8
+ import argparse
9
+ import os.path
10
+ from copy import deepcopy
11
+ from time import time
12
+ from omegaconf import OmegaConf
13
+
14
+ import apex
15
+ import torch
16
+ import accelerate
17
+
18
+ from torch.utils.data import DataLoader
19
+
20
+ from fid import calc
21
+ from models.maskdit import Precond_models
22
+ from train_utils.loss import Losses
23
+ from train_utils.datasets import ImageNetLatentDataset
24
+
25
+ from train_utils.helper import get_mask_ratio_fn, requires_grad, update_ema, unwrap_model
26
+
27
+ from sample import generate_with_net
28
+ from utils import dist, mprint, get_latest_ckpt, Logger, sample, \
29
+ str2bool, parse_str_none, parse_int_list, parse_float_none
30
+
31
+
32
+ # ------------------------------------------------------------
33
+
34
+
35
+ def train_loop(args):
36
+ # load configuration
37
+ config = OmegaConf.load(args.config)
38
+
39
+ if not args.no_amp:
40
+ config.train.amp = 'fp16'
41
+ else:
42
+ config.train.amp = 'no'
43
+
44
+ if config.train.tf32:
45
+ torch.backends.cudnn.allow_tf32 = True
46
+ torch.set_float32_matmul_precision('high')
47
+
48
+ accelerator = accelerate.Accelerator(mixed_precision=config.train.amp,
49
+ gradient_accumulation_steps=config.train.grad_accum,
50
+ log_with='wandb')
51
+ # setup wandb
52
+ if args.use_wandb:
53
+ wandb_init_kwargs = {
54
+ 'entity': config.wandb.entity,
55
+ 'project': config.wandb.project,
56
+ 'group': config.wandb.group,
57
+ }
58
+ accelerator.init_trackers(config.wandb.project, config=OmegaConf.to_container(config), init_kwargs=wandb_init_kwargs)
59
+
60
+ mprint('start training...')
61
+ size = accelerator.num_processes
62
+ rank = accelerator.process_index
63
+
64
+ print(f'global_rank: {rank}, global_size: {size}')
65
+ device = accelerator.device
66
+
67
+ seed = args.global_seed
68
+ torch.manual_seed(seed)
69
+
70
+ mprint(f"enable_amp: {not args.no_amp}, TF32: {config.train.tf32}")
71
+ # Select batch size per GPU
72
+ num_accumulation_rounds = config.train.grad_accum
73
+ micro_batch = config.train.batchsize
74
+ batch_gpu_total = micro_batch * num_accumulation_rounds
75
+ global_batch_size = batch_gpu_total * size
76
+ mprint(f"Global batchsize: {global_batch_size}, batchsize per GPU: {batch_gpu_total}, micro_batch: {micro_batch}.")
77
+
78
+ class_dropout_prob = config.model.class_dropout_prob
79
+ log_every = config.log.log_every
80
+ ckpt_every = config.log.ckpt_every
81
+
82
+ mask_ratio_fn = get_mask_ratio_fn(config.model.mask_ratio_fn, config.model.mask_ratio, config.model.mask_ratio_min)
83
+
84
+ # Setup an experiment folder
85
+ model_name = config.model.model_type.replace("/", "-") # e.g., DiT-XL/2 --> DiT-XL-2 (for naming folders)
86
+ data_name = config.data.dataset
87
+ if args.ckpt_path is not None and args.use_ckpt_path: # use the existing exp path (mainly used for fine-tuning)
88
+ checkpoint_dir = os.path.dirname(args.ckpt_path)
89
+ experiment_dir = os.path.dirname(checkpoint_dir)
90
+ exp_name = os.path.basename(experiment_dir)
91
+ else: # start a new exp path (and resume from the latest checkpoint if possible)
92
+ cond_gen = 'cond' if config.model.num_classes else 'uncond'
93
+ exp_name = f'{model_name}-{config.model.precond}-{data_name}-{cond_gen}-m{config.model.mask_ratio}-de{int(config.model.use_decoder)}' \
94
+ f'-mae{config.model.mae_loss_coef}-bs-{global_batch_size}-lr{config.train.lr}{config.log.tag}'
95
+ experiment_dir = f"{args.results_dir}/{exp_name}"
96
+ checkpoint_dir = f"{experiment_dir}/checkpoints" # Stores saved model checkpoints
97
+ os.makedirs(checkpoint_dir, exist_ok=True)
98
+ if args.ckpt_path is None:
99
+ args.ckpt_path = get_latest_ckpt(checkpoint_dir) # Resumes from the latest checkpoint if it exists
100
+ mprint(f"Experiment directory created at {experiment_dir}")
101
+
102
+ if accelerator.is_main_process:
103
+ logger = Logger(file_name=f'{experiment_dir}/log.txt', file_mode="a+", should_flush=True)
104
+
105
+ mprint(f"Experiment directory created at {experiment_dir}")
106
+ # Setup dataset
107
+ dataset = ImageNetLatentDataset(
108
+ config.data.root, resolution=config.data.resolution,
109
+ num_channels=config.data.num_channels, xflip=config.train.xflip,
110
+ feat_path=config.data.feat_path, feat_dim=config.model.ext_feature_dim)
111
+
112
+ loader = DataLoader(
113
+ dataset, batch_size=batch_gpu_total, shuffle=False,
114
+ num_workers=args.num_workers,
115
+ pin_memory=True, persistent_workers=True,
116
+ drop_last=True
117
+ )
118
+ mprint(f"Dataset contains {len(dataset):,} images ({config.data.root})")
119
+
120
+ steps_per_epoch = len(dataset) // global_batch_size
121
+ mprint(f"{steps_per_epoch} steps per epoch")
122
+
123
+ model = Precond_models[config.model.precond](
124
+ img_resolution=config.model.in_size,
125
+ img_channels=config.model.in_channels,
126
+ num_classes=config.model.num_classes,
127
+ model_type=config.model.model_type,
128
+ use_decoder=config.model.use_decoder,
129
+ mae_loss_coef=config.model.mae_loss_coef,
130
+ pad_cls_token=config.model.pad_cls_token
131
+ ).to(device)
132
+
133
+ # Note that parameter initialization is done within the model constructor
134
+ ema = deepcopy(model).to(device) # Create an EMA of the model for use after training
135
+ requires_grad(ema, False)
136
+
137
+ mprint(f"{config.model.model_type} ((use_decoder: {config.model.use_decoder})) Model Parameters: {sum(p.numel() for p in model.parameters()):,}")
138
+ mprint(f'extras: {model.model.extras}, cls_token: {model.model.cls_token}')
139
+
140
+ # Setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper):
141
+ optimizer = apex.optimizers.FusedAdam(model.parameters(), lr=config.train.lr, adam_w_mode=True, weight_decay=0)
142
+
143
+ # Load checkpoints
144
+ train_steps_start = 0
145
+ epoch_start = 0
146
+
147
+ if args.ckpt_path is not None:
148
+ ckpt = torch.load(args.ckpt_path, map_location=device)
149
+ model.load_state_dict(ckpt['model'], strict=args.use_strict_load)
150
+ ema.load_state_dict(ckpt['ema'], strict=args.use_strict_load)
151
+ mprint(f'Load weights from {args.ckpt_path}')
152
+ if args.use_strict_load:
153
+ optimizer.load_state_dict(ckpt['opt'])
154
+ for state in optimizer.state.values():
155
+ for k, v in state.items():
156
+ if isinstance(v, torch.Tensor):
157
+ state[k] = v.cuda()
158
+ mprint(f'Load optimizer state..')
159
+ train_steps_start = int(os.path.basename(args.ckpt_path).split('.pt')[0])
160
+ epoch_start = train_steps_start // steps_per_epoch
161
+ mprint(f"train_steps_start: {train_steps_start}")
162
+ del ckpt # conserve memory
163
+
164
+ # FID evaluation for the loaded weights
165
+ if args.enable_eval:
166
+ start_time = time()
167
+ args.outdir = os.path.join(experiment_dir, 'fid', f'edm-steps{args.num_steps}-ckpt{train_steps_start}_cfg{args.cfg_scale}')
168
+ os.makedirs(args.outdir, exist_ok=True)
169
+ generate_with_net(args, ema, device)
170
+ dist.barrier()
171
+ fid = calc(args.outdir, config.eval.ref_path, args.num_expected, args.global_seed, args.fid_batch_size)
172
+ mprint(f"time for fid calc: {time() - start_time}")
173
+ if args.use_wandb:
174
+ accelerator.log({f'eval/fid': fid}, step=train_steps_start)
175
+ mprint(f'guidance: {args.cfg_scale} FID: {fid}')
176
+ dist.barrier()
177
+
178
+ model, optimizer, loader = accelerator.prepare(model, optimizer, loader)
179
+ model = torch.compile(model)
180
+
181
+ # Setup loss
182
+ loss_fn = Losses[config.model.precond]()
183
+
184
+ # Prepare models for training:
185
+ if args.ckpt_path is None:
186
+ assert train_steps_start == 0
187
+ raw_model = unwrap_model(model)
188
+ update_ema(ema, raw_model, decay=0) # Ensure EMA is initialized with synced weights
189
+ model.train() # important! This enables embedding dropout for classifier-free guidance
190
+ ema.eval() # EMA model should always be in eval mode
191
+
192
+ # Variables for monitoring/logging purposes:
193
+ train_steps = train_steps_start
194
+ log_steps = 0
195
+ running_loss = 0
196
+ start_time = time()
197
+ mprint(f"Training for {config.train.epochs} epochs...")
198
+ for epoch in range(epoch_start, config.train.epochs):
199
+ mprint(f"Beginning epoch {epoch}...")
200
+ for x, cond in loader:
201
+ x = x.to(device)
202
+ y = cond.to(device)
203
+ x = sample(x)
204
+ # Accumulate gradients.
205
+ loss_batch = 0
206
+ model.zero_grad(set_to_none=True)
207
+ curr_mask_ratio = mask_ratio_fn((train_steps - train_steps_start) / config.train.max_num_steps)
208
+ if class_dropout_prob > 0:
209
+ y = y * (torch.rand([y.shape[0], 1], device=device) >= class_dropout_prob)
210
+
211
+ for round_idx in range(num_accumulation_rounds):
212
+ x_ = x[round_idx * micro_batch: (round_idx + 1) * micro_batch]
213
+ y_ = y[round_idx * micro_batch: (round_idx + 1) * micro_batch]
214
+
215
+ with accelerator.accumulate(model):
216
+ loss = loss_fn(net=model, images=x_, labels=y_,
217
+ mask_ratio=curr_mask_ratio,
218
+ mae_loss_coef=config.model.mae_loss_coef)
219
+ loss_mean = loss.mean()
220
+ accelerator.backward(loss_mean)
221
+
222
+ # Update weights with lr warmup.
223
+ lr_cur = config.train.lr * min(train_steps * global_batch_size / max(config.train.lr_rampup_kimg * 1000, 1e-8), 1)
224
+ for g in optimizer.param_groups:
225
+ g['lr'] = lr_cur
226
+ optimizer.step()
227
+ loss_batch += loss_mean.item()
228
+
229
+ raw_model = unwrap_model(model)
230
+ update_ema(ema, model.module)
231
+
232
+ # Log loss values:
233
+ running_loss += loss_batch
234
+ log_steps += 1
235
+ train_steps += 1
236
+ if train_steps > (train_steps_start + config.train.max_num_steps):
237
+ break
238
+ if train_steps % log_every == 0:
239
+ # Measure training speed:
240
+ torch.cuda.synchronize()
241
+ end_time = time()
242
+ steps_per_sec = log_steps / (end_time - start_time)
243
+ # Reduce loss history over all processes:
244
+ avg_loss = torch.tensor(running_loss / log_steps, device=device)
245
+ dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
246
+ avg_loss = avg_loss.item() / size
247
+ mprint(f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}")
248
+ mprint(f'Peak GPU memory usage: {torch.cuda.max_memory_allocated() / 1024 ** 3:.2f} GB')
249
+ mprint(f'Reserved GPU memory: {torch.cuda.memory_reserved() / 1024 ** 3:.2f} GB')
250
+
251
+ if args.use_wandb:
252
+ accelerator.log({f'train/loss': avg_loss, 'train/lr': lr_cur}, step=train_steps)
253
+ # Reset monitoring variables:
254
+ running_loss = 0
255
+ log_steps = 0
256
+ start_time = time()
257
+
258
+ # Save checkpoint:
259
+ if train_steps % ckpt_every == 0 and train_steps > train_steps_start:
260
+ if rank == 0:
261
+ checkpoint = {
262
+ "model": raw_model.state_dict(),
263
+ "ema": ema.state_dict(),
264
+ "opt": optimizer.state_dict(),
265
+ "args": args
266
+ }
267
+ checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt"
268
+ torch.save(checkpoint, checkpoint_path)
269
+ mprint(f"Saved checkpoint to {checkpoint_path}")
270
+ del checkpoint # conserve memory
271
+ dist.barrier()
272
+
273
+ # FID evaluation during training
274
+ if args.enable_eval:
275
+ start_time = time()
276
+ args.outdir = os.path.join(experiment_dir, 'fid', f'edm-steps{args.num_steps}-ckpt{train_steps}_cfg{args.cfg_scale}')
277
+ os.makedirs(args.outdir, exist_ok=True)
278
+ generate_with_net(args, ema, device, rank, size)
279
+
280
+ dist.barrier()
281
+ fid = calc(args.outdir, args.ref_path, args.num_expected, args.global_seed, args.fid_batch_size)
282
+ mprint(f"time for fid calc: {time() - start_time}, fid: {fid}")
283
+ if args.use_wandb:
284
+ accelerator.log({f'eval/fid': fid}, step=train_steps)
285
+ mprint(f'Guidance: {args.cfg_scale}, FID: {fid}')
286
+ dist.barrier()
287
+ start_time = time()
288
+
289
+ if accelerator.is_main_process:
290
+ logger.close()
291
+ accelerator.end_training()
292
+
293
+
294
+ if __name__ == '__main__':
295
+ parser = argparse.ArgumentParser('training parameters')
296
+ # basic config
297
+ parser.add_argument('--config', type=str, required=True, help='path to config file')
298
+
299
+ # training
300
+ parser.add_argument("--results_dir", type=str, default="results")
301
+ parser.add_argument("--ckpt_path", type=parse_str_none, default=None)
302
+
303
+ parser.add_argument("--global_seed", type=int, default=0)
304
+ parser.add_argument("--num_workers", type=int, default=4)
305
+ parser.add_argument('--no_amp', action='store_true', help="Disable automatic mixed precision.")
306
+
307
+ parser.add_argument("--use_wandb", action='store_true', help='enable wandb logging')
308
+ parser.add_argument("--use_ckpt_path", type=str2bool, default=True)
309
+ parser.add_argument("--use_strict_load", type=str2bool, default=True)
310
+ parser.add_argument("--tag", type=str, default='')
311
+
312
+ # sampling
313
+ parser.add_argument('--enable_eval', action='store_true', help='enable fid calc during training')
314
+ parser.add_argument('--seeds', type=parse_int_list, default='0-49999', help='Random seeds (e.g. 1,2,5-10)')
315
+ parser.add_argument('--subdirs', action='store_true', help='Create subdirectory for every 1000 seeds')
316
+ parser.add_argument('--class_idx', type=int, default=None, help='Class label [default: random]')
317
+ parser.add_argument('--max_batch_size', type=int, default=50, help='Maximum batch size per GPU during sampling, must be a factor of 50k if torch.compile is used')
318
+
319
+ parser.add_argument("--cfg_scale", type=parse_float_none, default=None, help='None = no guidance, by default = 4.0')
320
+
321
+ parser.add_argument('--num_steps', type=int, default=40, help='Number of sampling steps')
322
+ parser.add_argument('--S_churn', type=int, default=0, help='Stochasticity strength')
323
+ parser.add_argument('--solver', type=str, default=None, choices=['euler', 'heun'], help='Ablate ODE solver')
324
+ parser.add_argument('--discretization', type=str, default=None, choices=['vp', 've', 'iddpm', 'edm'], help='Ablate ODE solver')
325
+ parser.add_argument('--schedule', type=str, default=None, choices=['vp', 've', 'linear'], help='Ablate noise schedule sigma(t)')
326
+ parser.add_argument('--scaling', type=str, default=None, choices=['vp', 'none'], help='Ablate signal scaling s(t)')
327
+ parser.add_argument('--pretrained_path', type=str, default='assets/stable_diffusion/autoencoder_kl.pth', help='Autoencoder ckpt')
328
+
329
+ parser.add_argument('--ref_path', type=str, default='assets/fid_stats/fid_stats_imagenet256_guided_diffusion.npz', help='Dataset reference statistics')
330
+ parser.add_argument('--num_expected', type=int, default=50000, help='Number of images to use')
331
+ parser.add_argument('--fid_batch_size', type=int, default=64, help='Maximum batch size per GPU')
332
+
333
+ args = parser.parse_args()
334
+
335
+ torch.backends.cudnn.benchmark = True
336
+ train_loop(args)
train_utils/datasets.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) [2023] [Anima-Lab]
4
+
5
+
6
+ import io
7
+ import os
8
+ import json
9
+ import zipfile
10
+
11
+ import lmdb
12
+ import numpy as np
13
+ from PIL import Image
14
+ import torch
15
+ from torchvision.datasets import ImageFolder, VisionDataset
16
+
17
+
18
+
19
+ def center_crop_arr(pil_image, image_size):
20
+ """
21
+ Center cropping implementation from ADM.
22
+ https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
23
+ """
24
+ while min(*pil_image.size) >= 2 * image_size:
25
+ pil_image = pil_image.resize(
26
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
27
+ )
28
+
29
+ scale = image_size / min(*pil_image.size)
30
+ pil_image = pil_image.resize(
31
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
32
+ )
33
+
34
+ arr = np.array(pil_image)
35
+ crop_y = (arr.shape[0] - image_size) // 2
36
+ crop_x = (arr.shape[1] - image_size) // 2
37
+ return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
38
+
39
+
40
+ ################################################################################
41
+ # ImageNet - LMDB
42
+ ###############################################################################
43
+
44
+ def lmdb_loader(path, lmdb_data, resolution):
45
+ # In-memory binary streams
46
+ with lmdb_data.begin(write=False, buffers=True) as txn:
47
+ bytedata = txn.get(path.encode('ascii'))
48
+ img = Image.open(io.BytesIO(bytedata)).convert('RGB')
49
+ arr = center_crop_arr(img, resolution)
50
+ # arr = arr.astype(np.float32) / 127.5 - 1
51
+ # arr = np.transpose(arr, [2, 0, 1]) # CHW
52
+ return arr
53
+
54
+
55
+ def imagenet_lmdb_dataset(
56
+ root,
57
+ transform=None, target_transform=None,
58
+ resolution=256):
59
+ """
60
+ You can create this dataloader using:
61
+ train_data = imagenet_lmdb_dataset(traindir, transform=train_transform)
62
+ valid_data = imagenet_lmdb_dataset(validdir, transform=val_transform)
63
+ """
64
+
65
+ if root.endswith('/'):
66
+ root = root[:-1]
67
+ pt_path = os.path.join(
68
+ root + '_faster_imagefolder.lmdb.pt')
69
+ lmdb_path = os.path.join(
70
+ root + '_faster_imagefolder.lmdb')
71
+ if os.path.isfile(pt_path) and os.path.isdir(lmdb_path):
72
+ print('Loading pt {} and lmdb {}'.format(pt_path, lmdb_path))
73
+ data_set = torch.load(pt_path)
74
+ else:
75
+ data_set = ImageFolder(
76
+ root, None, None, None)
77
+ torch.save(data_set, pt_path, pickle_protocol=4)
78
+ print('Saving pt to {}'.format(pt_path))
79
+ print('Building lmdb to {}'.format(lmdb_path))
80
+ env = lmdb.open(lmdb_path, map_size=1e12)
81
+ with env.begin(write=True) as txn:
82
+ for path, class_index in data_set.imgs:
83
+ with open(path, 'rb') as f:
84
+ data = f.read()
85
+ txn.put(path.encode('ascii'), data)
86
+
87
+ lmdb_dataset = ImageLMDB(lmdb_path, transform, target_transform, resolution, data_set.imgs, data_set.class_to_idx, data_set.classes)
88
+ return lmdb_dataset
89
+
90
+
91
+ ################################################################################
92
+ # ImageNet Dataset class- LMDB
93
+ ###############################################################################
94
+
95
+ class ImageLMDB(VisionDataset):
96
+ """
97
+ A data loader for ImageNet LMDB dataset, which is faster than the original ImageFolder.
98
+ """
99
+ def __init__(self, root, transform=None, target_transform=None,
100
+ resolution=256, samples=None, class_to_idx=None, classes=None):
101
+ super().__init__(root, transform=transform,
102
+ target_transform=target_transform)
103
+ self.root = root
104
+ self.resolution = resolution
105
+ self.samples = samples
106
+ self.class_to_idx = class_to_idx
107
+ self.classes = classes
108
+
109
+ def __getitem__(self, index: int):
110
+ path, target = self.samples[index]
111
+
112
+ # load image from path
113
+ if not hasattr(self, 'txn'):
114
+ self.open_db()
115
+ bytedata = self.txn.get(path.encode('ascii'))
116
+ img = Image.open(io.BytesIO(bytedata)).convert('RGB')
117
+ arr = center_crop_arr(img, self.resolution)
118
+ if self.transform is not None:
119
+ arr = self.transform(arr)
120
+ if self.target_transform is not None:
121
+ target = self.target_transform(target)
122
+ return arr, target
123
+
124
+ def __len__(self) -> int:
125
+ return len(self.samples)
126
+
127
+ def open_db(self):
128
+ self.env = lmdb.open(self.root, readonly=True, max_readers=256, lock=False, readahead=False, meminit=False)
129
+ self.txn = self.env.begin(write=False, buffers=True)
130
+
131
+
132
+
133
+ ################################################################################
134
+ # ImageNet - LMDB - latent space
135
+ ###############################################################################
136
+
137
+
138
+
139
+ # ----------------------------------------------------------------------------
140
+ # Abstract base class for datasets.
141
+
142
+ class Dataset(torch.utils.data.Dataset):
143
+ def __init__(self,
144
+ name, # Name of the dataset.
145
+ raw_shape, # Shape of the raw image data (NCHW).
146
+ max_size=None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip.
147
+ label_dim=1000, # Ensure specific number of classes
148
+ xflip=False, # Artificially double the size of the dataset via x-flips. Applied after max_size.
149
+ random_seed=0, # Random seed to use when applying max_size.
150
+ ):
151
+ self._name = name
152
+ self._raw_shape = list(raw_shape)
153
+ self._label_dim = label_dim
154
+ self._label_shape = None
155
+
156
+ # Apply max_size.
157
+ self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64)
158
+ if (max_size is not None) and (self._raw_idx.size > max_size):
159
+ np.random.RandomState(random_seed % (1 << 31)).shuffle(self._raw_idx)
160
+ self._raw_idx = np.sort(self._raw_idx[:max_size])
161
+
162
+ # Apply xflip. (Assume the dataset already contains the same number of xflipped samples)
163
+ if xflip:
164
+ self._raw_idx = np.concatenate([self._raw_idx, self._raw_idx + self._raw_shape[0]])
165
+
166
+ def close(self): # to be overridden by subclass
167
+ pass
168
+
169
+ def _load_raw_data(self, raw_idx): # to be overridden by subclass
170
+ raise NotImplementedError
171
+
172
+ def __getstate__(self):
173
+ return dict(self.__dict__, _raw_labels=None)
174
+
175
+ def __del__(self):
176
+ try:
177
+ self.close()
178
+ except:
179
+ pass
180
+
181
+ def __len__(self):
182
+ return self._raw_idx.size
183
+
184
+ def __getitem__(self, idx):
185
+ raw_idx = self._raw_idx[idx]
186
+ image, cond = self._load_raw_data(raw_idx)
187
+ assert isinstance(image, np.ndarray)
188
+ if isinstance(cond, list): # [label, feature]
189
+ cond[0] = self._get_onehot(cond[0])
190
+ else: # label
191
+ cond = self._get_onehot(cond)
192
+ return image.copy(), cond
193
+
194
+ def _get_onehot(self, label):
195
+ if isinstance(label, int) or label.dtype == np.int64:
196
+ onehot = np.zeros(self.label_shape, dtype=np.float32)
197
+ onehot[label] = 1
198
+ label = onehot
199
+ assert isinstance(label, np.ndarray)
200
+ return label.copy()
201
+
202
+ @property
203
+ def name(self):
204
+ return self._name
205
+
206
+ @property
207
+ def image_shape(self):
208
+ return list(self._raw_shape[1:])
209
+
210
+ @property
211
+ def num_channels(self):
212
+ assert len(self.image_shape) == 3 # CHW
213
+ return self.image_shape[0]
214
+
215
+ @property
216
+ def resolution(self):
217
+ assert len(self.image_shape) == 3 # CHW
218
+ assert self.image_shape[1] == self.image_shape[2]
219
+ return self.image_shape[1]
220
+
221
+ @property
222
+ def label_shape(self):
223
+ if self._label_shape is None:
224
+ self._label_shape = [self._label_dim]
225
+ return list(self._label_shape)
226
+
227
+ @property
228
+ def label_dim(self):
229
+ assert len(self.label_shape) == 1
230
+ return self.label_shape[0]
231
+
232
+ @property
233
+ def has_labels(self):
234
+ return True
235
+
236
+
237
+ # ----------------------------------------------------------------------------
238
+ # Dataset subclass that loads latent images recursively from the specified lmdb file.
239
+
240
+ class ImageNetLatentDataset(Dataset):
241
+ def __init__(self,
242
+ path, # Path to directory or zip.
243
+ resolution=32, # Ensure specific resolution, default 32.
244
+ num_channels=4, # Ensure specific number of channels, default 4.
245
+ split='train', # train or val split
246
+ feat_path=None, # Path to features lmdb file (only works when feat_cond=True)
247
+ feat_dim=0, # feature dim
248
+ **super_kwargs, # Additional arguments for the Dataset base class.
249
+ ):
250
+ self._path = os.path.join(path, split)
251
+ self.feat_dim = feat_dim
252
+ if not hasattr(self, 'txn'):
253
+ self.open_lmdb()
254
+ self.feat_txn = None
255
+ if feat_path is not None and os.path.isdir(feat_path):
256
+ assert self.feat_dim > 0
257
+ self._feat_path = os.path.join(feat_path, split)
258
+ self.open_feat_lmdb()
259
+
260
+ length = int(self.txn.get('length'.encode('utf-8')).decode('utf-8'))
261
+ name = os.path.basename(path)
262
+ raw_shape = [length, num_channels, resolution, resolution] # 1281167 x 4 x 32 x 32
263
+ if raw_shape[2] != resolution or raw_shape[3] != resolution:
264
+ raise IOError('Image files do not match the specified resolution')
265
+
266
+ super().__init__(name=name, raw_shape=raw_shape, **super_kwargs)
267
+
268
+ def open_lmdb(self):
269
+ self.env = lmdb.open(self._path, readonly=True, lock=False, create=False)
270
+ self.txn = self.env.begin(write=False)
271
+
272
+ def open_feat_lmdb(self):
273
+ self.feat_env = lmdb.open(self._feat_path, readonly=True, lock=False, create=False)
274
+ self.feat_txn = self.feat_env.begin(write=False)
275
+
276
+ def _load_raw_data(self, idx):
277
+ if not hasattr(self, 'txn'):
278
+ self.open_lmdb()
279
+
280
+ z_bytes = self.txn.get(f'z-{str(idx)}'.encode('utf-8'))
281
+ y_bytes = self.txn.get(f'y-{str(idx)}'.encode('utf-8'))
282
+ z = np.frombuffer(z_bytes, dtype=np.float32).reshape([-1, self.resolution, self.resolution]).copy()
283
+ y = int(y_bytes.decode('utf-8'))
284
+
285
+ cond = y
286
+ if self.feat_txn is not None:
287
+ feat_bytes = self.feat_txn.get(f'feat-{str(idx)}'.encode('utf-8'))
288
+ feat_y_bytes = self.feat_txn.get(f'y-{str(idx)}'.encode('utf-8'))
289
+ feat = np.frombuffer(feat_bytes, dtype=np.float32).reshape([self.feat_dim]).copy()
290
+ feat_y = int(feat_y_bytes.decode('utf-8'))
291
+ assert y == feat_y, 'Ordering mismatch between txn and feat_txn!'
292
+ cond = [y, feat]
293
+
294
+ return z, cond
295
+
296
+ def close(self):
297
+ try:
298
+ if self.env is not None:
299
+ self.env.close()
300
+ if self.feat_env is not None:
301
+ self.feat_env.close()
302
+ finally:
303
+ self.env = None
304
+ self.feat_env = None
305
+
306
+
307
+ # ----------------------------------------------------------------------------
308
+ # Dataset subclass that loads images recursively from the specified directory or zip file.
309
+
310
+ class ImageFolderDataset(Dataset):
311
+ def __init__(self,
312
+ path, # Path to directory or zip.
313
+ resolution=None, # Ensure specific resolution, None = highest available.
314
+ use_labels=False, # Enable conditioning labels? False = label dimension is zero.
315
+ **super_kwargs, # Additional arguments for the Dataset base class.
316
+ ):
317
+ self._path = path
318
+ self._zipfile = None
319
+ self._raw_labels = None
320
+ self._use_labels = use_labels
321
+
322
+ if os.path.isdir(self._path):
323
+ self._type = 'dir'
324
+ self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in
325
+ os.walk(self._path) for fname in files}
326
+ elif self._file_ext(self._path) == '.zip':
327
+ self._type = 'zip'
328
+ self._all_fnames = set(self._get_zipfile().namelist())
329
+ else:
330
+ raise IOError('Path must point to a directory or zip')
331
+
332
+ Image.init()
333
+ self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in Image.EXTENSION)
334
+ if len(self._image_fnames) == 0:
335
+ raise IOError('No image files found in the specified path')
336
+
337
+ name = os.path.splitext(os.path.basename(self._path))[0]
338
+ raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape)
339
+ if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution):
340
+ raise IOError('Image files do not match the specified resolution')
341
+ super().__init__(name=name, raw_shape=raw_shape, **super_kwargs)
342
+
343
+ @staticmethod
344
+ def _file_ext(fname):
345
+ return os.path.splitext(fname)[1].lower()
346
+
347
+ def _get_zipfile(self):
348
+ assert self._type == 'zip'
349
+ if self._zipfile is None:
350
+ self._zipfile = zipfile.ZipFile(self._path)
351
+ return self._zipfile
352
+
353
+ def _open_file(self, fname):
354
+ if self._type == 'dir':
355
+ return open(os.path.join(self._path, fname), 'rb')
356
+ if self._type == 'zip':
357
+ return self._get_zipfile().open(fname, 'r')
358
+ return None
359
+
360
+ def close(self):
361
+ try:
362
+ if self._zipfile is not None:
363
+ self._zipfile.close()
364
+ finally:
365
+ self._zipfile = None
366
+
367
+ def __getstate__(self):
368
+ return dict(super().__getstate__(), _zipfile=None)
369
+
370
+ def _load_raw_data(self, raw_idx):
371
+ image = self._load_raw_image(raw_idx)
372
+ assert image.dtype == np.uint8
373
+ label = self._get_raw_labels()[raw_idx]
374
+ return image, label
375
+
376
+ def _load_raw_image(self, raw_idx):
377
+ fname = self._image_fnames[raw_idx]
378
+ with self._open_file(fname) as f:
379
+ image = np.array(Image.open(f))
380
+ if image.ndim == 2:
381
+ image = image[:, :, np.newaxis] # HW => HWC
382
+ image = image.transpose(2, 0, 1) # HWC => CHW
383
+ return image
384
+
385
+ def _get_raw_labels(self):
386
+ if self._raw_labels is None:
387
+ self._raw_labels = self._load_raw_labels() if self._use_labels else None
388
+ if self._raw_labels is None:
389
+ self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32)
390
+ assert isinstance(self._raw_labels, np.ndarray)
391
+ assert self._raw_labels.shape[0] == self._raw_shape[0]
392
+ assert self._raw_labels.dtype in [np.float32, np.int64]
393
+ if self._raw_labels.dtype == np.int64:
394
+ assert self._raw_labels.ndim == 1
395
+ assert np.all(self._raw_labels >= 0)
396
+ return self._raw_labels
397
+
398
+ def _load_raw_labels(self):
399
+ fname = 'dataset.json'
400
+ if fname not in self._all_fnames:
401
+ return None
402
+ with self._open_file(fname) as f:
403
+ labels = json.load(f)['labels']
404
+ if labels is None:
405
+ return None
406
+ labels = dict(labels)
407
+ labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames]
408
+ labels = np.array(labels)
409
+ labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])
410
+ return labels
411
+
412
+ # ----------------------------------------------------------------------------
train_utils/helper.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) [2023] [Anima-Lab]
4
+ from collections import OrderedDict
5
+ import torch
6
+ import numpy as np
7
+
8
+
9
+ def get_mask_ratio_fn(name='constant', ratio_scale=0.5, ratio_min=0.0):
10
+ if name == 'cosine2':
11
+ return lambda x: (ratio_scale - ratio_min) * np.cos(np.pi * x / 2) ** 2 + ratio_min
12
+ elif name == 'cosine3':
13
+ return lambda x: (ratio_scale - ratio_min) * np.cos(np.pi * x / 2) ** 3 + ratio_min
14
+ elif name == 'cosine4':
15
+ return lambda x: (ratio_scale - ratio_min) * np.cos(np.pi * x / 2) ** 4 + ratio_min
16
+ elif name == 'cosine5':
17
+ return lambda x: (ratio_scale - ratio_min) * np.cos(np.pi * x / 2) ** 5 + ratio_min
18
+ elif name == 'cosine6':
19
+ return lambda x: (ratio_scale - ratio_min) * np.cos(np.pi * x / 2) ** 6 + ratio_min
20
+ elif name == 'exp':
21
+ return lambda x: (ratio_scale - ratio_min) * np.exp(-x * 7) + ratio_min
22
+ elif name == 'linear':
23
+ return lambda x: (ratio_scale - ratio_min) * x + ratio_min
24
+ elif name == 'constant':
25
+ return lambda x: ratio_scale
26
+ else:
27
+ raise ValueError('Unknown mask ratio function: {}'.format(name))
28
+
29
+
30
+ def get_one_hot(labels, num_classes=1000):
31
+ one_hot = torch.zeros(labels.shape[0], num_classes, device=labels.device)
32
+ one_hot.scatter_(1, labels.view(-1, 1), 1)
33
+ return one_hot
34
+
35
+
36
+ def requires_grad(model, flag=True):
37
+ """
38
+ Set requires_grad flag for all parameters in a model.
39
+ """
40
+ for p in model.parameters():
41
+ p.requires_grad = flag
42
+
43
+
44
+ # ------------------------------------------------------------
45
+ # Training Helper Function
46
+
47
+ @torch.no_grad()
48
+ def update_ema(ema_model, model, decay=0.9999):
49
+ """
50
+ Step the EMA model towards the current model.
51
+ """
52
+ ema_params = OrderedDict(ema_model.named_parameters())
53
+ model_params = OrderedDict(model.named_parameters())
54
+
55
+ for name, param in model_params.items():
56
+ if param.requires_grad:
57
+ ema_name = name.replace('_orig_mod.', '')
58
+ ema_params[ema_name].mul_(decay).add_(param.data, alpha=1 - decay)
59
+
60
+
61
+ def unwrap_model(model):
62
+ """
63
+ Unwrap a model from any distributed or compiled wrappers.
64
+ """
65
+ if isinstance(model, torch._dynamo.eval_frame.OptimizedModule):
66
+ model = model._orig_mod
67
+ if isinstance(model, (torch.nn.parallel.DistributedDataParallel, torch.nn.DataParallel)):
68
+ model = model.module
69
+ return model
train_utils/loss.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) [2023] [Anima-Lab]
4
+
5
+ # This code is adapted from https://github.com/NVlabs/edm/blob/main/training/loss.py.
6
+ # The original code is licensed under a Creative Commons
7
+ # Attribution-NonCommercial-ShareAlike 4.0 International License, which is can be found at licenses/LICENSE_EDM.txt.
8
+
9
+ """Loss functions used in the paper
10
+ "Elucidating the Design Space of Diffusion-Based Generative Models"."""
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+
15
+ from utils import *
16
+ from train_utils.helper import unwrap_model
17
+
18
+
19
+ # Improved loss function proposed in the paper "Elucidating the Design Space
20
+ # of Diffusion-Based Generative Models" (EDM).
21
+
22
+ class EDMLoss:
23
+ def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5):
24
+ self.P_mean = P_mean
25
+ self.P_std = P_std
26
+ self.sigma_data = sigma_data
27
+
28
+ def __call__(self, net,
29
+ images,
30
+ labels=None,
31
+ mask_ratio=0,
32
+ mae_loss_coef=0,
33
+ feat=None, augment_pipe=None):
34
+ # sample x_t
35
+ rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device)
36
+ sigma = (rnd_normal * self.P_std + self.P_mean).exp()
37
+ weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
38
+ y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None)
39
+ n = torch.randn_like(y) * sigma
40
+
41
+ model_out = net(y + n, sigma, labels, mask_ratio=mask_ratio, mask_dict=None, feat=feat)
42
+ D_yn = model_out['x']
43
+ assert D_yn.shape == y.shape
44
+ loss = weight * ((D_yn - y) ** 2) # (N, C, H, W)
45
+ if mask_ratio > 0:
46
+ assert net.training and 'mask' in model_out
47
+ loss = F.avg_pool2d(loss.mean(dim=1), net.module.model.patch_size).flatten(1) # (N, L)
48
+ unmask = 1 - model_out['mask']
49
+ loss = (loss * unmask).sum(dim=1) / unmask.sum(dim=1) # (N)
50
+ assert loss.ndim == 1
51
+ if mae_loss_coef > 0:
52
+ loss += mae_loss_coef * mae_loss(net.module, y + n, D_yn, 1 - unmask)
53
+ else:
54
+ loss = mean_flat(loss) # (N)
55
+
56
+ raw_net = unwrap_model(net)
57
+ if mask_ratio == 0.0 and raw_net.model.mask_token is not None:
58
+ loss += 0 * torch.sum(raw_net.model.mask_token)
59
+ assert loss.ndim == 1
60
+ return loss
61
+
62
+
63
+ # ----------------------------------------------------------------------------
64
+
65
+
66
+ Losses = {
67
+ 'edm': EDMLoss
68
+ }
69
+
70
+
71
+ # ----------------------------------------------------------------------------
72
+
73
+ def patchify(imgs, patch_size=2, num_channels=4):
74
+ """
75
+ imgs: (N, 3, H, W)
76
+ x: (N, L, patch_size**2 *3)
77
+ """
78
+ p, c = patch_size, num_channels
79
+ assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
80
+
81
+ h = w = imgs.shape[2] // p
82
+ x = imgs.reshape(shape=(imgs.shape[0], c, h, p, w, p))
83
+ x = torch.einsum('nchpwq->nhwpqc', x)
84
+ x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * c))
85
+ return x
86
+
87
+
88
+ def mae_loss(net, target, pred, mask, norm_pix_loss=True):
89
+ target = patchify(target, net.model.patch_size, net.model.out_channels)
90
+ pred = patchify(pred, net.model.patch_size, net.model.out_channels)
91
+ if norm_pix_loss:
92
+ mean = target.mean(dim=-1, keepdim=True)
93
+ var = target.var(dim=-1, keepdim=True)
94
+ target = (target - mean) / (var + 1.e-6)**.5
95
+
96
+ loss = (pred - target) ** 2
97
+ loss = loss.mean(dim=-1) # [N, L], mean loss per patch
98
+
99
+ loss = (loss * mask).sum(dim=1) / mask.sum(dim=1) # mean loss on removed patches, (N)
100
+ assert loss.ndim == 1
101
+ return loss
train_wds.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) [2023] [Anima-Lab]
4
+ '''
5
+ Training MaskDiT on latent dataset in WebDataset format. Used for experiments on Imagenet512x512.
6
+ '''
7
+
8
+ import argparse
9
+ import os.path
10
+ from copy import deepcopy
11
+ from time import time
12
+ from omegaconf import OmegaConf
13
+ import pickle
14
+ from itertools import islice
15
+
16
+ import apex
17
+ import torch
18
+ import webdataset as wds
19
+
20
+ import accelerate
21
+
22
+ from fid import calc
23
+ from models.maskdit import Precond_models
24
+ from train_utils.loss import Losses
25
+
26
+ from train_utils.helper import get_mask_ratio_fn, get_one_hot, requires_grad, update_ema, unwrap_model
27
+
28
+ from sample import generate_with_net
29
+ from utils import dist, mprint, get_latest_ckpt, Logger, sample, \
30
+ str2bool, parse_str_none, parse_int_list, parse_float_none
31
+
32
+
33
+ # ------------------------------------------------------------
34
+ # WebDataset Helper Function
35
+ def nodesplitter(src, group=None):
36
+ rank, world_size, worker, num_workers = wds.utils.pytorch_worker_info()
37
+ if world_size > 1:
38
+ for s in islice(src, rank, None, world_size):
39
+ yield s
40
+ else:
41
+ for s in src:
42
+ yield s
43
+
44
+
45
+ def get_file_paths(dir):
46
+ return [os.path.join(dir, file) for file in os.listdir(dir)]
47
+
48
+
49
+ def split_by_proc(data_list, global_rank, total_size):
50
+ '''
51
+ Evenly split the data_list into total_size parts and return the part indexed by global_rank.
52
+ '''
53
+ assert len(data_list) >= total_size
54
+ assert global_rank < total_size
55
+ return data_list[global_rank::total_size]
56
+
57
+
58
+ def decode_data(item):
59
+ output = {}
60
+ img = pickle.loads(item['latent'])
61
+ output['latent'] = img
62
+ label = int(item['cls'].decode('utf-8'))
63
+ output['label'] = label
64
+ return output
65
+
66
+
67
+ def make_loader(root, mode='train', batch_size=32,
68
+ num_workers=4, cache_dir=None,
69
+ resampled=False, world_size=1, total_num=1281167,
70
+ bufsize=1000, initial=100):
71
+ data_list = get_file_paths(root)
72
+ num_batches_in_total = total_num // (batch_size * world_size)
73
+ if resampled:
74
+ repeat = True
75
+ splitter = False
76
+ else:
77
+ repeat = False
78
+ splitter = nodesplitter
79
+ dataset = (
80
+ wds.WebDataset(
81
+ data_list,
82
+ cache_dir=cache_dir,
83
+ repeat=repeat,
84
+ resampled=resampled,
85
+ handler=wds.handlers.warn_and_stop,
86
+ nodesplitter=splitter,
87
+ )
88
+ .shuffle(bufsize, initial=initial)
89
+ .map(decode_data, handler=wds.handlers.warn_and_stop)
90
+ .to_tuple('latent label')
91
+ .batched(batch_size, partial=False)
92
+ )
93
+
94
+ loader = wds.WebLoader(dataset, batch_size=None, num_workers=num_workers, shuffle=False, persistent_workers=True)
95
+ if resampled:
96
+ loader = loader.with_epoch(num_batches_in_total)
97
+ return loader
98
+
99
+
100
+ # ------------------------------------------------------------
101
+
102
+
103
+ def train_loop(args):
104
+ # load configuration
105
+ config = OmegaConf.load(args.config)
106
+ if not args.no_amp:
107
+ config.train.amp = 'fp16'
108
+ else:
109
+ config.train.amp = 'no'
110
+ if config.train.tf32:
111
+ torch.set_float32_matmul_precision('high')
112
+
113
+ accelerator = accelerate.Accelerator(mixed_precision=config.train.amp,
114
+ gradient_accumulation_steps=config.train.grad_accum,
115
+ log_with='wandb')
116
+ # setup wandb
117
+ if args.use_wandb:
118
+ wandb_init_kwargs = {
119
+ 'entity': config.wandb.entity,
120
+ 'project': config.wandb.project,
121
+ 'group': config.wandb.group,
122
+ }
123
+ accelerator.init_trackers(config.wandb.project, config=OmegaConf.to_container(config), init_kwargs=wandb_init_kwargs)
124
+
125
+ mprint('start training...')
126
+ size = accelerator.num_processes
127
+ rank = accelerator.process_index
128
+
129
+ print(f'global_rank: {rank}, global_size: {size}')
130
+ device = accelerator.device
131
+
132
+ seed = args.global_seed
133
+ torch.manual_seed(seed)
134
+
135
+ mprint(f"enable_amp: {not args.no_amp}, TF32: {config.train.tf32}")
136
+ # Select batch size per GPU
137
+ num_accumulation_rounds = config.train.grad_accum
138
+
139
+ micro_batch = config.train.batchsize
140
+ batch_gpu_total = micro_batch * num_accumulation_rounds
141
+ global_batch_size = batch_gpu_total * size
142
+ mprint(f"Global batchsize: {global_batch_size}, batchsize per GPU: {batch_gpu_total}, micro_batch: {micro_batch}.")
143
+
144
+ class_dropout_prob = config.model.class_dropout_prob
145
+ log_every = config.log.log_every
146
+ ckpt_every = config.log.ckpt_every
147
+
148
+ mask_ratio_fn = get_mask_ratio_fn(config.model.mask_ratio_fn, config.model.mask_ratio, config.model.mask_ratio_min)
149
+
150
+ # Setup an experiment folder
151
+ model_name = config.model.model_type.replace("/", "-") # e.g., DiT-XL/2 --> DiT-XL-2 (for naming folders)
152
+ data_name = config.data.dataset
153
+ if args.ckpt_path is not None and args.use_ckpt_path: # use the existing exp path (mainly used for fine-tuning)
154
+ checkpoint_dir = os.path.dirname(args.ckpt_path)
155
+ experiment_dir = os.path.dirname(checkpoint_dir)
156
+ exp_name = os.path.basename(experiment_dir)
157
+ else: # start a new exp path (and resume from the latest checkpoint if possible)
158
+ cond_gen = 'cond' if config.model.num_classes else 'uncond'
159
+ exp_name = f'{model_name}-{config.model.precond}-{data_name}-{cond_gen}-m{config.model.mask_ratio}-de{int(config.model.use_decoder)}' \
160
+ f'-mae{config.model.mae_loss_coef}-bs-{global_batch_size}-lr{config.train.lr}{config.log.tag}'
161
+ experiment_dir = f"{args.results_dir}/{exp_name}"
162
+ checkpoint_dir = f"{experiment_dir}/checkpoints" # Stores saved model checkpoints
163
+ os.makedirs(checkpoint_dir, exist_ok=True)
164
+ if args.ckpt_path is None:
165
+ args.ckpt_path = get_latest_ckpt(checkpoint_dir) # Resumes from the latest checkpoint if it exists
166
+
167
+ if accelerator.is_main_process:
168
+ logger = Logger(file_name=f'{experiment_dir}/log.txt', file_mode="a+", should_flush=True)
169
+ mprint(f"Experiment directory created at {experiment_dir}")
170
+
171
+ # Setup dataset
172
+ loader = make_loader(config.data.root,
173
+ mode='train',
174
+ batch_size=batch_gpu_total,
175
+ num_workers=args.num_workers,
176
+ resampled=args.resample,
177
+ world_size=size,
178
+ total_num=config.data.total_num)
179
+
180
+ steps_per_epoch = config.data.total_num // global_batch_size
181
+ mprint(f"{steps_per_epoch} steps per epoch")
182
+
183
+ model = Precond_models[config.model.precond](
184
+ img_resolution=config.model.in_size,
185
+ img_channels=config.model.in_channels,
186
+ num_classes=config.model.num_classes,
187
+ model_type=config.model.model_type,
188
+ use_decoder=config.model.use_decoder,
189
+ mae_loss_coef=config.model.mae_loss_coef,
190
+ pad_cls_token=config.model.pad_cls_token
191
+ ).to(device)
192
+ # Note that parameter initialization is done within the model constructor
193
+ ema = deepcopy(model) # Create an EMA of the model for use after training
194
+ requires_grad(ema, False)
195
+ ema = ema.to(device)
196
+
197
+ mprint(f"{config.model.model_type} ((use_decoder: {config.model.use_decoder})) Model Parameters: {sum(p.numel() for p in model.parameters()):,}")
198
+ mprint(f'extras: {model.model.extras}, cls_token: {model.model.cls_token}')
199
+
200
+ # Setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper):
201
+ optimizer = apex.optimizers.FusedAdam(model.parameters(), lr=config.train.lr, adam_w_mode=True, weight_decay=0)
202
+ # optimizer = torch.optim.Adam(model.parameters(), lr=config.train.lr)
203
+
204
+ # Load checkpoints
205
+ train_steps_start = 0
206
+ epoch_start = 0
207
+
208
+ if args.ckpt_path is not None:
209
+ ckpt = torch.load(args.ckpt_path, map_location=device)
210
+ model.load_state_dict(ckpt['model'], strict=args.use_strict_load)
211
+ ema.load_state_dict(ckpt['ema'], strict=args.use_strict_load)
212
+ mprint(f'Load weights from {args.ckpt_path}')
213
+ if args.use_strict_load:
214
+ optimizer.load_state_dict(ckpt['opt'])
215
+ for state in optimizer.state.values():
216
+ for k, v in state.items():
217
+ if isinstance(v, torch.Tensor):
218
+ state[k] = v.cuda()
219
+ mprint(f'Load optimizer state..')
220
+ train_steps_start = int(os.path.basename(args.ckpt_path).split('.pt')[0])
221
+ epoch_start = train_steps_start // steps_per_epoch
222
+ mprint(f"train_steps_start: {train_steps_start}")
223
+ del ckpt # conserve memory
224
+
225
+ # FID evaluation for the loaded weights
226
+ if args.enable_eval:
227
+ start_time = time()
228
+ args.outdir = os.path.join(experiment_dir, 'fid', f'edm-steps{args.num_steps}-ckpt{train_steps_start}_cfg{args.cfg_scale}')
229
+ os.makedirs(args.outdir, exist_ok=True)
230
+ generate_with_net(args, ema, device, rank, size)
231
+
232
+ dist.barrier()
233
+ fid = calc(args.outdir, config.eval.ref_path, args.num_expected, args.global_seed, args.fid_batch_size)
234
+ mprint(f"time for fid calc: {time() - start_time}")
235
+ if args.use_wandb:
236
+ accelerator.log({"eval/fid": fid}, step=train_steps_start)
237
+ mprint(f'guidance: {args.cfg_scale} FID: {fid}')
238
+ dist.barrier()
239
+
240
+ model, optimizer = accelerator.prepare(model, optimizer)
241
+ model = torch.compile(model)
242
+
243
+ # Setup loss
244
+ loss_fn = Losses[config.model.precond]()
245
+
246
+ # Prepare models for training:
247
+ if args.ckpt_path is None:
248
+ assert train_steps_start == 0
249
+ raw_model = unwrap_model(model)
250
+ update_ema(ema, raw_model, decay=0) # Ensure EMA is initialized with synced weights
251
+ model.train() # important! This enables embedding dropout for classifier-free guidance
252
+ ema.eval() # EMA model should always be in eval mode
253
+
254
+ # Variables for monitoring/logging purposes:
255
+ train_steps = train_steps_start
256
+ log_steps = 0
257
+ running_loss = 0
258
+ start_time = time()
259
+ mprint(f"Training for {config.train.epochs} epochs...")
260
+ for epoch in range(epoch_start, config.train.epochs):
261
+ mprint(f"Beginning epoch {epoch}...")
262
+ for x, cond in loader:
263
+ x = x.to(device)
264
+ y = cond.to(device)
265
+
266
+ y = get_one_hot(y, num_classes=config.model.num_classes)
267
+ x = sample(x)
268
+
269
+ loss_batch = 0
270
+ model.zero_grad(set_to_none=True)
271
+ curr_mask_ratio = mask_ratio_fn((train_steps - train_steps_start) / config.train.max_num_steps)
272
+ if class_dropout_prob > 0:
273
+ y = y * (torch.rand([y.shape[0], 1], device=device) >= class_dropout_prob)
274
+
275
+ for round_idx in range(num_accumulation_rounds):
276
+ x_ = x[round_idx * micro_batch:(round_idx + 1) * micro_batch]
277
+ y_ = y[round_idx * micro_batch:(round_idx + 1) * micro_batch]
278
+
279
+ with accelerator.accumulate(model):
280
+ loss = loss_fn(net=model, images=x_, labels=y_,
281
+ mask_ratio=curr_mask_ratio,
282
+ mae_loss_coef=config.model.mae_loss_coef)
283
+ loss_mean = loss.mean()
284
+ accelerator.backward(loss_mean)
285
+ # Update weights with lr warmup.
286
+ lr_cur = config.train.lr * min(train_steps * global_batch_size / max(config.train.lr_rampup_kimg * 1000, 1), 1)
287
+ for g in optimizer.param_groups:
288
+ g['lr'] = lr_cur
289
+ optimizer.step()
290
+ loss_batch = loss_mean.item()
291
+
292
+ raw_model = unwrap_model(model)
293
+ update_ema(ema, raw_model)
294
+
295
+ # Log loss values:
296
+ running_loss += loss_batch
297
+ log_steps += 1
298
+ train_steps += 1
299
+ if train_steps > (train_steps_start + config.train.max_num_steps):
300
+ break
301
+ if train_steps % log_every == 0:
302
+ # Measure training speed:
303
+ torch.cuda.synchronize()
304
+ end_time = time()
305
+ steps_per_sec = log_steps / (end_time - start_time)
306
+ # Reduce loss history over all processes:
307
+ avg_loss = torch.tensor(running_loss / log_steps, device=device)
308
+
309
+ dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
310
+ avg_loss = avg_loss.item() / size
311
+ mprint(f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}")
312
+ mprint(f'Peak GPU memory usage: {torch.cuda.max_memory_allocated() / 1024 ** 3:.2f} GB')
313
+ mprint(f'Reserved GPU memory: {torch.cuda.memory_reserved() / 1024 ** 3:.2f} GB')
314
+ if args.use_wandb:
315
+ accelerator.log({"train/loss": avg_loss, "train/lr": lr_cur}, step=train_steps)
316
+ # Reset monitoring variables:
317
+ running_loss = 0
318
+ log_steps = 0
319
+ start_time = time()
320
+
321
+ # Save checkpoint:
322
+ if train_steps % ckpt_every == 0 and train_steps > train_steps_start:
323
+ if accelerator.is_main_process:
324
+ checkpoint = {
325
+ "model": raw_model.state_dict(),
326
+ "ema": ema.state_dict(),
327
+ "opt": optimizer.state_dict(),
328
+ "args": args
329
+ }
330
+ checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt"
331
+ torch.save(checkpoint, checkpoint_path)
332
+ mprint(f"Saved checkpoint to {checkpoint_path}")
333
+ del checkpoint # conserve memory
334
+ dist.barrier()
335
+
336
+ # FID evaluation during training
337
+ if args.enable_eval:
338
+ start_time = time()
339
+ args.outdir = os.path.join(experiment_dir, 'fid', f'edm-steps{args.num_steps}-ckpt{train_steps}_cfg{args.cfg_scale}')
340
+ os.makedirs(args.outdir, exist_ok=True)
341
+ generate_with_net(args, ema, device, rank, size)
342
+
343
+ dist.barrier()
344
+ fid = calc(args.outdir, args.ref_path, args.num_expected, args.global_seed, args.fid_batch_size)
345
+ mprint(f"time for fid calc: {time() - start_time}, fid: {fid}")
346
+ if args.use_wandb:
347
+ accelerator.log({"eval/fid": fid}, step=train_steps)
348
+ mprint(f'Guidance: {args.cfg_scale}, FID: {fid}')
349
+ dist.barrier()
350
+ start_time = time()
351
+
352
+ if accelerator.is_main_process:
353
+ logger.close()
354
+ accelerator.end_training()
355
+
356
+
357
+
358
+ if __name__ == '__main__':
359
+ parser = argparse.ArgumentParser('training parameters')
360
+ # basic config
361
+ parser.add_argument('--config', type=str, required=True, help='path to config file')
362
+ # training
363
+ parser.add_argument("--results_dir", type=str, default="results")
364
+ parser.add_argument("--ckpt_path", type=parse_str_none, default=None)
365
+
366
+ parser.add_argument("--global_seed", type=int, default=0)
367
+ parser.add_argument("--num_workers", type=int, default=4)
368
+ parser.add_argument('--no_amp', action='store_true', help="Disable automatic mixed precision.")
369
+
370
+ parser.add_argument("--use_wandb", action='store_true', help='enable wandb logging')
371
+ parser.add_argument("--use_ckpt_path", type=str2bool, default=True)
372
+ parser.add_argument("--use_strict_load", type=str2bool, default=True)
373
+ parser.add_argument("--tag", type=str, default='')
374
+ parser.add_argument("--resample", action='store_true', help='enable shard resample')
375
+
376
+ # sampling
377
+ parser.add_argument('--enable_eval', action='store_true', help='enable fid calc during training')
378
+ parser.add_argument('--seeds', type=parse_int_list, default='100000-104999', help='Random seeds (e.g. 1,2,5-10)')
379
+ parser.add_argument('--subdirs', action='store_true', help='Create subdirectory for every 1000 seeds')
380
+ parser.add_argument('--class_idx', type=int, default=None, help='Class label [default: random]')
381
+ parser.add_argument('--max_batch_size', type=int, default=25, help='Maximum batch size per GPU during sampling, must be a factor of 50k if torch.compile is used')
382
+
383
+ parser.add_argument("--cfg_scale", type=parse_float_none, default=None, help='None = no guidance, by default = 4.0')
384
+
385
+ parser.add_argument('--num_steps', type=int, default=40, help='Number of sampling steps')
386
+ parser.add_argument('--S_churn', type=int, default=0, help='Stochasticity strength')
387
+ parser.add_argument('--solver', type=str, default=None, choices=['euler', 'heun'], help='Ablate ODE solver')
388
+ parser.add_argument('--discretization', type=str, default=None, choices=['vp', 've', 'iddpm', 'edm'], help='Ablate ODE solver')
389
+ parser.add_argument('--schedule', type=str, default=None, choices=['vp', 've', 'linear'], help='Ablate noise schedule sigma(t)')
390
+ parser.add_argument('--scaling', type=str, default=None, choices=['vp', 'none'], help='Ablate signal scaling s(t)')
391
+ parser.add_argument('--pretrained_path', type=str, default='assets/stable_diffusion/autoencoder_kl.pth', help='Autoencoder ckpt')
392
+
393
+ parser.add_argument('--ref_path', type=str, default='assets/fid_stats/VIRTUAL_imagenet512.npz', help='Dataset reference statistics')
394
+ parser.add_argument('--num_expected', type=int, default=5000, help='Number of images to use')
395
+ parser.add_argument('--fid_batch_size', type=int, default=64, help='Maximum batch size per GPU')
396
+
397
+ args = parser.parse_args()
398
+
399
+ torch.backends.cudnn.benchmark = True
400
+ train_loop(args)
utils.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) [2023] [Anima-Lab]
4
+
5
+ # This code is adapted from https://github.com/NVlabs/edm/blob/main/generate.py.
6
+ # The original code is licensed under a Creative Commons
7
+ # Attribution-NonCommercial-ShareAlike 4.0 International License, which is can be found at licenses/LICENSE_EDM.txt.
8
+
9
+
10
+ import os
11
+ import re
12
+ import sys
13
+ import contextlib
14
+
15
+ import torch
16
+ import torch.distributed as dist
17
+
18
+
19
+ #----------------------------------------------------------------------------
20
+ # Get the latest checkpoint from the save dir
21
+
22
+ def get_latest_ckpt(dir):
23
+ latest_id = -1
24
+ for file in os.listdir(dir):
25
+ if file.endswith('.pt'):
26
+ m = re.search(r'(\d+)\.pt', file)
27
+ if m:
28
+ ckpt_id = int(m.group(1))
29
+ latest_id = max(latest_id, ckpt_id)
30
+ if latest_id == -1:
31
+ return None
32
+ else:
33
+ ckpt_path = os.path.join(dir, f'{latest_id:07d}.pt')
34
+ return ckpt_path
35
+
36
+
37
+ def get_ckpt_paths(dir, id_min, id_max):
38
+ ckpt_dict = {}
39
+ for file in os.listdir(dir):
40
+ if file.endswith('.pt'):
41
+ m = re.search(r'(\d+)\.pt', file)
42
+ if m:
43
+ ckpt_id = int(m.group(1))
44
+ if id_min <= ckpt_id <= id_max:
45
+ ckpt_dict[ckpt_id] = os.path.join(dir, f'{ckpt_id:07d}.pt')
46
+ return ckpt_dict
47
+
48
+
49
+ #----------------------------------------------------------------------------
50
+ # Take the mean over all non-batch dimensions.
51
+
52
+ def mean_flat(tensor):
53
+ return tensor.mean(dim=list(range(1, tensor.ndim)))
54
+
55
+
56
+ #----------------------------------------------------------------------------
57
+ # Convert latent (mean, logvar) to latent variable (inherited from autoencoder.py)
58
+
59
+ def sample(moments, scale_factor=0.18215):
60
+ mean, logvar = torch.chunk(moments, 2, dim=1)
61
+ logvar = torch.clamp(logvar, -30.0, 20.0)
62
+ std = torch.exp(0.5 * logvar)
63
+ z = mean + std * torch.randn_like(mean)
64
+ z = scale_factor * z
65
+ return z
66
+
67
+
68
+ #----------------------------------------------------------------------------
69
+ # Context manager for easily enabling/disabling DistributedDataParallel
70
+ # synchronization.
71
+
72
+ @contextlib.contextmanager
73
+ def ddp_sync(module, sync):
74
+ assert isinstance(module, torch.nn.Module)
75
+ if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
76
+ yield
77
+ else:
78
+ with module.no_sync():
79
+ yield
80
+
81
+ #----------------------------------------------------------------------------
82
+ # Distributed training helper functions
83
+
84
+ def init_processes(fn, args):
85
+ """ Initialize the distributed environment. """
86
+ os.environ['MASTER_ADDR'] = args.master_address
87
+ os.environ['MASTER_PORT'] = '6020'
88
+ print(f'MASTER_ADDR = {os.environ["MASTER_ADDR"]}')
89
+ print(f'MASTER_PORT = {os.environ["MASTER_PORT"]}')
90
+ torch.cuda.set_device(args.local_rank)
91
+ dist.init_process_group(backend='nccl', init_method='env://', rank=args.global_rank, world_size=args.global_size)
92
+ fn(args)
93
+ if args.global_size > 1:
94
+ cleanup()
95
+
96
+
97
+ def mprint(*args, **kwargs):
98
+ """
99
+ Print only from rank 0.
100
+ """
101
+ if dist.get_rank() == 0:
102
+ print(*args, **kwargs)
103
+
104
+
105
+ def cleanup():
106
+ """
107
+ End DDP training.
108
+ """
109
+ dist.barrier()
110
+ mprint("Done!")
111
+ dist.barrier()
112
+ dist.destroy_process_group()
113
+
114
+
115
+ #----------------------------------------------------------------------------
116
+ # Wrapper for torch.Generator that allows specifying a different random seed
117
+ # for each sample in a minibatch.
118
+
119
+ class StackedRandomGenerator:
120
+ def __init__(self, device, seeds):
121
+ super().__init__()
122
+ self.generators = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds]
123
+
124
+ def randn(self, size, **kwargs):
125
+ assert size[0] == len(self.generators)
126
+ return torch.stack([torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators])
127
+
128
+ def randn_like(self, input):
129
+ return self.randn(input.shape, dtype=input.dtype, layout=input.layout, device=input.device)
130
+
131
+ def randint(self, *args, size, **kwargs):
132
+ assert size[0] == len(self.generators)
133
+ return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators])
134
+
135
+
136
+ #----------------------------------------------------------------------------
137
+ # Parse a comma separated list of numbers or ranges and return a list of ints.
138
+ # Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10]
139
+
140
+ def parse_int_list(s):
141
+ if isinstance(s, list): return s
142
+ ranges = []
143
+ range_re = re.compile(r'^(\d+)-(\d+)$')
144
+ for p in s.split(','):
145
+ m = range_re.match(p)
146
+ if m:
147
+ ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
148
+ else:
149
+ ranges.append(int(p))
150
+ return ranges
151
+
152
+ # Parse 'None' to None and others to float value
153
+ def parse_float_none(s):
154
+ assert isinstance(s, str)
155
+ return None if s == 'None' else float(s)
156
+
157
+ # Parse 'None' to None and others to str
158
+ def parse_str_none(s):
159
+ assert isinstance(s, str)
160
+ return None if s == 'None' else s
161
+
162
+ # Parse 'true' to True
163
+ def str2bool(s):
164
+ return s.lower() in ['true', '1', 'yes']
165
+
166
+
167
+ #----------------------------------------------------------------------------
168
+ # logging info.
169
+ class Logger(object):
170
+ """
171
+ Redirect stderr to stdout, optionally print stdout to a file,
172
+ and optionally force flushing on both stdout and the file.
173
+ """
174
+
175
+ def __init__(self, file_name=None, file_mode="w", should_flush=True):
176
+ self.file = None
177
+
178
+ if file_name is not None:
179
+ self.file = open(file_name, file_mode)
180
+
181
+ self.should_flush = should_flush
182
+ self.stdout = sys.stdout
183
+ self.stderr = sys.stderr
184
+
185
+ sys.stdout = self
186
+ sys.stderr = self
187
+
188
+ def __enter__(self):
189
+ return self
190
+
191
+ def __exit__(self, exc_type, exc_value, traceback):
192
+ self.close()
193
+
194
+ def write(self, text):
195
+ """Write text to stdout (and a file) and optionally flush."""
196
+ if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
197
+ return
198
+
199
+ if self.file is not None:
200
+ self.file.write(text)
201
+
202
+ self.stdout.write(text)
203
+
204
+ if self.should_flush:
205
+ self.flush()
206
+
207
+ def flush(self):
208
+ """Flush written text to both stdout and a file, if open."""
209
+ if self.file is not None:
210
+ self.file.flush()
211
+
212
+ self.stdout.flush()
213
+
214
+ def close(self):
215
+ """Flush, close possible files, and remove stdout/stderr mirroring."""
216
+ self.flush()
217
+
218
+ # if using multiple loggers, prevent closing in wrong order
219
+ if sys.stdout is self:
220
+ sys.stdout = self.stdout
221
+ if sys.stderr is self:
222
+ sys.stderr = self.stderr
223
+
224
+ if self.file is not None:
225
+ self.file.close()