devzhk
commited on
Commit
·
972a35a
1
Parent(s):
6e0fb69
Add model files
Browse files- .DS_Store +0 -0
- Dockerfile +4 -0
- LICENSE +21 -0
- README.md +163 -3
- assets/figs/12samples_compressed.png +3 -0
- assets/figs/arch.png +3 -0
- assets/figs/bar_mem_256.png +3 -0
- assets/figs/bar_mem_512.png +3 -0
- assets/figs/bar_speed_256.png +3 -0
- assets/figs/bar_speed_512.png +3 -0
- assets/figs/bubble_gflops_wg.png +3 -0
- assets/figs/bubble_gflops_wog.png +3 -0
- assets/figs/maskdit_arch.png +3 -0
- assets/figs/repo_head.png +3 -0
- assets/figs/sample512-set1.png +3 -0
- assets/imagenet_label.json +1 -0
- autoencoder.py +522 -0
- checkpoints/.DS_Store +0 -0
- configs/finetune/imagenet256-latent-const.yaml +49 -0
- configs/finetune/imagenet256-latent-cos.yaml +49 -0
- configs/finetune/imagenet512-latent.yaml +47 -0
- configs/test/maskdit-256.yaml +45 -0
- configs/test/maskdit-512.yaml +46 -0
- configs/train/imagenet256-latent.yaml +48 -0
- configs/train/imagenet512-latent.yaml +47 -0
- eval_latent.py +132 -0
- evaluator.py +695 -0
- extract_latent.py +114 -0
- fid.py +177 -0
- generate.py +91 -0
- licenses/LICENSE_ADM.txt +21 -0
- licenses/LICENSE_DIT.txt +400 -0
- licenses/LICENSE_EDM.txt +439 -0
- licenses/LICENSE_UVIT.txt +21 -0
- lmdb2wds.py +39 -0
- models/maskdit.py +781 -0
- sample.py +397 -0
- scripts/download_assets.sh +8 -0
- scripts/finetune_latent512.sh +14 -0
- scripts/prepare_latent256.sh +3 -0
- scripts/prepare_latent512.sh +6 -0
- scripts/train_latent512.sh +11 -0
- torch_utils/__init__.py +0 -0
- torch_utils/persistence.py +276 -0
- train.py +336 -0
- train_utils/datasets.py +412 -0
- train_utils/helper.py +69 -0
- train_utils/loss.py +101 -0
- train_wds.py +400 -0
- utils.py +225 -0
.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
Dockerfile
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM nvcr.io/nvidia/pytorch:23.03-py3
|
| 2 |
+
RUN pip install einops lmdb omegaconf wandb tqdm pyyaml accelerate
|
| 3 |
+
RUN pip install timm webdataset
|
| 4 |
+
RUN pip install diffusers["torch"] transformers
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) [2023] [Anima-Lab]
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,3 +1,163 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Fast Training of Diffusion Models with Masked Transformers
|
| 2 |
+
|
| 3 |
+
Official PyTorch implementation of the TMLR 2024 paper:<br>
|
| 4 |
+
**[Fast Training of Diffusion Models with Masked Transformers](https://openreview.net/pdf?id=vTBjBtGioE)**
|
| 5 |
+
<br>
|
| 6 |
+
Hongkai Zheng*, Weili Nie*, Arash Vahdat, Anima Anandkumar <br>
|
| 7 |
+
(*Equal contribution)<br>
|
| 8 |
+
|
| 9 |
+
Abstract: *While masked transformers have been extensively explored for representation learning, their application to
|
| 10 |
+
generative learning is less explored in the vision domain. Our work is the first to exploit masked training to reduce
|
| 11 |
+
the training cost of diffusion models significantly. Specifically, we randomly mask out a high proportion (e.g., 50%) of
|
| 12 |
+
patches in diffused input images during training. For masked training, we introduce an asymmetric encoder-decoder
|
| 13 |
+
architecture consisting of a transformer encoder that operates only on unmasked patches and a lightweight transformer
|
| 14 |
+
decoder on full patches. To promote a long-range understanding of full patches, we add an auxiliary task of
|
| 15 |
+
reconstructing masked patches to the denoising score matching objective that learns the score of unmasked patches.
|
| 16 |
+
Experiments on ImageNet-256x256 and ImageNet-512x512 show that our approach achieves competitive and even better
|
| 17 |
+
generative performance than the state-of-the-art Diffusion Transformer (DiT) model, using only around 30% of its
|
| 18 |
+
original training time. Thus, our method shows a promising way of efficiently training large transformer-based diffusion
|
| 19 |
+
models without sacrificing the generative performance.*
|
| 20 |
+
|
| 21 |
+
<div align='center'>
|
| 22 |
+
<img src="assets/figs/repo_head.png" alt="Architecture" width="900" height="500" style="display: block;"/>
|
| 23 |
+
</div>
|
| 24 |
+
|
| 25 |
+
## Requirements
|
| 26 |
+
|
| 27 |
+
- Training MaskDiT on ImageNet256x256 takes around 260 hours with 8 A100 GPUs to perform 2M updates with a batch size of
|
| 28 |
+
1024.
|
| 29 |
+
- Training MaskDiT on ImageNet512x512 takes around 210 A100 GPU days to perform 1M updates with a batch size of 1024.
|
| 30 |
+
- At least one high-end GPU for sampling.
|
| 31 |
+
- [Dockerfile](Dockerfile) is provided for exact software environment.
|
| 32 |
+
|
| 33 |
+
## Efficiency
|
| 34 |
+
|
| 35 |
+
Our MaskDiT applies Automatic Mixed Precision (AMP) by default. We also add the MaskDiT without AMP (Ours_ft32) for
|
| 36 |
+
reference.
|
| 37 |
+
|
| 38 |
+
### Training speed
|
| 39 |
+
|
| 40 |
+
<img src="assets/figs/bar_speed_256.png" width=45% style="display: inline-block;"><img src="assets/figs/bar_speed_512.png" width=46% style="display: inline-block;">
|
| 41 |
+
|
| 42 |
+
### GPU memory
|
| 43 |
+
|
| 44 |
+
<img src="assets/figs/bar_mem_256.png" width=45% style="display: inline-block;"><img src="assets/figs/bar_mem_512.png" width=44.3% style="display: inline-block;">
|
| 45 |
+
|
| 46 |
+
## Pretrained Models
|
| 47 |
+
We provide pretrained models of MaskDiT for ImageNet256 and ImageNet512 in the following table. For FID with guidance, the guidance scale is set to 1.5 by default.
|
| 48 |
+
| Guidance | Resolution | FID | Model |
|
| 49 |
+
| :------- | :--------- | :---- | :------------------------------------------------------------------------------------------------------------------------ |
|
| 50 |
+
| Yes | 256x256 | 2.28 | [imagenet256-guidance.pt](checkpoints/imagenet256_with_guidance.pt) |
|
| 51 |
+
| No | 256x256 | 5.69 | [imagenet256-conditional.pt](checkpoints/imagenet256_without_guidance.pt) |
|
| 52 |
+
| Yes | 512x512 | 2.50 | [imagenet512-guidance.pt](checkpoints/imagenet512_with_guidance.pt) |
|
| 53 |
+
| No | 512x512 | 10.79 | [imagenet512-conditional.pt](checkpoints/imagenet512_without_guidance.pt) | |
|
| 54 |
+
|
| 55 |
+
## Generate from pretrained models
|
| 56 |
+
|
| 57 |
+
To generate samples from provided checkpoints, run
|
| 58 |
+
|
| 59 |
+
```bash
|
| 60 |
+
python3 generate.py --config configs/test/maskdit-512.yaml --ckpt_path [path to checkpoints] --class_idx [class index from 0-999] --cfg_scale [guidance scale]
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
<img src="assets/figs/12samples_compressed.png" title="Generated samples from MaskDiT 256x256." width="850" style="display: block; margin: 0 auto;"/>
|
| 64 |
+
<p align='center'> Generated samples from MaskDiT 256x256. Upper panel: without CFG. Lower panel: with CFG (scale=1.5).
|
| 65 |
+
<p\>
|
| 66 |
+
|
| 67 |
+
<img src="assets/figs/imagenet512.png" title="Generated samples from MaskDiT 512x512." width="850" style="display: block; margin: 0 auto;"/>
|
| 68 |
+
<p align='center'> Generated samples from MaskDiT 512x512 with CFG (scale=1.5).
|
| 69 |
+
<p\>
|
| 70 |
+
|
| 71 |
+
## Prepare dataset
|
| 72 |
+
|
| 73 |
+
We use the pre-trained VAE to first encode the ImageNet dataset into latent space. You can download the ImageNet-256x256
|
| 74 |
+
and ImageNet-512x512 that have been encoded into latent space by running
|
| 75 |
+
|
| 76 |
+
```bash
|
| 77 |
+
bash scripts/download_assets.sh
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
`extract_latent.py` was used to encode the ImageNet.
|
| 81 |
+
|
| 82 |
+
### LMDB to Webdataset
|
| 83 |
+
|
| 84 |
+
When training on ImageNet-256x256, we store our data in LMDB format. When training on ImageNet-512x512, we
|
| 85 |
+
use [webdataset](https://github.com/webdataset/webdataset) for faster IO performance. To convert a LMDB dataset into a
|
| 86 |
+
webdataset, run
|
| 87 |
+
|
| 88 |
+
```
|
| 89 |
+
python3 lmdb2wds.py --datadir [path to lmdb] --outdir [path to save webdataset] --resolution [latent resolution] --num_channels [number of latent channels]
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
## Train
|
| 93 |
+
|
| 94 |
+
### ImageNet-256x256
|
| 95 |
+
|
| 96 |
+
First train MaskDiT with 50% mask ratio with AMP enabled.
|
| 97 |
+
|
| 98 |
+
```bash
|
| 99 |
+
accelerate launch --multi_gpu train.py --config configs/train/imagenet256-latent.yaml
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
Then finetune with unmasking.
|
| 103 |
+
|
| 104 |
+
```bash
|
| 105 |
+
accelerate launch --multi_gpu train.py --config configs/finetune/imagenet256-latent-const.yaml --ckpt_path [path to checkpoint] --use_ckpt_path False --use_strict_load False --no_amp
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
### ImageNet-512x512
|
| 109 |
+
|
| 110 |
+
Train MaskDiT with 50% mask ratio with AMP enabled. Here is an example of 4-node training script.
|
| 111 |
+
|
| 112 |
+
```bash
|
| 113 |
+
bash scripts/train_latent512.sh
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
Then finetune with unmasking.
|
| 117 |
+
|
| 118 |
+
```bash
|
| 119 |
+
bash scripts/finetune_latent512.sh
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
## Evaluation
|
| 123 |
+
|
| 124 |
+
### FID evaluation
|
| 125 |
+
|
| 126 |
+
To compute a FID of a pretrained model, run
|
| 127 |
+
|
| 128 |
+
```bash
|
| 129 |
+
accelerate launch --multi_gpu eval_latent.py --config configs/test/maskdit-256.yaml --ckpt [path to the pretrained model] --cfg_scale [guidance scale]
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
### Full evaluation
|
| 133 |
+
|
| 134 |
+
First, download the reference from [ADM repo](https://github.com/openai/guided-diffusion/tree/main/evaluations)
|
| 135 |
+
directly. You can also use `download_assets.py` by running
|
| 136 |
+
|
| 137 |
+
```bash
|
| 138 |
+
python3 download_assets.py --name imagenet256 --dest [destination directory]
|
| 139 |
+
```
|
| 140 |
+
|
| 141 |
+
Then we use the evaluator `evaluator.py`
|
| 142 |
+
from [ADM repo](https://github.com/openai/guided-diffusion/tree/main/evaluations), or `fid.py`
|
| 143 |
+
from [EDM repo](https://github.com/NVlabs/edm), to evaluate the generated samples.
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
### Citation
|
| 147 |
+
|
| 148 |
+
```
|
| 149 |
+
@inproceedings{Zheng2024MaskDiT,
|
| 150 |
+
title={Fast Training of Diffusion Models with Masked Transformers},
|
| 151 |
+
author={Zheng, Hongkai and Nie, Weili and Vahdat, Arash and Anandkumar, Anima},
|
| 152 |
+
booktitle = {Transactions on Machine Learning Research (TMLR)},
|
| 153 |
+
year={2024}
|
| 154 |
+
}
|
| 155 |
+
```
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
## Acknowledgements
|
| 159 |
+
|
| 160 |
+
Thanks to the open source codebases such as [DiT](https://github.com/facebookresearch/DiT)
|
| 161 |
+
, [MAE](https://github.com/facebookresearch/mae), [U-ViT](https://github.com/baofff/U-ViT)
|
| 162 |
+
, [ADM](https://github.com/openai/guided-diffusion), and [EDM](https://github.com/NVlabs/edm). Our codebase is built on
|
| 163 |
+
them.
|
assets/figs/12samples_compressed.png
ADDED
|
Git LFS Details
|
assets/figs/arch.png
ADDED
|
Git LFS Details
|
assets/figs/bar_mem_256.png
ADDED
|
Git LFS Details
|
assets/figs/bar_mem_512.png
ADDED
|
Git LFS Details
|
assets/figs/bar_speed_256.png
ADDED
|
Git LFS Details
|
assets/figs/bar_speed_512.png
ADDED
|
Git LFS Details
|
assets/figs/bubble_gflops_wg.png
ADDED
|
Git LFS Details
|
assets/figs/bubble_gflops_wog.png
ADDED
|
Git LFS Details
|
assets/figs/maskdit_arch.png
ADDED
|
Git LFS Details
|
assets/figs/repo_head.png
ADDED
|
Git LFS Details
|
assets/figs/sample512-set1.png
ADDED
|
Git LFS Details
|
assets/imagenet_label.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"0": ["n01440764", "tench"], "1": ["n01443537", "goldfish"], "2": ["n01484850", "great_white_shark"], "3": ["n01491361", "tiger_shark"], "4": ["n01494475", "hammerhead"], "5": ["n01496331", "electric_ray"], "6": ["n01498041", "stingray"], "7": ["n01514668", "cock"], "8": ["n01514859", "hen"], "9": ["n01518878", "ostrich"], "10": ["n01530575", "brambling"], "11": ["n01531178", "goldfinch"], "12": ["n01532829", "house_finch"], "13": ["n01534433", "junco"], "14": ["n01537544", "indigo_bunting"], "15": ["n01558993", "robin"], "16": ["n01560419", "bulbul"], "17": ["n01580077", "jay"], "18": ["n01582220", "magpie"], "19": ["n01592084", "chickadee"], "20": ["n01601694", "water_ouzel"], "21": ["n01608432", "kite"], "22": ["n01614925", "bald_eagle"], "23": ["n01616318", "vulture"], "24": ["n01622779", "great_grey_owl"], "25": ["n01629819", "European_fire_salamander"], "26": ["n01630670", "common_newt"], "27": ["n01631663", "eft"], "28": ["n01632458", "spotted_salamander"], "29": ["n01632777", "axolotl"], "30": ["n01641577", "bullfrog"], "31": ["n01644373", "tree_frog"], "32": ["n01644900", "tailed_frog"], "33": ["n01664065", "loggerhead"], "34": ["n01665541", "leatherback_turtle"], "35": ["n01667114", "mud_turtle"], "36": ["n01667778", "terrapin"], "37": ["n01669191", "box_turtle"], "38": ["n01675722", "banded_gecko"], "39": ["n01677366", "common_iguana"], "40": ["n01682714", "American_chameleon"], "41": ["n01685808", "whiptail"], "42": ["n01687978", "agama"], "43": ["n01688243", "frilled_lizard"], "44": ["n01689811", "alligator_lizard"], "45": ["n01692333", "Gila_monster"], "46": ["n01693334", "green_lizard"], "47": ["n01694178", "African_chameleon"], "48": ["n01695060", "Komodo_dragon"], "49": ["n01697457", "African_crocodile"], "50": ["n01698640", "American_alligator"], "51": ["n01704323", "triceratops"], "52": ["n01728572", "thunder_snake"], "53": ["n01728920", "ringneck_snake"], "54": ["n01729322", "hognose_snake"], "55": ["n01729977", "green_snake"], "56": ["n01734418", "king_snake"], "57": ["n01735189", "garter_snake"], "58": ["n01737021", "water_snake"], "59": ["n01739381", "vine_snake"], "60": ["n01740131", "night_snake"], "61": ["n01742172", "boa_constrictor"], "62": ["n01744401", "rock_python"], "63": ["n01748264", "Indian_cobra"], "64": ["n01749939", "green_mamba"], "65": ["n01751748", "sea_snake"], "66": ["n01753488", "horned_viper"], "67": ["n01755581", "diamondback"], "68": ["n01756291", "sidewinder"], "69": ["n01768244", "trilobite"], "70": ["n01770081", "harvestman"], "71": ["n01770393", "scorpion"], "72": ["n01773157", "black_and_gold_garden_spider"], "73": ["n01773549", "barn_spider"], "74": ["n01773797", "garden_spider"], "75": ["n01774384", "black_widow"], "76": ["n01774750", "tarantula"], "77": ["n01775062", "wolf_spider"], "78": ["n01776313", "tick"], "79": ["n01784675", "centipede"], "80": ["n01795545", "black_grouse"], "81": ["n01796340", "ptarmigan"], "82": ["n01797886", "ruffed_grouse"], "83": ["n01798484", "prairie_chicken"], "84": ["n01806143", "peacock"], "85": ["n01806567", "quail"], "86": ["n01807496", "partridge"], "87": ["n01817953", "African_grey"], "88": ["n01818515", "macaw"], "89": ["n01819313", "sulphur-crested_cockatoo"], "90": ["n01820546", "lorikeet"], "91": ["n01824575", "coucal"], "92": ["n01828970", "bee_eater"], "93": ["n01829413", "hornbill"], "94": ["n01833805", "hummingbird"], "95": ["n01843065", "jacamar"], "96": ["n01843383", "toucan"], "97": ["n01847000", "drake"], "98": ["n01855032", "red-breasted_merganser"], "99": ["n01855672", "goose"], "100": ["n01860187", "black_swan"], "101": ["n01871265", "tusker"], "102": ["n01872401", "echidna"], "103": ["n01873310", "platypus"], "104": ["n01877812", "wallaby"], "105": ["n01882714", "koala"], "106": ["n01883070", "wombat"], "107": ["n01910747", "jellyfish"], "108": ["n01914609", "sea_anemone"], "109": ["n01917289", "brain_coral"], "110": ["n01924916", "flatworm"], "111": ["n01930112", "nematode"], "112": ["n01943899", "conch"], "113": ["n01944390", "snail"], "114": ["n01945685", "slug"], "115": ["n01950731", "sea_slug"], "116": ["n01955084", "chiton"], "117": ["n01968897", "chambered_nautilus"], "118": ["n01978287", "Dungeness_crab"], "119": ["n01978455", "rock_crab"], "120": ["n01980166", "fiddler_crab"], "121": ["n01981276", "king_crab"], "122": ["n01983481", "American_lobster"], "123": ["n01984695", "spiny_lobster"], "124": ["n01985128", "crayfish"], "125": ["n01986214", "hermit_crab"], "126": ["n01990800", "isopod"], "127": ["n02002556", "white_stork"], "128": ["n02002724", "black_stork"], "129": ["n02006656", "spoonbill"], "130": ["n02007558", "flamingo"], "131": ["n02009229", "little_blue_heron"], "132": ["n02009912", "American_egret"], "133": ["n02011460", "bittern"], "134": ["n02012849", "crane"], "135": ["n02013706", "limpkin"], "136": ["n02017213", "European_gallinule"], "137": ["n02018207", "American_coot"], "138": ["n02018795", "bustard"], "139": ["n02025239", "ruddy_turnstone"], "140": ["n02027492", "red-backed_sandpiper"], "141": ["n02028035", "redshank"], "142": ["n02033041", "dowitcher"], "143": ["n02037110", "oystercatcher"], "144": ["n02051845", "pelican"], "145": ["n02056570", "king_penguin"], "146": ["n02058221", "albatross"], "147": ["n02066245", "grey_whale"], "148": ["n02071294", "killer_whale"], "149": ["n02074367", "dugong"], "150": ["n02077923", "sea_lion"], "151": ["n02085620", "Chihuahua"], "152": ["n02085782", "Japanese_spaniel"], "153": ["n02085936", "Maltese_dog"], "154": ["n02086079", "Pekinese"], "155": ["n02086240", "Shih-Tzu"], "156": ["n02086646", "Blenheim_spaniel"], "157": ["n02086910", "papillon"], "158": ["n02087046", "toy_terrier"], "159": ["n02087394", "Rhodesian_ridgeback"], "160": ["n02088094", "Afghan_hound"], "161": ["n02088238", "basset"], "162": ["n02088364", "beagle"], "163": ["n02088466", "bloodhound"], "164": ["n02088632", "bluetick"], "165": ["n02089078", "black-and-tan_coonhound"], "166": ["n02089867", "Walker_hound"], "167": ["n02089973", "English_foxhound"], "168": ["n02090379", "redbone"], "169": ["n02090622", "borzoi"], "170": ["n02090721", "Irish_wolfhound"], "171": ["n02091032", "Italian_greyhound"], "172": ["n02091134", "whippet"], "173": ["n02091244", "Ibizan_hound"], "174": ["n02091467", "Norwegian_elkhound"], "175": ["n02091635", "otterhound"], "176": ["n02091831", "Saluki"], "177": ["n02092002", "Scottish_deerhound"], "178": ["n02092339", "Weimaraner"], "179": ["n02093256", "Staffordshire_bullterrier"], "180": ["n02093428", "American_Staffordshire_terrier"], "181": ["n02093647", "Bedlington_terrier"], "182": ["n02093754", "Border_terrier"], "183": ["n02093859", "Kerry_blue_terrier"], "184": ["n02093991", "Irish_terrier"], "185": ["n02094114", "Norfolk_terrier"], "186": ["n02094258", "Norwich_terrier"], "187": ["n02094433", "Yorkshire_terrier"], "188": ["n02095314", "wire-haired_fox_terrier"], "189": ["n02095570", "Lakeland_terrier"], "190": ["n02095889", "Sealyham_terrier"], "191": ["n02096051", "Airedale"], "192": ["n02096177", "cairn"], "193": ["n02096294", "Australian_terrier"], "194": ["n02096437", "Dandie_Dinmont"], "195": ["n02096585", "Boston_bull"], "196": ["n02097047", "miniature_schnauzer"], "197": ["n02097130", "giant_schnauzer"], "198": ["n02097209", "standard_schnauzer"], "199": ["n02097298", "Scotch_terrier"], "200": ["n02097474", "Tibetan_terrier"], "201": ["n02097658", "silky_terrier"], "202": ["n02098105", "soft-coated_wheaten_terrier"], "203": ["n02098286", "West_Highland_white_terrier"], "204": ["n02098413", "Lhasa"], "205": ["n02099267", "flat-coated_retriever"], "206": ["n02099429", "curly-coated_retriever"], "207": ["n02099601", "golden_retriever"], "208": ["n02099712", "Labrador_retriever"], "209": ["n02099849", "Chesapeake_Bay_retriever"], "210": ["n02100236", "German_short-haired_pointer"], "211": ["n02100583", "vizsla"], "212": ["n02100735", "English_setter"], "213": ["n02100877", "Irish_setter"], "214": ["n02101006", "Gordon_setter"], "215": ["n02101388", "Brittany_spaniel"], "216": ["n02101556", "clumber"], "217": ["n02102040", "English_springer"], "218": ["n02102177", "Welsh_springer_spaniel"], "219": ["n02102318", "cocker_spaniel"], "220": ["n02102480", "Sussex_spaniel"], "221": ["n02102973", "Irish_water_spaniel"], "222": ["n02104029", "kuvasz"], "223": ["n02104365", "schipperke"], "224": ["n02105056", "groenendael"], "225": ["n02105162", "malinois"], "226": ["n02105251", "briard"], "227": ["n02105412", "kelpie"], "228": ["n02105505", "komondor"], "229": ["n02105641", "Old_English_sheepdog"], "230": ["n02105855", "Shetland_sheepdog"], "231": ["n02106030", "collie"], "232": ["n02106166", "Border_collie"], "233": ["n02106382", "Bouvier_des_Flandres"], "234": ["n02106550", "Rottweiler"], "235": ["n02106662", "German_shepherd"], "236": ["n02107142", "Doberman"], "237": ["n02107312", "miniature_pinscher"], "238": ["n02107574", "Greater_Swiss_Mountain_dog"], "239": ["n02107683", "Bernese_mountain_dog"], "240": ["n02107908", "Appenzeller"], "241": ["n02108000", "EntleBucher"], "242": ["n02108089", "boxer"], "243": ["n02108422", "bull_mastiff"], "244": ["n02108551", "Tibetan_mastiff"], "245": ["n02108915", "French_bulldog"], "246": ["n02109047", "Great_Dane"], "247": ["n02109525", "Saint_Bernard"], "248": ["n02109961", "Eskimo_dog"], "249": ["n02110063", "malamute"], "250": ["n02110185", "Siberian_husky"], "251": ["n02110341", "dalmatian"], "252": ["n02110627", "affenpinscher"], "253": ["n02110806", "basenji"], "254": ["n02110958", "pug"], "255": ["n02111129", "Leonberg"], "256": ["n02111277", "Newfoundland"], "257": ["n02111500", "Great_Pyrenees"], "258": ["n02111889", "Samoyed"], "259": ["n02112018", "Pomeranian"], "260": ["n02112137", "chow"], "261": ["n02112350", "keeshond"], "262": ["n02112706", "Brabancon_griffon"], "263": ["n02113023", "Pembroke"], "264": ["n02113186", "Cardigan"], "265": ["n02113624", "toy_poodle"], "266": ["n02113712", "miniature_poodle"], "267": ["n02113799", "standard_poodle"], "268": ["n02113978", "Mexican_hairless"], "269": ["n02114367", "timber_wolf"], "270": ["n02114548", "white_wolf"], "271": ["n02114712", "red_wolf"], "272": ["n02114855", "coyote"], "273": ["n02115641", "dingo"], "274": ["n02115913", "dhole"], "275": ["n02116738", "African_hunting_dog"], "276": ["n02117135", "hyena"], "277": ["n02119022", "red_fox"], "278": ["n02119789", "kit_fox"], "279": ["n02120079", "Arctic_fox"], "280": ["n02120505", "grey_fox"], "281": ["n02123045", "tabby"], "282": ["n02123159", "tiger_cat"], "283": ["n02123394", "Persian_cat"], "284": ["n02123597", "Siamese_cat"], "285": ["n02124075", "Egyptian_cat"], "286": ["n02125311", "cougar"], "287": ["n02127052", "lynx"], "288": ["n02128385", "leopard"], "289": ["n02128757", "snow_leopard"], "290": ["n02128925", "jaguar"], "291": ["n02129165", "lion"], "292": ["n02129604", "tiger"], "293": ["n02130308", "cheetah"], "294": ["n02132136", "brown_bear"], "295": ["n02133161", "American_black_bear"], "296": ["n02134084", "ice_bear"], "297": ["n02134418", "sloth_bear"], "298": ["n02137549", "mongoose"], "299": ["n02138441", "meerkat"], "300": ["n02165105", "tiger_beetle"], "301": ["n02165456", "ladybug"], "302": ["n02167151", "ground_beetle"], "303": ["n02168699", "long-horned_beetle"], "304": ["n02169497", "leaf_beetle"], "305": ["n02172182", "dung_beetle"], "306": ["n02174001", "rhinoceros_beetle"], "307": ["n02177972", "weevil"], "308": ["n02190166", "fly"], "309": ["n02206856", "bee"], "310": ["n02219486", "ant"], "311": ["n02226429", "grasshopper"], "312": ["n02229544", "cricket"], "313": ["n02231487", "walking_stick"], "314": ["n02233338", "cockroach"], "315": ["n02236044", "mantis"], "316": ["n02256656", "cicada"], "317": ["n02259212", "leafhopper"], "318": ["n02264363", "lacewing"], "319": ["n02268443", "dragonfly"], "320": ["n02268853", "damselfly"], "321": ["n02276258", "admiral"], "322": ["n02277742", "ringlet"], "323": ["n02279972", "monarch"], "324": ["n02280649", "cabbage_butterfly"], "325": ["n02281406", "sulphur_butterfly"], "326": ["n02281787", "lycaenid"], "327": ["n02317335", "starfish"], "328": ["n02319095", "sea_urchin"], "329": ["n02321529", "sea_cucumber"], "330": ["n02325366", "wood_rabbit"], "331": ["n02326432", "hare"], "332": ["n02328150", "Angora"], "333": ["n02342885", "hamster"], "334": ["n02346627", "porcupine"], "335": ["n02356798", "fox_squirrel"], "336": ["n02361337", "marmot"], "337": ["n02363005", "beaver"], "338": ["n02364673", "guinea_pig"], "339": ["n02389026", "sorrel"], "340": ["n02391049", "zebra"], "341": ["n02395406", "hog"], "342": ["n02396427", "wild_boar"], "343": ["n02397096", "warthog"], "344": ["n02398521", "hippopotamus"], "345": ["n02403003", "ox"], "346": ["n02408429", "water_buffalo"], "347": ["n02410509", "bison"], "348": ["n02412080", "ram"], "349": ["n02415577", "bighorn"], "350": ["n02417914", "ibex"], "351": ["n02422106", "hartebeest"], "352": ["n02422699", "impala"], "353": ["n02423022", "gazelle"], "354": ["n02437312", "Arabian_camel"], "355": ["n02437616", "llama"], "356": ["n02441942", "weasel"], "357": ["n02442845", "mink"], "358": ["n02443114", "polecat"], "359": ["n02443484", "black-footed_ferret"], "360": ["n02444819", "otter"], "361": ["n02445715", "skunk"], "362": ["n02447366", "badger"], "363": ["n02454379", "armadillo"], "364": ["n02457408", "three-toed_sloth"], "365": ["n02480495", "orangutan"], "366": ["n02480855", "gorilla"], "367": ["n02481823", "chimpanzee"], "368": ["n02483362", "gibbon"], "369": ["n02483708", "siamang"], "370": ["n02484975", "guenon"], "371": ["n02486261", "patas"], "372": ["n02486410", "baboon"], "373": ["n02487347", "macaque"], "374": ["n02488291", "langur"], "375": ["n02488702", "colobus"], "376": ["n02489166", "proboscis_monkey"], "377": ["n02490219", "marmoset"], "378": ["n02492035", "capuchin"], "379": ["n02492660", "howler_monkey"], "380": ["n02493509", "titi"], "381": ["n02493793", "spider_monkey"], "382": ["n02494079", "squirrel_monkey"], "383": ["n02497673", "Madagascar_cat"], "384": ["n02500267", "indri"], "385": ["n02504013", "Indian_elephant"], "386": ["n02504458", "African_elephant"], "387": ["n02509815", "lesser_panda"], "388": ["n02510455", "giant_panda"], "389": ["n02514041", "barracouta"], "390": ["n02526121", "eel"], "391": ["n02536864", "coho"], "392": ["n02606052", "rock_beauty"], "393": ["n02607072", "anemone_fish"], "394": ["n02640242", "sturgeon"], "395": ["n02641379", "gar"], "396": ["n02643566", "lionfish"], "397": ["n02655020", "puffer"], "398": ["n02666196", "abacus"], "399": ["n02667093", "abaya"], "400": ["n02669723", "academic_gown"], "401": ["n02672831", "accordion"], "402": ["n02676566", "acoustic_guitar"], "403": ["n02687172", "aircraft_carrier"], "404": ["n02690373", "airliner"], "405": ["n02692877", "airship"], "406": ["n02699494", "altar"], "407": ["n02701002", "ambulance"], "408": ["n02704792", "amphibian"], "409": ["n02708093", "analog_clock"], "410": ["n02727426", "apiary"], "411": ["n02730930", "apron"], "412": ["n02747177", "ashcan"], "413": ["n02749479", "assault_rifle"], "414": ["n02769748", "backpack"], "415": ["n02776631", "bakery"], "416": ["n02777292", "balance_beam"], "417": ["n02782093", "balloon"], "418": ["n02783161", "ballpoint"], "419": ["n02786058", "Band_Aid"], "420": ["n02787622", "banjo"], "421": ["n02788148", "bannister"], "422": ["n02790996", "barbell"], "423": ["n02791124", "barber_chair"], "424": ["n02791270", "barbershop"], "425": ["n02793495", "barn"], "426": ["n02794156", "barometer"], "427": ["n02795169", "barrel"], "428": ["n02797295", "barrow"], "429": ["n02799071", "baseball"], "430": ["n02802426", "basketball"], "431": ["n02804414", "bassinet"], "432": ["n02804610", "bassoon"], "433": ["n02807133", "bathing_cap"], "434": ["n02808304", "bath_towel"], "435": ["n02808440", "bathtub"], "436": ["n02814533", "beach_wagon"], "437": ["n02814860", "beacon"], "438": ["n02815834", "beaker"], "439": ["n02817516", "bearskin"], "440": ["n02823428", "beer_bottle"], "441": ["n02823750", "beer_glass"], "442": ["n02825657", "bell_cote"], "443": ["n02834397", "bib"], "444": ["n02835271", "bicycle-built-for-two"], "445": ["n02837789", "bikini"], "446": ["n02840245", "binder"], "447": ["n02841315", "binoculars"], "448": ["n02843684", "birdhouse"], "449": ["n02859443", "boathouse"], "450": ["n02860847", "bobsled"], "451": ["n02865351", "bolo_tie"], "452": ["n02869837", "bonnet"], "453": ["n02870880", "bookcase"], "454": ["n02871525", "bookshop"], "455": ["n02877765", "bottlecap"], "456": ["n02879718", "bow"], "457": ["n02883205", "bow_tie"], "458": ["n02892201", "brass"], "459": ["n02892767", "brassiere"], "460": ["n02894605", "breakwater"], "461": ["n02895154", "breastplate"], "462": ["n02906734", "broom"], "463": ["n02909870", "bucket"], "464": ["n02910353", "buckle"], "465": ["n02916936", "bulletproof_vest"], "466": ["n02917067", "bullet_train"], "467": ["n02927161", "butcher_shop"], "468": ["n02930766", "cab"], "469": ["n02939185", "caldron"], "470": ["n02948072", "candle"], "471": ["n02950826", "cannon"], "472": ["n02951358", "canoe"], "473": ["n02951585", "can_opener"], "474": ["n02963159", "cardigan"], "475": ["n02965783", "car_mirror"], "476": ["n02966193", "carousel"], "477": ["n02966687", "carpenter's_kit"], "478": ["n02971356", "carton"], "479": ["n02974003", "car_wheel"], "480": ["n02977058", "cash_machine"], "481": ["n02978881", "cassette"], "482": ["n02979186", "cassette_player"], "483": ["n02980441", "castle"], "484": ["n02981792", "catamaran"], "485": ["n02988304", "CD_player"], "486": ["n02992211", "cello"], "487": ["n02992529", "cellular_telephone"], "488": ["n02999410", "chain"], "489": ["n03000134", "chainlink_fence"], "490": ["n03000247", "chain_mail"], "491": ["n03000684", "chain_saw"], "492": ["n03014705", "chest"], "493": ["n03016953", "chiffonier"], "494": ["n03017168", "chime"], "495": ["n03018349", "china_cabinet"], "496": ["n03026506", "Christmas_stocking"], "497": ["n03028079", "church"], "498": ["n03032252", "cinema"], "499": ["n03041632", "cleaver"], "500": ["n03042490", "cliff_dwelling"], "501": ["n03045698", "cloak"], "502": ["n03047690", "clog"], "503": ["n03062245", "cocktail_shaker"], "504": ["n03063599", "coffee_mug"], "505": ["n03063689", "coffeepot"], "506": ["n03065424", "coil"], "507": ["n03075370", "combination_lock"], "508": ["n03085013", "computer_keyboard"], "509": ["n03089624", "confectionery"], "510": ["n03095699", "container_ship"], "511": ["n03100240", "convertible"], "512": ["n03109150", "corkscrew"], "513": ["n03110669", "cornet"], "514": ["n03124043", "cowboy_boot"], "515": ["n03124170", "cowboy_hat"], "516": ["n03125729", "cradle"], "517": ["n03126707", "crane"], "518": ["n03127747", "crash_helmet"], "519": ["n03127925", "crate"], "520": ["n03131574", "crib"], "521": ["n03133878", "Crock_Pot"], "522": ["n03134739", "croquet_ball"], "523": ["n03141823", "crutch"], "524": ["n03146219", "cuirass"], "525": ["n03160309", "dam"], "526": ["n03179701", "desk"], "527": ["n03180011", "desktop_computer"], "528": ["n03187595", "dial_telephone"], "529": ["n03188531", "diaper"], "530": ["n03196217", "digital_clock"], "531": ["n03197337", "digital_watch"], "532": ["n03201208", "dining_table"], "533": ["n03207743", "dishrag"], "534": ["n03207941", "dishwasher"], "535": ["n03208938", "disk_brake"], "536": ["n03216828", "dock"], "537": ["n03218198", "dogsled"], "538": ["n03220513", "dome"], "539": ["n03223299", "doormat"], "540": ["n03240683", "drilling_platform"], "541": ["n03249569", "drum"], "542": ["n03250847", "drumstick"], "543": ["n03255030", "dumbbell"], "544": ["n03259280", "Dutch_oven"], "545": ["n03271574", "electric_fan"], "546": ["n03272010", "electric_guitar"], "547": ["n03272562", "electric_locomotive"], "548": ["n03290653", "entertainment_center"], "549": ["n03291819", "envelope"], "550": ["n03297495", "espresso_maker"], "551": ["n03314780", "face_powder"], "552": ["n03325584", "feather_boa"], "553": ["n03337140", "file"], "554": ["n03344393", "fireboat"], "555": ["n03345487", "fire_engine"], "556": ["n03347037", "fire_screen"], "557": ["n03355925", "flagpole"], "558": ["n03372029", "flute"], "559": ["n03376595", "folding_chair"], "560": ["n03379051", "football_helmet"], "561": ["n03384352", "forklift"], "562": ["n03388043", "fountain"], "563": ["n03388183", "fountain_pen"], "564": ["n03388549", "four-poster"], "565": ["n03393912", "freight_car"], "566": ["n03394916", "French_horn"], "567": ["n03400231", "frying_pan"], "568": ["n03404251", "fur_coat"], "569": ["n03417042", "garbage_truck"], "570": ["n03424325", "gasmask"], "571": ["n03425413", "gas_pump"], "572": ["n03443371", "goblet"], "573": ["n03444034", "go-kart"], "574": ["n03445777", "golf_ball"], "575": ["n03445924", "golfcart"], "576": ["n03447447", "gondola"], "577": ["n03447721", "gong"], "578": ["n03450230", "gown"], "579": ["n03452741", "grand_piano"], "580": ["n03457902", "greenhouse"], "581": ["n03459775", "grille"], "582": ["n03461385", "grocery_store"], "583": ["n03467068", "guillotine"], "584": ["n03476684", "hair_slide"], "585": ["n03476991", "hair_spray"], "586": ["n03478589", "half_track"], "587": ["n03481172", "hammer"], "588": ["n03482405", "hamper"], "589": ["n03483316", "hand_blower"], "590": ["n03485407", "hand-held_computer"], "591": ["n03485794", "handkerchief"], "592": ["n03492542", "hard_disc"], "593": ["n03494278", "harmonica"], "594": ["n03495258", "harp"], "595": ["n03496892", "harvester"], "596": ["n03498962", "hatchet"], "597": ["n03527444", "holster"], "598": ["n03529860", "home_theater"], "599": ["n03530642", "honeycomb"], "600": ["n03532672", "hook"], "601": ["n03534580", "hoopskirt"], "602": ["n03535780", "horizontal_bar"], "603": ["n03538406", "horse_cart"], "604": ["n03544143", "hourglass"], "605": ["n03584254", "iPod"], "606": ["n03584829", "iron"], "607": ["n03590841", "jack-o'-lantern"], "608": ["n03594734", "jean"], "609": ["n03594945", "jeep"], "610": ["n03595614", "jersey"], "611": ["n03598930", "jigsaw_puzzle"], "612": ["n03599486", "jinrikisha"], "613": ["n03602883", "joystick"], "614": ["n03617480", "kimono"], "615": ["n03623198", "knee_pad"], "616": ["n03627232", "knot"], "617": ["n03630383", "lab_coat"], "618": ["n03633091", "ladle"], "619": ["n03637318", "lampshade"], "620": ["n03642806", "laptop"], "621": ["n03649909", "lawn_mower"], "622": ["n03657121", "lens_cap"], "623": ["n03658185", "letter_opener"], "624": ["n03661043", "library"], "625": ["n03662601", "lifeboat"], "626": ["n03666591", "lighter"], "627": ["n03670208", "limousine"], "628": ["n03673027", "liner"], "629": ["n03676483", "lipstick"], "630": ["n03680355", "Loafer"], "631": ["n03690938", "lotion"], "632": ["n03691459", "loudspeaker"], "633": ["n03692522", "loupe"], "634": ["n03697007", "lumbermill"], "635": ["n03706229", "magnetic_compass"], "636": ["n03709823", "mailbag"], "637": ["n03710193", "mailbox"], "638": ["n03710637", "maillot"], "639": ["n03710721", "maillot"], "640": ["n03717622", "manhole_cover"], "641": ["n03720891", "maraca"], "642": ["n03721384", "marimba"], "643": ["n03724870", "mask"], "644": ["n03729826", "matchstick"], "645": ["n03733131", "maypole"], "646": ["n03733281", "maze"], "647": ["n03733805", "measuring_cup"], "648": ["n03742115", "medicine_chest"], "649": ["n03743016", "megalith"], "650": ["n03759954", "microphone"], "651": ["n03761084", "microwave"], "652": ["n03763968", "military_uniform"], "653": ["n03764736", "milk_can"], "654": ["n03769881", "minibus"], "655": ["n03770439", "miniskirt"], "656": ["n03770679", "minivan"], "657": ["n03773504", "missile"], "658": ["n03775071", "mitten"], "659": ["n03775546", "mixing_bowl"], "660": ["n03776460", "mobile_home"], "661": ["n03777568", "Model_T"], "662": ["n03777754", "modem"], "663": ["n03781244", "monastery"], "664": ["n03782006", "monitor"], "665": ["n03785016", "moped"], "666": ["n03786901", "mortar"], "667": ["n03787032", "mortarboard"], "668": ["n03788195", "mosque"], "669": ["n03788365", "mosquito_net"], "670": ["n03791053", "motor_scooter"], "671": ["n03792782", "mountain_bike"], "672": ["n03792972", "mountain_tent"], "673": ["n03793489", "mouse"], "674": ["n03794056", "mousetrap"], "675": ["n03796401", "moving_van"], "676": ["n03803284", "muzzle"], "677": ["n03804744", "nail"], "678": ["n03814639", "neck_brace"], "679": ["n03814906", "necklace"], "680": ["n03825788", "nipple"], "681": ["n03832673", "notebook"], "682": ["n03837869", "obelisk"], "683": ["n03838899", "oboe"], "684": ["n03840681", "ocarina"], "685": ["n03841143", "odometer"], "686": ["n03843555", "oil_filter"], "687": ["n03854065", "organ"], "688": ["n03857828", "oscilloscope"], "689": ["n03866082", "overskirt"], "690": ["n03868242", "oxcart"], "691": ["n03868863", "oxygen_mask"], "692": ["n03871628", "packet"], "693": ["n03873416", "paddle"], "694": ["n03874293", "paddlewheel"], "695": ["n03874599", "padlock"], "696": ["n03876231", "paintbrush"], "697": ["n03877472", "pajama"], "698": ["n03877845", "palace"], "699": ["n03884397", "panpipe"], "700": ["n03887697", "paper_towel"], "701": ["n03888257", "parachute"], "702": ["n03888605", "parallel_bars"], "703": ["n03891251", "park_bench"], "704": ["n03891332", "parking_meter"], "705": ["n03895866", "passenger_car"], "706": ["n03899768", "patio"], "707": ["n03902125", "pay-phone"], "708": ["n03903868", "pedestal"], "709": ["n03908618", "pencil_box"], "710": ["n03908714", "pencil_sharpener"], "711": ["n03916031", "perfume"], "712": ["n03920288", "Petri_dish"], "713": ["n03924679", "photocopier"], "714": ["n03929660", "pick"], "715": ["n03929855", "pickelhaube"], "716": ["n03930313", "picket_fence"], "717": ["n03930630", "pickup"], "718": ["n03933933", "pier"], "719": ["n03935335", "piggy_bank"], "720": ["n03937543", "pill_bottle"], "721": ["n03938244", "pillow"], "722": ["n03942813", "ping-pong_ball"], "723": ["n03944341", "pinwheel"], "724": ["n03947888", "pirate"], "725": ["n03950228", "pitcher"], "726": ["n03954731", "plane"], "727": ["n03956157", "planetarium"], "728": ["n03958227", "plastic_bag"], "729": ["n03961711", "plate_rack"], "730": ["n03967562", "plow"], "731": ["n03970156", "plunger"], "732": ["n03976467", "Polaroid_camera"], "733": ["n03976657", "pole"], "734": ["n03977966", "police_van"], "735": ["n03980874", "poncho"], "736": ["n03982430", "pool_table"], "737": ["n03983396", "pop_bottle"], "738": ["n03991062", "pot"], "739": ["n03992509", "potter's_wheel"], "740": ["n03995372", "power_drill"], "741": ["n03998194", "prayer_rug"], "742": ["n04004767", "printer"], "743": ["n04005630", "prison"], "744": ["n04008634", "projectile"], "745": ["n04009552", "projector"], "746": ["n04019541", "puck"], "747": ["n04023962", "punching_bag"], "748": ["n04026417", "purse"], "749": ["n04033901", "quill"], "750": ["n04033995", "quilt"], "751": ["n04037443", "racer"], "752": ["n04039381", "racket"], "753": ["n04040759", "radiator"], "754": ["n04041544", "radio"], "755": ["n04044716", "radio_telescope"], "756": ["n04049303", "rain_barrel"], "757": ["n04065272", "recreational_vehicle"], "758": ["n04067472", "reel"], "759": ["n04069434", "reflex_camera"], "760": ["n04070727", "refrigerator"], "761": ["n04074963", "remote_control"], "762": ["n04081281", "restaurant"], "763": ["n04086273", "revolver"], "764": ["n04090263", "rifle"], "765": ["n04099969", "rocking_chair"], "766": ["n04111531", "rotisserie"], "767": ["n04116512", "rubber_eraser"], "768": ["n04118538", "rugby_ball"], "769": ["n04118776", "rule"], "770": ["n04120489", "running_shoe"], "771": ["n04125021", "safe"], "772": ["n04127249", "safety_pin"], "773": ["n04131690", "saltshaker"], "774": ["n04133789", "sandal"], "775": ["n04136333", "sarong"], "776": ["n04141076", "sax"], "777": ["n04141327", "scabbard"], "778": ["n04141975", "scale"], "779": ["n04146614", "school_bus"], "780": ["n04147183", "schooner"], "781": ["n04149813", "scoreboard"], "782": ["n04152593", "screen"], "783": ["n04153751", "screw"], "784": ["n04154565", "screwdriver"], "785": ["n04162706", "seat_belt"], "786": ["n04179913", "sewing_machine"], "787": ["n04192698", "shield"], "788": ["n04200800", "shoe_shop"], "789": ["n04201297", "shoji"], "790": ["n04204238", "shopping_basket"], "791": ["n04204347", "shopping_cart"], "792": ["n04208210", "shovel"], "793": ["n04209133", "shower_cap"], "794": ["n04209239", "shower_curtain"], "795": ["n04228054", "ski"], "796": ["n04229816", "ski_mask"], "797": ["n04235860", "sleeping_bag"], "798": ["n04238763", "slide_rule"], "799": ["n04239074", "sliding_door"], "800": ["n04243546", "slot"], "801": ["n04251144", "snorkel"], "802": ["n04252077", "snowmobile"], "803": ["n04252225", "snowplow"], "804": ["n04254120", "soap_dispenser"], "805": ["n04254680", "soccer_ball"], "806": ["n04254777", "sock"], "807": ["n04258138", "solar_dish"], "808": ["n04259630", "sombrero"], "809": ["n04263257", "soup_bowl"], "810": ["n04264628", "space_bar"], "811": ["n04265275", "space_heater"], "812": ["n04266014", "space_shuttle"], "813": ["n04270147", "spatula"], "814": ["n04273569", "speedboat"], "815": ["n04275548", "spider_web"], "816": ["n04277352", "spindle"], "817": ["n04285008", "sports_car"], "818": ["n04286575", "spotlight"], "819": ["n04296562", "stage"], "820": ["n04310018", "steam_locomotive"], "821": ["n04311004", "steel_arch_bridge"], "822": ["n04311174", "steel_drum"], "823": ["n04317175", "stethoscope"], "824": ["n04325704", "stole"], "825": ["n04326547", "stone_wall"], "826": ["n04328186", "stopwatch"], "827": ["n04330267", "stove"], "828": ["n04332243", "strainer"], "829": ["n04335435", "streetcar"], "830": ["n04336792", "stretcher"], "831": ["n04344873", "studio_couch"], "832": ["n04346328", "stupa"], "833": ["n04347754", "submarine"], "834": ["n04350905", "suit"], "835": ["n04355338", "sundial"], "836": ["n04355933", "sunglass"], "837": ["n04356056", "sunglasses"], "838": ["n04357314", "sunscreen"], "839": ["n04366367", "suspension_bridge"], "840": ["n04367480", "swab"], "841": ["n04370456", "sweatshirt"], "842": ["n04371430", "swimming_trunks"], "843": ["n04371774", "swing"], "844": ["n04372370", "switch"], "845": ["n04376876", "syringe"], "846": ["n04380533", "table_lamp"], "847": ["n04389033", "tank"], "848": ["n04392985", "tape_player"], "849": ["n04398044", "teapot"], "850": ["n04399382", "teddy"], "851": ["n04404412", "television"], "852": ["n04409515", "tennis_ball"], "853": ["n04417672", "thatch"], "854": ["n04418357", "theater_curtain"], "855": ["n04423845", "thimble"], "856": ["n04428191", "thresher"], "857": ["n04429376", "throne"], "858": ["n04435653", "tile_roof"], "859": ["n04442312", "toaster"], "860": ["n04443257", "tobacco_shop"], "861": ["n04447861", "toilet_seat"], "862": ["n04456115", "torch"], "863": ["n04458633", "totem_pole"], "864": ["n04461696", "tow_truck"], "865": ["n04462240", "toyshop"], "866": ["n04465501", "tractor"], "867": ["n04467665", "trailer_truck"], "868": ["n04476259", "tray"], "869": ["n04479046", "trench_coat"], "870": ["n04482393", "tricycle"], "871": ["n04483307", "trimaran"], "872": ["n04485082", "tripod"], "873": ["n04486054", "triumphal_arch"], "874": ["n04487081", "trolleybus"], "875": ["n04487394", "trombone"], "876": ["n04493381", "tub"], "877": ["n04501370", "turnstile"], "878": ["n04505470", "typewriter_keyboard"], "879": ["n04507155", "umbrella"], "880": ["n04509417", "unicycle"], "881": ["n04515003", "upright"], "882": ["n04517823", "vacuum"], "883": ["n04522168", "vase"], "884": ["n04523525", "vault"], "885": ["n04525038", "velvet"], "886": ["n04525305", "vending_machine"], "887": ["n04532106", "vestment"], "888": ["n04532670", "viaduct"], "889": ["n04536866", "violin"], "890": ["n04540053", "volleyball"], "891": ["n04542943", "waffle_iron"], "892": ["n04548280", "wall_clock"], "893": ["n04548362", "wallet"], "894": ["n04550184", "wardrobe"], "895": ["n04552348", "warplane"], "896": ["n04553703", "washbasin"], "897": ["n04554684", "washer"], "898": ["n04557648", "water_bottle"], "899": ["n04560804", "water_jug"], "900": ["n04562935", "water_tower"], "901": ["n04579145", "whiskey_jug"], "902": ["n04579432", "whistle"], "903": ["n04584207", "wig"], "904": ["n04589890", "window_screen"], "905": ["n04590129", "window_shade"], "906": ["n04591157", "Windsor_tie"], "907": ["n04591713", "wine_bottle"], "908": ["n04592741", "wing"], "909": ["n04596742", "wok"], "910": ["n04597913", "wooden_spoon"], "911": ["n04599235", "wool"], "912": ["n04604644", "worm_fence"], "913": ["n04606251", "wreck"], "914": ["n04612504", "yawl"], "915": ["n04613696", "yurt"], "916": ["n06359193", "web_site"], "917": ["n06596364", "comic_book"], "918": ["n06785654", "crossword_puzzle"], "919": ["n06794110", "street_sign"], "920": ["n06874185", "traffic_light"], "921": ["n07248320", "book_jacket"], "922": ["n07565083", "menu"], "923": ["n07579787", "plate"], "924": ["n07583066", "guacamole"], "925": ["n07584110", "consomme"], "926": ["n07590611", "hot_pot"], "927": ["n07613480", "trifle"], "928": ["n07614500", "ice_cream"], "929": ["n07615774", "ice_lolly"], "930": ["n07684084", "French_loaf"], "931": ["n07693725", "bagel"], "932": ["n07695742", "pretzel"], "933": ["n07697313", "cheeseburger"], "934": ["n07697537", "hotdog"], "935": ["n07711569", "mashed_potato"], "936": ["n07714571", "head_cabbage"], "937": ["n07714990", "broccoli"], "938": ["n07715103", "cauliflower"], "939": ["n07716358", "zucchini"], "940": ["n07716906", "spaghetti_squash"], "941": ["n07717410", "acorn_squash"], "942": ["n07717556", "butternut_squash"], "943": ["n07718472", "cucumber"], "944": ["n07718747", "artichoke"], "945": ["n07720875", "bell_pepper"], "946": ["n07730033", "cardoon"], "947": ["n07734744", "mushroom"], "948": ["n07742313", "Granny_Smith"], "949": ["n07745940", "strawberry"], "950": ["n07747607", "orange"], "951": ["n07749582", "lemon"], "952": ["n07753113", "fig"], "953": ["n07753275", "pineapple"], "954": ["n07753592", "banana"], "955": ["n07754684", "jackfruit"], "956": ["n07760859", "custard_apple"], "957": ["n07768694", "pomegranate"], "958": ["n07802026", "hay"], "959": ["n07831146", "carbonara"], "960": ["n07836838", "chocolate_sauce"], "961": ["n07860988", "dough"], "962": ["n07871810", "meat_loaf"], "963": ["n07873807", "pizza"], "964": ["n07875152", "potpie"], "965": ["n07880968", "burrito"], "966": ["n07892512", "red_wine"], "967": ["n07920052", "espresso"], "968": ["n07930864", "cup"], "969": ["n07932039", "eggnog"], "970": ["n09193705", "alp"], "971": ["n09229709", "bubble"], "972": ["n09246464", "cliff"], "973": ["n09256479", "coral_reef"], "974": ["n09288635", "geyser"], "975": ["n09332890", "lakeside"], "976": ["n09399592", "promontory"], "977": ["n09421951", "sandbar"], "978": ["n09428293", "seashore"], "979": ["n09468604", "valley"], "980": ["n09472597", "volcano"], "981": ["n09835506", "ballplayer"], "982": ["n10148035", "groom"], "983": ["n10565667", "scuba_diver"], "984": ["n11879895", "rapeseed"], "985": ["n11939491", "daisy"], "986": ["n12057211", "yellow_lady's_slipper"], "987": ["n12144580", "corn"], "988": ["n12267677", "acorn"], "989": ["n12620546", "hip"], "990": ["n12768682", "buckeye"], "991": ["n12985857", "coral_fungus"], "992": ["n12998815", "agaric"], "993": ["n13037406", "gyromitra"], "994": ["n13040303", "stinkhorn"], "995": ["n13044778", "earthstar"], "996": ["n13052670", "hen-of-the-woods"], "997": ["n13054560", "bolete"], "998": ["n13133613", "ear"], "999": ["n15075141", "toilet_tissue"]}
|
autoencoder.py
ADDED
|
@@ -0,0 +1,522 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Borrowed from U-ViT: https://github.com/baofff/U-ViT/blob/main/libs/autoencoder.py.
|
| 2 |
+
# The original code is licensed under MIT License, which is can be found at licenses/LICENSE_UVIT.txt.
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import numpy as np
|
| 7 |
+
from einops import rearrange
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class LinearAttention(nn.Module):
|
| 11 |
+
def __init__(self, dim, heads=4, dim_head=32):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.heads = heads
|
| 14 |
+
hidden_dim = dim_head * heads
|
| 15 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
|
| 16 |
+
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
| 17 |
+
|
| 18 |
+
def forward(self, x):
|
| 19 |
+
b, c, h, w = x.shape
|
| 20 |
+
qkv = self.to_qkv(x)
|
| 21 |
+
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
|
| 22 |
+
k = k.softmax(dim=-1)
|
| 23 |
+
context = torch.einsum('bhdn,bhen->bhde', k, v)
|
| 24 |
+
out = torch.einsum('bhde,bhdn->bhen', context, q)
|
| 25 |
+
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
|
| 26 |
+
return self.to_out(out)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def nonlinearity(x):
|
| 30 |
+
# swish
|
| 31 |
+
return x*torch.sigmoid(x)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def Normalize(in_channels, num_groups=32):
|
| 35 |
+
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class Upsample(nn.Module):
|
| 39 |
+
def __init__(self, in_channels, with_conv):
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.with_conv = with_conv
|
| 42 |
+
if self.with_conv:
|
| 43 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
| 44 |
+
in_channels,
|
| 45 |
+
kernel_size=3,
|
| 46 |
+
stride=1,
|
| 47 |
+
padding=1)
|
| 48 |
+
|
| 49 |
+
def forward(self, x):
|
| 50 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
| 51 |
+
if self.with_conv:
|
| 52 |
+
x = self.conv(x)
|
| 53 |
+
return x
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class Downsample(nn.Module):
|
| 57 |
+
def __init__(self, in_channels, with_conv):
|
| 58 |
+
super().__init__()
|
| 59 |
+
self.with_conv = with_conv
|
| 60 |
+
if self.with_conv:
|
| 61 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
| 62 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
| 63 |
+
in_channels,
|
| 64 |
+
kernel_size=3,
|
| 65 |
+
stride=2,
|
| 66 |
+
padding=0)
|
| 67 |
+
|
| 68 |
+
def forward(self, x):
|
| 69 |
+
if self.with_conv:
|
| 70 |
+
pad = (0,1,0,1)
|
| 71 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
| 72 |
+
x = self.conv(x)
|
| 73 |
+
else:
|
| 74 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
| 75 |
+
return x
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class ResnetBlock(nn.Module):
|
| 79 |
+
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
| 80 |
+
dropout, temb_channels=512):
|
| 81 |
+
super().__init__()
|
| 82 |
+
self.in_channels = in_channels
|
| 83 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 84 |
+
self.out_channels = out_channels
|
| 85 |
+
self.use_conv_shortcut = conv_shortcut
|
| 86 |
+
|
| 87 |
+
self.norm1 = Normalize(in_channels)
|
| 88 |
+
self.conv1 = torch.nn.Conv2d(in_channels,
|
| 89 |
+
out_channels,
|
| 90 |
+
kernel_size=3,
|
| 91 |
+
stride=1,
|
| 92 |
+
padding=1)
|
| 93 |
+
if temb_channels > 0:
|
| 94 |
+
self.temb_proj = torch.nn.Linear(temb_channels,
|
| 95 |
+
out_channels)
|
| 96 |
+
self.norm2 = Normalize(out_channels)
|
| 97 |
+
self.dropout = torch.nn.Dropout(dropout)
|
| 98 |
+
self.conv2 = torch.nn.Conv2d(out_channels,
|
| 99 |
+
out_channels,
|
| 100 |
+
kernel_size=3,
|
| 101 |
+
stride=1,
|
| 102 |
+
padding=1)
|
| 103 |
+
if self.in_channels != self.out_channels:
|
| 104 |
+
if self.use_conv_shortcut:
|
| 105 |
+
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
| 106 |
+
out_channels,
|
| 107 |
+
kernel_size=3,
|
| 108 |
+
stride=1,
|
| 109 |
+
padding=1)
|
| 110 |
+
else:
|
| 111 |
+
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
| 112 |
+
out_channels,
|
| 113 |
+
kernel_size=1,
|
| 114 |
+
stride=1,
|
| 115 |
+
padding=0)
|
| 116 |
+
|
| 117 |
+
def forward(self, x, temb):
|
| 118 |
+
h = x
|
| 119 |
+
h = self.norm1(h)
|
| 120 |
+
h = nonlinearity(h)
|
| 121 |
+
h = self.conv1(h)
|
| 122 |
+
|
| 123 |
+
if temb is not None:
|
| 124 |
+
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
|
| 125 |
+
|
| 126 |
+
h = self.norm2(h)
|
| 127 |
+
h = nonlinearity(h)
|
| 128 |
+
h = self.dropout(h)
|
| 129 |
+
h = self.conv2(h)
|
| 130 |
+
|
| 131 |
+
if self.in_channels != self.out_channels:
|
| 132 |
+
if self.use_conv_shortcut:
|
| 133 |
+
x = self.conv_shortcut(x)
|
| 134 |
+
else:
|
| 135 |
+
x = self.nin_shortcut(x)
|
| 136 |
+
|
| 137 |
+
return x+h
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class LinAttnBlock(LinearAttention):
|
| 141 |
+
"""to match AttnBlock usage"""
|
| 142 |
+
def __init__(self, in_channels):
|
| 143 |
+
super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class AttnBlock(nn.Module):
|
| 147 |
+
def __init__(self, in_channels):
|
| 148 |
+
super().__init__()
|
| 149 |
+
self.in_channels = in_channels
|
| 150 |
+
|
| 151 |
+
self.norm = Normalize(in_channels)
|
| 152 |
+
self.q = torch.nn.Conv2d(in_channels,
|
| 153 |
+
in_channels,
|
| 154 |
+
kernel_size=1,
|
| 155 |
+
stride=1,
|
| 156 |
+
padding=0)
|
| 157 |
+
self.k = torch.nn.Conv2d(in_channels,
|
| 158 |
+
in_channels,
|
| 159 |
+
kernel_size=1,
|
| 160 |
+
stride=1,
|
| 161 |
+
padding=0)
|
| 162 |
+
self.v = torch.nn.Conv2d(in_channels,
|
| 163 |
+
in_channels,
|
| 164 |
+
kernel_size=1,
|
| 165 |
+
stride=1,
|
| 166 |
+
padding=0)
|
| 167 |
+
self.proj_out = torch.nn.Conv2d(in_channels,
|
| 168 |
+
in_channels,
|
| 169 |
+
kernel_size=1,
|
| 170 |
+
stride=1,
|
| 171 |
+
padding=0)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def forward(self, x):
|
| 175 |
+
h_ = x
|
| 176 |
+
h_ = self.norm(h_)
|
| 177 |
+
q = self.q(h_)
|
| 178 |
+
k = self.k(h_)
|
| 179 |
+
v = self.v(h_)
|
| 180 |
+
|
| 181 |
+
# compute attention
|
| 182 |
+
b,c,h,w = q.shape
|
| 183 |
+
q = q.reshape(b,c,h*w)
|
| 184 |
+
q = q.permute(0,2,1) # b,hw,c
|
| 185 |
+
k = k.reshape(b,c,h*w) # b,c,hw
|
| 186 |
+
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
| 187 |
+
w_ = w_ * (int(c)**(-0.5))
|
| 188 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
| 189 |
+
|
| 190 |
+
# attend to values
|
| 191 |
+
v = v.reshape(b,c,h*w)
|
| 192 |
+
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
|
| 193 |
+
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
| 194 |
+
h_ = h_.reshape(b,c,h,w)
|
| 195 |
+
|
| 196 |
+
h_ = self.proj_out(h_)
|
| 197 |
+
|
| 198 |
+
return x+h_
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def make_attn(in_channels, attn_type="vanilla"):
|
| 202 |
+
assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
|
| 203 |
+
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
| 204 |
+
if attn_type == "vanilla":
|
| 205 |
+
return AttnBlock(in_channels)
|
| 206 |
+
elif attn_type == "none":
|
| 207 |
+
return nn.Identity(in_channels)
|
| 208 |
+
else:
|
| 209 |
+
return LinAttnBlock(in_channels)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
class Encoder(nn.Module):
|
| 213 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
| 214 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
| 215 |
+
resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
|
| 216 |
+
**ignore_kwargs):
|
| 217 |
+
super().__init__()
|
| 218 |
+
if use_linear_attn: attn_type = "linear"
|
| 219 |
+
self.ch = ch
|
| 220 |
+
self.temb_ch = 0
|
| 221 |
+
self.num_resolutions = len(ch_mult)
|
| 222 |
+
self.num_res_blocks = num_res_blocks
|
| 223 |
+
self.resolution = resolution
|
| 224 |
+
self.in_channels = in_channels
|
| 225 |
+
|
| 226 |
+
# downsampling
|
| 227 |
+
self.conv_in = torch.nn.Conv2d(in_channels,
|
| 228 |
+
self.ch,
|
| 229 |
+
kernel_size=3,
|
| 230 |
+
stride=1,
|
| 231 |
+
padding=1)
|
| 232 |
+
|
| 233 |
+
curr_res = resolution
|
| 234 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
| 235 |
+
self.in_ch_mult = in_ch_mult
|
| 236 |
+
self.down = nn.ModuleList()
|
| 237 |
+
for i_level in range(self.num_resolutions):
|
| 238 |
+
block = nn.ModuleList()
|
| 239 |
+
attn = nn.ModuleList()
|
| 240 |
+
block_in = ch*in_ch_mult[i_level]
|
| 241 |
+
block_out = ch*ch_mult[i_level]
|
| 242 |
+
for i_block in range(self.num_res_blocks):
|
| 243 |
+
block.append(ResnetBlock(in_channels=block_in,
|
| 244 |
+
out_channels=block_out,
|
| 245 |
+
temb_channels=self.temb_ch,
|
| 246 |
+
dropout=dropout))
|
| 247 |
+
block_in = block_out
|
| 248 |
+
if curr_res in attn_resolutions:
|
| 249 |
+
attn.append(make_attn(block_in, attn_type=attn_type))
|
| 250 |
+
down = nn.Module()
|
| 251 |
+
down.block = block
|
| 252 |
+
down.attn = attn
|
| 253 |
+
if i_level != self.num_resolutions-1:
|
| 254 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
| 255 |
+
curr_res = curr_res // 2
|
| 256 |
+
self.down.append(down)
|
| 257 |
+
|
| 258 |
+
# middle
|
| 259 |
+
self.mid = nn.Module()
|
| 260 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
| 261 |
+
out_channels=block_in,
|
| 262 |
+
temb_channels=self.temb_ch,
|
| 263 |
+
dropout=dropout)
|
| 264 |
+
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
| 265 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
| 266 |
+
out_channels=block_in,
|
| 267 |
+
temb_channels=self.temb_ch,
|
| 268 |
+
dropout=dropout)
|
| 269 |
+
|
| 270 |
+
# end
|
| 271 |
+
self.norm_out = Normalize(block_in)
|
| 272 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
| 273 |
+
2*z_channels if double_z else z_channels,
|
| 274 |
+
kernel_size=3,
|
| 275 |
+
stride=1,
|
| 276 |
+
padding=1)
|
| 277 |
+
|
| 278 |
+
def forward(self, x):
|
| 279 |
+
# timestep embedding
|
| 280 |
+
temb = None
|
| 281 |
+
|
| 282 |
+
# downsampling
|
| 283 |
+
hs = [self.conv_in(x)]
|
| 284 |
+
for i_level in range(self.num_resolutions):
|
| 285 |
+
for i_block in range(self.num_res_blocks):
|
| 286 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
| 287 |
+
if len(self.down[i_level].attn) > 0:
|
| 288 |
+
h = self.down[i_level].attn[i_block](h)
|
| 289 |
+
hs.append(h)
|
| 290 |
+
if i_level != self.num_resolutions-1:
|
| 291 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
| 292 |
+
|
| 293 |
+
# middle
|
| 294 |
+
h = hs[-1]
|
| 295 |
+
h = self.mid.block_1(h, temb)
|
| 296 |
+
h = self.mid.attn_1(h)
|
| 297 |
+
h = self.mid.block_2(h, temb)
|
| 298 |
+
|
| 299 |
+
# end
|
| 300 |
+
h = self.norm_out(h)
|
| 301 |
+
h = nonlinearity(h)
|
| 302 |
+
h = self.conv_out(h)
|
| 303 |
+
return h
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
class Decoder(nn.Module):
|
| 307 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
| 308 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
| 309 |
+
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
|
| 310 |
+
attn_type="vanilla", **ignorekwargs):
|
| 311 |
+
super().__init__()
|
| 312 |
+
if use_linear_attn: attn_type = "linear"
|
| 313 |
+
self.ch = ch
|
| 314 |
+
self.temb_ch = 0
|
| 315 |
+
self.num_resolutions = len(ch_mult)
|
| 316 |
+
self.num_res_blocks = num_res_blocks
|
| 317 |
+
self.resolution = resolution
|
| 318 |
+
self.in_channels = in_channels
|
| 319 |
+
self.give_pre_end = give_pre_end
|
| 320 |
+
self.tanh_out = tanh_out
|
| 321 |
+
|
| 322 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
| 323 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
| 324 |
+
block_in = ch*ch_mult[self.num_resolutions-1]
|
| 325 |
+
curr_res = resolution // 2**(self.num_resolutions-1)
|
| 326 |
+
self.z_shape = (1,z_channels,curr_res,curr_res)
|
| 327 |
+
print("Working with z of shape {} = {} dimensions.".format(
|
| 328 |
+
self.z_shape, np.prod(self.z_shape)))
|
| 329 |
+
|
| 330 |
+
# z to block_in
|
| 331 |
+
self.conv_in = torch.nn.Conv2d(z_channels,
|
| 332 |
+
block_in,
|
| 333 |
+
kernel_size=3,
|
| 334 |
+
stride=1,
|
| 335 |
+
padding=1)
|
| 336 |
+
|
| 337 |
+
# middle
|
| 338 |
+
self.mid = nn.Module()
|
| 339 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
| 340 |
+
out_channels=block_in,
|
| 341 |
+
temb_channels=self.temb_ch,
|
| 342 |
+
dropout=dropout)
|
| 343 |
+
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
| 344 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
| 345 |
+
out_channels=block_in,
|
| 346 |
+
temb_channels=self.temb_ch,
|
| 347 |
+
dropout=dropout)
|
| 348 |
+
|
| 349 |
+
# upsampling
|
| 350 |
+
self.up = nn.ModuleList()
|
| 351 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 352 |
+
block = nn.ModuleList()
|
| 353 |
+
attn = nn.ModuleList()
|
| 354 |
+
block_out = ch*ch_mult[i_level]
|
| 355 |
+
for i_block in range(self.num_res_blocks+1):
|
| 356 |
+
block.append(ResnetBlock(in_channels=block_in,
|
| 357 |
+
out_channels=block_out,
|
| 358 |
+
temb_channels=self.temb_ch,
|
| 359 |
+
dropout=dropout))
|
| 360 |
+
block_in = block_out
|
| 361 |
+
if curr_res in attn_resolutions:
|
| 362 |
+
attn.append(make_attn(block_in, attn_type=attn_type))
|
| 363 |
+
up = nn.Module()
|
| 364 |
+
up.block = block
|
| 365 |
+
up.attn = attn
|
| 366 |
+
if i_level != 0:
|
| 367 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
| 368 |
+
curr_res = curr_res * 2
|
| 369 |
+
self.up.insert(0, up) # prepend to get consistent order
|
| 370 |
+
|
| 371 |
+
# end
|
| 372 |
+
self.norm_out = Normalize(block_in)
|
| 373 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
| 374 |
+
out_ch,
|
| 375 |
+
kernel_size=3,
|
| 376 |
+
stride=1,
|
| 377 |
+
padding=1)
|
| 378 |
+
|
| 379 |
+
def forward(self, z):
|
| 380 |
+
#assert z.shape[1:] == self.z_shape[1:]
|
| 381 |
+
self.last_z_shape = z.shape
|
| 382 |
+
|
| 383 |
+
# timestep embedding
|
| 384 |
+
temb = None
|
| 385 |
+
|
| 386 |
+
# z to block_in
|
| 387 |
+
h = self.conv_in(z)
|
| 388 |
+
|
| 389 |
+
# middle
|
| 390 |
+
h = self.mid.block_1(h, temb)
|
| 391 |
+
h = self.mid.attn_1(h)
|
| 392 |
+
h = self.mid.block_2(h, temb)
|
| 393 |
+
|
| 394 |
+
# upsampling
|
| 395 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 396 |
+
for i_block in range(self.num_res_blocks+1):
|
| 397 |
+
h = self.up[i_level].block[i_block](h, temb)
|
| 398 |
+
if len(self.up[i_level].attn) > 0:
|
| 399 |
+
h = self.up[i_level].attn[i_block](h)
|
| 400 |
+
if i_level != 0:
|
| 401 |
+
h = self.up[i_level].upsample(h)
|
| 402 |
+
|
| 403 |
+
# end
|
| 404 |
+
if self.give_pre_end:
|
| 405 |
+
return h
|
| 406 |
+
|
| 407 |
+
h = self.norm_out(h)
|
| 408 |
+
h = nonlinearity(h)
|
| 409 |
+
h = self.conv_out(h)
|
| 410 |
+
if self.tanh_out:
|
| 411 |
+
h = torch.tanh(h)
|
| 412 |
+
return h
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
class FrozenAutoencoderKL(nn.Module):
|
| 416 |
+
def __init__(self, ddconfig, embed_dim, pretrained_path, scale_factor=0.18215):
|
| 417 |
+
super().__init__()
|
| 418 |
+
print(f'Create autoencoder with scale_factor={scale_factor}')
|
| 419 |
+
self.encoder = Encoder(**ddconfig)
|
| 420 |
+
self.decoder = Decoder(**ddconfig)
|
| 421 |
+
assert ddconfig["double_z"]
|
| 422 |
+
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
|
| 423 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
| 424 |
+
self.embed_dim = embed_dim
|
| 425 |
+
self.scale_factor = scale_factor
|
| 426 |
+
m, u = self.load_state_dict(torch.load(pretrained_path, map_location='cpu'))
|
| 427 |
+
assert len(m) == 0 and len(u) == 0
|
| 428 |
+
self.eval()
|
| 429 |
+
self.requires_grad_(False)
|
| 430 |
+
|
| 431 |
+
def encode_moments(self, x):
|
| 432 |
+
h = self.encoder(x)
|
| 433 |
+
moments = self.quant_conv(h)
|
| 434 |
+
return moments
|
| 435 |
+
|
| 436 |
+
def sample(self, moments):
|
| 437 |
+
mean, logvar = torch.chunk(moments, 2, dim=1)
|
| 438 |
+
logvar = torch.clamp(logvar, -30.0, 20.0)
|
| 439 |
+
std = torch.exp(0.5 * logvar)
|
| 440 |
+
z = mean + std * torch.randn_like(mean)
|
| 441 |
+
z = self.scale_factor * z
|
| 442 |
+
return z
|
| 443 |
+
|
| 444 |
+
def encode(self, x):
|
| 445 |
+
moments = self.encode_moments(x)
|
| 446 |
+
z = self.sample(moments)
|
| 447 |
+
return z
|
| 448 |
+
|
| 449 |
+
def decode(self, z):
|
| 450 |
+
z = (1. / self.scale_factor) * z
|
| 451 |
+
z = self.post_quant_conv(z)
|
| 452 |
+
dec = self.decoder(z)
|
| 453 |
+
return dec
|
| 454 |
+
|
| 455 |
+
def forward(self, inputs, fn):
|
| 456 |
+
if fn == 'encode_moments':
|
| 457 |
+
return self.encode_moments(inputs)
|
| 458 |
+
elif fn == 'encode':
|
| 459 |
+
return self.encode(inputs)
|
| 460 |
+
elif fn == 'decode':
|
| 461 |
+
return self.decode(inputs)
|
| 462 |
+
else:
|
| 463 |
+
raise NotImplementedError
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
def get_model(pretrained_path, scale_factor=0.18215):
|
| 467 |
+
ddconfig = dict(
|
| 468 |
+
double_z=True,
|
| 469 |
+
z_channels=4,
|
| 470 |
+
resolution=256,
|
| 471 |
+
in_channels=3,
|
| 472 |
+
out_ch=3,
|
| 473 |
+
ch=128,
|
| 474 |
+
ch_mult=[1, 2, 4, 4],
|
| 475 |
+
num_res_blocks=2,
|
| 476 |
+
attn_resolutions=[],
|
| 477 |
+
dropout=0.0
|
| 478 |
+
)
|
| 479 |
+
return FrozenAutoencoderKL(ddconfig, 4, pretrained_path, scale_factor)
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
def main():
|
| 483 |
+
import torchvision.transforms as transforms
|
| 484 |
+
from torchvision.utils import save_image
|
| 485 |
+
import os
|
| 486 |
+
from PIL import Image
|
| 487 |
+
|
| 488 |
+
model = get_model('assets/stable-diffusion/autoencoder_kl.pth')
|
| 489 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 490 |
+
model = model.to(device)
|
| 491 |
+
|
| 492 |
+
scale_factor = 0.18215
|
| 493 |
+
T = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(256), transforms.ToTensor()])
|
| 494 |
+
path = 'imgs'
|
| 495 |
+
fnames = os.listdir(path)
|
| 496 |
+
for fname in fnames:
|
| 497 |
+
p = os.path.join(path, fname)
|
| 498 |
+
img = Image.open(p)
|
| 499 |
+
img = T(img)
|
| 500 |
+
img = img * 2. - 1
|
| 501 |
+
img = img[None, ...]
|
| 502 |
+
img = img.to(device)
|
| 503 |
+
|
| 504 |
+
# with torch.cuda.amp.autocast():
|
| 505 |
+
# moments = model.encode_moments(img)
|
| 506 |
+
# mean, logvar = torch.chunk(moments, 2, dim=1)
|
| 507 |
+
# logvar = torch.clamp(logvar, -30.0, 20.0)
|
| 508 |
+
# std = torch.exp(0.5 * logvar)
|
| 509 |
+
# zs = [(mean + std * torch.randn_like(mean)) * scale_factor for _ in range(4)]
|
| 510 |
+
# recons = [model.decode(z) for z in zs]
|
| 511 |
+
|
| 512 |
+
with torch.cuda.amp.autocast():
|
| 513 |
+
print('test encode & decode')
|
| 514 |
+
recons = [model.decode(model.encode(img)) for _ in range(4)]
|
| 515 |
+
|
| 516 |
+
out = torch.cat([img, *recons], dim=0)
|
| 517 |
+
out = (out + 1) * 0.5
|
| 518 |
+
save_image(out, f'recons_{fname}')
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
if __name__ == "__main__":
|
| 522 |
+
main()
|
checkpoints/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
configs/finetune/imagenet256-latent-const.yaml
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data:
|
| 2 |
+
dataset: imagenet256-latent
|
| 3 |
+
category: lmdb
|
| 4 |
+
resolution: 32
|
| 5 |
+
num_channels: 4
|
| 6 |
+
random_flip: True
|
| 7 |
+
root: ../data/imagenet256
|
| 8 |
+
feat_path: None
|
| 9 |
+
|
| 10 |
+
model:
|
| 11 |
+
precond: edm
|
| 12 |
+
model_type: DiT-XL/2
|
| 13 |
+
in_size: 32
|
| 14 |
+
in_channels: 4
|
| 15 |
+
num_classes: 1000
|
| 16 |
+
use_decoder: True
|
| 17 |
+
ext_feature_dim: 0
|
| 18 |
+
pad_cls_token: False
|
| 19 |
+
mask_ratio: 0.0
|
| 20 |
+
mask_ratio_fn: constant
|
| 21 |
+
mask_ratio_min: 0
|
| 22 |
+
mae_loss_coef: 0.1
|
| 23 |
+
class_dropout_prob: 0.1
|
| 24 |
+
|
| 25 |
+
train:
|
| 26 |
+
tf32: True
|
| 27 |
+
amp: False
|
| 28 |
+
batchsize: 64 # batchsize per GPU
|
| 29 |
+
grad_accum: 1
|
| 30 |
+
epochs: 1000
|
| 31 |
+
lr: 0.00005
|
| 32 |
+
lr_rampup_kimg: 0
|
| 33 |
+
xflip: False
|
| 34 |
+
max_num_steps: 100_000
|
| 35 |
+
|
| 36 |
+
eval: # FID evaluation
|
| 37 |
+
cfg_scales: [1.5]
|
| 38 |
+
batchsize: 50
|
| 39 |
+
ref_path: assets/fid_stats/fid_stats_imagenet256_guided_diffusion.npz
|
| 40 |
+
|
| 41 |
+
log:
|
| 42 |
+
log_every: 500
|
| 43 |
+
ckpt_every: 12_500
|
| 44 |
+
tag: finetune-const
|
| 45 |
+
|
| 46 |
+
wandb:
|
| 47 |
+
entity: MaskDiT
|
| 48 |
+
project: MaskDiT-ImageNet256-latent-finetune
|
| 49 |
+
group: finetune-const
|
configs/finetune/imagenet256-latent-cos.yaml
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data:
|
| 2 |
+
dataset: imagenet256-latent
|
| 3 |
+
category: lmdb
|
| 4 |
+
resolution: 32
|
| 5 |
+
num_channels: 4
|
| 6 |
+
random_flip: True
|
| 7 |
+
root: ../data/imagenet256
|
| 8 |
+
feat_path: None
|
| 9 |
+
|
| 10 |
+
model:
|
| 11 |
+
precond: edm
|
| 12 |
+
model_type: DiT-XL/2
|
| 13 |
+
in_size: 32
|
| 14 |
+
in_channels: 4
|
| 15 |
+
num_classes: 1000
|
| 16 |
+
use_decoder: True
|
| 17 |
+
ext_feature_dim: 0
|
| 18 |
+
pad_cls_token: False
|
| 19 |
+
mask_ratio: 0.5
|
| 20 |
+
mask_ratio_fn: cos4
|
| 21 |
+
mask_ratio_min: 0
|
| 22 |
+
mae_loss_coef: 0.1
|
| 23 |
+
class_dropout_prob: 0.1
|
| 24 |
+
|
| 25 |
+
train:
|
| 26 |
+
tf32: True
|
| 27 |
+
amp: False
|
| 28 |
+
batchsize: 64 # batchsize per GPU
|
| 29 |
+
grad_accum: 1
|
| 30 |
+
epochs: 1000
|
| 31 |
+
lr: 0.00005
|
| 32 |
+
lr_rampup_kimg: 0
|
| 33 |
+
xflip: False
|
| 34 |
+
max_num_steps: 100_000
|
| 35 |
+
|
| 36 |
+
eval: # FID evaluation
|
| 37 |
+
cfg_scales: [1.5]
|
| 38 |
+
batchsize: 50
|
| 39 |
+
ref_path: assets/fid_stats/fid_stats_imagenet256_guided_diffusion.npz
|
| 40 |
+
|
| 41 |
+
log:
|
| 42 |
+
log_every: 500
|
| 43 |
+
ckpt_every: 12_500
|
| 44 |
+
tag: finetune-cos
|
| 45 |
+
|
| 46 |
+
wandb:
|
| 47 |
+
entity: MaskDiT
|
| 48 |
+
project: MaskDiT-ImageNet256-latent-finetune
|
| 49 |
+
group: finetune-cos
|
configs/finetune/imagenet512-latent.yaml
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data:
|
| 2 |
+
dataset: imagenet512-latent
|
| 3 |
+
category: lmdb
|
| 4 |
+
resolution: 64
|
| 5 |
+
num_channels: 4
|
| 6 |
+
root: ../data/imagenet512-wds
|
| 7 |
+
total_num: 1281167
|
| 8 |
+
|
| 9 |
+
model:
|
| 10 |
+
precond: edm
|
| 11 |
+
model_type: DiT-XL/2
|
| 12 |
+
in_size: 64
|
| 13 |
+
in_channels: 4
|
| 14 |
+
num_classes: 1000
|
| 15 |
+
use_decoder: True
|
| 16 |
+
ext_feature_dim: 0
|
| 17 |
+
pad_cls_token: False
|
| 18 |
+
mask_ratio: 0.0
|
| 19 |
+
mask_ratio_fn: constant
|
| 20 |
+
mask_ratio_min: 0
|
| 21 |
+
mae_loss_coef: 0.1
|
| 22 |
+
class_dropout_prob: 0.1
|
| 23 |
+
|
| 24 |
+
train:
|
| 25 |
+
tf32: True
|
| 26 |
+
amp: False
|
| 27 |
+
batchsize: 16 # batchsize per GPU
|
| 28 |
+
grad_accum: 1
|
| 29 |
+
epochs: 2000
|
| 30 |
+
lr: 0.00005
|
| 31 |
+
lr_rampup_kimg: 0
|
| 32 |
+
xflip: False
|
| 33 |
+
max_num_steps: 50_000
|
| 34 |
+
|
| 35 |
+
eval: # FID evaluation
|
| 36 |
+
batchsize: 50
|
| 37 |
+
ref_path: assets/fid_stats/VIRTUAL_imagenet512.npz
|
| 38 |
+
|
| 39 |
+
log:
|
| 40 |
+
log_every: 100
|
| 41 |
+
ckpt_every: 10_000
|
| 42 |
+
tag: finetune-4n-wds
|
| 43 |
+
|
| 44 |
+
wandb:
|
| 45 |
+
entity: MaskDiT
|
| 46 |
+
project: MaskDiT-ImageNet512-latent-finetune
|
| 47 |
+
group: finetune-wds-4nodes
|
configs/test/maskdit-256.yaml
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data:
|
| 2 |
+
dataset: imagenet256-latent
|
| 3 |
+
category: lmdb
|
| 4 |
+
resolution: 32
|
| 5 |
+
num_channels: 4
|
| 6 |
+
root: /imagenet_256_latent_lmdb
|
| 7 |
+
total_num: 1281167
|
| 8 |
+
|
| 9 |
+
model:
|
| 10 |
+
precond: edm
|
| 11 |
+
model_type: DiT-XL/2
|
| 12 |
+
in_size: 32
|
| 13 |
+
in_channels: 4
|
| 14 |
+
num_classes: 1000
|
| 15 |
+
use_decoder: True
|
| 16 |
+
ext_feature_dim: 0
|
| 17 |
+
pad_cls_token: False
|
| 18 |
+
mask_ratio: 0.5
|
| 19 |
+
cond_mask_ratio: 0
|
| 20 |
+
mae_loss_coef: 0.1
|
| 21 |
+
class_dropout_prob: 0.1
|
| 22 |
+
|
| 23 |
+
train:
|
| 24 |
+
tf32: False
|
| 25 |
+
amp: True
|
| 26 |
+
batchsize: 32 # batchsize per GPU
|
| 27 |
+
grad_accum: 1
|
| 28 |
+
epochs: 2800
|
| 29 |
+
lr: 0.0001
|
| 30 |
+
lr_rampup_kimg: 0
|
| 31 |
+
xflip: False
|
| 32 |
+
|
| 33 |
+
eval: # FID evaluation
|
| 34 |
+
batchsize: 50
|
| 35 |
+
ref_path: assets/fid_stats/fid_stats_imagenet256_guided_diffusion.npz
|
| 36 |
+
|
| 37 |
+
log:
|
| 38 |
+
log_every: 500
|
| 39 |
+
ckpt_every: 50_000
|
| 40 |
+
tag: baseline
|
| 41 |
+
|
| 42 |
+
wandb:
|
| 43 |
+
entity: MaskDiT
|
| 44 |
+
project: MaskDiT-ImageNet256-latent
|
| 45 |
+
group: baseline
|
configs/test/maskdit-512.yaml
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data:
|
| 2 |
+
dataset: imagenet512-latent
|
| 3 |
+
category: webdataset
|
| 4 |
+
resolution: 64
|
| 5 |
+
num_channels: 4
|
| 6 |
+
root: ../data/imagenet-wds
|
| 7 |
+
total_num: 1281167
|
| 8 |
+
|
| 9 |
+
model:
|
| 10 |
+
precond: edm
|
| 11 |
+
model_type: DiT-XL/2
|
| 12 |
+
in_size: 64
|
| 13 |
+
in_channels: 4
|
| 14 |
+
num_classes: 1000
|
| 15 |
+
use_decoder: True
|
| 16 |
+
ext_feature_dim: 0
|
| 17 |
+
pad_cls_token: False
|
| 18 |
+
mask_ratio: 0.5
|
| 19 |
+
cond_mask_ratio: 0
|
| 20 |
+
mae_loss_coef: 0.1
|
| 21 |
+
class_dropout_prob: 0.1
|
| 22 |
+
|
| 23 |
+
train:
|
| 24 |
+
tf32: False
|
| 25 |
+
amp: True
|
| 26 |
+
batchsize: 32 # batchsize per GPU
|
| 27 |
+
grad_accum: 1
|
| 28 |
+
epochs: 2800
|
| 29 |
+
lr: 0.0001
|
| 30 |
+
lr_rampup_kimg: 0
|
| 31 |
+
xflip: False
|
| 32 |
+
max_num_steps: 2000000
|
| 33 |
+
|
| 34 |
+
eval: # FID evaluation
|
| 35 |
+
batchsize: 50
|
| 36 |
+
ref_path: assets/fid_stats/VIRTUAL_imagenet512.npz
|
| 37 |
+
|
| 38 |
+
log:
|
| 39 |
+
log_every: 500
|
| 40 |
+
ckpt_every: 50_000
|
| 41 |
+
tag: pretrain
|
| 42 |
+
|
| 43 |
+
wandb:
|
| 44 |
+
entity: MaskDiT
|
| 45 |
+
project: MaskDiT-ImageNet256-latent
|
| 46 |
+
group: pretrain
|
configs/train/imagenet256-latent.yaml
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data:
|
| 2 |
+
dataset: imagenet256-latent
|
| 3 |
+
category: lmdb
|
| 4 |
+
resolution: 32
|
| 5 |
+
num_channels: 4
|
| 6 |
+
random_flip: True
|
| 7 |
+
root: ../data/imagenet256
|
| 8 |
+
feat_path: None
|
| 9 |
+
|
| 10 |
+
model:
|
| 11 |
+
precond: edm
|
| 12 |
+
model_type: DiT-XL/2
|
| 13 |
+
in_size: 32
|
| 14 |
+
in_channels: 4
|
| 15 |
+
num_classes: 1000
|
| 16 |
+
use_decoder: True
|
| 17 |
+
ext_feature_dim: 0
|
| 18 |
+
pad_cls_token: False
|
| 19 |
+
mask_ratio: 0.5
|
| 20 |
+
mask_ratio_fn: constant
|
| 21 |
+
mask_ratio_min: 0
|
| 22 |
+
mae_loss_coef: 0.1
|
| 23 |
+
class_dropout_prob: 0.1
|
| 24 |
+
|
| 25 |
+
train:
|
| 26 |
+
tf32: False
|
| 27 |
+
amp: True
|
| 28 |
+
batchsize: 128 # batchsize per GPU
|
| 29 |
+
grad_accum: 1
|
| 30 |
+
epochs: 2800
|
| 31 |
+
lr: 0.0001
|
| 32 |
+
lr_rampup_kimg: 0
|
| 33 |
+
xflip: False
|
| 34 |
+
max_num_steps: 2000000
|
| 35 |
+
|
| 36 |
+
eval: # FID evaluation
|
| 37 |
+
batchsize: 50
|
| 38 |
+
ref_path: assets/fid_stats/fid_stats_imagenet256_guided_diffusion.npz
|
| 39 |
+
|
| 40 |
+
log:
|
| 41 |
+
log_every: 500
|
| 42 |
+
ckpt_every: 50_000
|
| 43 |
+
tag: pretrain
|
| 44 |
+
|
| 45 |
+
wandb:
|
| 46 |
+
entity: MaskDiT
|
| 47 |
+
project: MaskDiT-ImageNet256-latent-train
|
| 48 |
+
group: pretrain
|
configs/train/imagenet512-latent.yaml
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data:
|
| 2 |
+
dataset: imagenet512-latent
|
| 3 |
+
category: webdataset
|
| 4 |
+
resolution: 64
|
| 5 |
+
num_channels: 4
|
| 6 |
+
root: ../data/imagenet512-wds
|
| 7 |
+
total_num: 1281167
|
| 8 |
+
|
| 9 |
+
model:
|
| 10 |
+
precond: edm
|
| 11 |
+
model_type: DiT-XL/2
|
| 12 |
+
in_size: 64
|
| 13 |
+
in_channels: 4
|
| 14 |
+
num_classes: 1000
|
| 15 |
+
use_decoder: True
|
| 16 |
+
ext_feature_dim: 0
|
| 17 |
+
pad_cls_token: False
|
| 18 |
+
mask_ratio: 0.5
|
| 19 |
+
mask_ratio_fn: constant
|
| 20 |
+
mask_ratio_min: 0
|
| 21 |
+
mae_loss_coef: 0.1
|
| 22 |
+
class_dropout_prob: 0.1
|
| 23 |
+
|
| 24 |
+
train:
|
| 25 |
+
tf32: False
|
| 26 |
+
amp: True
|
| 27 |
+
batchsize: 32 # batchsize per GPU
|
| 28 |
+
grad_accum: 1
|
| 29 |
+
epochs: 2000
|
| 30 |
+
lr: 0.0001
|
| 31 |
+
lr_rampup_kimg: 0
|
| 32 |
+
xflip: False
|
| 33 |
+
max_num_steps: 2000000
|
| 34 |
+
|
| 35 |
+
eval: # FID evaluation
|
| 36 |
+
batchsize: 50
|
| 37 |
+
ref_path: assets/fid_stats/VIRTUAL_imagenet512.npz
|
| 38 |
+
|
| 39 |
+
log:
|
| 40 |
+
log_every: 100
|
| 41 |
+
ckpt_every: 25_000
|
| 42 |
+
tag: pretrain-4nodes
|
| 43 |
+
|
| 44 |
+
wandb:
|
| 45 |
+
entity: MaskDiT
|
| 46 |
+
project: MaskDiT-ImageNet512-latent-train
|
| 47 |
+
group: pretrain-4nodes
|
eval_latent.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MIT License
|
| 2 |
+
|
| 3 |
+
# Copyright (c) [2023] [Anima-Lab]
|
| 4 |
+
|
| 5 |
+
from argparse import ArgumentParser
|
| 6 |
+
import os
|
| 7 |
+
from collections import OrderedDict
|
| 8 |
+
from omegaconf import OmegaConf
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
import accelerate
|
| 13 |
+
|
| 14 |
+
from fid import calc
|
| 15 |
+
from models.maskdit import Precond_models
|
| 16 |
+
from sample import generate_with_net
|
| 17 |
+
from utils import dist, mprint, get_ckpt_paths, Logger, parse_int_list, parse_float_none
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# ------------------------------------------------------------
|
| 21 |
+
# Training Helper Function
|
| 22 |
+
|
| 23 |
+
@torch.no_grad()
|
| 24 |
+
def update_ema(ema_model, model, decay=0.9999):
|
| 25 |
+
"""
|
| 26 |
+
Step the EMA model towards the current model.
|
| 27 |
+
"""
|
| 28 |
+
ema_params = OrderedDict(ema_model.named_parameters())
|
| 29 |
+
model_params = OrderedDict(model.named_parameters())
|
| 30 |
+
|
| 31 |
+
for name, param in model_params.items():
|
| 32 |
+
# TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
|
| 33 |
+
ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def requires_grad(model, flag=True):
|
| 37 |
+
"""
|
| 38 |
+
Set requires_grad flag for all parameters in a model.
|
| 39 |
+
"""
|
| 40 |
+
for p in model.parameters():
|
| 41 |
+
p.requires_grad = flag
|
| 42 |
+
|
| 43 |
+
# ------------------------------------------------------------
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def eval_fn(model, args, device, rank, size):
|
| 47 |
+
generate_with_net(args, model, device, rank, size)
|
| 48 |
+
dist.barrier()
|
| 49 |
+
fid = calc(args.outdir, args.ref_path, args.num_expected, args.global_seed, args.fid_batch_size)
|
| 50 |
+
mprint(f'{args.num_expected} samples generated and saved in {args.outdir}')
|
| 51 |
+
mprint(f'guidance: {args.cfg_scale} FID: {fid}')
|
| 52 |
+
dist.barrier()
|
| 53 |
+
return fid
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def eval_loop(args):
|
| 57 |
+
config = OmegaConf.load(args.config)
|
| 58 |
+
accelerator = accelerate.Accelerator()
|
| 59 |
+
|
| 60 |
+
device = accelerator.device
|
| 61 |
+
size = accelerator.num_processes
|
| 62 |
+
rank = accelerator.process_index
|
| 63 |
+
print(f'world_size: {size}, rank: {rank}')
|
| 64 |
+
experiment_dir = args.exp_dir
|
| 65 |
+
|
| 66 |
+
if accelerator.is_main_process:
|
| 67 |
+
logger = Logger(file_name=f'{experiment_dir}/log_eval.txt', file_mode="a+", should_flush=True)
|
| 68 |
+
# setup wandb
|
| 69 |
+
|
| 70 |
+
model = Precond_models[config.model.precond](
|
| 71 |
+
img_resolution=config.model.in_size,
|
| 72 |
+
img_channels=config.model.in_channels,
|
| 73 |
+
num_classes=config.model.num_classes,
|
| 74 |
+
model_type=config.model.model_type,
|
| 75 |
+
use_decoder=config.model.use_decoder,
|
| 76 |
+
mae_loss_coef=config.model.mae_loss_coef,
|
| 77 |
+
pad_cls_token=config.model.pad_cls_token,
|
| 78 |
+
).to(device)
|
| 79 |
+
# Note that parameter initialization is done within the model constructor
|
| 80 |
+
model.eval()
|
| 81 |
+
mprint(f"{config.model.model_type} ((use_decoder: {config.model.use_decoder})) Model Parameters: {sum(p.numel() for p in model.parameters()):,}")
|
| 82 |
+
mprint(f'extras: {model.model.extras}, cls_token: {model.model.cls_token}')
|
| 83 |
+
|
| 84 |
+
# model = torch.compile(model)
|
| 85 |
+
# Load checkpoints
|
| 86 |
+
mprint('start evaluating...')
|
| 87 |
+
|
| 88 |
+
args.outdir = os.path.join(experiment_dir, 'fid', f'edm-steps{args.num_steps}_cfg{args.cfg_scale}')
|
| 89 |
+
os.makedirs(args.outdir, exist_ok=True)
|
| 90 |
+
ckpt = torch.load(args.ckpt, map_location=device)
|
| 91 |
+
model.load_state_dict(ckpt['ema'])
|
| 92 |
+
fid = eval_fn(model, args, device, rank, size)
|
| 93 |
+
mprint(f'FID: {fid}')
|
| 94 |
+
|
| 95 |
+
if accelerator.is_main_process:
|
| 96 |
+
logger.close()
|
| 97 |
+
accelerator.end_training()
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
if __name__ == '__main__':
|
| 101 |
+
parser = ArgumentParser('training parameters')
|
| 102 |
+
# basic config
|
| 103 |
+
parser.add_argument('--config', type=str, required=True, help='path to config file')
|
| 104 |
+
|
| 105 |
+
# training
|
| 106 |
+
parser.add_argument("--exp_dir", type=str, required=True, help='The exp directory to evaluate, it must contain a checkpoints folder')
|
| 107 |
+
parser.add_argument('--ckpt', type=str, required=True, help='path to the checkpoint')
|
| 108 |
+
|
| 109 |
+
# sampling
|
| 110 |
+
parser.add_argument('--seeds', type=parse_int_list, default='100000-149999', help='Random seeds (e.g. 1,2,5-10)')
|
| 111 |
+
parser.add_argument('--subdirs', action='store_true', help='Create subdirectory for every 1000 seeds')
|
| 112 |
+
parser.add_argument('--class_idx', type=int, default=None, help='Class label [default: random]')
|
| 113 |
+
parser.add_argument('--max_batch_size', type=int, default=50, help='Maximum batch size per GPU during sampling, must be a factor of 50k if torch.compile is used')
|
| 114 |
+
parser.add_argument("--cfg_scale", type=parse_float_none, default=None, help='None = no guidance, by default = 4.0')
|
| 115 |
+
|
| 116 |
+
parser.add_argument('--num_steps', type=int, default=40, help='Number of sampling steps')
|
| 117 |
+
parser.add_argument('--S_churn', type=int, default=0, help='Stochasticity strength')
|
| 118 |
+
parser.add_argument('--solver', type=str, default=None, choices=['euler', 'heun'], help='Ablate ODE solver')
|
| 119 |
+
parser.add_argument('--discretization', type=str, default=None, choices=['vp', 've', 'iddpm', 'edm'], help='Ablate ODE solver')
|
| 120 |
+
parser.add_argument('--schedule', type=str, default=None, choices=['vp', 've', 'linear'], help='Ablate noise schedule sigma(t)')
|
| 121 |
+
parser.add_argument('--scaling', type=str, default=None, choices=['vp', 'none'], help='Ablate signal scaling s(t)')
|
| 122 |
+
parser.add_argument('--pretrained_path', type=str, default='assets/stable_diffusion/autoencoder_kl.pth', help='Autoencoder ckpt')
|
| 123 |
+
|
| 124 |
+
parser.add_argument('--ref_path', type=str, default='assets/fid_stats/VIRTUAL_imagenet512.npz', help='Dataset reference statistics')
|
| 125 |
+
parser.add_argument('--num_expected', type=int, default=50000, help='Number of images to use')
|
| 126 |
+
parser.add_argument("--global_seed", type=int, default=0)
|
| 127 |
+
parser.add_argument('--fid_batch_size', type=int, default=128, help='Maximum batch size per GPU')
|
| 128 |
+
|
| 129 |
+
args = parser.parse_args()
|
| 130 |
+
|
| 131 |
+
torch.backends.cudnn.benchmark = True
|
| 132 |
+
eval_loop(args)
|
evaluator.py
ADDED
|
@@ -0,0 +1,695 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MIT License
|
| 2 |
+
|
| 3 |
+
# Copyright (c) [2023] [Anima-Lab]
|
| 4 |
+
|
| 5 |
+
# This code is adapted from https://github.com/openai/guided-diffusion/blob/main/evaluations/evaluator.py.
|
| 6 |
+
# The original code is licensed under a MIT License, which is can be found at licenses/LICENSE_ADM.txt.
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
import argparse
|
| 10 |
+
import io
|
| 11 |
+
import os
|
| 12 |
+
import random
|
| 13 |
+
import warnings
|
| 14 |
+
import zipfile
|
| 15 |
+
from abc import ABC, abstractmethod
|
| 16 |
+
from contextlib import contextmanager
|
| 17 |
+
from functools import partial
|
| 18 |
+
from multiprocessing import cpu_count
|
| 19 |
+
from multiprocessing.pool import ThreadPool
|
| 20 |
+
from typing import Iterable, Optional, Tuple
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
import requests
|
| 24 |
+
import tensorflow.compat.v1 as tf
|
| 25 |
+
from scipy import linalg
|
| 26 |
+
from tqdm.auto import tqdm
|
| 27 |
+
from PIL import Image
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
INCEPTION_V3_URL = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb"
|
| 33 |
+
INCEPTION_V3_PATH = "classify_image_graph_def.pb"
|
| 34 |
+
|
| 35 |
+
FID_POOL_NAME = "pool_3:0"
|
| 36 |
+
FID_SPATIAL_NAME = "mixed_6/conv:0"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def get_all_files(path):
|
| 40 |
+
path_list = []
|
| 41 |
+
for root, dirs, files in os.walk(path):
|
| 42 |
+
for file in files:
|
| 43 |
+
path_list.append(os.path.join(root, file))
|
| 44 |
+
return path_list
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def isimg(filename):
|
| 48 |
+
if filename.endswith(".png") or filename.endswith(".jpg"):
|
| 49 |
+
return True
|
| 50 |
+
else:
|
| 51 |
+
return False
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def png2npz(img_dir):
|
| 55 |
+
img_list = []
|
| 56 |
+
file_list = get_all_files(img_dir)
|
| 57 |
+
for filename in file_list:
|
| 58 |
+
if isimg(filename):
|
| 59 |
+
filepath = filename
|
| 60 |
+
img = np.asarray(Image.open(filepath).convert('RGB'))
|
| 61 |
+
img_list.append(img)
|
| 62 |
+
imgs = np.stack(img_list, axis=0)
|
| 63 |
+
npz_dir = os.path.join('tmp', 'fid')
|
| 64 |
+
os.makedirs(npz_dir, exist_ok=True)
|
| 65 |
+
npz_path = os.path.join(npz_dir, 'imgs.npz')
|
| 66 |
+
np.savez(npz_path, imgs)
|
| 67 |
+
return npz_path
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def main():
|
| 71 |
+
parser = argparse.ArgumentParser()
|
| 72 |
+
parser.add_argument("ref_batch", help="path to reference batch npz file")
|
| 73 |
+
parser.add_argument("sample_batch", help="path to sample batch npz file")
|
| 74 |
+
args = parser.parse_args()
|
| 75 |
+
|
| 76 |
+
config = tf.ConfigProto(
|
| 77 |
+
allow_soft_placement=True # allows DecodeJpeg to run on CPU in Inception graph
|
| 78 |
+
)
|
| 79 |
+
config.gpu_options.allow_growth = True
|
| 80 |
+
evaluator = Evaluator(tf.Session(config=config))
|
| 81 |
+
|
| 82 |
+
print("warming up TensorFlow...")
|
| 83 |
+
# This will cause TF to print a bunch of verbose stuff now rather
|
| 84 |
+
# than after the next print(), to help prevent confusion.
|
| 85 |
+
evaluator.warmup()
|
| 86 |
+
|
| 87 |
+
print("computing reference batch activations...")
|
| 88 |
+
ref_acts = evaluator.read_activations(args.ref_batch)
|
| 89 |
+
print("computing/reading reference batch statistics...")
|
| 90 |
+
ref_stats, ref_stats_spatial = evaluator.read_statistics(args.ref_batch, ref_acts)
|
| 91 |
+
|
| 92 |
+
if os.path.isdir(args.sample_batch):
|
| 93 |
+
sample_batch = png2npz(args.sample_batch)
|
| 94 |
+
else:
|
| 95 |
+
sample_batch = args.sample_batch
|
| 96 |
+
|
| 97 |
+
print("computing sample batch activations...")
|
| 98 |
+
sample_acts = evaluator.read_activations(sample_batch)
|
| 99 |
+
print("computing/reading sample batch statistics...")
|
| 100 |
+
sample_stats, sample_stats_spatial = evaluator.read_statistics(sample_batch, sample_acts)
|
| 101 |
+
|
| 102 |
+
print("Computing evaluations...")
|
| 103 |
+
print("Inception Score:", evaluator.compute_inception_score(sample_acts[0]))
|
| 104 |
+
print("FID:", sample_stats.frechet_distance(ref_stats))
|
| 105 |
+
print("sFID:", sample_stats_spatial.frechet_distance(ref_stats_spatial))
|
| 106 |
+
prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0])
|
| 107 |
+
print("Precision:", prec)
|
| 108 |
+
print("Recall:", recall)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class InvalidFIDException(Exception):
|
| 112 |
+
pass
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class FIDStatistics:
|
| 116 |
+
def __init__(self, mu: np.ndarray, sigma: np.ndarray):
|
| 117 |
+
self.mu = mu
|
| 118 |
+
self.sigma = sigma
|
| 119 |
+
|
| 120 |
+
def frechet_distance(self, other, eps=1e-6):
|
| 121 |
+
"""
|
| 122 |
+
Compute the Frechet distance between two sets of statistics.
|
| 123 |
+
"""
|
| 124 |
+
# https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132
|
| 125 |
+
mu1, sigma1 = self.mu, self.sigma
|
| 126 |
+
mu2, sigma2 = other.mu, other.sigma
|
| 127 |
+
|
| 128 |
+
mu1 = np.atleast_1d(mu1)
|
| 129 |
+
mu2 = np.atleast_1d(mu2)
|
| 130 |
+
|
| 131 |
+
sigma1 = np.atleast_2d(sigma1)
|
| 132 |
+
sigma2 = np.atleast_2d(sigma2)
|
| 133 |
+
|
| 134 |
+
assert (
|
| 135 |
+
mu1.shape == mu2.shape
|
| 136 |
+
), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}"
|
| 137 |
+
assert (
|
| 138 |
+
sigma1.shape == sigma2.shape
|
| 139 |
+
), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}"
|
| 140 |
+
|
| 141 |
+
diff = mu1 - mu2
|
| 142 |
+
|
| 143 |
+
# product might be almost singular
|
| 144 |
+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
| 145 |
+
if not np.isfinite(covmean).all():
|
| 146 |
+
msg = (
|
| 147 |
+
"fid calculation produces singular product; adding %s to diagonal of cov estimates"
|
| 148 |
+
% eps
|
| 149 |
+
)
|
| 150 |
+
warnings.warn(msg)
|
| 151 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
| 152 |
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
| 153 |
+
|
| 154 |
+
# numerical error might give slight imaginary component
|
| 155 |
+
if np.iscomplexobj(covmean):
|
| 156 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
| 157 |
+
m = np.max(np.abs(covmean.imag))
|
| 158 |
+
raise ValueError("Imaginary component {}".format(m))
|
| 159 |
+
covmean = covmean.real
|
| 160 |
+
|
| 161 |
+
tr_covmean = np.trace(covmean)
|
| 162 |
+
|
| 163 |
+
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class Evaluator:
|
| 167 |
+
def __init__(
|
| 168 |
+
self,
|
| 169 |
+
session,
|
| 170 |
+
batch_size=64,
|
| 171 |
+
softmax_batch_size=512,
|
| 172 |
+
):
|
| 173 |
+
self.sess = session
|
| 174 |
+
self.batch_size = batch_size
|
| 175 |
+
self.softmax_batch_size = softmax_batch_size
|
| 176 |
+
self.manifold_estimator = ManifoldEstimator(session)
|
| 177 |
+
with self.sess.graph.as_default():
|
| 178 |
+
self.image_input = tf.placeholder(tf.float32, shape=[None, None, None, 3])
|
| 179 |
+
self.softmax_input = tf.placeholder(tf.float32, shape=[None, 2048])
|
| 180 |
+
self.pool_features, self.spatial_features = _create_feature_graph(self.image_input)
|
| 181 |
+
self.softmax = _create_softmax_graph(self.softmax_input)
|
| 182 |
+
|
| 183 |
+
def warmup(self):
|
| 184 |
+
self.compute_activations(np.zeros([1, 8, 64, 64, 3]))
|
| 185 |
+
|
| 186 |
+
def read_activations(self, npz_path: str) -> Tuple[np.ndarray, np.ndarray]:
|
| 187 |
+
with open_npz_array(npz_path, "arr_0") as reader:
|
| 188 |
+
return self.compute_activations(reader.read_batches(self.batch_size))
|
| 189 |
+
|
| 190 |
+
def compute_activations(self, batches: Iterable[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
|
| 191 |
+
"""
|
| 192 |
+
Compute image features for downstream evals.
|
| 193 |
+
:param batches: a iterator over NHWC numpy arrays in [0, 255].
|
| 194 |
+
:return: a tuple of numpy arrays of shape [N x X], where X is a feature
|
| 195 |
+
dimension. The tuple is (pool_3, spatial).
|
| 196 |
+
"""
|
| 197 |
+
preds = []
|
| 198 |
+
spatial_preds = []
|
| 199 |
+
for batch in tqdm(batches):
|
| 200 |
+
batch = batch.astype(np.float32)
|
| 201 |
+
pred, spatial_pred = self.sess.run(
|
| 202 |
+
[self.pool_features, self.spatial_features], {self.image_input: batch}
|
| 203 |
+
)
|
| 204 |
+
preds.append(pred.reshape([pred.shape[0], -1]))
|
| 205 |
+
spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1]))
|
| 206 |
+
return (
|
| 207 |
+
np.concatenate(preds, axis=0),
|
| 208 |
+
np.concatenate(spatial_preds, axis=0),
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
def read_statistics(
|
| 212 |
+
self, npz_path: str, activations: Tuple[np.ndarray, np.ndarray]
|
| 213 |
+
) -> Tuple[FIDStatistics, FIDStatistics]:
|
| 214 |
+
obj = np.load(npz_path)
|
| 215 |
+
if "mu" in list(obj.keys()):
|
| 216 |
+
return FIDStatistics(obj["mu"], obj["sigma"]), FIDStatistics(
|
| 217 |
+
obj["mu_s"], obj["sigma_s"]
|
| 218 |
+
)
|
| 219 |
+
return tuple(self.compute_statistics(x) for x in activations)
|
| 220 |
+
|
| 221 |
+
def compute_statistics(self, activations: np.ndarray) -> FIDStatistics:
|
| 222 |
+
mu = np.mean(activations, axis=0)
|
| 223 |
+
sigma = np.cov(activations, rowvar=False)
|
| 224 |
+
return FIDStatistics(mu, sigma)
|
| 225 |
+
|
| 226 |
+
def compute_inception_score(self, activations: np.ndarray, split_size: int = 5000) -> float:
|
| 227 |
+
softmax_out = []
|
| 228 |
+
for i in range(0, len(activations), self.softmax_batch_size):
|
| 229 |
+
acts = activations[i : i + self.softmax_batch_size]
|
| 230 |
+
softmax_out.append(self.sess.run(self.softmax, feed_dict={self.softmax_input: acts}))
|
| 231 |
+
preds = np.concatenate(softmax_out, axis=0)
|
| 232 |
+
# https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46
|
| 233 |
+
scores = []
|
| 234 |
+
for i in range(0, len(preds), split_size):
|
| 235 |
+
part = preds[i : i + split_size]
|
| 236 |
+
kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
|
| 237 |
+
kl = np.mean(np.sum(kl, 1))
|
| 238 |
+
scores.append(np.exp(kl))
|
| 239 |
+
return float(np.mean(scores))
|
| 240 |
+
|
| 241 |
+
def compute_prec_recall(
|
| 242 |
+
self, activations_ref: np.ndarray, activations_sample: np.ndarray
|
| 243 |
+
) -> Tuple[float, float]:
|
| 244 |
+
radii_1 = self.manifold_estimator.manifold_radii(activations_ref)
|
| 245 |
+
radii_2 = self.manifold_estimator.manifold_radii(activations_sample)
|
| 246 |
+
pr = self.manifold_estimator.evaluate_pr(
|
| 247 |
+
activations_ref, radii_1, activations_sample, radii_2
|
| 248 |
+
)
|
| 249 |
+
return (float(pr[0][0]), float(pr[1][0]))
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
class ManifoldEstimator:
|
| 253 |
+
"""
|
| 254 |
+
A helper for comparing manifolds of feature vectors.
|
| 255 |
+
Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L57
|
| 256 |
+
"""
|
| 257 |
+
|
| 258 |
+
def __init__(
|
| 259 |
+
self,
|
| 260 |
+
session,
|
| 261 |
+
row_batch_size=10000,
|
| 262 |
+
col_batch_size=10000,
|
| 263 |
+
nhood_sizes=(3,),
|
| 264 |
+
clamp_to_percentile=None,
|
| 265 |
+
eps=1e-5,
|
| 266 |
+
):
|
| 267 |
+
"""
|
| 268 |
+
Estimate the manifold of given feature vectors.
|
| 269 |
+
:param session: the TensorFlow session.
|
| 270 |
+
:param row_batch_size: row batch size to compute pairwise distances
|
| 271 |
+
(parameter to trade-off between memory usage and performance).
|
| 272 |
+
:param col_batch_size: column batch size to compute pairwise distances.
|
| 273 |
+
:param nhood_sizes: number of neighbors used to estimate the manifold.
|
| 274 |
+
:param clamp_to_percentile: prune hyperspheres that have radius larger than
|
| 275 |
+
the given percentile.
|
| 276 |
+
:param eps: small number for numerical stability.
|
| 277 |
+
"""
|
| 278 |
+
self.distance_block = DistanceBlock(session)
|
| 279 |
+
self.row_batch_size = row_batch_size
|
| 280 |
+
self.col_batch_size = col_batch_size
|
| 281 |
+
self.nhood_sizes = nhood_sizes
|
| 282 |
+
self.num_nhoods = len(nhood_sizes)
|
| 283 |
+
self.clamp_to_percentile = clamp_to_percentile
|
| 284 |
+
self.eps = eps
|
| 285 |
+
|
| 286 |
+
def warmup(self):
|
| 287 |
+
feats, radii = (
|
| 288 |
+
np.zeros([1, 2048], dtype=np.float32),
|
| 289 |
+
np.zeros([1, 1], dtype=np.float32),
|
| 290 |
+
)
|
| 291 |
+
self.evaluate_pr(feats, radii, feats, radii)
|
| 292 |
+
|
| 293 |
+
def manifold_radii(self, features: np.ndarray) -> np.ndarray:
|
| 294 |
+
num_images = len(features)
|
| 295 |
+
|
| 296 |
+
# Estimate manifold of features by calculating distances to k-NN of each sample.
|
| 297 |
+
radii = np.zeros([num_images, self.num_nhoods], dtype=np.float32)
|
| 298 |
+
distance_batch = np.zeros([self.row_batch_size, num_images], dtype=np.float32)
|
| 299 |
+
seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32)
|
| 300 |
+
|
| 301 |
+
for begin1 in range(0, num_images, self.row_batch_size):
|
| 302 |
+
end1 = min(begin1 + self.row_batch_size, num_images)
|
| 303 |
+
row_batch = features[begin1:end1]
|
| 304 |
+
|
| 305 |
+
for begin2 in range(0, num_images, self.col_batch_size):
|
| 306 |
+
end2 = min(begin2 + self.col_batch_size, num_images)
|
| 307 |
+
col_batch = features[begin2:end2]
|
| 308 |
+
|
| 309 |
+
# Compute distances between batches.
|
| 310 |
+
distance_batch[
|
| 311 |
+
0 : end1 - begin1, begin2:end2
|
| 312 |
+
] = self.distance_block.pairwise_distances(row_batch, col_batch)
|
| 313 |
+
|
| 314 |
+
# Find the k-nearest neighbor from the current batch.
|
| 315 |
+
radii[begin1:end1, :] = np.concatenate(
|
| 316 |
+
[
|
| 317 |
+
x[:, self.nhood_sizes]
|
| 318 |
+
for x in _numpy_partition(distance_batch[0 : end1 - begin1, :], seq, axis=1)
|
| 319 |
+
],
|
| 320 |
+
axis=0,
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
if self.clamp_to_percentile is not None:
|
| 324 |
+
max_distances = np.percentile(radii, self.clamp_to_percentile, axis=0)
|
| 325 |
+
radii[radii > max_distances] = 0
|
| 326 |
+
return radii
|
| 327 |
+
|
| 328 |
+
def evaluate(self, features: np.ndarray, radii: np.ndarray, eval_features: np.ndarray):
|
| 329 |
+
"""
|
| 330 |
+
Evaluate if new feature vectors are at the manifold.
|
| 331 |
+
"""
|
| 332 |
+
num_eval_images = eval_features.shape[0]
|
| 333 |
+
num_ref_images = radii.shape[0]
|
| 334 |
+
distance_batch = np.zeros([self.row_batch_size, num_ref_images], dtype=np.float32)
|
| 335 |
+
batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32)
|
| 336 |
+
max_realism_score = np.zeros([num_eval_images], dtype=np.float32)
|
| 337 |
+
nearest_indices = np.zeros([num_eval_images], dtype=np.int32)
|
| 338 |
+
|
| 339 |
+
for begin1 in range(0, num_eval_images, self.row_batch_size):
|
| 340 |
+
end1 = min(begin1 + self.row_batch_size, num_eval_images)
|
| 341 |
+
feature_batch = eval_features[begin1:end1]
|
| 342 |
+
|
| 343 |
+
for begin2 in range(0, num_ref_images, self.col_batch_size):
|
| 344 |
+
end2 = min(begin2 + self.col_batch_size, num_ref_images)
|
| 345 |
+
ref_batch = features[begin2:end2]
|
| 346 |
+
|
| 347 |
+
distance_batch[
|
| 348 |
+
0 : end1 - begin1, begin2:end2
|
| 349 |
+
] = self.distance_block.pairwise_distances(feature_batch, ref_batch)
|
| 350 |
+
|
| 351 |
+
# From the minibatch of new feature vectors, determine if they are in the estimated manifold.
|
| 352 |
+
# If a feature vector is inside a hypersphere of some reference sample, then
|
| 353 |
+
# the new sample lies at the estimated manifold.
|
| 354 |
+
# The radii of the hyperspheres are determined from distances of neighborhood size k.
|
| 355 |
+
samples_in_manifold = distance_batch[0 : end1 - begin1, :, None] <= radii
|
| 356 |
+
batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(np.int32)
|
| 357 |
+
|
| 358 |
+
max_realism_score[begin1:end1] = np.max(
|
| 359 |
+
radii[:, 0] / (distance_batch[0 : end1 - begin1, :] + self.eps), axis=1
|
| 360 |
+
)
|
| 361 |
+
nearest_indices[begin1:end1] = np.argmin(distance_batch[0 : end1 - begin1, :], axis=1)
|
| 362 |
+
|
| 363 |
+
return {
|
| 364 |
+
"fraction": float(np.mean(batch_predictions)),
|
| 365 |
+
"batch_predictions": batch_predictions,
|
| 366 |
+
"max_realisim_score": max_realism_score,
|
| 367 |
+
"nearest_indices": nearest_indices,
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
def evaluate_pr(
|
| 371 |
+
self,
|
| 372 |
+
features_1: np.ndarray,
|
| 373 |
+
radii_1: np.ndarray,
|
| 374 |
+
features_2: np.ndarray,
|
| 375 |
+
radii_2: np.ndarray,
|
| 376 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 377 |
+
"""
|
| 378 |
+
Evaluate precision and recall efficiently.
|
| 379 |
+
:param features_1: [N1 x D] feature vectors for reference batch.
|
| 380 |
+
:param radii_1: [N1 x K1] radii for reference vectors.
|
| 381 |
+
:param features_2: [N2 x D] feature vectors for the other batch.
|
| 382 |
+
:param radii_2: [N x K2] radii for other vectors.
|
| 383 |
+
:return: a tuple of arrays for (precision, recall):
|
| 384 |
+
- precision: an np.ndarray of length K1
|
| 385 |
+
- recall: an np.ndarray of length K2
|
| 386 |
+
"""
|
| 387 |
+
features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=np.bool)
|
| 388 |
+
features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=np.bool)
|
| 389 |
+
for begin_1 in range(0, len(features_1), self.row_batch_size):
|
| 390 |
+
end_1 = begin_1 + self.row_batch_size
|
| 391 |
+
batch_1 = features_1[begin_1:end_1]
|
| 392 |
+
for begin_2 in range(0, len(features_2), self.col_batch_size):
|
| 393 |
+
end_2 = begin_2 + self.col_batch_size
|
| 394 |
+
batch_2 = features_2[begin_2:end_2]
|
| 395 |
+
batch_1_in, batch_2_in = self.distance_block.less_thans(
|
| 396 |
+
batch_1, radii_1[begin_1:end_1], batch_2, radii_2[begin_2:end_2]
|
| 397 |
+
)
|
| 398 |
+
features_1_status[begin_1:end_1] |= batch_1_in
|
| 399 |
+
features_2_status[begin_2:end_2] |= batch_2_in
|
| 400 |
+
return (
|
| 401 |
+
np.mean(features_2_status.astype(np.float64), axis=0),
|
| 402 |
+
np.mean(features_1_status.astype(np.float64), axis=0),
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
class DistanceBlock:
|
| 407 |
+
"""
|
| 408 |
+
Calculate pairwise distances between vectors.
|
| 409 |
+
Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L34
|
| 410 |
+
"""
|
| 411 |
+
|
| 412 |
+
def __init__(self, session):
|
| 413 |
+
self.session = session
|
| 414 |
+
|
| 415 |
+
# Initialize TF graph to calculate pairwise distances.
|
| 416 |
+
with session.graph.as_default():
|
| 417 |
+
self._features_batch1 = tf.placeholder(tf.float32, shape=[None, None])
|
| 418 |
+
self._features_batch2 = tf.placeholder(tf.float32, shape=[None, None])
|
| 419 |
+
distance_block_16 = _batch_pairwise_distances(
|
| 420 |
+
tf.cast(self._features_batch1, tf.float16),
|
| 421 |
+
tf.cast(self._features_batch2, tf.float16),
|
| 422 |
+
)
|
| 423 |
+
self.distance_block = tf.cond(
|
| 424 |
+
tf.reduce_all(tf.math.is_finite(distance_block_16)),
|
| 425 |
+
lambda: tf.cast(distance_block_16, tf.float32),
|
| 426 |
+
lambda: _batch_pairwise_distances(self._features_batch1, self._features_batch2),
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
# Extra logic for less thans.
|
| 430 |
+
self._radii1 = tf.placeholder(tf.float32, shape=[None, None])
|
| 431 |
+
self._radii2 = tf.placeholder(tf.float32, shape=[None, None])
|
| 432 |
+
dist32 = tf.cast(self.distance_block, tf.float32)[..., None]
|
| 433 |
+
self._batch_1_in = tf.math.reduce_any(dist32 <= self._radii2, axis=1)
|
| 434 |
+
self._batch_2_in = tf.math.reduce_any(dist32 <= self._radii1[:, None], axis=0)
|
| 435 |
+
|
| 436 |
+
def pairwise_distances(self, U, V):
|
| 437 |
+
"""
|
| 438 |
+
Evaluate pairwise distances between two batches of feature vectors.
|
| 439 |
+
"""
|
| 440 |
+
return self.session.run(
|
| 441 |
+
self.distance_block,
|
| 442 |
+
feed_dict={self._features_batch1: U, self._features_batch2: V},
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
def less_thans(self, batch_1, radii_1, batch_2, radii_2):
|
| 446 |
+
return self.session.run(
|
| 447 |
+
[self._batch_1_in, self._batch_2_in],
|
| 448 |
+
feed_dict={
|
| 449 |
+
self._features_batch1: batch_1,
|
| 450 |
+
self._features_batch2: batch_2,
|
| 451 |
+
self._radii1: radii_1,
|
| 452 |
+
self._radii2: radii_2,
|
| 453 |
+
},
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
def _batch_pairwise_distances(U, V):
|
| 458 |
+
"""
|
| 459 |
+
Compute pairwise distances between two batches of feature vectors.
|
| 460 |
+
"""
|
| 461 |
+
with tf.variable_scope("pairwise_dist_block"):
|
| 462 |
+
# Squared norms of each row in U and V.
|
| 463 |
+
norm_u = tf.reduce_sum(tf.square(U), 1)
|
| 464 |
+
norm_v = tf.reduce_sum(tf.square(V), 1)
|
| 465 |
+
|
| 466 |
+
# norm_u as a column and norm_v as a row vectors.
|
| 467 |
+
norm_u = tf.reshape(norm_u, [-1, 1])
|
| 468 |
+
norm_v = tf.reshape(norm_v, [1, -1])
|
| 469 |
+
|
| 470 |
+
# Pairwise squared Euclidean distances.
|
| 471 |
+
D = tf.maximum(norm_u - 2 * tf.matmul(U, V, False, True) + norm_v, 0.0)
|
| 472 |
+
|
| 473 |
+
return D
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
class NpzArrayReader(ABC):
|
| 477 |
+
@abstractmethod
|
| 478 |
+
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
|
| 479 |
+
pass
|
| 480 |
+
|
| 481 |
+
@abstractmethod
|
| 482 |
+
def remaining(self) -> int:
|
| 483 |
+
pass
|
| 484 |
+
|
| 485 |
+
def read_batches(self, batch_size: int) -> Iterable[np.ndarray]:
|
| 486 |
+
def gen_fn():
|
| 487 |
+
while True:
|
| 488 |
+
batch = self.read_batch(batch_size)
|
| 489 |
+
if batch is None:
|
| 490 |
+
break
|
| 491 |
+
yield batch
|
| 492 |
+
|
| 493 |
+
rem = self.remaining()
|
| 494 |
+
num_batches = rem // batch_size + int(rem % batch_size != 0)
|
| 495 |
+
return BatchIterator(gen_fn, num_batches)
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
class BatchIterator:
|
| 499 |
+
def __init__(self, gen_fn, length):
|
| 500 |
+
self.gen_fn = gen_fn
|
| 501 |
+
self.length = length
|
| 502 |
+
|
| 503 |
+
def __len__(self):
|
| 504 |
+
return self.length
|
| 505 |
+
|
| 506 |
+
def __iter__(self):
|
| 507 |
+
return self.gen_fn()
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
class StreamingNpzArrayReader(NpzArrayReader):
|
| 511 |
+
def __init__(self, arr_f, shape, dtype):
|
| 512 |
+
self.arr_f = arr_f
|
| 513 |
+
self.shape = shape
|
| 514 |
+
self.dtype = dtype
|
| 515 |
+
self.idx = 0
|
| 516 |
+
|
| 517 |
+
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
|
| 518 |
+
if self.idx >= self.shape[0]:
|
| 519 |
+
return None
|
| 520 |
+
|
| 521 |
+
bs = min(batch_size, self.shape[0] - self.idx)
|
| 522 |
+
self.idx += bs
|
| 523 |
+
|
| 524 |
+
if self.dtype.itemsize == 0:
|
| 525 |
+
return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype)
|
| 526 |
+
|
| 527 |
+
read_count = bs * np.prod(self.shape[1:])
|
| 528 |
+
read_size = int(read_count * self.dtype.itemsize)
|
| 529 |
+
data = _read_bytes(self.arr_f, read_size, "array data")
|
| 530 |
+
return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]])
|
| 531 |
+
|
| 532 |
+
def remaining(self) -> int:
|
| 533 |
+
return max(0, self.shape[0] - self.idx)
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
class MemoryNpzArrayReader(NpzArrayReader):
|
| 537 |
+
def __init__(self, arr):
|
| 538 |
+
self.arr = arr
|
| 539 |
+
self.idx = 0
|
| 540 |
+
|
| 541 |
+
@classmethod
|
| 542 |
+
def load(cls, path: str, arr_name: str):
|
| 543 |
+
with open(path, "rb") as f:
|
| 544 |
+
arr = np.load(f)[arr_name]
|
| 545 |
+
return cls(arr)
|
| 546 |
+
|
| 547 |
+
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
|
| 548 |
+
if self.idx >= self.arr.shape[0]:
|
| 549 |
+
return None
|
| 550 |
+
|
| 551 |
+
res = self.arr[self.idx : self.idx + batch_size]
|
| 552 |
+
self.idx += batch_size
|
| 553 |
+
return res
|
| 554 |
+
|
| 555 |
+
def remaining(self) -> int:
|
| 556 |
+
return max(0, self.arr.shape[0] - self.idx)
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
@contextmanager
|
| 560 |
+
def open_npz_array(path: str, arr_name: str) -> NpzArrayReader:
|
| 561 |
+
with _open_npy_file(path, arr_name) as arr_f:
|
| 562 |
+
version = np.lib.format.read_magic(arr_f)
|
| 563 |
+
if version == (1, 0):
|
| 564 |
+
header = np.lib.format.read_array_header_1_0(arr_f)
|
| 565 |
+
elif version == (2, 0):
|
| 566 |
+
header = np.lib.format.read_array_header_2_0(arr_f)
|
| 567 |
+
else:
|
| 568 |
+
yield MemoryNpzArrayReader.load(path, arr_name)
|
| 569 |
+
return
|
| 570 |
+
shape, fortran, dtype = header
|
| 571 |
+
if fortran or dtype.hasobject:
|
| 572 |
+
yield MemoryNpzArrayReader.load(path, arr_name)
|
| 573 |
+
else:
|
| 574 |
+
yield StreamingNpzArrayReader(arr_f, shape, dtype)
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
def _read_bytes(fp, size, error_template="ran out of data"):
|
| 578 |
+
"""
|
| 579 |
+
Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886
|
| 580 |
+
Read from file-like object until size bytes are read.
|
| 581 |
+
Raises ValueError if not EOF is encountered before size bytes are read.
|
| 582 |
+
Non-blocking objects only supported if they derive from io objects.
|
| 583 |
+
Required as e.g. ZipExtFile in python 2.6 can return less data than
|
| 584 |
+
requested.
|
| 585 |
+
"""
|
| 586 |
+
data = bytes()
|
| 587 |
+
while True:
|
| 588 |
+
# io files (default in python3) return None or raise on
|
| 589 |
+
# would-block, python2 file will truncate, probably nothing can be
|
| 590 |
+
# done about that. note that regular files can't be non-blocking
|
| 591 |
+
try:
|
| 592 |
+
r = fp.read(size - len(data))
|
| 593 |
+
data += r
|
| 594 |
+
if len(r) == 0 or len(data) == size:
|
| 595 |
+
break
|
| 596 |
+
except io.BlockingIOError:
|
| 597 |
+
pass
|
| 598 |
+
if len(data) != size:
|
| 599 |
+
msg = "EOF: reading %s, expected %d bytes got %d"
|
| 600 |
+
raise ValueError(msg % (error_template, size, len(data)))
|
| 601 |
+
else:
|
| 602 |
+
return data
|
| 603 |
+
|
| 604 |
+
|
| 605 |
+
@contextmanager
|
| 606 |
+
def _open_npy_file(path: str, arr_name: str):
|
| 607 |
+
with open(path, "rb") as f:
|
| 608 |
+
with zipfile.ZipFile(f, "r") as zip_f:
|
| 609 |
+
if f"{arr_name}.npy" not in zip_f.namelist():
|
| 610 |
+
raise ValueError(f"missing {arr_name} in npz file")
|
| 611 |
+
with zip_f.open(f"{arr_name}.npy", "r") as arr_f:
|
| 612 |
+
yield arr_f
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
def _download_inception_model():
|
| 616 |
+
if os.path.exists(INCEPTION_V3_PATH):
|
| 617 |
+
return
|
| 618 |
+
print("downloading InceptionV3 model...")
|
| 619 |
+
with requests.get(INCEPTION_V3_URL, stream=True) as r:
|
| 620 |
+
r.raise_for_status()
|
| 621 |
+
tmp_path = INCEPTION_V3_PATH + ".tmp"
|
| 622 |
+
with open(tmp_path, "wb") as f:
|
| 623 |
+
for chunk in tqdm(r.iter_content(chunk_size=8192)):
|
| 624 |
+
f.write(chunk)
|
| 625 |
+
os.rename(tmp_path, INCEPTION_V3_PATH)
|
| 626 |
+
|
| 627 |
+
|
| 628 |
+
def _create_feature_graph(input_batch):
|
| 629 |
+
_download_inception_model()
|
| 630 |
+
prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
|
| 631 |
+
with open(INCEPTION_V3_PATH, "rb") as f:
|
| 632 |
+
graph_def = tf.GraphDef()
|
| 633 |
+
graph_def.ParseFromString(f.read())
|
| 634 |
+
pool3, spatial = tf.import_graph_def(
|
| 635 |
+
graph_def,
|
| 636 |
+
input_map={f"ExpandDims:0": input_batch},
|
| 637 |
+
return_elements=[FID_POOL_NAME, FID_SPATIAL_NAME],
|
| 638 |
+
name=prefix,
|
| 639 |
+
)
|
| 640 |
+
_update_shapes(pool3)
|
| 641 |
+
spatial = spatial[..., :7]
|
| 642 |
+
return pool3, spatial
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
def _create_softmax_graph(input_batch):
|
| 646 |
+
_download_inception_model()
|
| 647 |
+
prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
|
| 648 |
+
with open(INCEPTION_V3_PATH, "rb") as f:
|
| 649 |
+
graph_def = tf.GraphDef()
|
| 650 |
+
graph_def.ParseFromString(f.read())
|
| 651 |
+
(matmul,) = tf.import_graph_def(
|
| 652 |
+
graph_def, return_elements=[f"softmax/logits/MatMul"], name=prefix
|
| 653 |
+
)
|
| 654 |
+
w = matmul.inputs[1]
|
| 655 |
+
logits = tf.matmul(input_batch, w)
|
| 656 |
+
return tf.nn.softmax(logits)
|
| 657 |
+
|
| 658 |
+
|
| 659 |
+
def _update_shapes(pool3):
|
| 660 |
+
# https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L50-L63
|
| 661 |
+
ops = pool3.graph.get_operations()
|
| 662 |
+
for op in ops:
|
| 663 |
+
for o in op.outputs:
|
| 664 |
+
shape = o.get_shape()
|
| 665 |
+
if shape._dims is not None: # pylint: disable=protected-access
|
| 666 |
+
# shape = [s.value for s in shape] TF 1.x
|
| 667 |
+
shape = [s for s in shape] # TF 2.x
|
| 668 |
+
new_shape = []
|
| 669 |
+
for j, s in enumerate(shape):
|
| 670 |
+
if s == 1 and j == 0:
|
| 671 |
+
new_shape.append(None)
|
| 672 |
+
else:
|
| 673 |
+
new_shape.append(s)
|
| 674 |
+
o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
|
| 675 |
+
return pool3
|
| 676 |
+
|
| 677 |
+
|
| 678 |
+
def _numpy_partition(arr, kth, **kwargs):
|
| 679 |
+
num_workers = min(cpu_count(), len(arr))
|
| 680 |
+
chunk_size = len(arr) // num_workers
|
| 681 |
+
extra = len(arr) % num_workers
|
| 682 |
+
|
| 683 |
+
start_idx = 0
|
| 684 |
+
batches = []
|
| 685 |
+
for i in range(num_workers):
|
| 686 |
+
size = chunk_size + (1 if i < extra else 0)
|
| 687 |
+
batches.append(arr[start_idx : start_idx + size])
|
| 688 |
+
start_idx += size
|
| 689 |
+
|
| 690 |
+
with ThreadPool(num_workers) as pool:
|
| 691 |
+
return list(pool.map(partial(np.partition, kth=kth, **kwargs), batches))
|
| 692 |
+
|
| 693 |
+
|
| 694 |
+
if __name__ == "__main__":
|
| 695 |
+
main()
|
extract_latent.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import time
|
| 4 |
+
|
| 5 |
+
import lmdb
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn
|
| 8 |
+
|
| 9 |
+
import torchvision.transforms as transforms
|
| 10 |
+
from torch.utils.data import DataLoader
|
| 11 |
+
|
| 12 |
+
from autoencoder import get_model
|
| 13 |
+
from train_utils.datasets import imagenet_lmdb_dataset
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def main():
|
| 17 |
+
parser = argparse.ArgumentParser()
|
| 18 |
+
parser.add_argument('--data_name', default='imagenet', type=str)
|
| 19 |
+
parser.add_argument('--data_dir', default='../datasets', type=str)
|
| 20 |
+
parser.add_argument('--ckpt', default='assets/vae/autoencoder_kl.pth', type=str, help='checkpoint path')
|
| 21 |
+
parser.add_argument('--resolution', default=512, type=int)
|
| 22 |
+
parser.add_argument('--batch_size', default=128, type=int)
|
| 23 |
+
parser.add_argument('--split', default='train', type=str)
|
| 24 |
+
parser.add_argument('--xflip', action='store_true')
|
| 25 |
+
parser.add_argument('--outdir', type=str, default='../data/imagenet512-latent', help='output directory')
|
| 26 |
+
args = parser.parse_args()
|
| 27 |
+
|
| 28 |
+
assert args.split in ['train', 'val']
|
| 29 |
+
|
| 30 |
+
transform = transforms.Compose([
|
| 31 |
+
transforms.ToTensor(),
|
| 32 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
| 33 |
+
])
|
| 34 |
+
|
| 35 |
+
dataset = imagenet_lmdb_dataset(root=f'{args.data_dir}/{args.split}',
|
| 36 |
+
transform=transform, resolution=args.resolution)
|
| 37 |
+
|
| 38 |
+
print(f'data size: {len(dataset)}')
|
| 39 |
+
|
| 40 |
+
model = get_model(args.ckpt)
|
| 41 |
+
print(f'load vae weights from autoencoder_kl.pth')
|
| 42 |
+
model = nn.DataParallel(model)
|
| 43 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 44 |
+
model.to(device)
|
| 45 |
+
|
| 46 |
+
def extract_feature():
|
| 47 |
+
outdir = f'{args.data_name}_{args.resolution}_latent_lmdb'
|
| 48 |
+
target_db_dir = os.path.join(args.outdir, outdir, args.split)
|
| 49 |
+
os.makedirs(target_db_dir, exist_ok=True)
|
| 50 |
+
target_env = lmdb.open(target_db_dir, map_size=pow(2,40), readahead=False)
|
| 51 |
+
|
| 52 |
+
dataset_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, drop_last=False,
|
| 53 |
+
num_workers=8, pin_memory=True, persistent_workers=True)
|
| 54 |
+
|
| 55 |
+
idx = 0
|
| 56 |
+
begin = time.time()
|
| 57 |
+
print('start...')
|
| 58 |
+
for batch in dataset_loader:
|
| 59 |
+
img, label = batch
|
| 60 |
+
assert img.min() >= -1 and img.max() <= 1
|
| 61 |
+
|
| 62 |
+
img = img.to(device)
|
| 63 |
+
moments = model(img, fn='encode_moments')
|
| 64 |
+
assert moments.shape[-1] == (args.resolution // 8)
|
| 65 |
+
|
| 66 |
+
moments = moments.detach().cpu().numpy()
|
| 67 |
+
label = label.detach().cpu().numpy()
|
| 68 |
+
|
| 69 |
+
with target_env.begin(write=True) as target_txn:
|
| 70 |
+
for moment, lb in zip(moments, label):
|
| 71 |
+
target_txn.put(f'z-{str(idx)}'.encode('utf-8'), moment)
|
| 72 |
+
target_txn.put(f'y-{str(idx)}'.encode('utf-8'), str(lb).encode('utf-8'))
|
| 73 |
+
idx += 1
|
| 74 |
+
|
| 75 |
+
if idx % 5120 == 0:
|
| 76 |
+
cur_time = time.time()
|
| 77 |
+
print(f'saved {idx} files with {cur_time - begin}s elapsed')
|
| 78 |
+
begin = time.time()
|
| 79 |
+
|
| 80 |
+
# idx = 1_281_167
|
| 81 |
+
if args.xflip:
|
| 82 |
+
print('starting to store the xflip latents')
|
| 83 |
+
begin = time.time()
|
| 84 |
+
for batch in dataset_loader:
|
| 85 |
+
img, label = batch
|
| 86 |
+
assert img.min() >= -1 and img.max() <= 1
|
| 87 |
+
|
| 88 |
+
img = img.to(device)
|
| 89 |
+
moments = model(img.flip(dims=[-1]), fn='encode_moments')
|
| 90 |
+
|
| 91 |
+
moments = moments.detach().cpu().numpy()
|
| 92 |
+
label = label.detach().cpu().numpy()
|
| 93 |
+
|
| 94 |
+
with target_env.begin(write=True) as target_txn:
|
| 95 |
+
for moment, lb in zip(moments, label):
|
| 96 |
+
target_txn.put(f'z-{str(idx)}'.encode('utf-8'), moment)
|
| 97 |
+
target_txn.put(f'y-{str(idx)}'.encode('utf-8'), str(lb).encode('utf-8'))
|
| 98 |
+
idx += 1
|
| 99 |
+
|
| 100 |
+
if idx % 10000 == 0:
|
| 101 |
+
cur_time = time.time()
|
| 102 |
+
print(f'saved {idx} files with {cur_time - begin}s elapsed')
|
| 103 |
+
begin = time.time()
|
| 104 |
+
|
| 105 |
+
with target_env.begin(write=True) as target_txn:
|
| 106 |
+
target_txn.put('length'.encode('utf-8'), str(idx).encode('utf-8'))
|
| 107 |
+
|
| 108 |
+
print(f'[finished] saved {idx} files')
|
| 109 |
+
|
| 110 |
+
extract_feature()
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
if __name__ == "__main__":
|
| 114 |
+
main()
|
fid.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MIT License
|
| 2 |
+
|
| 3 |
+
# Copyright (c) [2023] [Anima-Lab]
|
| 4 |
+
|
| 5 |
+
# This code is adapted from https://github.com/NVlabs/edm/blob/main/fid.py.
|
| 6 |
+
# The original code is licensed under a Creative Commons
|
| 7 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License, which is can be found at licenses/LICENSE_EDM.txt.
|
| 8 |
+
|
| 9 |
+
"""Script for calculating Frechet Inception Distance (FID)."""
|
| 10 |
+
import argparse
|
| 11 |
+
from multiprocessing import Process
|
| 12 |
+
|
| 13 |
+
import click
|
| 14 |
+
import tqdm
|
| 15 |
+
import pickle
|
| 16 |
+
import numpy as np
|
| 17 |
+
import scipy.linalg
|
| 18 |
+
import torch
|
| 19 |
+
import torch.distributed as dist
|
| 20 |
+
from torch.utils.data import DataLoader
|
| 21 |
+
|
| 22 |
+
from utils import *
|
| 23 |
+
from train_utils.datasets import ImageFolderDataset
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
#----------------------------------------------------------------------------
|
| 27 |
+
|
| 28 |
+
def calculate_inception_stats(
|
| 29 |
+
image_path, num_expected=None, seed=0, max_batch_size=64,
|
| 30 |
+
num_workers=3, prefetch_factor=2, device=torch.device('cuda'),
|
| 31 |
+
):
|
| 32 |
+
# Rank 0 goes first.
|
| 33 |
+
if dist.get_rank() != 0:
|
| 34 |
+
dist.barrier()
|
| 35 |
+
|
| 36 |
+
# Load Inception-v3 model.
|
| 37 |
+
# This is a direct PyTorch translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
| 38 |
+
detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
|
| 39 |
+
mprint('Loading Inception-v3 model...')
|
| 40 |
+
detector_kwargs = dict(return_features=True)
|
| 41 |
+
feature_dim = 2048
|
| 42 |
+
with open(detector_url, 'rb') as f:
|
| 43 |
+
detector_net = pickle.load(f).to(device)
|
| 44 |
+
|
| 45 |
+
# List images.
|
| 46 |
+
mprint(f'Loading images from "{image_path}"...')
|
| 47 |
+
dataset_obj = ImageFolderDataset(path=image_path, max_size=num_expected, random_seed=seed)
|
| 48 |
+
if num_expected is not None and len(dataset_obj) < num_expected:
|
| 49 |
+
raise click.ClickException(f'Found {len(dataset_obj)} images, but expected at least {num_expected}')
|
| 50 |
+
if len(dataset_obj) < 2:
|
| 51 |
+
raise click.ClickException(f'Found {len(dataset_obj)} images, but need at least 2 to compute statistics')
|
| 52 |
+
|
| 53 |
+
# Other ranks follow.
|
| 54 |
+
if dist.get_rank() == 0:
|
| 55 |
+
dist.barrier()
|
| 56 |
+
|
| 57 |
+
# Divide images into batches.
|
| 58 |
+
num_batches = ((len(dataset_obj) - 1) // (max_batch_size * dist.get_world_size()) + 1) * dist.get_world_size()
|
| 59 |
+
all_batches = torch.arange(len(dataset_obj)).tensor_split(num_batches)
|
| 60 |
+
rank_batches = all_batches[dist.get_rank() :: dist.get_world_size()]
|
| 61 |
+
data_loader = DataLoader(dataset_obj, batch_sampler=rank_batches, num_workers=num_workers, prefetch_factor=prefetch_factor)
|
| 62 |
+
|
| 63 |
+
# Accumulate statistics.
|
| 64 |
+
mprint(f'Calculating statistics for {len(dataset_obj)} images...')
|
| 65 |
+
mu = torch.zeros([feature_dim], dtype=torch.float64, device=device)
|
| 66 |
+
sigma = torch.zeros([feature_dim, feature_dim], dtype=torch.float64, device=device)
|
| 67 |
+
for images, _labels in tqdm.tqdm(data_loader, unit='batch', disable=(dist.get_rank() != 0)):
|
| 68 |
+
dist.barrier()
|
| 69 |
+
if images.shape[0] == 0:
|
| 70 |
+
continue
|
| 71 |
+
if images.shape[1] == 1:
|
| 72 |
+
images = images.repeat([1, 3, 1, 1])
|
| 73 |
+
features = detector_net(images.to(device), **detector_kwargs).to(torch.float64)
|
| 74 |
+
mu += features.sum(0)
|
| 75 |
+
sigma += features.T @ features
|
| 76 |
+
|
| 77 |
+
# Calculate grand totals.
|
| 78 |
+
dist.all_reduce(mu)
|
| 79 |
+
dist.all_reduce(sigma)
|
| 80 |
+
mu /= len(dataset_obj)
|
| 81 |
+
sigma -= mu.ger(mu) * len(dataset_obj)
|
| 82 |
+
sigma /= len(dataset_obj) - 1
|
| 83 |
+
return mu.cpu().numpy(), sigma.cpu().numpy()
|
| 84 |
+
|
| 85 |
+
#----------------------------------------------------------------------------
|
| 86 |
+
|
| 87 |
+
def calculate_fid_from_inception_stats(mu, sigma, mu_ref, sigma_ref):
|
| 88 |
+
m = np.square(mu - mu_ref).sum()
|
| 89 |
+
s, _ = scipy.linalg.sqrtm(np.dot(sigma, sigma_ref), disp=False)
|
| 90 |
+
fid = m + np.trace(sigma + sigma_ref - s * 2)
|
| 91 |
+
return float(np.real(fid))
|
| 92 |
+
|
| 93 |
+
#----------------------------------------------------------------------------
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def calc(image_path, ref_path, num_expected, seed, batch):
|
| 97 |
+
"""Calculate FID for a given set of images."""
|
| 98 |
+
if dist.get_rank() == 0:
|
| 99 |
+
logger = Logger(file_name=f'{image_path}/log_fid.txt', file_mode="a+", should_flush=True)
|
| 100 |
+
|
| 101 |
+
mprint(f'Loading dataset reference statistics from "{ref_path}"...')
|
| 102 |
+
ref = None
|
| 103 |
+
if dist.get_rank() == 0:
|
| 104 |
+
assert ref_path.endswith('.npz')
|
| 105 |
+
ref = dict(np.load(ref_path))
|
| 106 |
+
|
| 107 |
+
mu, sigma = calculate_inception_stats(image_path=image_path, num_expected=num_expected, seed=seed, max_batch_size=batch)
|
| 108 |
+
mprint('Calculating FID...')
|
| 109 |
+
fid = None
|
| 110 |
+
if dist.get_rank() == 0:
|
| 111 |
+
fid = calculate_fid_from_inception_stats(mu, sigma, ref['mu'], ref['sigma'])
|
| 112 |
+
print(f'{fid:g}')
|
| 113 |
+
|
| 114 |
+
dist.barrier()
|
| 115 |
+
if dist.get_rank() == 0:
|
| 116 |
+
logger.close()
|
| 117 |
+
|
| 118 |
+
return fid
|
| 119 |
+
|
| 120 |
+
#----------------------------------------------------------------------------
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def ref(dataset_path, dest_path, batch):
|
| 124 |
+
"""Calculate dataset reference statistics needed by 'calc'."""
|
| 125 |
+
|
| 126 |
+
mu, sigma = calculate_inception_stats(image_path=dataset_path, max_batch_size=batch)
|
| 127 |
+
mprint(f'Saving dataset reference statistics to "{dest_path}"...')
|
| 128 |
+
if dist.get_rank() == 0:
|
| 129 |
+
if os.path.dirname(dest_path):
|
| 130 |
+
os.makedirs(os.path.dirname(dest_path), exist_ok=True)
|
| 131 |
+
np.savez(dest_path, mu=mu, sigma=sigma)
|
| 132 |
+
|
| 133 |
+
dist.barrier()
|
| 134 |
+
mprint('Done.')
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
if __name__ == '__main__':
|
| 138 |
+
parser = argparse.ArgumentParser('fid parameters')
|
| 139 |
+
|
| 140 |
+
# ddp
|
| 141 |
+
parser.add_argument('--num_proc_node', type=int, default=1, help='The number of nodes in multi node env.')
|
| 142 |
+
parser.add_argument('--num_process_per_node', type=int, default=1, help='number of gpus')
|
| 143 |
+
parser.add_argument('--node_rank', type=int, default=0, help='The index of node.')
|
| 144 |
+
parser.add_argument('--local_rank', type=int, default=0, help='rank of process in the node')
|
| 145 |
+
parser.add_argument('--master_address', type=str, default='localhost', help='address for master')
|
| 146 |
+
|
| 147 |
+
# fid
|
| 148 |
+
parser.add_argument('--mode', type=str, required=True, choices=['calc', 'ref'], help='Calcalute FID or store reference statistics')
|
| 149 |
+
parser.add_argument('--image_path', type=str, required=True, help='Path to the images')
|
| 150 |
+
parser.add_argument('--ref_path', type=str, default='assets/fid_stats/fid_stats_imagenet256_guided_diffusion.npz', help='Dataset reference statistics')
|
| 151 |
+
parser.add_argument('--num_expected', type=int, default=50000, help='Number of images to use')
|
| 152 |
+
parser.add_argument('--seed', type=int, default=0, help='Random seed for selecting the images')
|
| 153 |
+
parser.add_argument('--batch', type=int, default=64, help='Maximum batch size per GPU')
|
| 154 |
+
|
| 155 |
+
args = parser.parse_args()
|
| 156 |
+
args.global_size = args.num_proc_node * args.num_process_per_node
|
| 157 |
+
size = args.num_process_per_node
|
| 158 |
+
|
| 159 |
+
func = lambda args: calc(args.image_path, args.ref_path, args.num_expected, args.seed, args.batch) \
|
| 160 |
+
if args.mode == 'calc' else lambda args: ref(args.image_path, args.ref_path, args.batch)
|
| 161 |
+
|
| 162 |
+
if size > 1:
|
| 163 |
+
processes = []
|
| 164 |
+
for rank in range(size):
|
| 165 |
+
args.local_rank = rank
|
| 166 |
+
args.global_rank = rank + args.node_rank * args.num_process_per_node
|
| 167 |
+
p = Process(target=init_processes, args=(func, args))
|
| 168 |
+
p.start()
|
| 169 |
+
processes.append(p)
|
| 170 |
+
|
| 171 |
+
for p in processes:
|
| 172 |
+
p.join()
|
| 173 |
+
else:
|
| 174 |
+
print('Single GPU run')
|
| 175 |
+
assert args.global_size == 1 and args.local_rank == 0
|
| 176 |
+
args.global_rank = 0
|
| 177 |
+
init_processes(func, args)
|
generate.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MIT License
|
| 2 |
+
|
| 3 |
+
# Copyright (c) [2023] [Anima-Lab]
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
from argparse import ArgumentParser
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
from omegaconf import OmegaConf
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from models.maskdit import Precond_models
|
| 13 |
+
|
| 14 |
+
from sample import generate_with_net
|
| 15 |
+
from utils import parse_float_none, parse_int_list, init_processes
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def generate(args):
|
| 19 |
+
rank = args.global_rank
|
| 20 |
+
size = args.global_size
|
| 21 |
+
config = OmegaConf.load(args.config)
|
| 22 |
+
label_dict = json.load(open(args.label_dict, 'r'))
|
| 23 |
+
class_label = label_dict[str(args.class_idx)][1]
|
| 24 |
+
print(f'start sampling class {class_label}...')
|
| 25 |
+
device = torch.device('cuda')
|
| 26 |
+
# setup directory
|
| 27 |
+
sample_dir = os.path.join(args.results_dir, class_label)
|
| 28 |
+
os.makedirs(sample_dir, exist_ok=True)
|
| 29 |
+
args.outdir = sample_dir
|
| 30 |
+
# setup model
|
| 31 |
+
model = Precond_models[config.model.precond](
|
| 32 |
+
img_resolution=config.model.in_size,
|
| 33 |
+
img_channels=config.model.in_channels,
|
| 34 |
+
num_classes=config.model.num_classes,
|
| 35 |
+
model_type=config.model.model_type,
|
| 36 |
+
use_decoder=config.model.use_decoder,
|
| 37 |
+
mae_loss_coef=config.model.mae_loss_coef,
|
| 38 |
+
pad_cls_token=config.model.pad_cls_token,
|
| 39 |
+
use_encoder_feat=config.model.self_cond,
|
| 40 |
+
).to(device)
|
| 41 |
+
|
| 42 |
+
model.eval()
|
| 43 |
+
print(f"{config.model.model_type} ((use_decoder: {config.model.use_decoder})) Model Parameters: {sum(p.numel() for p in model.parameters()):,}")
|
| 44 |
+
print(f'extras: {model.model.extras}, cls_token: {model.model.cls_token}')
|
| 45 |
+
|
| 46 |
+
model = torch.compile(model)
|
| 47 |
+
ckpt = torch.load(args.ckpt_path, map_location=device)
|
| 48 |
+
model.load_state_dict(ckpt['ema'])
|
| 49 |
+
generate_with_net(args, model, device, rank, size)
|
| 50 |
+
|
| 51 |
+
print(f'sampling class {class_label} done!')
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
if __name__ == '__main__':
|
| 55 |
+
parser = ArgumentParser('Sample from a trained model')
|
| 56 |
+
# basic config
|
| 57 |
+
parser.add_argument('--config', type=str, required=True, help='path to config file')
|
| 58 |
+
parser.add_argument('--label_dict', type=str, default='assets/imagenet_label.json', help='path to label dict')
|
| 59 |
+
parser.add_argument("--results_dir", type=str, default="samples", help='path to save samples')
|
| 60 |
+
parser.add_argument('--ckpt_path', type=str, default=None, help='path to ckpt')
|
| 61 |
+
|
| 62 |
+
# sampling
|
| 63 |
+
parser.add_argument('--seeds', type=parse_int_list, default='100-131', help='Random seeds (e.g. 1,2,5-10)')
|
| 64 |
+
parser.add_argument('--subdirs', action='store_true', help='Create subdirectory for every 1000 seeds')
|
| 65 |
+
parser.add_argument('--class_idx', type=int, default=None, help='Class label [default: random]')
|
| 66 |
+
parser.add_argument("--cfg_scale", type=parse_float_none, default=None, help='None = no guidance, by default = 4.0')
|
| 67 |
+
|
| 68 |
+
parser.add_argument('--num_steps', type=int, default=40, help='Number of sampling steps')
|
| 69 |
+
parser.add_argument('--S_churn', type=int, default=0, help='Stochasticity strength')
|
| 70 |
+
parser.add_argument('--solver', type=str, default=None, choices=['euler', 'heun'], help='Ablate ODE solver')
|
| 71 |
+
parser.add_argument('--discretization', type=str, default=None, choices=['vp', 've', 'iddpm', 'edm'], help='Ablate ODE solver')
|
| 72 |
+
parser.add_argument('--schedule', type=str, default=None, choices=['vp', 've', 'linear'], help='Ablate noise schedule sigma(t)')
|
| 73 |
+
parser.add_argument('--scaling', type=str, default=None, choices=['vp', 'none'], help='Ablate signal scaling s(t)')
|
| 74 |
+
parser.add_argument('--pretrained_path', type=str, default='assets/autoencoder_kl.pth', help='Autoencoder ckpt')
|
| 75 |
+
|
| 76 |
+
parser.add_argument('--max_batch_size', type=int, default=32, help='Maximum batch size per GPU during sampling')
|
| 77 |
+
parser.add_argument('--num_expected', type=int, default=32, help='Number of images to use')
|
| 78 |
+
parser.add_argument("--global_seed", type=int, default=0)
|
| 79 |
+
parser.add_argument('--fid_batch_size', type=int, default=32, help='Maximum batch size')
|
| 80 |
+
|
| 81 |
+
# ddp
|
| 82 |
+
parser.add_argument('--num_proc_node', type=int, default=1, help='The number of nodes in multi node env.')
|
| 83 |
+
parser.add_argument('--num_process_per_node', type=int, default=1, help='number of gpus')
|
| 84 |
+
parser.add_argument('--node_rank', type=int, default=0, help='The index of node.')
|
| 85 |
+
parser.add_argument('--local_rank', type=int, default=0, help='rank of process in the node')
|
| 86 |
+
parser.add_argument('--master_address', type=str, default='localhost', help='address for master')
|
| 87 |
+
args = parser.parse_args()
|
| 88 |
+
args.global_rank = 0
|
| 89 |
+
args.local_rank = 0
|
| 90 |
+
args.global_size = 1
|
| 91 |
+
init_processes(generate, args)
|
licenses/LICENSE_ADM.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2021 OpenAI
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
licenses/LICENSE_DIT.txt
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
Attribution-NonCommercial 4.0 International
|
| 3 |
+
|
| 4 |
+
=======================================================================
|
| 5 |
+
|
| 6 |
+
Creative Commons Corporation ("Creative Commons") is not a law firm and
|
| 7 |
+
does not provide legal services or legal advice. Distribution of
|
| 8 |
+
Creative Commons public licenses does not create a lawyer-client or
|
| 9 |
+
other relationship. Creative Commons makes its licenses and related
|
| 10 |
+
information available on an "as-is" basis. Creative Commons gives no
|
| 11 |
+
warranties regarding its licenses, any material licensed under their
|
| 12 |
+
terms and conditions, or any related information. Creative Commons
|
| 13 |
+
disclaims all liability for damages resulting from their use to the
|
| 14 |
+
fullest extent possible.
|
| 15 |
+
|
| 16 |
+
Using Creative Commons Public Licenses
|
| 17 |
+
|
| 18 |
+
Creative Commons public licenses provide a standard set of terms and
|
| 19 |
+
conditions that creators and other rights holders may use to share
|
| 20 |
+
original works of authorship and other material subject to copyright
|
| 21 |
+
and certain other rights specified in the public license below. The
|
| 22 |
+
following considerations are for informational purposes only, are not
|
| 23 |
+
exhaustive, and do not form part of our licenses.
|
| 24 |
+
|
| 25 |
+
Considerations for licensors: Our public licenses are
|
| 26 |
+
intended for use by those authorized to give the public
|
| 27 |
+
permission to use material in ways otherwise restricted by
|
| 28 |
+
copyright and certain other rights. Our licenses are
|
| 29 |
+
irrevocable. Licensors should read and understand the terms
|
| 30 |
+
and conditions of the license they choose before applying it.
|
| 31 |
+
Licensors should also secure all rights necessary before
|
| 32 |
+
applying our licenses so that the public can reuse the
|
| 33 |
+
material as expected. Licensors should clearly mark any
|
| 34 |
+
material not subject to the license. This includes other CC-
|
| 35 |
+
licensed material, or material used under an exception or
|
| 36 |
+
limitation to copyright. More considerations for licensors:
|
| 37 |
+
wiki.creativecommons.org/Considerations_for_licensors
|
| 38 |
+
|
| 39 |
+
Considerations for the public: By using one of our public
|
| 40 |
+
licenses, a licensor grants the public permission to use the
|
| 41 |
+
licensed material under specified terms and conditions. If
|
| 42 |
+
the licensor's permission is not necessary for any reason--for
|
| 43 |
+
example, because of any applicable exception or limitation to
|
| 44 |
+
copyright--then that use is not regulated by the license. Our
|
| 45 |
+
licenses grant only permissions under copyright and certain
|
| 46 |
+
other rights that a licensor has authority to grant. Use of
|
| 47 |
+
the licensed material may still be restricted for other
|
| 48 |
+
reasons, including because others have copyright or other
|
| 49 |
+
rights in the material. A licensor may make special requests,
|
| 50 |
+
such as asking that all changes be marked or described.
|
| 51 |
+
Although not required by our licenses, you are encouraged to
|
| 52 |
+
respect those requests where reasonable. More_considerations
|
| 53 |
+
for the public:
|
| 54 |
+
wiki.creativecommons.org/Considerations_for_licensees
|
| 55 |
+
|
| 56 |
+
=======================================================================
|
| 57 |
+
|
| 58 |
+
Creative Commons Attribution-NonCommercial 4.0 International Public
|
| 59 |
+
License
|
| 60 |
+
|
| 61 |
+
By exercising the Licensed Rights (defined below), You accept and agree
|
| 62 |
+
to be bound by the terms and conditions of this Creative Commons
|
| 63 |
+
Attribution-NonCommercial 4.0 International Public License ("Public
|
| 64 |
+
License"). To the extent this Public License may be interpreted as a
|
| 65 |
+
contract, You are granted the Licensed Rights in consideration of Your
|
| 66 |
+
acceptance of these terms and conditions, and the Licensor grants You
|
| 67 |
+
such rights in consideration of benefits the Licensor receives from
|
| 68 |
+
making the Licensed Material available under these terms and
|
| 69 |
+
conditions.
|
| 70 |
+
|
| 71 |
+
Section 1 -- Definitions.
|
| 72 |
+
|
| 73 |
+
a. Adapted Material means material subject to Copyright and Similar
|
| 74 |
+
Rights that is derived from or based upon the Licensed Material
|
| 75 |
+
and in which the Licensed Material is translated, altered,
|
| 76 |
+
arranged, transformed, or otherwise modified in a manner requiring
|
| 77 |
+
permission under the Copyright and Similar Rights held by the
|
| 78 |
+
Licensor. For purposes of this Public License, where the Licensed
|
| 79 |
+
Material is a musical work, performance, or sound recording,
|
| 80 |
+
Adapted Material is always produced where the Licensed Material is
|
| 81 |
+
synched in timed relation with a moving image.
|
| 82 |
+
|
| 83 |
+
b. Adapter's License means the license You apply to Your Copyright
|
| 84 |
+
and Similar Rights in Your contributions to Adapted Material in
|
| 85 |
+
accordance with the terms and conditions of this Public License.
|
| 86 |
+
|
| 87 |
+
c. Copyright and Similar Rights means copyright and/or similar rights
|
| 88 |
+
closely related to copyright including, without limitation,
|
| 89 |
+
performance, broadcast, sound recording, and Sui Generis Database
|
| 90 |
+
Rights, without regard to how the rights are labeled or
|
| 91 |
+
categorized. For purposes of this Public License, the rights
|
| 92 |
+
specified in Section 2(b)(1)-(2) are not Copyright and Similar
|
| 93 |
+
Rights.
|
| 94 |
+
d. Effective Technological Measures means those measures that, in the
|
| 95 |
+
absence of proper authority, may not be circumvented under laws
|
| 96 |
+
fulfilling obligations under Article 11 of the WIPO Copyright
|
| 97 |
+
Treaty adopted on December 20, 1996, and/or similar international
|
| 98 |
+
agreements.
|
| 99 |
+
|
| 100 |
+
e. Exceptions and Limitations means fair use, fair dealing, and/or
|
| 101 |
+
any other exception or limitation to Copyright and Similar Rights
|
| 102 |
+
that applies to Your use of the Licensed Material.
|
| 103 |
+
|
| 104 |
+
f. Licensed Material means the artistic or literary work, database,
|
| 105 |
+
or other material to which the Licensor applied this Public
|
| 106 |
+
License.
|
| 107 |
+
|
| 108 |
+
g. Licensed Rights means the rights granted to You subject to the
|
| 109 |
+
terms and conditions of this Public License, which are limited to
|
| 110 |
+
all Copyright and Similar Rights that apply to Your use of the
|
| 111 |
+
Licensed Material and that the Licensor has authority to license.
|
| 112 |
+
|
| 113 |
+
h. Licensor means the individual(s) or entity(ies) granting rights
|
| 114 |
+
under this Public License.
|
| 115 |
+
|
| 116 |
+
i. NonCommercial means not primarily intended for or directed towards
|
| 117 |
+
commercial advantage or monetary compensation. For purposes of
|
| 118 |
+
this Public License, the exchange of the Licensed Material for
|
| 119 |
+
other material subject to Copyright and Similar Rights by digital
|
| 120 |
+
file-sharing or similar means is NonCommercial provided there is
|
| 121 |
+
no payment of monetary compensation in connection with the
|
| 122 |
+
exchange.
|
| 123 |
+
|
| 124 |
+
j. Share means to provide material to the public by any means or
|
| 125 |
+
process that requires permission under the Licensed Rights, such
|
| 126 |
+
as reproduction, public display, public performance, distribution,
|
| 127 |
+
dissemination, communication, or importation, and to make material
|
| 128 |
+
available to the public including in ways that members of the
|
| 129 |
+
public may access the material from a place and at a time
|
| 130 |
+
individually chosen by them.
|
| 131 |
+
|
| 132 |
+
k. Sui Generis Database Rights means rights other than copyright
|
| 133 |
+
resulting from Directive 96/9/EC of the European Parliament and of
|
| 134 |
+
the Council of 11 March 1996 on the legal protection of databases,
|
| 135 |
+
as amended and/or succeeded, as well as other essentially
|
| 136 |
+
equivalent rights anywhere in the world.
|
| 137 |
+
|
| 138 |
+
l. You means the individual or entity exercising the Licensed Rights
|
| 139 |
+
under this Public License. Your has a corresponding meaning.
|
| 140 |
+
|
| 141 |
+
Section 2 -- Scope.
|
| 142 |
+
|
| 143 |
+
a. License grant.
|
| 144 |
+
|
| 145 |
+
1. Subject to the terms and conditions of this Public License,
|
| 146 |
+
the Licensor hereby grants You a worldwide, royalty-free,
|
| 147 |
+
non-sublicensable, non-exclusive, irrevocable license to
|
| 148 |
+
exercise the Licensed Rights in the Licensed Material to:
|
| 149 |
+
|
| 150 |
+
a. reproduce and Share the Licensed Material, in whole or
|
| 151 |
+
in part, for NonCommercial purposes only; and
|
| 152 |
+
|
| 153 |
+
b. produce, reproduce, and Share Adapted Material for
|
| 154 |
+
NonCommercial purposes only.
|
| 155 |
+
|
| 156 |
+
2. Exceptions and Limitations. For the avoidance of doubt, where
|
| 157 |
+
Exceptions and Limitations apply to Your use, this Public
|
| 158 |
+
License does not apply, and You do not need to comply with
|
| 159 |
+
its terms and conditions.
|
| 160 |
+
|
| 161 |
+
3. Term. The term of this Public License is specified in Section
|
| 162 |
+
6(a).
|
| 163 |
+
|
| 164 |
+
4. Media and formats; technical modifications allowed. The
|
| 165 |
+
Licensor authorizes You to exercise the Licensed Rights in
|
| 166 |
+
all media and formats whether now known or hereafter created,
|
| 167 |
+
and to make technical modifications necessary to do so. The
|
| 168 |
+
Licensor waives and/or agrees not to assert any right or
|
| 169 |
+
authority to forbid You from making technical modifications
|
| 170 |
+
necessary to exercise the Licensed Rights, including
|
| 171 |
+
technical modifications necessary to circumvent Effective
|
| 172 |
+
Technological Measures. For purposes of this Public License,
|
| 173 |
+
simply making modifications authorized by this Section 2(a)
|
| 174 |
+
(4) never produces Adapted Material.
|
| 175 |
+
|
| 176 |
+
5. Downstream recipients.
|
| 177 |
+
|
| 178 |
+
a. Offer from the Licensor -- Licensed Material. Every
|
| 179 |
+
recipient of the Licensed Material automatically
|
| 180 |
+
receives an offer from the Licensor to exercise the
|
| 181 |
+
Licensed Rights under the terms and conditions of this
|
| 182 |
+
Public License.
|
| 183 |
+
|
| 184 |
+
b. No downstream restrictions. You may not offer or impose
|
| 185 |
+
any additional or different terms or conditions on, or
|
| 186 |
+
apply any Effective Technological Measures to, the
|
| 187 |
+
Licensed Material if doing so restricts exercise of the
|
| 188 |
+
Licensed Rights by any recipient of the Licensed
|
| 189 |
+
Material.
|
| 190 |
+
|
| 191 |
+
6. No endorsement. Nothing in this Public License constitutes or
|
| 192 |
+
may be construed as permission to assert or imply that You
|
| 193 |
+
are, or that Your use of the Licensed Material is, connected
|
| 194 |
+
with, or sponsored, endorsed, or granted official status by,
|
| 195 |
+
the Licensor or others designated to receive attribution as
|
| 196 |
+
provided in Section 3(a)(1)(A)(i).
|
| 197 |
+
|
| 198 |
+
b. Other rights.
|
| 199 |
+
|
| 200 |
+
1. Moral rights, such as the right of integrity, are not
|
| 201 |
+
licensed under this Public License, nor are publicity,
|
| 202 |
+
privacy, and/or other similar personality rights; however, to
|
| 203 |
+
the extent possible, the Licensor waives and/or agrees not to
|
| 204 |
+
assert any such rights held by the Licensor to the limited
|
| 205 |
+
extent necessary to allow You to exercise the Licensed
|
| 206 |
+
Rights, but not otherwise.
|
| 207 |
+
|
| 208 |
+
2. Patent and trademark rights are not licensed under this
|
| 209 |
+
Public License.
|
| 210 |
+
|
| 211 |
+
3. To the extent possible, the Licensor waives any right to
|
| 212 |
+
collect royalties from You for the exercise of the Licensed
|
| 213 |
+
Rights, whether directly or through a collecting society
|
| 214 |
+
under any voluntary or waivable statutory or compulsory
|
| 215 |
+
licensing scheme. In all other cases the Licensor expressly
|
| 216 |
+
reserves any right to collect such royalties, including when
|
| 217 |
+
the Licensed Material is used other than for NonCommercial
|
| 218 |
+
purposes.
|
| 219 |
+
|
| 220 |
+
Section 3 -- License Conditions.
|
| 221 |
+
|
| 222 |
+
Your exercise of the Licensed Rights is expressly made subject to the
|
| 223 |
+
following conditions.
|
| 224 |
+
|
| 225 |
+
a. Attribution.
|
| 226 |
+
|
| 227 |
+
1. If You Share the Licensed Material (including in modified
|
| 228 |
+
form), You must:
|
| 229 |
+
|
| 230 |
+
a. retain the following if it is supplied by the Licensor
|
| 231 |
+
with the Licensed Material:
|
| 232 |
+
|
| 233 |
+
i. identification of the creator(s) of the Licensed
|
| 234 |
+
Material and any others designated to receive
|
| 235 |
+
attribution, in any reasonable manner requested by
|
| 236 |
+
the Licensor (including by pseudonym if
|
| 237 |
+
designated);
|
| 238 |
+
|
| 239 |
+
ii. a copyright notice;
|
| 240 |
+
|
| 241 |
+
iii. a notice that refers to this Public License;
|
| 242 |
+
|
| 243 |
+
iv. a notice that refers to the disclaimer of
|
| 244 |
+
warranties;
|
| 245 |
+
|
| 246 |
+
v. a URI or hyperlink to the Licensed Material to the
|
| 247 |
+
extent reasonably practicable;
|
| 248 |
+
|
| 249 |
+
b. indicate if You modified the Licensed Material and
|
| 250 |
+
retain an indication of any previous modifications; and
|
| 251 |
+
|
| 252 |
+
c. indicate the Licensed Material is licensed under this
|
| 253 |
+
Public License, and include the text of, or the URI or
|
| 254 |
+
hyperlink to, this Public License.
|
| 255 |
+
|
| 256 |
+
2. You may satisfy the conditions in Section 3(a)(1) in any
|
| 257 |
+
reasonable manner based on the medium, means, and context in
|
| 258 |
+
which You Share the Licensed Material. For example, it may be
|
| 259 |
+
reasonable to satisfy the conditions by providing a URI or
|
| 260 |
+
hyperlink to a resource that includes the required
|
| 261 |
+
information.
|
| 262 |
+
|
| 263 |
+
3. If requested by the Licensor, You must remove any of the
|
| 264 |
+
information required by Section 3(a)(1)(A) to the extent
|
| 265 |
+
reasonably practicable.
|
| 266 |
+
|
| 267 |
+
4. If You Share Adapted Material You produce, the Adapter's
|
| 268 |
+
License You apply must not prevent recipients of the Adapted
|
| 269 |
+
Material from complying with this Public License.
|
| 270 |
+
|
| 271 |
+
Section 4 -- Sui Generis Database Rights.
|
| 272 |
+
|
| 273 |
+
Where the Licensed Rights include Sui Generis Database Rights that
|
| 274 |
+
apply to Your use of the Licensed Material:
|
| 275 |
+
|
| 276 |
+
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
| 277 |
+
to extract, reuse, reproduce, and Share all or a substantial
|
| 278 |
+
portion of the contents of the database for NonCommercial purposes
|
| 279 |
+
only;
|
| 280 |
+
|
| 281 |
+
b. if You include all or a substantial portion of the database
|
| 282 |
+
contents in a database in which You have Sui Generis Database
|
| 283 |
+
Rights, then the database in which You have Sui Generis Database
|
| 284 |
+
Rights (but not its individual contents) is Adapted Material; and
|
| 285 |
+
|
| 286 |
+
c. You must comply with the conditions in Section 3(a) if You Share
|
| 287 |
+
all or a substantial portion of the contents of the database.
|
| 288 |
+
|
| 289 |
+
For the avoidance of doubt, this Section 4 supplements and does not
|
| 290 |
+
replace Your obligations under this Public License where the Licensed
|
| 291 |
+
Rights include other Copyright and Similar Rights.
|
| 292 |
+
|
| 293 |
+
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
| 294 |
+
|
| 295 |
+
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
| 296 |
+
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
| 297 |
+
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
| 298 |
+
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
| 299 |
+
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
| 300 |
+
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
| 301 |
+
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
| 302 |
+
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
| 303 |
+
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
| 304 |
+
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
| 305 |
+
|
| 306 |
+
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
| 307 |
+
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
| 308 |
+
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
| 309 |
+
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
| 310 |
+
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
| 311 |
+
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
| 312 |
+
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
| 313 |
+
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
| 314 |
+
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
| 315 |
+
|
| 316 |
+
c. The disclaimer of warranties and limitation of liability provided
|
| 317 |
+
above shall be interpreted in a manner that, to the extent
|
| 318 |
+
possible, most closely approximates an absolute disclaimer and
|
| 319 |
+
waiver of all liability.
|
| 320 |
+
|
| 321 |
+
Section 6 -- Term and Termination.
|
| 322 |
+
|
| 323 |
+
a. This Public License applies for the term of the Copyright and
|
| 324 |
+
Similar Rights licensed here. However, if You fail to comply with
|
| 325 |
+
this Public License, then Your rights under this Public License
|
| 326 |
+
terminate automatically.
|
| 327 |
+
|
| 328 |
+
b. Where Your right to use the Licensed Material has terminated under
|
| 329 |
+
Section 6(a), it reinstates:
|
| 330 |
+
|
| 331 |
+
1. automatically as of the date the violation is cured, provided
|
| 332 |
+
it is cured within 30 days of Your discovery of the
|
| 333 |
+
violation; or
|
| 334 |
+
|
| 335 |
+
2. upon express reinstatement by the Licensor.
|
| 336 |
+
|
| 337 |
+
For the avoidance of doubt, this Section 6(b) does not affect any
|
| 338 |
+
right the Licensor may have to seek remedies for Your violations
|
| 339 |
+
of this Public License.
|
| 340 |
+
|
| 341 |
+
c. For the avoidance of doubt, the Licensor may also offer the
|
| 342 |
+
Licensed Material under separate terms or conditions or stop
|
| 343 |
+
distributing the Licensed Material at any time; however, doing so
|
| 344 |
+
will not terminate this Public License.
|
| 345 |
+
|
| 346 |
+
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
| 347 |
+
License.
|
| 348 |
+
|
| 349 |
+
Section 7 -- Other Terms and Conditions.
|
| 350 |
+
|
| 351 |
+
a. The Licensor shall not be bound by any additional or different
|
| 352 |
+
terms or conditions communicated by You unless expressly agreed.
|
| 353 |
+
|
| 354 |
+
b. Any arrangements, understandings, or agreements regarding the
|
| 355 |
+
Licensed Material not stated herein are separate from and
|
| 356 |
+
independent of the terms and conditions of this Public License.
|
| 357 |
+
|
| 358 |
+
Section 8 -- Interpretation.
|
| 359 |
+
|
| 360 |
+
a. For the avoidance of doubt, this Public License does not, and
|
| 361 |
+
shall not be interpreted to, reduce, limit, restrict, or impose
|
| 362 |
+
conditions on any use of the Licensed Material that could lawfully
|
| 363 |
+
be made without permission under this Public License.
|
| 364 |
+
|
| 365 |
+
b. To the extent possible, if any provision of this Public License is
|
| 366 |
+
deemed unenforceable, it shall be automatically reformed to the
|
| 367 |
+
minimum extent necessary to make it enforceable. If the provision
|
| 368 |
+
cannot be reformed, it shall be severed from this Public License
|
| 369 |
+
without affecting the enforceability of the remaining terms and
|
| 370 |
+
conditions.
|
| 371 |
+
|
| 372 |
+
c. No term or condition of this Public License will be waived and no
|
| 373 |
+
failure to comply consented to unless expressly agreed to by the
|
| 374 |
+
Licensor.
|
| 375 |
+
|
| 376 |
+
d. Nothing in this Public License constitutes or may be interpreted
|
| 377 |
+
as a limitation upon, or waiver of, any privileges and immunities
|
| 378 |
+
that apply to the Licensor or You, including from the legal
|
| 379 |
+
processes of any jurisdiction or authority.
|
| 380 |
+
|
| 381 |
+
=======================================================================
|
| 382 |
+
|
| 383 |
+
Creative Commons is not a party to its public
|
| 384 |
+
licenses. Notwithstanding, Creative Commons may elect to apply one of
|
| 385 |
+
its public licenses to material it publishes and in those instances
|
| 386 |
+
will be considered the “Licensor.” The text of the Creative Commons
|
| 387 |
+
public licenses is dedicated to the public domain under the CC0 Public
|
| 388 |
+
Domain Dedication. Except for the limited purpose of indicating that
|
| 389 |
+
material is shared under a Creative Commons public license or as
|
| 390 |
+
otherwise permitted by the Creative Commons policies published at
|
| 391 |
+
creativecommons.org/policies, Creative Commons does not authorize the
|
| 392 |
+
use of the trademark "Creative Commons" or any other trademark or logo
|
| 393 |
+
of Creative Commons without its prior written consent including,
|
| 394 |
+
without limitation, in connection with any unauthorized modifications
|
| 395 |
+
to any of its public licenses or any other arrangements,
|
| 396 |
+
understandings, or agreements concerning use of licensed material. For
|
| 397 |
+
the avoidance of doubt, this paragraph does not form part of the
|
| 398 |
+
public licenses.
|
| 399 |
+
|
| 400 |
+
Creative Commons may be contacted at creativecommons.org.
|
licenses/LICENSE_EDM.txt
ADDED
|
@@ -0,0 +1,439 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
|
| 3 |
+
Attribution-NonCommercial-ShareAlike 4.0 International
|
| 4 |
+
|
| 5 |
+
=======================================================================
|
| 6 |
+
|
| 7 |
+
Creative Commons Corporation ("Creative Commons") is not a law firm and
|
| 8 |
+
does not provide legal services or legal advice. Distribution of
|
| 9 |
+
Creative Commons public licenses does not create a lawyer-client or
|
| 10 |
+
other relationship. Creative Commons makes its licenses and related
|
| 11 |
+
information available on an "as-is" basis. Creative Commons gives no
|
| 12 |
+
warranties regarding its licenses, any material licensed under their
|
| 13 |
+
terms and conditions, or any related information. Creative Commons
|
| 14 |
+
disclaims all liability for damages resulting from their use to the
|
| 15 |
+
fullest extent possible.
|
| 16 |
+
|
| 17 |
+
Using Creative Commons Public Licenses
|
| 18 |
+
|
| 19 |
+
Creative Commons public licenses provide a standard set of terms and
|
| 20 |
+
conditions that creators and other rights holders may use to share
|
| 21 |
+
original works of authorship and other material subject to copyright
|
| 22 |
+
and certain other rights specified in the public license below. The
|
| 23 |
+
following considerations are for informational purposes only, are not
|
| 24 |
+
exhaustive, and do not form part of our licenses.
|
| 25 |
+
|
| 26 |
+
Considerations for licensors: Our public licenses are
|
| 27 |
+
intended for use by those authorized to give the public
|
| 28 |
+
permission to use material in ways otherwise restricted by
|
| 29 |
+
copyright and certain other rights. Our licenses are
|
| 30 |
+
irrevocable. Licensors should read and understand the terms
|
| 31 |
+
and conditions of the license they choose before applying it.
|
| 32 |
+
Licensors should also secure all rights necessary before
|
| 33 |
+
applying our licenses so that the public can reuse the
|
| 34 |
+
material as expected. Licensors should clearly mark any
|
| 35 |
+
material not subject to the license. This includes other CC-
|
| 36 |
+
licensed material, or material used under an exception or
|
| 37 |
+
limitation to copyright. More considerations for licensors:
|
| 38 |
+
wiki.creativecommons.org/Considerations_for_licensors
|
| 39 |
+
|
| 40 |
+
Considerations for the public: By using one of our public
|
| 41 |
+
licenses, a licensor grants the public permission to use the
|
| 42 |
+
licensed material under specified terms and conditions. If
|
| 43 |
+
the licensor's permission is not necessary for any reason--for
|
| 44 |
+
example, because of any applicable exception or limitation to
|
| 45 |
+
copyright--then that use is not regulated by the license. Our
|
| 46 |
+
licenses grant only permissions under copyright and certain
|
| 47 |
+
other rights that a licensor has authority to grant. Use of
|
| 48 |
+
the licensed material may still be restricted for other
|
| 49 |
+
reasons, including because others have copyright or other
|
| 50 |
+
rights in the material. A licensor may make special requests,
|
| 51 |
+
such as asking that all changes be marked or described.
|
| 52 |
+
Although not required by our licenses, you are encouraged to
|
| 53 |
+
respect those requests where reasonable. More considerations
|
| 54 |
+
for the public:
|
| 55 |
+
wiki.creativecommons.org/Considerations_for_licensees
|
| 56 |
+
|
| 57 |
+
=======================================================================
|
| 58 |
+
|
| 59 |
+
Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
|
| 60 |
+
Public License
|
| 61 |
+
|
| 62 |
+
By exercising the Licensed Rights (defined below), You accept and agree
|
| 63 |
+
to be bound by the terms and conditions of this Creative Commons
|
| 64 |
+
Attribution-NonCommercial-ShareAlike 4.0 International Public License
|
| 65 |
+
("Public License"). To the extent this Public License may be
|
| 66 |
+
interpreted as a contract, You are granted the Licensed Rights in
|
| 67 |
+
consideration of Your acceptance of these terms and conditions, and the
|
| 68 |
+
Licensor grants You such rights in consideration of benefits the
|
| 69 |
+
Licensor receives from making the Licensed Material available under
|
| 70 |
+
these terms and conditions.
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
Section 1 -- Definitions.
|
| 74 |
+
|
| 75 |
+
a. Adapted Material means material subject to Copyright and Similar
|
| 76 |
+
Rights that is derived from or based upon the Licensed Material
|
| 77 |
+
and in which the Licensed Material is translated, altered,
|
| 78 |
+
arranged, transformed, or otherwise modified in a manner requiring
|
| 79 |
+
permission under the Copyright and Similar Rights held by the
|
| 80 |
+
Licensor. For purposes of this Public License, where the Licensed
|
| 81 |
+
Material is a musical work, performance, or sound recording,
|
| 82 |
+
Adapted Material is always produced where the Licensed Material is
|
| 83 |
+
synched in timed relation with a moving image.
|
| 84 |
+
|
| 85 |
+
b. Adapter's License means the license You apply to Your Copyright
|
| 86 |
+
and Similar Rights in Your contributions to Adapted Material in
|
| 87 |
+
accordance with the terms and conditions of this Public License.
|
| 88 |
+
|
| 89 |
+
c. BY-NC-SA Compatible License means a license listed at
|
| 90 |
+
creativecommons.org/compatiblelicenses, approved by Creative
|
| 91 |
+
Commons as essentially the equivalent of this Public License.
|
| 92 |
+
|
| 93 |
+
d. Copyright and Similar Rights means copyright and/or similar rights
|
| 94 |
+
closely related to copyright including, without limitation,
|
| 95 |
+
performance, broadcast, sound recording, and Sui Generis Database
|
| 96 |
+
Rights, without regard to how the rights are labeled or
|
| 97 |
+
categorized. For purposes of this Public License, the rights
|
| 98 |
+
specified in Section 2(b)(1)-(2) are not Copyright and Similar
|
| 99 |
+
Rights.
|
| 100 |
+
|
| 101 |
+
e. Effective Technological Measures means those measures that, in the
|
| 102 |
+
absence of proper authority, may not be circumvented under laws
|
| 103 |
+
fulfilling obligations under Article 11 of the WIPO Copyright
|
| 104 |
+
Treaty adopted on December 20, 1996, and/or similar international
|
| 105 |
+
agreements.
|
| 106 |
+
|
| 107 |
+
f. Exceptions and Limitations means fair use, fair dealing, and/or
|
| 108 |
+
any other exception or limitation to Copyright and Similar Rights
|
| 109 |
+
that applies to Your use of the Licensed Material.
|
| 110 |
+
|
| 111 |
+
g. License Elements means the license attributes listed in the name
|
| 112 |
+
of a Creative Commons Public License. The License Elements of this
|
| 113 |
+
Public License are Attribution, NonCommercial, and ShareAlike.
|
| 114 |
+
|
| 115 |
+
h. Licensed Material means the artistic or literary work, database,
|
| 116 |
+
or other material to which the Licensor applied this Public
|
| 117 |
+
License.
|
| 118 |
+
|
| 119 |
+
i. Licensed Rights means the rights granted to You subject to the
|
| 120 |
+
terms and conditions of this Public License, which are limited to
|
| 121 |
+
all Copyright and Similar Rights that apply to Your use of the
|
| 122 |
+
Licensed Material and that the Licensor has authority to license.
|
| 123 |
+
|
| 124 |
+
j. Licensor means the individual(s) or entity(ies) granting rights
|
| 125 |
+
under this Public License.
|
| 126 |
+
|
| 127 |
+
k. NonCommercial means not primarily intended for or directed towards
|
| 128 |
+
commercial advantage or monetary compensation. For purposes of
|
| 129 |
+
this Public License, the exchange of the Licensed Material for
|
| 130 |
+
other material subject to Copyright and Similar Rights by digital
|
| 131 |
+
file-sharing or similar means is NonCommercial provided there is
|
| 132 |
+
no payment of monetary compensation in connection with the
|
| 133 |
+
exchange.
|
| 134 |
+
|
| 135 |
+
l. Share means to provide material to the public by any means or
|
| 136 |
+
process that requires permission under the Licensed Rights, such
|
| 137 |
+
as reproduction, public display, public performance, distribution,
|
| 138 |
+
dissemination, communication, or importation, and to make material
|
| 139 |
+
available to the public including in ways that members of the
|
| 140 |
+
public may access the material from a place and at a time
|
| 141 |
+
individually chosen by them.
|
| 142 |
+
|
| 143 |
+
m. Sui Generis Database Rights means rights other than copyright
|
| 144 |
+
resulting from Directive 96/9/EC of the European Parliament and of
|
| 145 |
+
the Council of 11 March 1996 on the legal protection of databases,
|
| 146 |
+
as amended and/or succeeded, as well as other essentially
|
| 147 |
+
equivalent rights anywhere in the world.
|
| 148 |
+
|
| 149 |
+
n. You means the individual or entity exercising the Licensed Rights
|
| 150 |
+
under this Public License. Your has a corresponding meaning.
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
Section 2 -- Scope.
|
| 154 |
+
|
| 155 |
+
a. License grant.
|
| 156 |
+
|
| 157 |
+
1. Subject to the terms and conditions of this Public License,
|
| 158 |
+
the Licensor hereby grants You a worldwide, royalty-free,
|
| 159 |
+
non-sublicensable, non-exclusive, irrevocable license to
|
| 160 |
+
exercise the Licensed Rights in the Licensed Material to:
|
| 161 |
+
|
| 162 |
+
a. reproduce and Share the Licensed Material, in whole or
|
| 163 |
+
in part, for NonCommercial purposes only; and
|
| 164 |
+
|
| 165 |
+
b. produce, reproduce, and Share Adapted Material for
|
| 166 |
+
NonCommercial purposes only.
|
| 167 |
+
|
| 168 |
+
2. Exceptions and Limitations. For the avoidance of doubt, where
|
| 169 |
+
Exceptions and Limitations apply to Your use, this Public
|
| 170 |
+
License does not apply, and You do not need to comply with
|
| 171 |
+
its terms and conditions.
|
| 172 |
+
|
| 173 |
+
3. Term. The term of this Public License is specified in Section
|
| 174 |
+
6(a).
|
| 175 |
+
|
| 176 |
+
4. Media and formats; technical modifications allowed. The
|
| 177 |
+
Licensor authorizes You to exercise the Licensed Rights in
|
| 178 |
+
all media and formats whether now known or hereafter created,
|
| 179 |
+
and to make technical modifications necessary to do so. The
|
| 180 |
+
Licensor waives and/or agrees not to assert any right or
|
| 181 |
+
authority to forbid You from making technical modifications
|
| 182 |
+
necessary to exercise the Licensed Rights, including
|
| 183 |
+
technical modifications necessary to circumvent Effective
|
| 184 |
+
Technological Measures. For purposes of this Public License,
|
| 185 |
+
simply making modifications authorized by this Section 2(a)
|
| 186 |
+
(4) never produces Adapted Material.
|
| 187 |
+
|
| 188 |
+
5. Downstream recipients.
|
| 189 |
+
|
| 190 |
+
a. Offer from the Licensor -- Licensed Material. Every
|
| 191 |
+
recipient of the Licensed Material automatically
|
| 192 |
+
receives an offer from the Licensor to exercise the
|
| 193 |
+
Licensed Rights under the terms and conditions of this
|
| 194 |
+
Public License.
|
| 195 |
+
|
| 196 |
+
b. Additional offer from the Licensor -- Adapted Material.
|
| 197 |
+
Every recipient of Adapted Material from You
|
| 198 |
+
automatically receives an offer from the Licensor to
|
| 199 |
+
exercise the Licensed Rights in the Adapted Material
|
| 200 |
+
under the conditions of the Adapter's License You apply.
|
| 201 |
+
|
| 202 |
+
c. No downstream restrictions. You may not offer or impose
|
| 203 |
+
any additional or different terms or conditions on, or
|
| 204 |
+
apply any Effective Technological Measures to, the
|
| 205 |
+
Licensed Material if doing so restricts exercise of the
|
| 206 |
+
Licensed Rights by any recipient of the Licensed
|
| 207 |
+
Material.
|
| 208 |
+
|
| 209 |
+
6. No endorsement. Nothing in this Public License constitutes or
|
| 210 |
+
may be construed as permission to assert or imply that You
|
| 211 |
+
are, or that Your use of the Licensed Material is, connected
|
| 212 |
+
with, or sponsored, endorsed, or granted official status by,
|
| 213 |
+
the Licensor or others designated to receive attribution as
|
| 214 |
+
provided in Section 3(a)(1)(A)(i).
|
| 215 |
+
|
| 216 |
+
b. Other rights.
|
| 217 |
+
|
| 218 |
+
1. Moral rights, such as the right of integrity, are not
|
| 219 |
+
licensed under this Public License, nor are publicity,
|
| 220 |
+
privacy, and/or other similar personality rights; however, to
|
| 221 |
+
the extent possible, the Licensor waives and/or agrees not to
|
| 222 |
+
assert any such rights held by the Licensor to the limited
|
| 223 |
+
extent necessary to allow You to exercise the Licensed
|
| 224 |
+
Rights, but not otherwise.
|
| 225 |
+
|
| 226 |
+
2. Patent and trademark rights are not licensed under this
|
| 227 |
+
Public License.
|
| 228 |
+
|
| 229 |
+
3. To the extent possible, the Licensor waives any right to
|
| 230 |
+
collect royalties from You for the exercise of the Licensed
|
| 231 |
+
Rights, whether directly or through a collecting society
|
| 232 |
+
under any voluntary or waivable statutory or compulsory
|
| 233 |
+
licensing scheme. In all other cases the Licensor expressly
|
| 234 |
+
reserves any right to collect such royalties, including when
|
| 235 |
+
the Licensed Material is used other than for NonCommercial
|
| 236 |
+
purposes.
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
Section 3 -- License Conditions.
|
| 240 |
+
|
| 241 |
+
Your exercise of the Licensed Rights is expressly made subject to the
|
| 242 |
+
following conditions.
|
| 243 |
+
|
| 244 |
+
a. Attribution.
|
| 245 |
+
|
| 246 |
+
1. If You Share the Licensed Material (including in modified
|
| 247 |
+
form), You must:
|
| 248 |
+
|
| 249 |
+
a. retain the following if it is supplied by the Licensor
|
| 250 |
+
with the Licensed Material:
|
| 251 |
+
|
| 252 |
+
i. identification of the creator(s) of the Licensed
|
| 253 |
+
Material and any others designated to receive
|
| 254 |
+
attribution, in any reasonable manner requested by
|
| 255 |
+
the Licensor (including by pseudonym if
|
| 256 |
+
designated);
|
| 257 |
+
|
| 258 |
+
ii. a copyright notice;
|
| 259 |
+
|
| 260 |
+
iii. a notice that refers to this Public License;
|
| 261 |
+
|
| 262 |
+
iv. a notice that refers to the disclaimer of
|
| 263 |
+
warranties;
|
| 264 |
+
|
| 265 |
+
v. a URI or hyperlink to the Licensed Material to the
|
| 266 |
+
extent reasonably practicable;
|
| 267 |
+
|
| 268 |
+
b. indicate if You modified the Licensed Material and
|
| 269 |
+
retain an indication of any previous modifications; and
|
| 270 |
+
|
| 271 |
+
c. indicate the Licensed Material is licensed under this
|
| 272 |
+
Public License, and include the text of, or the URI or
|
| 273 |
+
hyperlink to, this Public License.
|
| 274 |
+
|
| 275 |
+
2. You may satisfy the conditions in Section 3(a)(1) in any
|
| 276 |
+
reasonable manner based on the medium, means, and context in
|
| 277 |
+
which You Share the Licensed Material. For example, it may be
|
| 278 |
+
reasonable to satisfy the conditions by providing a URI or
|
| 279 |
+
hyperlink to a resource that includes the required
|
| 280 |
+
information.
|
| 281 |
+
3. If requested by the Licensor, You must remove any of the
|
| 282 |
+
information required by Section 3(a)(1)(A) to the extent
|
| 283 |
+
reasonably practicable.
|
| 284 |
+
|
| 285 |
+
b. ShareAlike.
|
| 286 |
+
|
| 287 |
+
In addition to the conditions in Section 3(a), if You Share
|
| 288 |
+
Adapted Material You produce, the following conditions also apply.
|
| 289 |
+
|
| 290 |
+
1. The Adapter's License You apply must be a Creative Commons
|
| 291 |
+
license with the same License Elements, this version or
|
| 292 |
+
later, or a BY-NC-SA Compatible License.
|
| 293 |
+
|
| 294 |
+
2. You must include the text of, or the URI or hyperlink to, the
|
| 295 |
+
Adapter's License You apply. You may satisfy this condition
|
| 296 |
+
in any reasonable manner based on the medium, means, and
|
| 297 |
+
context in which You Share Adapted Material.
|
| 298 |
+
|
| 299 |
+
3. You may not offer or impose any additional or different terms
|
| 300 |
+
or conditions on, or apply any Effective Technological
|
| 301 |
+
Measures to, Adapted Material that restrict exercise of the
|
| 302 |
+
rights granted under the Adapter's License You apply.
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
Section 4 -- Sui Generis Database Rights.
|
| 306 |
+
|
| 307 |
+
Where the Licensed Rights include Sui Generis Database Rights that
|
| 308 |
+
apply to Your use of the Licensed Material:
|
| 309 |
+
|
| 310 |
+
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
| 311 |
+
to extract, reuse, reproduce, and Share all or a substantial
|
| 312 |
+
portion of the contents of the database for NonCommercial purposes
|
| 313 |
+
only;
|
| 314 |
+
|
| 315 |
+
b. if You include all or a substantial portion of the database
|
| 316 |
+
contents in a database in which You have Sui Generis Database
|
| 317 |
+
Rights, then the database in which You have Sui Generis Database
|
| 318 |
+
Rights (but not its individual contents) is Adapted Material,
|
| 319 |
+
including for purposes of Section 3(b); and
|
| 320 |
+
|
| 321 |
+
c. You must comply with the conditions in Section 3(a) if You Share
|
| 322 |
+
all or a substantial portion of the contents of the database.
|
| 323 |
+
|
| 324 |
+
For the avoidance of doubt, this Section 4 supplements and does not
|
| 325 |
+
replace Your obligations under this Public License where the Licensed
|
| 326 |
+
Rights include other Copyright and Similar Rights.
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
| 330 |
+
|
| 331 |
+
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
| 332 |
+
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
| 333 |
+
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
| 334 |
+
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
| 335 |
+
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
| 336 |
+
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
| 337 |
+
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
| 338 |
+
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
| 339 |
+
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
| 340 |
+
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
| 341 |
+
|
| 342 |
+
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
| 343 |
+
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
| 344 |
+
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
| 345 |
+
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
| 346 |
+
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
| 347 |
+
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
| 348 |
+
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
| 349 |
+
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
| 350 |
+
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
| 351 |
+
|
| 352 |
+
c. The disclaimer of warranties and limitation of liability provided
|
| 353 |
+
above shall be interpreted in a manner that, to the extent
|
| 354 |
+
possible, most closely approximates an absolute disclaimer and
|
| 355 |
+
waiver of all liability.
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
Section 6 -- Term and Termination.
|
| 359 |
+
|
| 360 |
+
a. This Public License applies for the term of the Copyright and
|
| 361 |
+
Similar Rights licensed here. However, if You fail to comply with
|
| 362 |
+
this Public License, then Your rights under this Public License
|
| 363 |
+
terminate automatically.
|
| 364 |
+
|
| 365 |
+
b. Where Your right to use the Licensed Material has terminated under
|
| 366 |
+
Section 6(a), it reinstates:
|
| 367 |
+
|
| 368 |
+
1. automatically as of the date the violation is cured, provided
|
| 369 |
+
it is cured within 30 days of Your discovery of the
|
| 370 |
+
violation; or
|
| 371 |
+
|
| 372 |
+
2. upon express reinstatement by the Licensor.
|
| 373 |
+
|
| 374 |
+
For the avoidance of doubt, this Section 6(b) does not affect any
|
| 375 |
+
right the Licensor may have to seek remedies for Your violations
|
| 376 |
+
of this Public License.
|
| 377 |
+
|
| 378 |
+
c. For the avoidance of doubt, the Licensor may also offer the
|
| 379 |
+
Licensed Material under separate terms or conditions or stop
|
| 380 |
+
distributing the Licensed Material at any time; however, doing so
|
| 381 |
+
will not terminate this Public License.
|
| 382 |
+
|
| 383 |
+
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
| 384 |
+
License.
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
Section 7 -- Other Terms and Conditions.
|
| 388 |
+
|
| 389 |
+
a. The Licensor shall not be bound by any additional or different
|
| 390 |
+
terms or conditions communicated by You unless expressly agreed.
|
| 391 |
+
|
| 392 |
+
b. Any arrangements, understandings, or agreements regarding the
|
| 393 |
+
Licensed Material not stated herein are separate from and
|
| 394 |
+
independent of the terms and conditions of this Public License.
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
Section 8 -- Interpretation.
|
| 398 |
+
|
| 399 |
+
a. For the avoidance of doubt, this Public License does not, and
|
| 400 |
+
shall not be interpreted to, reduce, limit, restrict, or impose
|
| 401 |
+
conditions on any use of the Licensed Material that could lawfully
|
| 402 |
+
be made without permission under this Public License.
|
| 403 |
+
|
| 404 |
+
b. To the extent possible, if any provision of this Public License is
|
| 405 |
+
deemed unenforceable, it shall be automatically reformed to the
|
| 406 |
+
minimum extent necessary to make it enforceable. If the provision
|
| 407 |
+
cannot be reformed, it shall be severed from this Public License
|
| 408 |
+
without affecting the enforceability of the remaining terms and
|
| 409 |
+
conditions.
|
| 410 |
+
|
| 411 |
+
c. No term or condition of this Public License will be waived and no
|
| 412 |
+
failure to comply consented to unless expressly agreed to by the
|
| 413 |
+
Licensor.
|
| 414 |
+
|
| 415 |
+
d. Nothing in this Public License constitutes or may be interpreted
|
| 416 |
+
as a limitation upon, or waiver of, any privileges and immunities
|
| 417 |
+
that apply to the Licensor or You, including from the legal
|
| 418 |
+
processes of any jurisdiction or authority.
|
| 419 |
+
|
| 420 |
+
=======================================================================
|
| 421 |
+
|
| 422 |
+
Creative Commons is not a party to its public
|
| 423 |
+
licenses. Notwithstanding, Creative Commons may elect to apply one of
|
| 424 |
+
its public licenses to material it publishes and in those instances
|
| 425 |
+
will be considered the "Licensor." The text of the Creative Commons
|
| 426 |
+
public licenses is dedicated to the public domain under the CC0 Public
|
| 427 |
+
Domain Dedication. Except for the limited purpose of indicating that
|
| 428 |
+
material is shared under a Creative Commons public license or as
|
| 429 |
+
otherwise permitted by the Creative Commons policies published at
|
| 430 |
+
creativecommons.org/policies, Creative Commons does not authorize the
|
| 431 |
+
use of the trademark "Creative Commons" or any other trademark or logo
|
| 432 |
+
of Creative Commons without its prior written consent including,
|
| 433 |
+
without limitation, in connection with any unauthorized modifications
|
| 434 |
+
to any of its public licenses or any other arrangements,
|
| 435 |
+
understandings, or agreements concerning use of licensed material. For
|
| 436 |
+
the avoidance of doubt, this paragraph does not form part of the
|
| 437 |
+
public licenses.
|
| 438 |
+
|
| 439 |
+
Creative Commons may be contacted at creativecommons.org.
|
licenses/LICENSE_UVIT.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2022 Fan Bao
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
lmdb2wds.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
This file implements the direction conversion from the latent ImageNet dataset to WebDataset.
|
| 3 |
+
'''
|
| 4 |
+
import os
|
| 5 |
+
from argparse import ArgumentParser
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pickle
|
| 9 |
+
|
| 10 |
+
import webdataset as wds
|
| 11 |
+
|
| 12 |
+
from train_utils.datasets import ImageNetLatentDataset
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def convert2wds(args):
|
| 16 |
+
os.makedirs(args.outdir, exist_ok=True)
|
| 17 |
+
wds_path = os.path.join(args.outdir, f'latent_imagenet_512_{args.split}-%04d.tar')
|
| 18 |
+
dataset = ImageNetLatentDataset(args.datadir, resolution=args.resolution, num_channels=args.num_channels, split=args.split)
|
| 19 |
+
|
| 20 |
+
with wds.ShardWriter(wds_path, maxcount=args.maxcount, maxsize=args.maxsize) as sink:
|
| 21 |
+
for i in tqdm(range(len(dataset)), dynamic_ncols=True):
|
| 22 |
+
if i % args.maxcount == 0:
|
| 23 |
+
print(f'writing to the {i // args.maxcount}th shard')
|
| 24 |
+
img, label = dataset[i] # C, H, W
|
| 25 |
+
label = np.argmax(label) # int
|
| 26 |
+
sink.write({'__key__': f'{i:07d}', 'latent': pickle.dumps(img), 'cls': label})
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
if __name__ == "__main__":
|
| 30 |
+
parser = ArgumentParser('Convert the latent imagenet dataset to WebDataset')
|
| 31 |
+
parser.add_argument('--maxcount', type=int, default=10010, help='max number of entries per shard')
|
| 32 |
+
parser.add_argument('--maxsize', type=int, default=10 ** 10, help='max size per shard')
|
| 33 |
+
parser.add_argument('--outdir', type=str, default='latent_imagenet_wds', help='path to save the converted dataset')
|
| 34 |
+
parser.add_argument('--datadir', type=str, default='latent_imagenet', help='path to the latent imagenet dataset')
|
| 35 |
+
parser.add_argument('--resolution', type=int, default=64, help='image resolution')
|
| 36 |
+
parser.add_argument('--num_channels', type=int, default=8, help='number of image channels')
|
| 37 |
+
parser.add_argument('--split', type=str, default='train', help='split of the dataset')
|
| 38 |
+
args = parser.parse_args()
|
| 39 |
+
convert2wds(args)
|
models/maskdit.py
ADDED
|
@@ -0,0 +1,781 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MIT License
|
| 2 |
+
|
| 3 |
+
# Copyright (c) [2023] [Anima-Lab]
|
| 4 |
+
|
| 5 |
+
# This code is adapted from https://github.com/facebookresearch/mae/blob/main/models_mae.py
|
| 6 |
+
# and https://github.com/facebookresearch/DiT/blob/main/models.py.
|
| 7 |
+
# The original code is licensed under a Attribution-NonCommercial 4.0 InternationalCreative Commons License,
|
| 8 |
+
# which is can be found at licenses/LICENSE_MAE.txt and licenses/LICENSE_DIT.txt.
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import numpy as np
|
| 14 |
+
import math
|
| 15 |
+
from functools import partial
|
| 16 |
+
from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def modulate(x, shift, scale):
|
| 20 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
#################################################################################
|
| 24 |
+
# Embedding Layers for Timesteps and Class Labels #
|
| 25 |
+
#################################################################################
|
| 26 |
+
|
| 27 |
+
class TimestepEmbedder(nn.Module):
|
| 28 |
+
"""
|
| 29 |
+
Embeds scalar timesteps into vector representations.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
| 33 |
+
super().__init__()
|
| 34 |
+
self.mlp = nn.Sequential(
|
| 35 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
| 36 |
+
nn.SiLU(),
|
| 37 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 38 |
+
)
|
| 39 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 40 |
+
|
| 41 |
+
@staticmethod
|
| 42 |
+
def timestep_embedding(t, dim, max_period=10000):
|
| 43 |
+
"""
|
| 44 |
+
Create sinusoidal timestep embeddings.
|
| 45 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
| 46 |
+
These may be fractional.
|
| 47 |
+
:param dim: the dimension of the output.
|
| 48 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 49 |
+
:return: an (N, D) Tensor of positional embeddings.
|
| 50 |
+
"""
|
| 51 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
| 52 |
+
half = dim // 2
|
| 53 |
+
freqs = torch.exp(
|
| 54 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
| 55 |
+
).to(device=t.device)
|
| 56 |
+
args = t[:, None].float() * freqs[None]
|
| 57 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 58 |
+
if dim % 2:
|
| 59 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 60 |
+
return embedding
|
| 61 |
+
|
| 62 |
+
def forward(self, t):
|
| 63 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
| 64 |
+
t_emb = self.mlp(t_freq)
|
| 65 |
+
return t_emb
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class LabelEmbedder(nn.Module):
|
| 69 |
+
"""
|
| 70 |
+
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
def __init__(self, num_classes, hidden_size, dropout_prob):
|
| 74 |
+
super().__init__()
|
| 75 |
+
self.embedding_table = nn.Linear(num_classes, hidden_size, bias=False)
|
| 76 |
+
self.num_classes = num_classes
|
| 77 |
+
self.dropout_prob = dropout_prob
|
| 78 |
+
|
| 79 |
+
def forward(self, y):
|
| 80 |
+
embeddings = self.embedding_table(y)
|
| 81 |
+
return embeddings
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
#################################################################################
|
| 85 |
+
# Token Masking and Unmasking #
|
| 86 |
+
#################################################################################
|
| 87 |
+
|
| 88 |
+
def get_mask(batch, length, mask_ratio, device):
|
| 89 |
+
"""
|
| 90 |
+
Get the binary mask for the input sequence.
|
| 91 |
+
Args:
|
| 92 |
+
- batch: batch size
|
| 93 |
+
- length: sequence length
|
| 94 |
+
- mask_ratio: ratio of tokens to mask
|
| 95 |
+
return:
|
| 96 |
+
mask_dict with following keys:
|
| 97 |
+
- mask: binary mask, 0 is keep, 1 is remove
|
| 98 |
+
- ids_keep: indices of tokens to keep
|
| 99 |
+
- ids_restore: indices to restore the original order
|
| 100 |
+
"""
|
| 101 |
+
len_keep = int(length * (1 - mask_ratio))
|
| 102 |
+
noise = torch.rand(batch, length, device=device) # noise in [0, 1]
|
| 103 |
+
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
|
| 104 |
+
ids_restore = torch.argsort(ids_shuffle, dim=1)
|
| 105 |
+
# keep the first subset
|
| 106 |
+
ids_keep = ids_shuffle[:, :len_keep]
|
| 107 |
+
|
| 108 |
+
mask = torch.ones([batch, length], device=device)
|
| 109 |
+
mask[:, :len_keep] = 0
|
| 110 |
+
mask = torch.gather(mask, dim=1, index=ids_restore)
|
| 111 |
+
return {'mask': mask,
|
| 112 |
+
'ids_keep': ids_keep,
|
| 113 |
+
'ids_restore': ids_restore}
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def mask_out_token(x, ids_keep):
|
| 117 |
+
"""
|
| 118 |
+
Mask out the tokens specified by ids_keep.
|
| 119 |
+
Args:
|
| 120 |
+
- x: input sequence, [N, L, D]
|
| 121 |
+
- ids_keep: indices of tokens to keep
|
| 122 |
+
return:
|
| 123 |
+
- x_masked: masked sequence
|
| 124 |
+
"""
|
| 125 |
+
N, L, D = x.shape # batch, length, dim
|
| 126 |
+
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
|
| 127 |
+
return x_masked
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def mask_tokens(x, mask_ratio):
|
| 131 |
+
"""
|
| 132 |
+
Perform per-sample random masking by per-sample shuffling.
|
| 133 |
+
Per-sample shuffling is done by argsort random noise.
|
| 134 |
+
x: [N, L, D], sequence
|
| 135 |
+
"""
|
| 136 |
+
N, L, D = x.shape # batch, length, dim
|
| 137 |
+
len_keep = int(L * (1 - mask_ratio))
|
| 138 |
+
|
| 139 |
+
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
|
| 140 |
+
|
| 141 |
+
# sort noise for each sample
|
| 142 |
+
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
|
| 143 |
+
ids_restore = torch.argsort(ids_shuffle, dim=1)
|
| 144 |
+
|
| 145 |
+
# keep the first subset
|
| 146 |
+
ids_keep = ids_shuffle[:, :len_keep]
|
| 147 |
+
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
|
| 148 |
+
|
| 149 |
+
# generate the binary mask: 0 is keep, 1 is remove
|
| 150 |
+
mask = torch.ones([N, L], device=x.device)
|
| 151 |
+
mask[:, :len_keep] = 0
|
| 152 |
+
mask = torch.gather(mask, dim=1, index=ids_restore)
|
| 153 |
+
|
| 154 |
+
return x_masked, mask, ids_restore
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def unmask_tokens(x, ids_restore, mask_token, extras=0):
|
| 158 |
+
# x: [N, T, D] if extras == 0 (i.e., no cls token) else x: [N, T+1, D]
|
| 159 |
+
mask_tokens = mask_token.repeat(x.shape[0], ids_restore.shape[1] + extras - x.shape[1], 1)
|
| 160 |
+
x_ = torch.cat([x[:, extras:, :], mask_tokens], dim=1) # no cls token
|
| 161 |
+
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
|
| 162 |
+
x = torch.cat([x[:, :extras, :], x_], dim=1) # append cls token
|
| 163 |
+
return x
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
#################################################################################
|
| 167 |
+
# Core DiT Model #
|
| 168 |
+
#################################################################################
|
| 169 |
+
|
| 170 |
+
class DiTBlock(nn.Module):
|
| 171 |
+
"""
|
| 172 |
+
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
def __init__(self, hidden_size, c_emb_dize, num_heads, mlp_ratio=4.0, **block_kwargs):
|
| 176 |
+
super().__init__()
|
| 177 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 178 |
+
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
|
| 179 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 180 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 181 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
| 182 |
+
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
|
| 183 |
+
self.adaLN_modulation = nn.Sequential(
|
| 184 |
+
nn.SiLU(),
|
| 185 |
+
nn.Linear(c_emb_dize, 6 * hidden_size, bias=True)
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
def forward(self, x, c):
|
| 189 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
|
| 190 |
+
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
|
| 191 |
+
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
| 192 |
+
return x
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class DecoderLayer(nn.Module):
|
| 196 |
+
"""
|
| 197 |
+
The final layer of DiT.
|
| 198 |
+
"""
|
| 199 |
+
|
| 200 |
+
def __init__(self, hidden_size, decoder_hidden_size):
|
| 201 |
+
super().__init__()
|
| 202 |
+
self.norm_decoder = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 203 |
+
self.linear = nn.Linear(hidden_size, decoder_hidden_size, bias=True)
|
| 204 |
+
self.adaLN_modulation = nn.Sequential(
|
| 205 |
+
nn.SiLU(),
|
| 206 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
def forward(self, x, c):
|
| 210 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
| 211 |
+
x = modulate(self.norm_decoder(x), shift, scale)
|
| 212 |
+
x = self.linear(x)
|
| 213 |
+
return x
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
class FinalLayer(nn.Module):
|
| 217 |
+
"""
|
| 218 |
+
The final layer of DiT.
|
| 219 |
+
"""
|
| 220 |
+
|
| 221 |
+
def __init__(self, final_hidden_size, c_emb_size, patch_size, out_channels):
|
| 222 |
+
super().__init__()
|
| 223 |
+
self.norm_final = nn.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6)
|
| 224 |
+
self.linear = nn.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True)
|
| 225 |
+
self.adaLN_modulation = nn.Sequential(
|
| 226 |
+
nn.SiLU(),
|
| 227 |
+
nn.Linear(c_emb_size, 2 * final_hidden_size, bias=True)
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
def forward(self, x, c):
|
| 231 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
| 232 |
+
x = modulate(self.norm_final(x), shift, scale)
|
| 233 |
+
x = self.linear(x)
|
| 234 |
+
return x
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
class DiT(nn.Module):
|
| 238 |
+
"""
|
| 239 |
+
Diffusion model with a Transformer backbone.
|
| 240 |
+
"""
|
| 241 |
+
|
| 242 |
+
def __init__(
|
| 243 |
+
self,
|
| 244 |
+
input_size=32,
|
| 245 |
+
patch_size=2,
|
| 246 |
+
in_channels=4,
|
| 247 |
+
hidden_size=1152,
|
| 248 |
+
depth=28,
|
| 249 |
+
num_heads=16,
|
| 250 |
+
mlp_ratio=4.0,
|
| 251 |
+
class_dropout_prob=0.1,
|
| 252 |
+
num_classes=1000, # 0 = unconditional
|
| 253 |
+
learn_sigma=False,
|
| 254 |
+
use_decoder=False, # decide if add a lightweight DiT decoder
|
| 255 |
+
mae_loss_coef=0, # 0 = no mae loss
|
| 256 |
+
pad_cls_token=False, # decide if use cls_token as mask token for decoder
|
| 257 |
+
direct_cls_token=False, # decide if directly pass cls_toekn to decoder (0 = not pass to decoder)
|
| 258 |
+
ext_feature_dim=0, # decide if condition on external features (0 = no feature)
|
| 259 |
+
use_encoder_feat=False, # decide if condition on encoder output feature
|
| 260 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), # normalize the encoder output feature
|
| 261 |
+
):
|
| 262 |
+
super().__init__()
|
| 263 |
+
self.learn_sigma = learn_sigma
|
| 264 |
+
self.in_channels = in_channels
|
| 265 |
+
self.out_channels = in_channels * 2 if learn_sigma else in_channels
|
| 266 |
+
self.patch_size = patch_size
|
| 267 |
+
self.num_heads = num_heads
|
| 268 |
+
self.class_dropout_prob = class_dropout_prob
|
| 269 |
+
self.num_classes = num_classes
|
| 270 |
+
self.use_decoder = use_decoder
|
| 271 |
+
self.mae_loss_coef = mae_loss_coef
|
| 272 |
+
self.pad_cls_token = pad_cls_token
|
| 273 |
+
self.direct_cls_token = direct_cls_token
|
| 274 |
+
self.ext_feature_dim = ext_feature_dim
|
| 275 |
+
self.use_encoder_feat = use_encoder_feat
|
| 276 |
+
self.feat_norm = norm_layer(hidden_size, elementwise_affine=False)
|
| 277 |
+
|
| 278 |
+
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size)
|
| 279 |
+
self.t_embedder = TimestepEmbedder(hidden_size)
|
| 280 |
+
self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) if num_classes else None
|
| 281 |
+
num_patches = self.x_embedder.num_patches
|
| 282 |
+
|
| 283 |
+
self.cls_token = None
|
| 284 |
+
self.extras = 0
|
| 285 |
+
self.decoder_extras = 0
|
| 286 |
+
if self.pad_cls_token:
|
| 287 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))
|
| 288 |
+
self.extras = 1
|
| 289 |
+
self.decoder_extras = 1
|
| 290 |
+
|
| 291 |
+
self.feat_embedder = None
|
| 292 |
+
if self.ext_feature_dim > 0:
|
| 293 |
+
self.feat_embedder = nn.Linear(self.ext_feature_dim, hidden_size, bias=True)
|
| 294 |
+
|
| 295 |
+
# Will use fixed sin-cos embedding:
|
| 296 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.extras, hidden_size), requires_grad=False)
|
| 297 |
+
|
| 298 |
+
self.blocks = nn.ModuleList([
|
| 299 |
+
DiTBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
|
| 300 |
+
])
|
| 301 |
+
|
| 302 |
+
self.decoder_pos_embed = None
|
| 303 |
+
self.decoder_layer = None
|
| 304 |
+
self.decoder_blocks = None
|
| 305 |
+
self.mask_token = None
|
| 306 |
+
self.cls_token_embedder = None
|
| 307 |
+
self.enc_feat_embedder = None
|
| 308 |
+
final_hidden_size = hidden_size
|
| 309 |
+
if self.use_decoder:
|
| 310 |
+
decoder_hidden_size = 512
|
| 311 |
+
decoder_depth = 8
|
| 312 |
+
decoder_num_heads = 16
|
| 313 |
+
if not self.direct_cls_token:
|
| 314 |
+
self.decoder_extras = 0
|
| 315 |
+
self.decoder_pos_embed = nn.Parameter(
|
| 316 |
+
torch.zeros(1, num_patches + self.decoder_extras, decoder_hidden_size),
|
| 317 |
+
requires_grad=False)
|
| 318 |
+
self.decoder_layer = DecoderLayer(hidden_size, decoder_hidden_size)
|
| 319 |
+
self.decoder_blocks = nn.ModuleList([
|
| 320 |
+
DiTBlock(decoder_hidden_size, hidden_size, decoder_num_heads, mlp_ratio=mlp_ratio) for _ in
|
| 321 |
+
range(decoder_depth)
|
| 322 |
+
])
|
| 323 |
+
if self.mae_loss_coef > 0:
|
| 324 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_hidden_size)) # Similar to MAE
|
| 325 |
+
if self.pad_cls_token:
|
| 326 |
+
self.cls_token_embedder = nn.Linear(hidden_size, hidden_size, bias=True)
|
| 327 |
+
if self.use_encoder_feat:
|
| 328 |
+
self.enc_feat_embedder = nn.Linear(hidden_size, hidden_size, bias=True)
|
| 329 |
+
final_hidden_size = decoder_hidden_size
|
| 330 |
+
|
| 331 |
+
self.final_layer = FinalLayer(final_hidden_size, hidden_size, patch_size, self.out_channels)
|
| 332 |
+
self.initialize_weights()
|
| 333 |
+
|
| 334 |
+
def initialize_weights(self):
|
| 335 |
+
# Initialize transformer layers:
|
| 336 |
+
def _basic_init(module):
|
| 337 |
+
if isinstance(module, nn.Linear):
|
| 338 |
+
nn.init.xavier_uniform_(module.weight)
|
| 339 |
+
if module.bias is not None:
|
| 340 |
+
nn.init.constant_(module.bias, 0)
|
| 341 |
+
|
| 342 |
+
self.apply(_basic_init)
|
| 343 |
+
|
| 344 |
+
# Initialize (and freeze) pos_embed by sin-cos embedding:
|
| 345 |
+
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5),
|
| 346 |
+
cls_token=self.pad_cls_token, extra_tokens=self.extras)
|
| 347 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
| 348 |
+
|
| 349 |
+
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
| 350 |
+
w = self.x_embedder.proj.weight.data
|
| 351 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 352 |
+
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
| 353 |
+
|
| 354 |
+
# Initialize label embedding table:
|
| 355 |
+
if self.y_embedder is not None:
|
| 356 |
+
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
|
| 357 |
+
|
| 358 |
+
# Initialize timestep embedding MLP:
|
| 359 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
| 360 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
| 361 |
+
|
| 362 |
+
# Initialize cls_token embedding:
|
| 363 |
+
if self.feat_embedder is not None:
|
| 364 |
+
nn.init.normal_(self.feat_embedder.weight, std=0.02)
|
| 365 |
+
|
| 366 |
+
# Initialize cls token
|
| 367 |
+
if self.cls_token is not None:
|
| 368 |
+
nn.init.normal_(self.cls_token, std=.02)
|
| 369 |
+
|
| 370 |
+
# Initialize cls_token embedding:
|
| 371 |
+
if self.cls_token_embedder is not None:
|
| 372 |
+
nn.init.normal_(self.cls_token_embedder.weight, std=0.02)
|
| 373 |
+
|
| 374 |
+
# Zero-out adaLN modulation layers in DiT blocks:
|
| 375 |
+
for block in self.blocks:
|
| 376 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
| 377 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
| 378 |
+
|
| 379 |
+
# Zero-out output layers:
|
| 380 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
| 381 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
| 382 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
| 383 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
| 384 |
+
|
| 385 |
+
# --------------------------- decoder initialization ---------------------------
|
| 386 |
+
# Initialize (and freeze) decoder_pos_embed by sin-cos embedding:
|
| 387 |
+
if self.decoder_pos_embed is not None:
|
| 388 |
+
pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1],
|
| 389 |
+
int(self.x_embedder.num_patches ** 0.5),
|
| 390 |
+
cls_token=self.pad_cls_token, extra_tokens=self.decoder_extras)
|
| 391 |
+
self.decoder_pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
| 392 |
+
|
| 393 |
+
# Initialize mask token
|
| 394 |
+
if self.mae_loss_coef > 0 and self.mask_token is not None:
|
| 395 |
+
nn.init.normal_(self.mask_token, std=.02)
|
| 396 |
+
|
| 397 |
+
# Zero-out adaLN modulation layers in DiT decoder blocks:
|
| 398 |
+
if self.decoder_blocks is not None:
|
| 399 |
+
for block in self.decoder_blocks:
|
| 400 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
| 401 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
| 402 |
+
|
| 403 |
+
# Zero-out decoder layers: (TODO: here we keep it the same with final layers but not sure if it makes sense)
|
| 404 |
+
if self.decoder_layer is not None:
|
| 405 |
+
nn.init.constant_(self.decoder_layer.adaLN_modulation[-1].weight, 0)
|
| 406 |
+
nn.init.constant_(self.decoder_layer.adaLN_modulation[-1].bias, 0)
|
| 407 |
+
nn.init.constant_(self.decoder_layer.linear.weight, 0)
|
| 408 |
+
nn.init.constant_(self.decoder_layer.linear.bias, 0)
|
| 409 |
+
# ------------------------------------------------------------------------------
|
| 410 |
+
|
| 411 |
+
def unpatchify(self, x):
|
| 412 |
+
"""
|
| 413 |
+
x: (N, L, patch_size**2 * C)
|
| 414 |
+
imgs: (N, H, W, C)
|
| 415 |
+
"""
|
| 416 |
+
c = self.out_channels
|
| 417 |
+
p = self.x_embedder.patch_size[0]
|
| 418 |
+
h = w = int(x.shape[1] ** 0.5)
|
| 419 |
+
assert h * w == x.shape[1]
|
| 420 |
+
|
| 421 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
| 422 |
+
x = torch.einsum('nhwpqc->nchpwq', x)
|
| 423 |
+
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
| 424 |
+
return imgs
|
| 425 |
+
|
| 426 |
+
def encode(self, x, t, y, mask_ratio=0, mask_dict=None, feat=None):
|
| 427 |
+
'''
|
| 428 |
+
Encode x and (t, y, feat) into a latent representation.
|
| 429 |
+
Return:
|
| 430 |
+
x_feat: feature
|
| 431 |
+
mask_dict with keys: 'ids_keep', 'ids_mask', 'mask_ratio'
|
| 432 |
+
'''
|
| 433 |
+
x = self.x_embedder(x) + self.pos_embed[:, self.extras:, :] # (N, T, D), where T = H * W / patch_size ** 2
|
| 434 |
+
if mask_ratio > 0 and mask_dict is None:
|
| 435 |
+
mask_dict = get_mask(x.shape[0], x.shape[1], mask_ratio, device=x.device)
|
| 436 |
+
if mask_ratio > 0:
|
| 437 |
+
ids_keep = mask_dict['ids_keep']
|
| 438 |
+
x = mask_out_token(x, ids_keep)
|
| 439 |
+
# append cls token
|
| 440 |
+
if self.cls_token is not None:
|
| 441 |
+
cls_token = self.cls_token + self.pos_embed[:, :self.extras, :]
|
| 442 |
+
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
|
| 443 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 444 |
+
t = self.t_embedder(t) # (N, D)
|
| 445 |
+
c = t
|
| 446 |
+
if self.y_embedder is not None:
|
| 447 |
+
y = self.y_embedder(y) # (N, D)
|
| 448 |
+
c = c + y # (N, D)
|
| 449 |
+
assert (self.feat_embedder is None) or (self.enc_feat_embedder is None)
|
| 450 |
+
if self.feat_embedder is not None:
|
| 451 |
+
assert feat.shape[-1] == self.ext_feature_dim
|
| 452 |
+
feat_embed = self.feat_embedder(feat) # (N, D)
|
| 453 |
+
c = c + feat_embed # (N, D)
|
| 454 |
+
if self.enc_feat_embedder is not None and feat is not None:
|
| 455 |
+
assert feat.shape[-1] == c.shape[-1]
|
| 456 |
+
feat_embed = self.enc_feat_embedder(feat) # (N, D)
|
| 457 |
+
c = c + feat_embed # (N, D)
|
| 458 |
+
|
| 459 |
+
for block in self.blocks:
|
| 460 |
+
x = block(x, c) # (N, T, D)
|
| 461 |
+
|
| 462 |
+
x_feat = x[:, self.extras:, :].mean(dim=1) # global pool without cls token
|
| 463 |
+
x_feat = self.feat_norm(x_feat)
|
| 464 |
+
return x_feat, mask_dict
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
def forward_encoder(self, x, t, y, mask_ratio=0, mask_dict=None, feat=None, train=True):
|
| 468 |
+
'''
|
| 469 |
+
Encode x and (t, y, feat) into a latent representation.
|
| 470 |
+
Return:
|
| 471 |
+
- out_enc: dict, containing the following keys: x, x_feat
|
| 472 |
+
- c: the conditional embedding
|
| 473 |
+
'''
|
| 474 |
+
out_enc = dict()
|
| 475 |
+
x = self.x_embedder(x) + self.pos_embed[:, self.extras:, :] # (N, T, D), where T = H * W / patch_size ** 2
|
| 476 |
+
if mask_ratio > 0 and mask_dict is None:
|
| 477 |
+
mask_dict = get_mask(x.shape[0], x.shape[1], mask_ratio=mask_ratio, device=x.device)
|
| 478 |
+
|
| 479 |
+
if mask_ratio > 0:
|
| 480 |
+
ids_keep = mask_dict['ids_keep']
|
| 481 |
+
ids_restore = mask_dict['ids_restore']
|
| 482 |
+
if train:
|
| 483 |
+
x = mask_out_token(x, ids_keep)
|
| 484 |
+
|
| 485 |
+
# append cls token
|
| 486 |
+
if self.cls_token is not None:
|
| 487 |
+
cls_token = self.cls_token + self.pos_embed[:, :self.extras, :]
|
| 488 |
+
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
|
| 489 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 490 |
+
|
| 491 |
+
t = self.t_embedder(t) # (N, D)
|
| 492 |
+
c = t
|
| 493 |
+
if self.y_embedder is not None:
|
| 494 |
+
y = self.y_embedder(y) # (N, D)
|
| 495 |
+
c = c + y # (N, D)
|
| 496 |
+
assert (self.feat_embedder is None) or (self.enc_feat_embedder is None)
|
| 497 |
+
if self.feat_embedder is not None:
|
| 498 |
+
assert feat.shape[-1] == self.ext_feature_dim
|
| 499 |
+
feat_embed = self.feat_embedder(feat) # (N, D)
|
| 500 |
+
c = c + feat_embed # (N, D)
|
| 501 |
+
if self.enc_feat_embedder is not None and feat is not None:
|
| 502 |
+
assert feat.shape[-1] == c.shape[-1]
|
| 503 |
+
feat_embed = self.enc_feat_embedder(feat) # (N, D)
|
| 504 |
+
c = c + feat_embed # (N, D)
|
| 505 |
+
for block in self.blocks:
|
| 506 |
+
x = block(x, c) # (N, T, D)
|
| 507 |
+
out_enc['x'] = x
|
| 508 |
+
|
| 509 |
+
return out_enc, c, mask_dict
|
| 510 |
+
|
| 511 |
+
def forward(self, x, t, y, mask_ratio=0, mask_dict=None, feat=None):
|
| 512 |
+
"""
|
| 513 |
+
Forward pass of DiT.
|
| 514 |
+
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
| 515 |
+
t: (N,) tensor of diffusion timesteps
|
| 516 |
+
y: (N,) tensor of class labels
|
| 517 |
+
"""
|
| 518 |
+
if not self.training and self.use_encoder_feat:
|
| 519 |
+
feat, _ = self.encode(x, t, y, feat=feat)
|
| 520 |
+
out, c, mask_dict = self.forward_encoder(x, t, y, mask_ratio=mask_ratio, mask_dict=mask_dict, feat=feat, train=self.training)
|
| 521 |
+
if mask_ratio > 0:
|
| 522 |
+
ids_keep = mask_dict['ids_keep']
|
| 523 |
+
ids_restore = mask_dict['ids_restore']
|
| 524 |
+
out['mask'] = mask_dict['mask']
|
| 525 |
+
else:
|
| 526 |
+
ids_keep = ids_restore = None
|
| 527 |
+
x = out['x']
|
| 528 |
+
# Pass to a DiT decoder (if available)
|
| 529 |
+
if self.use_decoder:
|
| 530 |
+
if self.cls_token_embedder is not None:
|
| 531 |
+
# cls_token_output = x[:, :self.extras, :].squeeze(dim=1).detach().clone() # stop gradient
|
| 532 |
+
cls_token_output = x[:, :self.extras, :].squeeze(dim=1)
|
| 533 |
+
cls_token_embed = self.cls_token_embedder(self.feat_norm(cls_token_output)) # normalize cls token
|
| 534 |
+
c = c + cls_token_embed # pad cls_token output's embedding as feature conditioning
|
| 535 |
+
|
| 536 |
+
assert self.decoder_layer is not None
|
| 537 |
+
diff_extras = self.extras - self.decoder_extras
|
| 538 |
+
x = self.decoder_layer(x[:, diff_extras:, :], c) # remove cls token (if necessary)
|
| 539 |
+
if self.training and mask_ratio > 0:
|
| 540 |
+
mask_token = self.mask_token
|
| 541 |
+
if mask_token is None:
|
| 542 |
+
mask_token = torch.zeros(1, 1, x.shape[2]).to(x) # concat zeros to match shape
|
| 543 |
+
x = unmask_tokens(x, ids_restore, mask_token, extras=self.decoder_extras)
|
| 544 |
+
assert self.decoder_pos_embed is not None
|
| 545 |
+
x = x + self.decoder_pos_embed
|
| 546 |
+
assert self.decoder_blocks is not None
|
| 547 |
+
for block in self.decoder_blocks:
|
| 548 |
+
x = block(x, c) # (N, T, D)
|
| 549 |
+
|
| 550 |
+
x = self.final_layer(x, c) # (N, T or T+1, patch_size ** 2 * out_channels)
|
| 551 |
+
if not self.use_decoder and (self.training and mask_ratio > 0):
|
| 552 |
+
mask_token = torch.zeros(1, 1, x.shape[2]).to(x) # concat zeros to match shape
|
| 553 |
+
x = unmask_tokens(x, ids_restore, mask_token, extras=self.extras)
|
| 554 |
+
x = x[:, self.decoder_extras:, :] # remove cls token (if necessary)
|
| 555 |
+
x = self.unpatchify(x) # (N, out_channels, H, W)
|
| 556 |
+
out['x'] = x
|
| 557 |
+
return out
|
| 558 |
+
|
| 559 |
+
def forward_with_cfg(self, x, t, y, cfg_scale, feat=None, **model_kwargs):
|
| 560 |
+
"""
|
| 561 |
+
Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
|
| 562 |
+
"""
|
| 563 |
+
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
|
| 564 |
+
out = dict()
|
| 565 |
+
|
| 566 |
+
# Setup classifier-free guidance
|
| 567 |
+
x = torch.cat([x, x], 0)
|
| 568 |
+
y_null = torch.zeros_like(y)
|
| 569 |
+
y = torch.cat([y, y_null], 0)
|
| 570 |
+
if feat is not None:
|
| 571 |
+
feat = torch.cat([feat, feat], 0)
|
| 572 |
+
|
| 573 |
+
half = x[: len(x) // 2]
|
| 574 |
+
combined = torch.cat([half, half], dim=0)
|
| 575 |
+
assert self.num_classes and y is not None
|
| 576 |
+
model_out = self.forward(combined, t, y, feat=feat)['x']
|
| 577 |
+
# For exact reproducibility reasons, we apply classifier-free guidance on only
|
| 578 |
+
# three channels by default. The standard approach to cfg applies it to all channels.
|
| 579 |
+
# This can be done by uncommenting the following line and commenting-out the line following that.
|
| 580 |
+
eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
|
| 581 |
+
# eps, rest = model_out[:, :3], model_out[:, 3:]
|
| 582 |
+
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
| 583 |
+
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
|
| 584 |
+
half_rest = rest[: len(rest) // 2]
|
| 585 |
+
x = torch.cat([half_eps, half_rest], dim=1)
|
| 586 |
+
out['x'] = x
|
| 587 |
+
return out
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
#################################################################################
|
| 591 |
+
# Sine/Cosine Positional Embedding Functions #
|
| 592 |
+
#################################################################################
|
| 593 |
+
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
| 594 |
+
|
| 595 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=1):
|
| 596 |
+
"""
|
| 597 |
+
grid_size: int of the grid height and width
|
| 598 |
+
return:
|
| 599 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
| 600 |
+
"""
|
| 601 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
| 602 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
| 603 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
| 604 |
+
grid = np.stack(grid, axis=0)
|
| 605 |
+
|
| 606 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
| 607 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 608 |
+
if cls_token and extra_tokens > 0:
|
| 609 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
| 610 |
+
return pos_embed
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 614 |
+
assert embed_dim % 2 == 0
|
| 615 |
+
|
| 616 |
+
# use half of dimensions to encode grid_h
|
| 617 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
| 618 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
| 619 |
+
|
| 620 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
| 621 |
+
return emb
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 625 |
+
"""
|
| 626 |
+
embed_dim: output dimension for each position
|
| 627 |
+
pos: a list of positions to be encoded: size (M,)
|
| 628 |
+
out: (M, D)
|
| 629 |
+
"""
|
| 630 |
+
assert embed_dim % 2 == 0
|
| 631 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
| 632 |
+
omega /= embed_dim / 2.
|
| 633 |
+
omega = 1. / 10000 ** omega # (D/2,)
|
| 634 |
+
|
| 635 |
+
pos = pos.reshape(-1) # (M,)
|
| 636 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
| 637 |
+
|
| 638 |
+
emb_sin = np.sin(out) # (M, D/2)
|
| 639 |
+
emb_cos = np.cos(out) # (M, D/2)
|
| 640 |
+
|
| 641 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
| 642 |
+
return emb
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
#################################################################################
|
| 646 |
+
# DiT Configs #
|
| 647 |
+
#################################################################################
|
| 648 |
+
|
| 649 |
+
def DiT_H_2(**kwargs):
|
| 650 |
+
return DiT(depth=32, hidden_size=1280, patch_size=2, num_heads=16, **kwargs)
|
| 651 |
+
|
| 652 |
+
|
| 653 |
+
def DiT_H_4(**kwargs):
|
| 654 |
+
return DiT(depth=32, hidden_size=1280, patch_size=4, num_heads=16, **kwargs)
|
| 655 |
+
|
| 656 |
+
|
| 657 |
+
def DiT_H_8(**kwargs):
|
| 658 |
+
return DiT(depth=32, hidden_size=1280, patch_size=8, num_heads=16, **kwargs)
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
def DiT_XL_2(**kwargs):
|
| 662 |
+
return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
|
| 663 |
+
|
| 664 |
+
|
| 665 |
+
def DiT_XL_4(**kwargs):
|
| 666 |
+
return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
|
| 667 |
+
|
| 668 |
+
|
| 669 |
+
def DiT_XL_8(**kwargs):
|
| 670 |
+
return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
def DiT_L_2(**kwargs):
|
| 674 |
+
return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
|
| 675 |
+
|
| 676 |
+
|
| 677 |
+
def DiT_L_4(**kwargs):
|
| 678 |
+
return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
|
| 679 |
+
|
| 680 |
+
|
| 681 |
+
def DiT_L_8(**kwargs):
|
| 682 |
+
return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
|
| 683 |
+
|
| 684 |
+
|
| 685 |
+
def DiT_B_2(**kwargs):
|
| 686 |
+
return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
|
| 687 |
+
|
| 688 |
+
|
| 689 |
+
def DiT_B_4(**kwargs):
|
| 690 |
+
return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
|
| 691 |
+
|
| 692 |
+
|
| 693 |
+
def DiT_B_8(**kwargs):
|
| 694 |
+
return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
|
| 695 |
+
|
| 696 |
+
|
| 697 |
+
def DiT_S_2(**kwargs):
|
| 698 |
+
return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
|
| 699 |
+
|
| 700 |
+
|
| 701 |
+
def DiT_S_4(**kwargs):
|
| 702 |
+
return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
|
| 703 |
+
|
| 704 |
+
|
| 705 |
+
def DiT_S_8(**kwargs):
|
| 706 |
+
return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
|
| 707 |
+
|
| 708 |
+
|
| 709 |
+
DiT_models = {
|
| 710 |
+
'DiT-H/2': DiT_H_2, 'DiT-H/4': DiT_H_4, 'DiT-H/8': DiT_H_8,
|
| 711 |
+
'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8,
|
| 712 |
+
'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8,
|
| 713 |
+
'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8,
|
| 714 |
+
'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8,
|
| 715 |
+
}
|
| 716 |
+
|
| 717 |
+
|
| 718 |
+
# ----------------------------------------------------------------------------
|
| 719 |
+
# Improved preconditioning proposed in the paper "Elucidating the Design
|
| 720 |
+
# Space of Diffusion-Based Generative Models" (EDM).
|
| 721 |
+
|
| 722 |
+
class EDMPrecond(nn.Module):
|
| 723 |
+
def __init__(self,
|
| 724 |
+
img_resolution, # Image resolution.
|
| 725 |
+
img_channels, # Number of color channels.
|
| 726 |
+
num_classes=0, # Number of class labels, 0 = unconditional.
|
| 727 |
+
sigma_min=0, # Minimum supported noise level.
|
| 728 |
+
sigma_max=float('inf'), # Maximum supported noise level.
|
| 729 |
+
sigma_data=0.5, # Expected standard deviation of the training data.
|
| 730 |
+
model_type='DiT-B/2', # Class name of the underlying model.
|
| 731 |
+
**model_kwargs, # Keyword arguments for the underlying model.
|
| 732 |
+
):
|
| 733 |
+
super().__init__()
|
| 734 |
+
self.img_resolution = img_resolution
|
| 735 |
+
self.img_channels = img_channels
|
| 736 |
+
self.num_classes = num_classes
|
| 737 |
+
self.sigma_min = sigma_min
|
| 738 |
+
self.sigma_max = sigma_max
|
| 739 |
+
self.sigma_data = sigma_data
|
| 740 |
+
self.model = DiT_models[model_type](input_size=img_resolution, in_channels=img_channels,
|
| 741 |
+
num_classes=num_classes, **model_kwargs)
|
| 742 |
+
|
| 743 |
+
def encode(self, x, sigma, class_labels=None, **model_kwargs):
|
| 744 |
+
|
| 745 |
+
sigma = sigma.to(x.dtype).reshape(-1, 1, 1, 1)
|
| 746 |
+
class_labels = None if self.num_classes == 0 else \
|
| 747 |
+
torch.zeros([x.shape[0], self.num_classes], device=x.device) if class_labels is None else \
|
| 748 |
+
class_labels.to(x.dtype).reshape(-1, self.num_classes)
|
| 749 |
+
|
| 750 |
+
c_in = 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt()
|
| 751 |
+
c_noise = sigma.log() / 4
|
| 752 |
+
|
| 753 |
+
feat, mask_dict = self.model.encode((c_in * x).to(x.dtype), c_noise.flatten(), y=class_labels, **model_kwargs)
|
| 754 |
+
return feat
|
| 755 |
+
|
| 756 |
+
def forward(self, x, sigma, class_labels=None, cfg_scale=None, **model_kwargs):
|
| 757 |
+
model_fn = self.model if cfg_scale is None else partial(self.model.forward_with_cfg, cfg_scale=cfg_scale)
|
| 758 |
+
|
| 759 |
+
sigma = sigma.to(x.dtype).reshape(-1, 1, 1, 1)
|
| 760 |
+
class_labels = None if self.num_classes == 0 else \
|
| 761 |
+
torch.zeros([x.shape[0], self.num_classes], device=x.device) if class_labels is None else \
|
| 762 |
+
class_labels.to(x.dtype).reshape(-1, self.num_classes)
|
| 763 |
+
|
| 764 |
+
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
|
| 765 |
+
c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt()
|
| 766 |
+
c_in = 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt()
|
| 767 |
+
c_noise = sigma.log() / 4
|
| 768 |
+
|
| 769 |
+
model_out = model_fn((c_in * x).to(x.dtype), c_noise.flatten(), y=class_labels, **model_kwargs)
|
| 770 |
+
F_x = model_out['x']
|
| 771 |
+
D_x = c_skip * x + c_out * F_x
|
| 772 |
+
model_out['x'] = D_x
|
| 773 |
+
return model_out
|
| 774 |
+
|
| 775 |
+
def round_sigma(self, sigma):
|
| 776 |
+
return torch.as_tensor(sigma)
|
| 777 |
+
|
| 778 |
+
|
| 779 |
+
Precond_models = {
|
| 780 |
+
'edm': EDMPrecond
|
| 781 |
+
}
|
sample.py
ADDED
|
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MIT License
|
| 2 |
+
|
| 3 |
+
# Copyright (c) [2023] [Anima-Lab]
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
# This code is adapted from https://github.com/NVlabs/edm/blob/main/generate.py.
|
| 7 |
+
# The original code is licensed under a Creative Commons
|
| 8 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License, which is can be found at licenses/LICENSE_EDM.txt.
|
| 9 |
+
|
| 10 |
+
import argparse
|
| 11 |
+
import random
|
| 12 |
+
|
| 13 |
+
import PIL.Image
|
| 14 |
+
import lmdb
|
| 15 |
+
import numpy as np
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.distributed as dist
|
| 19 |
+
from torch.multiprocessing import Process
|
| 20 |
+
from tqdm import tqdm
|
| 21 |
+
|
| 22 |
+
from models.maskdit import Precond_models, DiT_models
|
| 23 |
+
from utils import *
|
| 24 |
+
import autoencoder
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# ----------------------------------------------------------------------------
|
| 28 |
+
# Proposed EDM sampler (Algorithm 2).
|
| 29 |
+
|
| 30 |
+
def edm_sampler(
|
| 31 |
+
net, latents, class_labels=None, cfg_scale=None, feat=None, randn_like=torch.randn_like,
|
| 32 |
+
num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
|
| 33 |
+
S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
|
| 34 |
+
):
|
| 35 |
+
# Adjust noise levels based on what's supported by the network.
|
| 36 |
+
sigma_min = max(sigma_min, net.sigma_min)
|
| 37 |
+
sigma_max = min(sigma_max, net.sigma_max)
|
| 38 |
+
|
| 39 |
+
# Time step discretization.
|
| 40 |
+
step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
|
| 41 |
+
t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (
|
| 42 |
+
sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
|
| 43 |
+
t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
|
| 44 |
+
|
| 45 |
+
# Main sampling loop.
|
| 46 |
+
x_next = latents.to(torch.float64) * t_steps[0]
|
| 47 |
+
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
|
| 48 |
+
x_cur = x_next
|
| 49 |
+
|
| 50 |
+
# Increase noise temporarily.
|
| 51 |
+
gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
|
| 52 |
+
t_hat = net.round_sigma(t_cur + gamma * t_cur)
|
| 53 |
+
x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)
|
| 54 |
+
|
| 55 |
+
# Euler step.
|
| 56 |
+
denoised = net(x_hat.float(), t_hat, class_labels, cfg_scale, feat=feat)['x'].to(torch.float64)
|
| 57 |
+
d_cur = (x_hat - denoised) / t_hat
|
| 58 |
+
x_next = x_hat + (t_next - t_hat) * d_cur
|
| 59 |
+
|
| 60 |
+
# Apply 2nd order correction.
|
| 61 |
+
if i < num_steps - 1:
|
| 62 |
+
denoised = net(x_next.float(), t_next, class_labels, cfg_scale, feat=feat)['x'].to(torch.float64)
|
| 63 |
+
d_prime = (x_next - denoised) / t_next
|
| 64 |
+
x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
|
| 65 |
+
|
| 66 |
+
return x_next
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# ----------------------------------------------------------------------------
|
| 70 |
+
# Generalized ablation sampler, representing the superset of all sampling
|
| 71 |
+
# methods discussed in the paper.
|
| 72 |
+
|
| 73 |
+
def ablation_sampler(
|
| 74 |
+
net, latents, class_labels=None, cfg_scale=None, feat=None, randn_like=torch.randn_like,
|
| 75 |
+
num_steps=18, sigma_min=None, sigma_max=None, rho=7,
|
| 76 |
+
solver='heun', discretization='edm', schedule='linear', scaling='none',
|
| 77 |
+
epsilon_s=1e-3, C_1=0.001, C_2=0.008, M=1000, alpha=1,
|
| 78 |
+
S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
|
| 79 |
+
):
|
| 80 |
+
assert solver in ['euler', 'heun']
|
| 81 |
+
assert discretization in ['vp', 've', 'iddpm', 'edm']
|
| 82 |
+
assert schedule in ['vp', 've', 'linear']
|
| 83 |
+
assert scaling in ['vp', 'none']
|
| 84 |
+
|
| 85 |
+
# Helper functions for VP & VE noise level schedules.
|
| 86 |
+
vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5
|
| 87 |
+
vp_sigma_deriv = lambda beta_d, beta_min: lambda t: 0.5 * (beta_min + beta_d * t) * (sigma(t) + 1 / sigma(t))
|
| 88 |
+
vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (
|
| 89 |
+
sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d
|
| 90 |
+
ve_sigma = lambda t: t.sqrt()
|
| 91 |
+
ve_sigma_deriv = lambda t: 0.5 / t.sqrt()
|
| 92 |
+
ve_sigma_inv = lambda sigma: sigma ** 2
|
| 93 |
+
|
| 94 |
+
# Select default noise level range based on the specified time step discretization.
|
| 95 |
+
if sigma_min is None:
|
| 96 |
+
vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=epsilon_s)
|
| 97 |
+
sigma_min = {'vp': vp_def, 've': 0.02, 'iddpm': 0.002, 'edm': 0.002}[discretization]
|
| 98 |
+
if sigma_max is None:
|
| 99 |
+
vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=1)
|
| 100 |
+
sigma_max = {'vp': vp_def, 've': 100, 'iddpm': 81, 'edm': 80}[discretization]
|
| 101 |
+
|
| 102 |
+
# Adjust noise levels based on what's supported by the network.
|
| 103 |
+
sigma_min = max(sigma_min, net.sigma_min)
|
| 104 |
+
sigma_max = min(sigma_max, net.sigma_max)
|
| 105 |
+
|
| 106 |
+
# Compute corresponding betas for VP.
|
| 107 |
+
vp_beta_d = 2 * (np.log(sigma_min ** 2 + 1) / epsilon_s - np.log(sigma_max ** 2 + 1)) / (epsilon_s - 1)
|
| 108 |
+
vp_beta_min = np.log(sigma_max ** 2 + 1) - 0.5 * vp_beta_d
|
| 109 |
+
|
| 110 |
+
# Define time steps in terms of noise level.
|
| 111 |
+
step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
|
| 112 |
+
if discretization == 'vp':
|
| 113 |
+
orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1)
|
| 114 |
+
sigma_steps = vp_sigma(vp_beta_d, vp_beta_min)(orig_t_steps)
|
| 115 |
+
elif discretization == 've':
|
| 116 |
+
orig_t_steps = (sigma_max ** 2) * ((sigma_min ** 2 / sigma_max ** 2) ** (step_indices / (num_steps - 1)))
|
| 117 |
+
sigma_steps = ve_sigma(orig_t_steps)
|
| 118 |
+
elif discretization == 'iddpm':
|
| 119 |
+
u = torch.zeros(M + 1, dtype=torch.float64, device=latents.device)
|
| 120 |
+
alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2
|
| 121 |
+
for j in torch.arange(M, 0, -1, device=latents.device): # M, ..., 1
|
| 122 |
+
u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1).sqrt()
|
| 123 |
+
u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)]
|
| 124 |
+
sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)]
|
| 125 |
+
else:
|
| 126 |
+
assert discretization == 'edm'
|
| 127 |
+
sigma_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (
|
| 128 |
+
sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
|
| 129 |
+
|
| 130 |
+
# Define noise level schedule.
|
| 131 |
+
if schedule == 'vp':
|
| 132 |
+
sigma = vp_sigma(vp_beta_d, vp_beta_min)
|
| 133 |
+
sigma_deriv = vp_sigma_deriv(vp_beta_d, vp_beta_min)
|
| 134 |
+
sigma_inv = vp_sigma_inv(vp_beta_d, vp_beta_min)
|
| 135 |
+
elif schedule == 've':
|
| 136 |
+
sigma = ve_sigma
|
| 137 |
+
sigma_deriv = ve_sigma_deriv
|
| 138 |
+
sigma_inv = ve_sigma_inv
|
| 139 |
+
else:
|
| 140 |
+
assert schedule == 'linear'
|
| 141 |
+
sigma = lambda t: t
|
| 142 |
+
sigma_deriv = lambda t: 1
|
| 143 |
+
sigma_inv = lambda sigma: sigma
|
| 144 |
+
|
| 145 |
+
# Define scaling schedule.
|
| 146 |
+
if scaling == 'vp':
|
| 147 |
+
s = lambda t: 1 / (1 + sigma(t) ** 2).sqrt()
|
| 148 |
+
s_deriv = lambda t: -sigma(t) * sigma_deriv(t) * (s(t) ** 3)
|
| 149 |
+
else:
|
| 150 |
+
assert scaling == 'none'
|
| 151 |
+
s = lambda t: 1
|
| 152 |
+
s_deriv = lambda t: 0
|
| 153 |
+
|
| 154 |
+
# Compute final time steps based on the corresponding noise levels.
|
| 155 |
+
t_steps = sigma_inv(net.round_sigma(sigma_steps))
|
| 156 |
+
t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0
|
| 157 |
+
|
| 158 |
+
# Main sampling loop.
|
| 159 |
+
t_next = t_steps[0]
|
| 160 |
+
x_next = latents.to(torch.float64) * (sigma(t_next) * s(t_next))
|
| 161 |
+
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
|
| 162 |
+
x_cur = x_next
|
| 163 |
+
|
| 164 |
+
# Increase noise temporarily.
|
| 165 |
+
gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= sigma(t_cur) <= S_max else 0
|
| 166 |
+
t_hat = sigma_inv(net.round_sigma(sigma(t_cur) + gamma * sigma(t_cur)))
|
| 167 |
+
x_hat = s(t_hat) / s(t_cur) * x_cur + (sigma(t_hat) ** 2 - sigma(t_cur) ** 2).clip(min=0).sqrt() * s(
|
| 168 |
+
t_hat) * S_noise * randn_like(x_cur)
|
| 169 |
+
|
| 170 |
+
# Euler step.
|
| 171 |
+
h = t_next - t_hat
|
| 172 |
+
denoised = net(x_hat.float() / s(t_hat), sigma(t_hat), class_labels, cfg_scale, feat=feat)['x'].to(torch.float64)
|
| 173 |
+
d_cur = (sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat)) * x_hat - sigma_deriv(t_hat) * s(
|
| 174 |
+
t_hat) / sigma(t_hat) * denoised
|
| 175 |
+
x_prime = x_hat + alpha * h * d_cur
|
| 176 |
+
t_prime = t_hat + alpha * h
|
| 177 |
+
|
| 178 |
+
# Apply 2nd order correction.
|
| 179 |
+
if solver == 'euler' or i == num_steps - 1:
|
| 180 |
+
x_next = x_hat + h * d_cur
|
| 181 |
+
else:
|
| 182 |
+
assert solver == 'heun'
|
| 183 |
+
denoised = net(x_prime.float() / s(t_prime), sigma(t_prime), class_labels, cfg_scale, feat=feat)['x'].to(torch.float64)
|
| 184 |
+
d_prime = (sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime)) * x_prime - sigma_deriv(
|
| 185 |
+
t_prime) * s(t_prime) / sigma(t_prime) * denoised
|
| 186 |
+
x_next = x_hat + h * ((1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime)
|
| 187 |
+
|
| 188 |
+
return x_next
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
# ----------------------------------------------------------------------------
|
| 192 |
+
|
| 193 |
+
def retrieve_n_features(batch_size, feat_path, feat_dim, num_classes, device, split='train', sample_mode='rand_full'):
|
| 194 |
+
env = lmdb.open(os.path.join(feat_path, split), readonly=True, lock=False, create=False)
|
| 195 |
+
|
| 196 |
+
# Start a new read transaction
|
| 197 |
+
with env.begin() as txn:
|
| 198 |
+
# Read all images in one single transaction, with one lock
|
| 199 |
+
# We could split this up into multiple transactions if needed
|
| 200 |
+
length = int(txn.get('length'.encode('utf-8')).decode('utf-8'))
|
| 201 |
+
if sample_mode == 'rand_full':
|
| 202 |
+
image_ids = random.sample(range(length // 2), batch_size)
|
| 203 |
+
image_ids_y = image_ids
|
| 204 |
+
elif sample_mode == 'rand_repeat':
|
| 205 |
+
image_ids = random.sample(range(length // 2), 1) * batch_size
|
| 206 |
+
image_ids_y = image_ids
|
| 207 |
+
elif sample_mode == 'rand_y':
|
| 208 |
+
image_ids = random.sample(range(length // 2), 1) * batch_size
|
| 209 |
+
image_ids_y = random.sample(range(length // 2), batch_size)
|
| 210 |
+
else:
|
| 211 |
+
raise NotImplementedError
|
| 212 |
+
features, labels = [], []
|
| 213 |
+
for image_id, image_id_y in zip(image_ids, image_ids_y):
|
| 214 |
+
feat_bytes = txn.get(f'feat-{str(image_id)}'.encode('utf-8'))
|
| 215 |
+
y_bytes = txn.get(f'y-{str(image_id_y)}'.encode('utf-8'))
|
| 216 |
+
feat = np.frombuffer(feat_bytes, dtype=np.float32).reshape([feat_dim]).copy()
|
| 217 |
+
y = int(y_bytes.decode('utf-8'))
|
| 218 |
+
features.append(feat)
|
| 219 |
+
labels.append(y)
|
| 220 |
+
features = torch.from_numpy(np.stack(features)).to(device)
|
| 221 |
+
labels = torch.from_numpy(np.array(labels)).to(device)
|
| 222 |
+
class_labels = torch.zeros([batch_size, num_classes], device=device)
|
| 223 |
+
if num_classes > 0:
|
| 224 |
+
class_labels = torch.eye(num_classes, device=device)[labels]
|
| 225 |
+
assert features.shape[0] == class_labels.shape[0] == batch_size
|
| 226 |
+
return features, class_labels
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
@torch.no_grad()
|
| 231 |
+
def generate_with_net(args, net, device, rank, size):
|
| 232 |
+
seeds = args.seeds
|
| 233 |
+
num_batches = ((len(seeds) - 1) // (args.max_batch_size * size) + 1) * size
|
| 234 |
+
all_batches = torch.as_tensor(seeds).tensor_split(num_batches)
|
| 235 |
+
rank_batches = all_batches[rank:: size]
|
| 236 |
+
|
| 237 |
+
net.eval()
|
| 238 |
+
|
| 239 |
+
# Setup sampler
|
| 240 |
+
sampler_kwargs = dict(num_steps=args.num_steps, S_churn=args.S_churn,
|
| 241 |
+
solver=args.solver, discretization=args.discretization,
|
| 242 |
+
schedule=args.schedule, scaling=args.scaling)
|
| 243 |
+
sampler_kwargs = {key: value for key, value in sampler_kwargs.items() if value is not None}
|
| 244 |
+
have_ablation_kwargs = any(x in sampler_kwargs for x in ['solver', 'discretization', 'schedule', 'scaling'])
|
| 245 |
+
sampler_fn = ablation_sampler if have_ablation_kwargs else edm_sampler
|
| 246 |
+
mprint(f"sampler_kwargs: {sampler_kwargs}, \nsampler fn: {sampler_fn.__name__}")
|
| 247 |
+
# Setup autoencoder
|
| 248 |
+
vae = autoencoder.get_model(args.pretrained_path).to(device)
|
| 249 |
+
|
| 250 |
+
# generate images
|
| 251 |
+
mprint(f'Generating {len(seeds)} images to "{args.outdir}"...')
|
| 252 |
+
for batch_seeds in tqdm(rank_batches, unit='batch', disable=(rank != 0)):
|
| 253 |
+
dist.barrier()
|
| 254 |
+
batch_size = len(batch_seeds)
|
| 255 |
+
if batch_size == 0:
|
| 256 |
+
continue
|
| 257 |
+
|
| 258 |
+
# Pick latents and labels.
|
| 259 |
+
rnd = StackedRandomGenerator(device, batch_seeds)
|
| 260 |
+
latents = rnd.randn([batch_size, net.img_channels, net.img_resolution, net.img_resolution], device=device)
|
| 261 |
+
class_labels = torch.zeros([batch_size, net.num_classes], device=device)
|
| 262 |
+
if net.num_classes:
|
| 263 |
+
class_labels = torch.eye(net.num_classes, device=device)[
|
| 264 |
+
rnd.randint(net.num_classes, size=[batch_size], device=device)]
|
| 265 |
+
if args.class_idx is not None:
|
| 266 |
+
class_labels[:, :] = 0
|
| 267 |
+
class_labels[:, args.class_idx] = 1
|
| 268 |
+
|
| 269 |
+
# retrieve features from training set [support random only]
|
| 270 |
+
feat = None
|
| 271 |
+
|
| 272 |
+
# Generate images.
|
| 273 |
+
def recur_decode(z):
|
| 274 |
+
try:
|
| 275 |
+
return vae.decode(z)
|
| 276 |
+
except: # reduce the batch for vae decoder but two forward passes when OOM happens occasionally
|
| 277 |
+
assert z.shape[2] % 2 == 0
|
| 278 |
+
z1, z2 = z.tensor_split(2)
|
| 279 |
+
return torch.cat([recur_decode(z1), recur_decode(z2)])
|
| 280 |
+
|
| 281 |
+
with torch.no_grad():
|
| 282 |
+
z = sampler_fn(net, latents.float(), class_labels.float(), randn_like=rnd.randn_like,
|
| 283 |
+
cfg_scale=args.cfg_scale, feat=feat, **sampler_kwargs).float()
|
| 284 |
+
images = recur_decode(z)
|
| 285 |
+
|
| 286 |
+
# Save images.
|
| 287 |
+
images_np = images.add_(1).mul(127.5).clamp_(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()
|
| 288 |
+
# images_np = (images * 127.5 + 128).clip(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()
|
| 289 |
+
for seed, image_np in zip(batch_seeds, images_np):
|
| 290 |
+
image_dir = os.path.join(args.outdir, f'{seed - seed % 1000:06d}') if args.subdirs else args.outdir
|
| 291 |
+
os.makedirs(image_dir, exist_ok=True)
|
| 292 |
+
image_path = os.path.join(image_dir, f'{seed:06d}.png')
|
| 293 |
+
if image_np.shape[2] == 1:
|
| 294 |
+
PIL.Image.fromarray(image_np[:, :, 0], 'L').save(image_path)
|
| 295 |
+
else:
|
| 296 |
+
PIL.Image.fromarray(image_np, 'RGB').save(image_path)
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def generate(args):
|
| 300 |
+
device = torch.device("cuda")
|
| 301 |
+
|
| 302 |
+
mprint(f'cf_scale: {args.cfg_scale}')
|
| 303 |
+
if args.global_rank == 0:
|
| 304 |
+
os.makedirs(args.outdir, exist_ok=True)
|
| 305 |
+
logger = Logger(file_name=f'{args.outdir}/log.txt', file_mode="a+", should_flush=True)
|
| 306 |
+
|
| 307 |
+
# Create model:
|
| 308 |
+
net = Precond_models[args.precond](
|
| 309 |
+
img_resolution=args.image_size,
|
| 310 |
+
img_channels=args.image_channels,
|
| 311 |
+
num_classes=args.num_classes,
|
| 312 |
+
model_type=args.model_type,
|
| 313 |
+
use_decoder=args.use_decoder,
|
| 314 |
+
mae_loss_coef=args.mae_loss_coef,
|
| 315 |
+
pad_cls_token=args.pad_cls_token,
|
| 316 |
+
ext_feature_dim=args.ext_feature_dim
|
| 317 |
+
).to(device)
|
| 318 |
+
mprint(
|
| 319 |
+
f"{args.model_type} (use_decoder: {args.use_decoder}) Model Parameters: {sum(p.numel() for p in net.parameters()):,}")
|
| 320 |
+
|
| 321 |
+
# Load checkpoints
|
| 322 |
+
ckpt = torch.load(args.ckpt_path, map_location=device)
|
| 323 |
+
net.load_state_dict(ckpt['ema'])
|
| 324 |
+
mprint(f'Load weights from {args.ckpt_path}')
|
| 325 |
+
|
| 326 |
+
generate_with_net(args, net, device)
|
| 327 |
+
|
| 328 |
+
# Done.
|
| 329 |
+
cleanup()
|
| 330 |
+
if args.global_rank == 0:
|
| 331 |
+
logger.close()
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
if __name__ == '__main__':
|
| 335 |
+
parser = argparse.ArgumentParser('sampling parameters')
|
| 336 |
+
|
| 337 |
+
# ddp
|
| 338 |
+
parser.add_argument('--num_proc_node', type=int, default=1, help='The number of nodes in multi node env.')
|
| 339 |
+
parser.add_argument('--num_process_per_node', type=int, default=1, help='number of gpus')
|
| 340 |
+
parser.add_argument('--node_rank', type=int, default=0, help='The index of node.')
|
| 341 |
+
parser.add_argument('--local_rank', type=int, default=0, help='rank of process in the node')
|
| 342 |
+
parser.add_argument('--master_address', type=str, default='localhost', help='address for master')
|
| 343 |
+
|
| 344 |
+
# sampling
|
| 345 |
+
parser.add_argument("--feat_path", type=str, default='')
|
| 346 |
+
parser.add_argument("--ext_feature_dim", type=int, default=0)
|
| 347 |
+
parser.add_argument('--ckpt_path', type=str, required=True, help='Network pickle filename')
|
| 348 |
+
parser.add_argument('--outdir', type=str, required=True, help='sampling results save filename')
|
| 349 |
+
parser.add_argument('--seeds', type=parse_int_list, default='0-63', help='Random seeds (e.g. 1,2,5-10)')
|
| 350 |
+
parser.add_argument('--subdirs', action='store_true', help='Create subdirectory for every 1000 seeds')
|
| 351 |
+
parser.add_argument('--class_idx', type=int, default=None, help='Class label [default: random]')
|
| 352 |
+
parser.add_argument('--max_batch_size', type=int, default=64, help='Maximum batch size per GPU')
|
| 353 |
+
|
| 354 |
+
parser.add_argument("--cfg_scale", type=parse_float_none, default=None, help='None = no guidance, by default = 4.0')
|
| 355 |
+
|
| 356 |
+
parser.add_argument('--num_steps', type=int, default=18, help='Number of sampling steps')
|
| 357 |
+
parser.add_argument('--S_churn', type=int, default=0, help='Stochasticity strength')
|
| 358 |
+
parser.add_argument('--solver', type=str, default=None, choices=['euler', 'heun'], help='Ablate ODE solver')
|
| 359 |
+
parser.add_argument('--discretization', type=str, default=None, choices=['vp', 've', 'iddpm', 'edm'],
|
| 360 |
+
help='Ablate ODE solver')
|
| 361 |
+
parser.add_argument('--schedule', type=str, default=None, choices=['vp', 've', 'linear'],
|
| 362 |
+
help='Ablate noise schedule sigma(t)')
|
| 363 |
+
parser.add_argument('--scaling', type=str, default=None, choices=['vp', 'none'], help='Ablate signal scaling s(t)')
|
| 364 |
+
parser.add_argument('--pretrained_path', type=str, default='assets/stable_diffusion/autoencoder_kl.pth',
|
| 365 |
+
help='Autoencoder ckpt')
|
| 366 |
+
|
| 367 |
+
# model
|
| 368 |
+
parser.add_argument("--image_size", type=int, default=32)
|
| 369 |
+
parser.add_argument("--image_channels", type=int, default=4)
|
| 370 |
+
parser.add_argument("--num_classes", type=int, default=1000, help='0 means unconditional')
|
| 371 |
+
parser.add_argument("--model_type", type=str, choices=list(DiT_models.keys()), default="DiT-XL/2")
|
| 372 |
+
parser.add_argument('--precond', type=str, choices=['vp', 've', 'edm'], default='edm', help='precond train & loss')
|
| 373 |
+
parser.add_argument("--use_decoder", type=str2bool, default=False)
|
| 374 |
+
parser.add_argument("--pad_cls_token", type=str2bool, default=False)
|
| 375 |
+
parser.add_argument('--mae_loss_coef', type=float, default=0, help='0 means no MAE loss')
|
| 376 |
+
parser.add_argument('--sample_mode', type=str, default='rand_full', help='[rand_full, rand_repeat]')
|
| 377 |
+
|
| 378 |
+
args = parser.parse_args()
|
| 379 |
+
args.global_size = args.num_proc_node * args.num_process_per_node
|
| 380 |
+
size = args.num_process_per_node
|
| 381 |
+
|
| 382 |
+
if size > 1:
|
| 383 |
+
processes = []
|
| 384 |
+
for rank in range(size):
|
| 385 |
+
args.local_rank = rank
|
| 386 |
+
args.global_rank = rank + args.node_rank * args.num_process_per_node
|
| 387 |
+
p = Process(target=init_processes, args=(generate, args))
|
| 388 |
+
p.start()
|
| 389 |
+
processes.append(p)
|
| 390 |
+
|
| 391 |
+
for p in processes:
|
| 392 |
+
p.join()
|
| 393 |
+
else:
|
| 394 |
+
print('Single GPU run')
|
| 395 |
+
assert args.global_size == 1 and args.local_rank == 0
|
| 396 |
+
args.global_rank = 0
|
| 397 |
+
init_processes(generate, args)
|
scripts/download_assets.sh
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# download pretrained VAE
|
| 2 |
+
python3 download_assets.py --name vae --dest assets/stable_diffusion
|
| 3 |
+
|
| 4 |
+
# download ImageNet256 training set
|
| 5 |
+
python3 download_assets.py --name imagenet256-latent-lmdb --dest ../data/imagenet256
|
| 6 |
+
|
| 7 |
+
# download ImageNet512 training set
|
| 8 |
+
python3 download_assets.py --name imagenet512-latent-wds --dest ../data/imagenet512-wds
|
scripts/finetune_latent512.sh
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate launch \
|
| 2 |
+
--main_process_ip $MASTER_ADDR \
|
| 3 |
+
--main_process_port $MASTER_PORT \
|
| 4 |
+
--num_machines 4 \
|
| 5 |
+
--machine_rank $NODE_RANK \
|
| 6 |
+
--num_processes 32 \
|
| 7 |
+
train_wds.py \
|
| 8 |
+
--config configs/finetune/imagenet512-latent.yaml \
|
| 9 |
+
--resample \
|
| 10 |
+
--ckpt_path checkpoints/1050000.pt \
|
| 11 |
+
--use_ckpt_path False --use_strict_load False \
|
| 12 |
+
--no_amp
|
| 13 |
+
|
| 14 |
+
|
scripts/prepare_latent256.sh
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Encode ImageNet 256x256 into latent space
|
| 2 |
+
|
| 3 |
+
python3 extract_latent.py --resolution 256 --ckpt assets/vae/autoencoder_kl.pth --batch_size 64 --outdir ../data/imagenet256-latent
|
scripts/prepare_latent512.sh
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Encode ImageNet 512x512 into latent space
|
| 2 |
+
|
| 3 |
+
python3 extract_latent.py --resolution 512 --ckpt assets/vae/autoencoder_kl.pth --batch_size 64 --outdir ../data/imagenet512-latent
|
| 4 |
+
|
| 5 |
+
# Convert lmdb to webdataset
|
| 6 |
+
python3 lmdb2wds.py --maxcount 10010 --datadir ../data/imagenet512-latent --outdir ../data/imagenet512-latent-wds --resolution 64 --num_channels 8
|
scripts/train_latent512.sh
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate launch \
|
| 2 |
+
--main_process_ip $MASTER_ADDR \
|
| 3 |
+
--main_process_port $MASTER_PORT \
|
| 4 |
+
--num_machines 4 \
|
| 5 |
+
--machine_rank $NODE_RANK \
|
| 6 |
+
--num_processes 32 \
|
| 7 |
+
train_wds.py \
|
| 8 |
+
--config configs/train/imagenet512-latent.yaml \
|
| 9 |
+
--resample
|
| 10 |
+
|
| 11 |
+
|
torch_utils/__init__.py
ADDED
|
File without changes
|
torch_utils/persistence.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MIT License
|
| 2 |
+
|
| 3 |
+
# Copyright (c) [2023] [Anima-Lab]
|
| 4 |
+
|
| 5 |
+
# This code is adapted from https://github.com/NVlabs/edm/blob/main/fid.py.
|
| 6 |
+
# The original code is licensed under a Creative Commons
|
| 7 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License, which is can be found at licenses/LICENSE_EDM.txt.
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
"""Facilities for pickling Python code alongside other data.
|
| 11 |
+
|
| 12 |
+
The pickled code is automatically imported into a separate Python module
|
| 13 |
+
during unpickling. This way, any previously exported pickles will remain
|
| 14 |
+
usable even if the original code is no longer available, or if the current
|
| 15 |
+
version of the code is not consistent with what was originally pickled."""
|
| 16 |
+
|
| 17 |
+
import sys
|
| 18 |
+
import pickle
|
| 19 |
+
import io
|
| 20 |
+
import inspect
|
| 21 |
+
import copy
|
| 22 |
+
import uuid
|
| 23 |
+
import types
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
#----------------------------------------------------------------------------
|
| 27 |
+
|
| 28 |
+
class EasyDict(dict):
|
| 29 |
+
"""Convenience class that behaves like a dict but allows access with the attribute syntax."""
|
| 30 |
+
|
| 31 |
+
def __getattr__(self, name):
|
| 32 |
+
try:
|
| 33 |
+
return self[name]
|
| 34 |
+
except KeyError:
|
| 35 |
+
raise AttributeError(name)
|
| 36 |
+
|
| 37 |
+
def __setattr__(self, name, value):
|
| 38 |
+
self[name] = value
|
| 39 |
+
|
| 40 |
+
def __delattr__(self, name):
|
| 41 |
+
del self[name]
|
| 42 |
+
|
| 43 |
+
#----------------------------------------------------------------------------
|
| 44 |
+
|
| 45 |
+
_version = 6 # internal version number
|
| 46 |
+
_decorators = set() # {decorator_class, ...}
|
| 47 |
+
_import_hooks = [] # [hook_function, ...]
|
| 48 |
+
_module_to_src_dict = dict() # {module: src, ...}
|
| 49 |
+
_src_to_module_dict = dict() # {src: module, ...}
|
| 50 |
+
|
| 51 |
+
#----------------------------------------------------------------------------
|
| 52 |
+
|
| 53 |
+
def persistent_class(orig_class):
|
| 54 |
+
r"""Class decorator that extends a given class to save its source code
|
| 55 |
+
when pickled.
|
| 56 |
+
|
| 57 |
+
Example:
|
| 58 |
+
|
| 59 |
+
from torch_utils import persistence
|
| 60 |
+
|
| 61 |
+
@persistence.persistent_class
|
| 62 |
+
class MyNetwork(torch.nn.Module):
|
| 63 |
+
def __init__(self, num_inputs, num_outputs):
|
| 64 |
+
super().__init__()
|
| 65 |
+
self.fc = MyLayer(num_inputs, num_outputs)
|
| 66 |
+
...
|
| 67 |
+
|
| 68 |
+
@persistence.persistent_class
|
| 69 |
+
class MyLayer(torch.nn.Module):
|
| 70 |
+
...
|
| 71 |
+
|
| 72 |
+
When pickled, any instance of `MyNetwork` and `MyLayer` will save its
|
| 73 |
+
source code alongside other internal state (e.g., parameters, buffers,
|
| 74 |
+
and submodules). This way, any previously exported pickle will remain
|
| 75 |
+
usable even if the class definitions have been modified or are no
|
| 76 |
+
longer available.
|
| 77 |
+
|
| 78 |
+
The decorator saves the source code of the entire Python module
|
| 79 |
+
containing the decorated class. It does *not* save the source code of
|
| 80 |
+
any imported modules. Thus, the imported modules must be available
|
| 81 |
+
during unpickling, also including `torch_utils.persistence` itself.
|
| 82 |
+
|
| 83 |
+
It is ok to call functions defined in the same module from the
|
| 84 |
+
decorated class. However, if the decorated class depends on other
|
| 85 |
+
classes defined in the same module, they must be decorated as well.
|
| 86 |
+
This is illustrated in the above example in the case of `MyLayer`.
|
| 87 |
+
|
| 88 |
+
It is also possible to employ the decorator just-in-time before
|
| 89 |
+
calling the constructor. For example:
|
| 90 |
+
|
| 91 |
+
cls = MyLayer
|
| 92 |
+
if want_to_make_it_persistent:
|
| 93 |
+
cls = persistence.persistent_class(cls)
|
| 94 |
+
layer = cls(num_inputs, num_outputs)
|
| 95 |
+
|
| 96 |
+
As an additional feature, the decorator also keeps track of the
|
| 97 |
+
arguments that were used to construct each instance of the decorated
|
| 98 |
+
class. The arguments can be queried via `obj.init_args` and
|
| 99 |
+
`obj.init_kwargs`, and they are automatically pickled alongside other
|
| 100 |
+
object state. This feature can be disabled on a per-instance basis
|
| 101 |
+
by setting `self._record_init_args = False` in the constructor.
|
| 102 |
+
|
| 103 |
+
A typical use case is to first unpickle a previous instance of a
|
| 104 |
+
persistent class, and then upgrade it to use the latest version of
|
| 105 |
+
the source code:
|
| 106 |
+
|
| 107 |
+
with open('old_pickle.pkl', 'rb') as f:
|
| 108 |
+
old_net = pickle.load(f)
|
| 109 |
+
new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
|
| 110 |
+
misc.copy_params_and_buffers(old_net, new_net, require_all=True)
|
| 111 |
+
"""
|
| 112 |
+
assert isinstance(orig_class, type)
|
| 113 |
+
if is_persistent(orig_class):
|
| 114 |
+
return orig_class
|
| 115 |
+
|
| 116 |
+
assert orig_class.__module__ in sys.modules
|
| 117 |
+
orig_module = sys.modules[orig_class.__module__]
|
| 118 |
+
orig_module_src = _module_to_src(orig_module)
|
| 119 |
+
|
| 120 |
+
class Decorator(orig_class):
|
| 121 |
+
_orig_module_src = orig_module_src
|
| 122 |
+
_orig_class_name = orig_class.__name__
|
| 123 |
+
|
| 124 |
+
def __init__(self, *args, **kwargs):
|
| 125 |
+
super().__init__(*args, **kwargs)
|
| 126 |
+
record_init_args = getattr(self, '_record_init_args', True)
|
| 127 |
+
self._init_args = copy.deepcopy(args) if record_init_args else None
|
| 128 |
+
self._init_kwargs = copy.deepcopy(kwargs) if record_init_args else None
|
| 129 |
+
assert orig_class.__name__ in orig_module.__dict__
|
| 130 |
+
_check_pickleable(self.__reduce__())
|
| 131 |
+
|
| 132 |
+
@property
|
| 133 |
+
def init_args(self):
|
| 134 |
+
assert self._init_args is not None
|
| 135 |
+
return copy.deepcopy(self._init_args)
|
| 136 |
+
|
| 137 |
+
@property
|
| 138 |
+
def init_kwargs(self):
|
| 139 |
+
assert self._init_kwargs is not None
|
| 140 |
+
return EasyDict(copy.deepcopy(self._init_kwargs))
|
| 141 |
+
|
| 142 |
+
def __reduce__(self):
|
| 143 |
+
fields = list(super().__reduce__())
|
| 144 |
+
fields += [None] * max(3 - len(fields), 0)
|
| 145 |
+
if fields[0] is not _reconstruct_persistent_obj:
|
| 146 |
+
meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
|
| 147 |
+
fields[0] = _reconstruct_persistent_obj # reconstruct func
|
| 148 |
+
fields[1] = (meta,) # reconstruct args
|
| 149 |
+
fields[2] = None # state dict
|
| 150 |
+
return tuple(fields)
|
| 151 |
+
|
| 152 |
+
Decorator.__name__ = orig_class.__name__
|
| 153 |
+
Decorator.__module__ = orig_class.__module__
|
| 154 |
+
_decorators.add(Decorator)
|
| 155 |
+
return Decorator
|
| 156 |
+
|
| 157 |
+
#----------------------------------------------------------------------------
|
| 158 |
+
|
| 159 |
+
def is_persistent(obj):
|
| 160 |
+
r"""Test whether the given object or class is persistent, i.e.,
|
| 161 |
+
whether it will save its source code when pickled.
|
| 162 |
+
"""
|
| 163 |
+
try:
|
| 164 |
+
if obj in _decorators:
|
| 165 |
+
return True
|
| 166 |
+
except TypeError:
|
| 167 |
+
pass
|
| 168 |
+
return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
|
| 169 |
+
|
| 170 |
+
#----------------------------------------------------------------------------
|
| 171 |
+
|
| 172 |
+
def import_hook(hook):
|
| 173 |
+
r"""Register an import hook that is called whenever a persistent object
|
| 174 |
+
is being unpickled. A typical use case is to patch the pickled source
|
| 175 |
+
code to avoid errors and inconsistencies when the API of some imported
|
| 176 |
+
module has changed.
|
| 177 |
+
|
| 178 |
+
The hook should have the following signature:
|
| 179 |
+
|
| 180 |
+
hook(meta) -> modified meta
|
| 181 |
+
|
| 182 |
+
`meta` is an instance of `EasyDict` with the following fields:
|
| 183 |
+
|
| 184 |
+
type: Type of the persistent object, e.g. `'class'`.
|
| 185 |
+
version: Internal version number of `torch_utils.persistence`.
|
| 186 |
+
module_src Original source code of the Python module.
|
| 187 |
+
class_name: Class name in the original Python module.
|
| 188 |
+
state: Internal state of the object.
|
| 189 |
+
|
| 190 |
+
Example:
|
| 191 |
+
|
| 192 |
+
@persistence.import_hook
|
| 193 |
+
def wreck_my_network(meta):
|
| 194 |
+
if meta.class_name == 'MyNetwork':
|
| 195 |
+
print('MyNetwork is being imported. I will wreck it!')
|
| 196 |
+
meta.module_src = meta.module_src.replace("True", "False")
|
| 197 |
+
return meta
|
| 198 |
+
"""
|
| 199 |
+
assert callable(hook)
|
| 200 |
+
_import_hooks.append(hook)
|
| 201 |
+
|
| 202 |
+
#----------------------------------------------------------------------------
|
| 203 |
+
|
| 204 |
+
def _reconstruct_persistent_obj(meta):
|
| 205 |
+
r"""Hook that is called internally by the `pickle` module to unpickle
|
| 206 |
+
a persistent object.
|
| 207 |
+
"""
|
| 208 |
+
meta = EasyDict(meta)
|
| 209 |
+
meta.state = EasyDict(meta.state)
|
| 210 |
+
for hook in _import_hooks:
|
| 211 |
+
meta = hook(meta)
|
| 212 |
+
assert meta is not None
|
| 213 |
+
|
| 214 |
+
assert meta.version == _version
|
| 215 |
+
module = _src_to_module(meta.module_src)
|
| 216 |
+
|
| 217 |
+
assert meta.type == 'class'
|
| 218 |
+
orig_class = module.__dict__[meta.class_name]
|
| 219 |
+
decorator_class = persistent_class(orig_class)
|
| 220 |
+
obj = decorator_class.__new__(decorator_class)
|
| 221 |
+
|
| 222 |
+
setstate = getattr(obj, '__setstate__', None)
|
| 223 |
+
if callable(setstate):
|
| 224 |
+
setstate(meta.state) # pylint: disable=not-callable
|
| 225 |
+
else:
|
| 226 |
+
obj.__dict__.update(meta.state)
|
| 227 |
+
return obj
|
| 228 |
+
|
| 229 |
+
#----------------------------------------------------------------------------
|
| 230 |
+
|
| 231 |
+
def _module_to_src(module):
|
| 232 |
+
r"""Query the source code of a given Python module.
|
| 233 |
+
"""
|
| 234 |
+
src = _module_to_src_dict.get(module, None)
|
| 235 |
+
if src is None:
|
| 236 |
+
src = inspect.getsource(module)
|
| 237 |
+
_module_to_src_dict[module] = src
|
| 238 |
+
_src_to_module_dict[src] = module
|
| 239 |
+
return src
|
| 240 |
+
|
| 241 |
+
def _src_to_module(src):
|
| 242 |
+
r"""Get or create a Python module for the given source code.
|
| 243 |
+
"""
|
| 244 |
+
module = _src_to_module_dict.get(src, None)
|
| 245 |
+
if module is None:
|
| 246 |
+
module_name = "_imported_module_" + uuid.uuid4().hex
|
| 247 |
+
module = types.ModuleType(module_name)
|
| 248 |
+
sys.modules[module_name] = module
|
| 249 |
+
_module_to_src_dict[module] = src
|
| 250 |
+
_src_to_module_dict[src] = module
|
| 251 |
+
exec(src, module.__dict__) # pylint: disable=exec-used
|
| 252 |
+
return module
|
| 253 |
+
|
| 254 |
+
#----------------------------------------------------------------------------
|
| 255 |
+
|
| 256 |
+
def _check_pickleable(obj):
|
| 257 |
+
r"""Check that the given object is pickleable, raising an exception if
|
| 258 |
+
it is not. This function is expected to be considerably more efficient
|
| 259 |
+
than actually pickling the object.
|
| 260 |
+
"""
|
| 261 |
+
def recurse(obj):
|
| 262 |
+
if isinstance(obj, (list, tuple, set)):
|
| 263 |
+
return [recurse(x) for x in obj]
|
| 264 |
+
if isinstance(obj, dict):
|
| 265 |
+
return [[recurse(x), recurse(y)] for x, y in obj.items()]
|
| 266 |
+
if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
|
| 267 |
+
return None # Python primitive types are pickleable.
|
| 268 |
+
if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']:
|
| 269 |
+
return None # NumPy arrays and PyTorch tensors are pickleable.
|
| 270 |
+
if is_persistent(obj):
|
| 271 |
+
return None # Persistent objects are pickleable, by virtue of the constructor check.
|
| 272 |
+
return obj
|
| 273 |
+
with io.BytesIO() as f:
|
| 274 |
+
pickle.dump(recurse(obj), f)
|
| 275 |
+
|
| 276 |
+
#----------------------------------------------------------------------------
|
train.py
ADDED
|
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MIT License
|
| 2 |
+
|
| 3 |
+
# Copyright (c) [2023] [Anima-Lab]
|
| 4 |
+
'''
|
| 5 |
+
Training MaskDiT on latent dataset in LMDB format. Used for experiments on Imagenet256x256.
|
| 6 |
+
'''
|
| 7 |
+
|
| 8 |
+
import argparse
|
| 9 |
+
import os.path
|
| 10 |
+
from copy import deepcopy
|
| 11 |
+
from time import time
|
| 12 |
+
from omegaconf import OmegaConf
|
| 13 |
+
|
| 14 |
+
import apex
|
| 15 |
+
import torch
|
| 16 |
+
import accelerate
|
| 17 |
+
|
| 18 |
+
from torch.utils.data import DataLoader
|
| 19 |
+
|
| 20 |
+
from fid import calc
|
| 21 |
+
from models.maskdit import Precond_models
|
| 22 |
+
from train_utils.loss import Losses
|
| 23 |
+
from train_utils.datasets import ImageNetLatentDataset
|
| 24 |
+
|
| 25 |
+
from train_utils.helper import get_mask_ratio_fn, requires_grad, update_ema, unwrap_model
|
| 26 |
+
|
| 27 |
+
from sample import generate_with_net
|
| 28 |
+
from utils import dist, mprint, get_latest_ckpt, Logger, sample, \
|
| 29 |
+
str2bool, parse_str_none, parse_int_list, parse_float_none
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# ------------------------------------------------------------
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def train_loop(args):
|
| 36 |
+
# load configuration
|
| 37 |
+
config = OmegaConf.load(args.config)
|
| 38 |
+
|
| 39 |
+
if not args.no_amp:
|
| 40 |
+
config.train.amp = 'fp16'
|
| 41 |
+
else:
|
| 42 |
+
config.train.amp = 'no'
|
| 43 |
+
|
| 44 |
+
if config.train.tf32:
|
| 45 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 46 |
+
torch.set_float32_matmul_precision('high')
|
| 47 |
+
|
| 48 |
+
accelerator = accelerate.Accelerator(mixed_precision=config.train.amp,
|
| 49 |
+
gradient_accumulation_steps=config.train.grad_accum,
|
| 50 |
+
log_with='wandb')
|
| 51 |
+
# setup wandb
|
| 52 |
+
if args.use_wandb:
|
| 53 |
+
wandb_init_kwargs = {
|
| 54 |
+
'entity': config.wandb.entity,
|
| 55 |
+
'project': config.wandb.project,
|
| 56 |
+
'group': config.wandb.group,
|
| 57 |
+
}
|
| 58 |
+
accelerator.init_trackers(config.wandb.project, config=OmegaConf.to_container(config), init_kwargs=wandb_init_kwargs)
|
| 59 |
+
|
| 60 |
+
mprint('start training...')
|
| 61 |
+
size = accelerator.num_processes
|
| 62 |
+
rank = accelerator.process_index
|
| 63 |
+
|
| 64 |
+
print(f'global_rank: {rank}, global_size: {size}')
|
| 65 |
+
device = accelerator.device
|
| 66 |
+
|
| 67 |
+
seed = args.global_seed
|
| 68 |
+
torch.manual_seed(seed)
|
| 69 |
+
|
| 70 |
+
mprint(f"enable_amp: {not args.no_amp}, TF32: {config.train.tf32}")
|
| 71 |
+
# Select batch size per GPU
|
| 72 |
+
num_accumulation_rounds = config.train.grad_accum
|
| 73 |
+
micro_batch = config.train.batchsize
|
| 74 |
+
batch_gpu_total = micro_batch * num_accumulation_rounds
|
| 75 |
+
global_batch_size = batch_gpu_total * size
|
| 76 |
+
mprint(f"Global batchsize: {global_batch_size}, batchsize per GPU: {batch_gpu_total}, micro_batch: {micro_batch}.")
|
| 77 |
+
|
| 78 |
+
class_dropout_prob = config.model.class_dropout_prob
|
| 79 |
+
log_every = config.log.log_every
|
| 80 |
+
ckpt_every = config.log.ckpt_every
|
| 81 |
+
|
| 82 |
+
mask_ratio_fn = get_mask_ratio_fn(config.model.mask_ratio_fn, config.model.mask_ratio, config.model.mask_ratio_min)
|
| 83 |
+
|
| 84 |
+
# Setup an experiment folder
|
| 85 |
+
model_name = config.model.model_type.replace("/", "-") # e.g., DiT-XL/2 --> DiT-XL-2 (for naming folders)
|
| 86 |
+
data_name = config.data.dataset
|
| 87 |
+
if args.ckpt_path is not None and args.use_ckpt_path: # use the existing exp path (mainly used for fine-tuning)
|
| 88 |
+
checkpoint_dir = os.path.dirname(args.ckpt_path)
|
| 89 |
+
experiment_dir = os.path.dirname(checkpoint_dir)
|
| 90 |
+
exp_name = os.path.basename(experiment_dir)
|
| 91 |
+
else: # start a new exp path (and resume from the latest checkpoint if possible)
|
| 92 |
+
cond_gen = 'cond' if config.model.num_classes else 'uncond'
|
| 93 |
+
exp_name = f'{model_name}-{config.model.precond}-{data_name}-{cond_gen}-m{config.model.mask_ratio}-de{int(config.model.use_decoder)}' \
|
| 94 |
+
f'-mae{config.model.mae_loss_coef}-bs-{global_batch_size}-lr{config.train.lr}{config.log.tag}'
|
| 95 |
+
experiment_dir = f"{args.results_dir}/{exp_name}"
|
| 96 |
+
checkpoint_dir = f"{experiment_dir}/checkpoints" # Stores saved model checkpoints
|
| 97 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 98 |
+
if args.ckpt_path is None:
|
| 99 |
+
args.ckpt_path = get_latest_ckpt(checkpoint_dir) # Resumes from the latest checkpoint if it exists
|
| 100 |
+
mprint(f"Experiment directory created at {experiment_dir}")
|
| 101 |
+
|
| 102 |
+
if accelerator.is_main_process:
|
| 103 |
+
logger = Logger(file_name=f'{experiment_dir}/log.txt', file_mode="a+", should_flush=True)
|
| 104 |
+
|
| 105 |
+
mprint(f"Experiment directory created at {experiment_dir}")
|
| 106 |
+
# Setup dataset
|
| 107 |
+
dataset = ImageNetLatentDataset(
|
| 108 |
+
config.data.root, resolution=config.data.resolution,
|
| 109 |
+
num_channels=config.data.num_channels, xflip=config.train.xflip,
|
| 110 |
+
feat_path=config.data.feat_path, feat_dim=config.model.ext_feature_dim)
|
| 111 |
+
|
| 112 |
+
loader = DataLoader(
|
| 113 |
+
dataset, batch_size=batch_gpu_total, shuffle=False,
|
| 114 |
+
num_workers=args.num_workers,
|
| 115 |
+
pin_memory=True, persistent_workers=True,
|
| 116 |
+
drop_last=True
|
| 117 |
+
)
|
| 118 |
+
mprint(f"Dataset contains {len(dataset):,} images ({config.data.root})")
|
| 119 |
+
|
| 120 |
+
steps_per_epoch = len(dataset) // global_batch_size
|
| 121 |
+
mprint(f"{steps_per_epoch} steps per epoch")
|
| 122 |
+
|
| 123 |
+
model = Precond_models[config.model.precond](
|
| 124 |
+
img_resolution=config.model.in_size,
|
| 125 |
+
img_channels=config.model.in_channels,
|
| 126 |
+
num_classes=config.model.num_classes,
|
| 127 |
+
model_type=config.model.model_type,
|
| 128 |
+
use_decoder=config.model.use_decoder,
|
| 129 |
+
mae_loss_coef=config.model.mae_loss_coef,
|
| 130 |
+
pad_cls_token=config.model.pad_cls_token
|
| 131 |
+
).to(device)
|
| 132 |
+
|
| 133 |
+
# Note that parameter initialization is done within the model constructor
|
| 134 |
+
ema = deepcopy(model).to(device) # Create an EMA of the model for use after training
|
| 135 |
+
requires_grad(ema, False)
|
| 136 |
+
|
| 137 |
+
mprint(f"{config.model.model_type} ((use_decoder: {config.model.use_decoder})) Model Parameters: {sum(p.numel() for p in model.parameters()):,}")
|
| 138 |
+
mprint(f'extras: {model.model.extras}, cls_token: {model.model.cls_token}')
|
| 139 |
+
|
| 140 |
+
# Setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper):
|
| 141 |
+
optimizer = apex.optimizers.FusedAdam(model.parameters(), lr=config.train.lr, adam_w_mode=True, weight_decay=0)
|
| 142 |
+
|
| 143 |
+
# Load checkpoints
|
| 144 |
+
train_steps_start = 0
|
| 145 |
+
epoch_start = 0
|
| 146 |
+
|
| 147 |
+
if args.ckpt_path is not None:
|
| 148 |
+
ckpt = torch.load(args.ckpt_path, map_location=device)
|
| 149 |
+
model.load_state_dict(ckpt['model'], strict=args.use_strict_load)
|
| 150 |
+
ema.load_state_dict(ckpt['ema'], strict=args.use_strict_load)
|
| 151 |
+
mprint(f'Load weights from {args.ckpt_path}')
|
| 152 |
+
if args.use_strict_load:
|
| 153 |
+
optimizer.load_state_dict(ckpt['opt'])
|
| 154 |
+
for state in optimizer.state.values():
|
| 155 |
+
for k, v in state.items():
|
| 156 |
+
if isinstance(v, torch.Tensor):
|
| 157 |
+
state[k] = v.cuda()
|
| 158 |
+
mprint(f'Load optimizer state..')
|
| 159 |
+
train_steps_start = int(os.path.basename(args.ckpt_path).split('.pt')[0])
|
| 160 |
+
epoch_start = train_steps_start // steps_per_epoch
|
| 161 |
+
mprint(f"train_steps_start: {train_steps_start}")
|
| 162 |
+
del ckpt # conserve memory
|
| 163 |
+
|
| 164 |
+
# FID evaluation for the loaded weights
|
| 165 |
+
if args.enable_eval:
|
| 166 |
+
start_time = time()
|
| 167 |
+
args.outdir = os.path.join(experiment_dir, 'fid', f'edm-steps{args.num_steps}-ckpt{train_steps_start}_cfg{args.cfg_scale}')
|
| 168 |
+
os.makedirs(args.outdir, exist_ok=True)
|
| 169 |
+
generate_with_net(args, ema, device)
|
| 170 |
+
dist.barrier()
|
| 171 |
+
fid = calc(args.outdir, config.eval.ref_path, args.num_expected, args.global_seed, args.fid_batch_size)
|
| 172 |
+
mprint(f"time for fid calc: {time() - start_time}")
|
| 173 |
+
if args.use_wandb:
|
| 174 |
+
accelerator.log({f'eval/fid': fid}, step=train_steps_start)
|
| 175 |
+
mprint(f'guidance: {args.cfg_scale} FID: {fid}')
|
| 176 |
+
dist.barrier()
|
| 177 |
+
|
| 178 |
+
model, optimizer, loader = accelerator.prepare(model, optimizer, loader)
|
| 179 |
+
model = torch.compile(model)
|
| 180 |
+
|
| 181 |
+
# Setup loss
|
| 182 |
+
loss_fn = Losses[config.model.precond]()
|
| 183 |
+
|
| 184 |
+
# Prepare models for training:
|
| 185 |
+
if args.ckpt_path is None:
|
| 186 |
+
assert train_steps_start == 0
|
| 187 |
+
raw_model = unwrap_model(model)
|
| 188 |
+
update_ema(ema, raw_model, decay=0) # Ensure EMA is initialized with synced weights
|
| 189 |
+
model.train() # important! This enables embedding dropout for classifier-free guidance
|
| 190 |
+
ema.eval() # EMA model should always be in eval mode
|
| 191 |
+
|
| 192 |
+
# Variables for monitoring/logging purposes:
|
| 193 |
+
train_steps = train_steps_start
|
| 194 |
+
log_steps = 0
|
| 195 |
+
running_loss = 0
|
| 196 |
+
start_time = time()
|
| 197 |
+
mprint(f"Training for {config.train.epochs} epochs...")
|
| 198 |
+
for epoch in range(epoch_start, config.train.epochs):
|
| 199 |
+
mprint(f"Beginning epoch {epoch}...")
|
| 200 |
+
for x, cond in loader:
|
| 201 |
+
x = x.to(device)
|
| 202 |
+
y = cond.to(device)
|
| 203 |
+
x = sample(x)
|
| 204 |
+
# Accumulate gradients.
|
| 205 |
+
loss_batch = 0
|
| 206 |
+
model.zero_grad(set_to_none=True)
|
| 207 |
+
curr_mask_ratio = mask_ratio_fn((train_steps - train_steps_start) / config.train.max_num_steps)
|
| 208 |
+
if class_dropout_prob > 0:
|
| 209 |
+
y = y * (torch.rand([y.shape[0], 1], device=device) >= class_dropout_prob)
|
| 210 |
+
|
| 211 |
+
for round_idx in range(num_accumulation_rounds):
|
| 212 |
+
x_ = x[round_idx * micro_batch: (round_idx + 1) * micro_batch]
|
| 213 |
+
y_ = y[round_idx * micro_batch: (round_idx + 1) * micro_batch]
|
| 214 |
+
|
| 215 |
+
with accelerator.accumulate(model):
|
| 216 |
+
loss = loss_fn(net=model, images=x_, labels=y_,
|
| 217 |
+
mask_ratio=curr_mask_ratio,
|
| 218 |
+
mae_loss_coef=config.model.mae_loss_coef)
|
| 219 |
+
loss_mean = loss.mean()
|
| 220 |
+
accelerator.backward(loss_mean)
|
| 221 |
+
|
| 222 |
+
# Update weights with lr warmup.
|
| 223 |
+
lr_cur = config.train.lr * min(train_steps * global_batch_size / max(config.train.lr_rampup_kimg * 1000, 1e-8), 1)
|
| 224 |
+
for g in optimizer.param_groups:
|
| 225 |
+
g['lr'] = lr_cur
|
| 226 |
+
optimizer.step()
|
| 227 |
+
loss_batch += loss_mean.item()
|
| 228 |
+
|
| 229 |
+
raw_model = unwrap_model(model)
|
| 230 |
+
update_ema(ema, model.module)
|
| 231 |
+
|
| 232 |
+
# Log loss values:
|
| 233 |
+
running_loss += loss_batch
|
| 234 |
+
log_steps += 1
|
| 235 |
+
train_steps += 1
|
| 236 |
+
if train_steps > (train_steps_start + config.train.max_num_steps):
|
| 237 |
+
break
|
| 238 |
+
if train_steps % log_every == 0:
|
| 239 |
+
# Measure training speed:
|
| 240 |
+
torch.cuda.synchronize()
|
| 241 |
+
end_time = time()
|
| 242 |
+
steps_per_sec = log_steps / (end_time - start_time)
|
| 243 |
+
# Reduce loss history over all processes:
|
| 244 |
+
avg_loss = torch.tensor(running_loss / log_steps, device=device)
|
| 245 |
+
dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
|
| 246 |
+
avg_loss = avg_loss.item() / size
|
| 247 |
+
mprint(f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}")
|
| 248 |
+
mprint(f'Peak GPU memory usage: {torch.cuda.max_memory_allocated() / 1024 ** 3:.2f} GB')
|
| 249 |
+
mprint(f'Reserved GPU memory: {torch.cuda.memory_reserved() / 1024 ** 3:.2f} GB')
|
| 250 |
+
|
| 251 |
+
if args.use_wandb:
|
| 252 |
+
accelerator.log({f'train/loss': avg_loss, 'train/lr': lr_cur}, step=train_steps)
|
| 253 |
+
# Reset monitoring variables:
|
| 254 |
+
running_loss = 0
|
| 255 |
+
log_steps = 0
|
| 256 |
+
start_time = time()
|
| 257 |
+
|
| 258 |
+
# Save checkpoint:
|
| 259 |
+
if train_steps % ckpt_every == 0 and train_steps > train_steps_start:
|
| 260 |
+
if rank == 0:
|
| 261 |
+
checkpoint = {
|
| 262 |
+
"model": raw_model.state_dict(),
|
| 263 |
+
"ema": ema.state_dict(),
|
| 264 |
+
"opt": optimizer.state_dict(),
|
| 265 |
+
"args": args
|
| 266 |
+
}
|
| 267 |
+
checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt"
|
| 268 |
+
torch.save(checkpoint, checkpoint_path)
|
| 269 |
+
mprint(f"Saved checkpoint to {checkpoint_path}")
|
| 270 |
+
del checkpoint # conserve memory
|
| 271 |
+
dist.barrier()
|
| 272 |
+
|
| 273 |
+
# FID evaluation during training
|
| 274 |
+
if args.enable_eval:
|
| 275 |
+
start_time = time()
|
| 276 |
+
args.outdir = os.path.join(experiment_dir, 'fid', f'edm-steps{args.num_steps}-ckpt{train_steps}_cfg{args.cfg_scale}')
|
| 277 |
+
os.makedirs(args.outdir, exist_ok=True)
|
| 278 |
+
generate_with_net(args, ema, device, rank, size)
|
| 279 |
+
|
| 280 |
+
dist.barrier()
|
| 281 |
+
fid = calc(args.outdir, args.ref_path, args.num_expected, args.global_seed, args.fid_batch_size)
|
| 282 |
+
mprint(f"time for fid calc: {time() - start_time}, fid: {fid}")
|
| 283 |
+
if args.use_wandb:
|
| 284 |
+
accelerator.log({f'eval/fid': fid}, step=train_steps)
|
| 285 |
+
mprint(f'Guidance: {args.cfg_scale}, FID: {fid}')
|
| 286 |
+
dist.barrier()
|
| 287 |
+
start_time = time()
|
| 288 |
+
|
| 289 |
+
if accelerator.is_main_process:
|
| 290 |
+
logger.close()
|
| 291 |
+
accelerator.end_training()
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
if __name__ == '__main__':
|
| 295 |
+
parser = argparse.ArgumentParser('training parameters')
|
| 296 |
+
# basic config
|
| 297 |
+
parser.add_argument('--config', type=str, required=True, help='path to config file')
|
| 298 |
+
|
| 299 |
+
# training
|
| 300 |
+
parser.add_argument("--results_dir", type=str, default="results")
|
| 301 |
+
parser.add_argument("--ckpt_path", type=parse_str_none, default=None)
|
| 302 |
+
|
| 303 |
+
parser.add_argument("--global_seed", type=int, default=0)
|
| 304 |
+
parser.add_argument("--num_workers", type=int, default=4)
|
| 305 |
+
parser.add_argument('--no_amp', action='store_true', help="Disable automatic mixed precision.")
|
| 306 |
+
|
| 307 |
+
parser.add_argument("--use_wandb", action='store_true', help='enable wandb logging')
|
| 308 |
+
parser.add_argument("--use_ckpt_path", type=str2bool, default=True)
|
| 309 |
+
parser.add_argument("--use_strict_load", type=str2bool, default=True)
|
| 310 |
+
parser.add_argument("--tag", type=str, default='')
|
| 311 |
+
|
| 312 |
+
# sampling
|
| 313 |
+
parser.add_argument('--enable_eval', action='store_true', help='enable fid calc during training')
|
| 314 |
+
parser.add_argument('--seeds', type=parse_int_list, default='0-49999', help='Random seeds (e.g. 1,2,5-10)')
|
| 315 |
+
parser.add_argument('--subdirs', action='store_true', help='Create subdirectory for every 1000 seeds')
|
| 316 |
+
parser.add_argument('--class_idx', type=int, default=None, help='Class label [default: random]')
|
| 317 |
+
parser.add_argument('--max_batch_size', type=int, default=50, help='Maximum batch size per GPU during sampling, must be a factor of 50k if torch.compile is used')
|
| 318 |
+
|
| 319 |
+
parser.add_argument("--cfg_scale", type=parse_float_none, default=None, help='None = no guidance, by default = 4.0')
|
| 320 |
+
|
| 321 |
+
parser.add_argument('--num_steps', type=int, default=40, help='Number of sampling steps')
|
| 322 |
+
parser.add_argument('--S_churn', type=int, default=0, help='Stochasticity strength')
|
| 323 |
+
parser.add_argument('--solver', type=str, default=None, choices=['euler', 'heun'], help='Ablate ODE solver')
|
| 324 |
+
parser.add_argument('--discretization', type=str, default=None, choices=['vp', 've', 'iddpm', 'edm'], help='Ablate ODE solver')
|
| 325 |
+
parser.add_argument('--schedule', type=str, default=None, choices=['vp', 've', 'linear'], help='Ablate noise schedule sigma(t)')
|
| 326 |
+
parser.add_argument('--scaling', type=str, default=None, choices=['vp', 'none'], help='Ablate signal scaling s(t)')
|
| 327 |
+
parser.add_argument('--pretrained_path', type=str, default='assets/stable_diffusion/autoencoder_kl.pth', help='Autoencoder ckpt')
|
| 328 |
+
|
| 329 |
+
parser.add_argument('--ref_path', type=str, default='assets/fid_stats/fid_stats_imagenet256_guided_diffusion.npz', help='Dataset reference statistics')
|
| 330 |
+
parser.add_argument('--num_expected', type=int, default=50000, help='Number of images to use')
|
| 331 |
+
parser.add_argument('--fid_batch_size', type=int, default=64, help='Maximum batch size per GPU')
|
| 332 |
+
|
| 333 |
+
args = parser.parse_args()
|
| 334 |
+
|
| 335 |
+
torch.backends.cudnn.benchmark = True
|
| 336 |
+
train_loop(args)
|
train_utils/datasets.py
ADDED
|
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MIT License
|
| 2 |
+
|
| 3 |
+
# Copyright (c) [2023] [Anima-Lab]
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
import io
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
import zipfile
|
| 10 |
+
|
| 11 |
+
import lmdb
|
| 12 |
+
import numpy as np
|
| 13 |
+
from PIL import Image
|
| 14 |
+
import torch
|
| 15 |
+
from torchvision.datasets import ImageFolder, VisionDataset
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def center_crop_arr(pil_image, image_size):
|
| 20 |
+
"""
|
| 21 |
+
Center cropping implementation from ADM.
|
| 22 |
+
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
| 23 |
+
"""
|
| 24 |
+
while min(*pil_image.size) >= 2 * image_size:
|
| 25 |
+
pil_image = pil_image.resize(
|
| 26 |
+
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
scale = image_size / min(*pil_image.size)
|
| 30 |
+
pil_image = pil_image.resize(
|
| 31 |
+
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
arr = np.array(pil_image)
|
| 35 |
+
crop_y = (arr.shape[0] - image_size) // 2
|
| 36 |
+
crop_x = (arr.shape[1] - image_size) // 2
|
| 37 |
+
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
################################################################################
|
| 41 |
+
# ImageNet - LMDB
|
| 42 |
+
###############################################################################
|
| 43 |
+
|
| 44 |
+
def lmdb_loader(path, lmdb_data, resolution):
|
| 45 |
+
# In-memory binary streams
|
| 46 |
+
with lmdb_data.begin(write=False, buffers=True) as txn:
|
| 47 |
+
bytedata = txn.get(path.encode('ascii'))
|
| 48 |
+
img = Image.open(io.BytesIO(bytedata)).convert('RGB')
|
| 49 |
+
arr = center_crop_arr(img, resolution)
|
| 50 |
+
# arr = arr.astype(np.float32) / 127.5 - 1
|
| 51 |
+
# arr = np.transpose(arr, [2, 0, 1]) # CHW
|
| 52 |
+
return arr
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def imagenet_lmdb_dataset(
|
| 56 |
+
root,
|
| 57 |
+
transform=None, target_transform=None,
|
| 58 |
+
resolution=256):
|
| 59 |
+
"""
|
| 60 |
+
You can create this dataloader using:
|
| 61 |
+
train_data = imagenet_lmdb_dataset(traindir, transform=train_transform)
|
| 62 |
+
valid_data = imagenet_lmdb_dataset(validdir, transform=val_transform)
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
if root.endswith('/'):
|
| 66 |
+
root = root[:-1]
|
| 67 |
+
pt_path = os.path.join(
|
| 68 |
+
root + '_faster_imagefolder.lmdb.pt')
|
| 69 |
+
lmdb_path = os.path.join(
|
| 70 |
+
root + '_faster_imagefolder.lmdb')
|
| 71 |
+
if os.path.isfile(pt_path) and os.path.isdir(lmdb_path):
|
| 72 |
+
print('Loading pt {} and lmdb {}'.format(pt_path, lmdb_path))
|
| 73 |
+
data_set = torch.load(pt_path)
|
| 74 |
+
else:
|
| 75 |
+
data_set = ImageFolder(
|
| 76 |
+
root, None, None, None)
|
| 77 |
+
torch.save(data_set, pt_path, pickle_protocol=4)
|
| 78 |
+
print('Saving pt to {}'.format(pt_path))
|
| 79 |
+
print('Building lmdb to {}'.format(lmdb_path))
|
| 80 |
+
env = lmdb.open(lmdb_path, map_size=1e12)
|
| 81 |
+
with env.begin(write=True) as txn:
|
| 82 |
+
for path, class_index in data_set.imgs:
|
| 83 |
+
with open(path, 'rb') as f:
|
| 84 |
+
data = f.read()
|
| 85 |
+
txn.put(path.encode('ascii'), data)
|
| 86 |
+
|
| 87 |
+
lmdb_dataset = ImageLMDB(lmdb_path, transform, target_transform, resolution, data_set.imgs, data_set.class_to_idx, data_set.classes)
|
| 88 |
+
return lmdb_dataset
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
################################################################################
|
| 92 |
+
# ImageNet Dataset class- LMDB
|
| 93 |
+
###############################################################################
|
| 94 |
+
|
| 95 |
+
class ImageLMDB(VisionDataset):
|
| 96 |
+
"""
|
| 97 |
+
A data loader for ImageNet LMDB dataset, which is faster than the original ImageFolder.
|
| 98 |
+
"""
|
| 99 |
+
def __init__(self, root, transform=None, target_transform=None,
|
| 100 |
+
resolution=256, samples=None, class_to_idx=None, classes=None):
|
| 101 |
+
super().__init__(root, transform=transform,
|
| 102 |
+
target_transform=target_transform)
|
| 103 |
+
self.root = root
|
| 104 |
+
self.resolution = resolution
|
| 105 |
+
self.samples = samples
|
| 106 |
+
self.class_to_idx = class_to_idx
|
| 107 |
+
self.classes = classes
|
| 108 |
+
|
| 109 |
+
def __getitem__(self, index: int):
|
| 110 |
+
path, target = self.samples[index]
|
| 111 |
+
|
| 112 |
+
# load image from path
|
| 113 |
+
if not hasattr(self, 'txn'):
|
| 114 |
+
self.open_db()
|
| 115 |
+
bytedata = self.txn.get(path.encode('ascii'))
|
| 116 |
+
img = Image.open(io.BytesIO(bytedata)).convert('RGB')
|
| 117 |
+
arr = center_crop_arr(img, self.resolution)
|
| 118 |
+
if self.transform is not None:
|
| 119 |
+
arr = self.transform(arr)
|
| 120 |
+
if self.target_transform is not None:
|
| 121 |
+
target = self.target_transform(target)
|
| 122 |
+
return arr, target
|
| 123 |
+
|
| 124 |
+
def __len__(self) -> int:
|
| 125 |
+
return len(self.samples)
|
| 126 |
+
|
| 127 |
+
def open_db(self):
|
| 128 |
+
self.env = lmdb.open(self.root, readonly=True, max_readers=256, lock=False, readahead=False, meminit=False)
|
| 129 |
+
self.txn = self.env.begin(write=False, buffers=True)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
################################################################################
|
| 134 |
+
# ImageNet - LMDB - latent space
|
| 135 |
+
###############################################################################
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
# ----------------------------------------------------------------------------
|
| 140 |
+
# Abstract base class for datasets.
|
| 141 |
+
|
| 142 |
+
class Dataset(torch.utils.data.Dataset):
|
| 143 |
+
def __init__(self,
|
| 144 |
+
name, # Name of the dataset.
|
| 145 |
+
raw_shape, # Shape of the raw image data (NCHW).
|
| 146 |
+
max_size=None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip.
|
| 147 |
+
label_dim=1000, # Ensure specific number of classes
|
| 148 |
+
xflip=False, # Artificially double the size of the dataset via x-flips. Applied after max_size.
|
| 149 |
+
random_seed=0, # Random seed to use when applying max_size.
|
| 150 |
+
):
|
| 151 |
+
self._name = name
|
| 152 |
+
self._raw_shape = list(raw_shape)
|
| 153 |
+
self._label_dim = label_dim
|
| 154 |
+
self._label_shape = None
|
| 155 |
+
|
| 156 |
+
# Apply max_size.
|
| 157 |
+
self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64)
|
| 158 |
+
if (max_size is not None) and (self._raw_idx.size > max_size):
|
| 159 |
+
np.random.RandomState(random_seed % (1 << 31)).shuffle(self._raw_idx)
|
| 160 |
+
self._raw_idx = np.sort(self._raw_idx[:max_size])
|
| 161 |
+
|
| 162 |
+
# Apply xflip. (Assume the dataset already contains the same number of xflipped samples)
|
| 163 |
+
if xflip:
|
| 164 |
+
self._raw_idx = np.concatenate([self._raw_idx, self._raw_idx + self._raw_shape[0]])
|
| 165 |
+
|
| 166 |
+
def close(self): # to be overridden by subclass
|
| 167 |
+
pass
|
| 168 |
+
|
| 169 |
+
def _load_raw_data(self, raw_idx): # to be overridden by subclass
|
| 170 |
+
raise NotImplementedError
|
| 171 |
+
|
| 172 |
+
def __getstate__(self):
|
| 173 |
+
return dict(self.__dict__, _raw_labels=None)
|
| 174 |
+
|
| 175 |
+
def __del__(self):
|
| 176 |
+
try:
|
| 177 |
+
self.close()
|
| 178 |
+
except:
|
| 179 |
+
pass
|
| 180 |
+
|
| 181 |
+
def __len__(self):
|
| 182 |
+
return self._raw_idx.size
|
| 183 |
+
|
| 184 |
+
def __getitem__(self, idx):
|
| 185 |
+
raw_idx = self._raw_idx[idx]
|
| 186 |
+
image, cond = self._load_raw_data(raw_idx)
|
| 187 |
+
assert isinstance(image, np.ndarray)
|
| 188 |
+
if isinstance(cond, list): # [label, feature]
|
| 189 |
+
cond[0] = self._get_onehot(cond[0])
|
| 190 |
+
else: # label
|
| 191 |
+
cond = self._get_onehot(cond)
|
| 192 |
+
return image.copy(), cond
|
| 193 |
+
|
| 194 |
+
def _get_onehot(self, label):
|
| 195 |
+
if isinstance(label, int) or label.dtype == np.int64:
|
| 196 |
+
onehot = np.zeros(self.label_shape, dtype=np.float32)
|
| 197 |
+
onehot[label] = 1
|
| 198 |
+
label = onehot
|
| 199 |
+
assert isinstance(label, np.ndarray)
|
| 200 |
+
return label.copy()
|
| 201 |
+
|
| 202 |
+
@property
|
| 203 |
+
def name(self):
|
| 204 |
+
return self._name
|
| 205 |
+
|
| 206 |
+
@property
|
| 207 |
+
def image_shape(self):
|
| 208 |
+
return list(self._raw_shape[1:])
|
| 209 |
+
|
| 210 |
+
@property
|
| 211 |
+
def num_channels(self):
|
| 212 |
+
assert len(self.image_shape) == 3 # CHW
|
| 213 |
+
return self.image_shape[0]
|
| 214 |
+
|
| 215 |
+
@property
|
| 216 |
+
def resolution(self):
|
| 217 |
+
assert len(self.image_shape) == 3 # CHW
|
| 218 |
+
assert self.image_shape[1] == self.image_shape[2]
|
| 219 |
+
return self.image_shape[1]
|
| 220 |
+
|
| 221 |
+
@property
|
| 222 |
+
def label_shape(self):
|
| 223 |
+
if self._label_shape is None:
|
| 224 |
+
self._label_shape = [self._label_dim]
|
| 225 |
+
return list(self._label_shape)
|
| 226 |
+
|
| 227 |
+
@property
|
| 228 |
+
def label_dim(self):
|
| 229 |
+
assert len(self.label_shape) == 1
|
| 230 |
+
return self.label_shape[0]
|
| 231 |
+
|
| 232 |
+
@property
|
| 233 |
+
def has_labels(self):
|
| 234 |
+
return True
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
# ----------------------------------------------------------------------------
|
| 238 |
+
# Dataset subclass that loads latent images recursively from the specified lmdb file.
|
| 239 |
+
|
| 240 |
+
class ImageNetLatentDataset(Dataset):
|
| 241 |
+
def __init__(self,
|
| 242 |
+
path, # Path to directory or zip.
|
| 243 |
+
resolution=32, # Ensure specific resolution, default 32.
|
| 244 |
+
num_channels=4, # Ensure specific number of channels, default 4.
|
| 245 |
+
split='train', # train or val split
|
| 246 |
+
feat_path=None, # Path to features lmdb file (only works when feat_cond=True)
|
| 247 |
+
feat_dim=0, # feature dim
|
| 248 |
+
**super_kwargs, # Additional arguments for the Dataset base class.
|
| 249 |
+
):
|
| 250 |
+
self._path = os.path.join(path, split)
|
| 251 |
+
self.feat_dim = feat_dim
|
| 252 |
+
if not hasattr(self, 'txn'):
|
| 253 |
+
self.open_lmdb()
|
| 254 |
+
self.feat_txn = None
|
| 255 |
+
if feat_path is not None and os.path.isdir(feat_path):
|
| 256 |
+
assert self.feat_dim > 0
|
| 257 |
+
self._feat_path = os.path.join(feat_path, split)
|
| 258 |
+
self.open_feat_lmdb()
|
| 259 |
+
|
| 260 |
+
length = int(self.txn.get('length'.encode('utf-8')).decode('utf-8'))
|
| 261 |
+
name = os.path.basename(path)
|
| 262 |
+
raw_shape = [length, num_channels, resolution, resolution] # 1281167 x 4 x 32 x 32
|
| 263 |
+
if raw_shape[2] != resolution or raw_shape[3] != resolution:
|
| 264 |
+
raise IOError('Image files do not match the specified resolution')
|
| 265 |
+
|
| 266 |
+
super().__init__(name=name, raw_shape=raw_shape, **super_kwargs)
|
| 267 |
+
|
| 268 |
+
def open_lmdb(self):
|
| 269 |
+
self.env = lmdb.open(self._path, readonly=True, lock=False, create=False)
|
| 270 |
+
self.txn = self.env.begin(write=False)
|
| 271 |
+
|
| 272 |
+
def open_feat_lmdb(self):
|
| 273 |
+
self.feat_env = lmdb.open(self._feat_path, readonly=True, lock=False, create=False)
|
| 274 |
+
self.feat_txn = self.feat_env.begin(write=False)
|
| 275 |
+
|
| 276 |
+
def _load_raw_data(self, idx):
|
| 277 |
+
if not hasattr(self, 'txn'):
|
| 278 |
+
self.open_lmdb()
|
| 279 |
+
|
| 280 |
+
z_bytes = self.txn.get(f'z-{str(idx)}'.encode('utf-8'))
|
| 281 |
+
y_bytes = self.txn.get(f'y-{str(idx)}'.encode('utf-8'))
|
| 282 |
+
z = np.frombuffer(z_bytes, dtype=np.float32).reshape([-1, self.resolution, self.resolution]).copy()
|
| 283 |
+
y = int(y_bytes.decode('utf-8'))
|
| 284 |
+
|
| 285 |
+
cond = y
|
| 286 |
+
if self.feat_txn is not None:
|
| 287 |
+
feat_bytes = self.feat_txn.get(f'feat-{str(idx)}'.encode('utf-8'))
|
| 288 |
+
feat_y_bytes = self.feat_txn.get(f'y-{str(idx)}'.encode('utf-8'))
|
| 289 |
+
feat = np.frombuffer(feat_bytes, dtype=np.float32).reshape([self.feat_dim]).copy()
|
| 290 |
+
feat_y = int(feat_y_bytes.decode('utf-8'))
|
| 291 |
+
assert y == feat_y, 'Ordering mismatch between txn and feat_txn!'
|
| 292 |
+
cond = [y, feat]
|
| 293 |
+
|
| 294 |
+
return z, cond
|
| 295 |
+
|
| 296 |
+
def close(self):
|
| 297 |
+
try:
|
| 298 |
+
if self.env is not None:
|
| 299 |
+
self.env.close()
|
| 300 |
+
if self.feat_env is not None:
|
| 301 |
+
self.feat_env.close()
|
| 302 |
+
finally:
|
| 303 |
+
self.env = None
|
| 304 |
+
self.feat_env = None
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
# ----------------------------------------------------------------------------
|
| 308 |
+
# Dataset subclass that loads images recursively from the specified directory or zip file.
|
| 309 |
+
|
| 310 |
+
class ImageFolderDataset(Dataset):
|
| 311 |
+
def __init__(self,
|
| 312 |
+
path, # Path to directory or zip.
|
| 313 |
+
resolution=None, # Ensure specific resolution, None = highest available.
|
| 314 |
+
use_labels=False, # Enable conditioning labels? False = label dimension is zero.
|
| 315 |
+
**super_kwargs, # Additional arguments for the Dataset base class.
|
| 316 |
+
):
|
| 317 |
+
self._path = path
|
| 318 |
+
self._zipfile = None
|
| 319 |
+
self._raw_labels = None
|
| 320 |
+
self._use_labels = use_labels
|
| 321 |
+
|
| 322 |
+
if os.path.isdir(self._path):
|
| 323 |
+
self._type = 'dir'
|
| 324 |
+
self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in
|
| 325 |
+
os.walk(self._path) for fname in files}
|
| 326 |
+
elif self._file_ext(self._path) == '.zip':
|
| 327 |
+
self._type = 'zip'
|
| 328 |
+
self._all_fnames = set(self._get_zipfile().namelist())
|
| 329 |
+
else:
|
| 330 |
+
raise IOError('Path must point to a directory or zip')
|
| 331 |
+
|
| 332 |
+
Image.init()
|
| 333 |
+
self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in Image.EXTENSION)
|
| 334 |
+
if len(self._image_fnames) == 0:
|
| 335 |
+
raise IOError('No image files found in the specified path')
|
| 336 |
+
|
| 337 |
+
name = os.path.splitext(os.path.basename(self._path))[0]
|
| 338 |
+
raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape)
|
| 339 |
+
if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution):
|
| 340 |
+
raise IOError('Image files do not match the specified resolution')
|
| 341 |
+
super().__init__(name=name, raw_shape=raw_shape, **super_kwargs)
|
| 342 |
+
|
| 343 |
+
@staticmethod
|
| 344 |
+
def _file_ext(fname):
|
| 345 |
+
return os.path.splitext(fname)[1].lower()
|
| 346 |
+
|
| 347 |
+
def _get_zipfile(self):
|
| 348 |
+
assert self._type == 'zip'
|
| 349 |
+
if self._zipfile is None:
|
| 350 |
+
self._zipfile = zipfile.ZipFile(self._path)
|
| 351 |
+
return self._zipfile
|
| 352 |
+
|
| 353 |
+
def _open_file(self, fname):
|
| 354 |
+
if self._type == 'dir':
|
| 355 |
+
return open(os.path.join(self._path, fname), 'rb')
|
| 356 |
+
if self._type == 'zip':
|
| 357 |
+
return self._get_zipfile().open(fname, 'r')
|
| 358 |
+
return None
|
| 359 |
+
|
| 360 |
+
def close(self):
|
| 361 |
+
try:
|
| 362 |
+
if self._zipfile is not None:
|
| 363 |
+
self._zipfile.close()
|
| 364 |
+
finally:
|
| 365 |
+
self._zipfile = None
|
| 366 |
+
|
| 367 |
+
def __getstate__(self):
|
| 368 |
+
return dict(super().__getstate__(), _zipfile=None)
|
| 369 |
+
|
| 370 |
+
def _load_raw_data(self, raw_idx):
|
| 371 |
+
image = self._load_raw_image(raw_idx)
|
| 372 |
+
assert image.dtype == np.uint8
|
| 373 |
+
label = self._get_raw_labels()[raw_idx]
|
| 374 |
+
return image, label
|
| 375 |
+
|
| 376 |
+
def _load_raw_image(self, raw_idx):
|
| 377 |
+
fname = self._image_fnames[raw_idx]
|
| 378 |
+
with self._open_file(fname) as f:
|
| 379 |
+
image = np.array(Image.open(f))
|
| 380 |
+
if image.ndim == 2:
|
| 381 |
+
image = image[:, :, np.newaxis] # HW => HWC
|
| 382 |
+
image = image.transpose(2, 0, 1) # HWC => CHW
|
| 383 |
+
return image
|
| 384 |
+
|
| 385 |
+
def _get_raw_labels(self):
|
| 386 |
+
if self._raw_labels is None:
|
| 387 |
+
self._raw_labels = self._load_raw_labels() if self._use_labels else None
|
| 388 |
+
if self._raw_labels is None:
|
| 389 |
+
self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32)
|
| 390 |
+
assert isinstance(self._raw_labels, np.ndarray)
|
| 391 |
+
assert self._raw_labels.shape[0] == self._raw_shape[0]
|
| 392 |
+
assert self._raw_labels.dtype in [np.float32, np.int64]
|
| 393 |
+
if self._raw_labels.dtype == np.int64:
|
| 394 |
+
assert self._raw_labels.ndim == 1
|
| 395 |
+
assert np.all(self._raw_labels >= 0)
|
| 396 |
+
return self._raw_labels
|
| 397 |
+
|
| 398 |
+
def _load_raw_labels(self):
|
| 399 |
+
fname = 'dataset.json'
|
| 400 |
+
if fname not in self._all_fnames:
|
| 401 |
+
return None
|
| 402 |
+
with self._open_file(fname) as f:
|
| 403 |
+
labels = json.load(f)['labels']
|
| 404 |
+
if labels is None:
|
| 405 |
+
return None
|
| 406 |
+
labels = dict(labels)
|
| 407 |
+
labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames]
|
| 408 |
+
labels = np.array(labels)
|
| 409 |
+
labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])
|
| 410 |
+
return labels
|
| 411 |
+
|
| 412 |
+
# ----------------------------------------------------------------------------
|
train_utils/helper.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MIT License
|
| 2 |
+
|
| 3 |
+
# Copyright (c) [2023] [Anima-Lab]
|
| 4 |
+
from collections import OrderedDict
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def get_mask_ratio_fn(name='constant', ratio_scale=0.5, ratio_min=0.0):
|
| 10 |
+
if name == 'cosine2':
|
| 11 |
+
return lambda x: (ratio_scale - ratio_min) * np.cos(np.pi * x / 2) ** 2 + ratio_min
|
| 12 |
+
elif name == 'cosine3':
|
| 13 |
+
return lambda x: (ratio_scale - ratio_min) * np.cos(np.pi * x / 2) ** 3 + ratio_min
|
| 14 |
+
elif name == 'cosine4':
|
| 15 |
+
return lambda x: (ratio_scale - ratio_min) * np.cos(np.pi * x / 2) ** 4 + ratio_min
|
| 16 |
+
elif name == 'cosine5':
|
| 17 |
+
return lambda x: (ratio_scale - ratio_min) * np.cos(np.pi * x / 2) ** 5 + ratio_min
|
| 18 |
+
elif name == 'cosine6':
|
| 19 |
+
return lambda x: (ratio_scale - ratio_min) * np.cos(np.pi * x / 2) ** 6 + ratio_min
|
| 20 |
+
elif name == 'exp':
|
| 21 |
+
return lambda x: (ratio_scale - ratio_min) * np.exp(-x * 7) + ratio_min
|
| 22 |
+
elif name == 'linear':
|
| 23 |
+
return lambda x: (ratio_scale - ratio_min) * x + ratio_min
|
| 24 |
+
elif name == 'constant':
|
| 25 |
+
return lambda x: ratio_scale
|
| 26 |
+
else:
|
| 27 |
+
raise ValueError('Unknown mask ratio function: {}'.format(name))
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_one_hot(labels, num_classes=1000):
|
| 31 |
+
one_hot = torch.zeros(labels.shape[0], num_classes, device=labels.device)
|
| 32 |
+
one_hot.scatter_(1, labels.view(-1, 1), 1)
|
| 33 |
+
return one_hot
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def requires_grad(model, flag=True):
|
| 37 |
+
"""
|
| 38 |
+
Set requires_grad flag for all parameters in a model.
|
| 39 |
+
"""
|
| 40 |
+
for p in model.parameters():
|
| 41 |
+
p.requires_grad = flag
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# ------------------------------------------------------------
|
| 45 |
+
# Training Helper Function
|
| 46 |
+
|
| 47 |
+
@torch.no_grad()
|
| 48 |
+
def update_ema(ema_model, model, decay=0.9999):
|
| 49 |
+
"""
|
| 50 |
+
Step the EMA model towards the current model.
|
| 51 |
+
"""
|
| 52 |
+
ema_params = OrderedDict(ema_model.named_parameters())
|
| 53 |
+
model_params = OrderedDict(model.named_parameters())
|
| 54 |
+
|
| 55 |
+
for name, param in model_params.items():
|
| 56 |
+
if param.requires_grad:
|
| 57 |
+
ema_name = name.replace('_orig_mod.', '')
|
| 58 |
+
ema_params[ema_name].mul_(decay).add_(param.data, alpha=1 - decay)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def unwrap_model(model):
|
| 62 |
+
"""
|
| 63 |
+
Unwrap a model from any distributed or compiled wrappers.
|
| 64 |
+
"""
|
| 65 |
+
if isinstance(model, torch._dynamo.eval_frame.OptimizedModule):
|
| 66 |
+
model = model._orig_mod
|
| 67 |
+
if isinstance(model, (torch.nn.parallel.DistributedDataParallel, torch.nn.DataParallel)):
|
| 68 |
+
model = model.module
|
| 69 |
+
return model
|
train_utils/loss.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MIT License
|
| 2 |
+
|
| 3 |
+
# Copyright (c) [2023] [Anima-Lab]
|
| 4 |
+
|
| 5 |
+
# This code is adapted from https://github.com/NVlabs/edm/blob/main/training/loss.py.
|
| 6 |
+
# The original code is licensed under a Creative Commons
|
| 7 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License, which is can be found at licenses/LICENSE_EDM.txt.
|
| 8 |
+
|
| 9 |
+
"""Loss functions used in the paper
|
| 10 |
+
"Elucidating the Design Space of Diffusion-Based Generative Models"."""
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
|
| 15 |
+
from utils import *
|
| 16 |
+
from train_utils.helper import unwrap_model
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# Improved loss function proposed in the paper "Elucidating the Design Space
|
| 20 |
+
# of Diffusion-Based Generative Models" (EDM).
|
| 21 |
+
|
| 22 |
+
class EDMLoss:
|
| 23 |
+
def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5):
|
| 24 |
+
self.P_mean = P_mean
|
| 25 |
+
self.P_std = P_std
|
| 26 |
+
self.sigma_data = sigma_data
|
| 27 |
+
|
| 28 |
+
def __call__(self, net,
|
| 29 |
+
images,
|
| 30 |
+
labels=None,
|
| 31 |
+
mask_ratio=0,
|
| 32 |
+
mae_loss_coef=0,
|
| 33 |
+
feat=None, augment_pipe=None):
|
| 34 |
+
# sample x_t
|
| 35 |
+
rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device)
|
| 36 |
+
sigma = (rnd_normal * self.P_std + self.P_mean).exp()
|
| 37 |
+
weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
|
| 38 |
+
y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None)
|
| 39 |
+
n = torch.randn_like(y) * sigma
|
| 40 |
+
|
| 41 |
+
model_out = net(y + n, sigma, labels, mask_ratio=mask_ratio, mask_dict=None, feat=feat)
|
| 42 |
+
D_yn = model_out['x']
|
| 43 |
+
assert D_yn.shape == y.shape
|
| 44 |
+
loss = weight * ((D_yn - y) ** 2) # (N, C, H, W)
|
| 45 |
+
if mask_ratio > 0:
|
| 46 |
+
assert net.training and 'mask' in model_out
|
| 47 |
+
loss = F.avg_pool2d(loss.mean(dim=1), net.module.model.patch_size).flatten(1) # (N, L)
|
| 48 |
+
unmask = 1 - model_out['mask']
|
| 49 |
+
loss = (loss * unmask).sum(dim=1) / unmask.sum(dim=1) # (N)
|
| 50 |
+
assert loss.ndim == 1
|
| 51 |
+
if mae_loss_coef > 0:
|
| 52 |
+
loss += mae_loss_coef * mae_loss(net.module, y + n, D_yn, 1 - unmask)
|
| 53 |
+
else:
|
| 54 |
+
loss = mean_flat(loss) # (N)
|
| 55 |
+
|
| 56 |
+
raw_net = unwrap_model(net)
|
| 57 |
+
if mask_ratio == 0.0 and raw_net.model.mask_token is not None:
|
| 58 |
+
loss += 0 * torch.sum(raw_net.model.mask_token)
|
| 59 |
+
assert loss.ndim == 1
|
| 60 |
+
return loss
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# ----------------------------------------------------------------------------
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
Losses = {
|
| 67 |
+
'edm': EDMLoss
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# ----------------------------------------------------------------------------
|
| 72 |
+
|
| 73 |
+
def patchify(imgs, patch_size=2, num_channels=4):
|
| 74 |
+
"""
|
| 75 |
+
imgs: (N, 3, H, W)
|
| 76 |
+
x: (N, L, patch_size**2 *3)
|
| 77 |
+
"""
|
| 78 |
+
p, c = patch_size, num_channels
|
| 79 |
+
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
|
| 80 |
+
|
| 81 |
+
h = w = imgs.shape[2] // p
|
| 82 |
+
x = imgs.reshape(shape=(imgs.shape[0], c, h, p, w, p))
|
| 83 |
+
x = torch.einsum('nchpwq->nhwpqc', x)
|
| 84 |
+
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * c))
|
| 85 |
+
return x
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def mae_loss(net, target, pred, mask, norm_pix_loss=True):
|
| 89 |
+
target = patchify(target, net.model.patch_size, net.model.out_channels)
|
| 90 |
+
pred = patchify(pred, net.model.patch_size, net.model.out_channels)
|
| 91 |
+
if norm_pix_loss:
|
| 92 |
+
mean = target.mean(dim=-1, keepdim=True)
|
| 93 |
+
var = target.var(dim=-1, keepdim=True)
|
| 94 |
+
target = (target - mean) / (var + 1.e-6)**.5
|
| 95 |
+
|
| 96 |
+
loss = (pred - target) ** 2
|
| 97 |
+
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
|
| 98 |
+
|
| 99 |
+
loss = (loss * mask).sum(dim=1) / mask.sum(dim=1) # mean loss on removed patches, (N)
|
| 100 |
+
assert loss.ndim == 1
|
| 101 |
+
return loss
|
train_wds.py
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MIT License
|
| 2 |
+
|
| 3 |
+
# Copyright (c) [2023] [Anima-Lab]
|
| 4 |
+
'''
|
| 5 |
+
Training MaskDiT on latent dataset in WebDataset format. Used for experiments on Imagenet512x512.
|
| 6 |
+
'''
|
| 7 |
+
|
| 8 |
+
import argparse
|
| 9 |
+
import os.path
|
| 10 |
+
from copy import deepcopy
|
| 11 |
+
from time import time
|
| 12 |
+
from omegaconf import OmegaConf
|
| 13 |
+
import pickle
|
| 14 |
+
from itertools import islice
|
| 15 |
+
|
| 16 |
+
import apex
|
| 17 |
+
import torch
|
| 18 |
+
import webdataset as wds
|
| 19 |
+
|
| 20 |
+
import accelerate
|
| 21 |
+
|
| 22 |
+
from fid import calc
|
| 23 |
+
from models.maskdit import Precond_models
|
| 24 |
+
from train_utils.loss import Losses
|
| 25 |
+
|
| 26 |
+
from train_utils.helper import get_mask_ratio_fn, get_one_hot, requires_grad, update_ema, unwrap_model
|
| 27 |
+
|
| 28 |
+
from sample import generate_with_net
|
| 29 |
+
from utils import dist, mprint, get_latest_ckpt, Logger, sample, \
|
| 30 |
+
str2bool, parse_str_none, parse_int_list, parse_float_none
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# ------------------------------------------------------------
|
| 34 |
+
# WebDataset Helper Function
|
| 35 |
+
def nodesplitter(src, group=None):
|
| 36 |
+
rank, world_size, worker, num_workers = wds.utils.pytorch_worker_info()
|
| 37 |
+
if world_size > 1:
|
| 38 |
+
for s in islice(src, rank, None, world_size):
|
| 39 |
+
yield s
|
| 40 |
+
else:
|
| 41 |
+
for s in src:
|
| 42 |
+
yield s
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def get_file_paths(dir):
|
| 46 |
+
return [os.path.join(dir, file) for file in os.listdir(dir)]
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def split_by_proc(data_list, global_rank, total_size):
|
| 50 |
+
'''
|
| 51 |
+
Evenly split the data_list into total_size parts and return the part indexed by global_rank.
|
| 52 |
+
'''
|
| 53 |
+
assert len(data_list) >= total_size
|
| 54 |
+
assert global_rank < total_size
|
| 55 |
+
return data_list[global_rank::total_size]
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def decode_data(item):
|
| 59 |
+
output = {}
|
| 60 |
+
img = pickle.loads(item['latent'])
|
| 61 |
+
output['latent'] = img
|
| 62 |
+
label = int(item['cls'].decode('utf-8'))
|
| 63 |
+
output['label'] = label
|
| 64 |
+
return output
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def make_loader(root, mode='train', batch_size=32,
|
| 68 |
+
num_workers=4, cache_dir=None,
|
| 69 |
+
resampled=False, world_size=1, total_num=1281167,
|
| 70 |
+
bufsize=1000, initial=100):
|
| 71 |
+
data_list = get_file_paths(root)
|
| 72 |
+
num_batches_in_total = total_num // (batch_size * world_size)
|
| 73 |
+
if resampled:
|
| 74 |
+
repeat = True
|
| 75 |
+
splitter = False
|
| 76 |
+
else:
|
| 77 |
+
repeat = False
|
| 78 |
+
splitter = nodesplitter
|
| 79 |
+
dataset = (
|
| 80 |
+
wds.WebDataset(
|
| 81 |
+
data_list,
|
| 82 |
+
cache_dir=cache_dir,
|
| 83 |
+
repeat=repeat,
|
| 84 |
+
resampled=resampled,
|
| 85 |
+
handler=wds.handlers.warn_and_stop,
|
| 86 |
+
nodesplitter=splitter,
|
| 87 |
+
)
|
| 88 |
+
.shuffle(bufsize, initial=initial)
|
| 89 |
+
.map(decode_data, handler=wds.handlers.warn_and_stop)
|
| 90 |
+
.to_tuple('latent label')
|
| 91 |
+
.batched(batch_size, partial=False)
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
loader = wds.WebLoader(dataset, batch_size=None, num_workers=num_workers, shuffle=False, persistent_workers=True)
|
| 95 |
+
if resampled:
|
| 96 |
+
loader = loader.with_epoch(num_batches_in_total)
|
| 97 |
+
return loader
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# ------------------------------------------------------------
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def train_loop(args):
|
| 104 |
+
# load configuration
|
| 105 |
+
config = OmegaConf.load(args.config)
|
| 106 |
+
if not args.no_amp:
|
| 107 |
+
config.train.amp = 'fp16'
|
| 108 |
+
else:
|
| 109 |
+
config.train.amp = 'no'
|
| 110 |
+
if config.train.tf32:
|
| 111 |
+
torch.set_float32_matmul_precision('high')
|
| 112 |
+
|
| 113 |
+
accelerator = accelerate.Accelerator(mixed_precision=config.train.amp,
|
| 114 |
+
gradient_accumulation_steps=config.train.grad_accum,
|
| 115 |
+
log_with='wandb')
|
| 116 |
+
# setup wandb
|
| 117 |
+
if args.use_wandb:
|
| 118 |
+
wandb_init_kwargs = {
|
| 119 |
+
'entity': config.wandb.entity,
|
| 120 |
+
'project': config.wandb.project,
|
| 121 |
+
'group': config.wandb.group,
|
| 122 |
+
}
|
| 123 |
+
accelerator.init_trackers(config.wandb.project, config=OmegaConf.to_container(config), init_kwargs=wandb_init_kwargs)
|
| 124 |
+
|
| 125 |
+
mprint('start training...')
|
| 126 |
+
size = accelerator.num_processes
|
| 127 |
+
rank = accelerator.process_index
|
| 128 |
+
|
| 129 |
+
print(f'global_rank: {rank}, global_size: {size}')
|
| 130 |
+
device = accelerator.device
|
| 131 |
+
|
| 132 |
+
seed = args.global_seed
|
| 133 |
+
torch.manual_seed(seed)
|
| 134 |
+
|
| 135 |
+
mprint(f"enable_amp: {not args.no_amp}, TF32: {config.train.tf32}")
|
| 136 |
+
# Select batch size per GPU
|
| 137 |
+
num_accumulation_rounds = config.train.grad_accum
|
| 138 |
+
|
| 139 |
+
micro_batch = config.train.batchsize
|
| 140 |
+
batch_gpu_total = micro_batch * num_accumulation_rounds
|
| 141 |
+
global_batch_size = batch_gpu_total * size
|
| 142 |
+
mprint(f"Global batchsize: {global_batch_size}, batchsize per GPU: {batch_gpu_total}, micro_batch: {micro_batch}.")
|
| 143 |
+
|
| 144 |
+
class_dropout_prob = config.model.class_dropout_prob
|
| 145 |
+
log_every = config.log.log_every
|
| 146 |
+
ckpt_every = config.log.ckpt_every
|
| 147 |
+
|
| 148 |
+
mask_ratio_fn = get_mask_ratio_fn(config.model.mask_ratio_fn, config.model.mask_ratio, config.model.mask_ratio_min)
|
| 149 |
+
|
| 150 |
+
# Setup an experiment folder
|
| 151 |
+
model_name = config.model.model_type.replace("/", "-") # e.g., DiT-XL/2 --> DiT-XL-2 (for naming folders)
|
| 152 |
+
data_name = config.data.dataset
|
| 153 |
+
if args.ckpt_path is not None and args.use_ckpt_path: # use the existing exp path (mainly used for fine-tuning)
|
| 154 |
+
checkpoint_dir = os.path.dirname(args.ckpt_path)
|
| 155 |
+
experiment_dir = os.path.dirname(checkpoint_dir)
|
| 156 |
+
exp_name = os.path.basename(experiment_dir)
|
| 157 |
+
else: # start a new exp path (and resume from the latest checkpoint if possible)
|
| 158 |
+
cond_gen = 'cond' if config.model.num_classes else 'uncond'
|
| 159 |
+
exp_name = f'{model_name}-{config.model.precond}-{data_name}-{cond_gen}-m{config.model.mask_ratio}-de{int(config.model.use_decoder)}' \
|
| 160 |
+
f'-mae{config.model.mae_loss_coef}-bs-{global_batch_size}-lr{config.train.lr}{config.log.tag}'
|
| 161 |
+
experiment_dir = f"{args.results_dir}/{exp_name}"
|
| 162 |
+
checkpoint_dir = f"{experiment_dir}/checkpoints" # Stores saved model checkpoints
|
| 163 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 164 |
+
if args.ckpt_path is None:
|
| 165 |
+
args.ckpt_path = get_latest_ckpt(checkpoint_dir) # Resumes from the latest checkpoint if it exists
|
| 166 |
+
|
| 167 |
+
if accelerator.is_main_process:
|
| 168 |
+
logger = Logger(file_name=f'{experiment_dir}/log.txt', file_mode="a+", should_flush=True)
|
| 169 |
+
mprint(f"Experiment directory created at {experiment_dir}")
|
| 170 |
+
|
| 171 |
+
# Setup dataset
|
| 172 |
+
loader = make_loader(config.data.root,
|
| 173 |
+
mode='train',
|
| 174 |
+
batch_size=batch_gpu_total,
|
| 175 |
+
num_workers=args.num_workers,
|
| 176 |
+
resampled=args.resample,
|
| 177 |
+
world_size=size,
|
| 178 |
+
total_num=config.data.total_num)
|
| 179 |
+
|
| 180 |
+
steps_per_epoch = config.data.total_num // global_batch_size
|
| 181 |
+
mprint(f"{steps_per_epoch} steps per epoch")
|
| 182 |
+
|
| 183 |
+
model = Precond_models[config.model.precond](
|
| 184 |
+
img_resolution=config.model.in_size,
|
| 185 |
+
img_channels=config.model.in_channels,
|
| 186 |
+
num_classes=config.model.num_classes,
|
| 187 |
+
model_type=config.model.model_type,
|
| 188 |
+
use_decoder=config.model.use_decoder,
|
| 189 |
+
mae_loss_coef=config.model.mae_loss_coef,
|
| 190 |
+
pad_cls_token=config.model.pad_cls_token
|
| 191 |
+
).to(device)
|
| 192 |
+
# Note that parameter initialization is done within the model constructor
|
| 193 |
+
ema = deepcopy(model) # Create an EMA of the model for use after training
|
| 194 |
+
requires_grad(ema, False)
|
| 195 |
+
ema = ema.to(device)
|
| 196 |
+
|
| 197 |
+
mprint(f"{config.model.model_type} ((use_decoder: {config.model.use_decoder})) Model Parameters: {sum(p.numel() for p in model.parameters()):,}")
|
| 198 |
+
mprint(f'extras: {model.model.extras}, cls_token: {model.model.cls_token}')
|
| 199 |
+
|
| 200 |
+
# Setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper):
|
| 201 |
+
optimizer = apex.optimizers.FusedAdam(model.parameters(), lr=config.train.lr, adam_w_mode=True, weight_decay=0)
|
| 202 |
+
# optimizer = torch.optim.Adam(model.parameters(), lr=config.train.lr)
|
| 203 |
+
|
| 204 |
+
# Load checkpoints
|
| 205 |
+
train_steps_start = 0
|
| 206 |
+
epoch_start = 0
|
| 207 |
+
|
| 208 |
+
if args.ckpt_path is not None:
|
| 209 |
+
ckpt = torch.load(args.ckpt_path, map_location=device)
|
| 210 |
+
model.load_state_dict(ckpt['model'], strict=args.use_strict_load)
|
| 211 |
+
ema.load_state_dict(ckpt['ema'], strict=args.use_strict_load)
|
| 212 |
+
mprint(f'Load weights from {args.ckpt_path}')
|
| 213 |
+
if args.use_strict_load:
|
| 214 |
+
optimizer.load_state_dict(ckpt['opt'])
|
| 215 |
+
for state in optimizer.state.values():
|
| 216 |
+
for k, v in state.items():
|
| 217 |
+
if isinstance(v, torch.Tensor):
|
| 218 |
+
state[k] = v.cuda()
|
| 219 |
+
mprint(f'Load optimizer state..')
|
| 220 |
+
train_steps_start = int(os.path.basename(args.ckpt_path).split('.pt')[0])
|
| 221 |
+
epoch_start = train_steps_start // steps_per_epoch
|
| 222 |
+
mprint(f"train_steps_start: {train_steps_start}")
|
| 223 |
+
del ckpt # conserve memory
|
| 224 |
+
|
| 225 |
+
# FID evaluation for the loaded weights
|
| 226 |
+
if args.enable_eval:
|
| 227 |
+
start_time = time()
|
| 228 |
+
args.outdir = os.path.join(experiment_dir, 'fid', f'edm-steps{args.num_steps}-ckpt{train_steps_start}_cfg{args.cfg_scale}')
|
| 229 |
+
os.makedirs(args.outdir, exist_ok=True)
|
| 230 |
+
generate_with_net(args, ema, device, rank, size)
|
| 231 |
+
|
| 232 |
+
dist.barrier()
|
| 233 |
+
fid = calc(args.outdir, config.eval.ref_path, args.num_expected, args.global_seed, args.fid_batch_size)
|
| 234 |
+
mprint(f"time for fid calc: {time() - start_time}")
|
| 235 |
+
if args.use_wandb:
|
| 236 |
+
accelerator.log({"eval/fid": fid}, step=train_steps_start)
|
| 237 |
+
mprint(f'guidance: {args.cfg_scale} FID: {fid}')
|
| 238 |
+
dist.barrier()
|
| 239 |
+
|
| 240 |
+
model, optimizer = accelerator.prepare(model, optimizer)
|
| 241 |
+
model = torch.compile(model)
|
| 242 |
+
|
| 243 |
+
# Setup loss
|
| 244 |
+
loss_fn = Losses[config.model.precond]()
|
| 245 |
+
|
| 246 |
+
# Prepare models for training:
|
| 247 |
+
if args.ckpt_path is None:
|
| 248 |
+
assert train_steps_start == 0
|
| 249 |
+
raw_model = unwrap_model(model)
|
| 250 |
+
update_ema(ema, raw_model, decay=0) # Ensure EMA is initialized with synced weights
|
| 251 |
+
model.train() # important! This enables embedding dropout for classifier-free guidance
|
| 252 |
+
ema.eval() # EMA model should always be in eval mode
|
| 253 |
+
|
| 254 |
+
# Variables for monitoring/logging purposes:
|
| 255 |
+
train_steps = train_steps_start
|
| 256 |
+
log_steps = 0
|
| 257 |
+
running_loss = 0
|
| 258 |
+
start_time = time()
|
| 259 |
+
mprint(f"Training for {config.train.epochs} epochs...")
|
| 260 |
+
for epoch in range(epoch_start, config.train.epochs):
|
| 261 |
+
mprint(f"Beginning epoch {epoch}...")
|
| 262 |
+
for x, cond in loader:
|
| 263 |
+
x = x.to(device)
|
| 264 |
+
y = cond.to(device)
|
| 265 |
+
|
| 266 |
+
y = get_one_hot(y, num_classes=config.model.num_classes)
|
| 267 |
+
x = sample(x)
|
| 268 |
+
|
| 269 |
+
loss_batch = 0
|
| 270 |
+
model.zero_grad(set_to_none=True)
|
| 271 |
+
curr_mask_ratio = mask_ratio_fn((train_steps - train_steps_start) / config.train.max_num_steps)
|
| 272 |
+
if class_dropout_prob > 0:
|
| 273 |
+
y = y * (torch.rand([y.shape[0], 1], device=device) >= class_dropout_prob)
|
| 274 |
+
|
| 275 |
+
for round_idx in range(num_accumulation_rounds):
|
| 276 |
+
x_ = x[round_idx * micro_batch:(round_idx + 1) * micro_batch]
|
| 277 |
+
y_ = y[round_idx * micro_batch:(round_idx + 1) * micro_batch]
|
| 278 |
+
|
| 279 |
+
with accelerator.accumulate(model):
|
| 280 |
+
loss = loss_fn(net=model, images=x_, labels=y_,
|
| 281 |
+
mask_ratio=curr_mask_ratio,
|
| 282 |
+
mae_loss_coef=config.model.mae_loss_coef)
|
| 283 |
+
loss_mean = loss.mean()
|
| 284 |
+
accelerator.backward(loss_mean)
|
| 285 |
+
# Update weights with lr warmup.
|
| 286 |
+
lr_cur = config.train.lr * min(train_steps * global_batch_size / max(config.train.lr_rampup_kimg * 1000, 1), 1)
|
| 287 |
+
for g in optimizer.param_groups:
|
| 288 |
+
g['lr'] = lr_cur
|
| 289 |
+
optimizer.step()
|
| 290 |
+
loss_batch = loss_mean.item()
|
| 291 |
+
|
| 292 |
+
raw_model = unwrap_model(model)
|
| 293 |
+
update_ema(ema, raw_model)
|
| 294 |
+
|
| 295 |
+
# Log loss values:
|
| 296 |
+
running_loss += loss_batch
|
| 297 |
+
log_steps += 1
|
| 298 |
+
train_steps += 1
|
| 299 |
+
if train_steps > (train_steps_start + config.train.max_num_steps):
|
| 300 |
+
break
|
| 301 |
+
if train_steps % log_every == 0:
|
| 302 |
+
# Measure training speed:
|
| 303 |
+
torch.cuda.synchronize()
|
| 304 |
+
end_time = time()
|
| 305 |
+
steps_per_sec = log_steps / (end_time - start_time)
|
| 306 |
+
# Reduce loss history over all processes:
|
| 307 |
+
avg_loss = torch.tensor(running_loss / log_steps, device=device)
|
| 308 |
+
|
| 309 |
+
dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
|
| 310 |
+
avg_loss = avg_loss.item() / size
|
| 311 |
+
mprint(f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}")
|
| 312 |
+
mprint(f'Peak GPU memory usage: {torch.cuda.max_memory_allocated() / 1024 ** 3:.2f} GB')
|
| 313 |
+
mprint(f'Reserved GPU memory: {torch.cuda.memory_reserved() / 1024 ** 3:.2f} GB')
|
| 314 |
+
if args.use_wandb:
|
| 315 |
+
accelerator.log({"train/loss": avg_loss, "train/lr": lr_cur}, step=train_steps)
|
| 316 |
+
# Reset monitoring variables:
|
| 317 |
+
running_loss = 0
|
| 318 |
+
log_steps = 0
|
| 319 |
+
start_time = time()
|
| 320 |
+
|
| 321 |
+
# Save checkpoint:
|
| 322 |
+
if train_steps % ckpt_every == 0 and train_steps > train_steps_start:
|
| 323 |
+
if accelerator.is_main_process:
|
| 324 |
+
checkpoint = {
|
| 325 |
+
"model": raw_model.state_dict(),
|
| 326 |
+
"ema": ema.state_dict(),
|
| 327 |
+
"opt": optimizer.state_dict(),
|
| 328 |
+
"args": args
|
| 329 |
+
}
|
| 330 |
+
checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt"
|
| 331 |
+
torch.save(checkpoint, checkpoint_path)
|
| 332 |
+
mprint(f"Saved checkpoint to {checkpoint_path}")
|
| 333 |
+
del checkpoint # conserve memory
|
| 334 |
+
dist.barrier()
|
| 335 |
+
|
| 336 |
+
# FID evaluation during training
|
| 337 |
+
if args.enable_eval:
|
| 338 |
+
start_time = time()
|
| 339 |
+
args.outdir = os.path.join(experiment_dir, 'fid', f'edm-steps{args.num_steps}-ckpt{train_steps}_cfg{args.cfg_scale}')
|
| 340 |
+
os.makedirs(args.outdir, exist_ok=True)
|
| 341 |
+
generate_with_net(args, ema, device, rank, size)
|
| 342 |
+
|
| 343 |
+
dist.barrier()
|
| 344 |
+
fid = calc(args.outdir, args.ref_path, args.num_expected, args.global_seed, args.fid_batch_size)
|
| 345 |
+
mprint(f"time for fid calc: {time() - start_time}, fid: {fid}")
|
| 346 |
+
if args.use_wandb:
|
| 347 |
+
accelerator.log({"eval/fid": fid}, step=train_steps)
|
| 348 |
+
mprint(f'Guidance: {args.cfg_scale}, FID: {fid}')
|
| 349 |
+
dist.barrier()
|
| 350 |
+
start_time = time()
|
| 351 |
+
|
| 352 |
+
if accelerator.is_main_process:
|
| 353 |
+
logger.close()
|
| 354 |
+
accelerator.end_training()
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
if __name__ == '__main__':
|
| 359 |
+
parser = argparse.ArgumentParser('training parameters')
|
| 360 |
+
# basic config
|
| 361 |
+
parser.add_argument('--config', type=str, required=True, help='path to config file')
|
| 362 |
+
# training
|
| 363 |
+
parser.add_argument("--results_dir", type=str, default="results")
|
| 364 |
+
parser.add_argument("--ckpt_path", type=parse_str_none, default=None)
|
| 365 |
+
|
| 366 |
+
parser.add_argument("--global_seed", type=int, default=0)
|
| 367 |
+
parser.add_argument("--num_workers", type=int, default=4)
|
| 368 |
+
parser.add_argument('--no_amp', action='store_true', help="Disable automatic mixed precision.")
|
| 369 |
+
|
| 370 |
+
parser.add_argument("--use_wandb", action='store_true', help='enable wandb logging')
|
| 371 |
+
parser.add_argument("--use_ckpt_path", type=str2bool, default=True)
|
| 372 |
+
parser.add_argument("--use_strict_load", type=str2bool, default=True)
|
| 373 |
+
parser.add_argument("--tag", type=str, default='')
|
| 374 |
+
parser.add_argument("--resample", action='store_true', help='enable shard resample')
|
| 375 |
+
|
| 376 |
+
# sampling
|
| 377 |
+
parser.add_argument('--enable_eval', action='store_true', help='enable fid calc during training')
|
| 378 |
+
parser.add_argument('--seeds', type=parse_int_list, default='100000-104999', help='Random seeds (e.g. 1,2,5-10)')
|
| 379 |
+
parser.add_argument('--subdirs', action='store_true', help='Create subdirectory for every 1000 seeds')
|
| 380 |
+
parser.add_argument('--class_idx', type=int, default=None, help='Class label [default: random]')
|
| 381 |
+
parser.add_argument('--max_batch_size', type=int, default=25, help='Maximum batch size per GPU during sampling, must be a factor of 50k if torch.compile is used')
|
| 382 |
+
|
| 383 |
+
parser.add_argument("--cfg_scale", type=parse_float_none, default=None, help='None = no guidance, by default = 4.0')
|
| 384 |
+
|
| 385 |
+
parser.add_argument('--num_steps', type=int, default=40, help='Number of sampling steps')
|
| 386 |
+
parser.add_argument('--S_churn', type=int, default=0, help='Stochasticity strength')
|
| 387 |
+
parser.add_argument('--solver', type=str, default=None, choices=['euler', 'heun'], help='Ablate ODE solver')
|
| 388 |
+
parser.add_argument('--discretization', type=str, default=None, choices=['vp', 've', 'iddpm', 'edm'], help='Ablate ODE solver')
|
| 389 |
+
parser.add_argument('--schedule', type=str, default=None, choices=['vp', 've', 'linear'], help='Ablate noise schedule sigma(t)')
|
| 390 |
+
parser.add_argument('--scaling', type=str, default=None, choices=['vp', 'none'], help='Ablate signal scaling s(t)')
|
| 391 |
+
parser.add_argument('--pretrained_path', type=str, default='assets/stable_diffusion/autoencoder_kl.pth', help='Autoencoder ckpt')
|
| 392 |
+
|
| 393 |
+
parser.add_argument('--ref_path', type=str, default='assets/fid_stats/VIRTUAL_imagenet512.npz', help='Dataset reference statistics')
|
| 394 |
+
parser.add_argument('--num_expected', type=int, default=5000, help='Number of images to use')
|
| 395 |
+
parser.add_argument('--fid_batch_size', type=int, default=64, help='Maximum batch size per GPU')
|
| 396 |
+
|
| 397 |
+
args = parser.parse_args()
|
| 398 |
+
|
| 399 |
+
torch.backends.cudnn.benchmark = True
|
| 400 |
+
train_loop(args)
|
utils.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MIT License
|
| 2 |
+
|
| 3 |
+
# Copyright (c) [2023] [Anima-Lab]
|
| 4 |
+
|
| 5 |
+
# This code is adapted from https://github.com/NVlabs/edm/blob/main/generate.py.
|
| 6 |
+
# The original code is licensed under a Creative Commons
|
| 7 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License, which is can be found at licenses/LICENSE_EDM.txt.
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import re
|
| 12 |
+
import sys
|
| 13 |
+
import contextlib
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.distributed as dist
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
#----------------------------------------------------------------------------
|
| 20 |
+
# Get the latest checkpoint from the save dir
|
| 21 |
+
|
| 22 |
+
def get_latest_ckpt(dir):
|
| 23 |
+
latest_id = -1
|
| 24 |
+
for file in os.listdir(dir):
|
| 25 |
+
if file.endswith('.pt'):
|
| 26 |
+
m = re.search(r'(\d+)\.pt', file)
|
| 27 |
+
if m:
|
| 28 |
+
ckpt_id = int(m.group(1))
|
| 29 |
+
latest_id = max(latest_id, ckpt_id)
|
| 30 |
+
if latest_id == -1:
|
| 31 |
+
return None
|
| 32 |
+
else:
|
| 33 |
+
ckpt_path = os.path.join(dir, f'{latest_id:07d}.pt')
|
| 34 |
+
return ckpt_path
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def get_ckpt_paths(dir, id_min, id_max):
|
| 38 |
+
ckpt_dict = {}
|
| 39 |
+
for file in os.listdir(dir):
|
| 40 |
+
if file.endswith('.pt'):
|
| 41 |
+
m = re.search(r'(\d+)\.pt', file)
|
| 42 |
+
if m:
|
| 43 |
+
ckpt_id = int(m.group(1))
|
| 44 |
+
if id_min <= ckpt_id <= id_max:
|
| 45 |
+
ckpt_dict[ckpt_id] = os.path.join(dir, f'{ckpt_id:07d}.pt')
|
| 46 |
+
return ckpt_dict
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
#----------------------------------------------------------------------------
|
| 50 |
+
# Take the mean over all non-batch dimensions.
|
| 51 |
+
|
| 52 |
+
def mean_flat(tensor):
|
| 53 |
+
return tensor.mean(dim=list(range(1, tensor.ndim)))
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
#----------------------------------------------------------------------------
|
| 57 |
+
# Convert latent (mean, logvar) to latent variable (inherited from autoencoder.py)
|
| 58 |
+
|
| 59 |
+
def sample(moments, scale_factor=0.18215):
|
| 60 |
+
mean, logvar = torch.chunk(moments, 2, dim=1)
|
| 61 |
+
logvar = torch.clamp(logvar, -30.0, 20.0)
|
| 62 |
+
std = torch.exp(0.5 * logvar)
|
| 63 |
+
z = mean + std * torch.randn_like(mean)
|
| 64 |
+
z = scale_factor * z
|
| 65 |
+
return z
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
#----------------------------------------------------------------------------
|
| 69 |
+
# Context manager for easily enabling/disabling DistributedDataParallel
|
| 70 |
+
# synchronization.
|
| 71 |
+
|
| 72 |
+
@contextlib.contextmanager
|
| 73 |
+
def ddp_sync(module, sync):
|
| 74 |
+
assert isinstance(module, torch.nn.Module)
|
| 75 |
+
if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
|
| 76 |
+
yield
|
| 77 |
+
else:
|
| 78 |
+
with module.no_sync():
|
| 79 |
+
yield
|
| 80 |
+
|
| 81 |
+
#----------------------------------------------------------------------------
|
| 82 |
+
# Distributed training helper functions
|
| 83 |
+
|
| 84 |
+
def init_processes(fn, args):
|
| 85 |
+
""" Initialize the distributed environment. """
|
| 86 |
+
os.environ['MASTER_ADDR'] = args.master_address
|
| 87 |
+
os.environ['MASTER_PORT'] = '6020'
|
| 88 |
+
print(f'MASTER_ADDR = {os.environ["MASTER_ADDR"]}')
|
| 89 |
+
print(f'MASTER_PORT = {os.environ["MASTER_PORT"]}')
|
| 90 |
+
torch.cuda.set_device(args.local_rank)
|
| 91 |
+
dist.init_process_group(backend='nccl', init_method='env://', rank=args.global_rank, world_size=args.global_size)
|
| 92 |
+
fn(args)
|
| 93 |
+
if args.global_size > 1:
|
| 94 |
+
cleanup()
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def mprint(*args, **kwargs):
|
| 98 |
+
"""
|
| 99 |
+
Print only from rank 0.
|
| 100 |
+
"""
|
| 101 |
+
if dist.get_rank() == 0:
|
| 102 |
+
print(*args, **kwargs)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def cleanup():
|
| 106 |
+
"""
|
| 107 |
+
End DDP training.
|
| 108 |
+
"""
|
| 109 |
+
dist.barrier()
|
| 110 |
+
mprint("Done!")
|
| 111 |
+
dist.barrier()
|
| 112 |
+
dist.destroy_process_group()
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
#----------------------------------------------------------------------------
|
| 116 |
+
# Wrapper for torch.Generator that allows specifying a different random seed
|
| 117 |
+
# for each sample in a minibatch.
|
| 118 |
+
|
| 119 |
+
class StackedRandomGenerator:
|
| 120 |
+
def __init__(self, device, seeds):
|
| 121 |
+
super().__init__()
|
| 122 |
+
self.generators = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds]
|
| 123 |
+
|
| 124 |
+
def randn(self, size, **kwargs):
|
| 125 |
+
assert size[0] == len(self.generators)
|
| 126 |
+
return torch.stack([torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators])
|
| 127 |
+
|
| 128 |
+
def randn_like(self, input):
|
| 129 |
+
return self.randn(input.shape, dtype=input.dtype, layout=input.layout, device=input.device)
|
| 130 |
+
|
| 131 |
+
def randint(self, *args, size, **kwargs):
|
| 132 |
+
assert size[0] == len(self.generators)
|
| 133 |
+
return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators])
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
#----------------------------------------------------------------------------
|
| 137 |
+
# Parse a comma separated list of numbers or ranges and return a list of ints.
|
| 138 |
+
# Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10]
|
| 139 |
+
|
| 140 |
+
def parse_int_list(s):
|
| 141 |
+
if isinstance(s, list): return s
|
| 142 |
+
ranges = []
|
| 143 |
+
range_re = re.compile(r'^(\d+)-(\d+)$')
|
| 144 |
+
for p in s.split(','):
|
| 145 |
+
m = range_re.match(p)
|
| 146 |
+
if m:
|
| 147 |
+
ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
|
| 148 |
+
else:
|
| 149 |
+
ranges.append(int(p))
|
| 150 |
+
return ranges
|
| 151 |
+
|
| 152 |
+
# Parse 'None' to None and others to float value
|
| 153 |
+
def parse_float_none(s):
|
| 154 |
+
assert isinstance(s, str)
|
| 155 |
+
return None if s == 'None' else float(s)
|
| 156 |
+
|
| 157 |
+
# Parse 'None' to None and others to str
|
| 158 |
+
def parse_str_none(s):
|
| 159 |
+
assert isinstance(s, str)
|
| 160 |
+
return None if s == 'None' else s
|
| 161 |
+
|
| 162 |
+
# Parse 'true' to True
|
| 163 |
+
def str2bool(s):
|
| 164 |
+
return s.lower() in ['true', '1', 'yes']
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
#----------------------------------------------------------------------------
|
| 168 |
+
# logging info.
|
| 169 |
+
class Logger(object):
|
| 170 |
+
"""
|
| 171 |
+
Redirect stderr to stdout, optionally print stdout to a file,
|
| 172 |
+
and optionally force flushing on both stdout and the file.
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
def __init__(self, file_name=None, file_mode="w", should_flush=True):
|
| 176 |
+
self.file = None
|
| 177 |
+
|
| 178 |
+
if file_name is not None:
|
| 179 |
+
self.file = open(file_name, file_mode)
|
| 180 |
+
|
| 181 |
+
self.should_flush = should_flush
|
| 182 |
+
self.stdout = sys.stdout
|
| 183 |
+
self.stderr = sys.stderr
|
| 184 |
+
|
| 185 |
+
sys.stdout = self
|
| 186 |
+
sys.stderr = self
|
| 187 |
+
|
| 188 |
+
def __enter__(self):
|
| 189 |
+
return self
|
| 190 |
+
|
| 191 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
| 192 |
+
self.close()
|
| 193 |
+
|
| 194 |
+
def write(self, text):
|
| 195 |
+
"""Write text to stdout (and a file) and optionally flush."""
|
| 196 |
+
if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
|
| 197 |
+
return
|
| 198 |
+
|
| 199 |
+
if self.file is not None:
|
| 200 |
+
self.file.write(text)
|
| 201 |
+
|
| 202 |
+
self.stdout.write(text)
|
| 203 |
+
|
| 204 |
+
if self.should_flush:
|
| 205 |
+
self.flush()
|
| 206 |
+
|
| 207 |
+
def flush(self):
|
| 208 |
+
"""Flush written text to both stdout and a file, if open."""
|
| 209 |
+
if self.file is not None:
|
| 210 |
+
self.file.flush()
|
| 211 |
+
|
| 212 |
+
self.stdout.flush()
|
| 213 |
+
|
| 214 |
+
def close(self):
|
| 215 |
+
"""Flush, close possible files, and remove stdout/stderr mirroring."""
|
| 216 |
+
self.flush()
|
| 217 |
+
|
| 218 |
+
# if using multiple loggers, prevent closing in wrong order
|
| 219 |
+
if sys.stdout is self:
|
| 220 |
+
sys.stdout = self.stdout
|
| 221 |
+
if sys.stderr is self:
|
| 222 |
+
sys.stderr = self.stderr
|
| 223 |
+
|
| 224 |
+
if self.file is not None:
|
| 225 |
+
self.file.close()
|