xiangzai commited on
Commit
db5d40d
·
verified ·
1 Parent(s): 3e4f775

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. New/REG/LICENSE +21 -0
  2. New/REG/__pycache__/dataset.cpython-312.pyc +0 -0
  3. New/REG/__pycache__/dataset.cpython-313.pyc +0 -0
  4. New/REG/__pycache__/generate.cpython-312.pyc +0 -0
  5. New/REG/__pycache__/generate.cpython-313.pyc +0 -0
  6. New/REG/__pycache__/loss.cpython-312.pyc +0 -0
  7. New/REG/__pycache__/samplers.cpython-312.pyc +0 -0
  8. New/REG/__pycache__/samplers.cpython-313.pyc +0 -0
  9. New/REG/__pycache__/train.cpython-313.pyc +0 -0
  10. New/REG/__pycache__/utils.cpython-312.pyc +0 -0
  11. New/REG/dataset.py +165 -0
  12. New/REG/evaluations/README.md +72 -0
  13. New/REG/evaluations/evaluator.py +679 -0
  14. New/REG/evaluations/requirements.txt +4 -0
  15. New/REG/generate_2400000_100.sh +137 -0
  16. New/REG/models/__pycache__/mocov3_vit.cpython-310.pyc +0 -0
  17. New/REG/models/__pycache__/mocov3_vit.cpython-312.pyc +0 -0
  18. New/REG/models/__pycache__/sit.cpython-310.pyc +0 -0
  19. New/REG/models/__pycache__/sit.cpython-312.pyc +0 -0
  20. New/REG/models/clip_vit.py +426 -0
  21. New/REG/models/jepa.py +547 -0
  22. New/REG/models/mae_vit.py +71 -0
  23. New/REG/models/mocov3_vit.py +207 -0
  24. New/REG/models/sit.py +420 -0
  25. New/REG/preprocessing/README.md +25 -0
  26. New/REG/preprocessing/dataset_image_encoder.py +353 -0
  27. New/REG/preprocessing/dataset_prepare_convert.sh +11 -0
  28. New/REG/preprocessing/dataset_prepare_encode.sh +9 -0
  29. New/REG/preprocessing/dataset_tools.py +422 -0
  30. New/REG/preprocessing/dnnlib/__init__.py +8 -0
  31. New/REG/preprocessing/dnnlib/util.py +485 -0
  32. New/REG/preprocessing/encoders.py +103 -0
  33. New/REG/preprocessing/torch_utils/__init__.py +8 -0
  34. New/REG/preprocessing/torch_utils/distributed.py +140 -0
  35. New/REG/preprocessing/torch_utils/misc.py +277 -0
  36. New/REG/preprocessing/torch_utils/persistence.py +257 -0
  37. New/REG/preprocessing/torch_utils/training_stats.py +283 -0
  38. New/REG/samplers.py +264 -0
  39. New/REG/train.py +573 -0
  40. New/REG/train.sh +42 -0
  41. New/REG/utils.py +225 -0
  42. New/REG/wandb/debug-internal.log +7 -0
  43. New/REG/wandb/debug.log +20 -0
  44. New/REG/wandb/run-20260326_123101-m3lli51t/files/output.log +0 -0
  45. New/REG/wandb/run-20260326_123101-m3lli51t/files/requirements.txt +168 -0
  46. New/REG/wandb/run-20260326_123101-m3lli51t/files/wandb-metadata.json +80 -0
  47. New/REG/wandb/run-20260326_123101-m3lli51t/logs/debug-internal.log +8 -0
  48. New/REG/wandb/run-20260326_123101-m3lli51t/logs/debug.log +20 -0
  49. New/REG/wandb/run-20260326_130847-0e5vs4f8/files/requirements.txt +168 -0
  50. New/REG/wandb/run-20260326_130847-0e5vs4f8/files/wandb-metadata.json +80 -0
New/REG/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Sihyun Yu
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
New/REG/__pycache__/dataset.cpython-312.pyc ADDED
Binary file (10.3 kB). View file
 
New/REG/__pycache__/dataset.cpython-313.pyc ADDED
Binary file (10.6 kB). View file
 
New/REG/__pycache__/generate.cpython-312.pyc ADDED
Binary file (13.6 kB). View file
 
New/REG/__pycache__/generate.cpython-313.pyc ADDED
Binary file (13.4 kB). View file
 
New/REG/__pycache__/loss.cpython-312.pyc ADDED
Binary file (5.48 kB). View file
 
New/REG/__pycache__/samplers.cpython-312.pyc ADDED
Binary file (11.5 kB). View file
 
New/REG/__pycache__/samplers.cpython-313.pyc ADDED
Binary file (10.6 kB). View file
 
New/REG/__pycache__/train.cpython-313.pyc ADDED
Binary file (30.2 kB). View file
 
New/REG/__pycache__/utils.cpython-312.pyc ADDED
Binary file (10.8 kB). View file
 
New/REG/dataset.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+ import numpy as np
7
+
8
+ from PIL import Image
9
+ import PIL.Image
10
+
11
+ try:
12
+ import pyspng
13
+ except ImportError:
14
+ pyspng = None
15
+
16
+
17
+ class CustomDataset(Dataset):
18
+ """
19
+ New/REG 的数据集对齐:支持和 REG 相同的两种模式:
20
+ 1) 仅加载 VAE latent + vae-sd features(与原本 New 行为一致)
21
+ 2) 若提供 semantic_features_dir,则加载预处理 semantic cls token,并在 __getitem__ 返回 4 元组:
22
+ (raw_image_latent, x_latent, semantic_cls_token, y_label)
23
+ """
24
+
25
+ def __init__(self, data_dir, semantic_features_dir=None):
26
+ PIL.Image.init()
27
+ supported_ext = PIL.Image.EXTENSION.keys() | {".npy"}
28
+
29
+ self.images_dir = os.path.join(data_dir, "imagenet_256_vae")
30
+ self.semantic_features_dir = semantic_features_dir
31
+ self.use_preprocessed_semantic = False
32
+
33
+ # 1) 预处理 semantic 模式(与 REG/dataset.py 一致)
34
+ if semantic_features_dir is None:
35
+ potential_semantic_dir = os.path.join(
36
+ data_dir, "imagenet_256_features", "dinov2-vit-b_tmp", "gpu0"
37
+ )
38
+ if os.path.exists(potential_semantic_dir):
39
+ self.semantic_features_dir = potential_semantic_dir
40
+ self.use_preprocessed_semantic = True
41
+ print(f"Found preprocessed semantic features at: {self.semantic_features_dir}")
42
+ else:
43
+ self.use_preprocessed_semantic = True
44
+
45
+ if self.use_preprocessed_semantic:
46
+ if self.semantic_features_dir is None:
47
+ raise ValueError("semantic_features_dir is None but use_preprocessed_semantic=True")
48
+
49
+ label_fname = os.path.join(self.semantic_features_dir, "dataset.json")
50
+ if not os.path.exists(label_fname):
51
+ raise FileNotFoundError(f"Label file not found: {label_fname}")
52
+
53
+ print(f"Using {label_fname}.")
54
+ with open(label_fname, "rb") as f:
55
+ data = json.load(f)
56
+
57
+ labels_list = data.get("labels", None)
58
+ if labels_list is None:
59
+ raise ValueError(f"'labels' field is missing in {label_fname}")
60
+
61
+ semantic_fnames = []
62
+ labels = []
63
+ for entry in labels_list:
64
+ if entry is None:
65
+ continue
66
+ fname, lab = entry
67
+ semantic_fnames.append(fname)
68
+ labels.append(0 if lab is None else lab)
69
+
70
+ self.semantic_fnames = semantic_fnames
71
+ self.labels = np.array(labels, dtype=np.int64)
72
+ self.num_samples = len(self.semantic_fnames)
73
+ print(f"Loaded {self.num_samples} semantic entries from dataset.json")
74
+
75
+ # semantic 模式下无需额外维护 vae-sd feature 列表
76
+ return
77
+
78
+ # 2) 非预处理 semantic 模式:VAE latent + vae-sd features
79
+ self.features_dir = os.path.join(data_dir, "vae-sd")
80
+
81
+ # images
82
+ self._image_fnames = {
83
+ os.path.relpath(os.path.join(root, fname), start=self.images_dir)
84
+ for root, _dirs, files in os.walk(self.images_dir)
85
+ for fname in files
86
+ }
87
+ self.image_fnames = sorted(
88
+ fname for fname in self._image_fnames if self._file_ext(fname) in supported_ext
89
+ )
90
+
91
+ # features
92
+ self._feature_fnames = {
93
+ os.path.relpath(os.path.join(root, fname), start=self.features_dir)
94
+ for root, _dirs, files in os.walk(self.features_dir)
95
+ for fname in files
96
+ }
97
+ self.feature_fnames = sorted(
98
+ fname for fname in self._feature_fnames if self._file_ext(fname) in supported_ext
99
+ )
100
+
101
+ # labels
102
+ fname = os.path.join(self.features_dir, "dataset.json")
103
+ if os.path.exists(fname):
104
+ print(f"Using {fname}.")
105
+ else:
106
+ raise FileNotFoundError("Neither of the specified files exists.")
107
+
108
+ with open(fname, "rb") as f:
109
+ labels = json.load(f)["labels"]
110
+ labels = dict(labels)
111
+ labels = [labels[fname.replace("\\", "/")] for fname in self.feature_fnames]
112
+ labels = np.array(labels)
113
+ self.labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])
114
+
115
+ def _file_ext(self, fname):
116
+ return os.path.splitext(fname)[1].lower()
117
+
118
+ def __len__(self):
119
+ if self.use_preprocessed_semantic:
120
+ return self.num_samples
121
+ assert len(self.image_fnames) == len(self.feature_fnames), (
122
+ "Number of feature files and label files should be same"
123
+ )
124
+ return len(self.feature_fnames)
125
+
126
+ def __getitem__(self, idx):
127
+ if self.use_preprocessed_semantic:
128
+ semantic_fname = self.semantic_fnames[idx]
129
+ basename = os.path.basename(semantic_fname)
130
+ idx_str = basename.split("-")[-1].split(".")[0]
131
+ subdir = idx_str[:5]
132
+ vae_relpath = os.path.join(subdir, f"img-mean-std-{idx_str}.npy")
133
+ vae_path = os.path.join(self.images_dir, vae_relpath)
134
+
135
+ with open(vae_path, "rb") as f:
136
+ image = np.load(f)
137
+
138
+ semantic_path = os.path.join(self.semantic_features_dir, semantic_fname)
139
+ semantic_features = np.load(semantic_path)
140
+
141
+ y = torch.tensor(self.labels[idx])
142
+ return (
143
+ torch.from_numpy(image).float(),
144
+ torch.from_numpy(image).float(),
145
+ torch.from_numpy(semantic_features).float(),
146
+ y,
147
+ )
148
+
149
+ # ---- non-preprocessed mode ----
150
+ image_fname = self.image_fnames[idx]
151
+ feature_fname = self.feature_fnames[idx]
152
+ image_ext = self._file_ext(image_fname)
153
+ with open(os.path.join(self.images_dir, image_fname), "rb") as f:
154
+ if image_ext == ".npy":
155
+ image = np.load(f)
156
+ image = image.reshape(-1, *image.shape[-2:])
157
+ elif image_ext == ".png" and pyspng is not None:
158
+ image = pyspng.load(f.read())
159
+ image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
160
+ else:
161
+ image = np.array(PIL.Image.open(f))
162
+ image = image.reshape(*image.shape[:2], -1).transpose(2, 0, 1)
163
+
164
+ features = np.load(os.path.join(self.features_dir, feature_fname))
165
+ return torch.from_numpy(image), torch.from_numpy(features), torch.tensor(self.labels[idx])
New/REG/evaluations/README.md ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Evaluations
2
+
3
+ To compare different generative models, we use FID, sFID, Precision, Recall, and Inception Score. These metrics can all be calculated using batches of samples, which we store in `.npz` (numpy) files.
4
+
5
+ # Download batches
6
+
7
+ We provide pre-computed sample batches for the reference datasets, our diffusion models, and several baselines we compare against. These are all stored in `.npz` format.
8
+
9
+ Reference dataset batches contain pre-computed statistics over the whole dataset, as well as 10,000 images for computing Precision and Recall. All other batches contain 50,000 images which can be used to compute statistics and Precision/Recall.
10
+
11
+ Here are links to download all of the sample and reference batches:
12
+
13
+ * LSUN
14
+ * LSUN bedroom: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/VIRTUAL_lsun_bedroom256.npz)
15
+ * [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/admnet_dropout_lsun_bedroom.npz)
16
+ * [DDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/ddpm_lsun_bedroom.npz)
17
+ * [IDDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/iddpm_lsun_bedroom.npz)
18
+ * [StyleGAN](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/stylegan_lsun_bedroom.npz)
19
+ * LSUN cat: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/VIRTUAL_lsun_cat256.npz)
20
+ * [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/admnet_dropout_lsun_cat.npz)
21
+ * [StyleGAN2](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/stylegan2_lsun_cat.npz)
22
+ * LSUN horse: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/VIRTUAL_lsun_horse256.npz)
23
+ * [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/admnet_dropout_lsun_horse.npz)
24
+ * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/admnet_lsun_horse.npz)
25
+
26
+ * ImageNet
27
+ * ImageNet 64x64: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/VIRTUAL_imagenet64_labeled.npz)
28
+ * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/admnet_imagenet64.npz)
29
+ * [IDDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/iddpm_imagenet64.npz)
30
+ * [BigGAN](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/biggan_deep_imagenet64.npz)
31
+ * ImageNet 128x128: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/VIRTUAL_imagenet128_labeled.npz)
32
+ * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_imagenet128.npz)
33
+ * [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_guided_imagenet128.npz)
34
+ * [ADM-G, 25 steps](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_guided_25step_imagenet128.npz)
35
+ * [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/biggan_deep_trunc1_imagenet128.npz)
36
+ * ImageNet 256x256: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz)
37
+ * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_imagenet256.npz)
38
+ * [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_imagenet256.npz)
39
+ * [ADM-G, 25 step](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_25step_imagenet256.npz)
40
+ * [ADM-G + ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_upsampled_imagenet256.npz)
41
+ * [ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_upsampled_imagenet256.npz)
42
+ * [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/biggan_deep_trunc1_imagenet256.npz)
43
+ * ImageNet 512x512: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/VIRTUAL_imagenet512.npz)
44
+ * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_imagenet512.npz)
45
+ * [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_imagenet512.npz)
46
+ * [ADM-G, 25 step](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_25step_imagenet512.npz)
47
+ * [ADM-G + ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_upsampled_imagenet512.npz)
48
+ * [ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_upsampled_imagenet512.npz)
49
+ * [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/biggan_deep_trunc1_imagenet512.npz)
50
+
51
+ # Run evaluations
52
+
53
+ First, generate or download a batch of samples and download the corresponding reference batch for the given dataset. For this example, we'll use ImageNet 256x256, so the refernce batch is `VIRTUAL_imagenet256_labeled.npz` and we can use the sample batch `admnet_guided_upsampled_imagenet256.npz`.
54
+
55
+ Next, run the `evaluator.py` script. The requirements of this script can be found in [requirements.txt](requirements.txt). Pass two arguments to the script: the reference batch and the sample batch. The script will download the InceptionV3 model used for evaluations into the current working directory (if it is not already present). This file is roughly 100MB.
56
+
57
+ The output of the script will look something like this, where the first `...` is a bunch of verbose TensorFlow logging:
58
+
59
+ ```
60
+ $ python evaluator.py VIRTUAL_imagenet256_labeled.npz admnet_guided_upsampled_imagenet256.npz
61
+ ...
62
+ computing reference batch activations...
63
+ computing/reading reference batch statistics...
64
+ computing sample batch activations...
65
+ computing/reading sample batch statistics...
66
+ Computing evaluations...
67
+ Inception Score: 215.8370361328125
68
+ FID: 3.9425574129223264
69
+ sFID: 6.140433703346162
70
+ Precision: 0.8265
71
+ Recall: 0.5309
72
+ ```
New/REG/evaluations/evaluator.py ADDED
@@ -0,0 +1,679 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import io
3
+ import os
4
+ import random
5
+ import warnings
6
+ import zipfile
7
+ from abc import ABC, abstractmethod
8
+ from contextlib import contextmanager
9
+ from functools import partial
10
+ from multiprocessing import cpu_count
11
+ from multiprocessing.pool import ThreadPool
12
+ from typing import Iterable, Optional, Tuple
13
+
14
+ import numpy as np
15
+ import requests
16
+ import tensorflow.compat.v1 as tf
17
+ from scipy import linalg
18
+ from tqdm.auto import tqdm
19
+
20
+ INCEPTION_V3_URL = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb"
21
+ INCEPTION_V3_PATH = "classify_image_graph_def.pb"
22
+
23
+ FID_POOL_NAME = "pool_3:0"
24
+ FID_SPATIAL_NAME = "mixed_6/conv:0"
25
+
26
+
27
+ def main():
28
+ parser = argparse.ArgumentParser()
29
+ parser.add_argument("--ref_batch", help="path to reference batch npz file")
30
+ parser.add_argument("--sample_batch", help="path to sample batch npz file")
31
+ parser.add_argument("--save_path", help="path to sample batch npz file")
32
+ parser.add_argument("--cfg_cond", default=1, type=int)
33
+ parser.add_argument("--step", default=1, type=int)
34
+ parser.add_argument("--cfg", default=1.0, type=float)
35
+ parser.add_argument("--cls_cfg", default=1.0, type=float)
36
+ parser.add_argument("--gh", default=1.0, type=float)
37
+ parser.add_argument("--num_steps", default=250, type=int)
38
+ args = parser.parse_args()
39
+
40
+ if not os.path.exists(args.save_path):
41
+ os.mkdir(args.save_path)
42
+
43
+
44
+ config = tf.ConfigProto(
45
+ allow_soft_placement=True # allows DecodeJpeg to run on CPU in Inception graph
46
+ )
47
+ config.gpu_options.allow_growth = True
48
+ evaluator = Evaluator(tf.Session(config=config))
49
+
50
+ print("warming up TensorFlow...")
51
+ # This will cause TF to print a bunch of verbose stuff now rather
52
+ # than after the next print(), to help prevent confusion.
53
+ evaluator.warmup()
54
+
55
+ print("computing reference batch activations...")
56
+ ref_acts = evaluator.read_activations(args.ref_batch)
57
+ print("computing/reading reference batch statistics...")
58
+ ref_stats, ref_stats_spatial = evaluator.read_statistics(args.ref_batch, ref_acts)
59
+
60
+ print("computing sample batch activations...")
61
+ sample_acts = evaluator.read_activations(args.sample_batch)
62
+ print("computing/reading sample batch statistics...")
63
+ sample_stats, sample_stats_spatial = evaluator.read_statistics(args.sample_batch, sample_acts)
64
+
65
+ print("Computing evaluations...")
66
+ Inception_Score = evaluator.compute_inception_score(sample_acts[0])
67
+ FID = sample_stats.frechet_distance(ref_stats)
68
+ sFID = sample_stats_spatial.frechet_distance(ref_stats_spatial)
69
+ prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0])
70
+
71
+ print("Inception Score:", Inception_Score)
72
+ print("FID:", FID)
73
+ print("sFID:", sFID)
74
+ print("Precision:", prec)
75
+ print("Recall:", recall)
76
+
77
+ if args.cfg_cond:
78
+ file_path = args.save_path + str(args.num_steps) + str(args.step) + str(args.cfg) + str(args.gh) + str(args.cls_cfg)+ "cfg_cond_true.txt"
79
+ else:
80
+ file_path = args.save_path + str(args.num_steps) + str(args.step) + str(args.cfg) + str(args.gh) + str(args.cls_cfg)+ "cfg_cond_false.txt"
81
+ with open(file_path, "w") as file:
82
+ file.write("Inception Score: {}\n".format(Inception_Score))
83
+ file.write("FID: {}\n".format(FID))
84
+ file.write("sFID: {}\n".format(sFID))
85
+ file.write("Precision: {}\n".format(prec))
86
+ file.write("Recall: {}\n".format(recall))
87
+
88
+
89
+ class InvalidFIDException(Exception):
90
+ pass
91
+
92
+
93
+ class FIDStatistics:
94
+ def __init__(self, mu: np.ndarray, sigma: np.ndarray):
95
+ self.mu = mu
96
+ self.sigma = sigma
97
+
98
+ def frechet_distance(self, other, eps=1e-6):
99
+ """
100
+ Compute the Frechet distance between two sets of statistics.
101
+ """
102
+ # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132
103
+ mu1, sigma1 = self.mu, self.sigma
104
+ mu2, sigma2 = other.mu, other.sigma
105
+
106
+ mu1 = np.atleast_1d(mu1)
107
+ mu2 = np.atleast_1d(mu2)
108
+
109
+ sigma1 = np.atleast_2d(sigma1)
110
+ sigma2 = np.atleast_2d(sigma2)
111
+
112
+ assert (
113
+ mu1.shape == mu2.shape
114
+ ), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}"
115
+ assert (
116
+ sigma1.shape == sigma2.shape
117
+ ), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}"
118
+
119
+ diff = mu1 - mu2
120
+
121
+ # product might be almost singular
122
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
123
+ if not np.isfinite(covmean).all():
124
+ msg = (
125
+ "fid calculation produces singular product; adding %s to diagonal of cov estimates"
126
+ % eps
127
+ )
128
+ warnings.warn(msg)
129
+ offset = np.eye(sigma1.shape[0]) * eps
130
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
131
+
132
+ # numerical error might give slight imaginary component
133
+ if np.iscomplexobj(covmean):
134
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
135
+ m = np.max(np.abs(covmean.imag))
136
+ raise ValueError("Imaginary component {}".format(m))
137
+ covmean = covmean.real
138
+
139
+ tr_covmean = np.trace(covmean)
140
+
141
+ return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
142
+
143
+
144
+ class Evaluator:
145
+ def __init__(
146
+ self,
147
+ session,
148
+ batch_size=64,
149
+ softmax_batch_size=512,
150
+ ):
151
+ self.sess = session
152
+ self.batch_size = batch_size
153
+ self.softmax_batch_size = softmax_batch_size
154
+ self.manifold_estimator = ManifoldEstimator(session)
155
+ with self.sess.graph.as_default():
156
+ self.image_input = tf.placeholder(tf.float32, shape=[None, None, None, 3])
157
+ self.softmax_input = tf.placeholder(tf.float32, shape=[None, 2048])
158
+ self.pool_features, self.spatial_features = _create_feature_graph(self.image_input)
159
+ self.softmax = _create_softmax_graph(self.softmax_input)
160
+
161
+ def warmup(self):
162
+ self.compute_activations(np.zeros([1, 8, 64, 64, 3]))
163
+
164
+ def read_activations(self, npz_path: str) -> Tuple[np.ndarray, np.ndarray]:
165
+ with open_npz_array(npz_path, "arr_0") as reader:
166
+ return self.compute_activations(reader.read_batches(self.batch_size))
167
+
168
+ def compute_activations(self, batches: Iterable[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
169
+ """
170
+ Compute image features for downstream evals.
171
+
172
+ :param batches: a iterator over NHWC numpy arrays in [0, 255].
173
+ :return: a tuple of numpy arrays of shape [N x X], where X is a feature
174
+ dimension. The tuple is (pool_3, spatial).
175
+ """
176
+ preds = []
177
+ spatial_preds = []
178
+ for batch in tqdm(batches):
179
+ batch = batch.astype(np.float32)
180
+ pred, spatial_pred = self.sess.run(
181
+ [self.pool_features, self.spatial_features], {self.image_input: batch}
182
+ )
183
+ preds.append(pred.reshape([pred.shape[0], -1]))
184
+ spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1]))
185
+ return (
186
+ np.concatenate(preds, axis=0),
187
+ np.concatenate(spatial_preds, axis=0),
188
+ )
189
+
190
+ def read_statistics(
191
+ self, npz_path: str, activations: Tuple[np.ndarray, np.ndarray]
192
+ ) -> Tuple[FIDStatistics, FIDStatistics]:
193
+ obj = np.load(npz_path)
194
+ if "mu" in list(obj.keys()):
195
+ return FIDStatistics(obj["mu"], obj["sigma"]), FIDStatistics(
196
+ obj["mu_s"], obj["sigma_s"]
197
+ )
198
+ return tuple(self.compute_statistics(x) for x in activations)
199
+
200
+ def compute_statistics(self, activations: np.ndarray) -> FIDStatistics:
201
+ mu = np.mean(activations, axis=0)
202
+ sigma = np.cov(activations, rowvar=False)
203
+ return FIDStatistics(mu, sigma)
204
+
205
+ def compute_inception_score(self, activations: np.ndarray, split_size: int = 5000) -> float:
206
+ softmax_out = []
207
+ for i in range(0, len(activations), self.softmax_batch_size):
208
+ acts = activations[i : i + self.softmax_batch_size]
209
+ softmax_out.append(self.sess.run(self.softmax, feed_dict={self.softmax_input: acts}))
210
+ preds = np.concatenate(softmax_out, axis=0)
211
+ # https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46
212
+ scores = []
213
+ for i in range(0, len(preds), split_size):
214
+ part = preds[i : i + split_size]
215
+ kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
216
+ kl = np.mean(np.sum(kl, 1))
217
+ scores.append(np.exp(kl))
218
+ return float(np.mean(scores))
219
+
220
+ def compute_prec_recall(
221
+ self, activations_ref: np.ndarray, activations_sample: np.ndarray
222
+ ) -> Tuple[float, float]:
223
+ radii_1 = self.manifold_estimator.manifold_radii(activations_ref)
224
+ radii_2 = self.manifold_estimator.manifold_radii(activations_sample)
225
+ pr = self.manifold_estimator.evaluate_pr(
226
+ activations_ref, radii_1, activations_sample, radii_2
227
+ )
228
+ return (float(pr[0][0]), float(pr[1][0]))
229
+
230
+
231
+ class ManifoldEstimator:
232
+ """
233
+ A helper for comparing manifolds of feature vectors.
234
+
235
+ Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L57
236
+ """
237
+
238
+ def __init__(
239
+ self,
240
+ session,
241
+ row_batch_size=10000,
242
+ col_batch_size=10000,
243
+ nhood_sizes=(3,),
244
+ clamp_to_percentile=None,
245
+ eps=1e-5,
246
+ ):
247
+ """
248
+ Estimate the manifold of given feature vectors.
249
+
250
+ :param session: the TensorFlow session.
251
+ :param row_batch_size: row batch size to compute pairwise distances
252
+ (parameter to trade-off between memory usage and performance).
253
+ :param col_batch_size: column batch size to compute pairwise distances.
254
+ :param nhood_sizes: number of neighbors used to estimate the manifold.
255
+ :param clamp_to_percentile: prune hyperspheres that have radius larger than
256
+ the given percentile.
257
+ :param eps: small number for numerical stability.
258
+ """
259
+ self.distance_block = DistanceBlock(session)
260
+ self.row_batch_size = row_batch_size
261
+ self.col_batch_size = col_batch_size
262
+ self.nhood_sizes = nhood_sizes
263
+ self.num_nhoods = len(nhood_sizes)
264
+ self.clamp_to_percentile = clamp_to_percentile
265
+ self.eps = eps
266
+
267
+ def warmup(self):
268
+ feats, radii = (
269
+ np.zeros([1, 2048], dtype=np.float32),
270
+ np.zeros([1, 1], dtype=np.float32),
271
+ )
272
+ self.evaluate_pr(feats, radii, feats, radii)
273
+
274
+ def manifold_radii(self, features: np.ndarray) -> np.ndarray:
275
+ num_images = len(features)
276
+
277
+ # Estimate manifold of features by calculating distances to k-NN of each sample.
278
+ radii = np.zeros([num_images, self.num_nhoods], dtype=np.float32)
279
+ distance_batch = np.zeros([self.row_batch_size, num_images], dtype=np.float32)
280
+ seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32)
281
+
282
+ for begin1 in range(0, num_images, self.row_batch_size):
283
+ end1 = min(begin1 + self.row_batch_size, num_images)
284
+ row_batch = features[begin1:end1]
285
+
286
+ for begin2 in range(0, num_images, self.col_batch_size):
287
+ end2 = min(begin2 + self.col_batch_size, num_images)
288
+ col_batch = features[begin2:end2]
289
+
290
+ # Compute distances between batches.
291
+ distance_batch[
292
+ 0 : end1 - begin1, begin2:end2
293
+ ] = self.distance_block.pairwise_distances(row_batch, col_batch)
294
+
295
+ # Find the k-nearest neighbor from the current batch.
296
+ radii[begin1:end1, :] = np.concatenate(
297
+ [
298
+ x[:, self.nhood_sizes]
299
+ for x in _numpy_partition(distance_batch[0 : end1 - begin1, :], seq, axis=1)
300
+ ],
301
+ axis=0,
302
+ )
303
+
304
+ if self.clamp_to_percentile is not None:
305
+ max_distances = np.percentile(radii, self.clamp_to_percentile, axis=0)
306
+ radii[radii > max_distances] = 0
307
+ return radii
308
+
309
+ def evaluate(self, features: np.ndarray, radii: np.ndarray, eval_features: np.ndarray):
310
+ """
311
+ Evaluate if new feature vectors are at the manifold.
312
+ """
313
+ num_eval_images = eval_features.shape[0]
314
+ num_ref_images = radii.shape[0]
315
+ distance_batch = np.zeros([self.row_batch_size, num_ref_images], dtype=np.float32)
316
+ batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32)
317
+ max_realism_score = np.zeros([num_eval_images], dtype=np.float32)
318
+ nearest_indices = np.zeros([num_eval_images], dtype=np.int32)
319
+
320
+ for begin1 in range(0, num_eval_images, self.row_batch_size):
321
+ end1 = min(begin1 + self.row_batch_size, num_eval_images)
322
+ feature_batch = eval_features[begin1:end1]
323
+
324
+ for begin2 in range(0, num_ref_images, self.col_batch_size):
325
+ end2 = min(begin2 + self.col_batch_size, num_ref_images)
326
+ ref_batch = features[begin2:end2]
327
+
328
+ distance_batch[
329
+ 0 : end1 - begin1, begin2:end2
330
+ ] = self.distance_block.pairwise_distances(feature_batch, ref_batch)
331
+
332
+ # From the minibatch of new feature vectors, determine if they are in the estimated manifold.
333
+ # If a feature vector is inside a hypersphere of some reference sample, then
334
+ # the new sample lies at the estimated manifold.
335
+ # The radii of the hyperspheres are determined from distances of neighborhood size k.
336
+ samples_in_manifold = distance_batch[0 : end1 - begin1, :, None] <= radii
337
+ batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(np.int32)
338
+
339
+ max_realism_score[begin1:end1] = np.max(
340
+ radii[:, 0] / (distance_batch[0 : end1 - begin1, :] + self.eps), axis=1
341
+ )
342
+ nearest_indices[begin1:end1] = np.argmin(distance_batch[0 : end1 - begin1, :], axis=1)
343
+
344
+ return {
345
+ "fraction": float(np.mean(batch_predictions)),
346
+ "batch_predictions": batch_predictions,
347
+ "max_realisim_score": max_realism_score,
348
+ "nearest_indices": nearest_indices,
349
+ }
350
+
351
+ def evaluate_pr(
352
+ self,
353
+ features_1: np.ndarray,
354
+ radii_1: np.ndarray,
355
+ features_2: np.ndarray,
356
+ radii_2: np.ndarray,
357
+ ) -> Tuple[np.ndarray, np.ndarray]:
358
+ """
359
+ Evaluate precision and recall efficiently.
360
+
361
+ :param features_1: [N1 x D] feature vectors for reference batch.
362
+ :param radii_1: [N1 x K1] radii for reference vectors.
363
+ :param features_2: [N2 x D] feature vectors for the other batch.
364
+ :param radii_2: [N x K2] radii for other vectors.
365
+ :return: a tuple of arrays for (precision, recall):
366
+ - precision: an np.ndarray of length K1
367
+ - recall: an np.ndarray of length K2
368
+ """
369
+ features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=np.bool_)
370
+ features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=np.bool_)
371
+ for begin_1 in range(0, len(features_1), self.row_batch_size):
372
+ end_1 = begin_1 + self.row_batch_size
373
+ batch_1 = features_1[begin_1:end_1]
374
+ for begin_2 in range(0, len(features_2), self.col_batch_size):
375
+ end_2 = begin_2 + self.col_batch_size
376
+ batch_2 = features_2[begin_2:end_2]
377
+ batch_1_in, batch_2_in = self.distance_block.less_thans(
378
+ batch_1, radii_1[begin_1:end_1], batch_2, radii_2[begin_2:end_2]
379
+ )
380
+ features_1_status[begin_1:end_1] |= batch_1_in
381
+ features_2_status[begin_2:end_2] |= batch_2_in
382
+ return (
383
+ np.mean(features_2_status.astype(np.float64), axis=0),
384
+ np.mean(features_1_status.astype(np.float64), axis=0),
385
+ )
386
+
387
+
388
+ class DistanceBlock:
389
+ """
390
+ Calculate pairwise distances between vectors.
391
+
392
+ Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L34
393
+ """
394
+
395
+ def __init__(self, session):
396
+ self.session = session
397
+
398
+ # Initialize TF graph to calculate pairwise distances.
399
+ with session.graph.as_default():
400
+ self._features_batch1 = tf.placeholder(tf.float32, shape=[None, None])
401
+ self._features_batch2 = tf.placeholder(tf.float32, shape=[None, None])
402
+ distance_block_16 = _batch_pairwise_distances(
403
+ tf.cast(self._features_batch1, tf.float16),
404
+ tf.cast(self._features_batch2, tf.float16),
405
+ )
406
+ self.distance_block = tf.cond(
407
+ tf.reduce_all(tf.math.is_finite(distance_block_16)),
408
+ lambda: tf.cast(distance_block_16, tf.float32),
409
+ lambda: _batch_pairwise_distances(self._features_batch1, self._features_batch2),
410
+ )
411
+
412
+ # Extra logic for less thans.
413
+ self._radii1 = tf.placeholder(tf.float32, shape=[None, None])
414
+ self._radii2 = tf.placeholder(tf.float32, shape=[None, None])
415
+ dist32 = tf.cast(self.distance_block, tf.float32)[..., None]
416
+ self._batch_1_in = tf.math.reduce_any(dist32 <= self._radii2, axis=1)
417
+ self._batch_2_in = tf.math.reduce_any(dist32 <= self._radii1[:, None], axis=0)
418
+
419
+ def pairwise_distances(self, U, V):
420
+ """
421
+ Evaluate pairwise distances between two batches of feature vectors.
422
+ """
423
+ return self.session.run(
424
+ self.distance_block,
425
+ feed_dict={self._features_batch1: U, self._features_batch2: V},
426
+ )
427
+
428
+ def less_thans(self, batch_1, radii_1, batch_2, radii_2):
429
+ return self.session.run(
430
+ [self._batch_1_in, self._batch_2_in],
431
+ feed_dict={
432
+ self._features_batch1: batch_1,
433
+ self._features_batch2: batch_2,
434
+ self._radii1: radii_1,
435
+ self._radii2: radii_2,
436
+ },
437
+ )
438
+
439
+
440
+ def _batch_pairwise_distances(U, V):
441
+ """
442
+ Compute pairwise distances between two batches of feature vectors.
443
+ """
444
+ with tf.variable_scope("pairwise_dist_block"):
445
+ # Squared norms of each row in U and V.
446
+ norm_u = tf.reduce_sum(tf.square(U), 1)
447
+ norm_v = tf.reduce_sum(tf.square(V), 1)
448
+
449
+ # norm_u as a column and norm_v as a row vectors.
450
+ norm_u = tf.reshape(norm_u, [-1, 1])
451
+ norm_v = tf.reshape(norm_v, [1, -1])
452
+
453
+ # Pairwise squared Euclidean distances.
454
+ D = tf.maximum(norm_u - 2 * tf.matmul(U, V, False, True) + norm_v, 0.0)
455
+
456
+ return D
457
+
458
+
459
+ class NpzArrayReader(ABC):
460
+ @abstractmethod
461
+ def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
462
+ pass
463
+
464
+ @abstractmethod
465
+ def remaining(self) -> int:
466
+ pass
467
+
468
+ def read_batches(self, batch_size: int) -> Iterable[np.ndarray]:
469
+ def gen_fn():
470
+ while True:
471
+ batch = self.read_batch(batch_size)
472
+ if batch is None:
473
+ break
474
+ yield batch
475
+
476
+ rem = self.remaining()
477
+ num_batches = rem // batch_size + int(rem % batch_size != 0)
478
+ return BatchIterator(gen_fn, num_batches)
479
+
480
+
481
+ class BatchIterator:
482
+ def __init__(self, gen_fn, length):
483
+ self.gen_fn = gen_fn
484
+ self.length = length
485
+
486
+ def __len__(self):
487
+ return self.length
488
+
489
+ def __iter__(self):
490
+ return self.gen_fn()
491
+
492
+
493
+ class StreamingNpzArrayReader(NpzArrayReader):
494
+ def __init__(self, arr_f, shape, dtype):
495
+ self.arr_f = arr_f
496
+ self.shape = shape
497
+ self.dtype = dtype
498
+ self.idx = 0
499
+
500
+ def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
501
+ if self.idx >= self.shape[0]:
502
+ return None
503
+
504
+ bs = min(batch_size, self.shape[0] - self.idx)
505
+ self.idx += bs
506
+
507
+ if self.dtype.itemsize == 0:
508
+ return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype)
509
+
510
+ read_count = bs * np.prod(self.shape[1:])
511
+ read_size = int(read_count * self.dtype.itemsize)
512
+ data = _read_bytes(self.arr_f, read_size, "array data")
513
+ return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]])
514
+
515
+ def remaining(self) -> int:
516
+ return max(0, self.shape[0] - self.idx)
517
+
518
+
519
+ class MemoryNpzArrayReader(NpzArrayReader):
520
+ def __init__(self, arr):
521
+ self.arr = arr
522
+ self.idx = 0
523
+
524
+ @classmethod
525
+ def load(cls, path: str, arr_name: str):
526
+ with open(path, "rb") as f:
527
+ arr = np.load(f)[arr_name]
528
+ return cls(arr)
529
+
530
+ def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
531
+ if self.idx >= self.arr.shape[0]:
532
+ return None
533
+
534
+ res = self.arr[self.idx : self.idx + batch_size]
535
+ self.idx += batch_size
536
+ return res
537
+
538
+ def remaining(self) -> int:
539
+ return max(0, self.arr.shape[0] - self.idx)
540
+
541
+
542
+ @contextmanager
543
+ def open_npz_array(path: str, arr_name: str) -> NpzArrayReader:
544
+ with _open_npy_file(path, arr_name) as arr_f:
545
+ version = np.lib.format.read_magic(arr_f)
546
+ if version == (1, 0):
547
+ header = np.lib.format.read_array_header_1_0(arr_f)
548
+ elif version == (2, 0):
549
+ header = np.lib.format.read_array_header_2_0(arr_f)
550
+ else:
551
+ yield MemoryNpzArrayReader.load(path, arr_name)
552
+ return
553
+ shape, fortran, dtype = header
554
+ if fortran or dtype.hasobject:
555
+ yield MemoryNpzArrayReader.load(path, arr_name)
556
+ else:
557
+ yield StreamingNpzArrayReader(arr_f, shape, dtype)
558
+
559
+
560
+ def _read_bytes(fp, size, error_template="ran out of data"):
561
+ """
562
+ Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886
563
+
564
+ Read from file-like object until size bytes are read.
565
+ Raises ValueError if not EOF is encountered before size bytes are read.
566
+ Non-blocking objects only supported if they derive from io objects.
567
+ Required as e.g. ZipExtFile in python 2.6 can return less data than
568
+ requested.
569
+ """
570
+ data = bytes()
571
+ while True:
572
+ # io files (default in python3) return None or raise on
573
+ # would-block, python2 file will truncate, probably nothing can be
574
+ # done about that. note that regular files can't be non-blocking
575
+ try:
576
+ r = fp.read(size - len(data))
577
+ data += r
578
+ if len(r) == 0 or len(data) == size:
579
+ break
580
+ except io.BlockingIOError:
581
+ pass
582
+ if len(data) != size:
583
+ msg = "EOF: reading %s, expected %d bytes got %d"
584
+ raise ValueError(msg % (error_template, size, len(data)))
585
+ else:
586
+ return data
587
+
588
+
589
+ @contextmanager
590
+ def _open_npy_file(path: str, arr_name: str):
591
+ with open(path, "rb") as f:
592
+ with zipfile.ZipFile(f, "r") as zip_f:
593
+ if f"{arr_name}.npy" not in zip_f.namelist():
594
+ raise ValueError(f"missing {arr_name} in npz file")
595
+ with zip_f.open(f"{arr_name}.npy", "r") as arr_f:
596
+ yield arr_f
597
+
598
+
599
+ def _download_inception_model():
600
+ if os.path.exists(INCEPTION_V3_PATH):
601
+ return
602
+ print("downloading InceptionV3 model...")
603
+ with requests.get(INCEPTION_V3_URL, stream=True) as r:
604
+ r.raise_for_status()
605
+ tmp_path = INCEPTION_V3_PATH + ".tmp"
606
+ with open(tmp_path, "wb") as f:
607
+ for chunk in tqdm(r.iter_content(chunk_size=8192)):
608
+ f.write(chunk)
609
+ os.rename(tmp_path, INCEPTION_V3_PATH)
610
+
611
+
612
+ def _create_feature_graph(input_batch):
613
+ _download_inception_model()
614
+ prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
615
+ with open(INCEPTION_V3_PATH, "rb") as f:
616
+ graph_def = tf.GraphDef()
617
+ graph_def.ParseFromString(f.read())
618
+ pool3, spatial = tf.import_graph_def(
619
+ graph_def,
620
+ input_map={f"ExpandDims:0": input_batch},
621
+ return_elements=[FID_POOL_NAME, FID_SPATIAL_NAME],
622
+ name=prefix,
623
+ )
624
+ _update_shapes(pool3)
625
+ spatial = spatial[..., :7]
626
+ return pool3, spatial
627
+
628
+
629
+ def _create_softmax_graph(input_batch):
630
+ _download_inception_model()
631
+ prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
632
+ with open(INCEPTION_V3_PATH, "rb") as f:
633
+ graph_def = tf.GraphDef()
634
+ graph_def.ParseFromString(f.read())
635
+ (matmul,) = tf.import_graph_def(
636
+ graph_def, return_elements=[f"softmax/logits/MatMul"], name=prefix
637
+ )
638
+ w = matmul.inputs[1]
639
+ logits = tf.matmul(input_batch, w)
640
+ return tf.nn.softmax(logits)
641
+
642
+
643
+ def _update_shapes(pool3):
644
+ # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L50-L63
645
+ ops = pool3.graph.get_operations()
646
+ for op in ops:
647
+ for o in op.outputs:
648
+ shape = o.get_shape()
649
+ if shape._dims is not None: # pylint: disable=protected-access
650
+ # shape = [s.value for s in shape] TF 1.x
651
+ shape = [s for s in shape] # TF 2.x
652
+ new_shape = []
653
+ for j, s in enumerate(shape):
654
+ if s == 1 and j == 0:
655
+ new_shape.append(None)
656
+ else:
657
+ new_shape.append(s)
658
+ o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
659
+ return pool3
660
+
661
+
662
+ def _numpy_partition(arr, kth, **kwargs):
663
+ num_workers = min(cpu_count(), len(arr))
664
+ chunk_size = len(arr) // num_workers
665
+ extra = len(arr) % num_workers
666
+
667
+ start_idx = 0
668
+ batches = []
669
+ for i in range(num_workers):
670
+ size = chunk_size + (1 if i < extra else 0)
671
+ batches.append(arr[start_idx : start_idx + size])
672
+ start_idx += size
673
+
674
+ with ThreadPool(num_workers) as pool:
675
+ return list(pool.map(partial(np.partition, kth=kth, **kwargs), batches))
676
+
677
+
678
+ if __name__ == "__main__":
679
+ main()
New/REG/evaluations/requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ tensorflow-gpu>=2.0
2
+ scipy
3
+ requests
4
+ tqdm
New/REG/generate_2400000_100.sh ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # ODE 采样(euler_sampler,与 samplers.py 中 REG ODE 一致):分别用 5 / 10 / 50 步,输出到不同子目录。
3
+ set -euo pipefail
4
+
5
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
6
+ cd "$SCRIPT_DIR"
7
+
8
+ NUM_GPUS=1
9
+ NUM_SAMPLES=100
10
+ PER_PROC_BATCH_SIZE=10
11
+
12
+ CFG_SCALE=1
13
+ CLS_CFG_SCALE=1
14
+ GH=0.85
15
+
16
+ CKPT_PATH="/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/REG-XLarge-256/REG-XLarge-256/2400000.pt"
17
+ #CKPT_PATH="/your_path/reg_xlarge_dinov2_base_align_8_cls/linear-dinov2-b-enc8/checkpoints/0040000.pt"
18
+
19
+ # 根输出目录(可用环境变量覆盖);其下会生成 ode_steps_5 / ode_steps_10 / ode_steps_50
20
+ OUTPUT_BASE="${OUTPUT_BASE:-./reg_xl_2400000_ode_multistep}"
21
+ NOISE_BANK_PATH="${NOISE_BANK_PATH:-${OUTPUT_BASE}/noise_bank.pt}"
22
+
23
+ export NCCL_P2P_DISABLE=1
24
+ export OUTPUT_BASE
25
+
26
+ GLOBAL_BATCH_SIZE=$((PER_PROC_BATCH_SIZE * NUM_GPUS))
27
+ TOTAL_SAMPLES=$((((NUM_SAMPLES + GLOBAL_BATCH_SIZE - 1) / GLOBAL_BATCH_SIZE) * GLOBAL_BATCH_SIZE))
28
+ export NOISE_BANK_PATH TOTAL_SAMPLES
29
+
30
+ REBUILD_NOISE_BANK=0
31
+ if [[ ! -f "${NOISE_BANK_PATH}" ]]; then
32
+ REBUILD_NOISE_BANK=1
33
+ else
34
+ EXISTING_SAMPLES=$(python - <<'PY'
35
+ import os
36
+ import torch
37
+ p = os.environ["NOISE_BANK_PATH"]
38
+ try:
39
+ d = torch.load(p, map_location="cpu", weights_only=True)
40
+ except TypeError:
41
+ d = torch.load(p, map_location="cpu")
42
+ print(int(d["z"].shape[0]))
43
+ PY
44
+ )
45
+ if [[ "${EXISTING_SAMPLES}" != "${TOTAL_SAMPLES}" ]]; then
46
+ echo "noise_bank size mismatch: existing=${EXISTING_SAMPLES}, required=${TOTAL_SAMPLES}. Rebuilding..."
47
+ REBUILD_NOISE_BANK=1
48
+ fi
49
+ fi
50
+
51
+ if [[ "${REBUILD_NOISE_BANK}" -eq 1 ]]; then
52
+ mkdir -p "$(dirname "${NOISE_BANK_PATH}")"
53
+ echo "Creating fixed noise bank: ${NOISE_BANK_PATH} (total_samples=${TOTAL_SAMPLES})"
54
+ python - <<'PY'
55
+ import os
56
+ import torch
57
+
58
+ noise_path = os.environ["NOISE_BANK_PATH"]
59
+ total_samples = int(os.environ["TOTAL_SAMPLES"])
60
+ num_classes = 1000
61
+ in_channels = 4
62
+ latent_size = 32 # 256 / 8
63
+ cls_dim = 768
64
+
65
+ z = torch.randn(total_samples, in_channels, latent_size, latent_size)
66
+ y = torch.randint(0, num_classes, (total_samples,), dtype=torch.long)
67
+ cls_z = torch.randn(total_samples, cls_dim)
68
+
69
+ torch.save({"z": z, "y": y, "cls_z": cls_z}, noise_path)
70
+ print(f"Saved fixed noise bank to {noise_path}")
71
+ PY
72
+ fi
73
+
74
+ for NUM_STEP in 5 10 20 50 100; do
75
+ RANDOM_PORT=$((RANDOM % 100 + 1200))
76
+ OUTPUT_DIR="${OUTPUT_BASE}/ode_steps_${NUM_STEP}/checkpoints"
77
+ echo "=== ODE num_steps=${NUM_STEP} -> ${OUTPUT_DIR} ==="
78
+
79
+ python -m torch.distributed.launch \
80
+ --master_port="${RANDOM_PORT}" \
81
+ --nproc_per_node="${NUM_GPUS}" \
82
+ generate.py \
83
+ --model "SiT-XL/2" \
84
+ --num-fid-samples "${NUM_SAMPLES}" \
85
+ --ckpt "${CKPT_PATH}" \
86
+ --path-type=linear \
87
+ --encoder-depth=8 \
88
+ --projector-embed-dims=768 \
89
+ --per-proc-batch-size="${PER_PROC_BATCH_SIZE}" \
90
+ --mode=ode \
91
+ --num-steps="${NUM_STEP}" \
92
+ --fixed-noise-file "${NOISE_BANK_PATH}" \
93
+ --cfg-scale="${CFG_SCALE}" \
94
+ --cls-cfg-scale="${CLS_CFG_SCALE}" \
95
+ --guidance-high="${GH}" \
96
+ --sample-dir "${OUTPUT_DIR}" \
97
+ --cls=768
98
+ done
99
+
100
+ echo "All ODE runs finished. Building side-by-side pairs..."
101
+
102
+ python - <<'PY'
103
+ import os
104
+ from PIL import Image
105
+
106
+ output_base = os.environ.get("OUTPUT_BASE", "./reg_xl_2400000_ode_multistep")
107
+ steps = [5, 10, 50]
108
+
109
+ def find_single_subdir(step):
110
+ root = os.path.join(output_base, f"ode_steps_{step}", "checkpoints")
111
+ if not os.path.isdir(root):
112
+ raise RuntimeError(f"Missing directory: {root}")
113
+ subdirs = [d for d in os.listdir(root) if os.path.isdir(os.path.join(root, d))]
114
+ if len(subdirs) != 1:
115
+ raise RuntimeError(f"Expected exactly 1 run folder under {root}, got {len(subdirs)}: {subdirs}")
116
+ return os.path.join(root, subdirs[0])
117
+
118
+ run_dirs = {s: find_single_subdir(s) for s in steps}
119
+ pair_dir = os.path.join(output_base, "pair")
120
+ os.makedirs(pair_dir, exist_ok=True)
121
+
122
+ files = sorted([f for f in os.listdir(run_dirs[steps[0]]) if f.endswith(".png")])
123
+ for name in files:
124
+ imgs = [Image.open(os.path.join(run_dirs[s], name)).convert("RGB") for s in steps]
125
+ w = sum(im.width for im in imgs)
126
+ h = max(im.height for im in imgs)
127
+ canvas = Image.new("RGB", (w, h))
128
+ x = 0
129
+ for im in imgs:
130
+ canvas.paste(im, (x, 0))
131
+ x += im.width
132
+ canvas.save(os.path.join(pair_dir, name))
133
+
134
+ print(f"Saved paired comparisons to: {pair_dir}")
135
+ PY
136
+
137
+ echo "Done. Outputs under: ${OUTPUT_BASE}/ode_steps_{5,10,50}/checkpoints ; pairs under: ${OUTPUT_BASE}/pair"
New/REG/models/__pycache__/mocov3_vit.cpython-310.pyc ADDED
Binary file (6.5 kB). View file
 
