Add files using upload-large-folder tool
Browse files- REG/evaluations/README.md +72 -0
- REG/evaluations/evaluator.py +679 -0
- REG/evaluations/requirements.txt +4 -0
- REG/models/clip_vit.py +426 -0
- REG/models/jepa.py +547 -0
- REG/models/mae_vit.py +71 -0
- REG/models/mocov3_vit.py +207 -0
- REG/models/sit.py +420 -0
- back/LICENSE +21 -0
- back/README.md +156 -0
- back/eval.sh +52 -0
- back/loss.py +168 -0
- back/requirements.txt +97 -0
- back/sample_from_checkpoint.py +596 -0
- back/samples.sh +15 -0
- back/samples_0.5.log +0 -0
- back/samples_ddp.sh +32 -0
- back/train.py +670 -0
- back/train.sh +43 -0
- back/utils.py +225 -0
REG/evaluations/README.md
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Evaluations
|
| 2 |
+
|
| 3 |
+
To compare different generative models, we use FID, sFID, Precision, Recall, and Inception Score. These metrics can all be calculated using batches of samples, which we store in `.npz` (numpy) files.
|
| 4 |
+
|
| 5 |
+
# Download batches
|
| 6 |
+
|
| 7 |
+
We provide pre-computed sample batches for the reference datasets, our diffusion models, and several baselines we compare against. These are all stored in `.npz` format.
|
| 8 |
+
|
| 9 |
+
Reference dataset batches contain pre-computed statistics over the whole dataset, as well as 10,000 images for computing Precision and Recall. All other batches contain 50,000 images which can be used to compute statistics and Precision/Recall.
|
| 10 |
+
|
| 11 |
+
Here are links to download all of the sample and reference batches:
|
| 12 |
+
|
| 13 |
+
* LSUN
|
| 14 |
+
* LSUN bedroom: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/VIRTUAL_lsun_bedroom256.npz)
|
| 15 |
+
* [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/admnet_dropout_lsun_bedroom.npz)
|
| 16 |
+
* [DDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/ddpm_lsun_bedroom.npz)
|
| 17 |
+
* [IDDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/iddpm_lsun_bedroom.npz)
|
| 18 |
+
* [StyleGAN](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/stylegan_lsun_bedroom.npz)
|
| 19 |
+
* LSUN cat: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/VIRTUAL_lsun_cat256.npz)
|
| 20 |
+
* [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/admnet_dropout_lsun_cat.npz)
|
| 21 |
+
* [StyleGAN2](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/stylegan2_lsun_cat.npz)
|
| 22 |
+
* LSUN horse: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/VIRTUAL_lsun_horse256.npz)
|
| 23 |
+
* [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/admnet_dropout_lsun_horse.npz)
|
| 24 |
+
* [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/admnet_lsun_horse.npz)
|
| 25 |
+
|
| 26 |
+
* ImageNet
|
| 27 |
+
* ImageNet 64x64: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/VIRTUAL_imagenet64_labeled.npz)
|
| 28 |
+
* [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/admnet_imagenet64.npz)
|
| 29 |
+
* [IDDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/iddpm_imagenet64.npz)
|
| 30 |
+
* [BigGAN](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/biggan_deep_imagenet64.npz)
|
| 31 |
+
* ImageNet 128x128: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/VIRTUAL_imagenet128_labeled.npz)
|
| 32 |
+
* [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_imagenet128.npz)
|
| 33 |
+
* [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_guided_imagenet128.npz)
|
| 34 |
+
* [ADM-G, 25 steps](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_guided_25step_imagenet128.npz)
|
| 35 |
+
* [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/biggan_deep_trunc1_imagenet128.npz)
|
| 36 |
+
* ImageNet 256x256: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz)
|
| 37 |
+
* [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_imagenet256.npz)
|
| 38 |
+
* [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_imagenet256.npz)
|
| 39 |
+
* [ADM-G, 25 step](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_25step_imagenet256.npz)
|
| 40 |
+
* [ADM-G + ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_upsampled_imagenet256.npz)
|
| 41 |
+
* [ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_upsampled_imagenet256.npz)
|
| 42 |
+
* [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/biggan_deep_trunc1_imagenet256.npz)
|
| 43 |
+
* ImageNet 512x512: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/VIRTUAL_imagenet512.npz)
|
| 44 |
+
* [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_imagenet512.npz)
|
| 45 |
+
* [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_imagenet512.npz)
|
| 46 |
+
* [ADM-G, 25 step](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_25step_imagenet512.npz)
|
| 47 |
+
* [ADM-G + ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_upsampled_imagenet512.npz)
|
| 48 |
+
* [ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_upsampled_imagenet512.npz)
|
| 49 |
+
* [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/biggan_deep_trunc1_imagenet512.npz)
|
| 50 |
+
|
| 51 |
+
# Run evaluations
|
| 52 |
+
|
| 53 |
+
First, generate or download a batch of samples and download the corresponding reference batch for the given dataset. For this example, we'll use ImageNet 256x256, so the refernce batch is `VIRTUAL_imagenet256_labeled.npz` and we can use the sample batch `admnet_guided_upsampled_imagenet256.npz`.
|
| 54 |
+
|
| 55 |
+
Next, run the `evaluator.py` script. The requirements of this script can be found in [requirements.txt](requirements.txt). Pass two arguments to the script: the reference batch and the sample batch. The script will download the InceptionV3 model used for evaluations into the current working directory (if it is not already present). This file is roughly 100MB.
|
| 56 |
+
|
| 57 |
+
The output of the script will look something like this, where the first `...` is a bunch of verbose TensorFlow logging:
|
| 58 |
+
|
| 59 |
+
```
|
| 60 |
+
$ python evaluator.py VIRTUAL_imagenet256_labeled.npz admnet_guided_upsampled_imagenet256.npz
|
| 61 |
+
...
|
| 62 |
+
computing reference batch activations...
|
| 63 |
+
computing/reading reference batch statistics...
|
| 64 |
+
computing sample batch activations...
|
| 65 |
+
computing/reading sample batch statistics...
|
| 66 |
+
Computing evaluations...
|
| 67 |
+
Inception Score: 215.8370361328125
|
| 68 |
+
FID: 3.9425574129223264
|
| 69 |
+
sFID: 6.140433703346162
|
| 70 |
+
Precision: 0.8265
|
| 71 |
+
Recall: 0.5309
|
| 72 |
+
```
|
REG/evaluations/evaluator.py
ADDED
|
@@ -0,0 +1,679 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import io
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
import warnings
|
| 6 |
+
import zipfile
|
| 7 |
+
from abc import ABC, abstractmethod
|
| 8 |
+
from contextlib import contextmanager
|
| 9 |
+
from functools import partial
|
| 10 |
+
from multiprocessing import cpu_count
|
| 11 |
+
from multiprocessing.pool import ThreadPool
|
| 12 |
+
from typing import Iterable, Optional, Tuple
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import requests
|
| 16 |
+
import tensorflow.compat.v1 as tf
|
| 17 |
+
from scipy import linalg
|
| 18 |
+
from tqdm.auto import tqdm
|
| 19 |
+
|
| 20 |
+
INCEPTION_V3_URL = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb"
|
| 21 |
+
INCEPTION_V3_PATH = "classify_image_graph_def.pb"
|
| 22 |
+
|
| 23 |
+
FID_POOL_NAME = "pool_3:0"
|
| 24 |
+
FID_SPATIAL_NAME = "mixed_6/conv:0"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def main():
|
| 28 |
+
parser = argparse.ArgumentParser()
|
| 29 |
+
parser.add_argument("--ref_batch", help="path to reference batch npz file")
|
| 30 |
+
parser.add_argument("--sample_batch", help="path to sample batch npz file")
|
| 31 |
+
parser.add_argument("--save_path", help="path to sample batch npz file")
|
| 32 |
+
parser.add_argument("--cfg_cond", default=1, type=int)
|
| 33 |
+
parser.add_argument("--step", default=1, type=int)
|
| 34 |
+
parser.add_argument("--cfg", default=1.0, type=float)
|
| 35 |
+
parser.add_argument("--cls_cfg", default=1.0, type=float)
|
| 36 |
+
parser.add_argument("--gh", default=1.0, type=float)
|
| 37 |
+
parser.add_argument("--num_steps", default=250, type=int)
|
| 38 |
+
args = parser.parse_args()
|
| 39 |
+
|
| 40 |
+
if not os.path.exists(args.save_path):
|
| 41 |
+
os.mkdir(args.save_path)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
config = tf.ConfigProto(
|
| 45 |
+
allow_soft_placement=True # allows DecodeJpeg to run on CPU in Inception graph
|
| 46 |
+
)
|
| 47 |
+
config.gpu_options.allow_growth = True
|
| 48 |
+
evaluator = Evaluator(tf.Session(config=config))
|
| 49 |
+
|
| 50 |
+
print("warming up TensorFlow...")
|
| 51 |
+
# This will cause TF to print a bunch of verbose stuff now rather
|
| 52 |
+
# than after the next print(), to help prevent confusion.
|
| 53 |
+
evaluator.warmup()
|
| 54 |
+
|
| 55 |
+
print("computing reference batch activations...")
|
| 56 |
+
ref_acts = evaluator.read_activations(args.ref_batch)
|
| 57 |
+
print("computing/reading reference batch statistics...")
|
| 58 |
+
ref_stats, ref_stats_spatial = evaluator.read_statistics(args.ref_batch, ref_acts)
|
| 59 |
+
|
| 60 |
+
print("computing sample batch activations...")
|
| 61 |
+
sample_acts = evaluator.read_activations(args.sample_batch)
|
| 62 |
+
print("computing/reading sample batch statistics...")
|
| 63 |
+
sample_stats, sample_stats_spatial = evaluator.read_statistics(args.sample_batch, sample_acts)
|
| 64 |
+
|
| 65 |
+
print("Computing evaluations...")
|
| 66 |
+
Inception_Score = evaluator.compute_inception_score(sample_acts[0])
|
| 67 |
+
FID = sample_stats.frechet_distance(ref_stats)
|
| 68 |
+
sFID = sample_stats_spatial.frechet_distance(ref_stats_spatial)
|
| 69 |
+
prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0])
|
| 70 |
+
|
| 71 |
+
print("Inception Score:", Inception_Score)
|
| 72 |
+
print("FID:", FID)
|
| 73 |
+
print("sFID:", sFID)
|
| 74 |
+
print("Precision:", prec)
|
| 75 |
+
print("Recall:", recall)
|
| 76 |
+
|
| 77 |
+
if args.cfg_cond:
|
| 78 |
+
file_path = args.save_path + str(args.num_steps) + str(args.step) + str(args.cfg) + str(args.gh) + str(args.cls_cfg)+ "cfg_cond_true.txt"
|
| 79 |
+
else:
|
| 80 |
+
file_path = args.save_path + str(args.num_steps) + str(args.step) + str(args.cfg) + str(args.gh) + str(args.cls_cfg)+ "cfg_cond_false.txt"
|
| 81 |
+
with open(file_path, "w") as file:
|
| 82 |
+
file.write("Inception Score: {}\n".format(Inception_Score))
|
| 83 |
+
file.write("FID: {}\n".format(FID))
|
| 84 |
+
file.write("sFID: {}\n".format(sFID))
|
| 85 |
+
file.write("Precision: {}\n".format(prec))
|
| 86 |
+
file.write("Recall: {}\n".format(recall))
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class InvalidFIDException(Exception):
|
| 90 |
+
pass
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class FIDStatistics:
|
| 94 |
+
def __init__(self, mu: np.ndarray, sigma: np.ndarray):
|
| 95 |
+
self.mu = mu
|
| 96 |
+
self.sigma = sigma
|
| 97 |
+
|
| 98 |
+
def frechet_distance(self, other, eps=1e-6):
|
| 99 |
+
"""
|
| 100 |
+
Compute the Frechet distance between two sets of statistics.
|
| 101 |
+
"""
|
| 102 |
+
# https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132
|
| 103 |
+
mu1, sigma1 = self.mu, self.sigma
|
| 104 |
+
mu2, sigma2 = other.mu, other.sigma
|
| 105 |
+
|
| 106 |
+
mu1 = np.atleast_1d(mu1)
|
| 107 |
+
mu2 = np.atleast_1d(mu2)
|
| 108 |
+
|
| 109 |
+
sigma1 = np.atleast_2d(sigma1)
|
| 110 |
+
sigma2 = np.atleast_2d(sigma2)
|
| 111 |
+
|
| 112 |
+
assert (
|
| 113 |
+
mu1.shape == mu2.shape
|
| 114 |
+
), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}"
|
| 115 |
+
assert (
|
| 116 |
+
sigma1.shape == sigma2.shape
|
| 117 |
+
), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}"
|
| 118 |
+
|
| 119 |
+
diff = mu1 - mu2
|
| 120 |
+
|
| 121 |
+
# product might be almost singular
|
| 122 |
+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
| 123 |
+
if not np.isfinite(covmean).all():
|
| 124 |
+
msg = (
|
| 125 |
+
"fid calculation produces singular product; adding %s to diagonal of cov estimates"
|
| 126 |
+
% eps
|
| 127 |
+
)
|
| 128 |
+
warnings.warn(msg)
|
| 129 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
| 130 |
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
| 131 |
+
|
| 132 |
+
# numerical error might give slight imaginary component
|
| 133 |
+
if np.iscomplexobj(covmean):
|
| 134 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
| 135 |
+
m = np.max(np.abs(covmean.imag))
|
| 136 |
+
raise ValueError("Imaginary component {}".format(m))
|
| 137 |
+
covmean = covmean.real
|
| 138 |
+
|
| 139 |
+
tr_covmean = np.trace(covmean)
|
| 140 |
+
|
| 141 |
+
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class Evaluator:
|
| 145 |
+
def __init__(
|
| 146 |
+
self,
|
| 147 |
+
session,
|
| 148 |
+
batch_size=64,
|
| 149 |
+
softmax_batch_size=512,
|
| 150 |
+
):
|
| 151 |
+
self.sess = session
|
| 152 |
+
self.batch_size = batch_size
|
| 153 |
+
self.softmax_batch_size = softmax_batch_size
|
| 154 |
+
self.manifold_estimator = ManifoldEstimator(session)
|
| 155 |
+
with self.sess.graph.as_default():
|
| 156 |
+
self.image_input = tf.placeholder(tf.float32, shape=[None, None, None, 3])
|
| 157 |
+
self.softmax_input = tf.placeholder(tf.float32, shape=[None, 2048])
|
| 158 |
+
self.pool_features, self.spatial_features = _create_feature_graph(self.image_input)
|
| 159 |
+
self.softmax = _create_softmax_graph(self.softmax_input)
|
| 160 |
+
|
| 161 |
+
def warmup(self):
|
| 162 |
+
self.compute_activations(np.zeros([1, 8, 64, 64, 3]))
|
| 163 |
+
|
| 164 |
+
def read_activations(self, npz_path: str) -> Tuple[np.ndarray, np.ndarray]:
|
| 165 |
+
with open_npz_array(npz_path, "arr_0") as reader:
|
| 166 |
+
return self.compute_activations(reader.read_batches(self.batch_size))
|
| 167 |
+
|
| 168 |
+
def compute_activations(self, batches: Iterable[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
|
| 169 |
+
"""
|
| 170 |
+
Compute image features for downstream evals.
|
| 171 |
+
|
| 172 |
+
:param batches: a iterator over NHWC numpy arrays in [0, 255].
|
| 173 |
+
:return: a tuple of numpy arrays of shape [N x X], where X is a feature
|
| 174 |
+
dimension. The tuple is (pool_3, spatial).
|
| 175 |
+
"""
|
| 176 |
+
preds = []
|
| 177 |
+
spatial_preds = []
|
| 178 |
+
for batch in tqdm(batches):
|
| 179 |
+
batch = batch.astype(np.float32)
|
| 180 |
+
pred, spatial_pred = self.sess.run(
|
| 181 |
+
[self.pool_features, self.spatial_features], {self.image_input: batch}
|
| 182 |
+
)
|
| 183 |
+
preds.append(pred.reshape([pred.shape[0], -1]))
|
| 184 |
+
spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1]))
|
| 185 |
+
return (
|
| 186 |
+
np.concatenate(preds, axis=0),
|
| 187 |
+
np.concatenate(spatial_preds, axis=0),
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
def read_statistics(
|
| 191 |
+
self, npz_path: str, activations: Tuple[np.ndarray, np.ndarray]
|
| 192 |
+
) -> Tuple[FIDStatistics, FIDStatistics]:
|
| 193 |
+
obj = np.load(npz_path)
|
| 194 |
+
if "mu" in list(obj.keys()):
|
| 195 |
+
return FIDStatistics(obj["mu"], obj["sigma"]), FIDStatistics(
|
| 196 |
+
obj["mu_s"], obj["sigma_s"]
|
| 197 |
+
)
|
| 198 |
+
return tuple(self.compute_statistics(x) for x in activations)
|
| 199 |
+
|
| 200 |
+
def compute_statistics(self, activations: np.ndarray) -> FIDStatistics:
|
| 201 |
+
mu = np.mean(activations, axis=0)
|
| 202 |
+
sigma = np.cov(activations, rowvar=False)
|
| 203 |
+
return FIDStatistics(mu, sigma)
|
| 204 |
+
|
| 205 |
+
def compute_inception_score(self, activations: np.ndarray, split_size: int = 5000) -> float:
|
| 206 |
+
softmax_out = []
|
| 207 |
+
for i in range(0, len(activations), self.softmax_batch_size):
|
| 208 |
+
acts = activations[i : i + self.softmax_batch_size]
|
| 209 |
+
softmax_out.append(self.sess.run(self.softmax, feed_dict={self.softmax_input: acts}))
|
| 210 |
+
preds = np.concatenate(softmax_out, axis=0)
|
| 211 |
+
# https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46
|
| 212 |
+
scores = []
|
| 213 |
+
for i in range(0, len(preds), split_size):
|
| 214 |
+
part = preds[i : i + split_size]
|
| 215 |
+
kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
|
| 216 |
+
kl = np.mean(np.sum(kl, 1))
|
| 217 |
+
scores.append(np.exp(kl))
|
| 218 |
+
return float(np.mean(scores))
|
| 219 |
+
|
| 220 |
+
def compute_prec_recall(
|
| 221 |
+
self, activations_ref: np.ndarray, activations_sample: np.ndarray
|
| 222 |
+
) -> Tuple[float, float]:
|
| 223 |
+
radii_1 = self.manifold_estimator.manifold_radii(activations_ref)
|
| 224 |
+
radii_2 = self.manifold_estimator.manifold_radii(activations_sample)
|
| 225 |
+
pr = self.manifold_estimator.evaluate_pr(
|
| 226 |
+
activations_ref, radii_1, activations_sample, radii_2
|
| 227 |
+
)
|
| 228 |
+
return (float(pr[0][0]), float(pr[1][0]))
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
class ManifoldEstimator:
|
| 232 |
+
"""
|
| 233 |
+
A helper for comparing manifolds of feature vectors.
|
| 234 |
+
|
| 235 |
+
Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L57
|
| 236 |
+
"""
|
| 237 |
+
|
| 238 |
+
def __init__(
|
| 239 |
+
self,
|
| 240 |
+
session,
|
| 241 |
+
row_batch_size=10000,
|
| 242 |
+
col_batch_size=10000,
|
| 243 |
+
nhood_sizes=(3,),
|
| 244 |
+
clamp_to_percentile=None,
|
| 245 |
+
eps=1e-5,
|
| 246 |
+
):
|
| 247 |
+
"""
|
| 248 |
+
Estimate the manifold of given feature vectors.
|
| 249 |
+
|
| 250 |
+
:param session: the TensorFlow session.
|
| 251 |
+
:param row_batch_size: row batch size to compute pairwise distances
|
| 252 |
+
(parameter to trade-off between memory usage and performance).
|
| 253 |
+
:param col_batch_size: column batch size to compute pairwise distances.
|
| 254 |
+
:param nhood_sizes: number of neighbors used to estimate the manifold.
|
| 255 |
+
:param clamp_to_percentile: prune hyperspheres that have radius larger than
|
| 256 |
+
the given percentile.
|
| 257 |
+
:param eps: small number for numerical stability.
|
| 258 |
+
"""
|
| 259 |
+
self.distance_block = DistanceBlock(session)
|
| 260 |
+
self.row_batch_size = row_batch_size
|
| 261 |
+
self.col_batch_size = col_batch_size
|
| 262 |
+
self.nhood_sizes = nhood_sizes
|
| 263 |
+
self.num_nhoods = len(nhood_sizes)
|
| 264 |
+
self.clamp_to_percentile = clamp_to_percentile
|
| 265 |
+
self.eps = eps
|
| 266 |
+
|
| 267 |
+
def warmup(self):
|
| 268 |
+
feats, radii = (
|
| 269 |
+
np.zeros([1, 2048], dtype=np.float32),
|
| 270 |
+
np.zeros([1, 1], dtype=np.float32),
|
| 271 |
+
)
|
| 272 |
+
self.evaluate_pr(feats, radii, feats, radii)
|
| 273 |
+
|
| 274 |
+
def manifold_radii(self, features: np.ndarray) -> np.ndarray:
|
| 275 |
+
num_images = len(features)
|
| 276 |
+
|
| 277 |
+
# Estimate manifold of features by calculating distances to k-NN of each sample.
|
| 278 |
+
radii = np.zeros([num_images, self.num_nhoods], dtype=np.float32)
|
| 279 |
+
distance_batch = np.zeros([self.row_batch_size, num_images], dtype=np.float32)
|
| 280 |
+
seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32)
|
| 281 |
+
|
| 282 |
+
for begin1 in range(0, num_images, self.row_batch_size):
|
| 283 |
+
end1 = min(begin1 + self.row_batch_size, num_images)
|
| 284 |
+
row_batch = features[begin1:end1]
|
| 285 |
+
|
| 286 |
+
for begin2 in range(0, num_images, self.col_batch_size):
|
| 287 |
+
end2 = min(begin2 + self.col_batch_size, num_images)
|
| 288 |
+
col_batch = features[begin2:end2]
|
| 289 |
+
|
| 290 |
+
# Compute distances between batches.
|
| 291 |
+
distance_batch[
|
| 292 |
+
0 : end1 - begin1, begin2:end2
|
| 293 |
+
] = self.distance_block.pairwise_distances(row_batch, col_batch)
|
| 294 |
+
|
| 295 |
+
# Find the k-nearest neighbor from the current batch.
|
| 296 |
+
radii[begin1:end1, :] = np.concatenate(
|
| 297 |
+
[
|
| 298 |
+
x[:, self.nhood_sizes]
|
| 299 |
+
for x in _numpy_partition(distance_batch[0 : end1 - begin1, :], seq, axis=1)
|
| 300 |
+
],
|
| 301 |
+
axis=0,
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
if self.clamp_to_percentile is not None:
|
| 305 |
+
max_distances = np.percentile(radii, self.clamp_to_percentile, axis=0)
|
| 306 |
+
radii[radii > max_distances] = 0
|
| 307 |
+
return radii
|
| 308 |
+
|
| 309 |
+
def evaluate(self, features: np.ndarray, radii: np.ndarray, eval_features: np.ndarray):
|
| 310 |
+
"""
|
| 311 |
+
Evaluate if new feature vectors are at the manifold.
|
| 312 |
+
"""
|
| 313 |
+
num_eval_images = eval_features.shape[0]
|
| 314 |
+
num_ref_images = radii.shape[0]
|
| 315 |
+
distance_batch = np.zeros([self.row_batch_size, num_ref_images], dtype=np.float32)
|
| 316 |
+
batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32)
|
| 317 |
+
max_realism_score = np.zeros([num_eval_images], dtype=np.float32)
|
| 318 |
+
nearest_indices = np.zeros([num_eval_images], dtype=np.int32)
|
| 319 |
+
|
| 320 |
+
for begin1 in range(0, num_eval_images, self.row_batch_size):
|
| 321 |
+
end1 = min(begin1 + self.row_batch_size, num_eval_images)
|
| 322 |
+
feature_batch = eval_features[begin1:end1]
|
| 323 |
+
|
| 324 |
+
for begin2 in range(0, num_ref_images, self.col_batch_size):
|
| 325 |
+
end2 = min(begin2 + self.col_batch_size, num_ref_images)
|
| 326 |
+
ref_batch = features[begin2:end2]
|
| 327 |
+
|
| 328 |
+
distance_batch[
|
| 329 |
+
0 : end1 - begin1, begin2:end2
|
| 330 |
+
] = self.distance_block.pairwise_distances(feature_batch, ref_batch)
|
| 331 |
+
|
| 332 |
+
# From the minibatch of new feature vectors, determine if they are in the estimated manifold.
|
| 333 |
+
# If a feature vector is inside a hypersphere of some reference sample, then
|
| 334 |
+
# the new sample lies at the estimated manifold.
|
| 335 |
+
# The radii of the hyperspheres are determined from distances of neighborhood size k.
|
| 336 |
+
samples_in_manifold = distance_batch[0 : end1 - begin1, :, None] <= radii
|
| 337 |
+
batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(np.int32)
|
| 338 |
+
|
| 339 |
+
max_realism_score[begin1:end1] = np.max(
|
| 340 |
+
radii[:, 0] / (distance_batch[0 : end1 - begin1, :] + self.eps), axis=1
|
| 341 |
+
)
|
| 342 |
+
nearest_indices[begin1:end1] = np.argmin(distance_batch[0 : end1 - begin1, :], axis=1)
|
| 343 |
+
|
| 344 |
+
return {
|
| 345 |
+
"fraction": float(np.mean(batch_predictions)),
|
| 346 |
+
"batch_predictions": batch_predictions,
|
| 347 |
+
"max_realisim_score": max_realism_score,
|
| 348 |
+
"nearest_indices": nearest_indices,
|
| 349 |
+
}
|
| 350 |
+
|
| 351 |
+
def evaluate_pr(
|
| 352 |
+
self,
|
| 353 |
+
features_1: np.ndarray,
|
| 354 |
+
radii_1: np.ndarray,
|
| 355 |
+
features_2: np.ndarray,
|
| 356 |
+
radii_2: np.ndarray,
|
| 357 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 358 |
+
"""
|
| 359 |
+
Evaluate precision and recall efficiently.
|
| 360 |
+
|
| 361 |
+
:param features_1: [N1 x D] feature vectors for reference batch.
|
| 362 |
+
:param radii_1: [N1 x K1] radii for reference vectors.
|
| 363 |
+
:param features_2: [N2 x D] feature vectors for the other batch.
|
| 364 |
+
:param radii_2: [N x K2] radii for other vectors.
|
| 365 |
+
:return: a tuple of arrays for (precision, recall):
|
| 366 |
+
- precision: an np.ndarray of length K1
|
| 367 |
+
- recall: an np.ndarray of length K2
|
| 368 |
+
"""
|
| 369 |
+
features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=np.bool_)
|
| 370 |
+
features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=np.bool_)
|
| 371 |
+
for begin_1 in range(0, len(features_1), self.row_batch_size):
|
| 372 |
+
end_1 = begin_1 + self.row_batch_size
|
| 373 |
+
batch_1 = features_1[begin_1:end_1]
|
| 374 |
+
for begin_2 in range(0, len(features_2), self.col_batch_size):
|
| 375 |
+
end_2 = begin_2 + self.col_batch_size
|
| 376 |
+
batch_2 = features_2[begin_2:end_2]
|
| 377 |
+
batch_1_in, batch_2_in = self.distance_block.less_thans(
|
| 378 |
+
batch_1, radii_1[begin_1:end_1], batch_2, radii_2[begin_2:end_2]
|
| 379 |
+
)
|
| 380 |
+
features_1_status[begin_1:end_1] |= batch_1_in
|
| 381 |
+
features_2_status[begin_2:end_2] |= batch_2_in
|
| 382 |
+
return (
|
| 383 |
+
np.mean(features_2_status.astype(np.float64), axis=0),
|
| 384 |
+
np.mean(features_1_status.astype(np.float64), axis=0),
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
class DistanceBlock:
|
| 389 |
+
"""
|
| 390 |
+
Calculate pairwise distances between vectors.
|
| 391 |
+
|
| 392 |
+
Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L34
|
| 393 |
+
"""
|
| 394 |
+
|
| 395 |
+
def __init__(self, session):
|
| 396 |
+
self.session = session
|
| 397 |
+
|
| 398 |
+
# Initialize TF graph to calculate pairwise distances.
|
| 399 |
+
with session.graph.as_default():
|
| 400 |
+
self._features_batch1 = tf.placeholder(tf.float32, shape=[None, None])
|
| 401 |
+
self._features_batch2 = tf.placeholder(tf.float32, shape=[None, None])
|
| 402 |
+
distance_block_16 = _batch_pairwise_distances(
|
| 403 |
+
tf.cast(self._features_batch1, tf.float16),
|
| 404 |
+
tf.cast(self._features_batch2, tf.float16),
|
| 405 |
+
)
|
| 406 |
+
self.distance_block = tf.cond(
|
| 407 |
+
tf.reduce_all(tf.math.is_finite(distance_block_16)),
|
| 408 |
+
lambda: tf.cast(distance_block_16, tf.float32),
|
| 409 |
+
lambda: _batch_pairwise_distances(self._features_batch1, self._features_batch2),
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
# Extra logic for less thans.
|
| 413 |
+
self._radii1 = tf.placeholder(tf.float32, shape=[None, None])
|
| 414 |
+
self._radii2 = tf.placeholder(tf.float32, shape=[None, None])
|
| 415 |
+
dist32 = tf.cast(self.distance_block, tf.float32)[..., None]
|
| 416 |
+
self._batch_1_in = tf.math.reduce_any(dist32 <= self._radii2, axis=1)
|
| 417 |
+
self._batch_2_in = tf.math.reduce_any(dist32 <= self._radii1[:, None], axis=0)
|
| 418 |
+
|
| 419 |
+
def pairwise_distances(self, U, V):
|
| 420 |
+
"""
|
| 421 |
+
Evaluate pairwise distances between two batches of feature vectors.
|
| 422 |
+
"""
|
| 423 |
+
return self.session.run(
|
| 424 |
+
self.distance_block,
|
| 425 |
+
feed_dict={self._features_batch1: U, self._features_batch2: V},
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
def less_thans(self, batch_1, radii_1, batch_2, radii_2):
|
| 429 |
+
return self.session.run(
|
| 430 |
+
[self._batch_1_in, self._batch_2_in],
|
| 431 |
+
feed_dict={
|
| 432 |
+
self._features_batch1: batch_1,
|
| 433 |
+
self._features_batch2: batch_2,
|
| 434 |
+
self._radii1: radii_1,
|
| 435 |
+
self._radii2: radii_2,
|
| 436 |
+
},
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
def _batch_pairwise_distances(U, V):
|
| 441 |
+
"""
|
| 442 |
+
Compute pairwise distances between two batches of feature vectors.
|
| 443 |
+
"""
|
| 444 |
+
with tf.variable_scope("pairwise_dist_block"):
|
| 445 |
+
# Squared norms of each row in U and V.
|
| 446 |
+
norm_u = tf.reduce_sum(tf.square(U), 1)
|
| 447 |
+
norm_v = tf.reduce_sum(tf.square(V), 1)
|
| 448 |
+
|
| 449 |
+
# norm_u as a column and norm_v as a row vectors.
|
| 450 |
+
norm_u = tf.reshape(norm_u, [-1, 1])
|
| 451 |
+
norm_v = tf.reshape(norm_v, [1, -1])
|
| 452 |
+
|
| 453 |
+
# Pairwise squared Euclidean distances.
|
| 454 |
+
D = tf.maximum(norm_u - 2 * tf.matmul(U, V, False, True) + norm_v, 0.0)
|
| 455 |
+
|
| 456 |
+
return D
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
class NpzArrayReader(ABC):
|
| 460 |
+
@abstractmethod
|
| 461 |
+
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
|
| 462 |
+
pass
|
| 463 |
+
|
| 464 |
+
@abstractmethod
|
| 465 |
+
def remaining(self) -> int:
|
| 466 |
+
pass
|
| 467 |
+
|
| 468 |
+
def read_batches(self, batch_size: int) -> Iterable[np.ndarray]:
|
| 469 |
+
def gen_fn():
|
| 470 |
+
while True:
|
| 471 |
+
batch = self.read_batch(batch_size)
|
| 472 |
+
if batch is None:
|
| 473 |
+
break
|
| 474 |
+
yield batch
|
| 475 |
+
|
| 476 |
+
rem = self.remaining()
|
| 477 |
+
num_batches = rem // batch_size + int(rem % batch_size != 0)
|
| 478 |
+
return BatchIterator(gen_fn, num_batches)
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
class BatchIterator:
|
| 482 |
+
def __init__(self, gen_fn, length):
|
| 483 |
+
self.gen_fn = gen_fn
|
| 484 |
+
self.length = length
|
| 485 |
+
|
| 486 |
+
def __len__(self):
|
| 487 |
+
return self.length
|
| 488 |
+
|
| 489 |
+
def __iter__(self):
|
| 490 |
+
return self.gen_fn()
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
class StreamingNpzArrayReader(NpzArrayReader):
|
| 494 |
+
def __init__(self, arr_f, shape, dtype):
|
| 495 |
+
self.arr_f = arr_f
|
| 496 |
+
self.shape = shape
|
| 497 |
+
self.dtype = dtype
|
| 498 |
+
self.idx = 0
|
| 499 |
+
|
| 500 |
+
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
|
| 501 |
+
if self.idx >= self.shape[0]:
|
| 502 |
+
return None
|
| 503 |
+
|
| 504 |
+
bs = min(batch_size, self.shape[0] - self.idx)
|
| 505 |
+
self.idx += bs
|
| 506 |
+
|
| 507 |
+
if self.dtype.itemsize == 0:
|
| 508 |
+
return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype)
|
| 509 |
+
|
| 510 |
+
read_count = bs * np.prod(self.shape[1:])
|
| 511 |
+
read_size = int(read_count * self.dtype.itemsize)
|
| 512 |
+
data = _read_bytes(self.arr_f, read_size, "array data")
|
| 513 |
+
return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]])
|
| 514 |
+
|
| 515 |
+
def remaining(self) -> int:
|
| 516 |
+
return max(0, self.shape[0] - self.idx)
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
class MemoryNpzArrayReader(NpzArrayReader):
|
| 520 |
+
def __init__(self, arr):
|
| 521 |
+
self.arr = arr
|
| 522 |
+
self.idx = 0
|
| 523 |
+
|
| 524 |
+
@classmethod
|
| 525 |
+
def load(cls, path: str, arr_name: str):
|
| 526 |
+
with open(path, "rb") as f:
|
| 527 |
+
arr = np.load(f)[arr_name]
|
| 528 |
+
return cls(arr)
|
| 529 |
+
|
| 530 |
+
def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
|
| 531 |
+
if self.idx >= self.arr.shape[0]:
|
| 532 |
+
return None
|
| 533 |
+
|
| 534 |
+
res = self.arr[self.idx : self.idx + batch_size]
|
| 535 |
+
self.idx += batch_size
|
| 536 |
+
return res
|
| 537 |
+
|
| 538 |
+
def remaining(self) -> int:
|
| 539 |
+
return max(0, self.arr.shape[0] - self.idx)
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
@contextmanager
|
| 543 |
+
def open_npz_array(path: str, arr_name: str) -> NpzArrayReader:
|
| 544 |
+
with _open_npy_file(path, arr_name) as arr_f:
|
| 545 |
+
version = np.lib.format.read_magic(arr_f)
|
| 546 |
+
if version == (1, 0):
|
| 547 |
+
header = np.lib.format.read_array_header_1_0(arr_f)
|
| 548 |
+
elif version == (2, 0):
|
| 549 |
+
header = np.lib.format.read_array_header_2_0(arr_f)
|
| 550 |
+
else:
|
| 551 |
+
yield MemoryNpzArrayReader.load(path, arr_name)
|
| 552 |
+
return
|
| 553 |
+
shape, fortran, dtype = header
|
| 554 |
+
if fortran or dtype.hasobject:
|
| 555 |
+
yield MemoryNpzArrayReader.load(path, arr_name)
|
| 556 |
+
else:
|
| 557 |
+
yield StreamingNpzArrayReader(arr_f, shape, dtype)
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
def _read_bytes(fp, size, error_template="ran out of data"):
|
| 561 |
+
"""
|
| 562 |
+
Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886
|
| 563 |
+
|
| 564 |
+
Read from file-like object until size bytes are read.
|
| 565 |
+
Raises ValueError if not EOF is encountered before size bytes are read.
|
| 566 |
+
Non-blocking objects only supported if they derive from io objects.
|
| 567 |
+
Required as e.g. ZipExtFile in python 2.6 can return less data than
|
| 568 |
+
requested.
|
| 569 |
+
"""
|
| 570 |
+
data = bytes()
|
| 571 |
+
while True:
|
| 572 |
+
# io files (default in python3) return None or raise on
|
| 573 |
+
# would-block, python2 file will truncate, probably nothing can be
|
| 574 |
+
# done about that. note that regular files can't be non-blocking
|
| 575 |
+
try:
|
| 576 |
+
r = fp.read(size - len(data))
|
| 577 |
+
data += r
|
| 578 |
+
if len(r) == 0 or len(data) == size:
|
| 579 |
+
break
|
| 580 |
+
except io.BlockingIOError:
|
| 581 |
+
pass
|
| 582 |
+
if len(data) != size:
|
| 583 |
+
msg = "EOF: reading %s, expected %d bytes got %d"
|
| 584 |
+
raise ValueError(msg % (error_template, size, len(data)))
|
| 585 |
+
else:
|
| 586 |
+
return data
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
@contextmanager
|
| 590 |
+
def _open_npy_file(path: str, arr_name: str):
|
| 591 |
+
with open(path, "rb") as f:
|
| 592 |
+
with zipfile.ZipFile(f, "r") as zip_f:
|
| 593 |
+
if f"{arr_name}.npy" not in zip_f.namelist():
|
| 594 |
+
raise ValueError(f"missing {arr_name} in npz file")
|
| 595 |
+
with zip_f.open(f"{arr_name}.npy", "r") as arr_f:
|
| 596 |
+
yield arr_f
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
def _download_inception_model():
|
| 600 |
+
if os.path.exists(INCEPTION_V3_PATH):
|
| 601 |
+
return
|
| 602 |
+
print("downloading InceptionV3 model...")
|
| 603 |
+
with requests.get(INCEPTION_V3_URL, stream=True) as r:
|
| 604 |
+
r.raise_for_status()
|
| 605 |
+
tmp_path = INCEPTION_V3_PATH + ".tmp"
|
| 606 |
+
with open(tmp_path, "wb") as f:
|
| 607 |
+
for chunk in tqdm(r.iter_content(chunk_size=8192)):
|
| 608 |
+
f.write(chunk)
|
| 609 |
+
os.rename(tmp_path, INCEPTION_V3_PATH)
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
def _create_feature_graph(input_batch):
|
| 613 |
+
_download_inception_model()
|
| 614 |
+
prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
|
| 615 |
+
with open(INCEPTION_V3_PATH, "rb") as f:
|
| 616 |
+
graph_def = tf.GraphDef()
|
| 617 |
+
graph_def.ParseFromString(f.read())
|
| 618 |
+
pool3, spatial = tf.import_graph_def(
|
| 619 |
+
graph_def,
|
| 620 |
+
input_map={f"ExpandDims:0": input_batch},
|
| 621 |
+
return_elements=[FID_POOL_NAME, FID_SPATIAL_NAME],
|
| 622 |
+
name=prefix,
|
| 623 |
+
)
|
| 624 |
+
_update_shapes(pool3)
|
| 625 |
+
spatial = spatial[..., :7]
|
| 626 |
+
return pool3, spatial
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
def _create_softmax_graph(input_batch):
|
| 630 |
+
_download_inception_model()
|
| 631 |
+
prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
|
| 632 |
+
with open(INCEPTION_V3_PATH, "rb") as f:
|
| 633 |
+
graph_def = tf.GraphDef()
|
| 634 |
+
graph_def.ParseFromString(f.read())
|
| 635 |
+
(matmul,) = tf.import_graph_def(
|
| 636 |
+
graph_def, return_elements=[f"softmax/logits/MatMul"], name=prefix
|
| 637 |
+
)
|
| 638 |
+
w = matmul.inputs[1]
|
| 639 |
+
logits = tf.matmul(input_batch, w)
|
| 640 |
+
return tf.nn.softmax(logits)
|
| 641 |
+
|
| 642 |
+
|
| 643 |
+
def _update_shapes(pool3):
|
| 644 |
+
# https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L50-L63
|
| 645 |
+
ops = pool3.graph.get_operations()
|
| 646 |
+
for op in ops:
|
| 647 |
+
for o in op.outputs:
|
| 648 |
+
shape = o.get_shape()
|
| 649 |
+
if shape._dims is not None: # pylint: disable=protected-access
|
| 650 |
+
# shape = [s.value for s in shape] TF 1.x
|
| 651 |
+
shape = [s for s in shape] # TF 2.x
|
| 652 |
+
new_shape = []
|
| 653 |
+
for j, s in enumerate(shape):
|
| 654 |
+
if s == 1 and j == 0:
|
| 655 |
+
new_shape.append(None)
|
| 656 |
+
else:
|
| 657 |
+
new_shape.append(s)
|
| 658 |
+
o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
|
| 659 |
+
return pool3
|
| 660 |
+
|
| 661 |
+
|
| 662 |
+
def _numpy_partition(arr, kth, **kwargs):
|
| 663 |
+
num_workers = min(cpu_count(), len(arr))
|
| 664 |
+
chunk_size = len(arr) // num_workers
|
| 665 |
+
extra = len(arr) % num_workers
|
| 666 |
+
|
| 667 |
+
start_idx = 0
|
| 668 |
+
batches = []
|
| 669 |
+
for i in range(num_workers):
|
| 670 |
+
size = chunk_size + (1 if i < extra else 0)
|
| 671 |
+
batches.append(arr[start_idx : start_idx + size])
|
| 672 |
+
start_idx += size
|
| 673 |
+
|
| 674 |
+
with ThreadPool(num_workers) as pool:
|
| 675 |
+
return list(pool.map(partial(np.partition, kth=kth, **kwargs), batches))
|
| 676 |
+
|
| 677 |
+
|
| 678 |
+
if __name__ == "__main__":
|
| 679 |
+
main()
|
REG/evaluations/requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
tensorflow-gpu>=2.0
|
| 2 |
+
scipy
|
| 3 |
+
requests
|
| 4 |
+
tqdm
|
REG/models/clip_vit.py
ADDED
|
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
from typing import Tuple, Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch import nn
|
| 8 |
+
|
| 9 |
+
import clip
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Bottleneck(nn.Module):
|
| 13 |
+
expansion = 4
|
| 14 |
+
|
| 15 |
+
def __init__(self, inplanes, planes, stride=1):
|
| 16 |
+
super().__init__()
|
| 17 |
+
|
| 18 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
| 19 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
| 20 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 21 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 22 |
+
|
| 23 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
| 24 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 25 |
+
self.relu2 = nn.ReLU(inplace=True)
|
| 26 |
+
|
| 27 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
| 28 |
+
|
| 29 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
| 30 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
| 31 |
+
self.relu3 = nn.ReLU(inplace=True)
|
| 32 |
+
|
| 33 |
+
self.downsample = None
|
| 34 |
+
self.stride = stride
|
| 35 |
+
|
| 36 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
| 37 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
| 38 |
+
self.downsample = nn.Sequential(OrderedDict([
|
| 39 |
+
("-1", nn.AvgPool2d(stride)),
|
| 40 |
+
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
| 41 |
+
("1", nn.BatchNorm2d(planes * self.expansion))
|
| 42 |
+
]))
|
| 43 |
+
|
| 44 |
+
def forward(self, x: torch.Tensor):
|
| 45 |
+
identity = x
|
| 46 |
+
|
| 47 |
+
out = self.relu1(self.bn1(self.conv1(x)))
|
| 48 |
+
out = self.relu2(self.bn2(self.conv2(out)))
|
| 49 |
+
out = self.avgpool(out)
|
| 50 |
+
out = self.bn3(self.conv3(out))
|
| 51 |
+
|
| 52 |
+
if self.downsample is not None:
|
| 53 |
+
identity = self.downsample(x)
|
| 54 |
+
|
| 55 |
+
out += identity
|
| 56 |
+
out = self.relu3(out)
|
| 57 |
+
return out
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class AttentionPool2d(nn.Module):
|
| 61 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
| 62 |
+
super().__init__()
|
| 63 |
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
| 64 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
| 65 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
| 66 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
| 67 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
| 68 |
+
self.num_heads = num_heads
|
| 69 |
+
|
| 70 |
+
def forward(self, x):
|
| 71 |
+
x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
|
| 72 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
| 73 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
| 74 |
+
x, _ = F.multi_head_attention_forward(
|
| 75 |
+
query=x[:1], key=x, value=x,
|
| 76 |
+
embed_dim_to_check=x.shape[-1],
|
| 77 |
+
num_heads=self.num_heads,
|
| 78 |
+
q_proj_weight=self.q_proj.weight,
|
| 79 |
+
k_proj_weight=self.k_proj.weight,
|
| 80 |
+
v_proj_weight=self.v_proj.weight,
|
| 81 |
+
in_proj_weight=None,
|
| 82 |
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
| 83 |
+
bias_k=None,
|
| 84 |
+
bias_v=None,
|
| 85 |
+
add_zero_attn=False,
|
| 86 |
+
dropout_p=0,
|
| 87 |
+
out_proj_weight=self.c_proj.weight,
|
| 88 |
+
out_proj_bias=self.c_proj.bias,
|
| 89 |
+
use_separate_proj_weight=True,
|
| 90 |
+
training=self.training,
|
| 91 |
+
need_weights=False
|
| 92 |
+
)
|
| 93 |
+
return x.squeeze(0)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class ModifiedResNet(nn.Module):
|
| 97 |
+
"""
|
| 98 |
+
A ResNet class that is similar to torchvision's but contains the following changes:
|
| 99 |
+
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
| 100 |
+
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
| 101 |
+
- The final pooling layer is a QKV attention instead of an average pool
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
| 105 |
+
super().__init__()
|
| 106 |
+
self.output_dim = output_dim
|
| 107 |
+
self.input_resolution = input_resolution
|
| 108 |
+
|
| 109 |
+
# the 3-layer stem
|
| 110 |
+
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
| 111 |
+
self.bn1 = nn.BatchNorm2d(width // 2)
|
| 112 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 113 |
+
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
| 114 |
+
self.bn2 = nn.BatchNorm2d(width // 2)
|
| 115 |
+
self.relu2 = nn.ReLU(inplace=True)
|
| 116 |
+
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
| 117 |
+
self.bn3 = nn.BatchNorm2d(width)
|
| 118 |
+
self.relu3 = nn.ReLU(inplace=True)
|
| 119 |
+
self.avgpool = nn.AvgPool2d(2)
|
| 120 |
+
|
| 121 |
+
# residual layers
|
| 122 |
+
self._inplanes = width # this is a *mutable* variable used during construction
|
| 123 |
+
self.layer1 = self._make_layer(width, layers[0])
|
| 124 |
+
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
| 125 |
+
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
| 126 |
+
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
| 127 |
+
|
| 128 |
+
embed_dim = width * 32 # the ResNet feature dimension
|
| 129 |
+
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
| 130 |
+
|
| 131 |
+
def _make_layer(self, planes, blocks, stride=1):
|
| 132 |
+
layers = [Bottleneck(self._inplanes, planes, stride)]
|
| 133 |
+
|
| 134 |
+
self._inplanes = planes * Bottleneck.expansion
|
| 135 |
+
for _ in range(1, blocks):
|
| 136 |
+
layers.append(Bottleneck(self._inplanes, planes))
|
| 137 |
+
|
| 138 |
+
return nn.Sequential(*layers)
|
| 139 |
+
|
| 140 |
+
def forward(self, x):
|
| 141 |
+
def stem(x):
|
| 142 |
+
x = self.relu1(self.bn1(self.conv1(x)))
|
| 143 |
+
x = self.relu2(self.bn2(self.conv2(x)))
|
| 144 |
+
x = self.relu3(self.bn3(self.conv3(x)))
|
| 145 |
+
x = self.avgpool(x)
|
| 146 |
+
return x
|
| 147 |
+
|
| 148 |
+
x = x.type(self.conv1.weight.dtype)
|
| 149 |
+
x = stem(x)
|
| 150 |
+
x = self.layer1(x)
|
| 151 |
+
x = self.layer2(x)
|
| 152 |
+
x = self.layer3(x)
|
| 153 |
+
x = self.layer4(x)
|
| 154 |
+
x = self.attnpool(x)
|
| 155 |
+
|
| 156 |
+
return x
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class LayerNorm(nn.LayerNorm):
|
| 160 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
| 161 |
+
|
| 162 |
+
def forward(self, x: torch.Tensor):
|
| 163 |
+
orig_type = x.dtype
|
| 164 |
+
ret = super().forward(x.type(torch.float32))
|
| 165 |
+
return ret.type(orig_type)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class QuickGELU(nn.Module):
|
| 169 |
+
def forward(self, x: torch.Tensor):
|
| 170 |
+
return x * torch.sigmoid(1.702 * x)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class ResidualAttentionBlock(nn.Module):
|
| 174 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
| 175 |
+
super().__init__()
|
| 176 |
+
|
| 177 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
| 178 |
+
self.ln_1 = LayerNorm(d_model)
|
| 179 |
+
self.mlp = nn.Sequential(OrderedDict([
|
| 180 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
| 181 |
+
("gelu", QuickGELU()),
|
| 182 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
| 183 |
+
]))
|
| 184 |
+
self.ln_2 = LayerNorm(d_model)
|
| 185 |
+
self.attn_mask = attn_mask
|
| 186 |
+
|
| 187 |
+
def attention(self, x: torch.Tensor):
|
| 188 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
| 189 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
| 190 |
+
|
| 191 |
+
def forward(self, x: torch.Tensor):
|
| 192 |
+
x = x + self.attention(self.ln_1(x))
|
| 193 |
+
x = x + self.mlp(self.ln_2(x))
|
| 194 |
+
return x
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
class Transformer(nn.Module):
|
| 198 |
+
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
|
| 199 |
+
super().__init__()
|
| 200 |
+
self.width = width
|
| 201 |
+
self.layers = layers
|
| 202 |
+
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
|
| 203 |
+
|
| 204 |
+
def forward(self, x: torch.Tensor):
|
| 205 |
+
return self.resblocks(x)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class UpdatedVisionTransformer(nn.Module):
|
| 209 |
+
def __init__(self, model):
|
| 210 |
+
super().__init__()
|
| 211 |
+
self.model = model
|
| 212 |
+
|
| 213 |
+
def forward(self, x: torch.Tensor):
|
| 214 |
+
x = self.model.conv1(x) # shape = [*, width, grid, grid]
|
| 215 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
| 216 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
| 217 |
+
x = torch.cat([self.model.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
| 218 |
+
x = x + self.model.positional_embedding.to(x.dtype)
|
| 219 |
+
x = self.model.ln_pre(x)
|
| 220 |
+
|
| 221 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 222 |
+
x = self.model.transformer(x)
|
| 223 |
+
x = x.permute(1, 0, 2)[:, 1:] # LND -> NLD
|
| 224 |
+
|
| 225 |
+
# x = self.ln_post(x[:, 0, :])
|
| 226 |
+
|
| 227 |
+
# if self.proj is not None:
|
| 228 |
+
# x = x @ self.proj
|
| 229 |
+
|
| 230 |
+
return x
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class CLIP(nn.Module):
|
| 234 |
+
def __init__(self,
|
| 235 |
+
embed_dim: int,
|
| 236 |
+
# vision
|
| 237 |
+
image_resolution: int,
|
| 238 |
+
vision_layers: Union[Tuple[int, int, int, int], int],
|
| 239 |
+
vision_width: int,
|
| 240 |
+
vision_patch_size: int,
|
| 241 |
+
# text
|
| 242 |
+
context_length: int,
|
| 243 |
+
vocab_size: int,
|
| 244 |
+
transformer_width: int,
|
| 245 |
+
transformer_heads: int,
|
| 246 |
+
transformer_layers: int
|
| 247 |
+
):
|
| 248 |
+
super().__init__()
|
| 249 |
+
|
| 250 |
+
self.context_length = context_length
|
| 251 |
+
|
| 252 |
+
if isinstance(vision_layers, (tuple, list)):
|
| 253 |
+
vision_heads = vision_width * 32 // 64
|
| 254 |
+
self.visual = ModifiedResNet(
|
| 255 |
+
layers=vision_layers,
|
| 256 |
+
output_dim=embed_dim,
|
| 257 |
+
heads=vision_heads,
|
| 258 |
+
input_resolution=image_resolution,
|
| 259 |
+
width=vision_width
|
| 260 |
+
)
|
| 261 |
+
else:
|
| 262 |
+
vision_heads = vision_width // 64
|
| 263 |
+
self.visual = UpdatedVisionTransformer(
|
| 264 |
+
input_resolution=image_resolution,
|
| 265 |
+
patch_size=vision_patch_size,
|
| 266 |
+
width=vision_width,
|
| 267 |
+
layers=vision_layers,
|
| 268 |
+
heads=vision_heads,
|
| 269 |
+
output_dim=embed_dim
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
self.transformer = Transformer(
|
| 273 |
+
width=transformer_width,
|
| 274 |
+
layers=transformer_layers,
|
| 275 |
+
heads=transformer_heads,
|
| 276 |
+
attn_mask=self.build_attention_mask()
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
self.vocab_size = vocab_size
|
| 280 |
+
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
| 281 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
| 282 |
+
self.ln_final = LayerNorm(transformer_width)
|
| 283 |
+
|
| 284 |
+
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
| 285 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
| 286 |
+
|
| 287 |
+
self.initialize_parameters()
|
| 288 |
+
|
| 289 |
+
def initialize_parameters(self):
|
| 290 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
| 291 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
| 292 |
+
|
| 293 |
+
if isinstance(self.visual, ModifiedResNet):
|
| 294 |
+
if self.visual.attnpool is not None:
|
| 295 |
+
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
| 296 |
+
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
| 297 |
+
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
| 298 |
+
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
| 299 |
+
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
| 300 |
+
|
| 301 |
+
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
|
| 302 |
+
for name, param in resnet_block.named_parameters():
|
| 303 |
+
if name.endswith("bn3.weight"):
|
| 304 |
+
nn.init.zeros_(param)
|
| 305 |
+
|
| 306 |
+
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
| 307 |
+
attn_std = self.transformer.width ** -0.5
|
| 308 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
| 309 |
+
for block in self.transformer.resblocks:
|
| 310 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
| 311 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
| 312 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
| 313 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
| 314 |
+
|
| 315 |
+
if self.text_projection is not None:
|
| 316 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
| 317 |
+
|
| 318 |
+
def build_attention_mask(self):
|
| 319 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
| 320 |
+
# pytorch uses additive attention mask; fill with -inf
|
| 321 |
+
mask = torch.empty(self.context_length, self.context_length)
|
| 322 |
+
mask.fill_(float("-inf"))
|
| 323 |
+
mask.triu_(1) # zero out the lower diagonal
|
| 324 |
+
return mask
|
| 325 |
+
|
| 326 |
+
@property
|
| 327 |
+
def dtype(self):
|
| 328 |
+
return self.visual.conv1.weight.dtype
|
| 329 |
+
|
| 330 |
+
def encode_image(self, image):
|
| 331 |
+
return self.visual(image.type(self.dtype))
|
| 332 |
+
|
| 333 |
+
def encode_text(self, text):
|
| 334 |
+
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
| 335 |
+
|
| 336 |
+
x = x + self.positional_embedding.type(self.dtype)
|
| 337 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
| 338 |
+
x = self.transformer(x)
|
| 339 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
| 340 |
+
x = self.ln_final(x).type(self.dtype)
|
| 341 |
+
|
| 342 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
| 343 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
| 344 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
| 345 |
+
|
| 346 |
+
return x
|
| 347 |
+
|
| 348 |
+
def forward(self, image, text):
|
| 349 |
+
image_features = self.encode_image(image)
|
| 350 |
+
text_features = self.encode_text(text)
|
| 351 |
+
|
| 352 |
+
# normalized features
|
| 353 |
+
image_features = image_features / image_features.norm(dim=1, keepdim=True)
|
| 354 |
+
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
| 355 |
+
|
| 356 |
+
# cosine similarity as logits
|
| 357 |
+
logit_scale = self.logit_scale.exp()
|
| 358 |
+
logits_per_image = logit_scale * image_features @ text_features.t()
|
| 359 |
+
logits_per_text = logits_per_image.t()
|
| 360 |
+
|
| 361 |
+
# shape = [global_batch_size, global_batch_size]
|
| 362 |
+
return logits_per_image, logits_per_text
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def convert_weights(model: nn.Module):
|
| 366 |
+
"""Convert applicable model parameters to fp16"""
|
| 367 |
+
|
| 368 |
+
def _convert_weights_to_fp16(l):
|
| 369 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
| 370 |
+
l.weight.data = l.weight.data.half()
|
| 371 |
+
if l.bias is not None:
|
| 372 |
+
l.bias.data = l.bias.data.half()
|
| 373 |
+
|
| 374 |
+
if isinstance(l, nn.MultiheadAttention):
|
| 375 |
+
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
| 376 |
+
tensor = getattr(l, attr)
|
| 377 |
+
if tensor is not None:
|
| 378 |
+
tensor.data = tensor.data.half()
|
| 379 |
+
|
| 380 |
+
for name in ["text_projection", "proj"]:
|
| 381 |
+
if hasattr(l, name):
|
| 382 |
+
attr = getattr(l, name)
|
| 383 |
+
if attr is not None:
|
| 384 |
+
attr.data = attr.data.half()
|
| 385 |
+
|
| 386 |
+
model.apply(_convert_weights_to_fp16)
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def build_model(state_dict: dict):
|
| 390 |
+
vit = "visual.proj" in state_dict
|
| 391 |
+
|
| 392 |
+
if vit:
|
| 393 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
| 394 |
+
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
| 395 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
| 396 |
+
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
| 397 |
+
image_resolution = vision_patch_size * grid_size
|
| 398 |
+
else:
|
| 399 |
+
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
| 400 |
+
vision_layers = tuple(counts)
|
| 401 |
+
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
| 402 |
+
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
| 403 |
+
vision_patch_size = None
|
| 404 |
+
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
| 405 |
+
image_resolution = output_width * 32
|
| 406 |
+
|
| 407 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
| 408 |
+
context_length = state_dict["positional_embedding"].shape[0]
|
| 409 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
| 410 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
| 411 |
+
transformer_heads = transformer_width // 64
|
| 412 |
+
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
|
| 413 |
+
|
| 414 |
+
model = CLIP(
|
| 415 |
+
embed_dim,
|
| 416 |
+
image_resolution, vision_layers, vision_width, vision_patch_size,
|
| 417 |
+
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
| 421 |
+
if key in state_dict:
|
| 422 |
+
del state_dict[key]
|
| 423 |
+
|
| 424 |
+
convert_weights(model)
|
| 425 |
+
model.load_state_dict(state_dict)
|
| 426 |
+
return model.eval()
|
REG/models/jepa.py
ADDED
|
@@ -0,0 +1,547 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
|
| 8 |
+
import math
|
| 9 |
+
from functools import partial
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
|
| 15 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
| 16 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
| 17 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
| 18 |
+
def norm_cdf(x):
|
| 19 |
+
# Computes standard normal cumulative distribution function
|
| 20 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
| 21 |
+
|
| 22 |
+
with torch.no_grad():
|
| 23 |
+
# Values are generated by using a truncated uniform distribution and
|
| 24 |
+
# then using the inverse CDF for the normal distribution.
|
| 25 |
+
# Get upper and lower cdf values
|
| 26 |
+
l = norm_cdf((a - mean) / std)
|
| 27 |
+
u = norm_cdf((b - mean) / std)
|
| 28 |
+
|
| 29 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
| 30 |
+
# [2l-1, 2u-1].
|
| 31 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
| 32 |
+
|
| 33 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
| 34 |
+
# standard normal
|
| 35 |
+
tensor.erfinv_()
|
| 36 |
+
|
| 37 |
+
# Transform to proper mean, std
|
| 38 |
+
tensor.mul_(std * math.sqrt(2.))
|
| 39 |
+
tensor.add_(mean)
|
| 40 |
+
|
| 41 |
+
# Clamp to ensure it's in the proper range
|
| 42 |
+
tensor.clamp_(min=a, max=b)
|
| 43 |
+
return tensor
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
| 47 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def repeat_interleave_batch(x, B, repeat):
|
| 51 |
+
N = len(x) // B
|
| 52 |
+
x = torch.cat([
|
| 53 |
+
torch.cat([x[i*B:(i+1)*B] for _ in range(repeat)], dim=0)
|
| 54 |
+
for i in range(N)
|
| 55 |
+
], dim=0)
|
| 56 |
+
return x
|
| 57 |
+
|
| 58 |
+
def apply_masks(x, masks):
|
| 59 |
+
"""
|
| 60 |
+
:param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)]
|
| 61 |
+
:param masks: list of tensors containing indices of patches in [N] to keep
|
| 62 |
+
"""
|
| 63 |
+
all_x = []
|
| 64 |
+
for m in masks:
|
| 65 |
+
mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1))
|
| 66 |
+
all_x += [torch.gather(x, dim=1, index=mask_keep)]
|
| 67 |
+
return torch.cat(all_x, dim=0)
|
| 68 |
+
|
| 69 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
| 70 |
+
"""
|
| 71 |
+
grid_size: int of the grid height and width
|
| 72 |
+
return:
|
| 73 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
| 74 |
+
"""
|
| 75 |
+
grid_h = np.arange(grid_size, dtype=float)
|
| 76 |
+
grid_w = np.arange(grid_size, dtype=float)
|
| 77 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
| 78 |
+
grid = np.stack(grid, axis=0)
|
| 79 |
+
|
| 80 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
| 81 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 82 |
+
if cls_token:
|
| 83 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
| 84 |
+
return pos_embed
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 88 |
+
assert embed_dim % 2 == 0
|
| 89 |
+
|
| 90 |
+
# use half of dimensions to encode grid_h
|
| 91 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
| 92 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
| 93 |
+
|
| 94 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
| 95 |
+
return emb
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
| 99 |
+
"""
|
| 100 |
+
grid_size: int of the grid length
|
| 101 |
+
return:
|
| 102 |
+
pos_embed: [grid_size, embed_dim] or [1+grid_size, embed_dim] (w/ or w/o cls_token)
|
| 103 |
+
"""
|
| 104 |
+
grid = np.arange(grid_size, dtype=float)
|
| 105 |
+
pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 106 |
+
if cls_token:
|
| 107 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
| 108 |
+
return pos_embed
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 112 |
+
"""
|
| 113 |
+
embed_dim: output dimension for each position
|
| 114 |
+
pos: a list of positions to be encoded: size (M,)
|
| 115 |
+
out: (M, D)
|
| 116 |
+
"""
|
| 117 |
+
assert embed_dim % 2 == 0
|
| 118 |
+
omega = np.arange(embed_dim // 2, dtype=float)
|
| 119 |
+
omega /= embed_dim / 2.
|
| 120 |
+
omega = 1. / 10000**omega # (D/2,)
|
| 121 |
+
|
| 122 |
+
pos = pos.reshape(-1) # (M,)
|
| 123 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
| 124 |
+
|
| 125 |
+
emb_sin = np.sin(out) # (M, D/2)
|
| 126 |
+
emb_cos = np.cos(out) # (M, D/2)
|
| 127 |
+
|
| 128 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
| 129 |
+
return emb
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
| 133 |
+
if drop_prob == 0. or not training:
|
| 134 |
+
return x
|
| 135 |
+
keep_prob = 1 - drop_prob
|
| 136 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 137 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
| 138 |
+
random_tensor.floor_() # binarize
|
| 139 |
+
output = x.div(keep_prob) * random_tensor
|
| 140 |
+
return output
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class DropPath(nn.Module):
|
| 144 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
| 145 |
+
"""
|
| 146 |
+
def __init__(self, drop_prob=None):
|
| 147 |
+
super(DropPath, self).__init__()
|
| 148 |
+
self.drop_prob = drop_prob
|
| 149 |
+
|
| 150 |
+
def forward(self, x):
|
| 151 |
+
return drop_path(x, self.drop_prob, self.training)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class MLP(nn.Module):
|
| 155 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 156 |
+
super().__init__()
|
| 157 |
+
out_features = out_features or in_features
|
| 158 |
+
hidden_features = hidden_features or in_features
|
| 159 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 160 |
+
self.act = act_layer()
|
| 161 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 162 |
+
self.drop = nn.Dropout(drop)
|
| 163 |
+
|
| 164 |
+
def forward(self, x):
|
| 165 |
+
x = self.fc1(x)
|
| 166 |
+
x = self.act(x)
|
| 167 |
+
x = self.drop(x)
|
| 168 |
+
x = self.fc2(x)
|
| 169 |
+
x = self.drop(x)
|
| 170 |
+
return x
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class Attention(nn.Module):
|
| 174 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
| 175 |
+
super().__init__()
|
| 176 |
+
self.num_heads = num_heads
|
| 177 |
+
head_dim = dim // num_heads
|
| 178 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 179 |
+
|
| 180 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 181 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 182 |
+
self.proj = nn.Linear(dim, dim)
|
| 183 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 184 |
+
|
| 185 |
+
def forward(self, x):
|
| 186 |
+
B, N, C = x.shape
|
| 187 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 188 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 189 |
+
|
| 190 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 191 |
+
attn = attn.softmax(dim=-1)
|
| 192 |
+
attn = self.attn_drop(attn)
|
| 193 |
+
|
| 194 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 195 |
+
x = self.proj(x)
|
| 196 |
+
x = self.proj_drop(x)
|
| 197 |
+
return x, attn
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class Block(nn.Module):
|
| 201 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
| 202 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
| 203 |
+
super().__init__()
|
| 204 |
+
self.norm1 = norm_layer(dim)
|
| 205 |
+
self.attn = Attention(
|
| 206 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
| 207 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 208 |
+
self.norm2 = norm_layer(dim)
|
| 209 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 210 |
+
self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 211 |
+
|
| 212 |
+
def forward(self, x, return_attention=False):
|
| 213 |
+
y, attn = self.attn(self.norm1(x))
|
| 214 |
+
if return_attention:
|
| 215 |
+
return attn
|
| 216 |
+
x = x + self.drop_path(y)
|
| 217 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 218 |
+
return x
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
class PatchEmbed(nn.Module):
|
| 222 |
+
""" Image to Patch Embedding
|
| 223 |
+
"""
|
| 224 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
| 225 |
+
super().__init__()
|
| 226 |
+
num_patches = (img_size // patch_size) * (img_size // patch_size)
|
| 227 |
+
self.img_size = img_size
|
| 228 |
+
self.patch_size = patch_size
|
| 229 |
+
self.num_patches = num_patches
|
| 230 |
+
|
| 231 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
| 232 |
+
|
| 233 |
+
def forward(self, x):
|
| 234 |
+
B, C, H, W = x.shape
|
| 235 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
| 236 |
+
return x
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class ConvEmbed(nn.Module):
|
| 240 |
+
"""
|
| 241 |
+
3x3 Convolution stems for ViT following ViTC models
|
| 242 |
+
"""
|
| 243 |
+
|
| 244 |
+
def __init__(self, channels, strides, img_size=224, in_chans=3, batch_norm=True):
|
| 245 |
+
super().__init__()
|
| 246 |
+
# Build the stems
|
| 247 |
+
stem = []
|
| 248 |
+
channels = [in_chans] + channels
|
| 249 |
+
for i in range(len(channels) - 2):
|
| 250 |
+
stem += [nn.Conv2d(channels[i], channels[i+1], kernel_size=3,
|
| 251 |
+
stride=strides[i], padding=1, bias=(not batch_norm))]
|
| 252 |
+
if batch_norm:
|
| 253 |
+
stem += [nn.BatchNorm2d(channels[i+1])]
|
| 254 |
+
stem += [nn.ReLU(inplace=True)]
|
| 255 |
+
stem += [nn.Conv2d(channels[-2], channels[-1], kernel_size=1, stride=strides[-1])]
|
| 256 |
+
self.stem = nn.Sequential(*stem)
|
| 257 |
+
|
| 258 |
+
# Comptute the number of patches
|
| 259 |
+
stride_prod = int(np.prod(strides))
|
| 260 |
+
self.num_patches = (img_size[0] // stride_prod)**2
|
| 261 |
+
|
| 262 |
+
def forward(self, x):
|
| 263 |
+
p = self.stem(x)
|
| 264 |
+
return p.flatten(2).transpose(1, 2)
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
class VisionTransformerPredictor(nn.Module):
|
| 268 |
+
""" Vision Transformer """
|
| 269 |
+
def __init__(
|
| 270 |
+
self,
|
| 271 |
+
num_patches,
|
| 272 |
+
embed_dim=768,
|
| 273 |
+
predictor_embed_dim=384,
|
| 274 |
+
depth=6,
|
| 275 |
+
num_heads=12,
|
| 276 |
+
mlp_ratio=4.0,
|
| 277 |
+
qkv_bias=True,
|
| 278 |
+
qk_scale=None,
|
| 279 |
+
drop_rate=0.0,
|
| 280 |
+
attn_drop_rate=0.0,
|
| 281 |
+
drop_path_rate=0.0,
|
| 282 |
+
norm_layer=nn.LayerNorm,
|
| 283 |
+
init_std=0.02,
|
| 284 |
+
**kwargs
|
| 285 |
+
):
|
| 286 |
+
super().__init__()
|
| 287 |
+
self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True)
|
| 288 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, predictor_embed_dim))
|
| 289 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
| 290 |
+
# --
|
| 291 |
+
self.predictor_pos_embed = nn.Parameter(torch.zeros(1, num_patches, predictor_embed_dim),
|
| 292 |
+
requires_grad=False)
|
| 293 |
+
predictor_pos_embed = get_2d_sincos_pos_embed(self.predictor_pos_embed.shape[-1],
|
| 294 |
+
int(num_patches**.5),
|
| 295 |
+
cls_token=False)
|
| 296 |
+
self.predictor_pos_embed.data.copy_(torch.from_numpy(predictor_pos_embed).float().unsqueeze(0))
|
| 297 |
+
# --
|
| 298 |
+
self.predictor_blocks = nn.ModuleList([
|
| 299 |
+
Block(
|
| 300 |
+
dim=predictor_embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 301 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
|
| 302 |
+
for i in range(depth)])
|
| 303 |
+
self.predictor_norm = norm_layer(predictor_embed_dim)
|
| 304 |
+
self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True)
|
| 305 |
+
# ------
|
| 306 |
+
self.init_std = init_std
|
| 307 |
+
trunc_normal_(self.mask_token, std=self.init_std)
|
| 308 |
+
self.apply(self._init_weights)
|
| 309 |
+
self.fix_init_weight()
|
| 310 |
+
|
| 311 |
+
def fix_init_weight(self):
|
| 312 |
+
def rescale(param, layer_id):
|
| 313 |
+
param.div_(math.sqrt(2.0 * layer_id))
|
| 314 |
+
|
| 315 |
+
for layer_id, layer in enumerate(self.predictor_blocks):
|
| 316 |
+
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
| 317 |
+
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
| 318 |
+
|
| 319 |
+
def _init_weights(self, m):
|
| 320 |
+
if isinstance(m, nn.Linear):
|
| 321 |
+
trunc_normal_(m.weight, std=self.init_std)
|
| 322 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 323 |
+
nn.init.constant_(m.bias, 0)
|
| 324 |
+
elif isinstance(m, nn.LayerNorm):
|
| 325 |
+
nn.init.constant_(m.bias, 0)
|
| 326 |
+
nn.init.constant_(m.weight, 1.0)
|
| 327 |
+
elif isinstance(m, nn.Conv2d):
|
| 328 |
+
trunc_normal_(m.weight, std=self.init_std)
|
| 329 |
+
if m.bias is not None:
|
| 330 |
+
nn.init.constant_(m.bias, 0)
|
| 331 |
+
|
| 332 |
+
def forward(self, x, masks_x, masks):
|
| 333 |
+
assert (masks is not None) and (masks_x is not None), 'Cannot run predictor without mask indices'
|
| 334 |
+
|
| 335 |
+
if not isinstance(masks_x, list):
|
| 336 |
+
masks_x = [masks_x]
|
| 337 |
+
|
| 338 |
+
if not isinstance(masks, list):
|
| 339 |
+
masks = [masks]
|
| 340 |
+
|
| 341 |
+
# -- Batch Size
|
| 342 |
+
B = len(x) // len(masks_x)
|
| 343 |
+
|
| 344 |
+
# -- map from encoder-dim to pedictor-dim
|
| 345 |
+
x = self.predictor_embed(x)
|
| 346 |
+
|
| 347 |
+
# -- add positional embedding to x tokens
|
| 348 |
+
x_pos_embed = self.predictor_pos_embed.repeat(B, 1, 1)
|
| 349 |
+
x += apply_masks(x_pos_embed, masks_x)
|
| 350 |
+
|
| 351 |
+
_, N_ctxt, D = x.shape
|
| 352 |
+
|
| 353 |
+
# -- concat mask tokens to x
|
| 354 |
+
pos_embs = self.predictor_pos_embed.repeat(B, 1, 1)
|
| 355 |
+
pos_embs = apply_masks(pos_embs, masks)
|
| 356 |
+
pos_embs = repeat_interleave_batch(pos_embs, B, repeat=len(masks_x))
|
| 357 |
+
# --
|
| 358 |
+
pred_tokens = self.mask_token.repeat(pos_embs.size(0), pos_embs.size(1), 1)
|
| 359 |
+
# --
|
| 360 |
+
pred_tokens += pos_embs
|
| 361 |
+
x = x.repeat(len(masks), 1, 1)
|
| 362 |
+
x = torch.cat([x, pred_tokens], dim=1)
|
| 363 |
+
|
| 364 |
+
# -- fwd prop
|
| 365 |
+
for blk in self.predictor_blocks:
|
| 366 |
+
x = blk(x)
|
| 367 |
+
x = self.predictor_norm(x)
|
| 368 |
+
|
| 369 |
+
# -- return preds for mask tokens
|
| 370 |
+
x = x[:, N_ctxt:]
|
| 371 |
+
x = self.predictor_proj(x)
|
| 372 |
+
|
| 373 |
+
return x
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
class VisionTransformer(nn.Module):
|
| 377 |
+
""" Vision Transformer """
|
| 378 |
+
def __init__(
|
| 379 |
+
self,
|
| 380 |
+
img_size=[224],
|
| 381 |
+
patch_size=16,
|
| 382 |
+
in_chans=3,
|
| 383 |
+
embed_dim=768,
|
| 384 |
+
predictor_embed_dim=384,
|
| 385 |
+
depth=12,
|
| 386 |
+
predictor_depth=12,
|
| 387 |
+
num_heads=12,
|
| 388 |
+
mlp_ratio=4.0,
|
| 389 |
+
qkv_bias=True,
|
| 390 |
+
qk_scale=None,
|
| 391 |
+
drop_rate=0.0,
|
| 392 |
+
attn_drop_rate=0.0,
|
| 393 |
+
drop_path_rate=0.0,
|
| 394 |
+
norm_layer=nn.LayerNorm,
|
| 395 |
+
init_std=0.02,
|
| 396 |
+
**kwargs
|
| 397 |
+
):
|
| 398 |
+
super().__init__()
|
| 399 |
+
self.num_features = self.embed_dim = embed_dim
|
| 400 |
+
self.num_heads = num_heads
|
| 401 |
+
# --
|
| 402 |
+
self.patch_embed = PatchEmbed(
|
| 403 |
+
img_size=img_size[0],
|
| 404 |
+
patch_size=patch_size,
|
| 405 |
+
in_chans=in_chans,
|
| 406 |
+
embed_dim=embed_dim)
|
| 407 |
+
num_patches = self.patch_embed.num_patches
|
| 408 |
+
# --
|
| 409 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim), requires_grad=False)
|
| 410 |
+
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1],
|
| 411 |
+
int(self.patch_embed.num_patches**.5),
|
| 412 |
+
cls_token=False)
|
| 413 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
| 414 |
+
# --
|
| 415 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
| 416 |
+
self.blocks = nn.ModuleList([
|
| 417 |
+
Block(
|
| 418 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 419 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
|
| 420 |
+
for i in range(depth)])
|
| 421 |
+
self.norm = norm_layer(embed_dim)
|
| 422 |
+
# ------
|
| 423 |
+
self.init_std = init_std
|
| 424 |
+
self.apply(self._init_weights)
|
| 425 |
+
self.fix_init_weight()
|
| 426 |
+
|
| 427 |
+
def fix_init_weight(self):
|
| 428 |
+
def rescale(param, layer_id):
|
| 429 |
+
param.div_(math.sqrt(2.0 * layer_id))
|
| 430 |
+
|
| 431 |
+
for layer_id, layer in enumerate(self.blocks):
|
| 432 |
+
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
| 433 |
+
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
| 434 |
+
|
| 435 |
+
def _init_weights(self, m):
|
| 436 |
+
if isinstance(m, nn.Linear):
|
| 437 |
+
trunc_normal_(m.weight, std=self.init_std)
|
| 438 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 439 |
+
nn.init.constant_(m.bias, 0)
|
| 440 |
+
elif isinstance(m, nn.LayerNorm):
|
| 441 |
+
nn.init.constant_(m.bias, 0)
|
| 442 |
+
nn.init.constant_(m.weight, 1.0)
|
| 443 |
+
elif isinstance(m, nn.Conv2d):
|
| 444 |
+
trunc_normal_(m.weight, std=self.init_std)
|
| 445 |
+
if m.bias is not None:
|
| 446 |
+
nn.init.constant_(m.bias, 0)
|
| 447 |
+
|
| 448 |
+
def forward(self, x, masks=None):
|
| 449 |
+
if masks is not None:
|
| 450 |
+
if not isinstance(masks, list):
|
| 451 |
+
masks = [masks]
|
| 452 |
+
|
| 453 |
+
# -- patchify x
|
| 454 |
+
x = self.patch_embed(x)
|
| 455 |
+
B, N, D = x.shape
|
| 456 |
+
|
| 457 |
+
# -- add positional embedding to x
|
| 458 |
+
pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
|
| 459 |
+
x = x + pos_embed
|
| 460 |
+
|
| 461 |
+
# -- mask x
|
| 462 |
+
if masks is not None:
|
| 463 |
+
x = apply_masks(x, masks)
|
| 464 |
+
|
| 465 |
+
# -- fwd prop
|
| 466 |
+
for i, blk in enumerate(self.blocks):
|
| 467 |
+
x = blk(x)
|
| 468 |
+
|
| 469 |
+
if self.norm is not None:
|
| 470 |
+
x = self.norm(x)
|
| 471 |
+
|
| 472 |
+
return x
|
| 473 |
+
|
| 474 |
+
def interpolate_pos_encoding(self, x, pos_embed):
|
| 475 |
+
npatch = x.shape[1] - 1
|
| 476 |
+
N = pos_embed.shape[1] - 1
|
| 477 |
+
if npatch == N:
|
| 478 |
+
return pos_embed
|
| 479 |
+
class_emb = pos_embed[:, 0]
|
| 480 |
+
pos_embed = pos_embed[:, 1:]
|
| 481 |
+
dim = x.shape[-1]
|
| 482 |
+
pos_embed = nn.functional.interpolate(
|
| 483 |
+
pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
|
| 484 |
+
scale_factor=math.sqrt(npatch / N),
|
| 485 |
+
mode='bicubic',
|
| 486 |
+
)
|
| 487 |
+
pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 488 |
+
return torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
def vit_predictor(**kwargs):
|
| 492 |
+
model = VisionTransformerPredictor(
|
| 493 |
+
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
| 494 |
+
**kwargs)
|
| 495 |
+
return model
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
def vit_tiny(patch_size=16, **kwargs):
|
| 499 |
+
model = VisionTransformer(
|
| 500 |
+
patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
|
| 501 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 502 |
+
return model
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
def vit_small(patch_size=16, **kwargs):
|
| 506 |
+
model = VisionTransformer(
|
| 507 |
+
patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
|
| 508 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 509 |
+
return model
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
def vit_base(patch_size=16, **kwargs):
|
| 513 |
+
model = VisionTransformer(
|
| 514 |
+
patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
|
| 515 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 516 |
+
return model
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
def vit_large(patch_size=16, **kwargs):
|
| 520 |
+
model = VisionTransformer(
|
| 521 |
+
patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
|
| 522 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 523 |
+
return model
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
def vit_huge(patch_size=16, **kwargs):
|
| 527 |
+
model = VisionTransformer(
|
| 528 |
+
patch_size=patch_size, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4,
|
| 529 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 530 |
+
return model
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
def vit_giant(patch_size=16, **kwargs):
|
| 534 |
+
model = VisionTransformer(
|
| 535 |
+
patch_size=patch_size, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=48/11,
|
| 536 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 537 |
+
return model
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
VIT_EMBED_DIMS = {
|
| 541 |
+
'vit_tiny': 192,
|
| 542 |
+
'vit_small': 384,
|
| 543 |
+
'vit_base': 768,
|
| 544 |
+
'vit_large': 1024,
|
| 545 |
+
'vit_huge': 1280,
|
| 546 |
+
'vit_giant': 1408,
|
| 547 |
+
}
|
REG/models/mae_vit.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
# References:
|
| 8 |
+
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
| 9 |
+
# DeiT: https://github.com/facebookresearch/deit
|
| 10 |
+
# --------------------------------------------------------
|
| 11 |
+
|
| 12 |
+
from functools import partial
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
|
| 17 |
+
import timm.models.vision_transformer
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
|
| 21 |
+
""" Vision Transformer with support for global average pooling
|
| 22 |
+
"""
|
| 23 |
+
def __init__(self, global_pool=False, **kwargs):
|
| 24 |
+
super(VisionTransformer, self).__init__(**kwargs)
|
| 25 |
+
|
| 26 |
+
self.global_pool = global_pool
|
| 27 |
+
if self.global_pool:
|
| 28 |
+
norm_layer = kwargs['norm_layer']
|
| 29 |
+
embed_dim = kwargs['embed_dim']
|
| 30 |
+
self.fc_norm = norm_layer(embed_dim)
|
| 31 |
+
|
| 32 |
+
del self.norm # remove the original norm
|
| 33 |
+
|
| 34 |
+
def forward_features(self, x):
|
| 35 |
+
B = x.shape[0]
|
| 36 |
+
x = self.patch_embed(x)
|
| 37 |
+
|
| 38 |
+
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
| 39 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 40 |
+
x = x + self.pos_embed
|
| 41 |
+
x = self.pos_drop(x)
|
| 42 |
+
|
| 43 |
+
for blk in self.blocks:
|
| 44 |
+
x = blk(x)
|
| 45 |
+
|
| 46 |
+
x = x[:, 1:, :] #.mean(dim=1) # global pool without cls token
|
| 47 |
+
|
| 48 |
+
return x
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def vit_base_patch16(**kwargs):
|
| 52 |
+
model = VisionTransformer(
|
| 53 |
+
num_classes=0,
|
| 54 |
+
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
| 55 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 56 |
+
return model
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def vit_large_patch16(**kwargs):
|
| 60 |
+
model = VisionTransformer(
|
| 61 |
+
num_classes=0,
|
| 62 |
+
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
| 63 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 64 |
+
return model
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def vit_huge_patch14(**kwargs):
|
| 68 |
+
model = VisionTransformer(
|
| 69 |
+
patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
| 70 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 71 |
+
return model
|
REG/models/mocov3_vit.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
from functools import partial, reduce
|
| 11 |
+
from operator import mul
|
| 12 |
+
|
| 13 |
+
from timm.layers.helpers import to_2tuple
|
| 14 |
+
from timm.models.vision_transformer import VisionTransformer, _cfg
|
| 15 |
+
from timm.models.vision_transformer import PatchEmbed
|
| 16 |
+
|
| 17 |
+
__all__ = [
|
| 18 |
+
'vit_small',
|
| 19 |
+
'vit_base',
|
| 20 |
+
'vit_large',
|
| 21 |
+
'vit_conv_small',
|
| 22 |
+
'vit_conv_base',
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def patchify_avg(input_tensor, patch_size):
|
| 27 |
+
# Ensure input tensor is 4D: (batch_size, channels, height, width)
|
| 28 |
+
if input_tensor.dim() != 4:
|
| 29 |
+
raise ValueError("Input tensor must be 4D (batch_size, channels, height, width)")
|
| 30 |
+
|
| 31 |
+
# Get input tensor dimensions
|
| 32 |
+
batch_size, channels, height, width = input_tensor.shape
|
| 33 |
+
|
| 34 |
+
# Ensure patch_size is valid
|
| 35 |
+
patch_height, patch_width = patch_size, patch_size
|
| 36 |
+
if height % patch_height != 0 or width % patch_width != 0:
|
| 37 |
+
raise ValueError("Input tensor dimensions must be divisible by patch_size")
|
| 38 |
+
|
| 39 |
+
# Use unfold to create patches
|
| 40 |
+
patches = input_tensor.unfold(2, patch_height, patch_height).unfold(3, patch_width, patch_width)
|
| 41 |
+
|
| 42 |
+
# Reshape patches to desired format: (batch_size, num_patches, channels)
|
| 43 |
+
patches = patches.contiguous().view(
|
| 44 |
+
batch_size, channels, -1, patch_height, patch_width
|
| 45 |
+
).mean(dim=-1).mean(dim=-1)
|
| 46 |
+
patches = patches.permute(0, 2, 1).contiguous()
|
| 47 |
+
|
| 48 |
+
return patches
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class VisionTransformerMoCo(VisionTransformer):
|
| 53 |
+
def __init__(self, stop_grad_conv1=False, **kwargs):
|
| 54 |
+
super().__init__(**kwargs)
|
| 55 |
+
# Use fixed 2D sin-cos position embedding
|
| 56 |
+
self.build_2d_sincos_position_embedding()
|
| 57 |
+
|
| 58 |
+
# weight initialization
|
| 59 |
+
for name, m in self.named_modules():
|
| 60 |
+
if isinstance(m, nn.Linear):
|
| 61 |
+
if 'qkv' in name:
|
| 62 |
+
# treat the weights of Q, K, V separately
|
| 63 |
+
val = math.sqrt(6. / float(m.weight.shape[0] // 3 + m.weight.shape[1]))
|
| 64 |
+
nn.init.uniform_(m.weight, -val, val)
|
| 65 |
+
else:
|
| 66 |
+
nn.init.xavier_uniform_(m.weight)
|
| 67 |
+
nn.init.zeros_(m.bias)
|
| 68 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
| 69 |
+
|
| 70 |
+
if isinstance(self.patch_embed, PatchEmbed):
|
| 71 |
+
# xavier_uniform initialization
|
| 72 |
+
val = math.sqrt(6. / float(3 * reduce(mul, self.patch_embed.patch_size, 1) + self.embed_dim))
|
| 73 |
+
nn.init.uniform_(self.patch_embed.proj.weight, -val, val)
|
| 74 |
+
nn.init.zeros_(self.patch_embed.proj.bias)
|
| 75 |
+
|
| 76 |
+
if stop_grad_conv1:
|
| 77 |
+
self.patch_embed.proj.weight.requires_grad = False
|
| 78 |
+
self.patch_embed.proj.bias.requires_grad = False
|
| 79 |
+
|
| 80 |
+
def build_2d_sincos_position_embedding(self, temperature=10000.):
|
| 81 |
+
h = self.patch_embed.img_size[0] // self.patch_embed.patch_size[0]
|
| 82 |
+
w = self.patch_embed.img_size[1] // self.patch_embed.patch_size[1]
|
| 83 |
+
grid_w = torch.arange(w, dtype=torch.float32)
|
| 84 |
+
grid_h = torch.arange(h, dtype=torch.float32)
|
| 85 |
+
grid_w, grid_h = torch.meshgrid(grid_w, grid_h)
|
| 86 |
+
assert self.embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
|
| 87 |
+
pos_dim = self.embed_dim // 4
|
| 88 |
+
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
|
| 89 |
+
omega = 1. / (temperature**omega)
|
| 90 |
+
out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega])
|
| 91 |
+
out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])
|
| 92 |
+
pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :]
|
| 93 |
+
|
| 94 |
+
# assert self.num_tokens == 1, 'Assuming one and only one token, [cls]'
|
| 95 |
+
pe_token = torch.zeros([1, 1, self.embed_dim], dtype=torch.float32)
|
| 96 |
+
self.pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1))
|
| 97 |
+
self.pos_embed.requires_grad = False
|
| 98 |
+
|
| 99 |
+
def forward_diffusion_output(self, x):
|
| 100 |
+
x = x.reshape(*x.shape[0:2], -1).permute(0, 2, 1)
|
| 101 |
+
x = self._pos_embed(x)
|
| 102 |
+
x = self.patch_drop(x)
|
| 103 |
+
x = self.norm_pre(x)
|
| 104 |
+
x = self.blocks(x)
|
| 105 |
+
x = self.norm(x)
|
| 106 |
+
return x
|
| 107 |
+
|
| 108 |
+
class ConvStem(nn.Module):
|
| 109 |
+
"""
|
| 110 |
+
ConvStem, from Early Convolutions Help Transformers See Better, Tete et al. https://arxiv.org/abs/2106.14881
|
| 111 |
+
"""
|
| 112 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
|
| 113 |
+
super().__init__()
|
| 114 |
+
|
| 115 |
+
assert patch_size == 16, 'ConvStem only supports patch size of 16'
|
| 116 |
+
assert embed_dim % 8 == 0, 'Embed dimension must be divisible by 8 for ConvStem'
|
| 117 |
+
|
| 118 |
+
img_size = to_2tuple(img_size)
|
| 119 |
+
patch_size = to_2tuple(patch_size)
|
| 120 |
+
self.img_size = img_size
|
| 121 |
+
self.patch_size = patch_size
|
| 122 |
+
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
| 123 |
+
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
| 124 |
+
self.flatten = flatten
|
| 125 |
+
|
| 126 |
+
# build stem, similar to the design in https://arxiv.org/abs/2106.14881
|
| 127 |
+
stem = []
|
| 128 |
+
input_dim, output_dim = 3, embed_dim // 8
|
| 129 |
+
for l in range(4):
|
| 130 |
+
stem.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=2, padding=1, bias=False))
|
| 131 |
+
stem.append(nn.BatchNorm2d(output_dim))
|
| 132 |
+
stem.append(nn.ReLU(inplace=True))
|
| 133 |
+
input_dim = output_dim
|
| 134 |
+
output_dim *= 2
|
| 135 |
+
stem.append(nn.Conv2d(input_dim, embed_dim, kernel_size=1))
|
| 136 |
+
self.proj = nn.Sequential(*stem)
|
| 137 |
+
|
| 138 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 139 |
+
|
| 140 |
+
def forward(self, x):
|
| 141 |
+
B, C, H, W = x.shape
|
| 142 |
+
assert H == self.img_size[0] and W == self.img_size[1], \
|
| 143 |
+
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
| 144 |
+
x = self.proj(x)
|
| 145 |
+
if self.flatten:
|
| 146 |
+
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
| 147 |
+
x = self.norm(x)
|
| 148 |
+
return x
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def vit_small(**kwargs):
|
| 152 |
+
model = VisionTransformerMoCo(
|
| 153 |
+
img_size=256,
|
| 154 |
+
patch_size=16, embed_dim=384, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
| 155 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 156 |
+
model.default_cfg = _cfg()
|
| 157 |
+
return model
|
| 158 |
+
|
| 159 |
+
def vit_base(**kwargs):
|
| 160 |
+
model = VisionTransformerMoCo(
|
| 161 |
+
img_size=256,
|
| 162 |
+
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
| 163 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 164 |
+
model.default_cfg = _cfg()
|
| 165 |
+
return model
|
| 166 |
+
|
| 167 |
+
def vit_large(**kwargs):
|
| 168 |
+
model = VisionTransformerMoCo(
|
| 169 |
+
img_size=256,
|
| 170 |
+
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
| 171 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 172 |
+
model.default_cfg = _cfg()
|
| 173 |
+
return model
|
| 174 |
+
|
| 175 |
+
def vit_conv_small(**kwargs):
|
| 176 |
+
# minus one ViT block
|
| 177 |
+
model = VisionTransformerMoCo(
|
| 178 |
+
patch_size=16, embed_dim=384, depth=11, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
| 179 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), embed_layer=ConvStem, **kwargs)
|
| 180 |
+
model.default_cfg = _cfg()
|
| 181 |
+
return model
|
| 182 |
+
|
| 183 |
+
def vit_conv_base(**kwargs):
|
| 184 |
+
# minus one ViT block
|
| 185 |
+
model = VisionTransformerMoCo(
|
| 186 |
+
patch_size=16, embed_dim=768, depth=11, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
| 187 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), embed_layer=ConvStem, **kwargs)
|
| 188 |
+
model.default_cfg = _cfg()
|
| 189 |
+
return model
|
| 190 |
+
|
| 191 |
+
def build_mlp(num_layers, input_dim, mlp_dim, output_dim, last_bn=True):
|
| 192 |
+
mlp = []
|
| 193 |
+
for l in range(num_layers):
|
| 194 |
+
dim1 = input_dim if l == 0 else mlp_dim
|
| 195 |
+
dim2 = output_dim if l == num_layers - 1 else mlp_dim
|
| 196 |
+
|
| 197 |
+
mlp.append(nn.Linear(dim1, dim2, bias=False))
|
| 198 |
+
|
| 199 |
+
if l < num_layers - 1:
|
| 200 |
+
mlp.append(nn.BatchNorm1d(dim2))
|
| 201 |
+
mlp.append(nn.ReLU(inplace=True))
|
| 202 |
+
elif last_bn:
|
| 203 |
+
# follow SimCLR's design: https://github.com/google-research/simclr/blob/master/model_util.py#L157
|
| 204 |
+
# for simplicity, we further removed gamma in BN
|
| 205 |
+
mlp.append(nn.BatchNorm1d(dim2, affine=False))
|
| 206 |
+
|
| 207 |
+
return nn.Sequential(*mlp)
|
REG/models/sit.py
ADDED
|
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This source code is licensed under the license found in the
|
| 2 |
+
# LICENSE file in the root directory of this source tree.
|
| 3 |
+
# --------------------------------------------------------
|
| 4 |
+
# References:
|
| 5 |
+
# GLIDE: https://github.com/openai/glide-text2im
|
| 6 |
+
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
|
| 7 |
+
# --------------------------------------------------------
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import numpy as np
|
| 12 |
+
import math
|
| 13 |
+
from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def build_mlp(hidden_size, projector_dim, z_dim):
|
| 17 |
+
return nn.Sequential(
|
| 18 |
+
nn.Linear(hidden_size, projector_dim),
|
| 19 |
+
nn.SiLU(),
|
| 20 |
+
nn.Linear(projector_dim, projector_dim),
|
| 21 |
+
nn.SiLU(),
|
| 22 |
+
nn.Linear(projector_dim, z_dim),
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
def modulate(x, shift, scale):
|
| 26 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
| 27 |
+
|
| 28 |
+
#################################################################################
|
| 29 |
+
# Embedding Layers for Timesteps and Class Labels #
|
| 30 |
+
#################################################################################
|
| 31 |
+
class TimestepEmbedder(nn.Module):
|
| 32 |
+
"""
|
| 33 |
+
Embeds scalar timesteps into vector representations.
|
| 34 |
+
"""
|
| 35 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.mlp = nn.Sequential(
|
| 38 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
| 39 |
+
nn.SiLU(),
|
| 40 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 41 |
+
)
|
| 42 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 43 |
+
|
| 44 |
+
@staticmethod
|
| 45 |
+
def positional_embedding(t, dim, max_period=10000):
|
| 46 |
+
"""
|
| 47 |
+
Create sinusoidal timestep embeddings.
|
| 48 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
| 49 |
+
These may be fractional.
|
| 50 |
+
:param dim: the dimension of the output.
|
| 51 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 52 |
+
:return: an (N, D) Tensor of positional embeddings.
|
| 53 |
+
"""
|
| 54 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
| 55 |
+
half = dim // 2
|
| 56 |
+
freqs = torch.exp(
|
| 57 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
| 58 |
+
).to(device=t.device)
|
| 59 |
+
args = t[:, None].float() * freqs[None]
|
| 60 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 61 |
+
if dim % 2:
|
| 62 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 63 |
+
return embedding
|
| 64 |
+
|
| 65 |
+
def forward(self, t):
|
| 66 |
+
self.timestep_embedding = self.positional_embedding
|
| 67 |
+
t_freq = self.timestep_embedding(t, dim=self.frequency_embedding_size).to(t.dtype)
|
| 68 |
+
t_emb = self.mlp(t_freq)
|
| 69 |
+
return t_emb
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class LabelEmbedder(nn.Module):
|
| 73 |
+
"""
|
| 74 |
+
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
| 75 |
+
"""
|
| 76 |
+
def __init__(self, num_classes, hidden_size, dropout_prob):
|
| 77 |
+
super().__init__()
|
| 78 |
+
use_cfg_embedding = dropout_prob > 0
|
| 79 |
+
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
|
| 80 |
+
self.num_classes = num_classes
|
| 81 |
+
self.dropout_prob = dropout_prob
|
| 82 |
+
|
| 83 |
+
def token_drop(self, labels, force_drop_ids=None):
|
| 84 |
+
"""
|
| 85 |
+
Drops labels to enable classifier-free guidance.
|
| 86 |
+
"""
|
| 87 |
+
if force_drop_ids is None:
|
| 88 |
+
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
|
| 89 |
+
else:
|
| 90 |
+
drop_ids = force_drop_ids == 1
|
| 91 |
+
labels = torch.where(drop_ids, self.num_classes, labels)
|
| 92 |
+
return labels
|
| 93 |
+
|
| 94 |
+
def forward(self, labels, train, force_drop_ids=None):
|
| 95 |
+
use_dropout = self.dropout_prob > 0
|
| 96 |
+
if (train and use_dropout) or (force_drop_ids is not None):
|
| 97 |
+
labels = self.token_drop(labels, force_drop_ids)
|
| 98 |
+
embeddings = self.embedding_table(labels)
|
| 99 |
+
return embeddings
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
#################################################################################
|
| 103 |
+
# Core SiT Model #
|
| 104 |
+
#################################################################################
|
| 105 |
+
|
| 106 |
+
class SiTBlock(nn.Module):
|
| 107 |
+
"""
|
| 108 |
+
A SiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
| 109 |
+
"""
|
| 110 |
+
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
|
| 111 |
+
super().__init__()
|
| 112 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 113 |
+
self.attn = Attention(
|
| 114 |
+
hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=block_kwargs["qk_norm"]
|
| 115 |
+
)
|
| 116 |
+
if "fused_attn" in block_kwargs.keys():
|
| 117 |
+
self.attn.fused_attn = block_kwargs["fused_attn"]
|
| 118 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 119 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 120 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
| 121 |
+
self.mlp = Mlp(
|
| 122 |
+
in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0
|
| 123 |
+
)
|
| 124 |
+
self.adaLN_modulation = nn.Sequential(
|
| 125 |
+
nn.SiLU(),
|
| 126 |
+
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
def forward(self, x, c):
|
| 130 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
| 131 |
+
self.adaLN_modulation(c).chunk(6, dim=-1)
|
| 132 |
+
)
|
| 133 |
+
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
|
| 134 |
+
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
| 135 |
+
|
| 136 |
+
return x
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class FinalLayer(nn.Module):
|
| 140 |
+
"""
|
| 141 |
+
The final layer of SiT.
|
| 142 |
+
"""
|
| 143 |
+
def __init__(self, hidden_size, patch_size, out_channels, cls_token_dim):
|
| 144 |
+
super().__init__()
|
| 145 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 146 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
| 147 |
+
self.linear_cls = nn.Linear(hidden_size, cls_token_dim, bias=True)
|
| 148 |
+
self.adaLN_modulation = nn.Sequential(
|
| 149 |
+
nn.SiLU(),
|
| 150 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
def forward(self, x, c, cls=None):
|
| 154 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
| 155 |
+
x = modulate(self.norm_final(x), shift, scale)
|
| 156 |
+
|
| 157 |
+
if cls is None:
|
| 158 |
+
x = self.linear(x)
|
| 159 |
+
return x, None
|
| 160 |
+
else:
|
| 161 |
+
cls_token = self.linear_cls(x[:, 0]).unsqueeze(1)
|
| 162 |
+
x = self.linear(x[:, 1:])
|
| 163 |
+
return x, cls_token.squeeze(1)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class SiT(nn.Module):
|
| 167 |
+
"""
|
| 168 |
+
Diffusion model with a Transformer backbone.
|
| 169 |
+
"""
|
| 170 |
+
def __init__(
|
| 171 |
+
self,
|
| 172 |
+
path_type='edm',
|
| 173 |
+
input_size=32,
|
| 174 |
+
patch_size=2,
|
| 175 |
+
in_channels=4,
|
| 176 |
+
hidden_size=1152,
|
| 177 |
+
decoder_hidden_size=768,
|
| 178 |
+
encoder_depth=8,
|
| 179 |
+
depth=28,
|
| 180 |
+
num_heads=16,
|
| 181 |
+
mlp_ratio=4.0,
|
| 182 |
+
class_dropout_prob=0.1,
|
| 183 |
+
num_classes=1000,
|
| 184 |
+
use_cfg=False,
|
| 185 |
+
z_dims=[768],
|
| 186 |
+
projector_dim=2048,
|
| 187 |
+
cls_token_dim=768,
|
| 188 |
+
**block_kwargs # fused_attn
|
| 189 |
+
):
|
| 190 |
+
super().__init__()
|
| 191 |
+
self.path_type = path_type
|
| 192 |
+
self.in_channels = in_channels
|
| 193 |
+
self.out_channels = in_channels
|
| 194 |
+
self.patch_size = patch_size
|
| 195 |
+
self.num_heads = num_heads
|
| 196 |
+
self.use_cfg = use_cfg
|
| 197 |
+
self.num_classes = num_classes
|
| 198 |
+
self.z_dims = z_dims
|
| 199 |
+
self.encoder_depth = encoder_depth
|
| 200 |
+
|
| 201 |
+
self.x_embedder = PatchEmbed(
|
| 202 |
+
input_size, patch_size, in_channels, hidden_size, bias=True
|
| 203 |
+
)
|
| 204 |
+
self.t_embedder = TimestepEmbedder(hidden_size) # timestep embedding type
|
| 205 |
+
self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
|
| 206 |
+
num_patches = self.x_embedder.num_patches
|
| 207 |
+
# Will use fixed sin-cos embedding:
|
| 208 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, hidden_size), requires_grad=False)
|
| 209 |
+
|
| 210 |
+
self.blocks = nn.ModuleList([
|
| 211 |
+
SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, **block_kwargs) for _ in range(depth)
|
| 212 |
+
])
|
| 213 |
+
self.projectors = nn.ModuleList([
|
| 214 |
+
build_mlp(hidden_size, projector_dim, z_dim) for z_dim in z_dims
|
| 215 |
+
])
|
| 216 |
+
|
| 217 |
+
z_dim = self.z_dims[0]
|
| 218 |
+
cls_token_dim = z_dim
|
| 219 |
+
self.final_layer = FinalLayer(decoder_hidden_size, patch_size, self.out_channels, cls_token_dim)
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
self.cls_projectors2 = nn.Linear(in_features=cls_token_dim, out_features=hidden_size, bias=True)
|
| 223 |
+
self.wg_norm = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
|
| 224 |
+
|
| 225 |
+
self.initialize_weights()
|
| 226 |
+
|
| 227 |
+
def initialize_weights(self):
|
| 228 |
+
# Initialize transformer layers:
|
| 229 |
+
def _basic_init(module):
|
| 230 |
+
if isinstance(module, nn.Linear):
|
| 231 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 232 |
+
if module.bias is not None:
|
| 233 |
+
nn.init.constant_(module.bias, 0)
|
| 234 |
+
self.apply(_basic_init)
|
| 235 |
+
|
| 236 |
+
# Initialize (and freeze) pos_embed by sin-cos embedding:
|
| 237 |
+
pos_embed = get_2d_sincos_pos_embed(
|
| 238 |
+
self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5), cls_token=1, extra_tokens=1
|
| 239 |
+
)
|
| 240 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
| 241 |
+
|
| 242 |
+
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
| 243 |
+
w = self.x_embedder.proj.weight.data
|
| 244 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 245 |
+
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
| 246 |
+
|
| 247 |
+
# Initialize label embedding table:
|
| 248 |
+
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
|
| 249 |
+
|
| 250 |
+
# Initialize timestep embedding MLP:
|
| 251 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
| 252 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
| 253 |
+
|
| 254 |
+
# Zero-out adaLN modulation layers in SiT blocks:
|
| 255 |
+
for block in self.blocks:
|
| 256 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
| 257 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
| 258 |
+
|
| 259 |
+
# Zero-out output layers:
|
| 260 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
| 261 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
| 262 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
| 263 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
| 264 |
+
nn.init.constant_(self.final_layer.linear_cls.weight, 0)
|
| 265 |
+
nn.init.constant_(self.final_layer.linear_cls.bias, 0)
|
| 266 |
+
|
| 267 |
+
def unpatchify(self, x, patch_size=None):
|
| 268 |
+
"""
|
| 269 |
+
x: (N, T, patch_size**2 * C)
|
| 270 |
+
imgs: (N, C, H, W)
|
| 271 |
+
"""
|
| 272 |
+
c = self.out_channels
|
| 273 |
+
p = self.x_embedder.patch_size[0] if patch_size is None else patch_size
|
| 274 |
+
h = w = int(x.shape[1] ** 0.5)
|
| 275 |
+
assert h * w == x.shape[1]
|
| 276 |
+
|
| 277 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
| 278 |
+
x = torch.einsum('nhwpqc->nchpwq', x)
|
| 279 |
+
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
| 280 |
+
return imgs
|
| 281 |
+
|
| 282 |
+
def forward(self, x, t, y, return_logvar=False, cls_token=None):
|
| 283 |
+
"""
|
| 284 |
+
Forward pass of SiT.
|
| 285 |
+
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
| 286 |
+
t: (N,) tensor of diffusion timesteps
|
| 287 |
+
y: (N,) tensor of class labels
|
| 288 |
+
"""
|
| 289 |
+
|
| 290 |
+
#cat with cls_token
|
| 291 |
+
x = self.x_embedder(x) # (N, T, D), where T = H * W / patch_size ** 2
|
| 292 |
+
if cls_token is not None:
|
| 293 |
+
cls_token = self.cls_projectors2(cls_token)
|
| 294 |
+
cls_token = self.wg_norm(cls_token)
|
| 295 |
+
cls_token = cls_token.unsqueeze(1) # [b, length, d]
|
| 296 |
+
x = torch.cat((cls_token, x), dim=1)
|
| 297 |
+
x = x + self.pos_embed
|
| 298 |
+
else:
|
| 299 |
+
exit()
|
| 300 |
+
N, T, D = x.shape
|
| 301 |
+
|
| 302 |
+
# timestep and class embedding
|
| 303 |
+
t_embed = self.t_embedder(t) # (N, D)
|
| 304 |
+
y = self.y_embedder(y, self.training) # (N, D)
|
| 305 |
+
c = t_embed + y
|
| 306 |
+
|
| 307 |
+
for i, block in enumerate(self.blocks):
|
| 308 |
+
x = block(x, c)
|
| 309 |
+
if (i + 1) == self.encoder_depth:
|
| 310 |
+
zs = [projector(x.reshape(-1, D)).reshape(N, T, -1) for projector in self.projectors]
|
| 311 |
+
|
| 312 |
+
x, cls_token = self.final_layer(x, c, cls=cls_token)
|
| 313 |
+
x = self.unpatchify(x)
|
| 314 |
+
|
| 315 |
+
return x, zs, cls_token
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
#################################################################################
|
| 319 |
+
# Sine/Cosine Positional Embedding Functions #
|
| 320 |
+
#################################################################################
|
| 321 |
+
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
| 322 |
+
|
| 323 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
|
| 324 |
+
"""
|
| 325 |
+
grid_size: int of the grid height and width
|
| 326 |
+
return:
|
| 327 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
| 328 |
+
"""
|
| 329 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
| 330 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
| 331 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
| 332 |
+
grid = np.stack(grid, axis=0)
|
| 333 |
+
|
| 334 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
| 335 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 336 |
+
if cls_token and extra_tokens > 0:
|
| 337 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
| 338 |
+
return pos_embed
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 342 |
+
assert embed_dim % 2 == 0
|
| 343 |
+
|
| 344 |
+
# use half of dimensions to encode grid_h
|
| 345 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
| 346 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
| 347 |
+
|
| 348 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
| 349 |
+
return emb
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 353 |
+
"""
|
| 354 |
+
embed_dim: output dimension for each position
|
| 355 |
+
pos: a list of positions to be encoded: size (M,)
|
| 356 |
+
out: (M, D)
|
| 357 |
+
"""
|
| 358 |
+
assert embed_dim % 2 == 0
|
| 359 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
| 360 |
+
omega /= embed_dim / 2.
|
| 361 |
+
omega = 1. / 10000**omega # (D/2,)
|
| 362 |
+
|
| 363 |
+
pos = pos.reshape(-1) # (M,)
|
| 364 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
| 365 |
+
|
| 366 |
+
emb_sin = np.sin(out) # (M, D/2)
|
| 367 |
+
emb_cos = np.cos(out) # (M, D/2)
|
| 368 |
+
|
| 369 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
| 370 |
+
return emb
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
#################################################################################
|
| 374 |
+
# SiT Configs #
|
| 375 |
+
#################################################################################
|
| 376 |
+
|
| 377 |
+
def SiT_XL_2(**kwargs):
|
| 378 |
+
return SiT(depth=28, hidden_size=1152, decoder_hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
|
| 379 |
+
|
| 380 |
+
def SiT_XL_4(**kwargs):
|
| 381 |
+
return SiT(depth=28, hidden_size=1152, decoder_hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
|
| 382 |
+
|
| 383 |
+
def SiT_XL_8(**kwargs):
|
| 384 |
+
return SiT(depth=28, hidden_size=1152, decoder_hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
|
| 385 |
+
|
| 386 |
+
def SiT_L_2(**kwargs):
|
| 387 |
+
return SiT(depth=24, hidden_size=1024, decoder_hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
|
| 388 |
+
|
| 389 |
+
def SiT_L_4(**kwargs):
|
| 390 |
+
return SiT(depth=24, hidden_size=1024, decoder_hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
|
| 391 |
+
|
| 392 |
+
def SiT_L_8(**kwargs):
|
| 393 |
+
return SiT(depth=24, hidden_size=1024, decoder_hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
|
| 394 |
+
|
| 395 |
+
def SiT_B_2(**kwargs):
|
| 396 |
+
return SiT(depth=12, hidden_size=768, decoder_hidden_size=768, patch_size=2, num_heads=12, **kwargs)
|
| 397 |
+
|
| 398 |
+
def SiT_B_4(**kwargs):
|
| 399 |
+
return SiT(depth=12, hidden_size=768, decoder_hidden_size=768, patch_size=4, num_heads=12, **kwargs)
|
| 400 |
+
|
| 401 |
+
def SiT_B_8(**kwargs):
|
| 402 |
+
return SiT(depth=12, hidden_size=768, decoder_hidden_size=768, patch_size=8, num_heads=12, **kwargs)
|
| 403 |
+
|
| 404 |
+
def SiT_S_2(**kwargs):
|
| 405 |
+
return SiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
|
| 406 |
+
|
| 407 |
+
def SiT_S_4(**kwargs):
|
| 408 |
+
return SiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
|
| 409 |
+
|
| 410 |
+
def SiT_S_8(**kwargs):
|
| 411 |
+
return SiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
SiT_models = {
|
| 415 |
+
'SiT-XL/2': SiT_XL_2, 'SiT-XL/4': SiT_XL_4, 'SiT-XL/8': SiT_XL_8,
|
| 416 |
+
'SiT-L/2': SiT_L_2, 'SiT-L/4': SiT_L_4, 'SiT-L/8': SiT_L_8,
|
| 417 |
+
'SiT-B/2': SiT_B_2, 'SiT-B/4': SiT_B_4, 'SiT-B/8': SiT_B_8,
|
| 418 |
+
'SiT-S/2': SiT_S_2, 'SiT-S/4': SiT_S_4, 'SiT-S/8': SiT_S_8,
|
| 419 |
+
}
|
| 420 |
+
|
back/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2024 Sihyun Yu
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
back/README.md
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<p align="center">
|
| 2 |
+
<h1 align="center">Representation Entanglement for Generation: Training Diffusion Transformers Is Much Easier Than You Think (NeurIPS 2025 Oral)
|
| 3 |
+
</h1>
|
| 4 |
+
<p align="center">
|
| 5 |
+
<a href='https://github.com/Martinser' style='text-decoration: none' >Ge Wu</a><sup>1</sup> 
|
| 6 |
+
<a href='https://github.com/ShenZhang-Shin' style='text-decoration: none' >Shen Zhang</a><sup>3</sup> 
|
| 7 |
+
<a href='' style='text-decoration: none' >Ruijing Shi</a><sup>1</sup> 
|
| 8 |
+
<a href='https://shgao.site/' style='text-decoration: none' >Shanghua Gao</a><sup>4</sup> 
|
| 9 |
+
<a href='https://zhenyuanchenai.github.io/' style='text-decoration: none' >Zhenyuan Chen</a><sup>1</sup> 
|
| 10 |
+
<a href='https://scholar.google.com/citations?user=6Z66DAwAAAAJ&hl=en' style='text-decoration: none' >Lei Wang</a><sup>1</sup> 
|
| 11 |
+
<a href='https://www.zhihu.com/people/chen-zhao-wei-16-2' style='text-decoration: none' >Zhaowei Chen</a><sup>3</sup> 
|
| 12 |
+
<a href='https://gao-hongcheng.github.io/' style='text-decoration: none' >Hongcheng Gao</a><sup>5</sup> 
|
| 13 |
+
<a href='https://scholar.google.com/citations?view_op=list_works&hl=zh-CN&hl=zh-CN&user=0xP6bxcAAAAJ' style='text-decoration: none' >Yao Tang</a><sup>3</sup> 
|
| 14 |
+
<a href='https://scholar.google.com/citations?user=6CIDtZQAAAAJ&hl=en' style='text-decoration: none' >Jian Yang</a><sup>1</sup> 
|
| 15 |
+
<a href='https://mmcheng.net/cmm/' style='text-decoration: none' >Ming-Ming Cheng</a><sup>1,2</sup> 
|
| 16 |
+
<a href='https://implus.github.io/' style='text-decoration: none' >Xiang Li</a><sup>1,2*</sup> 
|
| 17 |
+
<p align="center">
|
| 18 |
+
$^{1}$ VCIP, CS, Nankai University, $^{2}$ NKIARI, Shenzhen Futian, $^{3}$ JIIOV Technology,
|
| 19 |
+
$^{4}$ Harvard University, $^{5}$ University of Chinese Academy of Sciences
|
| 20 |
+
<p align='center'>
|
| 21 |
+
<div align="center">
|
| 22 |
+
<a href='https://arxiv.org/abs/2507.01467v2'><img src='https://img.shields.io/badge/arXiv-2507.01467v2-brown.svg?logo=arxiv&logoColor=white'></a>
|
| 23 |
+
<a href='https://huggingface.co/Martinser/REG/tree/main'><img src='https://img.shields.io/badge/🤗-Model-blue.svg'></a>
|
| 24 |
+
<a href='https://zhuanlan.zhihu.com/p/1952346823168595518'><img src='https://img.shields.io/badge/Zhihu-chinese_article-blue.svg?logo=zhihu&logoColor=white'></a>
|
| 25 |
+
</div>
|
| 26 |
+
<p align='center'>
|
| 27 |
+
</p>
|
| 28 |
+
</p>
|
| 29 |
+
</p>
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
## 🚩 Overview
|
| 33 |
+
|
| 34 |
+

|
| 35 |
+
|
| 36 |
+
REPA and its variants effectively mitigate training challenges in diffusion models by incorporating external visual representations from pretrained models, through alignment between the noisy hidden projections of denoising networks and foundational clean image representations.
|
| 37 |
+
We argue that the external alignment, which is absent during the entire denoising inference process, falls short of fully harnessing the potential of discriminative representations.
|
| 38 |
+
|
| 39 |
+
In this work, we propose a straightforward method called Representation Entanglement for Generation (REG), which entangles low-level image latents with a single high-level class token from pretrained foundation models for denoising.
|
| 40 |
+
REG acquires the capability to produce coherent image-class pairs directly from pure noise,
|
| 41 |
+
substantially improving both generation quality and training efficiency.
|
| 42 |
+
This is accomplished with negligible additional inference overhead, **requiring only one single additional token for denoising (<0.5\% increase in FLOPs and latency).**
|
| 43 |
+
The inference process concurrently reconstructs both image latents and their corresponding global semantics, where the acquired semantic knowledge actively guides and enhances the image generation process.
|
| 44 |
+
|
| 45 |
+
On ImageNet $256{\times}256$, SiT-XL/2 + REG demonstrates remarkable convergence acceleration, **achieving $\textbf{63}\times$ and $\textbf{23}\times$ faster training than SiT-XL/2 and SiT-XL/2 + REPA, respectively.**
|
| 46 |
+
More impressively, SiT-L/2 + REG trained for merely 400K iterations outperforms SiT-XL/2 + REPA trained for 4M iterations ($\textbf{10}\times$ longer).
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
## 📰 News
|
| 51 |
+
|
| 52 |
+
- **[2025.08.05]** We have released the pre-trained weights of REG + SiT-XL/2 in 4M (800 epochs).
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
## 📝 Results
|
| 56 |
+
|
| 57 |
+
- Performance on ImageNet $256{\times}256$ with FID=1.36 by introducing a single class token.
|
| 58 |
+
- $\textbf{63}\times$ and $\textbf{23}\times$ faster training than SiT-XL/2 and SiT-XL/2 + REPA.
|
| 59 |
+
|
| 60 |
+
<div align="center">
|
| 61 |
+
<img src="fig/img.png" alt="Results">
|
| 62 |
+
</div>
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
## 📋 Plan
|
| 66 |
+
- More training steps on ImageNet 256&512 and T2I.
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
## 👊 Usage
|
| 70 |
+
|
| 71 |
+
### 1. Environment setup
|
| 72 |
+
|
| 73 |
+
```bash
|
| 74 |
+
conda create -n reg python=3.10.16 -y
|
| 75 |
+
conda activate reg
|
| 76 |
+
pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1
|
| 77 |
+
pip install -r requirements.txt
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
### 2. Dataset
|
| 81 |
+
|
| 82 |
+
#### Dataset download
|
| 83 |
+
|
| 84 |
+
Currently, we provide experiments for ImageNet. You can place the data that you want and can specifiy it via `--data-dir` arguments in training scripts.
|
| 85 |
+
|
| 86 |
+
#### Preprocessing data
|
| 87 |
+
Please refer to the preprocessing guide. And you can directly download our processed data, ImageNet data [link](https://huggingface.co/WindATree/ImageNet-256-VAE/tree/main), and ImageNet data after VAE encoder [link]( https://huggingface.co/WindATree/vae-sd/tree/main)
|
| 88 |
+
|
| 89 |
+
### 3. Training
|
| 90 |
+
Run train.sh
|
| 91 |
+
```bash
|
| 92 |
+
bash train.sh
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
train.sh contains the following content.
|
| 96 |
+
```bash
|
| 97 |
+
accelerate launch --multi_gpu --num_processes $NUM_GPUS train.py \
|
| 98 |
+
--report-to="wandb" \
|
| 99 |
+
--allow-tf32 \
|
| 100 |
+
--mixed-precision="fp16" \
|
| 101 |
+
--seed=0 \
|
| 102 |
+
--path-type="linear" \
|
| 103 |
+
--prediction="v" \
|
| 104 |
+
--weighting="uniform" \
|
| 105 |
+
--model="SiT-B/2" \
|
| 106 |
+
--enc-type="dinov2-vit-b" \
|
| 107 |
+
--proj-coeff=0.5 \
|
| 108 |
+
--encoder-depth=4 \ #SiT-L/XL use 8, SiT-B use 4
|
| 109 |
+
--output-dir="your_path" \
|
| 110 |
+
--exp-name="linear-dinov2-b-enc4" \
|
| 111 |
+
--batch-size=256 \
|
| 112 |
+
--data-dir="data_path/imagenet_vae" \
|
| 113 |
+
--cls=0.03
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
Then this script will automatically create the folder in `exps` to save logs and checkpoints. You can adjust the following options:
|
| 117 |
+
|
| 118 |
+
- `--models`: `[SiT-B/2, SiT-L/2, SiT-XL/2]`
|
| 119 |
+
- `--enc-type`: `[dinov2-vit-b, clip-vit-L]`
|
| 120 |
+
- `--proj-coeff`: Any values larger than 0
|
| 121 |
+
- `--encoder-depth`: Any values between 1 to the depth of the model
|
| 122 |
+
- `--output-dir`: Any directory that you want to save checkpoints and logs
|
| 123 |
+
- `--exp-name`: Any string name (the folder will be created under `output-dir`)
|
| 124 |
+
- `--cls`: Weight coefficients of REG loss
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
### 4. Generate images and evaluation
|
| 128 |
+
You can generate images and get the final results through the following script.
|
| 129 |
+
The weight of REG can be found in this [link](https://pan.baidu.com/s/1QX2p3ybh1KfNU7wsp5McWw?pwd=khpp) or [HF](https://huggingface.co/Martinser/REG/tree/main).
|
| 130 |
+
|
| 131 |
+
```bash
|
| 132 |
+
bash eval.sh
|
| 133 |
+
```
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
## Citation
|
| 137 |
+
If you find our work, this repository, or pretrained models useful, please consider giving a star and citation.
|
| 138 |
+
```
|
| 139 |
+
@article{wu2025representation,
|
| 140 |
+
title={Representation Entanglement for Generation: Training Diffusion Transformers Is Much Easier Than You Think},
|
| 141 |
+
author={Wu, Ge and Zhang, Shen and Shi, Ruijing and Gao, Shanghua and Chen, Zhenyuan and Wang, Lei and Chen, Zhaowei and Gao, Hongcheng and Tang, Yao and Yang, Jian and others},
|
| 142 |
+
journal={arXiv preprint arXiv:2507.01467},
|
| 143 |
+
year={2025}
|
| 144 |
+
}
|
| 145 |
+
```
|
| 146 |
+
|
| 147 |
+
## Contact
|
| 148 |
+
If you have any questions, please create an issue on this repository, contact at gewu.nku@gmail.com or wechat(wg1158848).
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
## Acknowledgements
|
| 152 |
+
|
| 153 |
+
Our code is based on [REPA](https://github.com/sihyun-yu/REPA), along with [SiT](https://github.com/willisma/SiT), [DINOv2](https://github.com/facebookresearch/dinov2), [ADM](https://github.com/openai/guided-diffusion) and [U-ViT](https://github.com/baofff/U-ViT) repositories. We thank the authors for releasing their code. If you use our model and code, please consider citing these works as well.
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
|
back/eval.sh
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
random_number=$((RANDOM % 100 + 1200))
|
| 3 |
+
NUM_GPUS=8
|
| 4 |
+
STEP="4000000"
|
| 5 |
+
SAVE_PATH="your_path/reg_xlarge_dinov2_base_align_8_cls/linear-dinov2-b-enc8"
|
| 6 |
+
VAE_PATH="your_vae_path/"
|
| 7 |
+
NUM_STEP=250
|
| 8 |
+
MODEL_SIZE='XL'
|
| 9 |
+
CFG_SCALE=2.3
|
| 10 |
+
CLS_CFG_SCALE=2.3
|
| 11 |
+
GH=0.85
|
| 12 |
+
|
| 13 |
+
export NCCL_P2P_DISABLE=1
|
| 14 |
+
|
| 15 |
+
python -m torch.distributed.launch --master_port=$random_number --nproc_per_node=$NUM_GPUS generate.py \
|
| 16 |
+
--model SiT-XL/2 \
|
| 17 |
+
--num-fid-samples 50000 \
|
| 18 |
+
--ckpt ${SAVE_PATH}/checkpoints/${STEP}.pt \
|
| 19 |
+
--path-type=linear \
|
| 20 |
+
--encoder-depth=8 \
|
| 21 |
+
--projector-embed-dims=768 \
|
| 22 |
+
--per-proc-batch-size=64 \
|
| 23 |
+
--mode=sde \
|
| 24 |
+
--num-steps=${NUM_STEP} \
|
| 25 |
+
--cfg-scale=${CFG_SCALE} \
|
| 26 |
+
--cls-cfg-scale=${CLS_CFG_SCALE} \
|
| 27 |
+
--guidance-high=${GH} \
|
| 28 |
+
--sample-dir ${SAVE_PATH}/checkpoints \
|
| 29 |
+
--cls=768
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
python ./evaluations/evaluator.py \
|
| 33 |
+
--ref_batch your_path/VIRTUAL_imagenet256_labeled.npz \
|
| 34 |
+
--sample_batch ${SAVE_PATH}/checkpoints/SiT-${MODEL_SIZE}-2-${STEP}-size-256-vae-ema-cfg-${CFG_SCALE}-seed-0-sde-${GH}-${CLS_CFG_SCALE}.npz \
|
| 35 |
+
--save_path ${SAVE_PATH}/checkpoints \
|
| 36 |
+
--cfg_cond 1 \
|
| 37 |
+
--step ${STEP} \
|
| 38 |
+
--num_steps ${NUM_STEP} \
|
| 39 |
+
--cfg ${CFG_SCALE} \
|
| 40 |
+
--cls_cfg ${CLS_CFG_SCALE} \
|
| 41 |
+
--gh ${GH}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
|
back/loss.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
try:
|
| 6 |
+
from scipy.optimize import linear_sum_assignment
|
| 7 |
+
except ImportError:
|
| 8 |
+
linear_sum_assignment = None
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def ot_pair_noise_to_cls(noise_cls, cls_gt):
|
| 12 |
+
"""
|
| 13 |
+
Minibatch OT(与 conditional-flow-matching / torchcfm 中 sample_plan_with_scipy 一致):
|
| 14 |
+
在 batch 内用平方欧氏代价重排 noise,使 noise_ot[i] 与 cls_gt[i] 构成近似最优传输配对。
|
| 15 |
+
noise_cls, cls_gt: (N, D) 或任意可在最后一维展平为 D 的形状。
|
| 16 |
+
"""
|
| 17 |
+
n = noise_cls.shape[0]
|
| 18 |
+
if n <= 1:
|
| 19 |
+
return noise_cls, cls_gt
|
| 20 |
+
if linear_sum_assignment is None:
|
| 21 |
+
return noise_cls, cls_gt
|
| 22 |
+
x0 = noise_cls.detach().float().reshape(n, -1)
|
| 23 |
+
x1 = cls_gt.detach().float().reshape(n, -1)
|
| 24 |
+
M = torch.cdist(x0, x1) ** 2
|
| 25 |
+
_, j = linear_sum_assignment(M.cpu().numpy())
|
| 26 |
+
j = torch.as_tensor(j, device=noise_cls.device, dtype=torch.long)
|
| 27 |
+
return noise_cls[j], cls_gt
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def mean_flat(x):
|
| 31 |
+
"""
|
| 32 |
+
Take the mean over all non-batch dimensions.
|
| 33 |
+
"""
|
| 34 |
+
return torch.mean(x, dim=list(range(1, len(x.size()))))
|
| 35 |
+
|
| 36 |
+
def sum_flat(x):
|
| 37 |
+
"""
|
| 38 |
+
Take the mean over all non-batch dimensions.
|
| 39 |
+
"""
|
| 40 |
+
return torch.sum(x, dim=list(range(1, len(x.size()))))
|
| 41 |
+
|
| 42 |
+
class SILoss:
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
prediction='v',
|
| 46 |
+
path_type="linear",
|
| 47 |
+
weighting="uniform",
|
| 48 |
+
encoders=[],
|
| 49 |
+
accelerator=None,
|
| 50 |
+
latents_scale=None,
|
| 51 |
+
latents_bias=None,
|
| 52 |
+
t_c=0.5,
|
| 53 |
+
ot_cls=True,
|
| 54 |
+
):
|
| 55 |
+
self.prediction = prediction
|
| 56 |
+
self.weighting = weighting
|
| 57 |
+
self.path_type = path_type
|
| 58 |
+
self.encoders = encoders
|
| 59 |
+
self.accelerator = accelerator
|
| 60 |
+
self.latents_scale = latents_scale
|
| 61 |
+
self.latents_bias = latents_bias
|
| 62 |
+
# t 与 train.py / JsFlow 一致:t=0 为干净 latent,t=1 为纯噪声。
|
| 63 |
+
# t ∈ (t_c, 1]:语义 cls 沿 OT 配对后的路径从噪声演化为 cls_gt(生成语义通道);
|
| 64 |
+
# t ∈ [0, t_c]:cls 恒为真实 cls_gt,目标速度为 0(通道不再插值)。
|
| 65 |
+
tc = float(t_c)
|
| 66 |
+
self.t_c = min(max(tc, 1e-4), 1.0 - 1e-4)
|
| 67 |
+
self.ot_cls = bool(ot_cls)
|
| 68 |
+
|
| 69 |
+
def interpolant(self, t):
|
| 70 |
+
if self.path_type == "linear":
|
| 71 |
+
alpha_t = 1 - t
|
| 72 |
+
sigma_t = t
|
| 73 |
+
d_alpha_t = -1
|
| 74 |
+
d_sigma_t = 1
|
| 75 |
+
elif self.path_type == "cosine":
|
| 76 |
+
alpha_t = torch.cos(t * np.pi / 2)
|
| 77 |
+
sigma_t = torch.sin(t * np.pi / 2)
|
| 78 |
+
d_alpha_t = -np.pi / 2 * torch.sin(t * np.pi / 2)
|
| 79 |
+
d_sigma_t = np.pi / 2 * torch.cos(t * np.pi / 2)
|
| 80 |
+
else:
|
| 81 |
+
raise NotImplementedError()
|
| 82 |
+
|
| 83 |
+
return alpha_t, sigma_t, d_alpha_t, d_sigma_t
|
| 84 |
+
|
| 85 |
+
def __call__(self, model, images, model_kwargs=None, zs=None, cls_token=None,
|
| 86 |
+
time_input=None, noises=None,):
|
| 87 |
+
if model_kwargs == None:
|
| 88 |
+
model_kwargs = {}
|
| 89 |
+
# sample timesteps
|
| 90 |
+
if time_input is None:
|
| 91 |
+
if self.weighting == "uniform":
|
| 92 |
+
time_input = torch.rand((images.shape[0], 1, 1, 1))
|
| 93 |
+
elif self.weighting == "lognormal":
|
| 94 |
+
# sample timestep according to log-normal distribution of sigmas following EDM
|
| 95 |
+
rnd_normal = torch.randn((images.shape[0], 1 ,1, 1))
|
| 96 |
+
sigma = rnd_normal.exp()
|
| 97 |
+
if self.path_type == "linear":
|
| 98 |
+
time_input = sigma / (1 + sigma)
|
| 99 |
+
elif self.path_type == "cosine":
|
| 100 |
+
time_input = 2 / np.pi * torch.atan(sigma)
|
| 101 |
+
|
| 102 |
+
time_input = time_input.to(device=images.device, dtype=torch.float32)
|
| 103 |
+
cls_token = cls_token.to(device=images.device, dtype=torch.float32)
|
| 104 |
+
|
| 105 |
+
if noises is None:
|
| 106 |
+
noises = torch.randn_like(images)
|
| 107 |
+
noises_cls = torch.randn_like(cls_token)
|
| 108 |
+
else:
|
| 109 |
+
if isinstance(noises, (tuple, list)) and len(noises) == 2:
|
| 110 |
+
noises, noises_cls = noises
|
| 111 |
+
else:
|
| 112 |
+
noises_cls = torch.randn_like(cls_token)
|
| 113 |
+
|
| 114 |
+
alpha_t, sigma_t, d_alpha_t, d_sigma_t = self.interpolant(time_input)
|
| 115 |
+
|
| 116 |
+
model_input = alpha_t * images + sigma_t * noises
|
| 117 |
+
if self.prediction == 'v':
|
| 118 |
+
model_target = d_alpha_t * images + d_sigma_t * noises
|
| 119 |
+
else:
|
| 120 |
+
raise NotImplementedError()
|
| 121 |
+
|
| 122 |
+
N = images.shape[0]
|
| 123 |
+
t_flat = time_input.view(-1).float()
|
| 124 |
+
high_noise_mask = (t_flat > self.t_c).float().view(N, *([1] * (cls_token.dim() - 1)))
|
| 125 |
+
low_noise_mask = 1.0 - high_noise_mask
|
| 126 |
+
|
| 127 |
+
noise_cls_raw = noises_cls
|
| 128 |
+
if self.ot_cls:
|
| 129 |
+
noise_cls_paired, cls_gt_paired = ot_pair_noise_to_cls(noise_cls_raw, cls_token)
|
| 130 |
+
else:
|
| 131 |
+
noise_cls_paired, cls_gt_paired = noise_cls_raw, cls_token
|
| 132 |
+
|
| 133 |
+
tau_shape = (N,) + (1,) * max(0, cls_token.dim() - 1)
|
| 134 |
+
tau = (time_input.reshape(tau_shape) - self.t_c) / (1.0 - self.t_c + 1e-8)
|
| 135 |
+
tau = torch.clamp(tau, 0.0, 1.0)
|
| 136 |
+
alpha_sem = 1.0 - tau
|
| 137 |
+
sigma_sem = tau
|
| 138 |
+
|
| 139 |
+
cls_t_high = alpha_sem * cls_gt_paired + sigma_sem * noise_cls_paired
|
| 140 |
+
cls_t = high_noise_mask * cls_t_high + low_noise_mask * cls_token
|
| 141 |
+
cls_t = torch.nan_to_num(cls_t, nan=0.0, posinf=1e4, neginf=-1e4)
|
| 142 |
+
cls_t = torch.clamp(cls_t, -1e4, 1e4)
|
| 143 |
+
|
| 144 |
+
cls_for_model = cls_t * high_noise_mask + cls_t.detach() * low_noise_mask
|
| 145 |
+
|
| 146 |
+
inv_scale = 1.0 / (1.0 - self.t_c + 1e-8)
|
| 147 |
+
v_cls_high = (noise_cls_paired - cls_gt_paired) * inv_scale
|
| 148 |
+
v_cls_target = high_noise_mask * v_cls_high
|
| 149 |
+
|
| 150 |
+
model_output, zs_tilde, cls_output = model(
|
| 151 |
+
model_input, time_input.flatten(), **model_kwargs, cls_token=cls_for_model
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
#denoising_loss
|
| 155 |
+
denoising_loss = mean_flat((model_output - model_target) ** 2)
|
| 156 |
+
denoising_loss_cls = mean_flat((cls_output - v_cls_target) ** 2)
|
| 157 |
+
|
| 158 |
+
# projection loss
|
| 159 |
+
proj_loss = 0.
|
| 160 |
+
bsz = zs[0].shape[0]
|
| 161 |
+
for i, (z, z_tilde) in enumerate(zip(zs, zs_tilde)):
|
| 162 |
+
for j, (z_j, z_tilde_j) in enumerate(zip(z, z_tilde)):
|
| 163 |
+
z_tilde_j = torch.nn.functional.normalize(z_tilde_j, dim=-1)
|
| 164 |
+
z_j = torch.nn.functional.normalize(z_j, dim=-1)
|
| 165 |
+
proj_loss += mean_flat(-(z_j * z_tilde_j).sum(dim=-1))
|
| 166 |
+
proj_loss /= (len(zs) * bsz)
|
| 167 |
+
|
| 168 |
+
return denoising_loss, proj_loss, time_input, noises, denoising_loss_cls
|
back/requirements.txt
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
- pip:
|
| 2 |
+
absl-py==2.2.2
|
| 3 |
+
accelerate==1.2.1
|
| 4 |
+
aiohappyeyeballs==2.6.1
|
| 5 |
+
aiohttp==3.11.16
|
| 6 |
+
aiosignal==1.3.2
|
| 7 |
+
astunparse==1.6.3
|
| 8 |
+
async-timeout==5.0.1
|
| 9 |
+
attrs==25.3.0
|
| 10 |
+
certifi==2022.12.7
|
| 11 |
+
charset-normalizer==2.1.1
|
| 12 |
+
click==8.1.8
|
| 13 |
+
datasets==2.20.0
|
| 14 |
+
diffusers==0.32.1
|
| 15 |
+
dill==0.3.8
|
| 16 |
+
docker-pycreds==0.4.0
|
| 17 |
+
einops==0.8.1
|
| 18 |
+
filelock==3.13.1
|
| 19 |
+
flatbuffers==25.2.10
|
| 20 |
+
frozenlist==1.5.0
|
| 21 |
+
fsspec==2024.5.0
|
| 22 |
+
ftfy==6.3.1
|
| 23 |
+
gast==0.6.0
|
| 24 |
+
gitdb==4.0.12
|
| 25 |
+
gitpython==3.1.44
|
| 26 |
+
google-pasta==0.2.0
|
| 27 |
+
grpcio==1.71.0
|
| 28 |
+
h5py==3.13.0
|
| 29 |
+
huggingface-hub==0.27.1
|
| 30 |
+
idna==3.4
|
| 31 |
+
importlib-metadata==8.6.1
|
| 32 |
+
jinja2==3.1.4
|
| 33 |
+
joblib==1.4.2
|
| 34 |
+
keras==3.9.2
|
| 35 |
+
libclang==18.1.1
|
| 36 |
+
markdown==3.8
|
| 37 |
+
markdown-it-py==3.0.0
|
| 38 |
+
markupsafe==2.1.5
|
| 39 |
+
mdurl==0.1.2
|
| 40 |
+
ml-dtypes==0.3.2
|
| 41 |
+
mpmath==1.3.0
|
| 42 |
+
multidict==6.4.3
|
| 43 |
+
multiprocess==0.70.16
|
| 44 |
+
namex==0.0.8
|
| 45 |
+
networkx==3.3
|
| 46 |
+
numpy==1.26.4
|
| 47 |
+
opt-einsum==3.4.0
|
| 48 |
+
optree==0.15.0
|
| 49 |
+
packaging==24.2
|
| 50 |
+
pandas==2.2.3
|
| 51 |
+
pillow==11.0.0
|
| 52 |
+
platformdirs==4.3.7
|
| 53 |
+
propcache==0.3.1
|
| 54 |
+
protobuf==4.25.6
|
| 55 |
+
psutil==7.0.0
|
| 56 |
+
pyarrow==19.0.1
|
| 57 |
+
pyarrow-hotfix==0.6
|
| 58 |
+
pygments==2.19.1
|
| 59 |
+
python-dateutil==2.9.0.post0
|
| 60 |
+
pytz==2025.2
|
| 61 |
+
pyyaml==6.0.2
|
| 62 |
+
regex==2024.11.6
|
| 63 |
+
requests==2.32.3
|
| 64 |
+
rich==14.0.0
|
| 65 |
+
safetensors==0.5.3
|
| 66 |
+
scikit-learn==1.5.1
|
| 67 |
+
scipy==1.15.2
|
| 68 |
+
sentry-sdk==2.26.1
|
| 69 |
+
setproctitle==1.3.5
|
| 70 |
+
six==1.17.0
|
| 71 |
+
smmap==5.0.2
|
| 72 |
+
sympy==1.13.1
|
| 73 |
+
tensorboard==2.16.1
|
| 74 |
+
tensorboard-data-server==0.7.2
|
| 75 |
+
tensorflow==2.16.1
|
| 76 |
+
tensorflow-io-gcs-filesystem==0.37.1
|
| 77 |
+
termcolor==3.0.1
|
| 78 |
+
tf-keras==2.16.0
|
| 79 |
+
threadpoolctl==3.6.0
|
| 80 |
+
timm==1.0.12
|
| 81 |
+
tokenizers==0.21.0
|
| 82 |
+
tqdm==4.67.1
|
| 83 |
+
transformers==4.47.0
|
| 84 |
+
triton==2.1.0
|
| 85 |
+
typing-extensions==4.12.2
|
| 86 |
+
tzdata==2025.2
|
| 87 |
+
urllib3==1.26.13
|
| 88 |
+
wandb==0.17.6
|
| 89 |
+
wcwidth==0.2.13
|
| 90 |
+
werkzeug==3.1.3
|
| 91 |
+
wrapt==1.17.2
|
| 92 |
+
xformer==1.0.1
|
| 93 |
+
xformers==0.0.23
|
| 94 |
+
xxhash==3.5.0
|
| 95 |
+
yarl==1.20.0
|
| 96 |
+
zipp==3.21.0
|
| 97 |
+
|
back/sample_from_checkpoint.py
ADDED
|
@@ -0,0 +1,596 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
从 REG/train.py 保存的检查点加载权重,在指定目录生成若干 PNG。
|
| 4 |
+
|
| 5 |
+
示例:
|
| 6 |
+
python sample_from_checkpoint.py \\
|
| 7 |
+
--ckpt exps/jsflow-experiment/checkpoints/0050000.pt \\
|
| 8 |
+
--out-dir ./samples_gen \\
|
| 9 |
+
--num-images 64 \\
|
| 10 |
+
--batch-size 8
|
| 11 |
+
|
| 12 |
+
# 按训练 t_c 分段分配步数(t=1→t_c 与 t_c→0;--t-c 可省略若检查点含 t_c):
|
| 13 |
+
python sample_from_checkpoint.py ... \\
|
| 14 |
+
--steps-before-tc 150 --steps-after-tc 100 --t-c 0.5
|
| 15 |
+
|
| 16 |
+
# 同一批初始噪声连跑两种 t_c 后段步数(输出到 out-dir 下子目录):
|
| 17 |
+
python sample_from_checkpoint.py ... \\
|
| 18 |
+
--steps-before-tc 150 --steps-after-tc 5 --dual-compare-after
|
| 19 |
+
# 分段时会在 at_tc/(或 at_tc/after_input、at_tc/after_equal_before)额外保存 t≈t_c 的解码图。
|
| 20 |
+
|
| 21 |
+
检查点需包含 train.py 写入的键:ema(或 model)、args(推荐,用于自动还原结构)。
|
| 22 |
+
若缺少 args,需通过命令行显式传入 --model、--resolution、--enc-type 等。
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
from __future__ import annotations
|
| 26 |
+
|
| 27 |
+
import argparse
|
| 28 |
+
import os
|
| 29 |
+
import sys
|
| 30 |
+
import types
|
| 31 |
+
import numpy as np
|
| 32 |
+
import torch
|
| 33 |
+
from diffusers.models import AutoencoderKL
|
| 34 |
+
from PIL import Image
|
| 35 |
+
from tqdm import tqdm
|
| 36 |
+
|
| 37 |
+
from models.sit import SiT_models
|
| 38 |
+
from samplers import (
|
| 39 |
+
euler_maruyama_image_noise_before_tc_sampler,
|
| 40 |
+
euler_maruyama_image_noise_sampler,
|
| 41 |
+
euler_maruyama_sampler,
|
| 42 |
+
euler_ode_sampler,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def semantic_dim_from_enc_type(enc_type):
|
| 47 |
+
"""与 train.py 一致:按 enc_type 推断语义/class token 维度。"""
|
| 48 |
+
if enc_type is None:
|
| 49 |
+
return 768
|
| 50 |
+
s = str(enc_type).lower()
|
| 51 |
+
if "vit-g" in s or "vitg" in s:
|
| 52 |
+
return 1536
|
| 53 |
+
if "vit-l" in s or "vitl" in s:
|
| 54 |
+
return 1024
|
| 55 |
+
if "vit-s" in s or "vits" in s:
|
| 56 |
+
return 384
|
| 57 |
+
return 768
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def load_train_args_from_ckpt(ckpt: dict) -> argparse.Namespace | None:
|
| 61 |
+
a = ckpt.get("args")
|
| 62 |
+
if a is None:
|
| 63 |
+
return None
|
| 64 |
+
if isinstance(a, argparse.Namespace):
|
| 65 |
+
return a
|
| 66 |
+
if isinstance(a, dict):
|
| 67 |
+
return argparse.Namespace(**a)
|
| 68 |
+
if isinstance(a, types.SimpleNamespace):
|
| 69 |
+
return argparse.Namespace(**vars(a))
|
| 70 |
+
return None
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def load_vae(device: torch.device):
|
| 74 |
+
"""与 train.py 相同策略:优先本地 diffusers 缓存中的 sd-vae-ft-mse。"""
|
| 75 |
+
try:
|
| 76 |
+
from preprocessing import dnnlib
|
| 77 |
+
|
| 78 |
+
cache_dir = dnnlib.make_cache_dir_path("diffusers")
|
| 79 |
+
os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1")
|
| 80 |
+
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
|
| 81 |
+
os.environ["HF_HOME"] = cache_dir
|
| 82 |
+
try:
|
| 83 |
+
vae = AutoencoderKL.from_pretrained(
|
| 84 |
+
"stabilityai/sd-vae-ft-mse",
|
| 85 |
+
cache_dir=cache_dir,
|
| 86 |
+
local_files_only=True,
|
| 87 |
+
).to(device)
|
| 88 |
+
vae.eval()
|
| 89 |
+
print(f"Loaded VAE from local cache: {cache_dir}")
|
| 90 |
+
return vae
|
| 91 |
+
except Exception:
|
| 92 |
+
pass
|
| 93 |
+
candidate_dir = None
|
| 94 |
+
for root_dir in [
|
| 95 |
+
cache_dir,
|
| 96 |
+
os.path.join(os.path.expanduser("~"), ".cache", "dnnlib", "diffusers"),
|
| 97 |
+
os.path.join(os.path.expanduser("~"), ".cache", "diffusers"),
|
| 98 |
+
os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub"),
|
| 99 |
+
]:
|
| 100 |
+
if not os.path.isdir(root_dir):
|
| 101 |
+
continue
|
| 102 |
+
for root, _, files in os.walk(root_dir):
|
| 103 |
+
if "config.json" in files and "sd-vae-ft-mse" in root.replace("\\", "/"):
|
| 104 |
+
candidate_dir = root
|
| 105 |
+
break
|
| 106 |
+
if candidate_dir is not None:
|
| 107 |
+
break
|
| 108 |
+
if candidate_dir is not None:
|
| 109 |
+
vae = AutoencoderKL.from_pretrained(candidate_dir, local_files_only=True).to(device)
|
| 110 |
+
vae.eval()
|
| 111 |
+
print(f"Loaded VAE from {candidate_dir}")
|
| 112 |
+
return vae
|
| 113 |
+
except Exception as e:
|
| 114 |
+
print(f"VAE local cache search failed: {e}", file=sys.stderr)
|
| 115 |
+
try:
|
| 116 |
+
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device)
|
| 117 |
+
vae.eval()
|
| 118 |
+
print("Loaded VAE from Hub: stabilityai/sd-vae-ft-mse")
|
| 119 |
+
return vae
|
| 120 |
+
except Exception as e:
|
| 121 |
+
raise RuntimeError(
|
| 122 |
+
"无法加载 VAE stabilityai/sd-vae-ft-mse,请确认已下载或网络可用。"
|
| 123 |
+
) from e
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def build_model_from_train_args(ta: argparse.Namespace, device: torch.device):
|
| 127 |
+
res = int(getattr(ta, "resolution", 256))
|
| 128 |
+
latent_size = res // 8
|
| 129 |
+
enc_type = getattr(ta, "enc_type", "dinov2-vit-b")
|
| 130 |
+
z_dims = [semantic_dim_from_enc_type(enc_type)]
|
| 131 |
+
block_kwargs = {
|
| 132 |
+
"fused_attn": getattr(ta, "fused_attn", True),
|
| 133 |
+
"qk_norm": getattr(ta, "qk_norm", False),
|
| 134 |
+
}
|
| 135 |
+
cfg_prob = float(getattr(ta, "cfg_prob", 0.1))
|
| 136 |
+
if ta.model not in SiT_models:
|
| 137 |
+
raise ValueError(f"未知 model={ta.model!r},可选:{list(SiT_models.keys())}")
|
| 138 |
+
model = SiT_models[ta.model](
|
| 139 |
+
input_size=latent_size,
|
| 140 |
+
num_classes=int(getattr(ta, "num_classes", 1000)),
|
| 141 |
+
use_cfg=(cfg_prob > 0),
|
| 142 |
+
z_dims=z_dims,
|
| 143 |
+
encoder_depth=int(getattr(ta, "encoder_depth", 8)),
|
| 144 |
+
**block_kwargs,
|
| 145 |
+
).to(device)
|
| 146 |
+
return model, z_dims[0]
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def resolve_tc_schedule(cli, ta):
|
| 150 |
+
"""
|
| 151 |
+
若同时给出 --steps-before-tc 与 --steps-after-tc:在 t_c 处分段(--t-c 缺省则用检查点 args.t_c)。
|
| 152 |
+
否则使用均匀 --num-steps(与旧版一致)。
|
| 153 |
+
"""
|
| 154 |
+
sb = cli.steps_before_tc
|
| 155 |
+
sa = cli.steps_after_tc
|
| 156 |
+
tc = cli.t_c
|
| 157 |
+
if sb is None and sa is None:
|
| 158 |
+
return None, None, None
|
| 159 |
+
if sb is None or sa is None:
|
| 160 |
+
print(
|
| 161 |
+
"使用分段步数时必须同时指定 --steps-before-tc 与 --steps-after-tc。",
|
| 162 |
+
file=sys.stderr,
|
| 163 |
+
)
|
| 164 |
+
sys.exit(1)
|
| 165 |
+
if tc is None:
|
| 166 |
+
tc = getattr(ta, "t_c", None) if ta is not None else None
|
| 167 |
+
if tc is None:
|
| 168 |
+
print(
|
| 169 |
+
"分段采样需要 --t-c,或检查点 args 中含 t_c。",
|
| 170 |
+
file=sys.stderr,
|
| 171 |
+
)
|
| 172 |
+
sys.exit(1)
|
| 173 |
+
return float(tc), int(sb), int(sa)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def parse_cli():
|
| 177 |
+
p = argparse.ArgumentParser(description="REG 检查点采样出图(可选 ODE/EM/EM-图像噪声)")
|
| 178 |
+
p.add_argument("--ckpt", type=str, required=True, help="train.py 保存的 .pt 路径")
|
| 179 |
+
p.add_argument("--out-dir", type=str, required=True, help="输出 PNG 目录(会创建)")
|
| 180 |
+
p.add_argument("--num-images", type=int, required=True, help="生成图片总数")
|
| 181 |
+
p.add_argument("--batch-size", type=int, default=16)
|
| 182 |
+
p.add_argument("--seed", type=int, default=0)
|
| 183 |
+
p.add_argument(
|
| 184 |
+
"--weights",
|
| 185 |
+
type=str,
|
| 186 |
+
choices=("ema", "model"),
|
| 187 |
+
default="ema",
|
| 188 |
+
help="使用检查点中的 ema 或 model 权重",
|
| 189 |
+
)
|
| 190 |
+
p.add_argument("--device", type=str, default="cuda", help="如 cuda 或 cuda:0")
|
| 191 |
+
p.add_argument(
|
| 192 |
+
"--num-steps",
|
| 193 |
+
type=int,
|
| 194 |
+
default=50,
|
| 195 |
+
help="均匀时间网格时的欧拉步数(未使用 --steps-before-tc/--steps-after-tc 时生效)",
|
| 196 |
+
)
|
| 197 |
+
p.add_argument(
|
| 198 |
+
"--t-c",
|
| 199 |
+
type=float,
|
| 200 |
+
default=None,
|
| 201 |
+
help="分段时刻:t∈(t_c,1] 与 t∈[0,t_c] 两段;缺省可用检查点 args.t_c(需配合两段步数)",
|
| 202 |
+
)
|
| 203 |
+
p.add_argument(
|
| 204 |
+
"--steps-before-tc",
|
| 205 |
+
type=int,
|
| 206 |
+
default=None,
|
| 207 |
+
help="从 t=1 积分到 t=t_c 的步数(与 --steps-after-tc 成对使用)",
|
| 208 |
+
)
|
| 209 |
+
p.add_argument(
|
| 210 |
+
"--steps-after-tc",
|
| 211 |
+
type=int,
|
| 212 |
+
default=None,
|
| 213 |
+
help="从 t=t_c 积分到 t=0(经 t_floor=0.04)的步数",
|
| 214 |
+
)
|
| 215 |
+
p.add_argument("--cfg-scale", type=float, default=1.0)
|
| 216 |
+
p.add_argument("--cls-cfg-scale", type=float, default=0.0, help="cls 分支 CFG(>0 时需 cfg-scale>1)")
|
| 217 |
+
p.add_argument("--guidance-low", type=float, default=0.0)
|
| 218 |
+
p.add_argument("--guidance-high", type=float, default=1.0)
|
| 219 |
+
p.add_argument(
|
| 220 |
+
"--path-type",
|
| 221 |
+
type=str,
|
| 222 |
+
default=None,
|
| 223 |
+
choices=["linear", "cosine"],
|
| 224 |
+
help="默认从检查点 args 读取;可覆盖",
|
| 225 |
+
)
|
| 226 |
+
p.add_argument("--legacy", action=argparse.BooleanOptionalAction, default=False)
|
| 227 |
+
# 无 args 时的兜底
|
| 228 |
+
p.add_argument("--model", type=str, default=None, help="无检查点 args 时必填;与 SiT_models 键一致,如 SiT-XL/2")
|
| 229 |
+
p.add_argument("--resolution", type=int, default=None, choices=[256, 512])
|
| 230 |
+
p.add_argument("--num-classes", type=int, default=None)
|
| 231 |
+
p.add_argument("--encoder-depth", type=int, default=None)
|
| 232 |
+
p.add_argument("--enc-type", type=str, default=None)
|
| 233 |
+
p.add_argument("--fused-attn", action=argparse.BooleanOptionalAction, default=None)
|
| 234 |
+
p.add_argument("--qk-norm", action=argparse.BooleanOptionalAction, default=None)
|
| 235 |
+
p.add_argument("--cfg-prob", type=float, default=None)
|
| 236 |
+
p.add_argument(
|
| 237 |
+
"--sampler",
|
| 238 |
+
type=str,
|
| 239 |
+
default="em_image_noise",
|
| 240 |
+
choices=["ode", "em", "em_image_noise", "em_image_noise_before_tc"],
|
| 241 |
+
help="采样器:ode=euler_sampler 确定性漂移(linspace 1→0 或 t_c 分段直连 0,无 t_floor;与 EM 网格不同),"
|
| 242 |
+
"em=标准EM(含图像+cls噪声),em_image_noise=仅图像噪声,"
|
| 243 |
+
"em_image_noise_before_tc=t<=t_c时图像去随机+cls全程去随机",
|
| 244 |
+
)
|
| 245 |
+
p.add_argument(
|
| 246 |
+
"--dual-compare-after",
|
| 247 |
+
action="store_true",
|
| 248 |
+
help="需配合分段步数:同批 z/y/cls 连跑两次;after_input 用 --steps-after-tc,"
|
| 249 |
+
"after_equal_before 将 after 步数设为与 --steps-before-tc 相同",
|
| 250 |
+
)
|
| 251 |
+
p.add_argument(
|
| 252 |
+
"--save-fixed-trajectory",
|
| 253 |
+
action="store_true",
|
| 254 |
+
help="保存固定步采样轨迹(npy);仅对非 em 采样器启用,输出在 out-dir/trajectory",
|
| 255 |
+
)
|
| 256 |
+
return p.parse_args()
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def _decode_to_uint8_hwc(latents, latents_bias, latents_scale, vae):
|
| 260 |
+
imgs = vae.decode((latents - latents_bias) / latents_scale).sample
|
| 261 |
+
imgs = (imgs + 1) / 2.0
|
| 262 |
+
imgs = torch.clamp(imgs, 0, 1)
|
| 263 |
+
return (
|
| 264 |
+
(imgs * 255.0)
|
| 265 |
+
.round()
|
| 266 |
+
.to(torch.uint8)
|
| 267 |
+
.permute(0, 2, 3, 1)
|
| 268 |
+
.cpu()
|
| 269 |
+
.numpy()
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def main():
|
| 274 |
+
cli = parse_cli()
|
| 275 |
+
device = torch.device(cli.device if torch.cuda.is_available() else "cpu")
|
| 276 |
+
if device.type == "cuda":
|
| 277 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 278 |
+
|
| 279 |
+
try:
|
| 280 |
+
ckpt = torch.load(cli.ckpt, map_location="cpu", weights_only=False)
|
| 281 |
+
except TypeError:
|
| 282 |
+
ckpt = torch.load(cli.ckpt, map_location="cpu")
|
| 283 |
+
ta = load_train_args_from_ckpt(ckpt)
|
| 284 |
+
if ta is None:
|
| 285 |
+
if cli.model is None or cli.resolution is None or cli.enc_type is None:
|
| 286 |
+
print(
|
| 287 |
+
"检查点中无 args,请至少指定:--model --resolution --enc-type "
|
| 288 |
+
"(以及按需 --num-classes --encoder-depth)",
|
| 289 |
+
file=sys.stderr,
|
| 290 |
+
)
|
| 291 |
+
sys.exit(1)
|
| 292 |
+
ta = argparse.Namespace(
|
| 293 |
+
model=cli.model,
|
| 294 |
+
resolution=cli.resolution,
|
| 295 |
+
num_classes=cli.num_classes if cli.num_classes is not None else 1000,
|
| 296 |
+
encoder_depth=cli.encoder_depth if cli.encoder_depth is not None else 8,
|
| 297 |
+
enc_type=cli.enc_type,
|
| 298 |
+
fused_attn=cli.fused_attn if cli.fused_attn is not None else True,
|
| 299 |
+
qk_norm=cli.qk_norm if cli.qk_norm is not None else False,
|
| 300 |
+
cfg_prob=cli.cfg_prob if cli.cfg_prob is not None else 0.1,
|
| 301 |
+
)
|
| 302 |
+
else:
|
| 303 |
+
if cli.model is not None:
|
| 304 |
+
ta.model = cli.model
|
| 305 |
+
if cli.resolution is not None:
|
| 306 |
+
ta.resolution = cli.resolution
|
| 307 |
+
if cli.num_classes is not None:
|
| 308 |
+
ta.num_classes = cli.num_classes
|
| 309 |
+
if cli.encoder_depth is not None:
|
| 310 |
+
ta.encoder_depth = cli.encoder_depth
|
| 311 |
+
if cli.enc_type is not None:
|
| 312 |
+
ta.enc_type = cli.enc_type
|
| 313 |
+
if cli.fused_attn is not None:
|
| 314 |
+
ta.fused_attn = cli.fused_attn
|
| 315 |
+
if cli.qk_norm is not None:
|
| 316 |
+
ta.qk_norm = cli.qk_norm
|
| 317 |
+
if cli.cfg_prob is not None:
|
| 318 |
+
ta.cfg_prob = cli.cfg_prob
|
| 319 |
+
|
| 320 |
+
path_type = cli.path_type if cli.path_type is not None else getattr(ta, "path_type", "linear")
|
| 321 |
+
|
| 322 |
+
tc_split = resolve_tc_schedule(cli, ta)
|
| 323 |
+
if cli.dual_compare_after and tc_split[0] is None:
|
| 324 |
+
print("--dual-compare-after 必须配合 --steps-before-tc 与 --steps-after-tc(分段采样)", file=sys.stderr)
|
| 325 |
+
sys.exit(1)
|
| 326 |
+
if tc_split[0] is not None:
|
| 327 |
+
if cli.dual_compare_after:
|
| 328 |
+
print(
|
| 329 |
+
f"双次对比:t_c={tc_split[0]}, before={tc_split[1]}, "
|
| 330 |
+
f"after_input={tc_split[2]}, after_equal_before={tc_split[1]}"
|
| 331 |
+
)
|
| 332 |
+
else:
|
| 333 |
+
print(
|
| 334 |
+
f"时间网格:t_c={tc_split[0]}, 步数 (1→t_c)={tc_split[1]}, (t_c→0)={tc_split[2]} "
|
| 335 |
+
f"(总模型前向约 {tc_split[1] + tc_split[2] + 1} 次)"
|
| 336 |
+
)
|
| 337 |
+
else:
|
| 338 |
+
print(f"时间网格:均匀 num_steps={cli.num_steps}")
|
| 339 |
+
|
| 340 |
+
if cli.sampler == "ode":
|
| 341 |
+
sampler_fn = euler_ode_sampler
|
| 342 |
+
elif cli.sampler == "em":
|
| 343 |
+
sampler_fn = euler_maruyama_sampler
|
| 344 |
+
elif cli.sampler == "em_image_noise_before_tc":
|
| 345 |
+
sampler_fn = euler_maruyama_image_noise_before_tc_sampler
|
| 346 |
+
else:
|
| 347 |
+
sampler_fn = euler_maruyama_image_noise_sampler
|
| 348 |
+
|
| 349 |
+
model, cls_dim = build_model_from_train_args(ta, device)
|
| 350 |
+
wkey = cli.weights
|
| 351 |
+
if wkey not in ckpt:
|
| 352 |
+
raise KeyError(f"检查点中无 '{wkey}' 键,现有键:{list(ckpt.keys())}")
|
| 353 |
+
state = ckpt[wkey]
|
| 354 |
+
if cli.legacy:
|
| 355 |
+
from utils import load_legacy_checkpoints
|
| 356 |
+
|
| 357 |
+
state = load_legacy_checkpoints(
|
| 358 |
+
state_dict=state, encoder_depth=int(getattr(ta, "encoder_depth", 8))
|
| 359 |
+
)
|
| 360 |
+
model.load_state_dict(state, strict=True)
|
| 361 |
+
model.eval()
|
| 362 |
+
|
| 363 |
+
vae = load_vae(device)
|
| 364 |
+
latents_scale = torch.tensor([0.18215] * 4, device=device).view(1, 4, 1, 1)
|
| 365 |
+
latents_bias = torch.tensor([0.0] * 4, device=device).view(1, 4, 1, 1)
|
| 366 |
+
|
| 367 |
+
sampler_args = argparse.Namespace(cls_cfg_scale=float(cli.cls_cfg_scale))
|
| 368 |
+
|
| 369 |
+
at_tc_dir = at_tc_a = at_tc_b = None
|
| 370 |
+
traj_dir = traj_a = traj_b = None
|
| 371 |
+
if cli.dual_compare_after:
|
| 372 |
+
out_a = os.path.join(cli.out_dir, "after_input")
|
| 373 |
+
out_b = os.path.join(cli.out_dir, "after_equal_before")
|
| 374 |
+
os.makedirs(out_a, exist_ok=True)
|
| 375 |
+
os.makedirs(out_b, exist_ok=True)
|
| 376 |
+
if tc_split[0] is not None:
|
| 377 |
+
at_tc_a = os.path.join(cli.out_dir, "at_tc", "after_input")
|
| 378 |
+
at_tc_b = os.path.join(cli.out_dir, "at_tc", "after_equal_before")
|
| 379 |
+
os.makedirs(at_tc_a, exist_ok=True)
|
| 380 |
+
os.makedirs(at_tc_b, exist_ok=True)
|
| 381 |
+
if cli.save_fixed_trajectory and cli.sampler != "em":
|
| 382 |
+
traj_a = os.path.join(cli.out_dir, "trajectory", "after_input")
|
| 383 |
+
traj_b = os.path.join(cli.out_dir, "trajectory", "after_equal_before")
|
| 384 |
+
os.makedirs(traj_a, exist_ok=True)
|
| 385 |
+
os.makedirs(traj_b, exist_ok=True)
|
| 386 |
+
else:
|
| 387 |
+
os.makedirs(cli.out_dir, exist_ok=True)
|
| 388 |
+
if tc_split[0] is not None:
|
| 389 |
+
at_tc_dir = os.path.join(cli.out_dir, "at_tc")
|
| 390 |
+
os.makedirs(at_tc_dir, exist_ok=True)
|
| 391 |
+
if cli.save_fixed_trajectory and cli.sampler != "em":
|
| 392 |
+
traj_dir = os.path.join(cli.out_dir, "trajectory")
|
| 393 |
+
os.makedirs(traj_dir, exist_ok=True)
|
| 394 |
+
latent_size = int(getattr(ta, "resolution", 256)) // 8
|
| 395 |
+
n_total = int(cli.num_images)
|
| 396 |
+
b = max(1, int(cli.batch_size))
|
| 397 |
+
|
| 398 |
+
torch.manual_seed(cli.seed)
|
| 399 |
+
if device.type == "cuda":
|
| 400 |
+
torch.cuda.manual_seed_all(cli.seed)
|
| 401 |
+
|
| 402 |
+
written = 0
|
| 403 |
+
pbar = tqdm(total=n_total, desc="sampling")
|
| 404 |
+
while written < n_total:
|
| 405 |
+
cur = min(b, n_total - written)
|
| 406 |
+
z = torch.randn(cur, model.in_channels, latent_size, latent_size, device=device)
|
| 407 |
+
y = torch.randint(0, int(ta.num_classes), (cur,), device=device)
|
| 408 |
+
cls_z = torch.randn(cur, cls_dim, device=device)
|
| 409 |
+
|
| 410 |
+
with torch.no_grad():
|
| 411 |
+
base_kw = dict(
|
| 412 |
+
num_steps=cli.num_steps,
|
| 413 |
+
cfg_scale=cli.cfg_scale,
|
| 414 |
+
guidance_low=cli.guidance_low,
|
| 415 |
+
guidance_high=cli.guidance_high,
|
| 416 |
+
path_type=path_type,
|
| 417 |
+
cls_latents=cls_z,
|
| 418 |
+
args=sampler_args,
|
| 419 |
+
)
|
| 420 |
+
if cli.dual_compare_after:
|
| 421 |
+
tc_v, sb, sa_in = tc_split
|
| 422 |
+
# 两次完整采样会各自消耗 RNG;不重置则第二条的 1→t_c 噪声与第一条不同,z_tc/at_tc 会对不齐。
|
| 423 |
+
# 在固定 z/y/cls_z 之后打快照,第二条运行前恢复,使 t_c 中间态一致(仅后段步数不同)。
|
| 424 |
+
_rng_cpu_dual = torch.random.get_rng_state()
|
| 425 |
+
_rng_cuda_dual = (
|
| 426 |
+
torch.cuda.get_rng_state_all()
|
| 427 |
+
if device.type == "cuda"
|
| 428 |
+
else None
|
| 429 |
+
)
|
| 430 |
+
for _run_i, (subdir, sa, tc_save_dir) in enumerate(
|
| 431 |
+
(
|
| 432 |
+
(out_a, sa_in, at_tc_a),
|
| 433 |
+
(out_b, sb, at_tc_b),
|
| 434 |
+
)
|
| 435 |
+
):
|
| 436 |
+
if _run_i > 0:
|
| 437 |
+
torch.random.set_rng_state(_rng_cpu_dual)
|
| 438 |
+
if _rng_cuda_dual is not None:
|
| 439 |
+
torch.cuda.set_rng_state_all(_rng_cuda_dual)
|
| 440 |
+
em_kw = dict(base_kw)
|
| 441 |
+
em_kw["t_c"] = tc_v
|
| 442 |
+
em_kw["num_steps_before_tc"] = sb
|
| 443 |
+
em_kw["num_steps_after_tc"] = sa
|
| 444 |
+
if cli.sampler == "em_image_noise_before_tc":
|
| 445 |
+
if cli.save_fixed_trajectory and cli.sampler != "em":
|
| 446 |
+
latents, z_tc, cls_tc, cls_t0, traj = sampler_fn(
|
| 447 |
+
model,
|
| 448 |
+
z,
|
| 449 |
+
y,
|
| 450 |
+
**em_kw,
|
| 451 |
+
return_mid_state=True,
|
| 452 |
+
t_mid=float(tc_v),
|
| 453 |
+
return_cls_final=True,
|
| 454 |
+
return_trajectory=True,
|
| 455 |
+
)
|
| 456 |
+
else:
|
| 457 |
+
latents, z_tc, cls_tc, cls_t0 = sampler_fn(
|
| 458 |
+
model,
|
| 459 |
+
z,
|
| 460 |
+
y,
|
| 461 |
+
**em_kw,
|
| 462 |
+
return_mid_state=True,
|
| 463 |
+
t_mid=float(tc_v),
|
| 464 |
+
return_cls_final=True,
|
| 465 |
+
)
|
| 466 |
+
traj = None
|
| 467 |
+
else:
|
| 468 |
+
if cli.save_fixed_trajectory and cli.sampler != "em":
|
| 469 |
+
latents, z_tc, cls_tc, traj = sampler_fn(
|
| 470 |
+
model,
|
| 471 |
+
z,
|
| 472 |
+
y,
|
| 473 |
+
**em_kw,
|
| 474 |
+
return_mid_state=True,
|
| 475 |
+
t_mid=float(tc_v),
|
| 476 |
+
return_trajectory=True,
|
| 477 |
+
)
|
| 478 |
+
else:
|
| 479 |
+
latents, z_tc, cls_tc = sampler_fn(
|
| 480 |
+
model,
|
| 481 |
+
z,
|
| 482 |
+
y,
|
| 483 |
+
**em_kw,
|
| 484 |
+
return_mid_state=True,
|
| 485 |
+
t_mid=float(tc_v),
|
| 486 |
+
)
|
| 487 |
+
traj = None
|
| 488 |
+
cls_t0 = None
|
| 489 |
+
latents = latents.to(torch.float32)
|
| 490 |
+
imgs = _decode_to_uint8_hwc(latents, latents_bias, latents_scale, vae)
|
| 491 |
+
for i in range(cur):
|
| 492 |
+
Image.fromarray(imgs[i]).save(
|
| 493 |
+
os.path.join(subdir, f"{written + i:06d}.png")
|
| 494 |
+
)
|
| 495 |
+
if tc_save_dir is not None and z_tc is not None:
|
| 496 |
+
imgs_tc = _decode_to_uint8_hwc(
|
| 497 |
+
z_tc.to(torch.float32), latents_bias, latents_scale, vae
|
| 498 |
+
)
|
| 499 |
+
for i in range(cur):
|
| 500 |
+
Image.fromarray(imgs_tc[i]).save(
|
| 501 |
+
os.path.join(tc_save_dir, f"{written + i:06d}.png")
|
| 502 |
+
)
|
| 503 |
+
if traj is not None:
|
| 504 |
+
traj_np = torch.stack(traj, dim=0).to(torch.float32).cpu().numpy()
|
| 505 |
+
save_traj_dir = traj_a if subdir == out_a else traj_b
|
| 506 |
+
np.save(os.path.join(save_traj_dir, f"{written:06d}_traj.npy"), traj_np)
|
| 507 |
+
else:
|
| 508 |
+
em_kw = dict(base_kw)
|
| 509 |
+
if tc_split[0] is not None:
|
| 510 |
+
em_kw["t_c"] = tc_split[0]
|
| 511 |
+
em_kw["num_steps_before_tc"] = tc_split[1]
|
| 512 |
+
em_kw["num_steps_after_tc"] = tc_split[2]
|
| 513 |
+
if cli.sampler == "em_image_noise_before_tc":
|
| 514 |
+
if cli.save_fixed_trajectory and cli.sampler != "em":
|
| 515 |
+
latents, z_tc, cls_tc, cls_t0, traj = sampler_fn(
|
| 516 |
+
model,
|
| 517 |
+
z,
|
| 518 |
+
y,
|
| 519 |
+
**em_kw,
|
| 520 |
+
return_mid_state=True,
|
| 521 |
+
t_mid=float(tc_split[0]),
|
| 522 |
+
return_cls_final=True,
|
| 523 |
+
return_trajectory=True,
|
| 524 |
+
)
|
| 525 |
+
else:
|
| 526 |
+
latents, z_tc, cls_tc, cls_t0 = sampler_fn(
|
| 527 |
+
model,
|
| 528 |
+
z,
|
| 529 |
+
y,
|
| 530 |
+
**em_kw,
|
| 531 |
+
return_mid_state=True,
|
| 532 |
+
t_mid=float(tc_split[0]),
|
| 533 |
+
return_cls_final=True,
|
| 534 |
+
)
|
| 535 |
+
traj = None
|
| 536 |
+
else:
|
| 537 |
+
if cli.save_fixed_trajectory and cli.sampler != "em":
|
| 538 |
+
latents, z_tc, cls_tc, traj = sampler_fn(
|
| 539 |
+
model,
|
| 540 |
+
z,
|
| 541 |
+
y,
|
| 542 |
+
**em_kw,
|
| 543 |
+
return_mid_state=True,
|
| 544 |
+
t_mid=float(tc_split[0]),
|
| 545 |
+
return_trajectory=True,
|
| 546 |
+
)
|
| 547 |
+
else:
|
| 548 |
+
latents, z_tc, cls_tc = sampler_fn(
|
| 549 |
+
model,
|
| 550 |
+
z,
|
| 551 |
+
y,
|
| 552 |
+
**em_kw,
|
| 553 |
+
return_mid_state=True,
|
| 554 |
+
t_mid=float(tc_split[0]),
|
| 555 |
+
)
|
| 556 |
+
traj = None
|
| 557 |
+
cls_t0 = None
|
| 558 |
+
latents = latents.to(torch.float32)
|
| 559 |
+
if z_tc is not None and at_tc_dir is not None:
|
| 560 |
+
imgs_tc = _decode_to_uint8_hwc(
|
| 561 |
+
z_tc.to(torch.float32), latents_bias, latents_scale, vae
|
| 562 |
+
)
|
| 563 |
+
for i in range(cur):
|
| 564 |
+
Image.fromarray(imgs_tc[i]).save(
|
| 565 |
+
os.path.join(at_tc_dir, f"{written + i:06d}.png")
|
| 566 |
+
)
|
| 567 |
+
if traj is not None and traj_dir is not None:
|
| 568 |
+
traj_np = torch.stack(traj, dim=0).to(torch.float32).cpu().numpy()
|
| 569 |
+
np.save(os.path.join(traj_dir, f"{written:06d}_traj.npy"), traj_np)
|
| 570 |
+
else:
|
| 571 |
+
latents = sampler_fn(model, z, y, **em_kw).to(torch.float32)
|
| 572 |
+
imgs = _decode_to_uint8_hwc(latents, latents_bias, latents_scale, vae)
|
| 573 |
+
for i in range(cur):
|
| 574 |
+
Image.fromarray(imgs[i]).save(
|
| 575 |
+
os.path.join(cli.out_dir, f"{written + i:06d}.png")
|
| 576 |
+
)
|
| 577 |
+
written += cur
|
| 578 |
+
pbar.update(cur)
|
| 579 |
+
pbar.close()
|
| 580 |
+
if cli.dual_compare_after:
|
| 581 |
+
msg = (
|
| 582 |
+
f"Done. Saved {written} images per run under {out_a} and {out_b} "
|
| 583 |
+
f"(parent: {cli.out_dir})"
|
| 584 |
+
)
|
| 585 |
+
if tc_split[0] is not None and at_tc_a is not None:
|
| 586 |
+
msg += f"; t≈t_c decoded under {at_tc_a} and {at_tc_b}"
|
| 587 |
+
print(msg)
|
| 588 |
+
else:
|
| 589 |
+
msg = f"Done. Saved {written} images under {cli.out_dir}"
|
| 590 |
+
if tc_split[0] is not None and at_tc_dir is not None:
|
| 591 |
+
msg += f"; t≈t_c decoded under {at_tc_dir}"
|
| 592 |
+
print(msg)
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
if __name__ == "__main__":
|
| 596 |
+
main()
|
back/samples.sh
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# 双次对比步数请用 --dual-compare-after(见 sample_from_checkpoint.py),输出在 out-dir 子目录。
|
| 3 |
+
|
| 4 |
+
CUDA_VISIBLE_DEVICES=1 python sample_from_checkpoint.py \
|
| 5 |
+
--ckpt /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/exps/jsflow-experiment-0.75/checkpoints/0500000.pt \
|
| 6 |
+
--out-dir ./my_samples_test \
|
| 7 |
+
--num-images 24 \
|
| 8 |
+
--batch-size 4 \
|
| 9 |
+
--seed 0 \
|
| 10 |
+
--t-c 0.75 \
|
| 11 |
+
--steps-before-tc 50 \
|
| 12 |
+
--steps-after-tc 5 \
|
| 13 |
+
--sampler ode \
|
| 14 |
+
--cfg-scale 1.0 \
|
| 15 |
+
--dual-compare-after \
|
back/samples_0.5.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
back/samples_ddp.sh
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
# 4 卡 DDP 单路径采样(不做 dual-compare,不保存 at_tc 中间图)
|
| 4 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 nohup nohup torchrun \
|
| 5 |
+
--nnodes=1 \
|
| 6 |
+
--nproc_per_node=4 \
|
| 7 |
+
--rdzv_endpoint=localhost:29110 \
|
| 8 |
+
sample_from_checkpoint_ddp.py \
|
| 9 |
+
--ckpt /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/exps/jsflow-experiment-0.75/checkpoints/0600000.pt \
|
| 10 |
+
--out-dir ./my_samples_600k_new \
|
| 11 |
+
--num-images 40000 \
|
| 12 |
+
--batch-size 64 \
|
| 13 |
+
--seed 0 \
|
| 14 |
+
--t-c 0.75 \
|
| 15 |
+
--steps-before-tc 100 \
|
| 16 |
+
--steps-after-tc 50 \
|
| 17 |
+
--sampler em_image_noise_before_tc \
|
| 18 |
+
--cfg-scale 1.0 \
|
| 19 |
+
> samples_0.75_new.log 2>&1 &
|
| 20 |
+
|
| 21 |
+
# nohup python sample_from_checkpoint_ddp.py \
|
| 22 |
+
# --ckpt /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/exps/jsflow-experiment-0.5/checkpoints/0250000.pt \
|
| 23 |
+
# --out-dir ./my_samples_5 \
|
| 24 |
+
# --num-images 20000 \
|
| 25 |
+
# --batch-size 16 \
|
| 26 |
+
# --seed 0 \
|
| 27 |
+
# --t-c 0.5 \
|
| 28 |
+
# --steps-before-tc 100 \
|
| 29 |
+
# --steps-after-tc 50 \
|
| 30 |
+
# --sampler em_image_noise_before_tc \
|
| 31 |
+
# --cfg-scale 1.0 \
|
| 32 |
+
# > samples_0.5.log 2>&1 &
|
back/train.py
ADDED
|
@@ -0,0 +1,670 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import copy
|
| 3 |
+
from copy import deepcopy
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from collections import OrderedDict
|
| 8 |
+
import json
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
import torch.utils.checkpoint
|
| 14 |
+
from tqdm.auto import tqdm
|
| 15 |
+
from torch.utils.data import DataLoader
|
| 16 |
+
|
| 17 |
+
from accelerate import Accelerator, DistributedDataParallelKwargs
|
| 18 |
+
from accelerate.logging import get_logger
|
| 19 |
+
from accelerate.utils import ProjectConfiguration, set_seed
|
| 20 |
+
|
| 21 |
+
from models.sit import SiT_models
|
| 22 |
+
from loss import SILoss
|
| 23 |
+
from utils import load_encoders
|
| 24 |
+
|
| 25 |
+
from dataset import CustomDataset
|
| 26 |
+
from diffusers.models import AutoencoderKL
|
| 27 |
+
# import wandb_utils
|
| 28 |
+
import wandb
|
| 29 |
+
import math
|
| 30 |
+
from torchvision.utils import make_grid
|
| 31 |
+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
| 32 |
+
from torchvision.transforms import Normalize
|
| 33 |
+
from PIL import Image
|
| 34 |
+
|
| 35 |
+
logger = get_logger(__name__)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def semantic_dim_from_enc_type(enc_type):
|
| 39 |
+
"""DINOv2 等 enc_type 字符串推断 class token 维度(与预处理特征一致)。"""
|
| 40 |
+
if enc_type is None:
|
| 41 |
+
return 768
|
| 42 |
+
s = str(enc_type).lower()
|
| 43 |
+
if "vit-g" in s or "vitg" in s:
|
| 44 |
+
return 1536
|
| 45 |
+
if "vit-l" in s or "vitl" in s:
|
| 46 |
+
return 1024
|
| 47 |
+
if "vit-s" in s or "vits" in s:
|
| 48 |
+
return 384
|
| 49 |
+
return 768
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
CLIP_DEFAULT_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
| 53 |
+
CLIP_DEFAULT_STD = (0.26862954, 0.26130258, 0.27577711)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def preprocess_raw_image(x, enc_type):
|
| 58 |
+
resolution = x.shape[-1]
|
| 59 |
+
if 'clip' in enc_type:
|
| 60 |
+
x = x / 255.
|
| 61 |
+
x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic')
|
| 62 |
+
x = Normalize(CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD)(x)
|
| 63 |
+
elif 'mocov3' in enc_type or 'mae' in enc_type:
|
| 64 |
+
x = x / 255.
|
| 65 |
+
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
| 66 |
+
elif 'dinov2' in enc_type:
|
| 67 |
+
x = x / 255.
|
| 68 |
+
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
| 69 |
+
x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic')
|
| 70 |
+
elif 'dinov1' in enc_type:
|
| 71 |
+
x = x / 255.
|
| 72 |
+
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
| 73 |
+
elif 'jepa' in enc_type:
|
| 74 |
+
x = x / 255.
|
| 75 |
+
x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
|
| 76 |
+
x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic')
|
| 77 |
+
|
| 78 |
+
return x
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def array2grid(x):
|
| 82 |
+
nrow = round(math.sqrt(x.size(0)))
|
| 83 |
+
x = make_grid(x.clamp(0, 1), nrow=nrow, value_range=(0, 1))
|
| 84 |
+
x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
|
| 85 |
+
return x
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@torch.no_grad()
|
| 89 |
+
def sample_posterior(moments, latents_scale=1., latents_bias=0.):
|
| 90 |
+
device = moments.device
|
| 91 |
+
|
| 92 |
+
mean, std = torch.chunk(moments, 2, dim=1)
|
| 93 |
+
z = mean + std * torch.randn_like(mean)
|
| 94 |
+
z = (z * latents_scale + latents_bias)
|
| 95 |
+
return z
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
@torch.no_grad()
|
| 99 |
+
def update_ema(ema_model, model, decay=0.9999):
|
| 100 |
+
"""
|
| 101 |
+
Step the EMA model towards the current model.
|
| 102 |
+
"""
|
| 103 |
+
ema_params = OrderedDict(ema_model.named_parameters())
|
| 104 |
+
model_params = OrderedDict(model.named_parameters())
|
| 105 |
+
|
| 106 |
+
for name, param in model_params.items():
|
| 107 |
+
name = name.replace("module.", "")
|
| 108 |
+
# TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
|
| 109 |
+
ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def create_logger(logging_dir):
|
| 113 |
+
"""
|
| 114 |
+
Create a logger that writes to a log file and stdout.
|
| 115 |
+
"""
|
| 116 |
+
logging.basicConfig(
|
| 117 |
+
level=logging.INFO,
|
| 118 |
+
format='[\033[34m%(asctime)s\033[0m] %(message)s',
|
| 119 |
+
datefmt='%Y-%m-%d %H:%M:%S',
|
| 120 |
+
handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
|
| 121 |
+
)
|
| 122 |
+
logger = logging.getLogger(__name__)
|
| 123 |
+
return logger
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def requires_grad(model, flag=True):
|
| 127 |
+
"""
|
| 128 |
+
Set requires_grad flag for all parameters in a model.
|
| 129 |
+
"""
|
| 130 |
+
for p in model.parameters():
|
| 131 |
+
p.requires_grad = flag
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
#################################################################################
|
| 135 |
+
# Training Loop #
|
| 136 |
+
#################################################################################
|
| 137 |
+
|
| 138 |
+
def main(args):
|
| 139 |
+
# set accelerator
|
| 140 |
+
logging_dir = Path(args.output_dir, args.logging_dir)
|
| 141 |
+
accelerator_project_config = ProjectConfiguration(
|
| 142 |
+
project_dir=args.output_dir, logging_dir=logging_dir
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
accelerator = Accelerator(
|
| 146 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 147 |
+
mixed_precision=args.mixed_precision,
|
| 148 |
+
log_with=args.report_to,
|
| 149 |
+
project_config=accelerator_project_config,
|
| 150 |
+
kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)]
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
if accelerator.is_main_process:
|
| 154 |
+
os.makedirs(args.output_dir, exist_ok=True) # Make results folder (holds all experiment subfolders)
|
| 155 |
+
save_dir = os.path.join(args.output_dir, args.exp_name)
|
| 156 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 157 |
+
args_dict = vars(args)
|
| 158 |
+
# Save to a JSON file
|
| 159 |
+
json_dir = os.path.join(save_dir, "args.json")
|
| 160 |
+
with open(json_dir, 'w') as f:
|
| 161 |
+
json.dump(args_dict, f, indent=4)
|
| 162 |
+
checkpoint_dir = f"{save_dir}/checkpoints" # Stores saved model checkpoints
|
| 163 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 164 |
+
logger = create_logger(save_dir)
|
| 165 |
+
logger.info(f"Experiment directory created at {save_dir}")
|
| 166 |
+
device = accelerator.device
|
| 167 |
+
if torch.backends.mps.is_available():
|
| 168 |
+
accelerator.native_amp = False
|
| 169 |
+
if args.seed is not None:
|
| 170 |
+
set_seed(args.seed + accelerator.process_index)
|
| 171 |
+
|
| 172 |
+
# Create model:
|
| 173 |
+
assert args.resolution % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)."
|
| 174 |
+
latent_size = args.resolution // 8
|
| 175 |
+
|
| 176 |
+
train_dataset = CustomDataset(
|
| 177 |
+
args.data_dir, semantic_features_dir=args.semantic_features_dir
|
| 178 |
+
)
|
| 179 |
+
use_preprocessed_semantic = train_dataset.use_preprocessed_semantic
|
| 180 |
+
|
| 181 |
+
if use_preprocessed_semantic:
|
| 182 |
+
encoders, encoder_types, architectures = [], [], []
|
| 183 |
+
z_dims = [semantic_dim_from_enc_type(args.enc_type)]
|
| 184 |
+
if accelerator.is_main_process:
|
| 185 |
+
logger.info(
|
| 186 |
+
f"Preprocessed semantic features: skip loading online encoder, z_dims={z_dims}"
|
| 187 |
+
)
|
| 188 |
+
elif args.enc_type is not None:
|
| 189 |
+
encoders, encoder_types, architectures = load_encoders(
|
| 190 |
+
args.enc_type, device, args.resolution
|
| 191 |
+
)
|
| 192 |
+
z_dims = [encoder.embed_dim for encoder in encoders]
|
| 193 |
+
else:
|
| 194 |
+
raise NotImplementedError()
|
| 195 |
+
block_kwargs = {"fused_attn": args.fused_attn, "qk_norm": args.qk_norm}
|
| 196 |
+
model = SiT_models[args.model](
|
| 197 |
+
input_size=latent_size,
|
| 198 |
+
num_classes=args.num_classes,
|
| 199 |
+
use_cfg = (args.cfg_prob > 0),
|
| 200 |
+
z_dims = z_dims,
|
| 201 |
+
encoder_depth=args.encoder_depth,
|
| 202 |
+
**block_kwargs
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
model = model.to(device)
|
| 206 |
+
ema = deepcopy(model).to(device) # Create an EMA of the model for use after training
|
| 207 |
+
requires_grad(ema, False)
|
| 208 |
+
|
| 209 |
+
latents_scale = torch.tensor(
|
| 210 |
+
[0.18215, 0.18215, 0.18215, 0.18215]
|
| 211 |
+
).view(1, 4, 1, 1).to(device)
|
| 212 |
+
latents_bias = torch.tensor(
|
| 213 |
+
[0., 0., 0., 0.]
|
| 214 |
+
).view(1, 4, 1, 1).to(device)
|
| 215 |
+
|
| 216 |
+
# VAE decoder:采样阶段将 latent 解码为图像(与根目录 train.py / 预处理一致:sd-vae-ft-mse)
|
| 217 |
+
try:
|
| 218 |
+
from preprocessing import dnnlib
|
| 219 |
+
cache_dir = dnnlib.make_cache_dir_path("diffusers")
|
| 220 |
+
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
|
| 221 |
+
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
|
| 222 |
+
os.environ["HF_HOME"] = cache_dir
|
| 223 |
+
try:
|
| 224 |
+
vae = AutoencoderKL.from_pretrained(
|
| 225 |
+
"stabilityai/sd-vae-ft-mse",
|
| 226 |
+
cache_dir=cache_dir,
|
| 227 |
+
local_files_only=True,
|
| 228 |
+
).to(device)
|
| 229 |
+
vae.eval()
|
| 230 |
+
if accelerator.is_main_process:
|
| 231 |
+
logger.info(
|
| 232 |
+
"Loaded VAE 'stabilityai/sd-vae-ft-mse' from local diffusers cache "
|
| 233 |
+
f"at '{cache_dir}' for intermediate sampling."
|
| 234 |
+
)
|
| 235 |
+
except Exception as e_main:
|
| 236 |
+
vae = None
|
| 237 |
+
candidate_dir = None
|
| 238 |
+
possible_roots = [
|
| 239 |
+
cache_dir,
|
| 240 |
+
os.path.join(os.path.expanduser("~"), ".cache", "dnnlib", "diffusers"),
|
| 241 |
+
os.path.join(os.path.expanduser("~"), ".cache", "diffusers"),
|
| 242 |
+
os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub"),
|
| 243 |
+
]
|
| 244 |
+
checked_roots = []
|
| 245 |
+
for root_dir in possible_roots:
|
| 246 |
+
if not os.path.isdir(root_dir):
|
| 247 |
+
continue
|
| 248 |
+
checked_roots.append(root_dir)
|
| 249 |
+
for root, dirs, files in os.walk(root_dir):
|
| 250 |
+
if "config.json" in files and "sd-vae-ft-mse" in root.replace("\\", "/"):
|
| 251 |
+
candidate_dir = root
|
| 252 |
+
break
|
| 253 |
+
if candidate_dir is not None:
|
| 254 |
+
break
|
| 255 |
+
if candidate_dir is not None:
|
| 256 |
+
try:
|
| 257 |
+
vae = AutoencoderKL.from_pretrained(
|
| 258 |
+
candidate_dir,
|
| 259 |
+
local_files_only=True,
|
| 260 |
+
).to(device)
|
| 261 |
+
vae.eval()
|
| 262 |
+
if accelerator.is_main_process:
|
| 263 |
+
logger.info(
|
| 264 |
+
"Loaded VAE 'stabilityai/sd-vae-ft-mse' from discovered local path "
|
| 265 |
+
f"'{candidate_dir}'. Searched roots: {checked_roots}"
|
| 266 |
+
)
|
| 267 |
+
except Exception as e_fallback:
|
| 268 |
+
if accelerator.is_main_process:
|
| 269 |
+
logger.warning(
|
| 270 |
+
"Tried to load VAE from discovered local path "
|
| 271 |
+
f"'{candidate_dir}' but failed: {e_fallback}"
|
| 272 |
+
)
|
| 273 |
+
if vae is None and accelerator.is_main_process:
|
| 274 |
+
logger.warning(
|
| 275 |
+
"Could not load VAE 'stabilityai/sd-vae-ft-mse' via repo name or local search. "
|
| 276 |
+
f"Last repo-level error: {e_main}"
|
| 277 |
+
)
|
| 278 |
+
except Exception as e:
|
| 279 |
+
vae = None
|
| 280 |
+
if accelerator.is_main_process:
|
| 281 |
+
logger.warning(
|
| 282 |
+
f"Failed to initialize VAE loading logic (will skip image decoding): {e}"
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
# create loss function
|
| 286 |
+
loss_fn = SILoss(
|
| 287 |
+
prediction=args.prediction,
|
| 288 |
+
path_type=args.path_type,
|
| 289 |
+
encoders=encoders,
|
| 290 |
+
accelerator=accelerator,
|
| 291 |
+
latents_scale=latents_scale,
|
| 292 |
+
latents_bias=latents_bias,
|
| 293 |
+
weighting=args.weighting,
|
| 294 |
+
t_c=args.t_c,
|
| 295 |
+
ot_cls=args.ot_cls,
|
| 296 |
+
)
|
| 297 |
+
if accelerator.is_main_process:
|
| 298 |
+
logger.info(f"SiT Parameters: {sum(p.numel() for p in model.parameters()):,}")
|
| 299 |
+
|
| 300 |
+
# Setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper):
|
| 301 |
+
if args.allow_tf32:
|
| 302 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 303 |
+
torch.backends.cudnn.allow_tf32 = True
|
| 304 |
+
|
| 305 |
+
optimizer = torch.optim.AdamW(
|
| 306 |
+
model.parameters(),
|
| 307 |
+
lr=args.learning_rate,
|
| 308 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
| 309 |
+
weight_decay=args.adam_weight_decay,
|
| 310 |
+
eps=args.adam_epsilon,
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
# Setup data(train_dataset 已在上方创建)
|
| 314 |
+
local_batch_size = int(args.batch_size // accelerator.num_processes)
|
| 315 |
+
train_dataloader = DataLoader(
|
| 316 |
+
train_dataset,
|
| 317 |
+
batch_size=local_batch_size,
|
| 318 |
+
shuffle=True,
|
| 319 |
+
num_workers=args.num_workers,
|
| 320 |
+
pin_memory=True,
|
| 321 |
+
drop_last=True
|
| 322 |
+
)
|
| 323 |
+
if accelerator.is_main_process:
|
| 324 |
+
logger.info(f"Dataset contains {len(train_dataset):,} images ({args.data_dir})")
|
| 325 |
+
|
| 326 |
+
# Prepare models for training:
|
| 327 |
+
update_ema(ema, model, decay=0) # Ensure EMA is initialized with synced weights
|
| 328 |
+
model.train() # important! This enables embedding dropout for classifier-free guidance
|
| 329 |
+
ema.eval() # EMA model should always be in eval mode
|
| 330 |
+
|
| 331 |
+
# resume:
|
| 332 |
+
global_step = 0
|
| 333 |
+
if args.resume_step > 0:
|
| 334 |
+
ckpt_name = str(args.resume_step).zfill(7) +'.pt'
|
| 335 |
+
ckpt = torch.load(
|
| 336 |
+
f'{os.path.join(args.output_dir, args.exp_name)}/checkpoints/{ckpt_name}',
|
| 337 |
+
map_location='cpu',
|
| 338 |
+
)
|
| 339 |
+
model.load_state_dict(ckpt['model'])
|
| 340 |
+
ema.load_state_dict(ckpt['ema'])
|
| 341 |
+
optimizer.load_state_dict(ckpt['opt'])
|
| 342 |
+
global_step = ckpt['steps']
|
| 343 |
+
|
| 344 |
+
model, optimizer, train_dataloader = accelerator.prepare(
|
| 345 |
+
model, optimizer, train_dataloader
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
if accelerator.is_main_process:
|
| 349 |
+
tracker_config = vars(copy.deepcopy(args))
|
| 350 |
+
accelerator.init_trackers(
|
| 351 |
+
project_name="REG",
|
| 352 |
+
config=tracker_config,
|
| 353 |
+
init_kwargs={
|
| 354 |
+
"wandb": {"name": f"{args.exp_name}"}
|
| 355 |
+
},
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
progress_bar = tqdm(
|
| 360 |
+
range(0, args.max_train_steps),
|
| 361 |
+
initial=global_step,
|
| 362 |
+
desc="Steps",
|
| 363 |
+
# Only show the progress bar once on each machine.
|
| 364 |
+
disable=not accelerator.is_local_main_process,
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
# Labels to condition the model with (feel free to change):
|
| 368 |
+
sample_batch_size = 64 // accelerator.num_processes
|
| 369 |
+
first_batch = next(iter(train_dataloader))
|
| 370 |
+
if len(first_batch) == 4:
|
| 371 |
+
gt_raw_images, gt_xs, _, _ = first_batch
|
| 372 |
+
else:
|
| 373 |
+
gt_raw_images, gt_xs, _ = first_batch
|
| 374 |
+
assert gt_raw_images.shape[-1] == args.resolution
|
| 375 |
+
gt_xs = gt_xs[:sample_batch_size]
|
| 376 |
+
gt_xs = sample_posterior(
|
| 377 |
+
gt_xs.to(device), latents_scale=latents_scale, latents_bias=latents_bias
|
| 378 |
+
)
|
| 379 |
+
ys = torch.randint(1000, size=(sample_batch_size,), device=device)
|
| 380 |
+
ys = ys.to(device)
|
| 381 |
+
# Create sampling noise:
|
| 382 |
+
n = ys.size(0)
|
| 383 |
+
xT = torch.randn((n, 4, latent_size, latent_size), device=device)
|
| 384 |
+
|
| 385 |
+
for epoch in range(args.epochs):
|
| 386 |
+
model.train()
|
| 387 |
+
for batch in train_dataloader:
|
| 388 |
+
if len(batch) == 4:
|
| 389 |
+
raw_image, x, r_preprocessed, y = batch
|
| 390 |
+
use_sem_file = True
|
| 391 |
+
else:
|
| 392 |
+
raw_image, x, y = batch
|
| 393 |
+
r_preprocessed = None
|
| 394 |
+
use_sem_file = False
|
| 395 |
+
|
| 396 |
+
raw_image = raw_image.to(device)
|
| 397 |
+
x = x.squeeze(dim=1).to(device).float()
|
| 398 |
+
y = y.to(device)
|
| 399 |
+
if args.legacy:
|
| 400 |
+
# In our early experiments, we accidentally apply label dropping twice:
|
| 401 |
+
# once in train.py and once in sit.py.
|
| 402 |
+
# We keep this option for exact reproducibility with previous runs.
|
| 403 |
+
drop_ids = torch.rand(y.shape[0], device=y.device) < args.cfg_prob
|
| 404 |
+
labels = torch.where(drop_ids, args.num_classes, y)
|
| 405 |
+
else:
|
| 406 |
+
labels = y
|
| 407 |
+
with torch.no_grad():
|
| 408 |
+
x = sample_posterior(x, latents_scale=latents_scale, latents_bias=latents_bias)
|
| 409 |
+
zs = []
|
| 410 |
+
if use_sem_file and r_preprocessed is not None:
|
| 411 |
+
cls_token = r_preprocessed.to(device).float()
|
| 412 |
+
if cls_token.dim() == 1:
|
| 413 |
+
cls_token = cls_token.unsqueeze(0)
|
| 414 |
+
while cls_token.dim() > 2:
|
| 415 |
+
cls_token = cls_token.squeeze(1)
|
| 416 |
+
base_m = model.module if hasattr(model, "module") else model
|
| 417 |
+
n_pad = base_m.x_embedder.num_patches
|
| 418 |
+
zs = [
|
| 419 |
+
torch.cat(
|
| 420 |
+
[
|
| 421 |
+
cls_token.unsqueeze(1),
|
| 422 |
+
cls_token.unsqueeze(1).expand(-1, n_pad, -1),
|
| 423 |
+
],
|
| 424 |
+
dim=1,
|
| 425 |
+
)
|
| 426 |
+
]
|
| 427 |
+
else:
|
| 428 |
+
with accelerator.autocast():
|
| 429 |
+
for encoder, encoder_type, arch in zip(
|
| 430 |
+
encoders, encoder_types, architectures
|
| 431 |
+
):
|
| 432 |
+
raw_image_ = preprocess_raw_image(raw_image, encoder_type)
|
| 433 |
+
z = encoder.forward_features(raw_image_)
|
| 434 |
+
if 'dinov2' in encoder_type:
|
| 435 |
+
dense_z = z['x_norm_patchtokens']
|
| 436 |
+
cls_token = z['x_norm_clstoken']
|
| 437 |
+
dense_z = torch.cat([cls_token.unsqueeze(1), dense_z], dim=1)
|
| 438 |
+
else:
|
| 439 |
+
exit()
|
| 440 |
+
zs.append(dense_z)
|
| 441 |
+
|
| 442 |
+
with accelerator.accumulate(model):
|
| 443 |
+
model_kwargs = dict(y=labels)
|
| 444 |
+
loss1, proj_loss1, time_input, noises, loss2 = loss_fn(model, x, model_kwargs, zs=zs,
|
| 445 |
+
cls_token=cls_token,
|
| 446 |
+
time_input=None, noises=None)
|
| 447 |
+
loss_mean = loss1.mean()
|
| 448 |
+
loss_mean_cls = loss2.mean() * args.cls
|
| 449 |
+
proj_loss_mean = proj_loss1.mean() * args.proj_coeff
|
| 450 |
+
loss = loss_mean + proj_loss_mean + loss_mean_cls
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
## optimization
|
| 454 |
+
accelerator.backward(loss)
|
| 455 |
+
if accelerator.sync_gradients:
|
| 456 |
+
params_to_clip = model.parameters()
|
| 457 |
+
grad_norm = accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
| 458 |
+
optimizer.step()
|
| 459 |
+
optimizer.zero_grad(set_to_none=True)
|
| 460 |
+
|
| 461 |
+
if accelerator.sync_gradients:
|
| 462 |
+
update_ema(ema, model) # change ema function
|
| 463 |
+
|
| 464 |
+
### enter
|
| 465 |
+
if accelerator.sync_gradients:
|
| 466 |
+
progress_bar.update(1)
|
| 467 |
+
global_step += 1
|
| 468 |
+
if global_step % args.checkpointing_steps == 0 and global_step > 0:
|
| 469 |
+
if accelerator.is_main_process:
|
| 470 |
+
checkpoint = {
|
| 471 |
+
"model": model.module.state_dict(),
|
| 472 |
+
"ema": ema.state_dict(),
|
| 473 |
+
"opt": optimizer.state_dict(),
|
| 474 |
+
"args": args,
|
| 475 |
+
"steps": global_step,
|
| 476 |
+
}
|
| 477 |
+
checkpoint_path = f"{checkpoint_dir}/{global_step:07d}.pt"
|
| 478 |
+
torch.save(checkpoint, checkpoint_path)
|
| 479 |
+
logger.info(f"Saved checkpoint to {checkpoint_path}")
|
| 480 |
+
|
| 481 |
+
if (global_step == 1 or (global_step % args.sampling_steps == 0 and global_step > 0)):
|
| 482 |
+
t_mid_vis = float(args.t_c)
|
| 483 |
+
tc_tag = f"{t_mid_vis:.4f}".rstrip("0").rstrip(".").replace(".", "_")
|
| 484 |
+
logging.info(
|
| 485 |
+
f"Generating EMA samples (Euler-Maruyama; t≈{t_mid_vis:g} → t=0)..."
|
| 486 |
+
)
|
| 487 |
+
ema.eval()
|
| 488 |
+
with torch.no_grad():
|
| 489 |
+
latent_size = args.resolution // 8
|
| 490 |
+
n_samples = min(16, args.batch_size)
|
| 491 |
+
base_model = model.module if hasattr(model, "module") else model
|
| 492 |
+
cls_dim = base_model.z_dims[0]
|
| 493 |
+
shared_seed = torch.randint(0, 2**32, (1,), device=device).item()
|
| 494 |
+
torch.manual_seed(shared_seed)
|
| 495 |
+
z_init = torch.randn(n_samples, base_model.in_channels, latent_size, latent_size, device=device)
|
| 496 |
+
torch.manual_seed(shared_seed)
|
| 497 |
+
cls_init = torch.randn(n_samples, cls_dim, device=device)
|
| 498 |
+
y_samples = torch.randint(0, args.num_classes, (n_samples,), device=device)
|
| 499 |
+
|
| 500 |
+
from samplers import euler_maruyama_sampler
|
| 501 |
+
z_0, z_mid, _ = euler_maruyama_sampler(
|
| 502 |
+
ema,
|
| 503 |
+
z_init,
|
| 504 |
+
y_samples,
|
| 505 |
+
num_steps=50,
|
| 506 |
+
cfg_scale=1.0,
|
| 507 |
+
guidance_low=0.0,
|
| 508 |
+
guidance_high=1.0,
|
| 509 |
+
path_type=args.path_type,
|
| 510 |
+
cls_latents=cls_init,
|
| 511 |
+
args=args,
|
| 512 |
+
return_mid_state=True,
|
| 513 |
+
t_mid=t_mid_vis,
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
samples_root = os.path.join(args.output_dir, args.exp_name, "samples")
|
| 517 |
+
t0_dir = os.path.join(samples_root, "t0")
|
| 518 |
+
t_mid_dir = os.path.join(samples_root, f"t0_{tc_tag}")
|
| 519 |
+
os.makedirs(t0_dir, exist_ok=True)
|
| 520 |
+
os.makedirs(t_mid_dir, exist_ok=True)
|
| 521 |
+
|
| 522 |
+
if vae is not None:
|
| 523 |
+
z_f = z_0.to(dtype=torch.float32)
|
| 524 |
+
samples_final = vae.decode((z_f - latents_bias) / latents_scale).sample
|
| 525 |
+
samples_final = (samples_final + 1) / 2.0
|
| 526 |
+
samples_final = samples_final.clamp(0, 1)
|
| 527 |
+
grid_final = array2grid(samples_final)
|
| 528 |
+
Image.fromarray(grid_final).save(
|
| 529 |
+
os.path.join(t0_dir, f"step_{global_step:07d}_t0.png")
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
if z_mid is not None:
|
| 533 |
+
z_m = z_mid.to(dtype=torch.float32)
|
| 534 |
+
samples_mid = vae.decode((z_m - latents_bias) / latents_scale).sample
|
| 535 |
+
samples_mid = (samples_mid + 1) / 2.0
|
| 536 |
+
samples_mid = samples_mid.clamp(0, 1)
|
| 537 |
+
grid_mid = array2grid(samples_mid)
|
| 538 |
+
Image.fromarray(grid_mid).save(
|
| 539 |
+
os.path.join(t_mid_dir, f"step_{global_step:07d}_t0_{tc_tag}.png")
|
| 540 |
+
)
|
| 541 |
+
else:
|
| 542 |
+
logging.warning(
|
| 543 |
+
f"Sampling time grid did not bracket t_mid={t_mid_vis:g}; "
|
| 544 |
+
f"skip t0_{tc_tag} image this step."
|
| 545 |
+
)
|
| 546 |
+
|
| 547 |
+
del z_init, cls_init, y_samples, z_0
|
| 548 |
+
if z_mid is not None:
|
| 549 |
+
del z_mid
|
| 550 |
+
if vae is not None:
|
| 551 |
+
del samples_final, grid_final
|
| 552 |
+
if "samples_mid" in locals():
|
| 553 |
+
del samples_mid, grid_mid
|
| 554 |
+
torch.cuda.empty_cache()
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
logs = {
|
| 558 |
+
"loss_final": accelerator.gather(loss).mean().detach().item(),
|
| 559 |
+
"loss_mean": accelerator.gather(loss_mean).mean().detach().item(),
|
| 560 |
+
"proj_loss": accelerator.gather(proj_loss_mean).mean().detach().item(),
|
| 561 |
+
"loss_mean_cls": accelerator.gather(loss_mean_cls).mean().detach().item(),
|
| 562 |
+
"grad_norm": accelerator.gather(grad_norm).mean().detach().item()
|
| 563 |
+
}
|
| 564 |
+
|
| 565 |
+
log_message = ", ".join(f"{key}: {value:.6f}" for key, value in logs.items())
|
| 566 |
+
logging.info(f"Step: {global_step}, Training Logs: {log_message}")
|
| 567 |
+
|
| 568 |
+
progress_bar.set_postfix(**logs)
|
| 569 |
+
accelerator.log(logs, step=global_step)
|
| 570 |
+
|
| 571 |
+
if global_step >= args.max_train_steps:
|
| 572 |
+
break
|
| 573 |
+
if global_step >= args.max_train_steps:
|
| 574 |
+
break
|
| 575 |
+
|
| 576 |
+
model.eval() # important! This disables randomized embedding dropout
|
| 577 |
+
# do any sampling/FID calculation/etc. with ema (or model) in eval mode ...
|
| 578 |
+
|
| 579 |
+
accelerator.wait_for_everyone()
|
| 580 |
+
if accelerator.is_main_process:
|
| 581 |
+
logger.info("Done!")
|
| 582 |
+
accelerator.end_training()
|
| 583 |
+
|
| 584 |
+
def parse_args(input_args=None):
|
| 585 |
+
parser = argparse.ArgumentParser(description="Training")
|
| 586 |
+
|
| 587 |
+
# logging:
|
| 588 |
+
parser.add_argument("--output-dir", type=str, default="exps")
|
| 589 |
+
parser.add_argument("--exp-name", type=str, required=True)
|
| 590 |
+
parser.add_argument("--logging-dir", type=str, default="logs")
|
| 591 |
+
parser.add_argument("--report-to", type=str, default="wandb")
|
| 592 |
+
parser.add_argument("--sampling-steps", type=int, default=2000)
|
| 593 |
+
parser.add_argument("--resume-step", type=int, default=0)
|
| 594 |
+
|
| 595 |
+
# model
|
| 596 |
+
parser.add_argument("--model", type=str)
|
| 597 |
+
parser.add_argument("--num-classes", type=int, default=1000)
|
| 598 |
+
parser.add_argument("--encoder-depth", type=int, default=8)
|
| 599 |
+
parser.add_argument("--fused-attn", action=argparse.BooleanOptionalAction, default=True)
|
| 600 |
+
parser.add_argument("--qk-norm", action=argparse.BooleanOptionalAction, default=False)
|
| 601 |
+
parser.add_argument("--ops-head", type=int, default=16)
|
| 602 |
+
|
| 603 |
+
# dataset
|
| 604 |
+
parser.add_argument("--data-dir", type=str, default="../data/imagenet256")
|
| 605 |
+
parser.add_argument(
|
| 606 |
+
"--semantic-features-dir",
|
| 607 |
+
type=str,
|
| 608 |
+
default=None,
|
| 609 |
+
help="预处理 DINOv2 class token 等特征目录(含 dataset.json)。"
|
| 610 |
+
"默认 None 时若存在 data-dir/imagenet_256_features/dinov2-vit-b_tmp/gpu0 则自动使用。",
|
| 611 |
+
)
|
| 612 |
+
parser.add_argument("--resolution", type=int, choices=[256, 512], default=256)
|
| 613 |
+
parser.add_argument("--batch-size", type=int, default=256)#256
|
| 614 |
+
|
| 615 |
+
# precision
|
| 616 |
+
parser.add_argument("--allow-tf32", action="store_true")
|
| 617 |
+
parser.add_argument("--mixed-precision", type=str, default="fp16", choices=["no", "fp16", "bf16"])
|
| 618 |
+
|
| 619 |
+
# optimization
|
| 620 |
+
parser.add_argument("--epochs", type=int, default=1400)
|
| 621 |
+
parser.add_argument("--max-train-steps", type=int, default=1000000)
|
| 622 |
+
parser.add_argument("--checkpointing-steps", type=int, default=10000)
|
| 623 |
+
parser.add_argument("--gradient-accumulation-steps", type=int, default=1)
|
| 624 |
+
parser.add_argument("--learning-rate", type=float, default=1e-4)
|
| 625 |
+
parser.add_argument("--adam-beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
| 626 |
+
parser.add_argument("--adam-beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
| 627 |
+
parser.add_argument("--adam-weight-decay", type=float, default=0., help="Weight decay to use.")
|
| 628 |
+
parser.add_argument("--adam-epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
| 629 |
+
parser.add_argument("--max-grad-norm", default=1.0, type=float, help="Max gradient norm.")
|
| 630 |
+
|
| 631 |
+
# seed
|
| 632 |
+
parser.add_argument("--seed", type=int, default=0)
|
| 633 |
+
|
| 634 |
+
# cpu
|
| 635 |
+
parser.add_argument("--num-workers", type=int, default=4)
|
| 636 |
+
|
| 637 |
+
# loss
|
| 638 |
+
parser.add_argument("--path-type", type=str, default="linear", choices=["linear", "cosine"])
|
| 639 |
+
parser.add_argument("--prediction", type=str, default="v", choices=["v"]) # currently we only support v-prediction
|
| 640 |
+
parser.add_argument("--cfg-prob", type=float, default=0.1)
|
| 641 |
+
parser.add_argument("--enc-type", type=str, default='dinov2-vit-b')
|
| 642 |
+
parser.add_argument("--proj-coeff", type=float, default=0.5)
|
| 643 |
+
parser.add_argument("--weighting", default="uniform", type=str, help="Max gradient norm.")
|
| 644 |
+
parser.add_argument("--legacy", action=argparse.BooleanOptionalAction, default=False)
|
| 645 |
+
parser.add_argument("--cls", type=float, default=0.03)
|
| 646 |
+
parser.add_argument(
|
| 647 |
+
"--t-c",
|
| 648 |
+
type=float,
|
| 649 |
+
default=0.5,
|
| 650 |
+
help="语义分界时刻(与脚本内 t 约定一致:t=1 噪声→t=0 数据)。"
|
| 651 |
+
"t∈(t_c,1]:cls 沿 OT 配对后的路径插值(CFM/OT-CFM 式 minibatch OT);"
|
| 652 |
+
"t∈[0,t_c]:cls 固定为真实 encoder cls,目标 cls 速度为 0。",
|
| 653 |
+
)
|
| 654 |
+
parser.add_argument(
|
| 655 |
+
"--ot-cls",
|
| 656 |
+
action=argparse.BooleanOptionalAction,
|
| 657 |
+
default=True,
|
| 658 |
+
help="在 t>t_c 段对 cls 噪声与 batch 内 cls_gt 做 minibatch 最优传输配对(需 scipy);关闭则退化为独立高斯噪声配对。",
|
| 659 |
+
)
|
| 660 |
+
if input_args is not None:
|
| 661 |
+
args = parser.parse_args(input_args)
|
| 662 |
+
else:
|
| 663 |
+
args = parser.parse_args()
|
| 664 |
+
|
| 665 |
+
return args
|
| 666 |
+
|
| 667 |
+
if __name__ == "__main__":
|
| 668 |
+
args = parse_args()
|
| 669 |
+
|
| 670 |
+
main(args)
|
back/train.sh
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# REG/train.py:与主仓库类似,可单独指定数据根目录与预处理 cls 特征目录。
|
| 3 |
+
# 数据布局:${DATA_DIR}/imagenet_256_vae/ 下 VAE latent;
|
| 4 |
+
# ${SEMANTIC_FEATURES_DIR}/ 下 img-feature-*.npy + dataset.json(与 parallel_encode 一致)。
|
| 5 |
+
|
| 6 |
+
NUM_GPUS=4
|
| 7 |
+
|
| 8 |
+
# ------------ 按本机路径修改 ------------
|
| 9 |
+
DATA_DIR="/gemini/space/zhaozy/dataset/Imagenet/imagenet_256"
|
| 10 |
+
SEMANTIC_FEATURES_DIR="/gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0"
|
| 11 |
+
|
| 12 |
+
# 后台示例(与主实验脚本风格一致):
|
| 13 |
+
# nohup bash train.sh > jsflow-experiment.log 2>&1 &
|
| 14 |
+
|
| 15 |
+
nohup accelerate launch --multi_gpu --num_processes "${NUM_GPUS}" --mixed_precision bf16 train.py \
|
| 16 |
+
--report-to wandb \
|
| 17 |
+
--allow-tf32 \
|
| 18 |
+
--mixed-precision bf16 \
|
| 19 |
+
--seed 0 \
|
| 20 |
+
--path-type linear \
|
| 21 |
+
--prediction v \
|
| 22 |
+
--weighting uniform \
|
| 23 |
+
--model SiT-XL/2 \
|
| 24 |
+
--enc-type dinov2-vit-b \
|
| 25 |
+
--encoder-depth 8 \
|
| 26 |
+
--proj-coeff 0.5 \
|
| 27 |
+
--output-dir exps \
|
| 28 |
+
--exp-name jsflow-experiment-0.75 \
|
| 29 |
+
--batch-size 256 \
|
| 30 |
+
--data-dir "${DATA_DIR}" \
|
| 31 |
+
--semantic-features-dir "${SEMANTIC_FEATURES_DIR}" \
|
| 32 |
+
--learning-rate 0.00005 \
|
| 33 |
+
--t-c 0.75 \
|
| 34 |
+
--cls 0.05 \
|
| 35 |
+
--ot-cls \
|
| 36 |
+
> jsflow-experiment.log 2>&1 &
|
| 37 |
+
|
| 38 |
+
# 说明:
|
| 39 |
+
# - 不使用预处理特征、改在线抽 DINO 时:去掉 --semantic-features-dir,并保证 data-dir 为 REG 原布局
|
| 40 |
+
# (imagenet_256_vae + vae-sd)。
|
| 41 |
+
# - 关闭 minibatch OT:追加 --no-ot-cls。
|
| 42 |
+
# - 主仓库 train.py 中的 --weight-ratio / --semantic-reg-coeff / --repa-* 等为本 REG 脚本未实现项;
|
| 43 |
+
# 投影强度请用 --proj-coeff,cls 流损失权重用 --cls。
|
back/utils.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from torchvision.datasets.utils import download_url
|
| 3 |
+
import torch
|
| 4 |
+
import torchvision.models as torchvision_models
|
| 5 |
+
import timm
|
| 6 |
+
from models import mocov3_vit
|
| 7 |
+
import math
|
| 8 |
+
import warnings
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# code from SiT repository
|
| 12 |
+
pretrained_models = {'last.pt'}
|
| 13 |
+
|
| 14 |
+
def download_model(model_name):
|
| 15 |
+
"""
|
| 16 |
+
Downloads a pre-trained SiT model from the web.
|
| 17 |
+
"""
|
| 18 |
+
assert model_name in pretrained_models
|
| 19 |
+
local_path = f'pretrained_models/{model_name}'
|
| 20 |
+
if not os.path.isfile(local_path):
|
| 21 |
+
os.makedirs('pretrained_models', exist_ok=True)
|
| 22 |
+
web_path = f'https://www.dl.dropboxusercontent.com/scl/fi/cxedbs4da5ugjq5wg3zrg/last.pt?rlkey=8otgrdkno0nd89po3dpwngwcc&st=apcc645o&dl=0'
|
| 23 |
+
download_url(web_path, 'pretrained_models', filename=model_name)
|
| 24 |
+
model = torch.load(local_path, map_location=lambda storage, loc: storage)
|
| 25 |
+
return model
|
| 26 |
+
|
| 27 |
+
def fix_mocov3_state_dict(state_dict):
|
| 28 |
+
for k in list(state_dict.keys()):
|
| 29 |
+
# retain only base_encoder up to before the embedding layer
|
| 30 |
+
if k.startswith('module.base_encoder'):
|
| 31 |
+
# fix naming bug in checkpoint
|
| 32 |
+
new_k = k[len("module.base_encoder."):]
|
| 33 |
+
if "blocks.13.norm13" in new_k:
|
| 34 |
+
new_k = new_k.replace("norm13", "norm1")
|
| 35 |
+
if "blocks.13.mlp.fc13" in k:
|
| 36 |
+
new_k = new_k.replace("fc13", "fc1")
|
| 37 |
+
if "blocks.14.norm14" in k:
|
| 38 |
+
new_k = new_k.replace("norm14", "norm2")
|
| 39 |
+
if "blocks.14.mlp.fc14" in k:
|
| 40 |
+
new_k = new_k.replace("fc14", "fc2")
|
| 41 |
+
# remove prefix
|
| 42 |
+
if 'head' not in new_k and new_k.split('.')[0] != 'fc':
|
| 43 |
+
state_dict[new_k] = state_dict[k]
|
| 44 |
+
# delete renamed or unused k
|
| 45 |
+
del state_dict[k]
|
| 46 |
+
if 'pos_embed' in state_dict.keys():
|
| 47 |
+
state_dict['pos_embed'] = timm.layers.pos_embed.resample_abs_pos_embed(
|
| 48 |
+
state_dict['pos_embed'], [16, 16],
|
| 49 |
+
)
|
| 50 |
+
return state_dict
|
| 51 |
+
|
| 52 |
+
@torch.no_grad()
|
| 53 |
+
def load_encoders(enc_type, device, resolution=256):
|
| 54 |
+
assert (resolution == 256) or (resolution == 512)
|
| 55 |
+
|
| 56 |
+
enc_names = enc_type.split(',')
|
| 57 |
+
encoders, architectures, encoder_types = [], [], []
|
| 58 |
+
for enc_name in enc_names:
|
| 59 |
+
encoder_type, architecture, model_config = enc_name.split('-')
|
| 60 |
+
# Currently, we only support 512x512 experiments with DINOv2 encoders.
|
| 61 |
+
if resolution == 512:
|
| 62 |
+
if encoder_type != 'dinov2':
|
| 63 |
+
raise NotImplementedError(
|
| 64 |
+
"Currently, we only support 512x512 experiments with DINOv2 encoders."
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
architectures.append(architecture)
|
| 68 |
+
encoder_types.append(encoder_type)
|
| 69 |
+
if encoder_type == 'mocov3':
|
| 70 |
+
if architecture == 'vit':
|
| 71 |
+
if model_config == 's':
|
| 72 |
+
encoder = mocov3_vit.vit_small()
|
| 73 |
+
elif model_config == 'b':
|
| 74 |
+
encoder = mocov3_vit.vit_base()
|
| 75 |
+
elif model_config == 'l':
|
| 76 |
+
encoder = mocov3_vit.vit_large()
|
| 77 |
+
ckpt = torch.load(f'./ckpts/mocov3_vit{model_config}.pth')
|
| 78 |
+
state_dict = fix_mocov3_state_dict(ckpt['state_dict'])
|
| 79 |
+
del encoder.head
|
| 80 |
+
encoder.load_state_dict(state_dict, strict=True)
|
| 81 |
+
encoder.head = torch.nn.Identity()
|
| 82 |
+
elif architecture == 'resnet':
|
| 83 |
+
raise NotImplementedError()
|
| 84 |
+
|
| 85 |
+
encoder = encoder.to(device)
|
| 86 |
+
encoder.eval()
|
| 87 |
+
|
| 88 |
+
elif 'dinov2' in encoder_type:
|
| 89 |
+
import timm
|
| 90 |
+
if 'reg' in encoder_type:
|
| 91 |
+
try:
|
| 92 |
+
encoder = torch.hub.load('your_path/.cache/torch/hub/facebookresearch_dinov2_main',
|
| 93 |
+
f'dinov2_vit{model_config}14_reg', source='local')
|
| 94 |
+
except:
|
| 95 |
+
encoder = torch.hub.load('facebookresearch/dinov2', f'dinov2_vit{model_config}14_reg')
|
| 96 |
+
else:
|
| 97 |
+
try:
|
| 98 |
+
encoder = torch.hub.load('your_path/.cache/torch/hub/facebookresearch_dinov2_main',
|
| 99 |
+
f'dinov2_vit{model_config}14', source='local')
|
| 100 |
+
except:
|
| 101 |
+
encoder = torch.hub.load('facebookresearch/dinov2', f'dinov2_vit{model_config}14')
|
| 102 |
+
|
| 103 |
+
print(f"Now you are using the {enc_name} as the aligning model")
|
| 104 |
+
del encoder.head
|
| 105 |
+
patch_resolution = 16 * (resolution // 256)
|
| 106 |
+
encoder.pos_embed.data = timm.layers.pos_embed.resample_abs_pos_embed(
|
| 107 |
+
encoder.pos_embed.data, [patch_resolution, patch_resolution],
|
| 108 |
+
)
|
| 109 |
+
encoder.head = torch.nn.Identity()
|
| 110 |
+
encoder = encoder.to(device)
|
| 111 |
+
encoder.eval()
|
| 112 |
+
|
| 113 |
+
elif 'dinov1' == encoder_type:
|
| 114 |
+
import timm
|
| 115 |
+
from models import dinov1
|
| 116 |
+
encoder = dinov1.vit_base()
|
| 117 |
+
ckpt = torch.load(f'./ckpts/dinov1_vit{model_config}.pth')
|
| 118 |
+
if 'pos_embed' in ckpt.keys():
|
| 119 |
+
ckpt['pos_embed'] = timm.layers.pos_embed.resample_abs_pos_embed(
|
| 120 |
+
ckpt['pos_embed'], [16, 16],
|
| 121 |
+
)
|
| 122 |
+
del encoder.head
|
| 123 |
+
encoder.head = torch.nn.Identity()
|
| 124 |
+
encoder.load_state_dict(ckpt, strict=True)
|
| 125 |
+
encoder = encoder.to(device)
|
| 126 |
+
encoder.forward_features = encoder.forward
|
| 127 |
+
encoder.eval()
|
| 128 |
+
|
| 129 |
+
elif encoder_type == 'clip':
|
| 130 |
+
import clip
|
| 131 |
+
from models.clip_vit import UpdatedVisionTransformer
|
| 132 |
+
encoder_ = clip.load(f"ViT-{model_config}/14", device='cpu')[0].visual
|
| 133 |
+
encoder = UpdatedVisionTransformer(encoder_).to(device)
|
| 134 |
+
#.to(device)
|
| 135 |
+
encoder.embed_dim = encoder.model.transformer.width
|
| 136 |
+
encoder.forward_features = encoder.forward
|
| 137 |
+
encoder.eval()
|
| 138 |
+
|
| 139 |
+
elif encoder_type == 'mae':
|
| 140 |
+
from models.mae_vit import vit_large_patch16
|
| 141 |
+
import timm
|
| 142 |
+
kwargs = dict(img_size=256)
|
| 143 |
+
encoder = vit_large_patch16(**kwargs).to(device)
|
| 144 |
+
with open(f"ckpts/mae_vit{model_config}.pth", "rb") as f:
|
| 145 |
+
state_dict = torch.load(f)
|
| 146 |
+
if 'pos_embed' in state_dict["model"].keys():
|
| 147 |
+
state_dict["model"]['pos_embed'] = timm.layers.pos_embed.resample_abs_pos_embed(
|
| 148 |
+
state_dict["model"]['pos_embed'], [16, 16],
|
| 149 |
+
)
|
| 150 |
+
encoder.load_state_dict(state_dict["model"])
|
| 151 |
+
|
| 152 |
+
encoder.pos_embed.data = timm.layers.pos_embed.resample_abs_pos_embed(
|
| 153 |
+
encoder.pos_embed.data, [16, 16],
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
elif encoder_type == 'jepa':
|
| 157 |
+
from models.jepa import vit_huge
|
| 158 |
+
kwargs = dict(img_size=[224, 224], patch_size=14)
|
| 159 |
+
encoder = vit_huge(**kwargs).to(device)
|
| 160 |
+
with open(f"ckpts/ijepa_vit{model_config}.pth", "rb") as f:
|
| 161 |
+
state_dict = torch.load(f, map_location=device)
|
| 162 |
+
new_state_dict = dict()
|
| 163 |
+
for key, value in state_dict['encoder'].items():
|
| 164 |
+
new_state_dict[key[7:]] = value
|
| 165 |
+
encoder.load_state_dict(new_state_dict)
|
| 166 |
+
encoder.forward_features = encoder.forward
|
| 167 |
+
|
| 168 |
+
encoders.append(encoder)
|
| 169 |
+
|
| 170 |
+
return encoders, encoder_types, architectures
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
| 174 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
| 175 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
| 176 |
+
def norm_cdf(x):
|
| 177 |
+
# Computes standard normal cumulative distribution function
|
| 178 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
| 179 |
+
|
| 180 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
| 181 |
+
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
| 182 |
+
"The distribution of values may be incorrect.",
|
| 183 |
+
stacklevel=2)
|
| 184 |
+
|
| 185 |
+
with torch.no_grad():
|
| 186 |
+
# Values are generated by using a truncated uniform distribution and
|
| 187 |
+
# then using the inverse CDF for the normal distribution.
|
| 188 |
+
# Get upper and lower cdf values
|
| 189 |
+
l = norm_cdf((a - mean) / std)
|
| 190 |
+
u = norm_cdf((b - mean) / std)
|
| 191 |
+
|
| 192 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
| 193 |
+
# [2l-1, 2u-1].
|
| 194 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
| 195 |
+
|
| 196 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
| 197 |
+
# standard normal
|
| 198 |
+
tensor.erfinv_()
|
| 199 |
+
|
| 200 |
+
# Transform to proper mean, std
|
| 201 |
+
tensor.mul_(std * math.sqrt(2.))
|
| 202 |
+
tensor.add_(mean)
|
| 203 |
+
|
| 204 |
+
# Clamp to ensure it's in the proper range
|
| 205 |
+
tensor.clamp_(min=a, max=b)
|
| 206 |
+
return tensor
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
| 210 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def load_legacy_checkpoints(state_dict, encoder_depth):
|
| 214 |
+
new_state_dict = dict()
|
| 215 |
+
for key, value in state_dict.items():
|
| 216 |
+
if 'decoder_blocks' in key:
|
| 217 |
+
parts =key.split('.')
|
| 218 |
+
new_idx = int(parts[1]) + encoder_depth
|
| 219 |
+
parts[0] = 'blocks'
|
| 220 |
+
parts[1] = str(new_idx)
|
| 221 |
+
new_key = '.'.join(parts)
|
| 222 |
+
new_state_dict[new_key] = value
|
| 223 |
+
else:
|
| 224 |
+
new_state_dict[key] = value
|
| 225 |
+
return new_state_dict
|