New/REG/models/__pycache__/mocov3_vit.cpython-312.pyc ADDED
Binary file (12.4 kB). View file
 
New/REG/models/__pycache__/sit.cpython-310.pyc ADDED
Binary file (13.4 kB). View file
 
New/REG/models/__pycache__/sit.cpython-312.pyc ADDED
Binary file (22.9 kB). View file
 
New/REG/models/clip_vit.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+ import clip
10
+
11
+
12
+ class Bottleneck(nn.Module):
13
+ expansion = 4
14
+
15
+ def __init__(self, inplanes, planes, stride=1):
16
+ super().__init__()
17
+
18
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
19
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
20
+ self.bn1 = nn.BatchNorm2d(planes)
21
+ self.relu1 = nn.ReLU(inplace=True)
22
+
23
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
24
+ self.bn2 = nn.BatchNorm2d(planes)
25
+ self.relu2 = nn.ReLU(inplace=True)
26
+
27
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
28
+
29
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
30
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
31
+ self.relu3 = nn.ReLU(inplace=True)
32
+
33
+ self.downsample = None
34
+ self.stride = stride
35
+
36
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
37
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
38
+ self.downsample = nn.Sequential(OrderedDict([
39
+ ("-1", nn.AvgPool2d(stride)),
40
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
41
+ ("1", nn.BatchNorm2d(planes * self.expansion))
42
+ ]))
43
+
44
+ def forward(self, x: torch.Tensor):
45
+ identity = x
46
+
47
+ out = self.relu1(self.bn1(self.conv1(x)))
48
+ out = self.relu2(self.bn2(self.conv2(out)))
49
+ out = self.avgpool(out)
50
+ out = self.bn3(self.conv3(out))
51
+
52
+ if self.downsample is not None:
53
+ identity = self.downsample(x)
54
+
55
+ out += identity
56
+ out = self.relu3(out)
57
+ return out
58
+
59
+
60
+ class AttentionPool2d(nn.Module):
61
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
62
+ super().__init__()
63
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
64
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
65
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
66
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
67
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
68
+ self.num_heads = num_heads
69
+
70
+ def forward(self, x):
71
+ x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
72
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
73
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
74
+ x, _ = F.multi_head_attention_forward(
75
+ query=x[:1], key=x, value=x,
76
+ embed_dim_to_check=x.shape[-1],
77
+ num_heads=self.num_heads,
78
+ q_proj_weight=self.q_proj.weight,
79
+ k_proj_weight=self.k_proj.weight,
80
+ v_proj_weight=self.v_proj.weight,
81
+ in_proj_weight=None,
82
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
83
+ bias_k=None,
84
+ bias_v=None,
85
+ add_zero_attn=False,
86
+ dropout_p=0,
87
+ out_proj_weight=self.c_proj.weight,
88
+ out_proj_bias=self.c_proj.bias,
89
+ use_separate_proj_weight=True,
90
+ training=self.training,
91
+ need_weights=False
92
+ )
93
+ return x.squeeze(0)
94
+
95
+
96
+ class ModifiedResNet(nn.Module):
97
+ """
98
+ A ResNet class that is similar to torchvision's but contains the following changes:
99
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
100
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
101
+ - The final pooling layer is a QKV attention instead of an average pool
102
+ """
103
+
104
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
105
+ super().__init__()
106
+ self.output_dim = output_dim
107
+ self.input_resolution = input_resolution
108
+
109
+ # the 3-layer stem
110
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
111
+ self.bn1 = nn.BatchNorm2d(width // 2)
112
+ self.relu1 = nn.ReLU(inplace=True)
113
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
114
+ self.bn2 = nn.BatchNorm2d(width // 2)
115
+ self.relu2 = nn.ReLU(inplace=True)
116
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
117
+ self.bn3 = nn.BatchNorm2d(width)
118
+ self.relu3 = nn.ReLU(inplace=True)
119
+ self.avgpool = nn.AvgPool2d(2)
120
+
121
+ # residual layers
122
+ self._inplanes = width # this is a *mutable* variable used during construction
123
+ self.layer1 = self._make_layer(width, layers[0])
124
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
125
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
126
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
127
+
128
+ embed_dim = width * 32 # the ResNet feature dimension
129
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
130
+
131
+ def _make_layer(self, planes, blocks, stride=1):
132
+ layers = [Bottleneck(self._inplanes, planes, stride)]
133
+
134
+ self._inplanes = planes * Bottleneck.expansion
135
+ for _ in range(1, blocks):
136
+ layers.append(Bottleneck(self._inplanes, planes))
137
+
138
+ return nn.Sequential(*layers)
139
+
140
+ def forward(self, x):
141
+ def stem(x):
142
+ x = self.relu1(self.bn1(self.conv1(x)))
143
+ x = self.relu2(self.bn2(self.conv2(x)))
144
+ x = self.relu3(self.bn3(self.conv3(x)))
145
+ x = self.avgpool(x)
146
+ return x
147
+
148
+ x = x.type(self.conv1.weight.dtype)
149
+ x = stem(x)
150
+ x = self.layer1(x)
151
+ x = self.layer2(x)
152
+ x = self.layer3(x)
153
+ x = self.layer4(x)
154
+ x = self.attnpool(x)
155
+
156
+ return x
157
+
158
+
159
+ class LayerNorm(nn.LayerNorm):
160
+ """Subclass torch's LayerNorm to handle fp16."""
161
+
162
+ def forward(self, x: torch.Tensor):
163
+ orig_type = x.dtype
164
+ ret = super().forward(x.type(torch.float32))
165
+ return ret.type(orig_type)
166
+
167
+
168
+ class QuickGELU(nn.Module):
169
+ def forward(self, x: torch.Tensor):
170
+ return x * torch.sigmoid(1.702 * x)
171
+
172
+
173
+ class ResidualAttentionBlock(nn.Module):
174
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
175
+ super().__init__()
176
+
177
+ self.attn = nn.MultiheadAttention(d_model, n_head)
178
+ self.ln_1 = LayerNorm(d_model)
179
+ self.mlp = nn.Sequential(OrderedDict([
180
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
181
+ ("gelu", QuickGELU()),
182
+ ("c_proj", nn.Linear(d_model * 4, d_model))
183
+ ]))
184
+ self.ln_2 = LayerNorm(d_model)
185
+ self.attn_mask = attn_mask
186
+
187
+ def attention(self, x: torch.Tensor):
188
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
189
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
190
+
191
+ def forward(self, x: torch.Tensor):
192
+ x = x + self.attention(self.ln_1(x))
193
+ x = x + self.mlp(self.ln_2(x))
194
+ return x
195
+
196
+
197
+ class Transformer(nn.Module):
198
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
199
+ super().__init__()
200
+ self.width = width
201
+ self.layers = layers
202
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
203
+
204
+ def forward(self, x: torch.Tensor):
205
+ return self.resblocks(x)
206
+
207
+
208
+ class UpdatedVisionTransformer(nn.Module):
209
+ def __init__(self, model):
210
+ super().__init__()
211
+ self.model = model
212
+
213
+ def forward(self, x: torch.Tensor):
214
+ x = self.model.conv1(x) # shape = [*, width, grid, grid]
215
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
216
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
217
+ x = torch.cat([self.model.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
218
+ x = x + self.model.positional_embedding.to(x.dtype)
219
+ x = self.model.ln_pre(x)
220
+
221
+ x = x.permute(1, 0, 2) # NLD -> LND
222
+ x = self.model.transformer(x)
223
+ x = x.permute(1, 0, 2)[:, 1:] # LND -> NLD
224
+
225
+ # x = self.ln_post(x[:, 0, :])
226
+
227
+ # if self.proj is not None:
228
+ # x = x @ self.proj
229
+
230
+ return x
231
+
232
+
233
+ class CLIP(nn.Module):
234
+ def __init__(self,
235
+ embed_dim: int,
236
+ # vision
237
+ image_resolution: int,
238
+ vision_layers: Union[Tuple[int, int, int, int], int],
239
+ vision_width: int,
240
+ vision_patch_size: int,
241
+ # text
242
+ context_length: int,
243
+ vocab_size: int,
244
+ transformer_width: int,
245
+ transformer_heads: int,
246
+ transformer_layers: int
247
+ ):
248
+ super().__init__()
249
+
250
+ self.context_length = context_length
251
+
252
+ if isinstance(vision_layers, (tuple, list)):
253
+ vision_heads = vision_width * 32 // 64
254
+ self.visual = ModifiedResNet(
255
+ layers=vision_layers,
256
+ output_dim=embed_dim,
257
+ heads=vision_heads,
258
+ input_resolution=image_resolution,
259
+ width=vision_width
260
+ )
261
+ else:
262
+ vision_heads = vision_width // 64
263
+ self.visual = UpdatedVisionTransformer(
264
+ input_resolution=image_resolution,
265
+ patch_size=vision_patch_size,
266
+ width=vision_width,
267
+ layers=vision_layers,
268
+ heads=vision_heads,
269
+ output_dim=embed_dim
270
+ )
271
+
272
+ self.transformer = Transformer(
273
+ width=transformer_width,
274
+ layers=transformer_layers,
275
+ heads=transformer_heads,
276
+ attn_mask=self.build_attention_mask()
277
+ )
278
+
279
+ self.vocab_size = vocab_size
280
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
281
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
282
+ self.ln_final = LayerNorm(transformer_width)
283
+
284
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
285
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
286
+
287
+ self.initialize_parameters()
288
+
289
+ def initialize_parameters(self):
290
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
291
+ nn.init.normal_(self.positional_embedding, std=0.01)
292
+
293
+ if isinstance(self.visual, ModifiedResNet):
294
+ if self.visual.attnpool is not None:
295
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
296
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
297
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
298
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
299
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
300
+
301
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
302
+ for name, param in resnet_block.named_parameters():
303
+ if name.endswith("bn3.weight"):
304
+ nn.init.zeros_(param)
305
+
306
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
307
+ attn_std = self.transformer.width ** -0.5
308
+ fc_std = (2 * self.transformer.width) ** -0.5
309
+ for block in self.transformer.resblocks:
310
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
311
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
312
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
313
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
314
+
315
+ if self.text_projection is not None:
316
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
317
+
318
+ def build_attention_mask(self):
319
+ # lazily create causal attention mask, with full attention between the vision tokens
320
+ # pytorch uses additive attention mask; fill with -inf
321
+ mask = torch.empty(self.context_length, self.context_length)
322
+ mask.fill_(float("-inf"))
323
+ mask.triu_(1) # zero out the lower diagonal
324
+ return mask
325
+
326
+ @property
327
+ def dtype(self):
328
+ return self.visual.conv1.weight.dtype
329
+
330
+ def encode_image(self, image):
331
+ return self.visual(image.type(self.dtype))
332
+
333
+ def encode_text(self, text):
334
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
335
+
336
+ x = x + self.positional_embedding.type(self.dtype)
337
+ x = x.permute(1, 0, 2) # NLD -> LND
338
+ x = self.transformer(x)
339
+ x = x.permute(1, 0, 2) # LND -> NLD
340
+ x = self.ln_final(x).type(self.dtype)
341
+
342
+ # x.shape = [batch_size, n_ctx, transformer.width]
343
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
344
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
345
+
346
+ return x
347
+
348
+ def forward(self, image, text):
349
+ image_features = self.encode_image(image)
350
+ text_features = self.encode_text(text)
351
+
352
+ # normalized features
353
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
354
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
355
+
356
+ # cosine similarity as logits
357
+ logit_scale = self.logit_scale.exp()
358
+ logits_per_image = logit_scale * image_features @ text_features.t()
359
+ logits_per_text = logits_per_image.t()
360
+
361
+ # shape = [global_batch_size, global_batch_size]
362
+ return logits_per_image, logits_per_text
363
+
364
+
365
+ def convert_weights(model: nn.Module):
366
+ """Convert applicable model parameters to fp16"""
367
+
368
+ def _convert_weights_to_fp16(l):
369
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
370
+ l.weight.data = l.weight.data.half()
371
+ if l.bias is not None:
372
+ l.bias.data = l.bias.data.half()
373
+
374
+ if isinstance(l, nn.MultiheadAttention):
375
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
376
+ tensor = getattr(l, attr)
377
+ if tensor is not None:
378
+ tensor.data = tensor.data.half()
379
+
380
+ for name in ["text_projection", "proj"]:
381
+ if hasattr(l, name):
382
+ attr = getattr(l, name)
383
+ if attr is not None:
384
+ attr.data = attr.data.half()
385
+
386
+ model.apply(_convert_weights_to_fp16)
387
+
388
+
389
+ def build_model(state_dict: dict):
390
+ vit = "visual.proj" in state_dict
391
+
392
+ if vit:
393
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
394
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
395
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
396
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
397
+ image_resolution = vision_patch_size * grid_size
398
+ else:
399
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
400
+ vision_layers = tuple(counts)
401
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
402
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
403
+ vision_patch_size = None
404
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
405
+ image_resolution = output_width * 32
406
+
407
+ embed_dim = state_dict["text_projection"].shape[1]
408
+ context_length = state_dict["positional_embedding"].shape[0]
409
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
410
+ transformer_width = state_dict["ln_final.weight"].shape[0]
411
+ transformer_heads = transformer_width // 64
412
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
413
+
414
+ model = CLIP(
415
+ embed_dim,
416
+ image_resolution, vision_layers, vision_width, vision_patch_size,
417
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
418
+ )
419
+
420
+ for key in ["input_resolution", "context_length", "vocab_size"]:
421
+ if key in state_dict:
422
+ del state_dict[key]
423
+
424
+ convert_weights(model)
425
+ model.load_state_dict(state_dict)
426
+ return model.eval()
New/REG/models/jepa.py ADDED
@@ -0,0 +1,547 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+
8
+ import math
9
+ from functools import partial
10
+ import numpy as np
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+
15
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
16
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
17
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
18
+ def norm_cdf(x):
19
+ # Computes standard normal cumulative distribution function
20
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
21
+
22
+ with torch.no_grad():
23
+ # Values are generated by using a truncated uniform distribution and
24
+ # then using the inverse CDF for the normal distribution.
25
+ # Get upper and lower cdf values
26
+ l = norm_cdf((a - mean) / std)
27
+ u = norm_cdf((b - mean) / std)
28
+
29
+ # Uniformly fill tensor with values from [l, u], then translate to
30
+ # [2l-1, 2u-1].
31
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
32
+
33
+ # Use inverse cdf transform for normal distribution to get truncated
34
+ # standard normal
35
+ tensor.erfinv_()
36
+
37
+ # Transform to proper mean, std
38
+ tensor.mul_(std * math.sqrt(2.))
39
+ tensor.add_(mean)
40
+
41
+ # Clamp to ensure it's in the proper range
42
+ tensor.clamp_(min=a, max=b)
43
+ return tensor
44
+
45
+
46
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
47
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
48
+
49
+
50
+ def repeat_interleave_batch(x, B, repeat):
51
+ N = len(x) // B
52
+ x = torch.cat([
53
+ torch.cat([x[i*B:(i+1)*B] for _ in range(repeat)], dim=0)
54
+ for i in range(N)
55
+ ], dim=0)
56
+ return x
57
+
58
+ def apply_masks(x, masks):
59
+ """
60
+ :param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)]
61
+ :param masks: list of tensors containing indices of patches in [N] to keep
62
+ """
63
+ all_x = []
64
+ for m in masks:
65
+ mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1))
66
+ all_x += [torch.gather(x, dim=1, index=mask_keep)]
67
+ return torch.cat(all_x, dim=0)
68
+
69
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
70
+ """
71
+ grid_size: int of the grid height and width
72
+ return:
73
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
74
+ """
75
+ grid_h = np.arange(grid_size, dtype=float)
76
+ grid_w = np.arange(grid_size, dtype=float)
77
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
78
+ grid = np.stack(grid, axis=0)
79
+
80
+ grid = grid.reshape([2, 1, grid_size, grid_size])
81
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
82
+ if cls_token:
83
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
84
+ return pos_embed
85
+
86
+
87
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
88
+ assert embed_dim % 2 == 0
89
+
90
+ # use half of dimensions to encode grid_h
91
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
92
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
93
+
94
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
95
+ return emb
96
+
97
+
98
+ def get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
99
+ """
100
+ grid_size: int of the grid length
101
+ return:
102
+ pos_embed: [grid_size, embed_dim] or [1+grid_size, embed_dim] (w/ or w/o cls_token)
103
+ """
104
+ grid = np.arange(grid_size, dtype=float)
105
+ pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid)
106
+ if cls_token:
107
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
108
+ return pos_embed
109
+
110
+
111
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
112
+ """
113
+ embed_dim: output dimension for each position
114
+ pos: a list of positions to be encoded: size (M,)
115
+ out: (M, D)
116
+ """
117
+ assert embed_dim % 2 == 0
118
+ omega = np.arange(embed_dim // 2, dtype=float)
119
+ omega /= embed_dim / 2.
120
+ omega = 1. / 10000**omega # (D/2,)
121
+
122
+ pos = pos.reshape(-1) # (M,)
123
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
124
+
125
+ emb_sin = np.sin(out) # (M, D/2)
126
+ emb_cos = np.cos(out) # (M, D/2)
127
+
128
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
129
+ return emb
130
+
131
+
132
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
133
+ if drop_prob == 0. or not training:
134
+ return x
135
+ keep_prob = 1 - drop_prob
136
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
137
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
138
+ random_tensor.floor_() # binarize
139
+ output = x.div(keep_prob) * random_tensor
140
+ return output
141
+
142
+
143
+ class DropPath(nn.Module):
144
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
145
+ """
146
+ def __init__(self, drop_prob=None):
147
+ super(DropPath, self).__init__()
148
+ self.drop_prob = drop_prob
149
+
150
+ def forward(self, x):
151
+ return drop_path(x, self.drop_prob, self.training)
152
+
153
+
154
+ class MLP(nn.Module):
155
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
156
+ super().__init__()
157
+ out_features = out_features or in_features
158
+ hidden_features = hidden_features or in_features
159
+ self.fc1 = nn.Linear(in_features, hidden_features)
160
+ self.act = act_layer()
161
+ self.fc2 = nn.Linear(hidden_features, out_features)
162
+ self.drop = nn.Dropout(drop)
163
+
164
+ def forward(self, x):
165
+ x = self.fc1(x)
166
+ x = self.act(x)
167
+ x = self.drop(x)
168
+ x = self.fc2(x)
169
+ x = self.drop(x)
170
+ return x
171
+
172
+
173
+ class Attention(nn.Module):
174
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
175
+ super().__init__()
176
+ self.num_heads = num_heads
177
+ head_dim = dim // num_heads
178
+ self.scale = qk_scale or head_dim ** -0.5
179
+
180
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
181
+ self.attn_drop = nn.Dropout(attn_drop)
182
+ self.proj = nn.Linear(dim, dim)
183
+ self.proj_drop = nn.Dropout(proj_drop)
184
+
185
+ def forward(self, x):
186
+ B, N, C = x.shape
187
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
188
+ q, k, v = qkv[0], qkv[1], qkv[2]
189
+
190
+ attn = (q @ k.transpose(-2, -1)) * self.scale
191
+ attn = attn.softmax(dim=-1)
192
+ attn = self.attn_drop(attn)
193
+
194
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
195
+ x = self.proj(x)
196
+ x = self.proj_drop(x)
197
+ return x, attn
198
+
199
+
200
+ class Block(nn.Module):
201
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
202
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
203
+ super().__init__()
204
+ self.norm1 = norm_layer(dim)
205
+ self.attn = Attention(
206
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
207
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
208
+ self.norm2 = norm_layer(dim)
209
+ mlp_hidden_dim = int(dim * mlp_ratio)
210
+ self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
211
+
212
+ def forward(self, x, return_attention=False):
213
+ y, attn = self.attn(self.norm1(x))
214
+ if return_attention:
215
+ return attn
216
+ x = x + self.drop_path(y)
217
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
218
+ return x
219
+
220
+
221
+ class PatchEmbed(nn.Module):
222
+ """ Image to Patch Embedding
223
+ """
224
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
225
+ super().__init__()
226
+ num_patches = (img_size // patch_size) * (img_size // patch_size)
227
+ self.img_size = img_size
228
+ self.patch_size = patch_size
229
+ self.num_patches = num_patches
230
+
231
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
232
+
233
+ def forward(self, x):
234
+ B, C, H, W = x.shape
235
+ x = self.proj(x).flatten(2).transpose(1, 2)
236
+ return x
237
+
238
+
239
+ class ConvEmbed(nn.Module):
240
+ """
241
+ 3x3 Convolution stems for ViT following ViTC models
242
+ """
243
+
244
+ def __init__(self, channels, strides, img_size=224, in_chans=3, batch_norm=True):
245
+ super().__init__()
246
+ # Build the stems
247
+ stem = []
248
+ channels = [in_chans] + channels
249
+ for i in range(len(channels) - 2):
250
+ stem += [nn.Conv2d(channels[i], channels[i+1], kernel_size=3,
251
+ stride=strides[i], padding=1, bias=(not batch_norm))]
252
+ if batch_norm:
253
+ stem += [nn.BatchNorm2d(channels[i+1])]
254
+ stem += [nn.ReLU(inplace=True)]
255
+ stem += [nn.Conv2d(channels[-2], channels[-1], kernel_size=1, stride=strides[-1])]
256
+ self.stem = nn.Sequential(*stem)
257
+
258
+ # Comptute the number of patches
259
+ stride_prod = int(np.prod(strides))
260
+ self.num_patches = (img_size[0] // stride_prod)**2
261
+
262
+ def forward(self, x):
263
+ p = self.stem(x)
264
+ return p.flatten(2).transpose(1, 2)
265
+
266
+
267
+ class VisionTransformerPredictor(nn.Module):
268
+ """ Vision Transformer """
269
+ def __init__(
270
+ self,
271
+ num_patches,
272
+ embed_dim=768,
273
+ predictor_embed_dim=384,
274
+ depth=6,
275
+ num_heads=12,
276
+ mlp_ratio=4.0,
277
+ qkv_bias=True,
278
+ qk_scale=None,
279
+ drop_rate=0.0,
280
+ attn_drop_rate=0.0,
281
+ drop_path_rate=0.0,
282
+ norm_layer=nn.LayerNorm,
283
+ init_std=0.02,
284
+ **kwargs
285
+ ):
286
+ super().__init__()
287
+ self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True)
288
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, predictor_embed_dim))
289
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
290
+ # --
291
+ self.predictor_pos_embed = nn.Parameter(torch.zeros(1, num_patches, predictor_embed_dim),
292
+ requires_grad=False)
293
+ predictor_pos_embed = get_2d_sincos_pos_embed(self.predictor_pos_embed.shape[-1],
294
+ int(num_patches**.5),
295
+ cls_token=False)
296
+ self.predictor_pos_embed.data.copy_(torch.from_numpy(predictor_pos_embed).float().unsqueeze(0))
297
+ # --
298
+ self.predictor_blocks = nn.ModuleList([
299
+ Block(
300
+ dim=predictor_embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
301
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
302
+ for i in range(depth)])
303
+ self.predictor_norm = norm_layer(predictor_embed_dim)
304
+ self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True)
305
+ # ------
306
+ self.init_std = init_std
307
+ trunc_normal_(self.mask_token, std=self.init_std)
308
+ self.apply(self._init_weights)
309
+ self.fix_init_weight()
310
+
311
+ def fix_init_weight(self):
312
+ def rescale(param, layer_id):
313
+ param.div_(math.sqrt(2.0 * layer_id))
314
+
315
+ for layer_id, layer in enumerate(self.predictor_blocks):
316
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
317
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
318
+
319
+ def _init_weights(self, m):
320
+ if isinstance(m, nn.Linear):
321
+ trunc_normal_(m.weight, std=self.init_std)
322
+ if isinstance(m, nn.Linear) and m.bias is not None:
323
+ nn.init.constant_(m.bias, 0)
324
+ elif isinstance(m, nn.LayerNorm):
325
+ nn.init.constant_(m.bias, 0)
326
+ nn.init.constant_(m.weight, 1.0)
327
+ elif isinstance(m, nn.Conv2d):
328
+ trunc_normal_(m.weight, std=self.init_std)
329
+ if m.bias is not None:
330
+ nn.init.constant_(m.bias, 0)
331
+
332
+ def forward(self, x, masks_x, masks):
333
+ assert (masks is not None) and (masks_x is not None), 'Cannot run predictor without mask indices'
334
+
335
+ if not isinstance(masks_x, list):
336
+ masks_x = [masks_x]
337
+
338
+ if not isinstance(masks, list):
339
+ masks = [masks]
340
+
341
+ # -- Batch Size
342
+ B = len(x) // len(masks_x)
343
+
344
+ # -- map from encoder-dim to pedictor-dim
345
+ x = self.predictor_embed(x)
346
+
347
+ # -- add positional embedding to x tokens
348
+ x_pos_embed = self.predictor_pos_embed.repeat(B, 1, 1)
349
+ x += apply_masks(x_pos_embed, masks_x)
350
+
351
+ _, N_ctxt, D = x.shape
352
+
353
+ # -- concat mask tokens to x
354
+ pos_embs = self.predictor_pos_embed.repeat(B, 1, 1)
355
+ pos_embs = apply_masks(pos_embs, masks)
356
+ pos_embs = repeat_interleave_batch(pos_embs, B, repeat=len(masks_x))
357
+ # --
358
+ pred_tokens = self.mask_token.repeat(pos_embs.size(0), pos_embs.size(1), 1)
359
+ # --
360
+ pred_tokens += pos_embs
361
+ x = x.repeat(len(masks), 1, 1)
362
+ x = torch.cat([x, pred_tokens], dim=1)
363
+
364
+ # -- fwd prop
365
+ for blk in self.predictor_blocks:
366
+ x = blk(x)
367
+ x = self.predictor_norm(x)
368
+
369
+ # -- return preds for mask tokens
370
+ x = x[:, N_ctxt:]
371
+ x = self.predictor_proj(x)
372
+
373
+ return x
374
+
375
+
376
+ class VisionTransformer(nn.Module):
377
+ """ Vision Transformer """
378
+ def __init__(
379
+ self,
380
+ img_size=[224],
381
+ patch_size=16,
382
+ in_chans=3,
383
+ embed_dim=768,
384
+ predictor_embed_dim=384,
385
+ depth=12,
386
+ predictor_depth=12,
387
+ num_heads=12,
388
+ mlp_ratio=4.0,
389
+ qkv_bias=True,
390
+ qk_scale=None,
391
+ drop_rate=0.0,
392
+ attn_drop_rate=0.0,
393
+ drop_path_rate=0.0,
394
+ norm_layer=nn.LayerNorm,
395
+ init_std=0.02,
396
+ **kwargs
397
+ ):
398
+ super().__init__()
399
+ self.num_features = self.embed_dim = embed_dim
400
+ self.num_heads = num_heads
401
+ # --
402
+ self.patch_embed = PatchEmbed(
403
+ img_size=img_size[0],
404
+ patch_size=patch_size,
405
+ in_chans=in_chans,
406
+ embed_dim=embed_dim)
407
+ num_patches = self.patch_embed.num_patches
408
+ # --
409
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim), requires_grad=False)
410
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1],
411
+ int(self.patch_embed.num_patches**.5),
412
+ cls_token=False)
413
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
414
+ # --
415
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
416
+ self.blocks = nn.ModuleList([
417
+ Block(
418
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
419
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
420
+ for i in range(depth)])
421
+ self.norm = norm_layer(embed_dim)
422
+ # ------
423
+ self.init_std = init_std
424
+ self.apply(self._init_weights)
425
+ self.fix_init_weight()
426
+
427
+ def fix_init_weight(self):
428
+ def rescale(param, layer_id):
429
+ param.div_(math.sqrt(2.0 * layer_id))
430
+
431
+ for layer_id, layer in enumerate(self.blocks):
432
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
433
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
434
+
435
+ def _init_weights(self, m):
436
+ if isinstance(m, nn.Linear):
437
+ trunc_normal_(m.weight, std=self.init_std)
438
+ if isinstance(m, nn.Linear) and m.bias is not None:
439
+ nn.init.constant_(m.bias, 0)
440
+ elif isinstance(m, nn.LayerNorm):
441
+ nn.init.constant_(m.bias, 0)
442
+ nn.init.constant_(m.weight, 1.0)
443
+ elif isinstance(m, nn.Conv2d):
444
+ trunc_normal_(m.weight, std=self.init_std)
445
+ if m.bias is not None:
446
+ nn.init.constant_(m.bias, 0)
447
+
448
+ def forward(self, x, masks=None):
449
+ if masks is not None:
450
+ if not isinstance(masks, list):
451
+ masks = [masks]
452
+
453
+ # -- patchify x
454
+ x = self.patch_embed(x)
455
+ B, N, D = x.shape
456
+
457
+ # -- add positional embedding to x
458
+ pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
459
+ x = x + pos_embed
460
+
461
+ # -- mask x
462
+ if masks is not None:
463
+ x = apply_masks(x, masks)
464
+
465
+ # -- fwd prop
466
+ for i, blk in enumerate(self.blocks):
467
+ x = blk(x)
468
+
469
+ if self.norm is not None:
470
+ x = self.norm(x)
471
+
472
+ return x
473
+
474
+ def interpolate_pos_encoding(self, x, pos_embed):
475
+ npatch = x.shape[1] - 1
476
+ N = pos_embed.shape[1] - 1
477
+ if npatch == N:
478
+ return pos_embed
479
+ class_emb = pos_embed[:, 0]
480
+ pos_embed = pos_embed[:, 1:]
481
+ dim = x.shape[-1]
482
+ pos_embed = nn.functional.interpolate(
483
+ pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
484
+ scale_factor=math.sqrt(npatch / N),
485
+ mode='bicubic',
486
+ )
487
+ pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
488
+ return torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)
489
+
490
+
491
+ def vit_predictor(**kwargs):
492
+ model = VisionTransformerPredictor(
493
+ mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
494
+ **kwargs)
495
+ return model
496
+
497
+
498
+ def vit_tiny(patch_size=16, **kwargs):
499
+ model = VisionTransformer(
500
+ patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
501
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
502
+ return model
503
+
504
+
505
+ def vit_small(patch_size=16, **kwargs):
506
+ model = VisionTransformer(
507
+ patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
508
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
509
+ return model
510
+
511
+
512
+ def vit_base(patch_size=16, **kwargs):
513
+ model = VisionTransformer(
514
+ patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
515
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
516
+ return model
517
+
518
+
519
+ def vit_large(patch_size=16, **kwargs):
520
+ model = VisionTransformer(
521
+ patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
522
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
523
+ return model
524
+
525
+
526
+ def vit_huge(patch_size=16, **kwargs):
527
+ model = VisionTransformer(
528
+ patch_size=patch_size, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4,
529
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
530
+ return model
531
+
532
+
533
+ def vit_giant(patch_size=16, **kwargs):
534
+ model = VisionTransformer(
535
+ patch_size=patch_size, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=48/11,
536
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
537
+ return model
538
+
539
+
540
+ VIT_EMBED_DIMS = {
541
+ 'vit_tiny': 192,
542
+ 'vit_small': 384,
543
+ 'vit_base': 768,
544
+ 'vit_large': 1024,
545
+ 'vit_huge': 1280,
546
+ 'vit_giant': 1408,
547
+ }
New/REG/models/mae_vit.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
9
+ # DeiT: https://github.com/facebookresearch/deit
10
+ # --------------------------------------------------------
11
+
12
+ from functools import partial
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+ import timm.models.vision_transformer
18
+
19
+
20
+ class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
21
+ """ Vision Transformer with support for global average pooling
22
+ """
23
+ def __init__(self, global_pool=False, **kwargs):
24
+ super(VisionTransformer, self).__init__(**kwargs)
25
+
26
+ self.global_pool = global_pool
27
+ if self.global_pool:
28
+ norm_layer = kwargs['norm_layer']
29
+ embed_dim = kwargs['embed_dim']
30
+ self.fc_norm = norm_layer(embed_dim)
31
+
32
+ del self.norm # remove the original norm
33
+
34
+ def forward_features(self, x):
35
+ B = x.shape[0]
36
+ x = self.patch_embed(x)
37
+
38
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
39
+ x = torch.cat((cls_tokens, x), dim=1)
40
+ x = x + self.pos_embed
41
+ x = self.pos_drop(x)
42
+
43
+ for blk in self.blocks:
44
+ x = blk(x)
45
+
46
+ x = x[:, 1:, :] #.mean(dim=1) # global pool without cls token
47
+
48
+ return x
49
+
50
+
51
+ def vit_base_patch16(**kwargs):
52
+ model = VisionTransformer(
53
+ num_classes=0,
54
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
55
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
56
+ return model
57
+
58
+
59
+ def vit_large_patch16(**kwargs):
60
+ model = VisionTransformer(
61
+ num_classes=0,
62
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
63
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
64
+ return model
65
+
66
+
67
+ def vit_huge_patch14(**kwargs):
68
+ model = VisionTransformer(
69
+ patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True,
70
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
71
+ return model
New/REG/models/mocov3_vit.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import torch
9
+ import torch.nn as nn
10
+ from functools import partial, reduce
11
+ from operator import mul
12
+
13
+ from timm.layers.helpers import to_2tuple
14
+ from timm.models.vision_transformer import VisionTransformer, _cfg
15
+ from timm.models.vision_transformer import PatchEmbed
16
+
17
+ __all__ = [
18
+ 'vit_small',
19
+ 'vit_base',
20
+ 'vit_large',
21
+ 'vit_conv_small',
22
+ 'vit_conv_base',
23
+ ]
24
+
25
+
26
+ def patchify_avg(input_tensor, patch_size):
27
+ # Ensure input tensor is 4D: (batch_size, channels, height, width)
28
+ if input_tensor.dim() != 4:
29
+ raise ValueError("Input tensor must be 4D (batch_size, channels, height, width)")
30
+
31
+ # Get input tensor dimensions
32
+ batch_size, channels, height, width = input_tensor.shape
33
+
34
+ # Ensure patch_size is valid
35
+ patch_height, patch_width = patch_size, patch_size
36
+ if height % patch_height != 0 or width % patch_width != 0:
37
+ raise ValueError("Input tensor dimensions must be divisible by patch_size")
38
+
39
+ # Use unfold to create patches
40
+ patches = input_tensor.unfold(2, patch_height, patch_height).unfold(3, patch_width, patch_width)
41
+
42
+ # Reshape patches to desired format: (batch_size, num_patches, channels)
43
+ patches = patches.contiguous().view(
44
+ batch_size, channels, -1, patch_height, patch_width
45
+ ).mean(dim=-1).mean(dim=-1)
46
+ patches = patches.permute(0, 2, 1).contiguous()
47
+
48
+ return patches
49
+
50
+
51
+
52
+ class VisionTransformerMoCo(VisionTransformer):
53
+ def __init__(self, stop_grad_conv1=False, **kwargs):
54
+ super().__init__(**kwargs)
55
+ # Use fixed 2D sin-cos position embedding
56
+ self.build_2d_sincos_position_embedding()
57
+
58
+ # weight initialization
59
+ for name, m in self.named_modules():
60
+ if isinstance(m, nn.Linear):
61
+ if 'qkv' in name:
62
+ # treat the weights of Q, K, V separately
63
+ val = math.sqrt(6. / float(m.weight.shape[0] // 3 + m.weight.shape[1]))
64
+ nn.init.uniform_(m.weight, -val, val)
65
+ else:
66
+ nn.init.xavier_uniform_(m.weight)
67
+ nn.init.zeros_(m.bias)
68
+ nn.init.normal_(self.cls_token, std=1e-6)
69
+
70
+ if isinstance(self.patch_embed, PatchEmbed):
71
+ # xavier_uniform initialization
72
+ val = math.sqrt(6. / float(3 * reduce(mul, self.patch_embed.patch_size, 1) + self.embed_dim))
73
+ nn.init.uniform_(self.patch_embed.proj.weight, -val, val)
74
+ nn.init.zeros_(self.patch_embed.proj.bias)
75
+
76
+ if stop_grad_conv1:
77
+ self.patch_embed.proj.weight.requires_grad = False
78
+ self.patch_embed.proj.bias.requires_grad = False
79
+
80
+ def build_2d_sincos_position_embedding(self, temperature=10000.):
81
+ h = self.patch_embed.img_size[0] // self.patch_embed.patch_size[0]
82
+ w = self.patch_embed.img_size[1] // self.patch_embed.patch_size[1]
83
+ grid_w = torch.arange(w, dtype=torch.float32)
84
+ grid_h = torch.arange(h, dtype=torch.float32)
85
+ grid_w, grid_h = torch.meshgrid(grid_w, grid_h)
86
+ assert self.embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
87
+ pos_dim = self.embed_dim // 4
88
+ omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
89
+ omega = 1. / (temperature**omega)
90
+ out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega])
91
+ out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])
92
+ pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :]
93
+
94
+ # assert self.num_tokens == 1, 'Assuming one and only one token, [cls]'
95
+ pe_token = torch.zeros([1, 1, self.embed_dim], dtype=torch.float32)
96
+ self.pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1))
97
+ self.pos_embed.requires_grad = False
98
+
99
+ def forward_diffusion_output(self, x):
100
+ x = x.reshape(*x.shape[0:2], -1).permute(0, 2, 1)
101
+ x = self._pos_embed(x)
102
+ x = self.patch_drop(x)
103
+ x = self.norm_pre(x)
104
+ x = self.blocks(x)
105
+ x = self.norm(x)
106
+ return x
107
+
108
+ class ConvStem(nn.Module):
109
+ """
110
+ ConvStem, from Early Convolutions Help Transformers See Better, Tete et al. https://arxiv.org/abs/2106.14881
111
+ """
112
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
113
+ super().__init__()
114
+
115
+ assert patch_size == 16, 'ConvStem only supports patch size of 16'
116
+ assert embed_dim % 8 == 0, 'Embed dimension must be divisible by 8 for ConvStem'
117
+
118
+ img_size = to_2tuple(img_size)
119
+ patch_size = to_2tuple(patch_size)
120
+ self.img_size = img_size
121
+ self.patch_size = patch_size
122
+ self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
123
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
124
+ self.flatten = flatten
125
+
126
+ # build stem, similar to the design in https://arxiv.org/abs/2106.14881
127
+ stem = []
128
+ input_dim, output_dim = 3, embed_dim // 8
129
+ for l in range(4):
130
+ stem.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=2, padding=1, bias=False))
131
+ stem.append(nn.BatchNorm2d(output_dim))
132
+ stem.append(nn.ReLU(inplace=True))
133
+ input_dim = output_dim
134
+ output_dim *= 2
135
+ stem.append(nn.Conv2d(input_dim, embed_dim, kernel_size=1))
136
+ self.proj = nn.Sequential(*stem)
137
+
138
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
139
+
140
+ def forward(self, x):
141
+ B, C, H, W = x.shape
142
+ assert H == self.img_size[0] and W == self.img_size[1], \
143
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
144
+ x = self.proj(x)
145
+ if self.flatten:
146
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
147
+ x = self.norm(x)
148
+ return x
149
+
150
+
151
+ def vit_small(**kwargs):
152
+ model = VisionTransformerMoCo(
153
+ img_size=256,
154
+ patch_size=16, embed_dim=384, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
155
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
156
+ model.default_cfg = _cfg()
157
+ return model
158
+
159
+ def vit_base(**kwargs):
160
+ model = VisionTransformerMoCo(
161
+ img_size=256,
162
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
163
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
164
+ model.default_cfg = _cfg()
165
+ return model
166
+
167
+ def vit_large(**kwargs):
168
+ model = VisionTransformerMoCo(
169
+ img_size=256,
170
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
171
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
172
+ model.default_cfg = _cfg()
173
+ return model
174
+
175
+ def vit_conv_small(**kwargs):
176
+ # minus one ViT block
177
+ model = VisionTransformerMoCo(
178
+ patch_size=16, embed_dim=384, depth=11, num_heads=12, mlp_ratio=4, qkv_bias=True,
179
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), embed_layer=ConvStem, **kwargs)
180
+ model.default_cfg = _cfg()
181
+ return model
182
+
183
+ def vit_conv_base(**kwargs):
184
+ # minus one ViT block
185
+ model = VisionTransformerMoCo(
186
+ patch_size=16, embed_dim=768, depth=11, num_heads=12, mlp_ratio=4, qkv_bias=True,
187
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), embed_layer=ConvStem, **kwargs)
188
+ model.default_cfg = _cfg()
189
+ return model
190
+
191
+ def build_mlp(num_layers, input_dim, mlp_dim, output_dim, last_bn=True):
192
+ mlp = []
193
+ for l in range(num_layers):
194
+ dim1 = input_dim if l == 0 else mlp_dim
195
+ dim2 = output_dim if l == num_layers - 1 else mlp_dim
196
+
197
+ mlp.append(nn.Linear(dim1, dim2, bias=False))
198
+
199
+ if l < num_layers - 1:
200
+ mlp.append(nn.BatchNorm1d(dim2))
201
+ mlp.append(nn.ReLU(inplace=True))
202
+ elif last_bn:
203
+ # follow SimCLR's design: https://github.com/google-research/simclr/blob/master/model_util.py#L157
204
+ # for simplicity, we further removed gamma in BN
205
+ mlp.append(nn.BatchNorm1d(dim2, affine=False))
206
+
207
+ return nn.Sequential(*mlp)
New/REG/models/sit.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This source code is licensed under the license found in the
2
+ # LICENSE file in the root directory of this source tree.
3
+ # --------------------------------------------------------
4
+ # References:
5
+ # GLIDE: https://github.com/openai/glide-text2im
6
+ # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
7
+ # --------------------------------------------------------
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import numpy as np
12
+ import math
13
+ from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
14
+
15
+
16
+ def build_mlp(hidden_size, projector_dim, z_dim):
17
+ return nn.Sequential(
18
+ nn.Linear(hidden_size, projector_dim),
19
+ nn.SiLU(),
20
+ nn.Linear(projector_dim, projector_dim),
21
+ nn.SiLU(),
22
+ nn.Linear(projector_dim, z_dim),
23
+ )
24
+
25
+ def modulate(x, shift, scale):
26
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
27
+
28
+ #################################################################################
29
+ # Embedding Layers for Timesteps and Class Labels #
30
+ #################################################################################
31
+ class TimestepEmbedder(nn.Module):
32
+ """
33
+ Embeds scalar timesteps into vector representations.
34
+ """
35
+ def __init__(self, hidden_size, frequency_embedding_size=256):
36
+ super().__init__()
37
+ self.mlp = nn.Sequential(
38
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
39
+ nn.SiLU(),
40
+ nn.Linear(hidden_size, hidden_size, bias=True),
41
+ )
42
+ self.frequency_embedding_size = frequency_embedding_size
43
+
44
+ @staticmethod
45
+ def positional_embedding(t, dim, max_period=10000):
46
+ """
47
+ Create sinusoidal timestep embeddings.
48
+ :param t: a 1-D Tensor of N indices, one per batch element.
49
+ These may be fractional.
50
+ :param dim: the dimension of the output.
51
+ :param max_period: controls the minimum frequency of the embeddings.
52
+ :return: an (N, D) Tensor of positional embeddings.
53
+ """
54
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
55
+ half = dim // 2
56
+ freqs = torch.exp(
57
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
58
+ ).to(device=t.device)
59
+ args = t[:, None].float() * freqs[None]
60
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
61
+ if dim % 2:
62
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
63
+ return embedding
64
+
65
+ def forward(self, t):
66
+ self.timestep_embedding = self.positional_embedding
67
+ t_freq = self.timestep_embedding(t, dim=self.frequency_embedding_size).to(t.dtype)
68
+ t_emb = self.mlp(t_freq)
69
+ return t_emb
70
+
71
+
72
+ class LabelEmbedder(nn.Module):
73
+ """
74
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
75
+ """
76
+ def __init__(self, num_classes, hidden_size, dropout_prob):
77
+ super().__init__()
78
+ use_cfg_embedding = dropout_prob > 0
79
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
80
+ self.num_classes = num_classes
81
+ self.dropout_prob = dropout_prob
82
+
83
+ def token_drop(self, labels, force_drop_ids=None):
84
+ """
85
+ Drops labels to enable classifier-free guidance.
86
+ """
87
+ if force_drop_ids is None:
88
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
89
+ else:
90
+ drop_ids = force_drop_ids == 1
91
+ labels = torch.where(drop_ids, self.num_classes, labels)
92
+ return labels
93
+
94
+ def forward(self, labels, train, force_drop_ids=None):
95
+ use_dropout = self.dropout_prob > 0
96
+ if (train and use_dropout) or (force_drop_ids is not None):
97
+ labels = self.token_drop(labels, force_drop_ids)
98
+ embeddings = self.embedding_table(labels)
99
+ return embeddings
100
+
101
+
102
+ #################################################################################
103
+ # Core SiT Model #
104
+ #################################################################################
105
+
106
+ class SiTBlock(nn.Module):
107
+ """
108
+ A SiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
109
+ """
110
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
111
+ super().__init__()
112
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
113
+ self.attn = Attention(
114
+ hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=block_kwargs["qk_norm"]
115
+ )
116
+ if "fused_attn" in block_kwargs.keys():
117
+ self.attn.fused_attn = block_kwargs["fused_attn"]
118
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
119
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
120
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
121
+ self.mlp = Mlp(
122
+ in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0
123
+ )
124
+ self.adaLN_modulation = nn.Sequential(
125
+ nn.SiLU(),
126
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
127
+ )
128
+
129
+ def forward(self, x, c):
130
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
131
+ self.adaLN_modulation(c).chunk(6, dim=-1)
132
+ )
133
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
134
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
135
+
136
+ return x
137
+
138
+
139
+ class FinalLayer(nn.Module):
140
+ """
141
+ The final layer of SiT.
142
+ """
143
+ def __init__(self, hidden_size, patch_size, out_channels, cls_token_dim):
144
+ super().__init__()
145
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
146
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
147
+ self.linear_cls = nn.Linear(hidden_size, cls_token_dim, bias=True)
148
+ self.adaLN_modulation = nn.Sequential(
149
+ nn.SiLU(),
150
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
151
+ )
152
+
153
+ def forward(self, x, c, cls=None):
154
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
155
+ x = modulate(self.norm_final(x), shift, scale)
156
+
157
+ if cls is None:
158
+ x = self.linear(x)
159
+ return x, None
160
+ else:
161
+ cls_token = self.linear_cls(x[:, 0]).unsqueeze(1)
162
+ x = self.linear(x[:, 1:])
163
+ return x, cls_token.squeeze(1)
164
+
165
+
166
+ class SiT(nn.Module):
167
+ """
168
+ Diffusion model with a Transformer backbone.
169
+ """
170
+ def __init__(
171
+ self,
172
+ path_type='edm',
173
+ input_size=32,
174
+ patch_size=2,
175
+ in_channels=4,
176
+ hidden_size=1152,
177
+ decoder_hidden_size=768,
178
+ encoder_depth=8,
179
+ depth=28,
180
+ num_heads=16,
181
+ mlp_ratio=4.0,
182
+ class_dropout_prob=0.1,
183
+ num_classes=1000,
184
+ use_cfg=False,
185
+ z_dims=[768],
186
+ projector_dim=2048,
187
+ cls_token_dim=768,
188
+ **block_kwargs # fused_attn
189
+ ):
190
+ super().__init__()
191
+ self.path_type = path_type
192
+ self.in_channels = in_channels
193
+ self.out_channels = in_channels
194
+ self.patch_size = patch_size
195
+ self.num_heads = num_heads
196
+ self.use_cfg = use_cfg
197
+ self.num_classes = num_classes
198
+ self.z_dims = z_dims
199
+ self.encoder_depth = encoder_depth
200
+
201
+ self.x_embedder = PatchEmbed(
202
+ input_size, patch_size, in_channels, hidden_size, bias=True
203
+ )
204
+ self.t_embedder = TimestepEmbedder(hidden_size) # timestep embedding type
205
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
206
+ num_patches = self.x_embedder.num_patches
207
+ # Will use fixed sin-cos embedding:
208
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, hidden_size), requires_grad=False)
209
+
210
+ self.blocks = nn.ModuleList([
211
+ SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, **block_kwargs) for _ in range(depth)
212
+ ])
213
+ self.projectors = nn.ModuleList([
214
+ build_mlp(hidden_size, projector_dim, z_dim) for z_dim in z_dims
215
+ ])
216
+
217
+ z_dim = self.z_dims[0]
218
+ cls_token_dim = z_dim
219
+ self.final_layer = FinalLayer(decoder_hidden_size, patch_size, self.out_channels, cls_token_dim)
220
+
221
+
222
+ self.cls_projectors2 = nn.Linear(in_features=cls_token_dim, out_features=hidden_size, bias=True)
223
+ self.wg_norm = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
224
+
225
+ self.initialize_weights()
226
+
227
+ def initialize_weights(self):
228
+ # Initialize transformer layers:
229
+ def _basic_init(module):
230
+ if isinstance(module, nn.Linear):
231
+ torch.nn.init.xavier_uniform_(module.weight)
232
+ if module.bias is not None:
233
+ nn.init.constant_(module.bias, 0)
234
+ self.apply(_basic_init)
235
+
236
+ # Initialize (and freeze) pos_embed by sin-cos embedding:
237
+ pos_embed = get_2d_sincos_pos_embed(
238
+ self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5), cls_token=1, extra_tokens=1
239
+ )
240
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
241
+
242
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
243
+ w = self.x_embedder.proj.weight.data
244
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
245
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
246
+
247
+ # Initialize label embedding table:
248
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
249
+
250
+ # Initialize timestep embedding MLP:
251
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
252
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
253
+
254
+ # Zero-out adaLN modulation layers in SiT blocks:
255
+ for block in self.blocks:
256
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
257
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
258
+
259
+ # Zero-out output layers:
260
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
261
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
262
+ nn.init.constant_(self.final_layer.linear.weight, 0)
263
+ nn.init.constant_(self.final_layer.linear.bias, 0)
264
+ nn.init.constant_(self.final_layer.linear_cls.weight, 0)
265
+ nn.init.constant_(self.final_layer.linear_cls.bias, 0)
266
+
267
+ def unpatchify(self, x, patch_size=None):
268
+ """
269
+ x: (N, T, patch_size**2 * C)
270
+ imgs: (N, C, H, W)
271
+ """
272
+ c = self.out_channels
273
+ p = self.x_embedder.patch_size[0] if patch_size is None else patch_size
274
+ h = w = int(x.shape[1] ** 0.5)
275
+ assert h * w == x.shape[1]
276
+
277
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
278
+ x = torch.einsum('nhwpqc->nchpwq', x)
279
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
280
+ return imgs
281
+
282
+ def forward(self, x, t, y, return_logvar=False, cls_token=None):
283
+ """
284
+ Forward pass of SiT.
285
+ x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
286
+ t: (N,) tensor of diffusion timesteps
287
+ y: (N,) tensor of class labels
288
+ """
289
+
290
+ #cat with cls_token
291
+ x = self.x_embedder(x) # (N, T, D), where T = H * W / patch_size ** 2
292
+ if cls_token is not None:
293
+ cls_token = self.cls_projectors2(cls_token)
294
+ cls_token = self.wg_norm(cls_token)
295
+ cls_token = cls_token.unsqueeze(1) # [b, length, d]
296
+ x = torch.cat((cls_token, x), dim=1)
297
+ x = x + self.pos_embed
298
+ else:
299
+ exit()
300
+ N, T, D = x.shape
301
+
302
+ # timestep and class embedding
303
+ t_embed = self.t_embedder(t) # (N, D)
304
+ y = self.y_embedder(y, self.training) # (N, D)
305
+ c = t_embed + y
306
+
307
+ for i, block in enumerate(self.blocks):
308
+ x = block(x, c)
309
+ if (i + 1) == self.encoder_depth:
310
+ zs = [projector(x.reshape(-1, D)).reshape(N, T, -1) for projector in self.projectors]
311
+
312
+ x, cls_token = self.final_layer(x, c, cls=cls_token)
313
+ x = self.unpatchify(x)
314
+
315
+ return x, zs, cls_token
316
+
317
+
318
+ #################################################################################
319
+ # Sine/Cosine Positional Embedding Functions #
320
+ #################################################################################
321
+ # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
322
+
323
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
324
+ """
325
+ grid_size: int of the grid height and width
326
+ return:
327
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
328
+ """
329
+ grid_h = np.arange(grid_size, dtype=np.float32)
330
+ grid_w = np.arange(grid_size, dtype=np.float32)
331
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
332
+ grid = np.stack(grid, axis=0)
333
+
334
+ grid = grid.reshape([2, 1, grid_size, grid_size])
335
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
336
+ if cls_token and extra_tokens > 0:
337
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
338
+ return pos_embed
339
+
340
+
341
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
342
+ assert embed_dim % 2 == 0
343
+
344
+ # use half of dimensions to encode grid_h
345
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
346
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
347
+
348
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
349
+ return emb
350
+
351
+
352
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
353
+ """
354
+ embed_dim: output dimension for each position
355
+ pos: a list of positions to be encoded: size (M,)
356
+ out: (M, D)
357
+ """
358
+ assert embed_dim % 2 == 0
359
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
360
+ omega /= embed_dim / 2.
361
+ omega = 1. / 10000**omega # (D/2,)
362
+
363
+ pos = pos.reshape(-1) # (M,)
364
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
365
+
366
+ emb_sin = np.sin(out) # (M, D/2)
367
+ emb_cos = np.cos(out) # (M, D/2)
368
+
369
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
370
+ return emb
371
+
372
+
373
+ #################################################################################
374
+ # SiT Configs #
375
+ #################################################################################
376
+
377
+ def SiT_XL_2(**kwargs):
378
+ return SiT(depth=28, hidden_size=1152, decoder_hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
379
+
380
+ def SiT_XL_4(**kwargs):
381
+ return SiT(depth=28, hidden_size=1152, decoder_hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
382
+
383
+ def SiT_XL_8(**kwargs):
384
+ return SiT(depth=28, hidden_size=1152, decoder_hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
385
+
386
+ def SiT_L_2(**kwargs):
387
+ return SiT(depth=24, hidden_size=1024, decoder_hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
388
+
389
+ def SiT_L_4(**kwargs):
390
+ return SiT(depth=24, hidden_size=1024, decoder_hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
391
+
392
+ def SiT_L_8(**kwargs):
393
+ return SiT(depth=24, hidden_size=1024, decoder_hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
394
+
395
+ def SiT_B_2(**kwargs):
396
+ return SiT(depth=12, hidden_size=768, decoder_hidden_size=768, patch_size=2, num_heads=12, **kwargs)
397
+
398
+ def SiT_B_4(**kwargs):
399
+ return SiT(depth=12, hidden_size=768, decoder_hidden_size=768, patch_size=4, num_heads=12, **kwargs)
400
+
401
+ def SiT_B_8(**kwargs):
402
+ return SiT(depth=12, hidden_size=768, decoder_hidden_size=768, patch_size=8, num_heads=12, **kwargs)
403
+
404
+ def SiT_S_2(**kwargs):
405
+ return SiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
406
+
407
+ def SiT_S_4(**kwargs):
408
+ return SiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
409
+
410
+ def SiT_S_8(**kwargs):
411
+ return SiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
412
+
413
+
414
+ SiT_models = {
415
+ 'SiT-XL/2': SiT_XL_2, 'SiT-XL/4': SiT_XL_4, 'SiT-XL/8': SiT_XL_8,
416
+ 'SiT-L/2': SiT_L_2, 'SiT-L/4': SiT_L_4, 'SiT-L/8': SiT_L_8,
417
+ 'SiT-B/2': SiT_B_2, 'SiT-B/4': SiT_B_4, 'SiT-B/8': SiT_B_8,
418
+ 'SiT-S/2': SiT_S_2, 'SiT-S/4': SiT_S_4, 'SiT-S/8': SiT_S_8,
419
+ }
420
+
New/REG/preprocessing/README.md ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <h1 align="center"> Preprocessing Guide
2
+ </h1>
3
+
4
+ #### Dataset download
5
+
6
+ We follow the preprocessing code used in [edm2](https://github.com/NVlabs/edm2). In this code we made a several edits: (1) we removed unncessary parts except preprocessing because this code is only used for preprocessing, (2) we use [-1, 1] range for an input to the stable diffusion VAE (similar to DiT or SiT) unlike edm2 that uses [0, 1] range, and (3) we consider preprocessing to 256x256 resolution (or 512x512 resolution).
7
+
8
+ After downloading ImageNet, please run the following scripts (please update 256x256 to 512x512 if you want to do experiments on 512x512 resolution);
9
+
10
+ Convert raw ImageNet data to a ZIP archive at 256x256 resolution
11
+ ```bash
12
+ bash dataset_prepare_encode.sh
13
+ ```
14
+
15
+ Convert the pixel data to VAE latents
16
+
17
+ ```bash
18
+ bash dataset_prepare_convert.sh
19
+ ```
20
+
21
+ Here,`YOUR_DOWNLOAD_PATH` is the directory that you downloaded the dataset, and `TARGET_PATH` is the directory that you will save the preprocessed images and corresponding compressed latent vectors. This directory will be used for your experiment scripts.
22
+
23
+ ## Acknowledgement
24
+
25
+ This code is mainly built upon [edm2](https://github.com/NVlabs/edm2) repository.
New/REG/preprocessing/dataset_image_encoder.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ """Tool for creating ZIP/PNG based datasets."""
9
+
10
+ from collections.abc import Iterator
11
+ from dataclasses import dataclass
12
+ import functools
13
+ import io
14
+ import json
15
+ import os
16
+ import re
17
+ import zipfile
18
+ from pathlib import Path
19
+ from typing import Callable, Optional, Tuple, Union
20
+ import click
21
+ import numpy as np
22
+ import PIL.Image
23
+ import torch
24
+ from tqdm import tqdm
25
+
26
+ from encoders import StabilityVAEEncoder
27
+ from utils import load_encoders
28
+ from torchvision.transforms import Normalize
29
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
30
+ CLIP_DEFAULT_MEAN = (0.48145466, 0.4578275, 0.40821073)
31
+ CLIP_DEFAULT_STD = (0.26862954, 0.26130258, 0.27577711)
32
+
33
+ def preprocess_raw_image(x, enc_type):
34
+ resolution = x.shape[-1]
35
+ if 'clip' in enc_type:
36
+ x = x / 255.
37
+ x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic')
38
+ x = Normalize(CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD)(x)
39
+ elif 'mocov3' in enc_type or 'mae' in enc_type:
40
+ x = x / 255.
41
+ x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
42
+ elif 'dinov2' in enc_type:
43
+ x = x / 255.
44
+ x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
45
+ x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic')
46
+ elif 'dinov1' in enc_type:
47
+ x = x / 255.
48
+ x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
49
+ elif 'jepa' in enc_type:
50
+ x = x / 255.
51
+ x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
52
+ x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic')
53
+
54
+ return x
55
+
56
+
57
+ #----------------------------------------------------------------------------
58
+
59
+ @dataclass
60
+ class ImageEntry:
61
+ img: np.ndarray
62
+ label: Optional[int]
63
+
64
+ #----------------------------------------------------------------------------
65
+ # Parse a 'M,N' or 'MxN' integer tuple.
66
+ # Example: '4x2' returns (4,2)
67
+
68
+ def parse_tuple(s: str) -> Tuple[int, int]:
69
+ m = re.match(r'^(\d+)[x,](\d+)$', s)
70
+ if m:
71
+ return int(m.group(1)), int(m.group(2))
72
+ raise click.ClickException(f'cannot parse tuple {s}')
73
+
74
+ #----------------------------------------------------------------------------
75
+
76
+ def maybe_min(a: int, b: Optional[int]) -> int:
77
+ if b is not None:
78
+ return min(a, b)
79
+ return a
80
+
81
+ #----------------------------------------------------------------------------
82
+
83
+ def file_ext(name: Union[str, Path]) -> str:
84
+ return str(name).split('.')[-1]
85
+
86
+ #----------------------------------------------------------------------------
87
+
88
+ def is_image_ext(fname: Union[str, Path]) -> bool:
89
+ ext = file_ext(fname).lower()
90
+ return f'.{ext}' in PIL.Image.EXTENSION
91
+
92
+ #----------------------------------------------------------------------------
93
+
94
+ def open_image_folder(source_dir, *, max_images: Optional[int]) -> tuple[int, Iterator[ImageEntry]]:
95
+ input_images = []
96
+ def _recurse_dirs(root: str): # workaround Path().rglob() slowness
97
+ with os.scandir(root) as it:
98
+ for e in it:
99
+ if e.is_file():
100
+ input_images.append(os.path.join(root, e.name))
101
+ elif e.is_dir():
102
+ _recurse_dirs(os.path.join(root, e.name))
103
+ _recurse_dirs(source_dir)
104
+ input_images = sorted([f for f in input_images if is_image_ext(f)])
105
+
106
+ arch_fnames = {fname: os.path.relpath(fname, source_dir).replace('\\', '/') for fname in input_images}
107
+ max_idx = maybe_min(len(input_images), max_images)
108
+
109
+ # Load labels.
110
+ labels = dict()
111
+ meta_fname = os.path.join(source_dir, 'dataset.json')
112
+ if os.path.isfile(meta_fname):
113
+ with open(meta_fname, 'r') as file:
114
+ data = json.load(file)['labels']
115
+ if data is not None:
116
+ labels = {x[0]: x[1] for x in data}
117
+
118
+ # No labels available => determine from top-level directory names.
119
+ if len(labels) == 0:
120
+ toplevel_names = {arch_fname: arch_fname.split('/')[0] if '/' in arch_fname else '' for arch_fname in arch_fnames.values()}
121
+ toplevel_indices = {toplevel_name: idx for idx, toplevel_name in enumerate(sorted(set(toplevel_names.values())))}
122
+ if len(toplevel_indices) > 1:
123
+ labels = {arch_fname: toplevel_indices[toplevel_name] for arch_fname, toplevel_name in toplevel_names.items()}
124
+
125
+ def iterate_images():
126
+ for idx, fname in enumerate(input_images):
127
+ img = np.array(PIL.Image.open(fname).convert('RGB'))#.transpose(2, 0, 1)
128
+ yield ImageEntry(img=img, label=labels.get(arch_fnames[fname]))
129
+ if idx >= max_idx - 1:
130
+ break
131
+ return max_idx, iterate_images()
132
+
133
+ #----------------------------------------------------------------------------
134
+
135
+ def open_image_zip(source, *, max_images: Optional[int]) -> tuple[int, Iterator[ImageEntry]]:
136
+ with zipfile.ZipFile(source, mode='r') as z:
137
+ input_images = [str(f) for f in sorted(z.namelist()) if is_image_ext(f)]
138
+ max_idx = maybe_min(len(input_images), max_images)
139
+
140
+ # Load labels.
141
+ labels = dict()
142
+ if 'dataset.json' in z.namelist():
143
+ with z.open('dataset.json', 'r') as file:
144
+ data = json.load(file)['labels']
145
+ if data is not None:
146
+ labels = {x[0]: x[1] for x in data}
147
+
148
+ def iterate_images():
149
+ with zipfile.ZipFile(source, mode='r') as z:
150
+ for idx, fname in enumerate(input_images):
151
+ with z.open(fname, 'r') as file:
152
+ img = np.array(PIL.Image.open(file).convert('RGB'))
153
+ yield ImageEntry(img=img, label=labels.get(fname))
154
+ if idx >= max_idx - 1:
155
+ break
156
+ return max_idx, iterate_images()
157
+
158
+ #----------------------------------------------------------------------------
159
+
160
+ def make_transform(
161
+ transform: Optional[str],
162
+ output_width: Optional[int],
163
+ output_height: Optional[int]
164
+ ) -> Callable[[np.ndarray], Optional[np.ndarray]]:
165
+ def scale(width, height, img):
166
+ w = img.shape[1]
167
+ h = img.shape[0]
168
+ if width == w and height == h:
169
+ return img
170
+ img = PIL.Image.fromarray(img, 'RGB')
171
+ ww = width if width is not None else w
172
+ hh = height if height is not None else h
173
+ img = img.resize((ww, hh), PIL.Image.Resampling.LANCZOS)
174
+ return np.array(img)
175
+
176
+ def center_crop(width, height, img):
177
+ crop = np.min(img.shape[:2])
178
+ img = img[(img.shape[0] - crop) // 2 : (img.shape[0] + crop) // 2, (img.shape[1] - crop) // 2 : (img.shape[1] + crop) // 2]
179
+ img = PIL.Image.fromarray(img, 'RGB')
180
+ img = img.resize((width, height), PIL.Image.Resampling.LANCZOS)
181
+ return np.array(img)
182
+
183
+ def center_crop_wide(width, height, img):
184
+ ch = int(np.round(width * img.shape[0] / img.shape[1]))
185
+ if img.shape[1] < width or ch < height:
186
+ return None
187
+
188
+ img = img[(img.shape[0] - ch) // 2 : (img.shape[0] + ch) // 2]
189
+ img = PIL.Image.fromarray(img, 'RGB')
190
+ img = img.resize((width, height), PIL.Image.Resampling.LANCZOS)
191
+ img = np.array(img)
192
+
193
+ canvas = np.zeros([width, width, 3], dtype=np.uint8)
194
+ canvas[(width - height) // 2 : (width + height) // 2, :] = img
195
+ return canvas
196
+
197
+ def center_crop_imagenet(image_size: int, arr: np.ndarray):
198
+ """
199
+ Center cropping implementation from ADM.
200
+ https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
201
+ """
202
+ pil_image = PIL.Image.fromarray(arr)
203
+ while min(*pil_image.size) >= 2 * image_size:
204
+ new_size = tuple(x // 2 for x in pil_image.size)
205
+ assert len(new_size) == 2
206
+ pil_image = pil_image.resize(new_size, resample=PIL.Image.Resampling.BOX)
207
+
208
+ scale = image_size / min(*pil_image.size)
209
+ new_size = tuple(round(x * scale) for x in pil_image.size)
210
+ assert len(new_size) == 2
211
+ pil_image = pil_image.resize(new_size, resample=PIL.Image.Resampling.BICUBIC)
212
+
213
+ arr = np.array(pil_image)
214
+ crop_y = (arr.shape[0] - image_size) // 2
215
+ crop_x = (arr.shape[1] - image_size) // 2
216
+ return arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]
217
+
218
+ if transform is None:
219
+ return functools.partial(scale, output_width, output_height)
220
+ if transform == 'center-crop':
221
+ if output_width is None or output_height is None:
222
+ raise click.ClickException('must specify --resolution=WxH when using ' + transform + 'transform')
223
+ return functools.partial(center_crop, output_width, output_height)
224
+ if transform == 'center-crop-wide':
225
+ if output_width is None or output_height is None:
226
+ raise click.ClickException('must specify --resolution=WxH when using ' + transform + ' transform')
227
+ return functools.partial(center_crop_wide, output_width, output_height)
228
+ if transform == 'center-crop-dhariwal':
229
+ if output_width is None or output_height is None:
230
+ raise click.ClickException('must specify --resolution=WxH when using ' + transform + ' transform')
231
+ if output_width != output_height:
232
+ raise click.ClickException('width and height must match in --resolution=WxH when using ' + transform + ' transform')
233
+ return functools.partial(center_crop_imagenet, output_width)
234
+ assert False, 'unknown transform'
235
+
236
+ #----------------------------------------------------------------------------
237
+
238
+ def open_dataset(source, *, max_images: Optional[int]):
239
+ if os.path.isdir(source):
240
+ return open_image_folder(source, max_images=max_images)
241
+ elif os.path.isfile(source):
242
+ if file_ext(source) == 'zip':
243
+ return open_image_zip(source, max_images=max_images)
244
+ else:
245
+ raise click.ClickException(f'Only zip archives are supported: {source}')
246
+ else:
247
+ raise click.ClickException(f'Missing input file or directory: {source}')
248
+
249
+ #----------------------------------------------------------------------------
250
+
251
+ def open_dest(dest: str) -> Tuple[str, Callable[[str, Union[bytes, str]], None], Callable[[], None]]:
252
+ dest_ext = file_ext(dest)
253
+
254
+ if dest_ext == 'zip':
255
+ if os.path.dirname(dest) != '':
256
+ os.makedirs(os.path.dirname(dest), exist_ok=True)
257
+ zf = zipfile.ZipFile(file=dest, mode='w', compression=zipfile.ZIP_STORED)
258
+ def zip_write_bytes(fname: str, data: Union[bytes, str]):
259
+ zf.writestr(fname, data)
260
+ return '', zip_write_bytes, zf.close
261
+ else:
262
+ # If the output folder already exists, check that is is
263
+ # empty.
264
+ #
265
+ # Note: creating the output directory is not strictly
266
+ # necessary as folder_write_bytes() also mkdirs, but it's better
267
+ # to give an error message earlier in case the dest folder
268
+ # somehow cannot be created.
269
+ if os.path.isdir(dest) and len(os.listdir(dest)) != 0:
270
+ raise click.ClickException('--dest folder must be empty')
271
+ os.makedirs(dest, exist_ok=True)
272
+
273
+ def folder_write_bytes(fname: str, data: Union[bytes, str]):
274
+ os.makedirs(os.path.dirname(fname), exist_ok=True)
275
+ with open(fname, 'wb') as fout:
276
+ if isinstance(data, str):
277
+ data = data.encode('utf8')
278
+ fout.write(data)
279
+ return dest, folder_write_bytes, lambda: None
280
+
281
+ #----------------------------------------------------------------------------
282
+
283
+ @click.group()
284
+ def cmdline():
285
+ '''Dataset processing tool for dataset image data conversion and VAE encode/decode preprocessing.'''
286
+ if os.environ.get('WORLD_SIZE', '1') != '1':
287
+ raise click.ClickException('Distributed execution is not supported.')
288
+
289
+
290
+ #----------------------------------------------------------------------------
291
+
292
+
293
+
294
+ @cmdline.command()
295
+ @click.option('--source', help='Input directory or archive name', metavar='PATH', type=str, required=True)
296
+ @click.option('--dest', help='Output directory or archive name', metavar='PATH', type=str, required=True)
297
+ @click.option('--max-images', help='Maximum number of images to output', metavar='INT', type=int)
298
+ @click.option('--enc-type', help='Maximum number of images to output', metavar='PATH', type=str, default='dinov2-vit-b')
299
+ @click.option('--resolution', help='Maximum number of images to output', metavar='INT', type=int, default=256)
300
+
301
+ def encode(
302
+ source: str,
303
+ dest: str,
304
+ max_images: Optional[int],
305
+ enc_type,
306
+ resolution
307
+ ):
308
+
309
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
310
+ encoder, encoder_type, architectures = load_encoders(enc_type, device, resolution)
311
+ encoder, encoder_type, architectures = encoder[0], encoder_type[0], architectures[0]
312
+ print("Encoder is over!!!")
313
+
314
+ """Encode pixel data to VAE latents."""
315
+ PIL.Image.init()
316
+ if dest == '':
317
+ raise click.ClickException('--dest output filename or directory must not be an empty string')
318
+
319
+ num_files, input_iter = open_dataset(source, max_images=max_images)
320
+ archive_root_dir, save_bytes, close_dest = open_dest(dest)
321
+ print("Data is over!!!")
322
+ labels = []
323
+
324
+ temp_list1 = []
325
+ temp_list2 = []
326
+ for idx, image in tqdm(enumerate(input_iter), total=num_files):
327
+ with torch.no_grad():
328
+ img_tensor = torch.tensor(image.img).to('cuda').permute(2, 0, 1).unsqueeze(0)
329
+ raw_image_ = preprocess_raw_image(img_tensor, encoder_type)
330
+ z = encoder.forward_features(raw_image_)
331
+ if 'dinov2' in encoder_type: z = z['x_norm_patchtokens']
332
+ temp_list1.append(z)
333
+ z = z.detach().cpu().numpy()
334
+ temp_list2.append(z)
335
+
336
+ idx_str = f'{idx:08d}'
337
+ archive_fname = f'{idx_str[:5]}/img-feature-{idx_str}.npy'
338
+
339
+ f = io.BytesIO()
340
+ np.save(f, z)
341
+ save_bytes(os.path.join(archive_root_dir, archive_fname), f.getvalue())
342
+ labels.append([archive_fname, image.label] if image.label is not None else None)
343
+
344
+
345
+ metadata = {'labels': labels if all(x is not None for x in labels) else None}
346
+ save_bytes(os.path.join(archive_root_dir, 'dataset.json'), json.dumps(metadata))
347
+ close_dest()
348
+
349
+ if __name__ == "__main__":
350
+ cmdline()
351
+
352
+
353
+ #----------------------------------------------------------------------------
New/REG/preprocessing/dataset_prepare_convert.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+
5
+
6
+ #256
7
+ python preprocessing/dataset_tools.py convert \
8
+ --source=/home/share/imagenet/train \
9
+ --dest=/home/share/imagenet_vae/imagenet_256_vae \
10
+ --resolution=256x256 \
11
+ --transform=center-crop-dhariwal
New/REG/preprocessing/dataset_prepare_encode.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+
5
+
6
+ #256
7
+ python preprocessing/dataset_tools.py encode \
8
+ --source=/home/share/imagenet_vae/imagenet_256_vae \
9
+ --dest=/home/share/imagenet_vae/vae-sd-256
New/REG/preprocessing/dataset_tools.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ """Tool for creating ZIP/PNG based datasets."""
9
+
10
+ from collections.abc import Iterator
11
+ from dataclasses import dataclass
12
+ import functools
13
+ import io
14
+ import json
15
+ import os
16
+ import re
17
+ import zipfile
18
+ from pathlib import Path
19
+ from typing import Callable, Optional, Tuple, Union
20
+ import click
21
+ import numpy as np
22
+ import PIL.Image
23
+ import torch
24
+ from tqdm import tqdm
25
+
26
+ from encoders import StabilityVAEEncoder
27
+
28
+ #----------------------------------------------------------------------------
29
+
30
+ @dataclass
31
+ class ImageEntry:
32
+ img: np.ndarray
33
+ label: Optional[int]
34
+
35
+ #----------------------------------------------------------------------------
36
+ # Parse a 'M,N' or 'MxN' integer tuple.
37
+ # Example: '4x2' returns (4,2)
38
+
39
+ def parse_tuple(s: str) -> Tuple[int, int]:
40
+ m = re.match(r'^(\d+)[x,](\d+)$', s)
41
+ if m:
42
+ return int(m.group(1)), int(m.group(2))
43
+ raise click.ClickException(f'cannot parse tuple {s}')
44
+
45
+ #----------------------------------------------------------------------------
46
+
47
+ def maybe_min(a: int, b: Optional[int]) -> int:
48
+ if b is not None:
49
+ return min(a, b)
50
+ return a
51
+
52
+ #----------------------------------------------------------------------------
53
+
54
+ def file_ext(name: Union[str, Path]) -> str:
55
+ return str(name).split('.')[-1]
56
+
57
+ #----------------------------------------------------------------------------
58
+
59
+ def is_image_ext(fname: Union[str, Path]) -> bool:
60
+ ext = file_ext(fname).lower()
61
+ return f'.{ext}' in PIL.Image.EXTENSION
62
+
63
+ #----------------------------------------------------------------------------
64
+
65
+ def open_image_folder(source_dir, *, max_images: Optional[int]) -> tuple[int, Iterator[ImageEntry]]:
66
+ input_images = []
67
+ def _recurse_dirs(root: str): # workaround Path().rglob() slowness
68
+ with os.scandir(root) as it:
69
+ for e in it:
70
+ if e.is_file():
71
+ input_images.append(os.path.join(root, e.name))
72
+ elif e.is_dir():
73
+ _recurse_dirs(os.path.join(root, e.name))
74
+ _recurse_dirs(source_dir)
75
+ input_images = sorted([f for f in input_images if is_image_ext(f)])
76
+
77
+ arch_fnames = {fname: os.path.relpath(fname, source_dir).replace('\\', '/') for fname in input_images}
78
+ max_idx = maybe_min(len(input_images), max_images)
79
+
80
+ # Load labels.
81
+ labels = dict()
82
+ meta_fname = os.path.join(source_dir, 'dataset.json')
83
+ if os.path.isfile(meta_fname):
84
+ with open(meta_fname, 'r') as file:
85
+ data = json.load(file)['labels']
86
+ if data is not None:
87
+ labels = {x[0]: x[1] for x in data}
88
+
89
+ # No labels available => determine from top-level directory names.
90
+ if len(labels) == 0:
91
+ toplevel_names = {arch_fname: arch_fname.split('/')[0] if '/' in arch_fname else '' for arch_fname in arch_fnames.values()}
92
+ toplevel_indices = {toplevel_name: idx for idx, toplevel_name in enumerate(sorted(set(toplevel_names.values())))}
93
+ if len(toplevel_indices) > 1:
94
+ labels = {arch_fname: toplevel_indices[toplevel_name] for arch_fname, toplevel_name in toplevel_names.items()}
95
+
96
+ def iterate_images():
97
+ for idx, fname in enumerate(input_images):
98
+ img = np.array(PIL.Image.open(fname).convert('RGB'))
99
+ yield ImageEntry(img=img, label=labels.get(arch_fnames[fname]))
100
+ if idx >= max_idx - 1:
101
+ break
102
+ return max_idx, iterate_images()
103
+
104
+ #----------------------------------------------------------------------------
105
+
106
+ def open_image_zip(source, *, max_images: Optional[int]) -> tuple[int, Iterator[ImageEntry]]:
107
+ with zipfile.ZipFile(source, mode='r') as z:
108
+ input_images = [str(f) for f in sorted(z.namelist()) if is_image_ext(f)]
109
+ max_idx = maybe_min(len(input_images), max_images)
110
+
111
+ # Load labels.
112
+ labels = dict()
113
+ if 'dataset.json' in z.namelist():
114
+ with z.open('dataset.json', 'r') as file:
115
+ data = json.load(file)['labels']
116
+ if data is not None:
117
+ labels = {x[0]: x[1] for x in data}
118
+
119
+ def iterate_images():
120
+ with zipfile.ZipFile(source, mode='r') as z:
121
+ for idx, fname in enumerate(input_images):
122
+ with z.open(fname, 'r') as file:
123
+ img = np.array(PIL.Image.open(file).convert('RGB'))
124
+ yield ImageEntry(img=img, label=labels.get(fname))
125
+ if idx >= max_idx - 1:
126
+ break
127
+ return max_idx, iterate_images()
128
+
129
+ #----------------------------------------------------------------------------
130
+
131
+ def make_transform(
132
+ transform: Optional[str],
133
+ output_width: Optional[int],
134
+ output_height: Optional[int]
135
+ ) -> Callable[[np.ndarray], Optional[np.ndarray]]:
136
+ def scale(width, height, img):
137
+ w = img.shape[1]
138
+ h = img.shape[0]
139
+ if width == w and height == h:
140
+ return img
141
+ img = PIL.Image.fromarray(img, 'RGB')
142
+ ww = width if width is not None else w
143
+ hh = height if height is not None else h
144
+ img = img.resize((ww, hh), PIL.Image.Resampling.LANCZOS)
145
+ return np.array(img)
146
+
147
+ def center_crop(width, height, img):
148
+ crop = np.min(img.shape[:2])
149
+ img = img[(img.shape[0] - crop) // 2 : (img.shape[0] + crop) // 2, (img.shape[1] - crop) // 2 : (img.shape[1] + crop) // 2]
150
+ img = PIL.Image.fromarray(img, 'RGB')
151
+ img = img.resize((width, height), PIL.Image.Resampling.LANCZOS)
152
+ return np.array(img)
153
+
154
+ def center_crop_wide(width, height, img):
155
+ ch = int(np.round(width * img.shape[0] / img.shape[1]))
156
+ if img.shape[1] < width or ch < height:
157
+ return None
158
+
159
+ img = img[(img.shape[0] - ch) // 2 : (img.shape[0] + ch) // 2]
160
+ img = PIL.Image.fromarray(img, 'RGB')
161
+ img = img.resize((width, height), PIL.Image.Resampling.LANCZOS)
162
+ img = np.array(img)
163
+
164
+ canvas = np.zeros([width, width, 3], dtype=np.uint8)
165
+ canvas[(width - height) // 2 : (width + height) // 2, :] = img
166
+ return canvas
167
+
168
+ def center_crop_imagenet(image_size: int, arr: np.ndarray):
169
+ """
170
+ Center cropping implementation from ADM.
171
+ https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
172
+ """
173
+ pil_image = PIL.Image.fromarray(arr)
174
+ while min(*pil_image.size) >= 2 * image_size:
175
+ new_size = tuple(x // 2 for x in pil_image.size)
176
+ assert len(new_size) == 2
177
+ pil_image = pil_image.resize(new_size, resample=PIL.Image.Resampling.BOX)
178
+
179
+ scale = image_size / min(*pil_image.size)
180
+ new_size = tuple(round(x * scale) for x in pil_image.size)
181
+ assert len(new_size) == 2
182
+ pil_image = pil_image.resize(new_size, resample=PIL.Image.Resampling.BICUBIC)
183
+
184
+ arr = np.array(pil_image)
185
+ crop_y = (arr.shape[0] - image_size) // 2
186
+ crop_x = (arr.shape[1] - image_size) // 2
187
+ return arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]
188
+
189
+ if transform is None:
190
+ return functools.partial(scale, output_width, output_height)
191
+ if transform == 'center-crop':
192
+ if output_width is None or output_height is None:
193
+ raise click.ClickException('must specify --resolution=WxH when using ' + transform + 'transform')
194
+ return functools.partial(center_crop, output_width, output_height)
195
+ if transform == 'center-crop-wide':
196
+ if output_width is None or output_height is None:
197
+ raise click.ClickException('must specify --resolution=WxH when using ' + transform + ' transform')
198
+ return functools.partial(center_crop_wide, output_width, output_height)
199
+ if transform == 'center-crop-dhariwal':
200
+ if output_width is None or output_height is None:
201
+ raise click.ClickException('must specify --resolution=WxH when using ' + transform + ' transform')
202
+ if output_width != output_height:
203
+ raise click.ClickException('width and height must match in --resolution=WxH when using ' + transform + ' transform')
204
+ return functools.partial(center_crop_imagenet, output_width)
205
+ assert False, 'unknown transform'
206
+
207
+ #----------------------------------------------------------------------------
208
+
209
+ def open_dataset(source, *, max_images: Optional[int]):
210
+ if os.path.isdir(source):
211
+ return open_image_folder(source, max_images=max_images)
212
+ elif os.path.isfile(source):
213
+ if file_ext(source) == 'zip':
214
+ return open_image_zip(source, max_images=max_images)
215
+ else:
216
+ raise click.ClickException(f'Only zip archives are supported: {source}')
217
+ else:
218
+ raise click.ClickException(f'Missing input file or directory: {source}')
219
+
220
+ #----------------------------------------------------------------------------
221
+
222
+ def open_dest(dest: str) -> Tuple[str, Callable[[str, Union[bytes, str]], None], Callable[[], None]]:
223
+ dest_ext = file_ext(dest)
224
+
225
+ if dest_ext == 'zip':
226
+ if os.path.dirname(dest) != '':
227
+ os.makedirs(os.path.dirname(dest), exist_ok=True)
228
+ zf = zipfile.ZipFile(file=dest, mode='w', compression=zipfile.ZIP_STORED)
229
+ def zip_write_bytes(fname: str, data: Union[bytes, str]):
230
+ zf.writestr(fname, data)
231
+ return '', zip_write_bytes, zf.close
232
+ else:
233
+ # If the output folder already exists, check that is is
234
+ # empty.
235
+ #
236
+ # Note: creating the output directory is not strictly
237
+ # necessary as folder_write_bytes() also mkdirs, but it's better
238
+ # to give an error message earlier in case the dest folder
239
+ # somehow cannot be created.
240
+ if os.path.isdir(dest) and len(os.listdir(dest)) != 0:
241
+ raise click.ClickException('--dest folder must be empty')
242
+ os.makedirs(dest, exist_ok=True)
243
+
244
+ def folder_write_bytes(fname: str, data: Union[bytes, str]):
245
+ os.makedirs(os.path.dirname(fname), exist_ok=True)
246
+ with open(fname, 'wb') as fout:
247
+ if isinstance(data, str):
248
+ data = data.encode('utf8')
249
+ fout.write(data)
250
+ return dest, folder_write_bytes, lambda: None
251
+
252
+ #----------------------------------------------------------------------------
253
+
254
+ @click.group()
255
+ def cmdline():
256
+ '''Dataset processing tool for dataset image data conversion and VAE encode/decode preprocessing.'''
257
+ if os.environ.get('WORLD_SIZE', '1') != '1':
258
+ raise click.ClickException('Distributed execution is not supported.')
259
+
260
+ #----------------------------------------------------------------------------
261
+
262
+ @cmdline.command()
263
+ @click.option('--source', help='Input directory or archive name', metavar='PATH', type=str, required=True)
264
+ @click.option('--dest', help='Output directory or archive name', metavar='PATH', type=str, required=True)
265
+ @click.option('--max-images', help='Maximum number of images to output', metavar='INT', type=int)
266
+ @click.option('--transform', help='Input crop/resize mode', metavar='MODE', type=click.Choice(['center-crop', 'center-crop-wide', 'center-crop-dhariwal']))
267
+ @click.option('--resolution', help='Output resolution (e.g., 512x512)', metavar='WxH', type=parse_tuple)
268
+
269
+ def convert(
270
+ source: str,
271
+ dest: str,
272
+ max_images: Optional[int],
273
+ transform: Optional[str],
274
+ resolution: Optional[Tuple[int, int]]
275
+ ):
276
+ """Convert an image dataset into archive format for training.
277
+
278
+ Specifying the input images:
279
+
280
+ \b
281
+ --source path/ Recursively load all images from path/
282
+ --source dataset.zip Load all images from dataset.zip
283
+
284
+ Specifying the output format and path:
285
+
286
+ \b
287
+ --dest /path/to/dir Save output files under /path/to/dir
288
+ --dest /path/to/dataset.zip Save output files into /path/to/dataset.zip
289
+
290
+ The output dataset format can be either an image folder or an uncompressed zip archive.
291
+ Zip archives makes it easier to move datasets around file servers and clusters, and may
292
+ offer better training performance on network file systems.
293
+
294
+ Images within the dataset archive will be stored as uncompressed PNG.
295
+ Uncompresed PNGs can be efficiently decoded in the training loop.
296
+
297
+ Class labels are stored in a file called 'dataset.json' that is stored at the
298
+ dataset root folder. This file has the following structure:
299
+
300
+ \b
301
+ {
302
+ "labels": [
303
+ ["00000/img00000000.png",6],
304
+ ["00000/img00000001.png",9],
305
+ ... repeated for every image in the datase
306
+ ["00049/img00049999.png",1]
307
+ ]
308
+ }
309
+
310
+ If the 'dataset.json' file cannot be found, class labels are determined from
311
+ top-level directory names.
312
+
313
+ Image scale/crop and resolution requirements:
314
+
315
+ Output images must be square-shaped and they must all have the same power-of-two
316
+ dimensions.
317
+
318
+ To scale arbitrary input image size to a specific width and height, use the
319
+ --resolution option. Output resolution will be either the original
320
+ input resolution (if resolution was not specified) or the one specified with
321
+ --resolution option.
322
+
323
+ The --transform=center-crop-dhariwal selects a crop/rescale mode that is intended
324
+ to exactly match with results obtained for ImageNet in common diffusion model literature:
325
+
326
+ \b
327
+ python dataset_tool.py convert --source=downloads/imagenet/ILSVRC/Data/CLS-LOC/train \\
328
+ --dest=datasets/img64.zip --resolution=64x64 --transform=center-crop-dhariwal
329
+ """
330
+ PIL.Image.init()
331
+ if dest == '':
332
+ raise click.ClickException('--dest output filename or directory must not be an empty string')
333
+ print("Begin!!!!!!!!")
334
+ num_files, input_iter = open_dataset(source, max_images=max_images)
335
+ print("open_dataset is over")
336
+ archive_root_dir, save_bytes, close_dest = open_dest(dest)
337
+ print("open_dest is over")
338
+ transform_image = make_transform(transform, *resolution if resolution is not None else (None, None))
339
+ dataset_attrs = None
340
+
341
+ labels = []
342
+ for idx, image in tqdm(enumerate(input_iter), total=num_files):
343
+ idx_str = f'{idx:08d}'
344
+ archive_fname = f'{idx_str[:5]}/img{idx_str}.png'
345
+
346
+ # Apply crop and resize.
347
+ img = transform_image(image.img)
348
+ if img is None:
349
+ continue
350
+
351
+ # Error check to require uniform image attributes across
352
+ # the whole dataset.
353
+ assert img.ndim == 3
354
+ cur_image_attrs = {'width': img.shape[1], 'height': img.shape[0]}
355
+ if dataset_attrs is None:
356
+ dataset_attrs = cur_image_attrs
357
+ width = dataset_attrs['width']
358
+ height = dataset_attrs['height']
359
+ if width != height:
360
+ raise click.ClickException(f'Image dimensions after scale and crop are required to be square. Got {width}x{height}')
361
+ if width != 2 ** int(np.floor(np.log2(width))):
362
+ raise click.ClickException('Image width/height after scale and crop are required to be power-of-two')
363
+ elif dataset_attrs != cur_image_attrs:
364
+ err = [f' dataset {k}/cur image {k}: {dataset_attrs[k]}/{cur_image_attrs[k]}' for k in dataset_attrs.keys()]
365
+ raise click.ClickException(f'Image {archive_fname} attributes must be equal across all images of the dataset. Got:\n' + '\n'.join(err))
366
+
367
+ # Save the image as an uncompressed PNG.
368
+ img = PIL.Image.fromarray(img)
369
+ image_bits = io.BytesIO()
370
+ img.save(image_bits, format='png', compress_level=0, optimize=False)
371
+ save_bytes(os.path.join(archive_root_dir, archive_fname), image_bits.getbuffer())
372
+ labels.append([archive_fname, image.label] if image.label is not None else None)
373
+
374
+ metadata = {'labels': labels if all(x is not None for x in labels) else None}
375
+ save_bytes(os.path.join(archive_root_dir, 'dataset.json'), json.dumps(metadata))
376
+ close_dest()
377
+
378
+ #----------------------------------------------------------------------------
379
+
380
+ @cmdline.command()
381
+ @click.option('--model-url', help='VAE encoder model', metavar='URL', type=str, default='stabilityai/sd-vae-ft-mse', show_default=True)
382
+ @click.option('--source', help='Input directory or archive name', metavar='PATH', type=str, required=True)
383
+ @click.option('--dest', help='Output directory or archive name', metavar='PATH', type=str, required=True)
384
+ @click.option('--max-images', help='Maximum number of images to output', metavar='INT', type=int)
385
+
386
+ def encode(
387
+ model_url: str,
388
+ source: str,
389
+ dest: str,
390
+ max_images: Optional[int],
391
+ ):
392
+ """Encode pixel data to VAE latents."""
393
+ PIL.Image.init()
394
+ if dest == '':
395
+ raise click.ClickException('--dest output filename or directory must not be an empty string')
396
+
397
+ vae = StabilityVAEEncoder(vae_name=model_url, batch_size=1)
398
+ print("VAE is over!!!")
399
+ num_files, input_iter = open_dataset(source, max_images=max_images)
400
+ archive_root_dir, save_bytes, close_dest = open_dest(dest)
401
+ print("Data is over!!!")
402
+ labels = []
403
+ #device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
404
+ for idx, image in tqdm(enumerate(input_iter), total=num_files):
405
+ img_tensor = torch.tensor(image.img).to('cuda').permute(2, 0, 1).unsqueeze(0)
406
+ mean_std = vae.encode_pixels(img_tensor)[0].cpu()
407
+ idx_str = f'{idx:08d}'
408
+ archive_fname = f'{idx_str[:5]}/img-mean-std-{idx_str}.npy'
409
+
410
+ f = io.BytesIO()
411
+ np.save(f, mean_std)
412
+ save_bytes(os.path.join(archive_root_dir, archive_fname), f.getvalue())
413
+ labels.append([archive_fname, image.label] if image.label is not None else None)
414
+
415
+ metadata = {'labels': labels if all(x is not None for x in labels) else None}
416
+ save_bytes(os.path.join(archive_root_dir, 'dataset.json'), json.dumps(metadata))
417
+ close_dest()
418
+
419
+ if __name__ == "__main__":
420
+ cmdline()
421
+
422
+ #----------------------------------------------------------------------------
New/REG/preprocessing/dnnlib/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ from .util import EasyDict, make_cache_dir_path
New/REG/preprocessing/dnnlib/util.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ """Miscellaneous utility classes and functions."""
9
+
10
+ import ctypes
11
+ import fnmatch
12
+ import importlib
13
+ import inspect
14
+ import numpy as np
15
+ import os
16
+ import shutil
17
+ import sys
18
+ import types
19
+ import io
20
+ import pickle
21
+ import re
22
+ import requests
23
+ import html
24
+ import hashlib
25
+ import glob
26
+ import tempfile
27
+ import urllib
28
+ import urllib.parse
29
+ import uuid
30
+
31
+ from typing import Any, Callable, BinaryIO, List, Tuple, Union, Optional
32
+
33
+ # Util classes
34
+ # ------------------------------------------------------------------------------------------
35
+
36
+
37
+ class EasyDict(dict):
38
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
39
+
40
+ def __getattr__(self, name: str) -> Any:
41
+ try:
42
+ return self[name]
43
+ except KeyError:
44
+ raise AttributeError(name)
45
+
46
+ def __setattr__(self, name: str, value: Any) -> None:
47
+ self[name] = value
48
+
49
+ def __delattr__(self, name: str) -> None:
50
+ del self[name]
51
+
52
+
53
+ class Logger(object):
54
+ """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
55
+
56
+ def __init__(self, file_name: Optional[str] = None, file_mode: str = "w", should_flush: bool = True):
57
+ self.file = None
58
+
59
+ if file_name is not None:
60
+ self.file = open(file_name, file_mode)
61
+
62
+ self.should_flush = should_flush
63
+ self.stdout = sys.stdout
64
+ self.stderr = sys.stderr
65
+
66
+ sys.stdout = self
67
+ sys.stderr = self
68
+
69
+ def __enter__(self) -> "Logger":
70
+ return self
71
+
72
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
73
+ self.close()
74
+
75
+ def write(self, text: Union[str, bytes]) -> None:
76
+ """Write text to stdout (and a file) and optionally flush."""
77
+ if isinstance(text, bytes):
78
+ text = text.decode()
79
+ if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
80
+ return
81
+
82
+ if self.file is not None:
83
+ self.file.write(text)
84
+
85
+ self.stdout.write(text)
86
+
87
+ if self.should_flush:
88
+ self.flush()
89
+
90
+ def flush(self) -> None:
91
+ """Flush written text to both stdout and a file, if open."""
92
+ if self.file is not None:
93
+ self.file.flush()
94
+
95
+ self.stdout.flush()
96
+
97
+ def close(self) -> None:
98
+ """Flush, close possible files, and remove stdout/stderr mirroring."""
99
+ self.flush()
100
+
101
+ # if using multiple loggers, prevent closing in wrong order
102
+ if sys.stdout is self:
103
+ sys.stdout = self.stdout
104
+ if sys.stderr is self:
105
+ sys.stderr = self.stderr
106
+
107
+ if self.file is not None:
108
+ self.file.close()
109
+ self.file = None
110
+
111
+
112
+ # Cache directories
113
+ # ------------------------------------------------------------------------------------------
114
+
115
+ _dnnlib_cache_dir = None
116
+
117
+ def set_cache_dir(path: str) -> None:
118
+ global _dnnlib_cache_dir
119
+ _dnnlib_cache_dir = path
120
+
121
+ def make_cache_dir_path(*paths: str) -> str:
122
+ if _dnnlib_cache_dir is not None:
123
+ return os.path.join(_dnnlib_cache_dir, *paths)
124
+ if 'DNNLIB_CACHE_DIR' in os.environ:
125
+ return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
126
+ if 'HOME' in os.environ:
127
+ return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
128
+ if 'USERPROFILE' in os.environ:
129
+ return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
130
+ return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
131
+
132
+ # Small util functions
133
+ # ------------------------------------------------------------------------------------------
134
+
135
+
136
+ def format_time(seconds: Union[int, float]) -> str:
137
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
138
+ s = int(np.rint(seconds))
139
+
140
+ if s < 60:
141
+ return "{0}s".format(s)
142
+ elif s < 60 * 60:
143
+ return "{0}m {1:02}s".format(s // 60, s % 60)
144
+ elif s < 24 * 60 * 60:
145
+ return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
146
+ else:
147
+ return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
148
+
149
+
150
+ def format_time_brief(seconds: Union[int, float]) -> str:
151
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
152
+ s = int(np.rint(seconds))
153
+
154
+ if s < 60:
155
+ return "{0}s".format(s)
156
+ elif s < 60 * 60:
157
+ return "{0}m {1:02}s".format(s // 60, s % 60)
158
+ elif s < 24 * 60 * 60:
159
+ return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
160
+ else:
161
+ return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
162
+
163
+
164
+ def tuple_product(t: Tuple) -> Any:
165
+ """Calculate the product of the tuple elements."""
166
+ result = 1
167
+
168
+ for v in t:
169
+ result *= v
170
+
171
+ return result
172
+
173
+
174
+ _str_to_ctype = {
175
+ "uint8": ctypes.c_ubyte,
176
+ "uint16": ctypes.c_uint16,
177
+ "uint32": ctypes.c_uint32,
178
+ "uint64": ctypes.c_uint64,
179
+ "int8": ctypes.c_byte,
180
+ "int16": ctypes.c_int16,
181
+ "int32": ctypes.c_int32,
182
+ "int64": ctypes.c_int64,
183
+ "float32": ctypes.c_float,
184
+ "float64": ctypes.c_double
185
+ }
186
+
187
+
188
+ def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
189
+ """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
190
+ type_str = None
191
+
192
+ if isinstance(type_obj, str):
193
+ type_str = type_obj
194
+ elif hasattr(type_obj, "__name__"):
195
+ type_str = type_obj.__name__
196
+ elif hasattr(type_obj, "name"):
197
+ type_str = type_obj.name
198
+ else:
199
+ raise RuntimeError("Cannot infer type name from input")
200
+
201
+ assert type_str in _str_to_ctype.keys()
202
+
203
+ my_dtype = np.dtype(type_str)
204
+ my_ctype = _str_to_ctype[type_str]
205
+
206
+ assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
207
+
208
+ return my_dtype, my_ctype
209
+
210
+
211
+ def is_pickleable(obj: Any) -> bool:
212
+ try:
213
+ with io.BytesIO() as stream:
214
+ pickle.dump(obj, stream)
215
+ return True
216
+ except:
217
+ return False
218
+
219
+
220
+ # Functionality to import modules/objects by name, and call functions by name
221
+ # ------------------------------------------------------------------------------------------
222
+
223
+ def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
224
+ """Searches for the underlying module behind the name to some python object.
225
+ Returns the module and the object name (original name with module part removed)."""
226
+
227
+ # allow convenience shorthands, substitute them by full names
228
+ obj_name = re.sub("^np.", "numpy.", obj_name)
229
+ obj_name = re.sub("^tf.", "tensorflow.", obj_name)
230
+
231
+ # list alternatives for (module_name, local_obj_name)
232
+ parts = obj_name.split(".")
233
+ name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
234
+
235
+ # try each alternative in turn
236
+ for module_name, local_obj_name in name_pairs:
237
+ try:
238
+ module = importlib.import_module(module_name) # may raise ImportError
239
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
240
+ return module, local_obj_name
241
+ except:
242
+ pass
243
+
244
+ # maybe some of the modules themselves contain errors?
245
+ for module_name, _local_obj_name in name_pairs:
246
+ try:
247
+ importlib.import_module(module_name) # may raise ImportError
248
+ except ImportError:
249
+ if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
250
+ raise
251
+
252
+ # maybe the requested attribute is missing?
253
+ for module_name, local_obj_name in name_pairs:
254
+ try:
255
+ module = importlib.import_module(module_name) # may raise ImportError
256
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
257
+ except ImportError:
258
+ pass
259
+
260
+ # we are out of luck, but we have no idea why
261
+ raise ImportError(obj_name)
262
+
263
+
264
+ def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
265
+ """Traverses the object name and returns the last (rightmost) python object."""
266
+ if obj_name == '':
267
+ return module
268
+ obj = module
269
+ for part in obj_name.split("."):
270
+ obj = getattr(obj, part)
271
+ return obj
272
+
273
+
274
+ def get_obj_by_name(name: str) -> Any:
275
+ """Finds the python object with the given name."""
276
+ module, obj_name = get_module_from_obj_name(name)
277
+ return get_obj_from_module(module, obj_name)
278
+
279
+
280
+ def call_func_by_name(*args, func_name: Union[str, Callable], **kwargs) -> Any:
281
+ """Finds the python object with the given name and calls it as a function."""
282
+ assert func_name is not None
283
+ func_obj = get_obj_by_name(func_name) if isinstance(func_name, str) else func_name
284
+ assert callable(func_obj)
285
+ return func_obj(*args, **kwargs)
286
+
287
+
288
+ def construct_class_by_name(*args, class_name: Union[str, type], **kwargs) -> Any:
289
+ """Finds the python class with the given name and constructs it with the given arguments."""
290
+ return call_func_by_name(*args, func_name=class_name, **kwargs)
291
+
292
+
293
+ def get_module_dir_by_obj_name(obj_name: str) -> str:
294
+ """Get the directory path of the module containing the given object name."""
295
+ module, _ = get_module_from_obj_name(obj_name)
296
+ return os.path.dirname(inspect.getfile(module))
297
+
298
+
299
+ def is_top_level_function(obj: Any) -> bool:
300
+ """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
301
+ return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
302
+
303
+
304
+ def get_top_level_function_name(obj: Any) -> str:
305
+ """Return the fully-qualified name of a top-level function."""
306
+ assert is_top_level_function(obj)
307
+ module = obj.__module__
308
+ if module == '__main__':
309
+ fname = sys.modules[module].__file__
310
+ assert fname is not None
311
+ module = os.path.splitext(os.path.basename(fname))[0]
312
+ return module + "." + obj.__name__
313
+
314
+
315
+ # File system helpers
316
+ # ------------------------------------------------------------------------------------------
317
+
318
+ def list_dir_recursively_with_ignore(dir_path: str, ignores: Optional[List[str]] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
319
+ """List all files recursively in a given directory while ignoring given file and directory names.
320
+ Returns list of tuples containing both absolute and relative paths."""
321
+ assert os.path.isdir(dir_path)
322
+ base_name = os.path.basename(os.path.normpath(dir_path))
323
+
324
+ if ignores is None:
325
+ ignores = []
326
+
327
+ result = []
328
+
329
+ for root, dirs, files in os.walk(dir_path, topdown=True):
330
+ for ignore_ in ignores:
331
+ dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
332
+
333
+ # dirs need to be edited in-place
334
+ for d in dirs_to_remove:
335
+ dirs.remove(d)
336
+
337
+ files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
338
+
339
+ absolute_paths = [os.path.join(root, f) for f in files]
340
+ relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
341
+
342
+ if add_base_to_relative:
343
+ relative_paths = [os.path.join(base_name, p) for p in relative_paths]
344
+
345
+ assert len(absolute_paths) == len(relative_paths)
346
+ result += zip(absolute_paths, relative_paths)
347
+
348
+ return result
349
+
350
+
351
+ def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
352
+ """Takes in a list of tuples of (src, dst) paths and copies files.
353
+ Will create all necessary directories."""
354
+ for file in files:
355
+ target_dir_name = os.path.dirname(file[1])
356
+
357
+ # will create all intermediate-level directories
358
+ os.makedirs(target_dir_name, exist_ok=True)
359
+ shutil.copyfile(file[0], file[1])
360
+
361
+
362
+ # URL helpers
363
+ # ------------------------------------------------------------------------------------------
364
+
365
+ def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
366
+ """Determine whether the given object is a valid URL string."""
367
+ if not isinstance(obj, str) or not "://" in obj:
368
+ return False
369
+ if allow_file_urls and obj.startswith('file://'):
370
+ return True
371
+ try:
372
+ res = urllib.parse.urlparse(obj)
373
+ if not res.scheme or not res.netloc or not "." in res.netloc:
374
+ return False
375
+ res = urllib.parse.urlparse(urllib.parse.urljoin(obj, "/"))
376
+ if not res.scheme or not res.netloc or not "." in res.netloc:
377
+ return False
378
+ except:
379
+ return False
380
+ return True
381
+
382
+ # Note on static typing: a better API would be to split 'open_url' to 'openl_url' and
383
+ # 'download_url' with separate return types (BinaryIO, str). As the `return_filename=True`
384
+ # case is somewhat uncommon, we just pretend like this function never returns a string
385
+ # and type ignore return value for those cases.
386
+ def open_url(url: str, cache_dir: Optional[str] = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> BinaryIO:
387
+ """Download the given URL and return a binary-mode file object to access the data."""
388
+ assert num_attempts >= 1
389
+ assert not (return_filename and (not cache))
390
+
391
+ # Doesn't look like an URL scheme so interpret it as a local filename.
392
+ if not re.match('^[a-z]+://', url):
393
+ return url if return_filename else open(url, "rb") # type: ignore
394
+
395
+ # Handle file URLs. This code handles unusual file:// patterns that
396
+ # arise on Windows:
397
+ #
398
+ # file:///c:/foo.txt
399
+ #
400
+ # which would translate to a local '/c:/foo.txt' filename that's
401
+ # invalid. Drop the forward slash for such pathnames.
402
+ #
403
+ # If you touch this code path, you should test it on both Linux and
404
+ # Windows.
405
+ #
406
+ # Some internet resources suggest using urllib.request.url2pathname()
407
+ # but that converts forward slashes to backslashes and this causes
408
+ # its own set of problems.
409
+ if url.startswith('file://'):
410
+ filename = urllib.parse.urlparse(url).path
411
+ if re.match(r'^/[a-zA-Z]:', filename):
412
+ filename = filename[1:]
413
+ return filename if return_filename else open(filename, "rb") # type: ignore
414
+
415
+ assert is_url(url)
416
+
417
+ # Lookup from cache.
418
+ if cache_dir is None:
419
+ cache_dir = make_cache_dir_path('downloads')
420
+
421
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
422
+ if cache:
423
+ cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
424
+ if len(cache_files) == 1:
425
+ filename = cache_files[0]
426
+ return filename if return_filename else open(filename, "rb") # type: ignore
427
+
428
+ # Download.
429
+ url_name = None
430
+ url_data = None
431
+ with requests.Session() as session:
432
+ if verbose:
433
+ print("Downloading %s ..." % url, end="", flush=True)
434
+ for attempts_left in reversed(range(num_attempts)):
435
+ try:
436
+ with session.get(url) as res:
437
+ res.raise_for_status()
438
+ if len(res.content) == 0:
439
+ raise IOError("No data received")
440
+
441
+ if len(res.content) < 8192:
442
+ content_str = res.content.decode("utf-8")
443
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
444
+ links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
445
+ if len(links) == 1:
446
+ url = urllib.parse.urljoin(url, links[0])
447
+ raise IOError("Google Drive virus checker nag")
448
+ if "Google Drive - Quota exceeded" in content_str:
449
+ raise IOError("Google Drive download quota exceeded -- please try again later")
450
+
451
+ match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
452
+ url_name = match[1] if match else url
453
+ url_data = res.content
454
+ if verbose:
455
+ print(" done")
456
+ break
457
+ except KeyboardInterrupt:
458
+ raise
459
+ except:
460
+ if not attempts_left:
461
+ if verbose:
462
+ print(" failed")
463
+ raise
464
+ if verbose:
465
+ print(".", end="", flush=True)
466
+
467
+ assert url_data is not None
468
+
469
+ # Save to cache.
470
+ if cache:
471
+ assert url_name is not None
472
+ safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
473
+ safe_name = safe_name[:min(len(safe_name), 128)]
474
+ cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
475
+ temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
476
+ os.makedirs(cache_dir, exist_ok=True)
477
+ with open(temp_file, "wb") as f:
478
+ f.write(url_data)
479
+ os.replace(temp_file, cache_file) # atomic
480
+ if return_filename:
481
+ return cache_file # type: ignore
482
+
483
+ # Return data as file object.
484
+ assert not return_filename
485
+ return io.BytesIO(url_data)
New/REG/preprocessing/encoders.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ """Converting between pixel and latent representations of image data."""
9
+
10
+ import os
11
+ import warnings
12
+ import numpy as np
13
+ import torch
14
+ from torch_utils import persistence
15
+ from torch_utils import misc
16
+
17
+ warnings.filterwarnings('ignore', 'torch.utils._pytree._register_pytree_node is deprecated.')
18
+ warnings.filterwarnings('ignore', '`resume_download` is deprecated')
19
+
20
+ #----------------------------------------------------------------------------
21
+ # Abstract base class for encoders/decoders that convert back and forth
22
+ # between pixel and latent representations of image data.
23
+ #
24
+ # Logically, "raw pixels" are first encoded into "raw latents" that are
25
+ # then further encoded into "final latents". Decoding, on the other hand,
26
+ # goes directly from the final latents to raw pixels. The final latents are
27
+ # used as inputs and outputs of the model, whereas the raw latents are
28
+ # stored in the dataset. This separation provides added flexibility in terms
29
+ # of performing just-in-time adjustments, such as data whitening, without
30
+ # having to construct a new dataset.
31
+ #
32
+ # All image data is represented as PyTorch tensors in NCHW order.
33
+ # Raw pixels are represented as 3-channel uint8.
34
+
35
+ @persistence.persistent_class
36
+ class Encoder:
37
+ def __init__(self):
38
+ pass
39
+
40
+ def init(self, device): # force lazy init to happen now
41
+ pass
42
+
43
+ def __getstate__(self):
44
+ return self.__dict__
45
+
46
+ def encode_pixels(self, x): # raw pixels => raw latents
47
+ raise NotImplementedError # to be overridden by subclass
48
+ #----------------------------------------------------------------------------
49
+ # Pre-trained VAE encoder from Stability AI.
50
+
51
+ @persistence.persistent_class
52
+ class StabilityVAEEncoder(Encoder):
53
+ def __init__(self,
54
+ vae_name = 'stabilityai/sd-vae-ft-mse', # Name of the VAE to use.
55
+ batch_size = 8, # Batch size to use when running the VAE.
56
+ ):
57
+ super().__init__()
58
+ self.vae_name = vae_name
59
+ self.batch_size = int(batch_size)
60
+ self._vae = None
61
+
62
+ def init(self, device): # force lazy init to happen now
63
+ super().init(device)
64
+ if self._vae is None:
65
+ self._vae = load_stability_vae(self.vae_name, device=device)
66
+ else:
67
+ self._vae.to(device)
68
+
69
+ def __getstate__(self):
70
+ return dict(super().__getstate__(), _vae=None) # do not pickle the vae
71
+
72
+ def _run_vae_encoder(self, x):
73
+ d = self._vae.encode(x)['latent_dist']
74
+ return torch.cat([d.mean, d.std], dim=1)
75
+
76
+ def encode_pixels(self, x): # raw pixels => raw latents
77
+ self.init(x.device)
78
+ x = x.to(torch.float32) / 127.5 - 1
79
+ x = torch.cat([self._run_vae_encoder(batch) for batch in x.split(self.batch_size)])
80
+ return x
81
+
82
+ #----------------------------------------------------------------------------
83
+
84
+ def load_stability_vae(vae_name='stabilityai/sd-vae-ft-mse', device=torch.device('cpu')):
85
+ import dnnlib
86
+ cache_dir = dnnlib.make_cache_dir_path('diffusers')
87
+ os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = '1'
88
+ os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'
89
+ os.environ['HF_HOME'] = cache_dir
90
+
91
+
92
+ import diffusers # pip install diffusers # pyright: ignore [reportMissingImports]
93
+ try:
94
+ # First try with local_files_only to avoid consulting tfhub metadata if the model is already in cache.
95
+ vae = diffusers.models.AutoencoderKL.from_pretrained(
96
+ vae_name, cache_dir=cache_dir, local_files_only=True
97
+ )
98
+ except:
99
+ # Could not load the model from cache; try without local_files_only.
100
+ vae = diffusers.models.AutoencoderKL.from_pretrained(vae_name, cache_dir=cache_dir)
101
+ return vae.eval().requires_grad_(False).to(device)
102
+
103
+ #----------------------------------------------------------------------------
New/REG/preprocessing/torch_utils/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ # empty
New/REG/preprocessing/torch_utils/distributed.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ import os
9
+ import re
10
+ import socket
11
+ import torch
12
+ import torch.distributed
13
+ from . import training_stats
14
+
15
+ _sync_device = None
16
+
17
+ #----------------------------------------------------------------------------
18
+
19
+ def init():
20
+ global _sync_device
21
+
22
+ if not torch.distributed.is_initialized():
23
+ # Setup some reasonable defaults for env-based distributed init if
24
+ # not set by the running environment.
25
+ if 'MASTER_ADDR' not in os.environ:
26
+ os.environ['MASTER_ADDR'] = 'localhost'
27
+ if 'MASTER_PORT' not in os.environ:
28
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
29
+ s.bind(('', 0))
30
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
31
+ os.environ['MASTER_PORT'] = str(s.getsockname()[1])
32
+ s.close()
33
+ if 'RANK' not in os.environ:
34
+ os.environ['RANK'] = '0'
35
+ if 'LOCAL_RANK' not in os.environ:
36
+ os.environ['LOCAL_RANK'] = '0'
37
+ if 'WORLD_SIZE' not in os.environ:
38
+ os.environ['WORLD_SIZE'] = '1'
39
+ backend = 'gloo' if os.name == 'nt' else 'nccl'
40
+ torch.distributed.init_process_group(backend=backend, init_method='env://')
41
+ torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0')))
42
+
43
+ _sync_device = torch.device('cuda') if get_world_size() > 1 else None
44
+ training_stats.init_multiprocessing(rank=get_rank(), sync_device=_sync_device)
45
+
46
+ #----------------------------------------------------------------------------
47
+
48
+ def get_rank():
49
+ return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
50
+
51
+ #----------------------------------------------------------------------------
52
+
53
+ def get_world_size():
54
+ return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
55
+
56
+ #----------------------------------------------------------------------------
57
+
58
+ def should_stop():
59
+ return False
60
+
61
+ #----------------------------------------------------------------------------
62
+
63
+ def should_suspend():
64
+ return False
65
+
66
+ #----------------------------------------------------------------------------
67
+
68
+ def request_suspend():
69
+ pass
70
+
71
+ #----------------------------------------------------------------------------
72
+
73
+ def update_progress(cur, total):
74
+ pass
75
+
76
+ #----------------------------------------------------------------------------
77
+
78
+ def print0(*args, **kwargs):
79
+ if get_rank() == 0:
80
+ print(*args, **kwargs)
81
+
82
+ #----------------------------------------------------------------------------
83
+
84
+ class CheckpointIO:
85
+ def __init__(self, **kwargs):
86
+ self._state_objs = kwargs
87
+
88
+ def save(self, pt_path, verbose=True):
89
+ if verbose:
90
+ print0(f'Saving {pt_path} ... ', end='', flush=True)
91
+ data = dict()
92
+ for name, obj in self._state_objs.items():
93
+ if obj is None:
94
+ data[name] = None
95
+ elif isinstance(obj, dict):
96
+ data[name] = obj
97
+ elif hasattr(obj, 'state_dict'):
98
+ data[name] = obj.state_dict()
99
+ elif hasattr(obj, '__getstate__'):
100
+ data[name] = obj.__getstate__()
101
+ elif hasattr(obj, '__dict__'):
102
+ data[name] = obj.__dict__
103
+ else:
104
+ raise ValueError(f'Invalid state object of type {type(obj).__name__}')
105
+ if get_rank() == 0:
106
+ torch.save(data, pt_path)
107
+ if verbose:
108
+ print0('done')
109
+
110
+ def load(self, pt_path, verbose=True):
111
+ if verbose:
112
+ print0(f'Loading {pt_path} ... ', end='', flush=True)
113
+ data = torch.load(pt_path, map_location=torch.device('cpu'))
114
+ for name, obj in self._state_objs.items():
115
+ if obj is None:
116
+ pass
117
+ elif isinstance(obj, dict):
118
+ obj.clear()
119
+ obj.update(data[name])
120
+ elif hasattr(obj, 'load_state_dict'):
121
+ obj.load_state_dict(data[name])
122
+ elif hasattr(obj, '__setstate__'):
123
+ obj.__setstate__(data[name])
124
+ elif hasattr(obj, '__dict__'):
125
+ obj.__dict__.clear()
126
+ obj.__dict__.update(data[name])
127
+ else:
128
+ raise ValueError(f'Invalid state object of type {type(obj).__name__}')
129
+ if verbose:
130
+ print0('done')
131
+
132
+ def load_latest(self, run_dir, pattern=r'training-state-(\d+).pt', verbose=True):
133
+ fnames = [entry.name for entry in os.scandir(run_dir) if entry.is_file() and re.fullmatch(pattern, entry.name)]
134
+ if len(fnames) == 0:
135
+ return None
136
+ pt_path = os.path.join(run_dir, max(fnames, key=lambda x: float(re.fullmatch(pattern, x).group(1))))
137
+ self.load(pt_path, verbose=verbose)
138
+ return pt_path
139
+
140
+ #----------------------------------------------------------------------------
New/REG/preprocessing/torch_utils/misc.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ import re
9
+ import contextlib
10
+ import functools
11
+ import numpy as np
12
+ import torch
13
+ import warnings
14
+ import dnnlib
15
+
16
+ #----------------------------------------------------------------------------
17
+ # Re-seed torch & numpy random generators based on the given arguments.
18
+
19
+ def set_random_seed(*args):
20
+ seed = hash(args) % (1 << 31)
21
+ torch.manual_seed(seed)
22
+ np.random.seed(seed)
23
+
24
+ #----------------------------------------------------------------------------
25
+ # Cached construction of constant tensors. Avoids CPU=>GPU copy when the
26
+ # same constant is used multiple times.
27
+
28
+ _constant_cache = dict()
29
+
30
+ def constant(value, shape=None, dtype=None, device=None, memory_format=None):
31
+ value = np.asarray(value)
32
+ if shape is not None:
33
+ shape = tuple(shape)
34
+ if dtype is None:
35
+ dtype = torch.get_default_dtype()
36
+ if device is None:
37
+ device = torch.device('cpu')
38
+ if memory_format is None:
39
+ memory_format = torch.contiguous_format
40
+
41
+ key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
42
+ tensor = _constant_cache.get(key, None)
43
+ if tensor is None:
44
+ tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
45
+ if shape is not None:
46
+ tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
47
+ tensor = tensor.contiguous(memory_format=memory_format)
48
+ _constant_cache[key] = tensor
49
+ return tensor
50
+
51
+ #----------------------------------------------------------------------------
52
+ # Variant of constant() that inherits dtype and device from the given
53
+ # reference tensor by default.
54
+
55
+ def const_like(ref, value, shape=None, dtype=None, device=None, memory_format=None):
56
+ if dtype is None:
57
+ dtype = ref.dtype
58
+ if device is None:
59
+ device = ref.device
60
+ return constant(value, shape=shape, dtype=dtype, device=device, memory_format=memory_format)
61
+
62
+ #----------------------------------------------------------------------------
63
+ # Cached construction of temporary tensors in pinned CPU memory.
64
+
65
+ @functools.lru_cache(None)
66
+ def pinned_buf(shape, dtype):
67
+ return torch.empty(shape, dtype=dtype).pin_memory()
68
+
69
+ #----------------------------------------------------------------------------
70
+ # Symbolic assert.
71
+
72
+ try:
73
+ symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
74
+ except AttributeError:
75
+ symbolic_assert = torch.Assert # 1.7.0
76
+
77
+ #----------------------------------------------------------------------------
78
+ # Context manager to temporarily suppress known warnings in torch.jit.trace().
79
+ # Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
80
+
81
+ @contextlib.contextmanager
82
+ def suppress_tracer_warnings():
83
+ flt = ('ignore', None, torch.jit.TracerWarning, None, 0)
84
+ warnings.filters.insert(0, flt)
85
+ yield
86
+ warnings.filters.remove(flt)
87
+
88
+ #----------------------------------------------------------------------------
89
+ # Assert that the shape of a tensor matches the given list of integers.
90
+ # None indicates that the size of a dimension is allowed to vary.
91
+ # Performs symbolic assertion when used in torch.jit.trace().
92
+
93
+ def assert_shape(tensor, ref_shape):
94
+ if tensor.ndim != len(ref_shape):
95
+ raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
96
+ for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
97
+ if ref_size is None:
98
+ pass
99
+ elif isinstance(ref_size, torch.Tensor):
100
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
101
+ symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
102
+ elif isinstance(size, torch.Tensor):
103
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
104
+ symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
105
+ elif size != ref_size:
106
+ raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
107
+
108
+ #----------------------------------------------------------------------------
109
+ # Function decorator that calls torch.autograd.profiler.record_function().
110
+
111
+ def profiled_function(fn):
112
+ def decorator(*args, **kwargs):
113
+ with torch.autograd.profiler.record_function(fn.__name__):
114
+ return fn(*args, **kwargs)
115
+ decorator.__name__ = fn.__name__
116
+ return decorator
117
+
118
+ #----------------------------------------------------------------------------
119
+ # Sampler for torch.utils.data.DataLoader that loops over the dataset
120
+ # indefinitely, shuffling items as it goes.
121
+
122
+ class InfiniteSampler(torch.utils.data.Sampler):
123
+ def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, start_idx=0):
124
+ assert len(dataset) > 0
125
+ assert num_replicas > 0
126
+ assert 0 <= rank < num_replicas
127
+ warnings.filterwarnings('ignore', '`data_source` argument is not used and will be removed')
128
+ super().__init__(dataset)
129
+ self.dataset_size = len(dataset)
130
+ self.start_idx = start_idx + rank
131
+ self.stride = num_replicas
132
+ self.shuffle = shuffle
133
+ self.seed = seed
134
+
135
+ def __iter__(self):
136
+ idx = self.start_idx
137
+ epoch = None
138
+ while True:
139
+ if epoch != idx // self.dataset_size:
140
+ epoch = idx // self.dataset_size
141
+ order = np.arange(self.dataset_size)
142
+ if self.shuffle:
143
+ np.random.RandomState(hash((self.seed, epoch)) % (1 << 31)).shuffle(order)
144
+ yield int(order[idx % self.dataset_size])
145
+ idx += self.stride
146
+
147
+ #----------------------------------------------------------------------------
148
+ # Utilities for operating with torch.nn.Module parameters and buffers.
149
+
150
+ def params_and_buffers(module):
151
+ assert isinstance(module, torch.nn.Module)
152
+ return list(module.parameters()) + list(module.buffers())
153
+
154
+ def named_params_and_buffers(module):
155
+ assert isinstance(module, torch.nn.Module)
156
+ return list(module.named_parameters()) + list(module.named_buffers())
157
+
158
+ @torch.no_grad()
159
+ def copy_params_and_buffers(src_module, dst_module, require_all=False):
160
+ assert isinstance(src_module, torch.nn.Module)
161
+ assert isinstance(dst_module, torch.nn.Module)
162
+ src_tensors = dict(named_params_and_buffers(src_module))
163
+ for name, tensor in named_params_and_buffers(dst_module):
164
+ assert (name in src_tensors) or (not require_all)
165
+ if name in src_tensors:
166
+ tensor.copy_(src_tensors[name])
167
+
168
+ #----------------------------------------------------------------------------
169
+ # Context manager for easily enabling/disabling DistributedDataParallel
170
+ # synchronization.
171
+
172
+ @contextlib.contextmanager
173
+ def ddp_sync(module, sync):
174
+ assert isinstance(module, torch.nn.Module)
175
+ if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
176
+ yield
177
+ else:
178
+ with module.no_sync():
179
+ yield
180
+
181
+ #----------------------------------------------------------------------------
182
+ # Check DistributedDataParallel consistency across processes.
183
+
184
+ def check_ddp_consistency(module, ignore_regex=None):
185
+ assert isinstance(module, torch.nn.Module)
186
+ for name, tensor in named_params_and_buffers(module):
187
+ fullname = type(module).__name__ + '.' + name
188
+ if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
189
+ continue
190
+ tensor = tensor.detach()
191
+ if tensor.is_floating_point():
192
+ tensor = torch.nan_to_num(tensor)
193
+ other = tensor.clone()
194
+ torch.distributed.broadcast(tensor=other, src=0)
195
+ assert (tensor == other).all(), fullname
196
+
197
+ #----------------------------------------------------------------------------
198
+ # Print summary table of module hierarchy.
199
+
200
+ @torch.no_grad()
201
+ def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
202
+ assert isinstance(module, torch.nn.Module)
203
+ assert not isinstance(module, torch.jit.ScriptModule)
204
+ assert isinstance(inputs, (tuple, list))
205
+
206
+ # Register hooks.
207
+ entries = []
208
+ nesting = [0]
209
+ def pre_hook(_mod, _inputs):
210
+ nesting[0] += 1
211
+ def post_hook(mod, _inputs, outputs):
212
+ nesting[0] -= 1
213
+ if nesting[0] <= max_nesting:
214
+ outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
215
+ outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
216
+ entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
217
+ hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
218
+ hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
219
+
220
+ # Run module.
221
+ outputs = module(*inputs)
222
+ for hook in hooks:
223
+ hook.remove()
224
+
225
+ # Identify unique outputs, parameters, and buffers.
226
+ tensors_seen = set()
227
+ for e in entries:
228
+ e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
229
+ e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
230
+ e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
231
+ tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
232
+
233
+ # Filter out redundant entries.
234
+ if skip_redundant:
235
+ entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
236
+
237
+ # Construct table.
238
+ rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
239
+ rows += [['---'] * len(rows[0])]
240
+ param_total = 0
241
+ buffer_total = 0
242
+ submodule_names = {mod: name for name, mod in module.named_modules()}
243
+ for e in entries:
244
+ name = '<top-level>' if e.mod is module else submodule_names[e.mod]
245
+ param_size = sum(t.numel() for t in e.unique_params)
246
+ buffer_size = sum(t.numel() for t in e.unique_buffers)
247
+ output_shapes = [str(list(t.shape)) for t in e.outputs]
248
+ output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
249
+ rows += [[
250
+ name + (':0' if len(e.outputs) >= 2 else ''),
251
+ str(param_size) if param_size else '-',
252
+ str(buffer_size) if buffer_size else '-',
253
+ (output_shapes + ['-'])[0],
254
+ (output_dtypes + ['-'])[0],
255
+ ]]
256
+ for idx in range(1, len(e.outputs)):
257
+ rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
258
+ param_total += param_size
259
+ buffer_total += buffer_size
260
+ rows += [['---'] * len(rows[0])]
261
+ rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
262
+
263
+ # Print table.
264
+ widths = [max(len(cell) for cell in column) for column in zip(*rows)]
265
+ print()
266
+ for row in rows:
267
+ print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
268
+ print()
269
+
270
+ #----------------------------------------------------------------------------
271
+ # Tile a batch of images into a 2D grid.
272
+
273
+ def tile_images(x, w, h):
274
+ assert x.ndim == 4 # NCHW => CHW
275
+ return x.reshape(h, w, *x.shape[1:]).permute(2, 0, 3, 1, 4).reshape(x.shape[1], h * x.shape[2], w * x.shape[3])
276
+
277
+ #----------------------------------------------------------------------------
New/REG/preprocessing/torch_utils/persistence.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ """Facilities for pickling Python code alongside other data.
9
+
10
+ The pickled code is automatically imported into a separate Python module
11
+ during unpickling. This way, any previously exported pickles will remain
12
+ usable even if the original code is no longer available, or if the current
13
+ version of the code is not consistent with what was originally pickled."""
14
+
15
+ import sys
16
+ import pickle
17
+ import io
18
+ import inspect
19
+ import copy
20
+ import uuid
21
+ import types
22
+ import functools
23
+ import dnnlib
24
+
25
+ #----------------------------------------------------------------------------
26
+
27
+ _version = 6 # internal version number
28
+ _decorators = set() # {decorator_class, ...}
29
+ _import_hooks = [] # [hook_function, ...]
30
+ _module_to_src_dict = dict() # {module: src, ...}
31
+ _src_to_module_dict = dict() # {src: module, ...}
32
+
33
+ #----------------------------------------------------------------------------
34
+
35
+ def persistent_class(orig_class):
36
+ r"""Class decorator that extends a given class to save its source code
37
+ when pickled.
38
+
39
+ Example:
40
+
41
+ from torch_utils import persistence
42
+
43
+ @persistence.persistent_class
44
+ class MyNetwork(torch.nn.Module):
45
+ def __init__(self, num_inputs, num_outputs):
46
+ super().__init__()
47
+ self.fc = MyLayer(num_inputs, num_outputs)
48
+ ...
49
+
50
+ @persistence.persistent_class
51
+ class MyLayer(torch.nn.Module):
52
+ ...
53
+
54
+ When pickled, any instance of `MyNetwork` and `MyLayer` will save its
55
+ source code alongside other internal state (e.g., parameters, buffers,
56
+ and submodules). This way, any previously exported pickle will remain
57
+ usable even if the class definitions have been modified or are no
58
+ longer available.
59
+
60
+ The decorator saves the source code of the entire Python module
61
+ containing the decorated class. It does *not* save the source code of
62
+ any imported modules. Thus, the imported modules must be available
63
+ during unpickling, also including `torch_utils.persistence` itself.
64
+
65
+ It is ok to call functions defined in the same module from the
66
+ decorated class. However, if the decorated class depends on other
67
+ classes defined in the same module, they must be decorated as well.
68
+ This is illustrated in the above example in the case of `MyLayer`.
69
+
70
+ It is also possible to employ the decorator just-in-time before
71
+ calling the constructor. For example:
72
+
73
+ cls = MyLayer
74
+ if want_to_make_it_persistent:
75
+ cls = persistence.persistent_class(cls)
76
+ layer = cls(num_inputs, num_outputs)
77
+
78
+ As an additional feature, the decorator also keeps track of the
79
+ arguments that were used to construct each instance of the decorated
80
+ class. The arguments can be queried via `obj.init_args` and
81
+ `obj.init_kwargs`, and they are automatically pickled alongside other
82
+ object state. This feature can be disabled on a per-instance basis
83
+ by setting `self._record_init_args = False` in the constructor.
84
+
85
+ A typical use case is to first unpickle a previous instance of a
86
+ persistent class, and then upgrade it to use the latest version of
87
+ the source code:
88
+
89
+ with open('old_pickle.pkl', 'rb') as f:
90
+ old_net = pickle.load(f)
91
+ new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
92
+ misc.copy_params_and_buffers(old_net, new_net, require_all=True)
93
+ """
94
+ assert isinstance(orig_class, type)
95
+ if is_persistent(orig_class):
96
+ return orig_class
97
+
98
+ assert orig_class.__module__ in sys.modules
99
+ orig_module = sys.modules[orig_class.__module__]
100
+ orig_module_src = _module_to_src(orig_module)
101
+
102
+ @functools.wraps(orig_class, updated=())
103
+ class Decorator(orig_class):
104
+ _orig_module_src = orig_module_src
105
+ _orig_class_name = orig_class.__name__
106
+
107
+ def __init__(self, *args, **kwargs):
108
+ super().__init__(*args, **kwargs)
109
+ record_init_args = getattr(self, '_record_init_args', True)
110
+ self._init_args = copy.deepcopy(args) if record_init_args else None
111
+ self._init_kwargs = copy.deepcopy(kwargs) if record_init_args else None
112
+ assert orig_class.__name__ in orig_module.__dict__
113
+ _check_pickleable(self.__reduce__())
114
+
115
+ @property
116
+ def init_args(self):
117
+ assert self._init_args is not None
118
+ return copy.deepcopy(self._init_args)
119
+
120
+ @property
121
+ def init_kwargs(self):
122
+ assert self._init_kwargs is not None
123
+ return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs))
124
+
125
+ def __reduce__(self):
126
+ fields = list(super().__reduce__())
127
+ fields += [None] * max(3 - len(fields), 0)
128
+ if fields[0] is not _reconstruct_persistent_obj:
129
+ meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
130
+ fields[0] = _reconstruct_persistent_obj # reconstruct func
131
+ fields[1] = (meta,) # reconstruct args
132
+ fields[2] = None # state dict
133
+ return tuple(fields)
134
+
135
+ _decorators.add(Decorator)
136
+ return Decorator
137
+
138
+ #----------------------------------------------------------------------------
139
+
140
+ def is_persistent(obj):
141
+ r"""Test whether the given object or class is persistent, i.e.,
142
+ whether it will save its source code when pickled.
143
+ """
144
+ try:
145
+ if obj in _decorators:
146
+ return True
147
+ except TypeError:
148
+ pass
149
+ return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
150
+
151
+ #----------------------------------------------------------------------------
152
+
153
+ def import_hook(hook):
154
+ r"""Register an import hook that is called whenever a persistent object
155
+ is being unpickled. A typical use case is to patch the pickled source
156
+ code to avoid errors and inconsistencies when the API of some imported
157
+ module has changed.
158
+
159
+ The hook should have the following signature:
160
+
161
+ hook(meta) -> modified meta
162
+
163
+ `meta` is an instance of `dnnlib.EasyDict` with the following fields:
164
+
165
+ type: Type of the persistent object, e.g. `'class'`.
166
+ version: Internal version number of `torch_utils.persistence`.
167
+ module_src Original source code of the Python module.
168
+ class_name: Class name in the original Python module.
169
+ state: Internal state of the object.
170
+
171
+ Example:
172
+
173
+ @persistence.import_hook
174
+ def wreck_my_network(meta):
175
+ if meta.class_name == 'MyNetwork':
176
+ print('MyNetwork is being imported. I will wreck it!')
177
+ meta.module_src = meta.module_src.replace("True", "False")
178
+ return meta
179
+ """
180
+ assert callable(hook)
181
+ _import_hooks.append(hook)
182
+
183
+ #----------------------------------------------------------------------------
184
+
185
+ def _reconstruct_persistent_obj(meta):
186
+ r"""Hook that is called internally by the `pickle` module to unpickle
187
+ a persistent object.
188
+ """
189
+ meta = dnnlib.EasyDict(meta)
190
+ meta.state = dnnlib.EasyDict(meta.state)
191
+ for hook in _import_hooks:
192
+ meta = hook(meta)
193
+ assert meta is not None
194
+
195
+ assert meta.version == _version
196
+ module = _src_to_module(meta.module_src)
197
+
198
+ assert meta.type == 'class'
199
+ orig_class = module.__dict__[meta.class_name]
200
+ decorator_class = persistent_class(orig_class)
201
+ obj = decorator_class.__new__(decorator_class)
202
+
203
+ setstate = getattr(obj, '__setstate__', None)
204
+ if callable(setstate):
205
+ setstate(meta.state) # pylint: disable=not-callable
206
+ else:
207
+ obj.__dict__.update(meta.state)
208
+ return obj
209
+
210
+ #----------------------------------------------------------------------------
211
+
212
+ def _module_to_src(module):
213
+ r"""Query the source code of a given Python module.
214
+ """
215
+ src = _module_to_src_dict.get(module, None)
216
+ if src is None:
217
+ src = inspect.getsource(module)
218
+ _module_to_src_dict[module] = src
219
+ _src_to_module_dict[src] = module
220
+ return src
221
+
222
+ def _src_to_module(src):
223
+ r"""Get or create a Python module for the given source code.
224
+ """
225
+ module = _src_to_module_dict.get(src, None)
226
+ if module is None:
227
+ module_name = "_imported_module_" + uuid.uuid4().hex
228
+ module = types.ModuleType(module_name)
229
+ sys.modules[module_name] = module
230
+ _module_to_src_dict[module] = src
231
+ _src_to_module_dict[src] = module
232
+ exec(src, module.__dict__) # pylint: disable=exec-used
233
+ return module
234
+
235
+ #----------------------------------------------------------------------------
236
+
237
+ def _check_pickleable(obj):
238
+ r"""Check that the given object is pickleable, raising an exception if
239
+ it is not. This function is expected to be considerably more efficient
240
+ than actually pickling the object.
241
+ """
242
+ def recurse(obj):
243
+ if isinstance(obj, (list, tuple, set)):
244
+ return [recurse(x) for x in obj]
245
+ if isinstance(obj, dict):
246
+ return [[recurse(x), recurse(y)] for x, y in obj.items()]
247
+ if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
248
+ return None # Python primitive types are pickleable.
249
+ if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']:
250
+ return None # NumPy arrays and PyTorch tensors are pickleable.
251
+ if is_persistent(obj):
252
+ return None # Persistent objects are pickleable, by virtue of the constructor check.
253
+ return obj
254
+ with io.BytesIO() as f:
255
+ pickle.dump(recurse(obj), f)
256
+
257
+ #----------------------------------------------------------------------------
New/REG/preprocessing/torch_utils/training_stats.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ """Facilities for reporting and collecting training statistics across
9
+ multiple processes and devices. The interface is designed to minimize
10
+ synchronization overhead as well as the amount of boilerplate in user
11
+ code."""
12
+
13
+ import re
14
+ import numpy as np
15
+ import torch
16
+ import dnnlib
17
+
18
+ from . import misc
19
+
20
+ #----------------------------------------------------------------------------
21
+
22
+ _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
23
+ _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
24
+ _counter_dtype = torch.float64 # Data type to use for the internal counters.
25
+ _rank = 0 # Rank of the current process.
26
+ _sync_device = None # Device to use for multiprocess communication. None = single-process.
27
+ _sync_called = False # Has _sync() been called yet?
28
+ _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor
29
+ _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
30
+
31
+ #----------------------------------------------------------------------------
32
+
33
+ def init_multiprocessing(rank, sync_device):
34
+ r"""Initializes `torch_utils.training_stats` for collecting statistics
35
+ across multiple processes.
36
+
37
+ This function must be called after
38
+ `torch.distributed.init_process_group()` and before `Collector.update()`.
39
+ The call is not necessary if multi-process collection is not needed.
40
+
41
+ Args:
42
+ rank: Rank of the current process.
43
+ sync_device: PyTorch device to use for inter-process
44
+ communication, or None to disable multi-process
45
+ collection. Typically `torch.device('cuda', rank)`.
46
+ """
47
+ global _rank, _sync_device
48
+ assert not _sync_called
49
+ _rank = rank
50
+ _sync_device = sync_device
51
+
52
+ #----------------------------------------------------------------------------
53
+
54
+ @misc.profiled_function
55
+ def report(name, value):
56
+ r"""Broadcasts the given set of scalars to all interested instances of
57
+ `Collector`, across device and process boundaries. NaNs and Infs are
58
+ ignored.
59
+
60
+ This function is expected to be extremely cheap and can be safely
61
+ called from anywhere in the training loop, loss function, or inside a
62
+ `torch.nn.Module`.
63
+
64
+ Warning: The current implementation expects the set of unique names to
65
+ be consistent across processes. Please make sure that `report()` is
66
+ called at least once for each unique name by each process, and in the
67
+ same order. If a given process has no scalars to broadcast, it can do
68
+ `report(name, [])` (empty list).
69
+
70
+ Args:
71
+ name: Arbitrary string specifying the name of the statistic.
72
+ Averages are accumulated separately for each unique name.
73
+ value: Arbitrary set of scalars. Can be a list, tuple,
74
+ NumPy array, PyTorch tensor, or Python scalar.
75
+
76
+ Returns:
77
+ The same `value` that was passed in.
78
+ """
79
+ if name not in _counters:
80
+ _counters[name] = dict()
81
+
82
+ elems = torch.as_tensor(value)
83
+ if elems.numel() == 0:
84
+ return value
85
+
86
+ elems = elems.detach().flatten().to(_reduce_dtype)
87
+ square = elems.square()
88
+ finite = square.isfinite()
89
+ moments = torch.stack([
90
+ finite.sum(dtype=_reduce_dtype),
91
+ torch.where(finite, elems, 0).sum(),
92
+ torch.where(finite, square, 0).sum(),
93
+ ])
94
+ assert moments.ndim == 1 and moments.shape[0] == _num_moments
95
+ moments = moments.to(_counter_dtype)
96
+
97
+ device = moments.device
98
+ if device not in _counters[name]:
99
+ _counters[name][device] = torch.zeros_like(moments)
100
+ _counters[name][device].add_(moments)
101
+ return value
102
+
103
+ #----------------------------------------------------------------------------
104
+
105
+ def report0(name, value):
106
+ r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
107
+ but ignores any scalars provided by the other processes.
108
+ See `report()` for further details.
109
+ """
110
+ report(name, value if _rank == 0 else [])
111
+ return value
112
+
113
+ #----------------------------------------------------------------------------
114
+
115
+ class Collector:
116
+ r"""Collects the scalars broadcasted by `report()` and `report0()` and
117
+ computes their long-term averages (mean and standard deviation) over
118
+ user-defined periods of time.
119
+
120
+ The averages are first collected into internal counters that are not
121
+ directly visible to the user. They are then copied to the user-visible
122
+ state as a result of calling `update()` and can then be queried using
123
+ `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the
124
+ internal counters for the next round, so that the user-visible state
125
+ effectively reflects averages collected between the last two calls to
126
+ `update()`.
127
+
128
+ Args:
129
+ regex: Regular expression defining which statistics to
130
+ collect. The default is to collect everything.
131
+ keep_previous: Whether to retain the previous averages if no
132
+ scalars were collected on a given round
133
+ (default: False).
134
+ """
135
+ def __init__(self, regex='.*', keep_previous=False):
136
+ self._regex = re.compile(regex)
137
+ self._keep_previous = keep_previous
138
+ self._cumulative = dict()
139
+ self._moments = dict()
140
+ self.update()
141
+ self._moments.clear()
142
+
143
+ def names(self):
144
+ r"""Returns the names of all statistics broadcasted so far that
145
+ match the regular expression specified at construction time.
146
+ """
147
+ return [name for name in _counters if self._regex.fullmatch(name)]
148
+
149
+ def update(self):
150
+ r"""Copies current values of the internal counters to the
151
+ user-visible state and resets them for the next round.
152
+
153
+ If `keep_previous=True` was specified at construction time, the
154
+ operation is skipped for statistics that have received no scalars
155
+ since the last update, retaining their previous averages.
156
+
157
+ This method performs a number of GPU-to-CPU transfers and one
158
+ `torch.distributed.all_reduce()`. It is intended to be called
159
+ periodically in the main training loop, typically once every
160
+ N training steps.
161
+ """
162
+ if not self._keep_previous:
163
+ self._moments.clear()
164
+ for name, cumulative in _sync(self.names()):
165
+ if name not in self._cumulative:
166
+ self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
167
+ delta = cumulative - self._cumulative[name]
168
+ self._cumulative[name].copy_(cumulative)
169
+ if float(delta[0]) != 0:
170
+ self._moments[name] = delta
171
+
172
+ def _get_delta(self, name):
173
+ r"""Returns the raw moments that were accumulated for the given
174
+ statistic between the last two calls to `update()`, or zero if
175
+ no scalars were collected.
176
+ """
177
+ assert self._regex.fullmatch(name)
178
+ if name not in self._moments:
179
+ self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
180
+ return self._moments[name]
181
+
182
+ def num(self, name):
183
+ r"""Returns the number of scalars that were accumulated for the given
184
+ statistic between the last two calls to `update()`, or zero if
185
+ no scalars were collected.
186
+ """
187
+ delta = self._get_delta(name)
188
+ return int(delta[0])
189
+
190
+ def mean(self, name):
191
+ r"""Returns the mean of the scalars that were accumulated for the
192
+ given statistic between the last two calls to `update()`, or NaN if
193
+ no scalars were collected.
194
+ """
195
+ delta = self._get_delta(name)
196
+ if int(delta[0]) == 0:
197
+ return float('nan')
198
+ return float(delta[1] / delta[0])
199
+
200
+ def std(self, name):
201
+ r"""Returns the standard deviation of the scalars that were
202
+ accumulated for the given statistic between the last two calls to
203
+ `update()`, or NaN if no scalars were collected.
204
+ """
205
+ delta = self._get_delta(name)
206
+ if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
207
+ return float('nan')
208
+ if int(delta[0]) == 1:
209
+ return float(0)
210
+ mean = float(delta[1] / delta[0])
211
+ raw_var = float(delta[2] / delta[0])
212
+ return np.sqrt(max(raw_var - np.square(mean), 0))
213
+
214
+ def as_dict(self):
215
+ r"""Returns the averages accumulated between the last two calls to
216
+ `update()` as an `dnnlib.EasyDict`. The contents are as follows:
217
+
218
+ dnnlib.EasyDict(
219
+ NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT),
220
+ ...
221
+ )
222
+ """
223
+ stats = dnnlib.EasyDict()
224
+ for name in self.names():
225
+ stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
226
+ return stats
227
+
228
+ def __getitem__(self, name):
229
+ r"""Convenience getter.
230
+ `collector[name]` is a synonym for `collector.mean(name)`.
231
+ """
232
+ return self.mean(name)
233
+
234
+ #----------------------------------------------------------------------------
235
+
236
+ def _sync(names):
237
+ r"""Synchronize the global cumulative counters across devices and
238
+ processes. Called internally by `Collector.update()`.
239
+ """
240
+ if len(names) == 0:
241
+ return []
242
+ global _sync_called
243
+ _sync_called = True
244
+
245
+ # Check that all ranks have the same set of names.
246
+ if _sync_device is not None:
247
+ value = hash(tuple(tuple(ord(char) for char in name) for name in names))
248
+ other = torch.as_tensor(value, dtype=torch.int64, device=_sync_device)
249
+ torch.distributed.broadcast(tensor=other, src=0)
250
+ if value != int(other.cpu()):
251
+ raise ValueError('Training statistics are inconsistent between ranks')
252
+
253
+ # Collect deltas within current rank.
254
+ deltas = []
255
+ device = _sync_device if _sync_device is not None else torch.device('cpu')
256
+ for name in names:
257
+ delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
258
+ for counter in _counters[name].values():
259
+ delta.add_(counter.to(device))
260
+ counter.copy_(torch.zeros_like(counter))
261
+ deltas.append(delta)
262
+ deltas = torch.stack(deltas)
263
+
264
+ # Sum deltas across ranks.
265
+ if _sync_device is not None:
266
+ torch.distributed.all_reduce(deltas)
267
+
268
+ # Update cumulative values.
269
+ deltas = deltas.cpu()
270
+ for idx, name in enumerate(names):
271
+ if name not in _cumulative:
272
+ _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
273
+ _cumulative[name].add_(deltas[idx])
274
+
275
+ # Return name-value pairs.
276
+ return [(name, _cumulative[name]) for name in names]
277
+
278
+ #----------------------------------------------------------------------------
279
+ # Convenience.
280
+
281
+ default_collector = Collector()
282
+
283
+ #----------------------------------------------------------------------------
New/REG/samplers.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ def expand_t_like_x(t, x_cur):
6
+ """Function to reshape time t to broadcastable dimension of x
7
+ Args:
8
+ t: [batch_dim,], time vector
9
+ x: [batch_dim,...], data point
10
+ """
11
+ dims = [1] * (len(x_cur.size()) - 1)
12
+ t = t.view(t.size(0), *dims)
13
+ return t
14
+
15
+ def get_score_from_velocity(vt, xt, t, path_type="linear"):
16
+ """Wrapper function: transfrom velocity prediction model to score
17
+ Args:
18
+ velocity: [batch_dim, ...] shaped tensor; velocity model output
19
+ x: [batch_dim, ...] shaped tensor; x_t data point
20
+ t: [batch_dim,] time tensor
21
+ """
22
+ t = expand_t_like_x(t, xt)
23
+ if path_type == "linear":
24
+ alpha_t, d_alpha_t = 1 - t, torch.ones_like(xt, device=xt.device) * -1
25
+ sigma_t, d_sigma_t = t, torch.ones_like(xt, device=xt.device)
26
+ elif path_type == "cosine":
27
+ alpha_t = torch.cos(t * np.pi / 2)
28
+ sigma_t = torch.sin(t * np.pi / 2)
29
+ d_alpha_t = -np.pi / 2 * torch.sin(t * np.pi / 2)
30
+ d_sigma_t = np.pi / 2 * torch.cos(t * np.pi / 2)
31
+ else:
32
+ raise NotImplementedError
33
+
34
+ mean = xt
35
+ reverse_alpha_ratio = alpha_t / d_alpha_t
36
+ var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t
37
+ score = (reverse_alpha_ratio * vt - mean) / var
38
+
39
+ return score
40
+
41
+
42
+ def compute_diffusion(t_cur):
43
+ return 2 * t_cur
44
+
45
+
46
+ def euler_maruyama_sampler(
47
+ model,
48
+ latents,
49
+ y,
50
+ num_steps=20,
51
+ heun=False, # not used, just for compatability
52
+ cfg_scale=1.0,
53
+ guidance_low=0.0,
54
+ guidance_high=1.0,
55
+ path_type="linear",
56
+ cls_latents=None,
57
+ args=None,
58
+ return_mid_state=False,
59
+ t_mid=0.5,
60
+ ):
61
+ # setup conditioning
62
+ if cfg_scale > 1.0:
63
+ y_null = torch.tensor([1000] * y.size(0), device=y.device)
64
+ #[1000, 1000]
65
+ _dtype = latents.dtype
66
+
67
+
68
+ t_steps = torch.linspace(1., 0.04, num_steps, dtype=torch.float64)
69
+ t_steps = torch.cat([t_steps, torch.tensor([0.], dtype=torch.float64)])
70
+ x_next = latents.to(torch.float64)
71
+ cls_x_next = cls_latents.to(torch.float64)
72
+ device = x_next.device
73
+ z_mid, cls_mid = None, None
74
+ t_mid = float(t_mid)
75
+
76
+
77
+ with torch.no_grad():
78
+ for i, (t_cur, t_next) in enumerate(zip(t_steps[:-2], t_steps[1:-1])):
79
+ dt = t_next - t_cur
80
+ x_cur = x_next
81
+ cls_x_cur = cls_x_next
82
+ tc, tn = float(t_cur), float(t_next)
83
+ if return_mid_state and z_mid is None and tn <= t_mid <= tc:
84
+ if abs(tc - t_mid) < abs(tn - t_mid):
85
+ z_mid = x_cur.clone()
86
+ cls_mid = cls_x_cur.clone()
87
+
88
+ if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
89
+ model_input = torch.cat([x_cur] * 2, dim=0)
90
+ cls_model_input = torch.cat([cls_x_cur] * 2, dim=0)
91
+ y_cur = torch.cat([y, y_null], dim=0)
92
+ else:
93
+ model_input = x_cur
94
+ cls_model_input = cls_x_cur
95
+ y_cur = y
96
+
97
+ kwargs = dict(y=y_cur)
98
+ time_input = torch.ones(model_input.size(0)).to(device=device, dtype=torch.float64) * t_cur
99
+ diffusion = compute_diffusion(t_cur)
100
+
101
+ eps_i = torch.randn_like(x_cur).to(device)
102
+ cls_eps_i = torch.randn_like(cls_x_cur).to(device)
103
+ deps = eps_i * torch.sqrt(torch.abs(dt))
104
+ cls_deps = cls_eps_i * torch.sqrt(torch.abs(dt))
105
+
106
+ # compute drift
107
+ v_cur, _, cls_v_cur = model(
108
+ model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype)
109
+ )
110
+ v_cur = v_cur.to(torch.float64)
111
+ cls_v_cur = cls_v_cur.to(torch.float64)
112
+
113
+ s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type)
114
+ d_cur = v_cur - 0.5 * diffusion * s_cur
115
+
116
+ cls_s_cur = get_score_from_velocity(cls_v_cur, cls_model_input, time_input, path_type=path_type)
117
+ cls_d_cur = cls_v_cur - 0.5 * diffusion * cls_s_cur
118
+
119
+ if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:
120
+ d_cur_cond, d_cur_uncond = d_cur.chunk(2)
121
+ d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
122
+
123
+ cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2)
124
+ if args.cls_cfg_scale >0:
125
+ cls_d_cur = cls_d_cur_uncond + args.cls_cfg_scale * (cls_d_cur_cond - cls_d_cur_uncond)
126
+ else:
127
+ cls_d_cur = cls_d_cur_cond
128
+ x_next = x_cur + d_cur * dt + torch.sqrt(diffusion) * deps
129
+ cls_x_next = cls_x_cur + cls_d_cur * dt + torch.sqrt(diffusion) * cls_deps
130
+ if return_mid_state and z_mid is None and tn <= t_mid <= tc:
131
+ z_mid = x_next.clone()
132
+ cls_mid = cls_x_next.clone()
133
+
134
+ # last step
135
+ t_cur, t_next = t_steps[-2], t_steps[-1]
136
+ dt = t_next - t_cur
137
+ x_cur = x_next
138
+ cls_x_cur = cls_x_next
139
+
140
+ if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
141
+ model_input = torch.cat([x_cur] * 2, dim=0)
142
+ cls_model_input = torch.cat([cls_x_cur] * 2, dim=0)
143
+ y_cur = torch.cat([y, y_null], dim=0)
144
+ else:
145
+ model_input = x_cur
146
+ cls_model_input = cls_x_cur
147
+ y_cur = y
148
+ kwargs = dict(y=y_cur)
149
+ time_input = torch.ones(model_input.size(0)).to(
150
+ device=device, dtype=torch.float64
151
+ ) * t_cur
152
+
153
+ # compute drift
154
+ v_cur, _, cls_v_cur = model(
155
+ model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype)
156
+ )
157
+ v_cur = v_cur.to(torch.float64)
158
+ cls_v_cur = cls_v_cur.to(torch.float64)
159
+
160
+
161
+ s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type)
162
+ cls_s_cur = get_score_from_velocity(cls_v_cur, cls_model_input, time_input, path_type=path_type)
163
+
164
+ diffusion = compute_diffusion(t_cur)
165
+ d_cur = v_cur - 0.5 * diffusion * s_cur
166
+ cls_d_cur = cls_v_cur - 0.5 * diffusion * cls_s_cur # d_cur [b, 4, 32 ,32]
167
+
168
+ if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:
169
+ d_cur_cond, d_cur_uncond = d_cur.chunk(2)
170
+ d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
171
+
172
+ cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2)
173
+ if args.cls_cfg_scale > 0:
174
+ cls_d_cur = cls_d_cur_uncond + args.cls_cfg_scale * (cls_d_cur_cond - cls_d_cur_uncond)
175
+ else:
176
+ cls_d_cur = cls_d_cur_cond
177
+
178
+ mean_x = x_cur + dt * d_cur
179
+ cls_mean_x = cls_x_cur + dt * cls_d_cur
180
+
181
+ if return_mid_state:
182
+ return mean_x, z_mid, cls_mean_x if cls_mid is None else cls_mid
183
+ return mean_x
184
+
185
+
186
+ def euler_sampler(
187
+ model,
188
+ latents,
189
+ y,
190
+ num_steps=20,
191
+ heun=False, # not used; only for compatibility with caller
192
+ cfg_scale=1.0,
193
+ guidance_low=0.0,
194
+ guidance_high=1.0,
195
+ path_type="linear", # not used for ODE (velocity parameterization directly)
196
+ cls_latents=None,
197
+ args=None
198
+ ):
199
+ """
200
+ REG 的 ODE 采样器:确定性(不注入扩散噪声)。
201
+
202
+ 这里按照 REG/SiT 的 velocity 参数化直接做 ODE:
203
+ d/dt x_t = v_t
204
+ 因此不需要把 velocity 再转成 score 再转 drift。
205
+ """
206
+ # setup conditioning
207
+ if cfg_scale > 1.0:
208
+ y_null = torch.tensor([1000] * y.size(0), device=y.device)
209
+ _dtype = latents.dtype
210
+
211
+ cls_cfg_scale = getattr(args, "cls_cfg_scale", 0) if args is not None else 0
212
+
213
+ # ODE 时间网格:默认从 t=1 到 t=0
214
+ t_steps = torch.linspace(1.0, 0.0, int(num_steps) + 1, dtype=torch.float64, device=latents.device)
215
+
216
+ x_next = latents.to(torch.float64)
217
+ cls_x_next = cls_latents.to(torch.float64)
218
+ device = x_next.device
219
+
220
+ with torch.no_grad():
221
+ for t_cur, t_next in zip(t_steps[:-1], t_steps[1:]):
222
+ dt = t_next - t_cur
223
+ x_cur = x_next
224
+ cls_x_cur = cls_x_next
225
+
226
+ # classifier-free guidance(只在指定时间段启用)
227
+ if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
228
+ model_input = torch.cat([x_cur] * 2, dim=0)
229
+ cls_model_input = torch.cat([cls_x_cur] * 2, dim=0)
230
+ y_cur = torch.cat([y, y_null], dim=0)
231
+ else:
232
+ model_input = x_cur
233
+ cls_model_input = cls_x_cur
234
+ y_cur = y
235
+
236
+ kwargs = dict(y=y_cur)
237
+ time_input = torch.ones(model_input.size(0), device=device, dtype=torch.float64) * t_cur
238
+
239
+ v_cur, _, cls_v_cur = model(
240
+ model_input.to(dtype=_dtype),
241
+ time_input.to(dtype=_dtype),
242
+ **kwargs,
243
+ cls_token=cls_model_input.to(dtype=_dtype),
244
+ )
245
+
246
+ # ODE:velocity 参数化直接作为导数
247
+ d_cur = v_cur.to(torch.float64)
248
+ cls_d_cur = cls_v_cur.to(torch.float64)
249
+
250
+ # 指定时间段内进行 guidance 合成
251
+ if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
252
+ d_cur_cond, d_cur_uncond = d_cur.chunk(2)
253
+ d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
254
+
255
+ cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2)
256
+ if cls_cfg_scale > 0:
257
+ cls_d_cur = cls_d_cur_uncond + cls_cfg_scale * (cls_d_cur_cond - cls_d_cur_uncond)
258
+ else:
259
+ cls_d_cur = cls_d_cur_cond
260
+
261
+ x_next = x_cur + dt * d_cur
262
+ cls_x_next = cls_x_cur + dt * cls_d_cur
263
+
264
+ return x_next
New/REG/train.py ADDED
@@ -0,0 +1,573 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import copy
3
+ from copy import deepcopy
4
+ import logging
5
+ import os
6
+ from pathlib import Path
7
+ from collections import OrderedDict
8
+ import json
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as F
13
+ import torch.utils.checkpoint
14
+ from tqdm.auto import tqdm
15
+ from torch.utils.data import DataLoader
16
+
17
+ from accelerate import Accelerator, DistributedDataParallelKwargs
18
+ from accelerate.logging import get_logger
19
+ from accelerate.utils import ProjectConfiguration, set_seed
20
+
21
+ from models.sit import SiT_models
22
+ from loss import SILoss
23
+ from utils import load_encoders
24
+
25
+ from dataset import CustomDataset
26
+ from diffusers.models import AutoencoderKL
27
+ from PIL import Image
28
+
29
+ from samplers import euler_maruyama_sampler
30
+ # import wandb_utils
31
+ import wandb
32
+ import math
33
+ from torchvision.utils import make_grid
34
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
35
+ from torchvision.transforms import Normalize
36
+
37
+ logger = get_logger(__name__)
38
+
39
+ CLIP_DEFAULT_MEAN = (0.48145466, 0.4578275, 0.40821073)
40
+ CLIP_DEFAULT_STD = (0.26862954, 0.26130258, 0.27577711)
41
+
42
+
43
+
44
+ def preprocess_raw_image(x, enc_type):
45
+ resolution = x.shape[-1]
46
+ if 'clip' in enc_type:
47
+ x = x / 255.
48
+ x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic')
49
+ x = Normalize(CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD)(x)
50
+ elif 'mocov3' in enc_type or 'mae' in enc_type:
51
+ x = x / 255.
52
+ x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
53
+ elif 'dinov2' in enc_type:
54
+ x = x / 255.
55
+ x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
56
+ x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic')
57
+ elif 'dinov1' in enc_type:
58
+ x = x / 255.
59
+ x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
60
+ elif 'jepa' in enc_type:
61
+ x = x / 255.
62
+ x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
63
+ x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic')
64
+
65
+ return x
66
+
67
+
68
+ def array2grid(x):
69
+ nrow = round(math.sqrt(x.size(0)))
70
+ x = make_grid(x.clamp(0, 1), nrow=nrow, value_range=(0, 1))
71
+ x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
72
+ return x
73
+
74
+
75
+ @torch.no_grad()
76
+ def sample_posterior(moments, latents_scale=1., latents_bias=0.):
77
+ device = moments.device
78
+
79
+ mean, std = torch.chunk(moments, 2, dim=1)
80
+ z = mean + std * torch.randn_like(mean)
81
+ z = (z * latents_scale + latents_bias)
82
+ return z
83
+
84
+
85
+ @torch.no_grad()
86
+ def update_ema(ema_model, model, decay=0.9999):
87
+ """
88
+ Step the EMA model towards the current model.
89
+ """
90
+ ema_params = OrderedDict(ema_model.named_parameters())
91
+ model_params = OrderedDict(model.named_parameters())
92
+
93
+ for name, param in model_params.items():
94
+ name = name.replace("module.", "")
95
+ # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
96
+ ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
97
+
98
+
99
+ def create_logger(logging_dir):
100
+ """
101
+ Create a logger that writes to a log file and stdout.
102
+ """
103
+ logging.basicConfig(
104
+ level=logging.INFO,
105
+ format='[\033[34m%(asctime)s\033[0m] %(message)s',
106
+ datefmt='%Y-%m-%d %H:%M:%S',
107
+ handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
108
+ )
109
+ logger = logging.getLogger(__name__)
110
+ return logger
111
+
112
+
113
+ def requires_grad(model, flag=True):
114
+ """
115
+ Set requires_grad flag for all parameters in a model.
116
+ """
117
+ for p in model.parameters():
118
+ p.requires_grad = flag
119
+
120
+
121
+ #################################################################################
122
+ # Training Loop #
123
+ #################################################################################
124
+
125
+ def main(args):
126
+ # set accelerator
127
+ logging_dir = Path(args.output_dir, args.logging_dir)
128
+ accelerator_project_config = ProjectConfiguration(
129
+ project_dir=args.output_dir, logging_dir=logging_dir
130
+ )
131
+
132
+ accelerator = Accelerator(
133
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
134
+ mixed_precision=args.mixed_precision,
135
+ log_with=args.report_to,
136
+ project_config=accelerator_project_config,
137
+ kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)]
138
+ )
139
+
140
+ if accelerator.is_main_process:
141
+ os.makedirs(args.output_dir, exist_ok=True) # Make results folder (holds all experiment subfolders)
142
+ save_dir = os.path.join(args.output_dir, args.exp_name)
143
+ os.makedirs(save_dir, exist_ok=True)
144
+ args_dict = vars(args)
145
+ # Save to a JSON file
146
+ json_dir = os.path.join(save_dir, "args.json")
147
+ with open(json_dir, 'w') as f:
148
+ json.dump(args_dict, f, indent=4)
149
+ checkpoint_dir = f"{save_dir}/checkpoints" # Stores saved model checkpoints
150
+ os.makedirs(checkpoint_dir, exist_ok=True)
151
+ logger = create_logger(save_dir)
152
+ logger.info(f"Experiment directory created at {save_dir}")
153
+ device = accelerator.device
154
+ if torch.backends.mps.is_available():
155
+ accelerator.native_amp = False
156
+ if args.seed is not None:
157
+ set_seed(args.seed + accelerator.process_index)
158
+
159
+ # Create model:
160
+ assert args.resolution % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)."
161
+ latent_size = args.resolution // 8
162
+
163
+ if args.enc_type != None:
164
+ encoders, encoder_types, architectures = load_encoders(
165
+ args.enc_type, device, args.resolution
166
+ )
167
+ else:
168
+ raise NotImplementedError()
169
+ z_dims = [encoder.embed_dim for encoder in encoders] if args.enc_type != 'None' else [0]
170
+ block_kwargs = {"fused_attn": args.fused_attn, "qk_norm": args.qk_norm}
171
+ model = SiT_models[args.model](
172
+ input_size=latent_size,
173
+ num_classes=args.num_classes,
174
+ use_cfg = (args.cfg_prob > 0),
175
+ z_dims = z_dims,
176
+ encoder_depth=args.encoder_depth,
177
+ **block_kwargs
178
+ )
179
+
180
+ model = model.to(device)
181
+ ema = deepcopy(model).to(device) # Create an EMA of the model for use after training
182
+ requires_grad(ema, False)
183
+
184
+ latents_scale = torch.tensor(
185
+ [0.18215, 0.18215, 0.18215, 0.18215]
186
+ ).view(1, 4, 1, 1).to(device)
187
+ latents_bias = torch.tensor(
188
+ [0., 0., 0., 0.]
189
+ ).view(1, 4, 1, 1).to(device)
190
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device)
191
+ vae.eval()
192
+
193
+ # create loss function
194
+ loss_fn = SILoss(
195
+ prediction=args.prediction,
196
+ path_type=args.path_type,
197
+ encoders=encoders,
198
+ accelerator=accelerator,
199
+ latents_scale=latents_scale,
200
+ latents_bias=latents_bias,
201
+ weighting=args.weighting
202
+ )
203
+ if accelerator.is_main_process:
204
+ logger.info(f"SiT Parameters: {sum(p.numel() for p in model.parameters()):,}")
205
+
206
+ # Setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper):
207
+ if args.allow_tf32:
208
+ torch.backends.cuda.matmul.allow_tf32 = True
209
+ torch.backends.cudnn.allow_tf32 = True
210
+
211
+ optimizer = torch.optim.AdamW(
212
+ model.parameters(),
213
+ lr=args.learning_rate,
214
+ betas=(args.adam_beta1, args.adam_beta2),
215
+ weight_decay=args.adam_weight_decay,
216
+ eps=args.adam_epsilon,
217
+ )
218
+
219
+ # Setup data:
220
+ train_dataset = CustomDataset(
221
+ args.data_dir, semantic_features_dir=args.semantic_features_dir
222
+ )
223
+ local_batch_size = int(args.batch_size // accelerator.num_processes)
224
+ train_dataloader = DataLoader(
225
+ train_dataset,
226
+ batch_size=local_batch_size,
227
+ shuffle=True,
228
+ num_workers=args.num_workers,
229
+ pin_memory=True,
230
+ drop_last=True
231
+ )
232
+ if accelerator.is_main_process:
233
+ logger.info(f"Dataset contains {len(train_dataset):,} images ({args.data_dir})")
234
+
235
+ # Prepare models for training:
236
+ update_ema(ema, model, decay=0) # Ensure EMA is initialized with synced weights
237
+ model.train() # important! This enables embedding dropout for classifier-free guidance
238
+ ema.eval() # EMA model should always be in eval mode
239
+
240
+ # resume:
241
+ global_step = 0
242
+ if args.resume_step > 0:
243
+ ckpt_name = str(args.resume_step).zfill(7) +'.pt'
244
+ ckpt = torch.load(
245
+ f'{os.path.join(args.output_dir, args.exp_name)}/checkpoints/{ckpt_name}',
246
+ map_location='cpu',
247
+ )
248
+ model.load_state_dict(ckpt['model'])
249
+ ema.load_state_dict(ckpt['ema'])
250
+ optimizer.load_state_dict(ckpt['opt'])
251
+ global_step = ckpt['steps']
252
+
253
+ model, optimizer, train_dataloader = accelerator.prepare(
254
+ model, optimizer, train_dataloader
255
+ )
256
+
257
+ if accelerator.is_main_process:
258
+ tracker_config = vars(copy.deepcopy(args))
259
+ accelerator.init_trackers(
260
+ project_name="REG",
261
+ config=tracker_config,
262
+ init_kwargs={
263
+ "wandb": {"name": f"{args.exp_name}"}
264
+ },
265
+ )
266
+
267
+
268
+ progress_bar = tqdm(
269
+ range(0, args.max_train_steps),
270
+ initial=global_step,
271
+ desc="Steps",
272
+ # Only show the progress bar once on each machine.
273
+ disable=not accelerator.is_local_main_process,
274
+ )
275
+
276
+ # Labels to condition the model with (feel free to change):
277
+ sample_batch_size = 64 // accelerator.num_processes
278
+ first_batch = next(iter(train_dataloader))
279
+ preprocessed_semantic = len(first_batch) == 4
280
+ if preprocessed_semantic:
281
+ gt_raw_images, gt_xs, _r_pre, _y = first_batch
282
+ else:
283
+ gt_raw_images, gt_xs, _y = first_batch
284
+ # 仅在“非预处理 semantic 模式”下,raw_image 是 RGB 图(分辨率应与 args.resolution 对齐)。
285
+ if not preprocessed_semantic:
286
+ assert gt_raw_images.shape[-1] == args.resolution
287
+ gt_xs = gt_xs[:sample_batch_size]
288
+ gt_xs = sample_posterior(
289
+ gt_xs.to(device), latents_scale=latents_scale, latents_bias=latents_bias
290
+ )
291
+ ys = torch.randint(1000, size=(sample_batch_size,), device=device)
292
+ ys = ys.to(device)
293
+ # Create sampling noise:
294
+ n = ys.size(0)
295
+ xT = torch.randn((n, 4, latent_size, latent_size), device=device)
296
+
297
+ for epoch in range(args.epochs):
298
+ model.train()
299
+ for batch in train_dataloader:
300
+ if len(batch) == 4:
301
+ raw_image, x, r_preprocessed, y = batch
302
+ r_preprocessed = r_preprocessed.to(device).float()
303
+ else:
304
+ raw_image, x, y = batch
305
+ r_preprocessed = None
306
+
307
+ raw_image = raw_image.to(device)
308
+ x = x.squeeze(dim=1).to(device)
309
+ y = y.to(device)
310
+
311
+ z = None
312
+ if args.legacy:
313
+ # In our early experiments, we accidentally apply label dropping twice:
314
+ # once in train.py and once in sit.py.
315
+ # We keep this option for exact reproducibility with previous runs.
316
+ drop_ids = torch.rand(y.shape[0], device=y.device) < args.cfg_prob
317
+ labels = torch.where(drop_ids, args.num_classes, y)
318
+ else:
319
+ labels = y
320
+ with torch.no_grad():
321
+ x = sample_posterior(x, latents_scale=latents_scale, latents_bias=latents_bias)
322
+ zs = []
323
+ if r_preprocessed is not None:
324
+ # 预处理 semantic 模式:直接用 cls token 构造 dense tokens。
325
+ cls_token = r_preprocessed
326
+ while cls_token.dim() > 2:
327
+ cls_token = cls_token.squeeze(1)
328
+
329
+ base_m = model.module if hasattr(model, "module") else model
330
+ n_pad = base_m.x_embedder.num_patches
331
+ zs = [
332
+ torch.cat(
333
+ [
334
+ cls_token.unsqueeze(1),
335
+ cls_token.unsqueeze(1).expand(-1, n_pad, -1),
336
+ ],
337
+ dim=1,
338
+ )
339
+ ]
340
+ else:
341
+ # 在线 encoder 模式:与原 New/REG 行为一致
342
+ with accelerator.autocast():
343
+ for encoder, encoder_type, arch in zip(
344
+ encoders, encoder_types, architectures
345
+ ):
346
+ raw_image_ = preprocess_raw_image(raw_image, encoder_type)
347
+ z = encoder.forward_features(raw_image_)
348
+ if "dinov2" in encoder_type:
349
+ dense_z = z["x_norm_patchtokens"]
350
+ cls_token = z["x_norm_clstoken"]
351
+ dense_z = torch.cat(
352
+ [cls_token.unsqueeze(1), dense_z], dim=1
353
+ )
354
+ else:
355
+ exit()
356
+ zs.append(dense_z)
357
+
358
+ with accelerator.accumulate(model):
359
+ model_kwargs = dict(y=labels)
360
+ loss1, proj_loss1, time_input, noises, loss2 = loss_fn(model, x, model_kwargs, zs=zs,
361
+ cls_token=cls_token,
362
+ time_input=None, noises=None)
363
+ loss_mean = loss1.mean()
364
+ loss_mean_cls = loss2.mean() * args.cls
365
+ proj_loss_mean = proj_loss1.mean() * args.proj_coeff
366
+ loss = loss_mean + proj_loss_mean + loss_mean_cls
367
+
368
+
369
+ ## optimization
370
+ accelerator.backward(loss)
371
+ if accelerator.sync_gradients:
372
+ params_to_clip = model.parameters()
373
+ grad_norm = accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
374
+ optimizer.step()
375
+ optimizer.zero_grad(set_to_none=True)
376
+
377
+ if accelerator.sync_gradients:
378
+ update_ema(ema, model) # change ema function
379
+
380
+ ### enter
381
+ if accelerator.sync_gradients:
382
+ progress_bar.update(1)
383
+ global_step += 1
384
+ if global_step % args.checkpointing_steps == 0 and global_step > 0:
385
+ if accelerator.is_main_process:
386
+ checkpoint = {
387
+ "model": model.module.state_dict(),
388
+ "ema": ema.state_dict(),
389
+ "opt": optimizer.state_dict(),
390
+ "args": args,
391
+ "steps": global_step,
392
+ }
393
+ checkpoint_path = f"{checkpoint_dir}/{global_step:07d}.pt"
394
+ torch.save(checkpoint, checkpoint_path)
395
+ logger.info(f"Saved checkpoint to {checkpoint_path}")
396
+
397
+ if (global_step == 1 or (global_step % args.sampling_steps == 0 and global_step > 0)):
398
+ t_mid_vis = float(args.t_c)
399
+ tc_tag = f"{t_mid_vis:.4f}".rstrip("0").rstrip(".").replace(".", "_")
400
+ logging.info(
401
+ f"Generating EMA samples (Euler-Maruyama; t≈{t_mid_vis:g} -> t=0)..."
402
+ )
403
+ ema.eval()
404
+ with torch.no_grad():
405
+ latent_size = args.resolution // 8
406
+ n_samples = min(16, args.batch_size)
407
+ base_model = model.module if hasattr(model, "module") else model
408
+ cls_dim = base_model.z_dims[0]
409
+ shared_seed = torch.randint(0, 2**32, (1,), device=device).item()
410
+ torch.manual_seed(shared_seed)
411
+ z_init = torch.randn(n_samples, base_model.in_channels, latent_size, latent_size, device=device)
412
+ torch.manual_seed(shared_seed)
413
+ cls_init = torch.randn(n_samples, cls_dim, device=device)
414
+ y_samples = torch.randint(0, args.num_classes, (n_samples,), device=device)
415
+
416
+ z_0, z_mid, _ = euler_maruyama_sampler(
417
+ ema,
418
+ z_init,
419
+ y_samples,
420
+ num_steps=50,
421
+ cfg_scale=1.0,
422
+ guidance_low=0.0,
423
+ guidance_high=1.0,
424
+ path_type=args.path_type,
425
+ cls_latents=cls_init,
426
+ args=args,
427
+ return_mid_state=True,
428
+ t_mid=t_mid_vis,
429
+ )
430
+
431
+ samples_root = os.path.join(args.output_dir, args.exp_name, "samples")
432
+ t0_dir = os.path.join(samples_root, "t0")
433
+ t_mid_dir = os.path.join(samples_root, f"t0_{tc_tag}")
434
+ os.makedirs(t0_dir, exist_ok=True)
435
+ os.makedirs(t_mid_dir, exist_ok=True)
436
+
437
+ z_f = z_0.to(dtype=torch.float32)
438
+ samples_final = vae.decode((z_f - latents_bias) / latents_scale).sample
439
+ samples_final = (samples_final + 1) / 2.0
440
+ samples_final = samples_final.clamp(0, 1)
441
+ grid_final = array2grid(samples_final)
442
+ Image.fromarray(grid_final).save(
443
+ os.path.join(t0_dir, f"step_{global_step:07d}_t0.png")
444
+ )
445
+
446
+ if z_mid is not None:
447
+ z_m = z_mid.to(dtype=torch.float32)
448
+ samples_mid = vae.decode((z_m - latents_bias) / latents_scale).sample
449
+ samples_mid = (samples_mid + 1) / 2.0
450
+ samples_mid = samples_mid.clamp(0, 1)
451
+ grid_mid = array2grid(samples_mid)
452
+ Image.fromarray(grid_mid).save(
453
+ os.path.join(t_mid_dir, f"step_{global_step:07d}_t0_{tc_tag}.png")
454
+ )
455
+ else:
456
+ logging.warning(
457
+ f"Sampling time grid did not bracket t_mid={t_mid_vis:g}; "
458
+ f"skip t0_{tc_tag} image this step."
459
+ )
460
+
461
+ del z_init, cls_init, y_samples, z_0
462
+ if z_mid is not None:
463
+ del z_mid
464
+ del samples_final, grid_final
465
+ if "samples_mid" in locals():
466
+ del samples_mid, grid_mid
467
+ torch.cuda.empty_cache()
468
+
469
+ logs = {
470
+ "loss_final": accelerator.gather(loss).mean().detach().item(),
471
+ "loss_mean": accelerator.gather(loss_mean).mean().detach().item(),
472
+ "proj_loss": accelerator.gather(proj_loss_mean).mean().detach().item(),
473
+ "loss_mean_cls": accelerator.gather(loss_mean_cls).mean().detach().item(),
474
+ "grad_norm": accelerator.gather(grad_norm).mean().detach().item()
475
+ }
476
+
477
+ log_message = ", ".join(f"{key}: {value:.6f}" for key, value in logs.items())
478
+ logging.info(f"Step: {global_step}, Training Logs: {log_message}")
479
+
480
+ progress_bar.set_postfix(**logs)
481
+ accelerator.log(logs, step=global_step)
482
+
483
+ if global_step >= args.max_train_steps:
484
+ break
485
+ if global_step >= args.max_train_steps:
486
+ break
487
+
488
+ model.eval() # important! This disables randomized embedding dropout
489
+ # do any sampling/FID calculation/etc. with ema (or model) in eval mode ...
490
+
491
+ accelerator.wait_for_everyone()
492
+ if accelerator.is_main_process:
493
+ logger.info("Done!")
494
+ accelerator.end_training()
495
+
496
+ def parse_args(input_args=None):
497
+ parser = argparse.ArgumentParser(description="Training")
498
+
499
+ # logging:
500
+ parser.add_argument("--output-dir", type=str, default="exps")
501
+ parser.add_argument("--exp-name", type=str, required=True)
502
+ parser.add_argument("--logging-dir", type=str, default="logs")
503
+ parser.add_argument("--report-to", type=str, default="wandb")
504
+ parser.add_argument("--sampling-steps", type=int, default=10000)
505
+ parser.add_argument("--resume-step", type=int, default=0)
506
+
507
+ # model
508
+ parser.add_argument("--model", type=str)
509
+ parser.add_argument("--num-classes", type=int, default=1000)
510
+ parser.add_argument("--encoder-depth", type=int, default=8)
511
+ parser.add_argument("--fused-attn", action=argparse.BooleanOptionalAction, default=True)
512
+ parser.add_argument("--qk-norm", action=argparse.BooleanOptionalAction, default=False)
513
+ parser.add_argument("--ops-head", type=int, default=16)
514
+
515
+ # dataset
516
+ parser.add_argument("--data-dir", type=str, default="../data/imagenet256")
517
+ parser.add_argument(
518
+ "--semantic-features-dir",
519
+ type=str,
520
+ default=None,
521
+ help="预处理 semantic features 目录(与 REG/dataset.py 语义相同),仅影响数据加载方式,不引入 t_c。",
522
+ )
523
+ parser.add_argument("--resolution", type=int, choices=[256, 512], default=256)
524
+ parser.add_argument("--batch-size", type=int, default=8)#256
525
+
526
+ # precision
527
+ parser.add_argument("--allow-tf32", action="store_true")
528
+ parser.add_argument("--mixed-precision", type=str, default="fp16", choices=["no", "fp16", "bf16"])
529
+
530
+ # optimization
531
+ parser.add_argument("--epochs", type=int, default=1400)
532
+ parser.add_argument("--max-train-steps", type=int, default=1000000)
533
+ parser.add_argument("--checkpointing-steps", type=int, default=10000)
534
+ parser.add_argument("--gradient-accumulation-steps", type=int, default=1)
535
+ parser.add_argument("--learning-rate", type=float, default=1e-4)
536
+ parser.add_argument("--adam-beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
537
+ parser.add_argument("--adam-beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
538
+ parser.add_argument("--adam-weight-decay", type=float, default=0., help="Weight decay to use.")
539
+ parser.add_argument("--adam-epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
540
+ parser.add_argument("--max-grad-norm", default=1.0, type=float, help="Max gradient norm.")
541
+
542
+ # seed
543
+ parser.add_argument("--seed", type=int, default=0)
544
+
545
+ # cpu
546
+ parser.add_argument("--num-workers", type=int, default=4)
547
+
548
+ # loss
549
+ parser.add_argument("--path-type", type=str, default="linear", choices=["linear", "cosine"])
550
+ parser.add_argument("--prediction", type=str, default="v", choices=["v"]) # currently we only support v-prediction
551
+ parser.add_argument("--cfg-prob", type=float, default=0.1)
552
+ parser.add_argument("--enc-type", type=str, default='dinov2-vit-b')
553
+ parser.add_argument("--proj-coeff", type=float, default=0.5)
554
+ parser.add_argument("--weighting", default="uniform", type=str, help="Max gradient norm.")
555
+ parser.add_argument("--legacy", action=argparse.BooleanOptionalAction, default=False)
556
+ parser.add_argument("--cls", type=float, default=0.03)
557
+ parser.add_argument(
558
+ "--t-c",
559
+ type=float,
560
+ default=0.5,
561
+ help="训练中采样时保存的中间时刻 t(用于输出 t0 与 t0_tc 对比图)。",
562
+ )
563
+ if input_args is not None:
564
+ args = parser.parse_args(input_args)
565
+ else:
566
+ args = parser.parse_args()
567
+
568
+ return args
569
+
570
+ if __name__ == "__main__":
571
+ args = parse_args()
572
+
573
+ main(args)
New/REG/train.sh ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # New/REG/train.sh:只对齐数据加载相关目录/参数;不引入 t_c/ot-cls。
3
+
4
+ set -euo pipefail
5
+
6
+ NUM_GPUS=4
7
+
8
+ # ------------------------- 按需修改这些路径 -------------------------
9
+ # New/REG/dataset.py 期望 data_dir 下至少有:
10
+ # - imagenet_256_vae/
11
+ # - vae-sd/(非预处理 semantic 模式需要)
12
+ # 若启用预处理语义:还需要
13
+ # - imagenet_256_features/dinov2-vit-b_tmp/gpu0/dataset.json + 对应 .npy
14
+ DATA_DIR="/gemini/space/zhaozy/dataset/Imagenet/imagenet_256"
15
+ SEMANTIC_FEATURES_DIR="/gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0"
16
+
17
+ OUTPUT_BASE_DIR="/your_path/reg_xlarge_dinov2_base_align_8_cls"
18
+ EXP_NAME="linear-dinov2-b-enc8"
19
+
20
+ # ------------------------- 不建议改动以下逻辑 -------------------------
21
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
22
+ TRAIN_PY="${SCRIPT_DIR}/train.py"
23
+
24
+ nohup accelerate launch --multi_gpu --num_processes "${NUM_GPUS}" "${TRAIN_PY}" \
25
+ --report-to="wandb" \
26
+ --allow-tf32 \
27
+ --mixed-precision="fp16" \
28
+ --seed=0 \
29
+ --path-type="linear" \
30
+ --prediction="v" \
31
+ --weighting="uniform" \
32
+ --model="SiT-XL/2" \
33
+ --enc-type="dinov2-vit-b" \
34
+ --proj-coeff=0.5 \
35
+ --encoder-depth=8 \
36
+ --cls=0.03 \
37
+ --output-dir="${OUTPUT_BASE_DIR}" \
38
+ --exp-name="${EXP_NAME}" \
39
+ --batch-size=256 \
40
+ --data-dir="${DATA_DIR}" \
41
+ --semantic-features-dir="${SEMANTIC_FEATURES_DIR}" \
42
+ > jsflow-experiment.log 2>&1 &
New/REG/utils.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from torchvision.datasets.utils import download_url
3
+ import torch
4
+ import torchvision.models as torchvision_models
5
+ import timm
6
+ from models import mocov3_vit
7
+ import math
8
+ import warnings
9
+
10
+
11
+ # code from SiT repository
12
+ pretrained_models = {'last.pt'}
13
+
14
+ def download_model(model_name):
15
+ """
16
+ Downloads a pre-trained SiT model from the web.
17
+ """
18
+ assert model_name in pretrained_models
19
+ local_path = f'pretrained_models/{model_name}'
20
+ if not os.path.isfile(local_path):
21
+ os.makedirs('pretrained_models', exist_ok=True)
22
+ web_path = f'https://www.dl.dropboxusercontent.com/scl/fi/cxedbs4da5ugjq5wg3zrg/last.pt?rlkey=8otgrdkno0nd89po3dpwngwcc&st=apcc645o&dl=0'
23
+ download_url(web_path, 'pretrained_models', filename=model_name)
24
+ model = torch.load(local_path, map_location=lambda storage, loc: storage)
25
+ return model
26
+
27
+ def fix_mocov3_state_dict(state_dict):
28
+ for k in list(state_dict.keys()):
29
+ # retain only base_encoder up to before the embedding layer
30
+ if k.startswith('module.base_encoder'):
31
+ # fix naming bug in checkpoint
32
+ new_k = k[len("module.base_encoder."):]
33
+ if "blocks.13.norm13" in new_k:
34
+ new_k = new_k.replace("norm13", "norm1")
35
+ if "blocks.13.mlp.fc13" in k:
36
+ new_k = new_k.replace("fc13", "fc1")
37
+ if "blocks.14.norm14" in k:
38
+ new_k = new_k.replace("norm14", "norm2")
39
+ if "blocks.14.mlp.fc14" in k:
40
+ new_k = new_k.replace("fc14", "fc2")
41
+ # remove prefix
42
+ if 'head' not in new_k and new_k.split('.')[0] != 'fc':
43
+ state_dict[new_k] = state_dict[k]
44
+ # delete renamed or unused k
45
+ del state_dict[k]
46
+ if 'pos_embed' in state_dict.keys():
47
+ state_dict['pos_embed'] = timm.layers.pos_embed.resample_abs_pos_embed(
48
+ state_dict['pos_embed'], [16, 16],
49
+ )
50
+ return state_dict
51
+
52
+ @torch.no_grad()
53
+ def load_encoders(enc_type, device, resolution=256):
54
+ assert (resolution == 256) or (resolution == 512)
55
+
56
+ enc_names = enc_type.split(',')
57
+ encoders, architectures, encoder_types = [], [], []
58
+ for enc_name in enc_names:
59
+ encoder_type, architecture, model_config = enc_name.split('-')
60
+ # Currently, we only support 512x512 experiments with DINOv2 encoders.
61
+ if resolution == 512:
62
+ if encoder_type != 'dinov2':
63
+ raise NotImplementedError(
64
+ "Currently, we only support 512x512 experiments with DINOv2 encoders."
65
+ )
66
+
67
+ architectures.append(architecture)
68
+ encoder_types.append(encoder_type)
69
+ if encoder_type == 'mocov3':
70
+ if architecture == 'vit':
71
+ if model_config == 's':
72
+ encoder = mocov3_vit.vit_small()
73
+ elif model_config == 'b':
74
+ encoder = mocov3_vit.vit_base()
75
+ elif model_config == 'l':
76
+ encoder = mocov3_vit.vit_large()
77
+ ckpt = torch.load(f'./ckpts/mocov3_vit{model_config}.pth')
78
+ state_dict = fix_mocov3_state_dict(ckpt['state_dict'])
79
+ del encoder.head
80
+ encoder.load_state_dict(state_dict, strict=True)
81
+ encoder.head = torch.nn.Identity()
82
+ elif architecture == 'resnet':
83
+ raise NotImplementedError()
84
+
85
+ encoder = encoder.to(device)
86
+ encoder.eval()
87
+
88
+ elif 'dinov2' in encoder_type:
89
+ import timm
90
+ if 'reg' in encoder_type:
91
+ try:
92
+ encoder = torch.hub.load('your_path/.cache/torch/hub/facebookresearch_dinov2_main',
93
+ f'dinov2_vit{model_config}14_reg', source='local')
94
+ except:
95
+ encoder = torch.hub.load('facebookresearch/dinov2', f'dinov2_vit{model_config}14_reg')
96
+ else:
97
+ try:
98
+ encoder = torch.hub.load('your_path/.cache/torch/hub/facebookresearch_dinov2_main',
99
+ f'dinov2_vit{model_config}14', source='local')
100
+ except:
101
+ encoder = torch.hub.load('facebookresearch/dinov2', f'dinov2_vit{model_config}14')
102
+
103
+ print(f"Now you are using the {enc_name} as the aligning model")
104
+ del encoder.head
105
+ patch_resolution = 16 * (resolution // 256)
106
+ encoder.pos_embed.data = timm.layers.pos_embed.resample_abs_pos_embed(
107
+ encoder.pos_embed.data, [patch_resolution, patch_resolution],
108
+ )
109
+ encoder.head = torch.nn.Identity()
110
+ encoder = encoder.to(device)
111
+ encoder.eval()
112
+
113
+ elif 'dinov1' == encoder_type:
114
+ import timm
115
+ from models import dinov1
116
+ encoder = dinov1.vit_base()
117
+ ckpt = torch.load(f'./ckpts/dinov1_vit{model_config}.pth')
118
+ if 'pos_embed' in ckpt.keys():
119
+ ckpt['pos_embed'] = timm.layers.pos_embed.resample_abs_pos_embed(
120
+ ckpt['pos_embed'], [16, 16],
121
+ )
122
+ del encoder.head
123
+ encoder.head = torch.nn.Identity()
124
+ encoder.load_state_dict(ckpt, strict=True)
125
+ encoder = encoder.to(device)
126
+ encoder.forward_features = encoder.forward
127
+ encoder.eval()
128
+
129
+ elif encoder_type == 'clip':
130
+ import clip
131
+ from models.clip_vit import UpdatedVisionTransformer
132
+ encoder_ = clip.load(f"ViT-{model_config}/14", device='cpu')[0].visual
133
+ encoder = UpdatedVisionTransformer(encoder_).to(device)
134
+ #.to(device)
135
+ encoder.embed_dim = encoder.model.transformer.width
136
+ encoder.forward_features = encoder.forward
137
+ encoder.eval()
138
+
139
+ elif encoder_type == 'mae':
140
+ from models.mae_vit import vit_large_patch16
141
+ import timm
142
+ kwargs = dict(img_size=256)
143
+ encoder = vit_large_patch16(**kwargs).to(device)
144
+ with open(f"ckpts/mae_vit{model_config}.pth", "rb") as f:
145
+ state_dict = torch.load(f)
146
+ if 'pos_embed' in state_dict["model"].keys():
147
+ state_dict["model"]['pos_embed'] = timm.layers.pos_embed.resample_abs_pos_embed(
148
+ state_dict["model"]['pos_embed'], [16, 16],
149
+ )
150
+ encoder.load_state_dict(state_dict["model"])
151
+
152
+ encoder.pos_embed.data = timm.layers.pos_embed.resample_abs_pos_embed(
153
+ encoder.pos_embed.data, [16, 16],
154
+ )
155
+
156
+ elif encoder_type == 'jepa':
157
+ from models.jepa import vit_huge
158
+ kwargs = dict(img_size=[224, 224], patch_size=14)
159
+ encoder = vit_huge(**kwargs).to(device)
160
+ with open(f"ckpts/ijepa_vit{model_config}.pth", "rb") as f:
161
+ state_dict = torch.load(f, map_location=device)
162
+ new_state_dict = dict()
163
+ for key, value in state_dict['encoder'].items():
164
+ new_state_dict[key[7:]] = value
165
+ encoder.load_state_dict(new_state_dict)
166
+ encoder.forward_features = encoder.forward
167
+
168
+ encoders.append(encoder)
169
+
170
+ return encoders, encoder_types, architectures
171
+
172
+
173
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
174
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
175
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
176
+ def norm_cdf(x):
177
+ # Computes standard normal cumulative distribution function
178
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
179
+
180
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
181
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
182
+ "The distribution of values may be incorrect.",
183
+ stacklevel=2)
184
+
185
+ with torch.no_grad():
186
+ # Values are generated by using a truncated uniform distribution and
187
+ # then using the inverse CDF for the normal distribution.
188
+ # Get upper and lower cdf values
189
+ l = norm_cdf((a - mean) / std)
190
+ u = norm_cdf((b - mean) / std)
191
+
192
+ # Uniformly fill tensor with values from [l, u], then translate to
193
+ # [2l-1, 2u-1].
194
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
195
+
196
+ # Use inverse cdf transform for normal distribution to get truncated
197
+ # standard normal
198
+ tensor.erfinv_()
199
+
200
+ # Transform to proper mean, std
201
+ tensor.mul_(std * math.sqrt(2.))
202
+ tensor.add_(mean)
203
+
204
+ # Clamp to ensure it's in the proper range
205
+ tensor.clamp_(min=a, max=b)
206
+ return tensor
207
+
208
+
209
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
210
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
211
+
212
+
213
+ def load_legacy_checkpoints(state_dict, encoder_depth):
214
+ new_state_dict = dict()
215
+ for key, value in state_dict.items():
216
+ if 'decoder_blocks' in key:
217
+ parts =key.split('.')
218
+ new_idx = int(parts[1]) + encoder_depth
219
+ parts[0] = 'blocks'
220
+ parts[1] = str(new_idx)
221
+ new_key = '.'.join(parts)
222
+ new_state_dict[new_key] = value
223
+ else:
224
+ new_state_dict[key] = value
225
+ return new_state_dict
New/REG/wandb/debug-internal.log ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {"time":"2026-03-26T13:08:48.077546544+08:00","level":"INFO","msg":"stream: starting","core version":"0.25.0"}
2
+ {"time":"2026-03-26T13:08:49.00834408+08:00","level":"INFO","msg":"stream: created new stream","id":"0e5vs4f8"}
3
+ {"time":"2026-03-26T13:08:49.008541708+08:00","level":"INFO","msg":"handler: started","stream_id":"0e5vs4f8"}
4
+ {"time":"2026-03-26T13:08:49.009420311+08:00","level":"INFO","msg":"stream: started","id":"0e5vs4f8"}
5
+ {"time":"2026-03-26T13:08:49.00944+08:00","level":"INFO","msg":"writer: started","stream_id":"0e5vs4f8"}
6
+ {"time":"2026-03-26T13:08:49.009442857+08:00","level":"INFO","msg":"sender: started","stream_id":"0e5vs4f8"}
7
+ {"time":"2026-03-27T00:25:27.785418795+08:00","level":"INFO","msg":"api: retrying HTTP error","status":502,"url":"https://api.wandb.ai/files/2365972933-teleai/REG/0e5vs4f8/file_stream","body":"\n<html><head>\n<meta http-equiv=\"content-type\" content=\"text/html;charset=utf-8\">\n<title>502 Server Error</title>\n</head>\n<body text=#000000 bgcolor=#ffffff>\n<h1>Error: Server Error</h1>\n<h2>The server encountered a temporary error and could not complete your request.<p>Please try again in 30 seconds.</h2>\n<h2></h2>\n</body></html>\n"}
New/REG/wandb/debug.log ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2026-03-26 13:08:47,761 INFO MainThread:576284 [wandb_setup.py:_flush():81] Current SDK version is 0.25.0
2
+ 2026-03-26 13:08:47,761 INFO MainThread:576284 [wandb_setup.py:_flush():81] Configure stats pid to 576284
3
+ 2026-03-26 13:08:47,761 INFO MainThread:576284 [wandb_setup.py:_flush():81] Loading settings from environment variables
4
+ 2026-03-26 13:08:47,761 INFO MainThread:576284 [wandb_init.py:setup_run_log_directory():717] Logging user logs to /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/New/REG/wandb/run-20260326_130847-0e5vs4f8/logs/debug.log
5
+ 2026-03-26 13:08:47,761 INFO MainThread:576284 [wandb_init.py:setup_run_log_directory():718] Logging internal logs to /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/New/REG/wandb/run-20260326_130847-0e5vs4f8/logs/debug-internal.log
6
+ 2026-03-26 13:08:47,761 INFO MainThread:576284 [wandb_init.py:init():844] calling init triggers
7
+ 2026-03-26 13:08:47,761 INFO MainThread:576284 [wandb_init.py:init():849] wandb.init called with sweep_config: {}
8
+ config: {'_wandb': {}}
9
+ 2026-03-26 13:08:47,761 INFO MainThread:576284 [wandb_init.py:init():892] starting backend
10
+ 2026-03-26 13:08:48,060 INFO MainThread:576284 [wandb_init.py:init():895] sending inform_init request
11
+ 2026-03-26 13:08:48,073 INFO MainThread:576284 [wandb_init.py:init():903] backend started and connected
12
+ 2026-03-26 13:08:48,075 INFO MainThread:576284 [wandb_init.py:init():973] updated telemetry
13
+ 2026-03-26 13:08:48,089 INFO MainThread:576284 [wandb_init.py:init():997] communicating run to backend with 90.0 second timeout
14
+ 2026-03-26 13:08:49,730 INFO MainThread:576284 [wandb_init.py:init():1042] starting run threads in backend
15
+ 2026-03-26 13:08:49,821 INFO MainThread:576284 [wandb_run.py:_console_start():2524] atexit reg
16
+ 2026-03-26 13:08:49,821 INFO MainThread:576284 [wandb_run.py:_redirect():2373] redirect: wrap_raw
17
+ 2026-03-26 13:08:49,821 INFO MainThread:576284 [wandb_run.py:_redirect():2442] Wrapping output streams.
18
+ 2026-03-26 13:08:49,821 INFO MainThread:576284 [wandb_run.py:_redirect():2465] Redirects installed.
19
+ 2026-03-26 13:08:49,826 INFO MainThread:576284 [wandb_init.py:init():1082] run started, returning control to user process
20
+ 2026-03-26 13:08:49,826 INFO MainThread:576284 [wandb_run.py:_config_callback():1403] config_cb None None {'output_dir': '/your_path/reg_xlarge_dinov2_base_align_8_cls', 'exp_name': 'linear-dinov2-b-enc8', 'logging_dir': 'logs', 'report_to': 'wandb', 'sampling_steps': 10000, 'resume_step': 0, 'train_generate_step': 2000, 'train_generate_num_steps': 2000, 'train_generate_num_images': 16, 'train_generate_mode': 'sde', 'train_generate_cfg_scale': 2.3, 'train_generate_cls_cfg_scale': 2.3, 'train_generate_guidance_high': 0.85, 'train_generate_vae': 'ema', 'model': 'SiT-XL/2', 'num_classes': 1000, 'encoder_depth': 8, 'fused_attn': True, 'qk_norm': False, 'ops_head': 16, 'data_dir': '/gemini/space/zhaozy/dataset/Imagenet/imagenet_256', 'semantic_features_dir': '/gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0', 'resolution': 256, 'batch_size': 256, 'allow_tf32': True, 'mixed_precision': 'fp16', 'epochs': 1400, 'max_train_steps': 1000000, 'checkpointing_steps': 10000, 'gradient_accumulation_steps': 1, 'learning_rate': 0.0001, 'adam_beta1': 0.9, 'adam_beta2': 0.999, 'adam_weight_decay': 0.0, 'adam_epsilon': 1e-08, 'max_grad_norm': 1.0, 'seed': 0, 'num_workers': 4, 'path_type': 'linear', 'prediction': 'v', 'cfg_prob': 0.1, 'enc_type': 'dinov2-vit-b', 'proj_coeff': 0.5, 'weighting': 'uniform', 'legacy': False, 'cls': 0.03}
New/REG/wandb/run-20260326_123101-m3lli51t/files/output.log ADDED
The diff for this file is too large to render. See raw diff
 
New/REG/wandb/run-20260326_123101-m3lli51t/files/requirements.txt ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dill==0.3.8
2
+ mkl-service==2.4.0
3
+ mpmath==1.3.0
4
+ typing_extensions==4.12.2
5
+ urllib3==2.3.0
6
+ torch==2.5.1
7
+ ptyprocess==0.7.0
8
+ traitlets==5.14.3
9
+ pyasn1==0.6.1
10
+ opencv-python-headless==4.12.0.88
11
+ nest-asyncio==1.6.0
12
+ kiwisolver==1.4.8
13
+ click==8.2.1
14
+ fire==0.7.1
15
+ diffusers==0.35.1
16
+ accelerate==1.7.0
17
+ ipykernel==6.29.5
18
+ peft==0.17.1
19
+ attrs==24.3.0
20
+ six==1.17.0
21
+ numpy==2.0.1
22
+ yarl==1.18.0
23
+ huggingface_hub==0.34.4
24
+ Bottleneck==1.4.2
25
+ numexpr==2.11.0
26
+ dataclasses==0.6
27
+ typing-inspection==0.4.1
28
+ safetensors==0.5.3
29
+ pyparsing==3.2.3
30
+ psutil==7.0.0
31
+ imageio==2.37.0
32
+ debugpy==1.8.14
33
+ cycler==0.12.1
34
+ pyasn1_modules==0.4.2
35
+ matplotlib-inline==0.1.7
36
+ matplotlib==3.10.3
37
+ jedi==0.19.2
38
+ tokenizers==0.21.2
39
+ seaborn==0.13.2
40
+ timm==1.0.15
41
+ aiohappyeyeballs==2.6.1
42
+ hf-xet==1.1.8
43
+ multidict==6.1.0
44
+ tqdm==4.67.1
45
+ wheel==0.45.1
46
+ simsimd==6.5.1
47
+ sentencepiece==0.2.1
48
+ grpcio==1.74.0
49
+ asttokens==3.0.0
50
+ absl-py==2.3.1
51
+ stack-data==0.6.3
52
+ pandas==2.3.0
53
+ importlib_metadata==8.7.0
54
+ pytorch-image-generation-metrics==0.6.1
55
+ frozenlist==1.5.0
56
+ MarkupSafe==3.0.2
57
+ setuptools==78.1.1
58
+ multiprocess==0.70.15
59
+ pip==25.1
60
+ requests==2.32.3
61
+ mkl_random==1.2.8
62
+ tensorboard-plugin-wit==1.8.1
63
+ ExifRead-nocycle==3.0.1
64
+ webdataset==0.2.111
65
+ threadpoolctl==3.6.0
66
+ pyarrow==21.0.0
67
+ executing==2.2.0
68
+ decorator==5.2.1
69
+ contourpy==1.3.2
70
+ annotated-types==0.7.0
71
+ scikit-learn==1.7.1
72
+ jupyter_client==8.6.3
73
+ albumentations==1.4.24
74
+ wandb==0.25.0
75
+ certifi==2025.8.3
76
+ idna==3.7
77
+ xxhash==3.5.0
78
+ Jinja2==3.1.6
79
+ python-dateutil==2.9.0.post0
80
+ aiosignal==1.4.0
81
+ triton==3.1.0
82
+ torchvision==0.20.1
83
+ stringzilla==3.12.6
84
+ pure_eval==0.2.3
85
+ braceexpand==0.1.7
86
+ zipp==3.22.0
87
+ oauthlib==3.3.1
88
+ Markdown==3.8.2
89
+ fsspec==2025.3.0
90
+ fonttools==4.58.2
91
+ comm==0.2.2
92
+ ipython==9.3.0
93
+ img2dataset==1.47.0
94
+ networkx==3.4.2
95
+ PySocks==1.7.1
96
+ tzdata==2025.2
97
+ smmap==5.0.2
98
+ mkl_fft==1.3.11
99
+ sentry-sdk==2.29.1
100
+ Pygments==2.19.1
101
+ pexpect==4.9.0
102
+ ftfy==6.3.1
103
+ einops==0.8.1
104
+ requests-oauthlib==2.0.0
105
+ gitdb==4.0.12
106
+ albucore==0.0.23
107
+ torchdiffeq==0.2.5
108
+ GitPython==3.1.44
109
+ bitsandbytes==0.47.0
110
+ pytorch-fid==0.3.0
111
+ clean-fid==0.1.35
112
+ pytorch-gan-metrics==0.5.4
113
+ Brotli==1.0.9
114
+ charset-normalizer==3.3.2
115
+ gmpy2==2.2.1
116
+ pillow==11.1.0
117
+ PyYAML==6.0.2
118
+ tornado==6.5.1
119
+ termcolor==3.1.0
120
+ setproctitle==1.3.6
121
+ scipy==1.15.3
122
+ regex==2024.11.6
123
+ protobuf==6.31.1
124
+ platformdirs==4.3.8
125
+ joblib==1.5.1
126
+ cachetools==4.2.4
127
+ ipython_pygments_lexers==1.1.1
128
+ google-auth==1.35.0
129
+ transformers==4.53.2
130
+ torch-fidelity==0.3.0
131
+ tensorboard==2.4.0
132
+ filelock==3.17.0
133
+ packaging==25.0
134
+ propcache==0.3.1
135
+ pytz==2025.2
136
+ aiohttp==3.11.10
137
+ wcwidth==0.2.13
138
+ clip==0.2.0
139
+ Werkzeug==3.1.3
140
+ tensorboard-data-server==0.6.1
141
+ sympy==1.13.1
142
+ pyzmq==26.4.0
143
+ pydantic_core==2.33.2
144
+ prompt_toolkit==3.0.51
145
+ parso==0.8.4
146
+ docker-pycreds==0.4.0
147
+ rsa==4.9.1
148
+ pydantic==2.11.5
149
+ jupyter_core==5.8.1
150
+ google-auth-oauthlib==0.4.6
151
+ datasets==4.0.0
152
+ torch-tb-profiler==0.4.3
153
+ autocommand==2.2.2
154
+ backports.tarfile==1.2.0
155
+ importlib_metadata==8.0.0
156
+ jaraco.collections==5.1.0
157
+ jaraco.context==5.3.0
158
+ jaraco.functools==4.0.1
159
+ more-itertools==10.3.0
160
+ packaging==24.2
161
+ platformdirs==4.2.2
162
+ typeguard==4.3.0
163
+ inflect==7.3.1
164
+ jaraco.text==3.12.1
165
+ tomli==2.0.1
166
+ typing_extensions==4.12.2
167
+ wheel==0.45.1
168
+ zipp==3.19.2
New/REG/wandb/run-20260326_123101-m3lli51t/files/wandb-metadata.json ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-5.15.0-94-generic-x86_64-with-glibc2.35",
3
+ "python": "CPython 3.12.9",
4
+ "startedAt": "2026-03-26T04:31:01.251691Z",
5
+ "args": [
6
+ "--report-to=wandb",
7
+ "--allow-tf32",
8
+ "--mixed-precision=fp16",
9
+ "--seed=0",
10
+ "--path-type=linear",
11
+ "--prediction=v",
12
+ "--weighting=uniform",
13
+ "--model=SiT-XL/2",
14
+ "--enc-type=dinov2-vit-b",
15
+ "--proj-coeff=0.5",
16
+ "--encoder-depth=8",
17
+ "--cls=0.03",
18
+ "--output-dir=/your_path/reg_xlarge_dinov2_base_align_8_cls",
19
+ "--exp-name=linear-dinov2-b-enc8",
20
+ "--batch-size=256",
21
+ "--data-dir=/gemini/space/zhaozy/dataset/Imagenet/imagenet_256",
22
+ "--semantic-features-dir=/gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0"
23
+ ],
24
+ "program": "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/New/REG/train.py",
25
+ "codePath": "train.py",
26
+ "codePathLocal": "train.py",
27
+ "git": {
28
+ "remote": "https://github.com/Martinser/REG.git",
29
+ "commit": "021ea2e50c38c5803bd9afff16316958a01fbd1d"
30
+ },
31
+ "email": "2365972933@qq.com",
32
+ "root": "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/New/REG",
33
+ "host": "24c964746905d416ce09d045f9a06f23-taskrole1-0",
34
+ "executable": "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/bin/python",
35
+ "cpu_count": 96,
36
+ "cpu_count_logical": 192,
37
+ "gpu": "NVIDIA H100 80GB HBM3",
38
+ "gpu_count": 4,
39
+ "disk": {
40
+ "/": {
41
+ "total": "3838880616448",
42
+ "used": "367215697920"
43
+ }
44
+ },
45
+ "memory": {
46
+ "total": "2164115296256"
47
+ },
48
+ "gpu_nvidia": [
49
+ {
50
+ "name": "NVIDIA H100 80GB HBM3",
51
+ "memoryTotal": "85520809984",
52
+ "cudaCores": 16896,
53
+ "architecture": "Hopper",
54
+ "uuid": "GPU-757303bb-4ec2-808b-a17f-95f6f5bad6dc"
55
+ },
56
+ {
57
+ "name": "NVIDIA H100 80GB HBM3",
58
+ "memoryTotal": "85520809984",
59
+ "cudaCores": 16896,
60
+ "architecture": "Hopper",
61
+ "uuid": "GPU-a09f2421-99e6-a72e-63bd-fd7452510758"
62
+ },
63
+ {
64
+ "name": "NVIDIA H100 80GB HBM3",
65
+ "memoryTotal": "85520809984",
66
+ "cudaCores": 16896,
67
+ "architecture": "Hopper",
68
+ "uuid": "GPU-9c670cc7-60a8-17f8-9b39-7ced3744976d"
69
+ },
70
+ {
71
+ "name": "NVIDIA H100 80GB HBM3",
72
+ "memoryTotal": "85520809984",
73
+ "cudaCores": 16896,
74
+ "architecture": "Hopper",
75
+ "uuid": "GPU-e6b1d8da-68d7-ed83-90d0-a4dedf33120e"
76
+ }
77
+ ],
78
+ "cudaVersion": "13.0",
79
+ "writerId": "shmu4nsng7t9zxqf30qry601bzvrritx"
80
+ }
New/REG/wandb/run-20260326_123101-m3lli51t/logs/debug-internal.log ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {"time":"2026-03-26T12:31:01.594683189+08:00","level":"INFO","msg":"stream: starting","core version":"0.25.0"}
2
+ {"time":"2026-03-26T12:31:02.390583334+08:00","level":"INFO","msg":"stream: created new stream","id":"m3lli51t"}
3
+ {"time":"2026-03-26T12:31:02.390723765+08:00","level":"INFO","msg":"handler: started","stream_id":"m3lli51t"}
4
+ {"time":"2026-03-26T12:31:02.39153628+08:00","level":"INFO","msg":"stream: started","id":"m3lli51t"}
5
+ {"time":"2026-03-26T12:31:02.391552999+08:00","level":"INFO","msg":"writer: started","stream_id":"m3lli51t"}
6
+ {"time":"2026-03-26T12:31:02.391598449+08:00","level":"INFO","msg":"sender: started","stream_id":"m3lli51t"}
7
+ {"time":"2026-03-26T12:31:25.872492492+08:00","level":"INFO","msg":"api: retrying error","error":"Put \"https://storage.googleapis.com/wandb-production.appspot.com/2365972933-teleai/REG/m3lli51t/requirements.txt?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com%2F20260326%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20260326T043108Z&X-Goog-Expires=86399&X-Goog-Signature=b7efae9bb95ca96e0208d793a895c400eb9f42714e7f79e1ccff1fb97664053ab3d4a38f98cba45f7336911ceef6bb02894228ed5cf3b4dffdd02698b5a31a7faa483fd85396d291642957c4420cf884bf125233c9c9cec3ce04029a41984d86f1db898738ae442aae0015236bafa9031b20eedd7a081f77402ea66e0e54e4f65adadb2c11f9850863e10e75f24b9438f3cf3823f5bdf9d938245fc91ee6680fd95fa221fc69e487ffde730b8fc8bcd8fa79ae686fb7bd8ae6c1bd4b25d8eef9f64931d06496a6c52b7ed9dbe4a0716ed4f7fbb833b8470f3ae8fde3da8f15dc226596019add90ffbc322acc6147ab83fd50eca1690a3060e947a76b97d42b47&X-Goog-SignedHeaders=host&X-User=2365972933\": read tcp 172.20.98.27:43208->142.250.199.219:443: read: connection reset by peer"}
8
+ {"time":"2026-03-26T12:31:27.355342484+08:00","level":"INFO","msg":"api: retrying error","error":"Put \"https://storage.googleapis.com/wandb-production.appspot.com/2365972933-teleai/REG/m3lli51t/wandb-metadata.json?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=gorilla-files-url-signer-man%40wandb-production.iam.gserviceaccount.com%2F20260326%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20260326T043108Z&X-Goog-Expires=86399&X-Goog-Signature=6f5856f5b1dd6e9d2efd01092fa98674135a31cd288253e9013a5ba3024eebbd56748db3da929cb5342625eb9500b726db70e1b5e2434471f98fded30c38a53884c36993b0f6ad68e20df49dc9642e3f3251903b352dd72aad76871227838ccfca1a4097f4bddd24e22f39bdd506334722f15d05be1b2fab112907648c2cad8774ae360e76c816113cdb33b19a2a0fa98a4cf50f21b74fcb4d6e392257126d83935963e0430a4a48a58e2354d47db9580855540e2f2cfceee9b9d8e6d53cc3a72ed15a4b4341a08b25e608511d4961afcfaf8231b8b25e8a7d732c7e9fd007db6fbd2197776041c8fd046bd2a8a638599df98923c5ddb5f7bead505453ccd11f&X-Goog-SignedHeaders=host&X-User=2365972933\": read tcp 172.20.98.27:43200->142.250.199.219:443: read: connection reset by peer"}
New/REG/wandb/run-20260326_123101-m3lli51t/logs/debug.log ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2026-03-26 12:31:01,272 INFO MainThread:571602 [wandb_setup.py:_flush():81] Current SDK version is 0.25.0
2
+ 2026-03-26 12:31:01,272 INFO MainThread:571602 [wandb_setup.py:_flush():81] Configure stats pid to 571602
3
+ 2026-03-26 12:31:01,272 INFO MainThread:571602 [wandb_setup.py:_flush():81] Loading settings from environment variables
4
+ 2026-03-26 12:31:01,272 INFO MainThread:571602 [wandb_init.py:setup_run_log_directory():717] Logging user logs to /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/New/REG/wandb/run-20260326_123101-m3lli51t/logs/debug.log
5
+ 2026-03-26 12:31:01,273 INFO MainThread:571602 [wandb_init.py:setup_run_log_directory():718] Logging internal logs to /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/New/REG/wandb/run-20260326_123101-m3lli51t/logs/debug-internal.log
6
+ 2026-03-26 12:31:01,273 INFO MainThread:571602 [wandb_init.py:init():844] calling init triggers
7
+ 2026-03-26 12:31:01,273 INFO MainThread:571602 [wandb_init.py:init():849] wandb.init called with sweep_config: {}
8
+ config: {'_wandb': {}}
9
+ 2026-03-26 12:31:01,273 INFO MainThread:571602 [wandb_init.py:init():892] starting backend
10
+ 2026-03-26 12:31:01,576 INFO MainThread:571602 [wandb_init.py:init():895] sending inform_init request
11
+ 2026-03-26 12:31:01,591 INFO MainThread:571602 [wandb_init.py:init():903] backend started and connected
12
+ 2026-03-26 12:31:01,593 INFO MainThread:571602 [wandb_init.py:init():973] updated telemetry
13
+ 2026-03-26 12:31:01,607 INFO MainThread:571602 [wandb_init.py:init():997] communicating run to backend with 90.0 second timeout
14
+ 2026-03-26 12:31:03,274 INFO MainThread:571602 [wandb_init.py:init():1042] starting run threads in backend
15
+ 2026-03-26 12:31:03,366 INFO MainThread:571602 [wandb_run.py:_console_start():2524] atexit reg
16
+ 2026-03-26 12:31:03,366 INFO MainThread:571602 [wandb_run.py:_redirect():2373] redirect: wrap_raw
17
+ 2026-03-26 12:31:03,366 INFO MainThread:571602 [wandb_run.py:_redirect():2442] Wrapping output streams.
18
+ 2026-03-26 12:31:03,366 INFO MainThread:571602 [wandb_run.py:_redirect():2465] Redirects installed.
19
+ 2026-03-26 12:31:03,372 INFO MainThread:571602 [wandb_init.py:init():1082] run started, returning control to user process
20
+ 2026-03-26 12:31:03,372 INFO MainThread:571602 [wandb_run.py:_config_callback():1403] config_cb None None {'output_dir': '/your_path/reg_xlarge_dinov2_base_align_8_cls', 'exp_name': 'linear-dinov2-b-enc8', 'logging_dir': 'logs', 'report_to': 'wandb', 'sampling_steps': 10000, 'resume_step': 0, 'model': 'SiT-XL/2', 'num_classes': 1000, 'encoder_depth': 8, 'fused_attn': True, 'qk_norm': False, 'ops_head': 16, 'data_dir': '/gemini/space/zhaozy/dataset/Imagenet/imagenet_256', 'semantic_features_dir': '/gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0', 'resolution': 256, 'batch_size': 256, 'allow_tf32': True, 'mixed_precision': 'fp16', 'epochs': 1400, 'max_train_steps': 1000000, 'checkpointing_steps': 10000, 'gradient_accumulation_steps': 1, 'learning_rate': 0.0001, 'adam_beta1': 0.9, 'adam_beta2': 0.999, 'adam_weight_decay': 0.0, 'adam_epsilon': 1e-08, 'max_grad_norm': 1.0, 'seed': 0, 'num_workers': 4, 'path_type': 'linear', 'prediction': 'v', 'cfg_prob': 0.1, 'enc_type': 'dinov2-vit-b', 'proj_coeff': 0.5, 'weighting': 'uniform', 'legacy': False, 'cls': 0.03}
New/REG/wandb/run-20260326_130847-0e5vs4f8/files/requirements.txt ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dill==0.3.8
2
+ mkl-service==2.4.0
3
+ mpmath==1.3.0
4
+ typing_extensions==4.12.2
5
+ urllib3==2.3.0
6
+ torch==2.5.1
7
+ ptyprocess==0.7.0
8
+ traitlets==5.14.3
9
+ pyasn1==0.6.1
10
+ opencv-python-headless==4.12.0.88
11
+ nest-asyncio==1.6.0
12
+ kiwisolver==1.4.8
13
+ click==8.2.1
14
+ fire==0.7.1
15
+ diffusers==0.35.1
16
+ accelerate==1.7.0
17
+ ipykernel==6.29.5
18
+ peft==0.17.1
19
+ attrs==24.3.0
20
+ six==1.17.0
21
+ numpy==2.0.1
22
+ yarl==1.18.0
23
+ huggingface_hub==0.34.4
24
+ Bottleneck==1.4.2
25
+ numexpr==2.11.0
26
+ dataclasses==0.6
27
+ typing-inspection==0.4.1
28
+ safetensors==0.5.3
29
+ pyparsing==3.2.3
30
+ psutil==7.0.0
31
+ imageio==2.37.0
32
+ debugpy==1.8.14
33
+ cycler==0.12.1
34
+ pyasn1_modules==0.4.2
35
+ matplotlib-inline==0.1.7
36
+ matplotlib==3.10.3
37
+ jedi==0.19.2
38
+ tokenizers==0.21.2
39
+ seaborn==0.13.2
40
+ timm==1.0.15
41
+ aiohappyeyeballs==2.6.1
42
+ hf-xet==1.1.8
43
+ multidict==6.1.0
44
+ tqdm==4.67.1
45
+ wheel==0.45.1
46
+ simsimd==6.5.1
47
+ sentencepiece==0.2.1
48
+ grpcio==1.74.0
49
+ asttokens==3.0.0
50
+ absl-py==2.3.1
51
+ stack-data==0.6.3
52
+ pandas==2.3.0
53
+ importlib_metadata==8.7.0
54
+ pytorch-image-generation-metrics==0.6.1
55
+ frozenlist==1.5.0
56
+ MarkupSafe==3.0.2
57
+ setuptools==78.1.1
58
+ multiprocess==0.70.15
59
+ pip==25.1
60
+ requests==2.32.3
61
+ mkl_random==1.2.8
62
+ tensorboard-plugin-wit==1.8.1
63
+ ExifRead-nocycle==3.0.1
64
+ webdataset==0.2.111
65
+ threadpoolctl==3.6.0
66
+ pyarrow==21.0.0
67
+ executing==2.2.0
68
+ decorator==5.2.1
69
+ contourpy==1.3.2
70
+ annotated-types==0.7.0
71
+ scikit-learn==1.7.1
72
+ jupyter_client==8.6.3
73
+ albumentations==1.4.24
74
+ wandb==0.25.0
75
+ certifi==2025.8.3
76
+ idna==3.7
77
+ xxhash==3.5.0
78
+ Jinja2==3.1.6
79
+ python-dateutil==2.9.0.post0
80
+ aiosignal==1.4.0
81
+ triton==3.1.0
82
+ torchvision==0.20.1
83
+ stringzilla==3.12.6
84
+ pure_eval==0.2.3
85
+ braceexpand==0.1.7
86
+ zipp==3.22.0
87
+ oauthlib==3.3.1
88
+ Markdown==3.8.2
89
+ fsspec==2025.3.0
90
+ fonttools==4.58.2
91
+ comm==0.2.2
92
+ ipython==9.3.0
93
+ img2dataset==1.47.0
94
+ networkx==3.4.2
95
+ PySocks==1.7.1
96
+ tzdata==2025.2
97
+ smmap==5.0.2
98
+ mkl_fft==1.3.11
99
+ sentry-sdk==2.29.1
100
+ Pygments==2.19.1
101
+ pexpect==4.9.0
102
+ ftfy==6.3.1
103
+ einops==0.8.1
104
+ requests-oauthlib==2.0.0
105
+ gitdb==4.0.12
106
+ albucore==0.0.23
107
+ torchdiffeq==0.2.5
108
+ GitPython==3.1.44
109
+ bitsandbytes==0.47.0
110
+ pytorch-fid==0.3.0
111
+ clean-fid==0.1.35
112
+ pytorch-gan-metrics==0.5.4
113
+ Brotli==1.0.9
114
+ charset-normalizer==3.3.2
115
+ gmpy2==2.2.1
116
+ pillow==11.1.0
117
+ PyYAML==6.0.2
118
+ tornado==6.5.1
119
+ termcolor==3.1.0
120
+ setproctitle==1.3.6
121
+ scipy==1.15.3
122
+ regex==2024.11.6
123
+ protobuf==6.31.1
124
+ platformdirs==4.3.8
125
+ joblib==1.5.1
126
+ cachetools==4.2.4
127
+ ipython_pygments_lexers==1.1.1
128
+ google-auth==1.35.0
129
+ transformers==4.53.2
130
+ torch-fidelity==0.3.0
131
+ tensorboard==2.4.0
132
+ filelock==3.17.0
133
+ packaging==25.0
134
+ propcache==0.3.1
135
+ pytz==2025.2
136
+ aiohttp==3.11.10
137
+ wcwidth==0.2.13
138
+ clip==0.2.0
139
+ Werkzeug==3.1.3
140
+ tensorboard-data-server==0.6.1
141
+ sympy==1.13.1
142
+ pyzmq==26.4.0
143
+ pydantic_core==2.33.2
144
+ prompt_toolkit==3.0.51
145
+ parso==0.8.4
146
+ docker-pycreds==0.4.0
147
+ rsa==4.9.1
148
+ pydantic==2.11.5
149
+ jupyter_core==5.8.1
150
+ google-auth-oauthlib==0.4.6
151
+ datasets==4.0.0
152
+ torch-tb-profiler==0.4.3
153
+ autocommand==2.2.2
154
+ backports.tarfile==1.2.0
155
+ importlib_metadata==8.0.0
156
+ jaraco.collections==5.1.0
157
+ jaraco.context==5.3.0
158
+ jaraco.functools==4.0.1
159
+ more-itertools==10.3.0
160
+ packaging==24.2
161
+ platformdirs==4.2.2
162
+ typeguard==4.3.0
163
+ inflect==7.3.1
164
+ jaraco.text==3.12.1
165
+ tomli==2.0.1
166
+ typing_extensions==4.12.2
167
+ wheel==0.45.1
168
+ zipp==3.19.2
New/REG/wandb/run-20260326_130847-0e5vs4f8/files/wandb-metadata.json ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-5.15.0-94-generic-x86_64-with-glibc2.35",
3
+ "python": "CPython 3.12.9",
4
+ "startedAt": "2026-03-26T05:08:47.738290Z",
5
+ "args": [
6
+ "--report-to=wandb",
7
+ "--allow-tf32",
8
+ "--mixed-precision=fp16",
9
+ "--seed=0",
10
+ "--path-type=linear",
11
+ "--prediction=v",
12
+ "--weighting=uniform",
13
+ "--model=SiT-XL/2",
14
+ "--enc-type=dinov2-vit-b",
15
+ "--proj-coeff=0.5",
16
+ "--encoder-depth=8",
17
+ "--cls=0.03",
18
+ "--output-dir=/your_path/reg_xlarge_dinov2_base_align_8_cls",
19
+ "--exp-name=linear-dinov2-b-enc8",
20
+ "--batch-size=256",
21
+ "--data-dir=/gemini/space/zhaozy/dataset/Imagenet/imagenet_256",
22
+ "--semantic-features-dir=/gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0"
23
+ ],
24
+ "program": "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/New/REG/train.py",
25
+ "codePath": "train.py",
26
+ "codePathLocal": "train.py",
27
+ "git": {
28
+ "remote": "https://github.com/Martinser/REG.git",
29
+ "commit": "021ea2e50c38c5803bd9afff16316958a01fbd1d"
30
+ },
31
+ "email": "2365972933@qq.com",
32
+ "root": "/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/New/REG",
33
+ "host": "24c964746905d416ce09d045f9a06f23-taskrole1-0",
34
+ "executable": "/gemini/space/zhaozy/guzhenyu/envs/envs/SiT/bin/python",
35
+ "cpu_count": 96,
36
+ "cpu_count_logical": 192,
37
+ "gpu": "NVIDIA H100 80GB HBM3",
38
+ "gpu_count": 4,
39
+ "disk": {
40
+ "/": {
41
+ "total": "3838880616448",
42
+ "used": "367217483776"
43
+ }
44
+ },
45
+ "memory": {
46
+ "total": "2164115296256"
47
+ },
48
+ "gpu_nvidia": [
49
+ {
50
+ "name": "NVIDIA H100 80GB HBM3",
51
+ "memoryTotal": "85520809984",
52
+ "cudaCores": 16896,
53
+ "architecture": "Hopper",
54
+ "uuid": "GPU-757303bb-4ec2-808b-a17f-95f6f5bad6dc"
55
+ },
56
+ {
57
+ "name": "NVIDIA H100 80GB HBM3",
58
+ "memoryTotal": "85520809984",
59
+ "cudaCores": 16896,
60
+ "architecture": "Hopper",
61
+ "uuid": "GPU-a09f2421-99e6-a72e-63bd-fd7452510758"
62
+ },
63
+ {
64
+ "name": "NVIDIA H100 80GB HBM3",
65
+ "memoryTotal": "85520809984",
66
+ "cudaCores": 16896,
67
+ "architecture": "Hopper",
68
+ "uuid": "GPU-9c670cc7-60a8-17f8-9b39-7ced3744976d"
69
+ },
70
+ {
71
+ "name": "NVIDIA H100 80GB HBM3",
72
+ "memoryTotal": "85520809984",
73
+ "cudaCores": 16896,
74
+ "architecture": "Hopper",
75
+ "uuid": "GPU-e6b1d8da-68d7-ed83-90d0-a4dedf33120e"
76
+ }
77
+ ],
78
+ "cudaVersion": "13.0",
79
+ "writerId": "zx05al0qjdz8zwayv7jbdzbhr8qwhpnu"
80
+ }