xiangzai commited on
Commit
342a08e
·
verified ·
1 Parent(s): 5484dca

Add files using upload-large-folder tool

Browse files
REG/evaluations/README.md ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Evaluations
2
+
3
+ To compare different generative models, we use FID, sFID, Precision, Recall, and Inception Score. These metrics can all be calculated using batches of samples, which we store in `.npz` (numpy) files.
4
+
5
+ # Download batches
6
+
7
+ We provide pre-computed sample batches for the reference datasets, our diffusion models, and several baselines we compare against. These are all stored in `.npz` format.
8
+
9
+ Reference dataset batches contain pre-computed statistics over the whole dataset, as well as 10,000 images for computing Precision and Recall. All other batches contain 50,000 images which can be used to compute statistics and Precision/Recall.
10
+
11
+ Here are links to download all of the sample and reference batches:
12
+
13
+ * LSUN
14
+ * LSUN bedroom: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/VIRTUAL_lsun_bedroom256.npz)
15
+ * [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/admnet_dropout_lsun_bedroom.npz)
16
+ * [DDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/ddpm_lsun_bedroom.npz)
17
+ * [IDDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/iddpm_lsun_bedroom.npz)
18
+ * [StyleGAN](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/stylegan_lsun_bedroom.npz)
19
+ * LSUN cat: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/VIRTUAL_lsun_cat256.npz)
20
+ * [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/admnet_dropout_lsun_cat.npz)
21
+ * [StyleGAN2](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/stylegan2_lsun_cat.npz)
22
+ * LSUN horse: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/VIRTUAL_lsun_horse256.npz)
23
+ * [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/admnet_dropout_lsun_horse.npz)
24
+ * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/admnet_lsun_horse.npz)
25
+
26
+ * ImageNet
27
+ * ImageNet 64x64: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/VIRTUAL_imagenet64_labeled.npz)
28
+ * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/admnet_imagenet64.npz)
29
+ * [IDDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/iddpm_imagenet64.npz)
30
+ * [BigGAN](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/biggan_deep_imagenet64.npz)
31
+ * ImageNet 128x128: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/VIRTUAL_imagenet128_labeled.npz)
32
+ * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_imagenet128.npz)
33
+ * [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_guided_imagenet128.npz)
34
+ * [ADM-G, 25 steps](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_guided_25step_imagenet128.npz)
35
+ * [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/biggan_deep_trunc1_imagenet128.npz)
36
+ * ImageNet 256x256: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz)
37
+ * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_imagenet256.npz)
38
+ * [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_imagenet256.npz)
39
+ * [ADM-G, 25 step](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_25step_imagenet256.npz)
40
+ * [ADM-G + ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_upsampled_imagenet256.npz)
41
+ * [ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_upsampled_imagenet256.npz)
42
+ * [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/biggan_deep_trunc1_imagenet256.npz)
43
+ * ImageNet 512x512: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/VIRTUAL_imagenet512.npz)
44
+ * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_imagenet512.npz)
45
+ * [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_imagenet512.npz)
46
+ * [ADM-G, 25 step](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_25step_imagenet512.npz)
47
+ * [ADM-G + ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_upsampled_imagenet512.npz)
48
+ * [ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_upsampled_imagenet512.npz)
49
+ * [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/biggan_deep_trunc1_imagenet512.npz)
50
+
51
+ # Run evaluations
52
+
53
+ First, generate or download a batch of samples and download the corresponding reference batch for the given dataset. For this example, we'll use ImageNet 256x256, so the refernce batch is `VIRTUAL_imagenet256_labeled.npz` and we can use the sample batch `admnet_guided_upsampled_imagenet256.npz`.
54
+
55
+ Next, run the `evaluator.py` script. The requirements of this script can be found in [requirements.txt](requirements.txt). Pass two arguments to the script: the reference batch and the sample batch. The script will download the InceptionV3 model used for evaluations into the current working directory (if it is not already present). This file is roughly 100MB.
56
+
57
+ The output of the script will look something like this, where the first `...` is a bunch of verbose TensorFlow logging:
58
+
59
+ ```
60
+ $ python evaluator.py VIRTUAL_imagenet256_labeled.npz admnet_guided_upsampled_imagenet256.npz
61
+ ...
62
+ computing reference batch activations...
63
+ computing/reading reference batch statistics...
64
+ computing sample batch activations...
65
+ computing/reading sample batch statistics...
66
+ Computing evaluations...
67
+ Inception Score: 215.8370361328125
68
+ FID: 3.9425574129223264
69
+ sFID: 6.140433703346162
70
+ Precision: 0.8265
71
+ Recall: 0.5309
72
+ ```
REG/evaluations/evaluator.py ADDED
@@ -0,0 +1,679 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import io
3
+ import os
4
+ import random
5
+ import warnings
6
+ import zipfile
7
+ from abc import ABC, abstractmethod
8
+ from contextlib import contextmanager
9
+ from functools import partial
10
+ from multiprocessing import cpu_count
11
+ from multiprocessing.pool import ThreadPool
12
+ from typing import Iterable, Optional, Tuple
13
+
14
+ import numpy as np
15
+ import requests
16
+ import tensorflow.compat.v1 as tf
17
+ from scipy import linalg
18
+ from tqdm.auto import tqdm
19
+
20
+ INCEPTION_V3_URL = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb"
21
+ INCEPTION_V3_PATH = "classify_image_graph_def.pb"
22
+
23
+ FID_POOL_NAME = "pool_3:0"
24
+ FID_SPATIAL_NAME = "mixed_6/conv:0"
25
+
26
+
27
+ def main():
28
+ parser = argparse.ArgumentParser()
29
+ parser.add_argument("--ref_batch", help="path to reference batch npz file")
30
+ parser.add_argument("--sample_batch", help="path to sample batch npz file")
31
+ parser.add_argument("--save_path", help="path to sample batch npz file")
32
+ parser.add_argument("--cfg_cond", default=1, type=int)
33
+ parser.add_argument("--step", default=1, type=int)
34
+ parser.add_argument("--cfg", default=1.0, type=float)
35
+ parser.add_argument("--cls_cfg", default=1.0, type=float)
36
+ parser.add_argument("--gh", default=1.0, type=float)
37
+ parser.add_argument("--num_steps", default=250, type=int)
38
+ args = parser.parse_args()
39
+
40
+ if not os.path.exists(args.save_path):
41
+ os.mkdir(args.save_path)
42
+
43
+
44
+ config = tf.ConfigProto(
45
+ allow_soft_placement=True # allows DecodeJpeg to run on CPU in Inception graph
46
+ )
47
+ config.gpu_options.allow_growth = True
48
+ evaluator = Evaluator(tf.Session(config=config))
49
+
50
+ print("warming up TensorFlow...")
51
+ # This will cause TF to print a bunch of verbose stuff now rather
52
+ # than after the next print(), to help prevent confusion.
53
+ evaluator.warmup()
54
+
55
+ print("computing reference batch activations...")
56
+ ref_acts = evaluator.read_activations(args.ref_batch)
57
+ print("computing/reading reference batch statistics...")
58
+ ref_stats, ref_stats_spatial = evaluator.read_statistics(args.ref_batch, ref_acts)
59
+
60
+ print("computing sample batch activations...")
61
+ sample_acts = evaluator.read_activations(args.sample_batch)
62
+ print("computing/reading sample batch statistics...")
63
+ sample_stats, sample_stats_spatial = evaluator.read_statistics(args.sample_batch, sample_acts)
64
+
65
+ print("Computing evaluations...")
66
+ Inception_Score = evaluator.compute_inception_score(sample_acts[0])
67
+ FID = sample_stats.frechet_distance(ref_stats)
68
+ sFID = sample_stats_spatial.frechet_distance(ref_stats_spatial)
69
+ prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0])
70
+
71
+ print("Inception Score:", Inception_Score)
72
+ print("FID:", FID)
73
+ print("sFID:", sFID)
74
+ print("Precision:", prec)
75
+ print("Recall:", recall)
76
+
77
+ if args.cfg_cond:
78
+ file_path = args.save_path + str(args.num_steps) + str(args.step) + str(args.cfg) + str(args.gh) + str(args.cls_cfg)+ "cfg_cond_true.txt"
79
+ else:
80
+ file_path = args.save_path + str(args.num_steps) + str(args.step) + str(args.cfg) + str(args.gh) + str(args.cls_cfg)+ "cfg_cond_false.txt"
81
+ with open(file_path, "w") as file:
82
+ file.write("Inception Score: {}\n".format(Inception_Score))
83
+ file.write("FID: {}\n".format(FID))
84
+ file.write("sFID: {}\n".format(sFID))
85
+ file.write("Precision: {}\n".format(prec))
86
+ file.write("Recall: {}\n".format(recall))
87
+
88
+
89
+ class InvalidFIDException(Exception):
90
+ pass
91
+
92
+
93
+ class FIDStatistics:
94
+ def __init__(self, mu: np.ndarray, sigma: np.ndarray):
95
+ self.mu = mu
96
+ self.sigma = sigma
97
+
98
+ def frechet_distance(self, other, eps=1e-6):
99
+ """
100
+ Compute the Frechet distance between two sets of statistics.
101
+ """
102
+ # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132
103
+ mu1, sigma1 = self.mu, self.sigma
104
+ mu2, sigma2 = other.mu, other.sigma
105
+
106
+ mu1 = np.atleast_1d(mu1)
107
+ mu2 = np.atleast_1d(mu2)
108
+
109
+ sigma1 = np.atleast_2d(sigma1)
110
+ sigma2 = np.atleast_2d(sigma2)
111
+
112
+ assert (
113
+ mu1.shape == mu2.shape
114
+ ), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}"
115
+ assert (
116
+ sigma1.shape == sigma2.shape
117
+ ), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}"
118
+
119
+ diff = mu1 - mu2
120
+
121
+ # product might be almost singular
122
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
123
+ if not np.isfinite(covmean).all():
124
+ msg = (
125
+ "fid calculation produces singular product; adding %s to diagonal of cov estimates"
126
+ % eps
127
+ )
128
+ warnings.warn(msg)
129
+ offset = np.eye(sigma1.shape[0]) * eps
130
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
131
+
132
+ # numerical error might give slight imaginary component
133
+ if np.iscomplexobj(covmean):
134
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
135
+ m = np.max(np.abs(covmean.imag))
136
+ raise ValueError("Imaginary component {}".format(m))
137
+ covmean = covmean.real
138
+
139
+ tr_covmean = np.trace(covmean)
140
+
141
+ return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
142
+
143
+
144
+ class Evaluator:
145
+ def __init__(
146
+ self,
147
+ session,
148
+ batch_size=64,
149
+ softmax_batch_size=512,
150
+ ):
151
+ self.sess = session
152
+ self.batch_size = batch_size
153
+ self.softmax_batch_size = softmax_batch_size
154
+ self.manifold_estimator = ManifoldEstimator(session)
155
+ with self.sess.graph.as_default():
156
+ self.image_input = tf.placeholder(tf.float32, shape=[None, None, None, 3])
157
+ self.softmax_input = tf.placeholder(tf.float32, shape=[None, 2048])
158
+ self.pool_features, self.spatial_features = _create_feature_graph(self.image_input)
159
+ self.softmax = _create_softmax_graph(self.softmax_input)
160
+
161
+ def warmup(self):
162
+ self.compute_activations(np.zeros([1, 8, 64, 64, 3]))
163
+
164
+ def read_activations(self, npz_path: str) -> Tuple[np.ndarray, np.ndarray]:
165
+ with open_npz_array(npz_path, "arr_0") as reader:
166
+ return self.compute_activations(reader.read_batches(self.batch_size))
167
+
168
+ def compute_activations(self, batches: Iterable[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
169
+ """
170
+ Compute image features for downstream evals.
171
+
172
+ :param batches: a iterator over NHWC numpy arrays in [0, 255].
173
+ :return: a tuple of numpy arrays of shape [N x X], where X is a feature
174
+ dimension. The tuple is (pool_3, spatial).
175
+ """
176
+ preds = []
177
+ spatial_preds = []
178
+ for batch in tqdm(batches):
179
+ batch = batch.astype(np.float32)
180
+ pred, spatial_pred = self.sess.run(
181
+ [self.pool_features, self.spatial_features], {self.image_input: batch}
182
+ )
183
+ preds.append(pred.reshape([pred.shape[0], -1]))
184
+ spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1]))
185
+ return (
186
+ np.concatenate(preds, axis=0),
187
+ np.concatenate(spatial_preds, axis=0),
188
+ )
189
+
190
+ def read_statistics(
191
+ self, npz_path: str, activations: Tuple[np.ndarray, np.ndarray]
192
+ ) -> Tuple[FIDStatistics, FIDStatistics]:
193
+ obj = np.load(npz_path)
194
+ if "mu" in list(obj.keys()):
195
+ return FIDStatistics(obj["mu"], obj["sigma"]), FIDStatistics(
196
+ obj["mu_s"], obj["sigma_s"]
197
+ )
198
+ return tuple(self.compute_statistics(x) for x in activations)
199
+
200
+ def compute_statistics(self, activations: np.ndarray) -> FIDStatistics:
201
+ mu = np.mean(activations, axis=0)
202
+ sigma = np.cov(activations, rowvar=False)
203
+ return FIDStatistics(mu, sigma)
204
+
205
+ def compute_inception_score(self, activations: np.ndarray, split_size: int = 5000) -> float:
206
+ softmax_out = []
207
+ for i in range(0, len(activations), self.softmax_batch_size):
208
+ acts = activations[i : i + self.softmax_batch_size]
209
+ softmax_out.append(self.sess.run(self.softmax, feed_dict={self.softmax_input: acts}))
210
+ preds = np.concatenate(softmax_out, axis=0)
211
+ # https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46
212
+ scores = []
213
+ for i in range(0, len(preds), split_size):
214
+ part = preds[i : i + split_size]
215
+ kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
216
+ kl = np.mean(np.sum(kl, 1))
217
+ scores.append(np.exp(kl))
218
+ return float(np.mean(scores))
219
+
220
+ def compute_prec_recall(
221
+ self, activations_ref: np.ndarray, activations_sample: np.ndarray
222
+ ) -> Tuple[float, float]:
223
+ radii_1 = self.manifold_estimator.manifold_radii(activations_ref)
224
+ radii_2 = self.manifold_estimator.manifold_radii(activations_sample)
225
+ pr = self.manifold_estimator.evaluate_pr(
226
+ activations_ref, radii_1, activations_sample, radii_2
227
+ )
228
+ return (float(pr[0][0]), float(pr[1][0]))
229
+
230
+
231
+ class ManifoldEstimator:
232
+ """
233
+ A helper for comparing manifolds of feature vectors.
234
+
235
+ Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L57
236
+ """
237
+
238
+ def __init__(
239
+ self,
240
+ session,
241
+ row_batch_size=10000,
242
+ col_batch_size=10000,
243
+ nhood_sizes=(3,),
244
+ clamp_to_percentile=None,
245
+ eps=1e-5,
246
+ ):
247
+ """
248
+ Estimate the manifold of given feature vectors.
249
+
250
+ :param session: the TensorFlow session.
251
+ :param row_batch_size: row batch size to compute pairwise distances
252
+ (parameter to trade-off between memory usage and performance).
253
+ :param col_batch_size: column batch size to compute pairwise distances.
254
+ :param nhood_sizes: number of neighbors used to estimate the manifold.
255
+ :param clamp_to_percentile: prune hyperspheres that have radius larger than
256
+ the given percentile.
257
+ :param eps: small number for numerical stability.
258
+ """
259
+ self.distance_block = DistanceBlock(session)
260
+ self.row_batch_size = row_batch_size
261
+ self.col_batch_size = col_batch_size
262
+ self.nhood_sizes = nhood_sizes
263
+ self.num_nhoods = len(nhood_sizes)
264
+ self.clamp_to_percentile = clamp_to_percentile
265
+ self.eps = eps
266
+
267
+ def warmup(self):
268
+ feats, radii = (
269
+ np.zeros([1, 2048], dtype=np.float32),
270
+ np.zeros([1, 1], dtype=np.float32),
271
+ )
272
+ self.evaluate_pr(feats, radii, feats, radii)
273
+
274
+ def manifold_radii(self, features: np.ndarray) -> np.ndarray:
275
+ num_images = len(features)
276
+
277
+ # Estimate manifold of features by calculating distances to k-NN of each sample.
278
+ radii = np.zeros([num_images, self.num_nhoods], dtype=np.float32)
279
+ distance_batch = np.zeros([self.row_batch_size, num_images], dtype=np.float32)
280
+ seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32)
281
+
282
+ for begin1 in range(0, num_images, self.row_batch_size):
283
+ end1 = min(begin1 + self.row_batch_size, num_images)
284
+ row_batch = features[begin1:end1]
285
+
286
+ for begin2 in range(0, num_images, self.col_batch_size):
287
+ end2 = min(begin2 + self.col_batch_size, num_images)
288
+ col_batch = features[begin2:end2]
289
+
290
+ # Compute distances between batches.
291
+ distance_batch[
292
+ 0 : end1 - begin1, begin2:end2
293
+ ] = self.distance_block.pairwise_distances(row_batch, col_batch)
294
+
295
+ # Find the k-nearest neighbor from the current batch.
296
+ radii[begin1:end1, :] = np.concatenate(
297
+ [
298
+ x[:, self.nhood_sizes]
299
+ for x in _numpy_partition(distance_batch[0 : end1 - begin1, :], seq, axis=1)
300
+ ],
301
+ axis=0,
302
+ )
303
+
304
+ if self.clamp_to_percentile is not None:
305
+ max_distances = np.percentile(radii, self.clamp_to_percentile, axis=0)
306
+ radii[radii > max_distances] = 0
307
+ return radii
308
+
309
+ def evaluate(self, features: np.ndarray, radii: np.ndarray, eval_features: np.ndarray):
310
+ """
311
+ Evaluate if new feature vectors are at the manifold.
312
+ """
313
+ num_eval_images = eval_features.shape[0]
314
+ num_ref_images = radii.shape[0]
315
+ distance_batch = np.zeros([self.row_batch_size, num_ref_images], dtype=np.float32)
316
+ batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32)
317
+ max_realism_score = np.zeros([num_eval_images], dtype=np.float32)
318
+ nearest_indices = np.zeros([num_eval_images], dtype=np.int32)
319
+
320
+ for begin1 in range(0, num_eval_images, self.row_batch_size):
321
+ end1 = min(begin1 + self.row_batch_size, num_eval_images)
322
+ feature_batch = eval_features[begin1:end1]
323
+
324
+ for begin2 in range(0, num_ref_images, self.col_batch_size):
325
+ end2 = min(begin2 + self.col_batch_size, num_ref_images)
326
+ ref_batch = features[begin2:end2]
327
+
328
+ distance_batch[
329
+ 0 : end1 - begin1, begin2:end2
330
+ ] = self.distance_block.pairwise_distances(feature_batch, ref_batch)
331
+
332
+ # From the minibatch of new feature vectors, determine if they are in the estimated manifold.
333
+ # If a feature vector is inside a hypersphere of some reference sample, then
334
+ # the new sample lies at the estimated manifold.
335
+ # The radii of the hyperspheres are determined from distances of neighborhood size k.
336
+ samples_in_manifold = distance_batch[0 : end1 - begin1, :, None] <= radii
337
+ batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(np.int32)
338
+
339
+ max_realism_score[begin1:end1] = np.max(
340
+ radii[:, 0] / (distance_batch[0 : end1 - begin1, :] + self.eps), axis=1
341
+ )
342
+ nearest_indices[begin1:end1] = np.argmin(distance_batch[0 : end1 - begin1, :], axis=1)
343
+
344
+ return {
345
+ "fraction": float(np.mean(batch_predictions)),
346
+ "batch_predictions": batch_predictions,
347
+ "max_realisim_score": max_realism_score,
348
+ "nearest_indices": nearest_indices,
349
+ }
350
+
351
+ def evaluate_pr(
352
+ self,
353
+ features_1: np.ndarray,
354
+ radii_1: np.ndarray,
355
+ features_2: np.ndarray,
356
+ radii_2: np.ndarray,
357
+ ) -> Tuple[np.ndarray, np.ndarray]:
358
+ """
359
+ Evaluate precision and recall efficiently.
360
+
361
+ :param features_1: [N1 x D] feature vectors for reference batch.
362
+ :param radii_1: [N1 x K1] radii for reference vectors.
363
+ :param features_2: [N2 x D] feature vectors for the other batch.
364
+ :param radii_2: [N x K2] radii for other vectors.
365
+ :return: a tuple of arrays for (precision, recall):
366
+ - precision: an np.ndarray of length K1
367
+ - recall: an np.ndarray of length K2
368
+ """
369
+ features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=np.bool_)
370
+ features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=np.bool_)
371
+ for begin_1 in range(0, len(features_1), self.row_batch_size):
372
+ end_1 = begin_1 + self.row_batch_size
373
+ batch_1 = features_1[begin_1:end_1]
374
+ for begin_2 in range(0, len(features_2), self.col_batch_size):
375
+ end_2 = begin_2 + self.col_batch_size
376
+ batch_2 = features_2[begin_2:end_2]
377
+ batch_1_in, batch_2_in = self.distance_block.less_thans(
378
+ batch_1, radii_1[begin_1:end_1], batch_2, radii_2[begin_2:end_2]
379
+ )
380
+ features_1_status[begin_1:end_1] |= batch_1_in
381
+ features_2_status[begin_2:end_2] |= batch_2_in
382
+ return (
383
+ np.mean(features_2_status.astype(np.float64), axis=0),
384
+ np.mean(features_1_status.astype(np.float64), axis=0),
385
+ )
386
+
387
+
388
+ class DistanceBlock:
389
+ """
390
+ Calculate pairwise distances between vectors.
391
+
392
+ Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L34
393
+ """
394
+
395
+ def __init__(self, session):
396
+ self.session = session
397
+
398
+ # Initialize TF graph to calculate pairwise distances.
399
+ with session.graph.as_default():
400
+ self._features_batch1 = tf.placeholder(tf.float32, shape=[None, None])
401
+ self._features_batch2 = tf.placeholder(tf.float32, shape=[None, None])
402
+ distance_block_16 = _batch_pairwise_distances(
403
+ tf.cast(self._features_batch1, tf.float16),
404
+ tf.cast(self._features_batch2, tf.float16),
405
+ )
406
+ self.distance_block = tf.cond(
407
+ tf.reduce_all(tf.math.is_finite(distance_block_16)),
408
+ lambda: tf.cast(distance_block_16, tf.float32),
409
+ lambda: _batch_pairwise_distances(self._features_batch1, self._features_batch2),
410
+ )
411
+
412
+ # Extra logic for less thans.
413
+ self._radii1 = tf.placeholder(tf.float32, shape=[None, None])
414
+ self._radii2 = tf.placeholder(tf.float32, shape=[None, None])
415
+ dist32 = tf.cast(self.distance_block, tf.float32)[..., None]
416
+ self._batch_1_in = tf.math.reduce_any(dist32 <= self._radii2, axis=1)
417
+ self._batch_2_in = tf.math.reduce_any(dist32 <= self._radii1[:, None], axis=0)
418
+
419
+ def pairwise_distances(self, U, V):
420
+ """
421
+ Evaluate pairwise distances between two batches of feature vectors.
422
+ """
423
+ return self.session.run(
424
+ self.distance_block,
425
+ feed_dict={self._features_batch1: U, self._features_batch2: V},
426
+ )
427
+
428
+ def less_thans(self, batch_1, radii_1, batch_2, radii_2):
429
+ return self.session.run(
430
+ [self._batch_1_in, self._batch_2_in],
431
+ feed_dict={
432
+ self._features_batch1: batch_1,
433
+ self._features_batch2: batch_2,
434
+ self._radii1: radii_1,
435
+ self._radii2: radii_2,
436
+ },
437
+ )
438
+
439
+
440
+ def _batch_pairwise_distances(U, V):
441
+ """
442
+ Compute pairwise distances between two batches of feature vectors.
443
+ """
444
+ with tf.variable_scope("pairwise_dist_block"):
445
+ # Squared norms of each row in U and V.
446
+ norm_u = tf.reduce_sum(tf.square(U), 1)
447
+ norm_v = tf.reduce_sum(tf.square(V), 1)
448
+
449
+ # norm_u as a column and norm_v as a row vectors.
450
+ norm_u = tf.reshape(norm_u, [-1, 1])
451
+ norm_v = tf.reshape(norm_v, [1, -1])
452
+
453
+ # Pairwise squared Euclidean distances.
454
+ D = tf.maximum(norm_u - 2 * tf.matmul(U, V, False, True) + norm_v, 0.0)
455
+
456
+ return D
457
+
458
+
459
+ class NpzArrayReader(ABC):
460
+ @abstractmethod
461
+ def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
462
+ pass
463
+
464
+ @abstractmethod
465
+ def remaining(self) -> int:
466
+ pass
467
+
468
+ def read_batches(self, batch_size: int) -> Iterable[np.ndarray]:
469
+ def gen_fn():
470
+ while True:
471
+ batch = self.read_batch(batch_size)
472
+ if batch is None:
473
+ break
474
+ yield batch
475
+
476
+ rem = self.remaining()
477
+ num_batches = rem // batch_size + int(rem % batch_size != 0)
478
+ return BatchIterator(gen_fn, num_batches)
479
+
480
+
481
+ class BatchIterator:
482
+ def __init__(self, gen_fn, length):
483
+ self.gen_fn = gen_fn
484
+ self.length = length
485
+
486
+ def __len__(self):
487
+ return self.length
488
+
489
+ def __iter__(self):
490
+ return self.gen_fn()
491
+
492
+
493
+ class StreamingNpzArrayReader(NpzArrayReader):
494
+ def __init__(self, arr_f, shape, dtype):
495
+ self.arr_f = arr_f
496
+ self.shape = shape
497
+ self.dtype = dtype
498
+ self.idx = 0
499
+
500
+ def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
501
+ if self.idx >= self.shape[0]:
502
+ return None
503
+
504
+ bs = min(batch_size, self.shape[0] - self.idx)
505
+ self.idx += bs
506
+
507
+ if self.dtype.itemsize == 0:
508
+ return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype)
509
+
510
+ read_count = bs * np.prod(self.shape[1:])
511
+ read_size = int(read_count * self.dtype.itemsize)
512
+ data = _read_bytes(self.arr_f, read_size, "array data")
513
+ return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]])
514
+
515
+ def remaining(self) -> int:
516
+ return max(0, self.shape[0] - self.idx)
517
+
518
+
519
+ class MemoryNpzArrayReader(NpzArrayReader):
520
+ def __init__(self, arr):
521
+ self.arr = arr
522
+ self.idx = 0
523
+
524
+ @classmethod
525
+ def load(cls, path: str, arr_name: str):
526
+ with open(path, "rb") as f:
527
+ arr = np.load(f)[arr_name]
528
+ return cls(arr)
529
+
530
+ def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
531
+ if self.idx >= self.arr.shape[0]:
532
+ return None
533
+
534
+ res = self.arr[self.idx : self.idx + batch_size]
535
+ self.idx += batch_size
536
+ return res
537
+
538
+ def remaining(self) -> int:
539
+ return max(0, self.arr.shape[0] - self.idx)
540
+
541
+
542
+ @contextmanager
543
+ def open_npz_array(path: str, arr_name: str) -> NpzArrayReader:
544
+ with _open_npy_file(path, arr_name) as arr_f:
545
+ version = np.lib.format.read_magic(arr_f)
546
+ if version == (1, 0):
547
+ header = np.lib.format.read_array_header_1_0(arr_f)
548
+ elif version == (2, 0):
549
+ header = np.lib.format.read_array_header_2_0(arr_f)
550
+ else:
551
+ yield MemoryNpzArrayReader.load(path, arr_name)
552
+ return
553
+ shape, fortran, dtype = header
554
+ if fortran or dtype.hasobject:
555
+ yield MemoryNpzArrayReader.load(path, arr_name)
556
+ else:
557
+ yield StreamingNpzArrayReader(arr_f, shape, dtype)
558
+
559
+
560
+ def _read_bytes(fp, size, error_template="ran out of data"):
561
+ """
562
+ Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886
563
+
564
+ Read from file-like object until size bytes are read.
565
+ Raises ValueError if not EOF is encountered before size bytes are read.
566
+ Non-blocking objects only supported if they derive from io objects.
567
+ Required as e.g. ZipExtFile in python 2.6 can return less data than
568
+ requested.
569
+ """
570
+ data = bytes()
571
+ while True:
572
+ # io files (default in python3) return None or raise on
573
+ # would-block, python2 file will truncate, probably nothing can be
574
+ # done about that. note that regular files can't be non-blocking
575
+ try:
576
+ r = fp.read(size - len(data))
577
+ data += r
578
+ if len(r) == 0 or len(data) == size:
579
+ break
580
+ except io.BlockingIOError:
581
+ pass
582
+ if len(data) != size:
583
+ msg = "EOF: reading %s, expected %d bytes got %d"
584
+ raise ValueError(msg % (error_template, size, len(data)))
585
+ else:
586
+ return data
587
+
588
+
589
+ @contextmanager
590
+ def _open_npy_file(path: str, arr_name: str):
591
+ with open(path, "rb") as f:
592
+ with zipfile.ZipFile(f, "r") as zip_f:
593
+ if f"{arr_name}.npy" not in zip_f.namelist():
594
+ raise ValueError(f"missing {arr_name} in npz file")
595
+ with zip_f.open(f"{arr_name}.npy", "r") as arr_f:
596
+ yield arr_f
597
+
598
+
599
+ def _download_inception_model():
600
+ if os.path.exists(INCEPTION_V3_PATH):
601
+ return
602
+ print("downloading InceptionV3 model...")
603
+ with requests.get(INCEPTION_V3_URL, stream=True) as r:
604
+ r.raise_for_status()
605
+ tmp_path = INCEPTION_V3_PATH + ".tmp"
606
+ with open(tmp_path, "wb") as f:
607
+ for chunk in tqdm(r.iter_content(chunk_size=8192)):
608
+ f.write(chunk)
609
+ os.rename(tmp_path, INCEPTION_V3_PATH)
610
+
611
+
612
+ def _create_feature_graph(input_batch):
613
+ _download_inception_model()
614
+ prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
615
+ with open(INCEPTION_V3_PATH, "rb") as f:
616
+ graph_def = tf.GraphDef()
617
+ graph_def.ParseFromString(f.read())
618
+ pool3, spatial = tf.import_graph_def(
619
+ graph_def,
620
+ input_map={f"ExpandDims:0": input_batch},
621
+ return_elements=[FID_POOL_NAME, FID_SPATIAL_NAME],
622
+ name=prefix,
623
+ )
624
+ _update_shapes(pool3)
625
+ spatial = spatial[..., :7]
626
+ return pool3, spatial
627
+
628
+
629
+ def _create_softmax_graph(input_batch):
630
+ _download_inception_model()
631
+ prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
632
+ with open(INCEPTION_V3_PATH, "rb") as f:
633
+ graph_def = tf.GraphDef()
634
+ graph_def.ParseFromString(f.read())
635
+ (matmul,) = tf.import_graph_def(
636
+ graph_def, return_elements=[f"softmax/logits/MatMul"], name=prefix
637
+ )
638
+ w = matmul.inputs[1]
639
+ logits = tf.matmul(input_batch, w)
640
+ return tf.nn.softmax(logits)
641
+
642
+
643
+ def _update_shapes(pool3):
644
+ # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L50-L63
645
+ ops = pool3.graph.get_operations()
646
+ for op in ops:
647
+ for o in op.outputs:
648
+ shape = o.get_shape()
649
+ if shape._dims is not None: # pylint: disable=protected-access
650
+ # shape = [s.value for s in shape] TF 1.x
651
+ shape = [s for s in shape] # TF 2.x
652
+ new_shape = []
653
+ for j, s in enumerate(shape):
654
+ if s == 1 and j == 0:
655
+ new_shape.append(None)
656
+ else:
657
+ new_shape.append(s)
658
+ o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
659
+ return pool3
660
+
661
+
662
+ def _numpy_partition(arr, kth, **kwargs):
663
+ num_workers = min(cpu_count(), len(arr))
664
+ chunk_size = len(arr) // num_workers
665
+ extra = len(arr) % num_workers
666
+
667
+ start_idx = 0
668
+ batches = []
669
+ for i in range(num_workers):
670
+ size = chunk_size + (1 if i < extra else 0)
671
+ batches.append(arr[start_idx : start_idx + size])
672
+ start_idx += size
673
+
674
+ with ThreadPool(num_workers) as pool:
675
+ return list(pool.map(partial(np.partition, kth=kth, **kwargs), batches))
676
+
677
+
678
+ if __name__ == "__main__":
679
+ main()
REG/evaluations/requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ tensorflow-gpu>=2.0
2
+ scipy
3
+ requests
4
+ tqdm
REG/models/clip_vit.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+ import clip
10
+
11
+
12
+ class Bottleneck(nn.Module):
13
+ expansion = 4
14
+
15
+ def __init__(self, inplanes, planes, stride=1):
16
+ super().__init__()
17
+
18
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
19
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
20
+ self.bn1 = nn.BatchNorm2d(planes)
21
+ self.relu1 = nn.ReLU(inplace=True)
22
+
23
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
24
+ self.bn2 = nn.BatchNorm2d(planes)
25
+ self.relu2 = nn.ReLU(inplace=True)
26
+
27
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
28
+
29
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
30
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
31
+ self.relu3 = nn.ReLU(inplace=True)
32
+
33
+ self.downsample = None
34
+ self.stride = stride
35
+
36
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
37
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
38
+ self.downsample = nn.Sequential(OrderedDict([
39
+ ("-1", nn.AvgPool2d(stride)),
40
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
41
+ ("1", nn.BatchNorm2d(planes * self.expansion))
42
+ ]))
43
+
44
+ def forward(self, x: torch.Tensor):
45
+ identity = x
46
+
47
+ out = self.relu1(self.bn1(self.conv1(x)))
48
+ out = self.relu2(self.bn2(self.conv2(out)))
49
+ out = self.avgpool(out)
50
+ out = self.bn3(self.conv3(out))
51
+
52
+ if self.downsample is not None:
53
+ identity = self.downsample(x)
54
+
55
+ out += identity
56
+ out = self.relu3(out)
57
+ return out
58
+
59
+
60
+ class AttentionPool2d(nn.Module):
61
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
62
+ super().__init__()
63
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
64
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
65
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
66
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
67
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
68
+ self.num_heads = num_heads
69
+
70
+ def forward(self, x):
71
+ x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
72
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
73
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
74
+ x, _ = F.multi_head_attention_forward(
75
+ query=x[:1], key=x, value=x,
76
+ embed_dim_to_check=x.shape[-1],
77
+ num_heads=self.num_heads,
78
+ q_proj_weight=self.q_proj.weight,
79
+ k_proj_weight=self.k_proj.weight,
80
+ v_proj_weight=self.v_proj.weight,
81
+ in_proj_weight=None,
82
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
83
+ bias_k=None,
84
+ bias_v=None,
85
+ add_zero_attn=False,
86
+ dropout_p=0,
87
+ out_proj_weight=self.c_proj.weight,
88
+ out_proj_bias=self.c_proj.bias,
89
+ use_separate_proj_weight=True,
90
+ training=self.training,
91
+ need_weights=False
92
+ )
93
+ return x.squeeze(0)
94
+
95
+
96
+ class ModifiedResNet(nn.Module):
97
+ """
98
+ A ResNet class that is similar to torchvision's but contains the following changes:
99
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
100
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
101
+ - The final pooling layer is a QKV attention instead of an average pool
102
+ """
103
+
104
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
105
+ super().__init__()
106
+ self.output_dim = output_dim
107
+ self.input_resolution = input_resolution
108
+
109
+ # the 3-layer stem
110
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
111
+ self.bn1 = nn.BatchNorm2d(width // 2)
112
+ self.relu1 = nn.ReLU(inplace=True)
113
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
114
+ self.bn2 = nn.BatchNorm2d(width // 2)
115
+ self.relu2 = nn.ReLU(inplace=True)
116
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
117
+ self.bn3 = nn.BatchNorm2d(width)
118
+ self.relu3 = nn.ReLU(inplace=True)
119
+ self.avgpool = nn.AvgPool2d(2)
120
+
121
+ # residual layers
122
+ self._inplanes = width # this is a *mutable* variable used during construction
123
+ self.layer1 = self._make_layer(width, layers[0])
124
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
125
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
126
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
127
+
128
+ embed_dim = width * 32 # the ResNet feature dimension
129
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
130
+
131
+ def _make_layer(self, planes, blocks, stride=1):
132
+ layers = [Bottleneck(self._inplanes, planes, stride)]
133
+
134
+ self._inplanes = planes * Bottleneck.expansion
135
+ for _ in range(1, blocks):
136
+ layers.append(Bottleneck(self._inplanes, planes))
137
+
138
+ return nn.Sequential(*layers)
139
+
140
+ def forward(self, x):
141
+ def stem(x):
142
+ x = self.relu1(self.bn1(self.conv1(x)))
143
+ x = self.relu2(self.bn2(self.conv2(x)))
144
+ x = self.relu3(self.bn3(self.conv3(x)))
145
+ x = self.avgpool(x)
146
+ return x
147
+
148
+ x = x.type(self.conv1.weight.dtype)
149
+ x = stem(x)
150
+ x = self.layer1(x)
151
+ x = self.layer2(x)
152
+ x = self.layer3(x)
153
+ x = self.layer4(x)
154
+ x = self.attnpool(x)
155
+
156
+ return x
157
+
158
+
159
+ class LayerNorm(nn.LayerNorm):
160
+ """Subclass torch's LayerNorm to handle fp16."""
161
+
162
+ def forward(self, x: torch.Tensor):
163
+ orig_type = x.dtype
164
+ ret = super().forward(x.type(torch.float32))
165
+ return ret.type(orig_type)
166
+
167
+
168
+ class QuickGELU(nn.Module):
169
+ def forward(self, x: torch.Tensor):
170
+ return x * torch.sigmoid(1.702 * x)
171
+
172
+
173
+ class ResidualAttentionBlock(nn.Module):
174
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
175
+ super().__init__()
176
+
177
+ self.attn = nn.MultiheadAttention(d_model, n_head)
178
+ self.ln_1 = LayerNorm(d_model)
179
+ self.mlp = nn.Sequential(OrderedDict([
180
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
181
+ ("gelu", QuickGELU()),
182
+ ("c_proj", nn.Linear(d_model * 4, d_model))
183
+ ]))
184
+ self.ln_2 = LayerNorm(d_model)
185
+ self.attn_mask = attn_mask
186
+
187
+ def attention(self, x: torch.Tensor):
188
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
189
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
190
+
191
+ def forward(self, x: torch.Tensor):
192
+ x = x + self.attention(self.ln_1(x))
193
+ x = x + self.mlp(self.ln_2(x))
194
+ return x
195
+
196
+
197
+ class Transformer(nn.Module):
198
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
199
+ super().__init__()
200
+ self.width = width
201
+ self.layers = layers
202
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
203
+
204
+ def forward(self, x: torch.Tensor):
205
+ return self.resblocks(x)
206
+
207
+
208
+ class UpdatedVisionTransformer(nn.Module):
209
+ def __init__(self, model):
210
+ super().__init__()
211
+ self.model = model
212
+
213
+ def forward(self, x: torch.Tensor):
214
+ x = self.model.conv1(x) # shape = [*, width, grid, grid]
215
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
216
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
217
+ x = torch.cat([self.model.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
218
+ x = x + self.model.positional_embedding.to(x.dtype)
219
+ x = self.model.ln_pre(x)
220
+
221
+ x = x.permute(1, 0, 2) # NLD -> LND
222
+ x = self.model.transformer(x)
223
+ x = x.permute(1, 0, 2)[:, 1:] # LND -> NLD
224
+
225
+ # x = self.ln_post(x[:, 0, :])
226
+
227
+ # if self.proj is not None:
228
+ # x = x @ self.proj
229
+
230
+ return x
231
+
232
+
233
+ class CLIP(nn.Module):
234
+ def __init__(self,
235
+ embed_dim: int,
236
+ # vision
237
+ image_resolution: int,
238
+ vision_layers: Union[Tuple[int, int, int, int], int],
239
+ vision_width: int,
240
+ vision_patch_size: int,
241
+ # text
242
+ context_length: int,
243
+ vocab_size: int,
244
+ transformer_width: int,
245
+ transformer_heads: int,
246
+ transformer_layers: int
247
+ ):
248
+ super().__init__()
249
+
250
+ self.context_length = context_length
251
+
252
+ if isinstance(vision_layers, (tuple, list)):
253
+ vision_heads = vision_width * 32 // 64
254
+ self.visual = ModifiedResNet(
255
+ layers=vision_layers,
256
+ output_dim=embed_dim,
257
+ heads=vision_heads,
258
+ input_resolution=image_resolution,
259
+ width=vision_width
260
+ )
261
+ else:
262
+ vision_heads = vision_width // 64
263
+ self.visual = UpdatedVisionTransformer(
264
+ input_resolution=image_resolution,
265
+ patch_size=vision_patch_size,
266
+ width=vision_width,
267
+ layers=vision_layers,
268
+ heads=vision_heads,
269
+ output_dim=embed_dim
270
+ )
271
+
272
+ self.transformer = Transformer(
273
+ width=transformer_width,
274
+ layers=transformer_layers,
275
+ heads=transformer_heads,
276
+ attn_mask=self.build_attention_mask()
277
+ )
278
+
279
+ self.vocab_size = vocab_size
280
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
281
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
282
+ self.ln_final = LayerNorm(transformer_width)
283
+
284
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
285
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
286
+
287
+ self.initialize_parameters()
288
+
289
+ def initialize_parameters(self):
290
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
291
+ nn.init.normal_(self.positional_embedding, std=0.01)
292
+
293
+ if isinstance(self.visual, ModifiedResNet):
294
+ if self.visual.attnpool is not None:
295
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
296
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
297
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
298
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
299
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
300
+
301
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
302
+ for name, param in resnet_block.named_parameters():
303
+ if name.endswith("bn3.weight"):
304
+ nn.init.zeros_(param)
305
+
306
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
307
+ attn_std = self.transformer.width ** -0.5
308
+ fc_std = (2 * self.transformer.width) ** -0.5
309
+ for block in self.transformer.resblocks:
310
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
311
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
312
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
313
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
314
+
315
+ if self.text_projection is not None:
316
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
317
+
318
+ def build_attention_mask(self):
319
+ # lazily create causal attention mask, with full attention between the vision tokens
320
+ # pytorch uses additive attention mask; fill with -inf
321
+ mask = torch.empty(self.context_length, self.context_length)
322
+ mask.fill_(float("-inf"))
323
+ mask.triu_(1) # zero out the lower diagonal
324
+ return mask
325
+
326
+ @property
327
+ def dtype(self):
328
+ return self.visual.conv1.weight.dtype
329
+
330
+ def encode_image(self, image):
331
+ return self.visual(image.type(self.dtype))
332
+
333
+ def encode_text(self, text):
334
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
335
+
336
+ x = x + self.positional_embedding.type(self.dtype)
337
+ x = x.permute(1, 0, 2) # NLD -> LND
338
+ x = self.transformer(x)
339
+ x = x.permute(1, 0, 2) # LND -> NLD
340
+ x = self.ln_final(x).type(self.dtype)
341
+
342
+ # x.shape = [batch_size, n_ctx, transformer.width]
343
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
344
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
345
+
346
+ return x
347
+
348
+ def forward(self, image, text):
349
+ image_features = self.encode_image(image)
350
+ text_features = self.encode_text(text)
351
+
352
+ # normalized features
353
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
354
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
355
+
356
+ # cosine similarity as logits
357
+ logit_scale = self.logit_scale.exp()
358
+ logits_per_image = logit_scale * image_features @ text_features.t()
359
+ logits_per_text = logits_per_image.t()
360
+
361
+ # shape = [global_batch_size, global_batch_size]
362
+ return logits_per_image, logits_per_text
363
+
364
+
365
+ def convert_weights(model: nn.Module):
366
+ """Convert applicable model parameters to fp16"""
367
+
368
+ def _convert_weights_to_fp16(l):
369
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
370
+ l.weight.data = l.weight.data.half()
371
+ if l.bias is not None:
372
+ l.bias.data = l.bias.data.half()
373
+
374
+ if isinstance(l, nn.MultiheadAttention):
375
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
376
+ tensor = getattr(l, attr)
377
+ if tensor is not None:
378
+ tensor.data = tensor.data.half()
379
+
380
+ for name in ["text_projection", "proj"]:
381
+ if hasattr(l, name):
382
+ attr = getattr(l, name)
383
+ if attr is not None:
384
+ attr.data = attr.data.half()
385
+
386
+ model.apply(_convert_weights_to_fp16)
387
+
388
+
389
+ def build_model(state_dict: dict):
390
+ vit = "visual.proj" in state_dict
391
+
392
+ if vit:
393
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
394
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
395
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
396
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
397
+ image_resolution = vision_patch_size * grid_size
398
+ else:
399
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
400
+ vision_layers = tuple(counts)
401
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
402
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
403
+ vision_patch_size = None
404
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
405
+ image_resolution = output_width * 32
406
+
407
+ embed_dim = state_dict["text_projection"].shape[1]
408
+ context_length = state_dict["positional_embedding"].shape[0]
409
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
410
+ transformer_width = state_dict["ln_final.weight"].shape[0]
411
+ transformer_heads = transformer_width // 64
412
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
413
+
414
+ model = CLIP(
415
+ embed_dim,
416
+ image_resolution, vision_layers, vision_width, vision_patch_size,
417
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
418
+ )
419
+
420
+ for key in ["input_resolution", "context_length", "vocab_size"]:
421
+ if key in state_dict:
422
+ del state_dict[key]
423
+
424
+ convert_weights(model)
425
+ model.load_state_dict(state_dict)
426
+ return model.eval()
REG/models/jepa.py ADDED
@@ -0,0 +1,547 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+
8
+ import math
9
+ from functools import partial
10
+ import numpy as np
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+
15
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
16
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
17
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
18
+ def norm_cdf(x):
19
+ # Computes standard normal cumulative distribution function
20
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
21
+
22
+ with torch.no_grad():
23
+ # Values are generated by using a truncated uniform distribution and
24
+ # then using the inverse CDF for the normal distribution.
25
+ # Get upper and lower cdf values
26
+ l = norm_cdf((a - mean) / std)
27
+ u = norm_cdf((b - mean) / std)
28
+
29
+ # Uniformly fill tensor with values from [l, u], then translate to
30
+ # [2l-1, 2u-1].
31
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
32
+
33
+ # Use inverse cdf transform for normal distribution to get truncated
34
+ # standard normal
35
+ tensor.erfinv_()
36
+
37
+ # Transform to proper mean, std
38
+ tensor.mul_(std * math.sqrt(2.))
39
+ tensor.add_(mean)
40
+
41
+ # Clamp to ensure it's in the proper range
42
+ tensor.clamp_(min=a, max=b)
43
+ return tensor
44
+
45
+
46
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
47
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
48
+
49
+
50
+ def repeat_interleave_batch(x, B, repeat):
51
+ N = len(x) // B
52
+ x = torch.cat([
53
+ torch.cat([x[i*B:(i+1)*B] for _ in range(repeat)], dim=0)
54
+ for i in range(N)
55
+ ], dim=0)
56
+ return x
57
+
58
+ def apply_masks(x, masks):
59
+ """
60
+ :param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)]
61
+ :param masks: list of tensors containing indices of patches in [N] to keep
62
+ """
63
+ all_x = []
64
+ for m in masks:
65
+ mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1))
66
+ all_x += [torch.gather(x, dim=1, index=mask_keep)]
67
+ return torch.cat(all_x, dim=0)
68
+
69
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
70
+ """
71
+ grid_size: int of the grid height and width
72
+ return:
73
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
74
+ """
75
+ grid_h = np.arange(grid_size, dtype=float)
76
+ grid_w = np.arange(grid_size, dtype=float)
77
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
78
+ grid = np.stack(grid, axis=0)
79
+
80
+ grid = grid.reshape([2, 1, grid_size, grid_size])
81
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
82
+ if cls_token:
83
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
84
+ return pos_embed
85
+
86
+
87
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
88
+ assert embed_dim % 2 == 0
89
+
90
+ # use half of dimensions to encode grid_h
91
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
92
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
93
+
94
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
95
+ return emb
96
+
97
+
98
+ def get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
99
+ """
100
+ grid_size: int of the grid length
101
+ return:
102
+ pos_embed: [grid_size, embed_dim] or [1+grid_size, embed_dim] (w/ or w/o cls_token)
103
+ """
104
+ grid = np.arange(grid_size, dtype=float)
105
+ pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid)
106
+ if cls_token:
107
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
108
+ return pos_embed
109
+
110
+
111
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
112
+ """
113
+ embed_dim: output dimension for each position
114
+ pos: a list of positions to be encoded: size (M,)
115
+ out: (M, D)
116
+ """
117
+ assert embed_dim % 2 == 0
118
+ omega = np.arange(embed_dim // 2, dtype=float)
119
+ omega /= embed_dim / 2.
120
+ omega = 1. / 10000**omega # (D/2,)
121
+
122
+ pos = pos.reshape(-1) # (M,)
123
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
124
+
125
+ emb_sin = np.sin(out) # (M, D/2)
126
+ emb_cos = np.cos(out) # (M, D/2)
127
+
128
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
129
+ return emb
130
+
131
+
132
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
133
+ if drop_prob == 0. or not training:
134
+ return x
135
+ keep_prob = 1 - drop_prob
136
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
137
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
138
+ random_tensor.floor_() # binarize
139
+ output = x.div(keep_prob) * random_tensor
140
+ return output
141
+
142
+
143
+ class DropPath(nn.Module):
144
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
145
+ """
146
+ def __init__(self, drop_prob=None):
147
+ super(DropPath, self).__init__()
148
+ self.drop_prob = drop_prob
149
+
150
+ def forward(self, x):
151
+ return drop_path(x, self.drop_prob, self.training)
152
+
153
+
154
+ class MLP(nn.Module):
155
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
156
+ super().__init__()
157
+ out_features = out_features or in_features
158
+ hidden_features = hidden_features or in_features
159
+ self.fc1 = nn.Linear(in_features, hidden_features)
160
+ self.act = act_layer()
161
+ self.fc2 = nn.Linear(hidden_features, out_features)
162
+ self.drop = nn.Dropout(drop)
163
+
164
+ def forward(self, x):
165
+ x = self.fc1(x)
166
+ x = self.act(x)
167
+ x = self.drop(x)
168
+ x = self.fc2(x)
169
+ x = self.drop(x)
170
+ return x
171
+
172
+
173
+ class Attention(nn.Module):
174
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
175
+ super().__init__()
176
+ self.num_heads = num_heads
177
+ head_dim = dim // num_heads
178
+ self.scale = qk_scale or head_dim ** -0.5
179
+
180
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
181
+ self.attn_drop = nn.Dropout(attn_drop)
182
+ self.proj = nn.Linear(dim, dim)
183
+ self.proj_drop = nn.Dropout(proj_drop)
184
+
185
+ def forward(self, x):
186
+ B, N, C = x.shape
187
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
188
+ q, k, v = qkv[0], qkv[1], qkv[2]
189
+
190
+ attn = (q @ k.transpose(-2, -1)) * self.scale
191
+ attn = attn.softmax(dim=-1)
192
+ attn = self.attn_drop(attn)
193
+
194
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
195
+ x = self.proj(x)
196
+ x = self.proj_drop(x)
197
+ return x, attn
198
+
199
+
200
+ class Block(nn.Module):
201
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
202
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
203
+ super().__init__()
204
+ self.norm1 = norm_layer(dim)
205
+ self.attn = Attention(
206
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
207
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
208
+ self.norm2 = norm_layer(dim)
209
+ mlp_hidden_dim = int(dim * mlp_ratio)
210
+ self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
211
+
212
+ def forward(self, x, return_attention=False):
213
+ y, attn = self.attn(self.norm1(x))
214
+ if return_attention:
215
+ return attn
216
+ x = x + self.drop_path(y)
217
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
218
+ return x
219
+
220
+
221
+ class PatchEmbed(nn.Module):
222
+ """ Image to Patch Embedding
223
+ """
224
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
225
+ super().__init__()
226
+ num_patches = (img_size // patch_size) * (img_size // patch_size)
227
+ self.img_size = img_size
228
+ self.patch_size = patch_size
229
+ self.num_patches = num_patches
230
+
231
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
232
+
233
+ def forward(self, x):
234
+ B, C, H, W = x.shape
235
+ x = self.proj(x).flatten(2).transpose(1, 2)
236
+ return x
237
+
238
+
239
+ class ConvEmbed(nn.Module):
240
+ """
241
+ 3x3 Convolution stems for ViT following ViTC models
242
+ """
243
+
244
+ def __init__(self, channels, strides, img_size=224, in_chans=3, batch_norm=True):
245
+ super().__init__()
246
+ # Build the stems
247
+ stem = []
248
+ channels = [in_chans] + channels
249
+ for i in range(len(channels) - 2):
250
+ stem += [nn.Conv2d(channels[i], channels[i+1], kernel_size=3,
251
+ stride=strides[i], padding=1, bias=(not batch_norm))]
252
+ if batch_norm:
253
+ stem += [nn.BatchNorm2d(channels[i+1])]
254
+ stem += [nn.ReLU(inplace=True)]
255
+ stem += [nn.Conv2d(channels[-2], channels[-1], kernel_size=1, stride=strides[-1])]
256
+ self.stem = nn.Sequential(*stem)
257
+
258
+ # Comptute the number of patches
259
+ stride_prod = int(np.prod(strides))
260
+ self.num_patches = (img_size[0] // stride_prod)**2
261
+
262
+ def forward(self, x):
263
+ p = self.stem(x)
264
+ return p.flatten(2).transpose(1, 2)
265
+
266
+
267
+ class VisionTransformerPredictor(nn.Module):
268
+ """ Vision Transformer """
269
+ def __init__(
270
+ self,
271
+ num_patches,
272
+ embed_dim=768,
273
+ predictor_embed_dim=384,
274
+ depth=6,
275
+ num_heads=12,
276
+ mlp_ratio=4.0,
277
+ qkv_bias=True,
278
+ qk_scale=None,
279
+ drop_rate=0.0,
280
+ attn_drop_rate=0.0,
281
+ drop_path_rate=0.0,
282
+ norm_layer=nn.LayerNorm,
283
+ init_std=0.02,
284
+ **kwargs
285
+ ):
286
+ super().__init__()
287
+ self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True)
288
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, predictor_embed_dim))
289
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
290
+ # --
291
+ self.predictor_pos_embed = nn.Parameter(torch.zeros(1, num_patches, predictor_embed_dim),
292
+ requires_grad=False)
293
+ predictor_pos_embed = get_2d_sincos_pos_embed(self.predictor_pos_embed.shape[-1],
294
+ int(num_patches**.5),
295
+ cls_token=False)
296
+ self.predictor_pos_embed.data.copy_(torch.from_numpy(predictor_pos_embed).float().unsqueeze(0))
297
+ # --
298
+ self.predictor_blocks = nn.ModuleList([
299
+ Block(
300
+ dim=predictor_embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
301
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
302
+ for i in range(depth)])
303
+ self.predictor_norm = norm_layer(predictor_embed_dim)
304
+ self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True)
305
+ # ------
306
+ self.init_std = init_std
307
+ trunc_normal_(self.mask_token, std=self.init_std)
308
+ self.apply(self._init_weights)
309
+ self.fix_init_weight()
310
+
311
+ def fix_init_weight(self):
312
+ def rescale(param, layer_id):
313
+ param.div_(math.sqrt(2.0 * layer_id))
314
+
315
+ for layer_id, layer in enumerate(self.predictor_blocks):
316
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
317
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
318
+
319
+ def _init_weights(self, m):
320
+ if isinstance(m, nn.Linear):
321
+ trunc_normal_(m.weight, std=self.init_std)
322
+ if isinstance(m, nn.Linear) and m.bias is not None:
323
+ nn.init.constant_(m.bias, 0)
324
+ elif isinstance(m, nn.LayerNorm):
325
+ nn.init.constant_(m.bias, 0)
326
+ nn.init.constant_(m.weight, 1.0)
327
+ elif isinstance(m, nn.Conv2d):
328
+ trunc_normal_(m.weight, std=self.init_std)
329
+ if m.bias is not None:
330
+ nn.init.constant_(m.bias, 0)
331
+
332
+ def forward(self, x, masks_x, masks):
333
+ assert (masks is not None) and (masks_x is not None), 'Cannot run predictor without mask indices'
334
+
335
+ if not isinstance(masks_x, list):
336
+ masks_x = [masks_x]
337
+
338
+ if not isinstance(masks, list):
339
+ masks = [masks]
340
+
341
+ # -- Batch Size
342
+ B = len(x) // len(masks_x)
343
+
344
+ # -- map from encoder-dim to pedictor-dim
345
+ x = self.predictor_embed(x)
346
+
347
+ # -- add positional embedding to x tokens
348
+ x_pos_embed = self.predictor_pos_embed.repeat(B, 1, 1)
349
+ x += apply_masks(x_pos_embed, masks_x)
350
+
351
+ _, N_ctxt, D = x.shape
352
+
353
+ # -- concat mask tokens to x
354
+ pos_embs = self.predictor_pos_embed.repeat(B, 1, 1)
355
+ pos_embs = apply_masks(pos_embs, masks)
356
+ pos_embs = repeat_interleave_batch(pos_embs, B, repeat=len(masks_x))
357
+ # --
358
+ pred_tokens = self.mask_token.repeat(pos_embs.size(0), pos_embs.size(1), 1)
359
+ # --
360
+ pred_tokens += pos_embs
361
+ x = x.repeat(len(masks), 1, 1)
362
+ x = torch.cat([x, pred_tokens], dim=1)
363
+
364
+ # -- fwd prop
365
+ for blk in self.predictor_blocks:
366
+ x = blk(x)
367
+ x = self.predictor_norm(x)
368
+
369
+ # -- return preds for mask tokens
370
+ x = x[:, N_ctxt:]
371
+ x = self.predictor_proj(x)
372
+
373
+ return x
374
+
375
+
376
+ class VisionTransformer(nn.Module):
377
+ """ Vision Transformer """
378
+ def __init__(
379
+ self,
380
+ img_size=[224],
381
+ patch_size=16,
382
+ in_chans=3,
383
+ embed_dim=768,
384
+ predictor_embed_dim=384,
385
+ depth=12,
386
+ predictor_depth=12,
387
+ num_heads=12,
388
+ mlp_ratio=4.0,
389
+ qkv_bias=True,
390
+ qk_scale=None,
391
+ drop_rate=0.0,
392
+ attn_drop_rate=0.0,
393
+ drop_path_rate=0.0,
394
+ norm_layer=nn.LayerNorm,
395
+ init_std=0.02,
396
+ **kwargs
397
+ ):
398
+ super().__init__()
399
+ self.num_features = self.embed_dim = embed_dim
400
+ self.num_heads = num_heads
401
+ # --
402
+ self.patch_embed = PatchEmbed(
403
+ img_size=img_size[0],
404
+ patch_size=patch_size,
405
+ in_chans=in_chans,
406
+ embed_dim=embed_dim)
407
+ num_patches = self.patch_embed.num_patches
408
+ # --
409
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim), requires_grad=False)
410
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1],
411
+ int(self.patch_embed.num_patches**.5),
412
+ cls_token=False)
413
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
414
+ # --
415
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
416
+ self.blocks = nn.ModuleList([
417
+ Block(
418
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
419
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
420
+ for i in range(depth)])
421
+ self.norm = norm_layer(embed_dim)
422
+ # ------
423
+ self.init_std = init_std
424
+ self.apply(self._init_weights)
425
+ self.fix_init_weight()
426
+
427
+ def fix_init_weight(self):
428
+ def rescale(param, layer_id):
429
+ param.div_(math.sqrt(2.0 * layer_id))
430
+
431
+ for layer_id, layer in enumerate(self.blocks):
432
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
433
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
434
+
435
+ def _init_weights(self, m):
436
+ if isinstance(m, nn.Linear):
437
+ trunc_normal_(m.weight, std=self.init_std)
438
+ if isinstance(m, nn.Linear) and m.bias is not None:
439
+ nn.init.constant_(m.bias, 0)
440
+ elif isinstance(m, nn.LayerNorm):
441
+ nn.init.constant_(m.bias, 0)
442
+ nn.init.constant_(m.weight, 1.0)
443
+ elif isinstance(m, nn.Conv2d):
444
+ trunc_normal_(m.weight, std=self.init_std)
445
+ if m.bias is not None:
446
+ nn.init.constant_(m.bias, 0)
447
+
448
+ def forward(self, x, masks=None):
449
+ if masks is not None:
450
+ if not isinstance(masks, list):
451
+ masks = [masks]
452
+
453
+ # -- patchify x
454
+ x = self.patch_embed(x)
455
+ B, N, D = x.shape
456
+
457
+ # -- add positional embedding to x
458
+ pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
459
+ x = x + pos_embed
460
+
461
+ # -- mask x
462
+ if masks is not None:
463
+ x = apply_masks(x, masks)
464
+
465
+ # -- fwd prop
466
+ for i, blk in enumerate(self.blocks):
467
+ x = blk(x)
468
+
469
+ if self.norm is not None:
470
+ x = self.norm(x)
471
+
472
+ return x
473
+
474
+ def interpolate_pos_encoding(self, x, pos_embed):
475
+ npatch = x.shape[1] - 1
476
+ N = pos_embed.shape[1] - 1
477
+ if npatch == N:
478
+ return pos_embed
479
+ class_emb = pos_embed[:, 0]
480
+ pos_embed = pos_embed[:, 1:]
481
+ dim = x.shape[-1]
482
+ pos_embed = nn.functional.interpolate(
483
+ pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
484
+ scale_factor=math.sqrt(npatch / N),
485
+ mode='bicubic',
486
+ )
487
+ pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
488
+ return torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)
489
+
490
+
491
+ def vit_predictor(**kwargs):
492
+ model = VisionTransformerPredictor(
493
+ mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
494
+ **kwargs)
495
+ return model
496
+
497
+
498
+ def vit_tiny(patch_size=16, **kwargs):
499
+ model = VisionTransformer(
500
+ patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
501
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
502
+ return model
503
+
504
+
505
+ def vit_small(patch_size=16, **kwargs):
506
+ model = VisionTransformer(
507
+ patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
508
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
509
+ return model
510
+
511
+
512
+ def vit_base(patch_size=16, **kwargs):
513
+ model = VisionTransformer(
514
+ patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
515
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
516
+ return model
517
+
518
+
519
+ def vit_large(patch_size=16, **kwargs):
520
+ model = VisionTransformer(
521
+ patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
522
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
523
+ return model
524
+
525
+
526
+ def vit_huge(patch_size=16, **kwargs):
527
+ model = VisionTransformer(
528
+ patch_size=patch_size, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4,
529
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
530
+ return model
531
+
532
+
533
+ def vit_giant(patch_size=16, **kwargs):
534
+ model = VisionTransformer(
535
+ patch_size=patch_size, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=48/11,
536
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
537
+ return model
538
+
539
+
540
+ VIT_EMBED_DIMS = {
541
+ 'vit_tiny': 192,
542
+ 'vit_small': 384,
543
+ 'vit_base': 768,
544
+ 'vit_large': 1024,
545
+ 'vit_huge': 1280,
546
+ 'vit_giant': 1408,
547
+ }
REG/models/mae_vit.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
9
+ # DeiT: https://github.com/facebookresearch/deit
10
+ # --------------------------------------------------------
11
+
12
+ from functools import partial
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+ import timm.models.vision_transformer
18
+
19
+
20
+ class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
21
+ """ Vision Transformer with support for global average pooling
22
+ """
23
+ def __init__(self, global_pool=False, **kwargs):
24
+ super(VisionTransformer, self).__init__(**kwargs)
25
+
26
+ self.global_pool = global_pool
27
+ if self.global_pool:
28
+ norm_layer = kwargs['norm_layer']
29
+ embed_dim = kwargs['embed_dim']
30
+ self.fc_norm = norm_layer(embed_dim)
31
+
32
+ del self.norm # remove the original norm
33
+
34
+ def forward_features(self, x):
35
+ B = x.shape[0]
36
+ x = self.patch_embed(x)
37
+
38
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
39
+ x = torch.cat((cls_tokens, x), dim=1)
40
+ x = x + self.pos_embed
41
+ x = self.pos_drop(x)
42
+
43
+ for blk in self.blocks:
44
+ x = blk(x)
45
+
46
+ x = x[:, 1:, :] #.mean(dim=1) # global pool without cls token
47
+
48
+ return x
49
+
50
+
51
+ def vit_base_patch16(**kwargs):
52
+ model = VisionTransformer(
53
+ num_classes=0,
54
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
55
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
56
+ return model
57
+
58
+
59
+ def vit_large_patch16(**kwargs):
60
+ model = VisionTransformer(
61
+ num_classes=0,
62
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
63
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
64
+ return model
65
+
66
+
67
+ def vit_huge_patch14(**kwargs):
68
+ model = VisionTransformer(
69
+ patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True,
70
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
71
+ return model
REG/models/mocov3_vit.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import torch
9
+ import torch.nn as nn
10
+ from functools import partial, reduce
11
+ from operator import mul
12
+
13
+ from timm.layers.helpers import to_2tuple
14
+ from timm.models.vision_transformer import VisionTransformer, _cfg
15
+ from timm.models.vision_transformer import PatchEmbed
16
+
17
+ __all__ = [
18
+ 'vit_small',
19
+ 'vit_base',
20
+ 'vit_large',
21
+ 'vit_conv_small',
22
+ 'vit_conv_base',
23
+ ]
24
+
25
+
26
+ def patchify_avg(input_tensor, patch_size):
27
+ # Ensure input tensor is 4D: (batch_size, channels, height, width)
28
+ if input_tensor.dim() != 4:
29
+ raise ValueError("Input tensor must be 4D (batch_size, channels, height, width)")
30
+
31
+ # Get input tensor dimensions
32
+ batch_size, channels, height, width = input_tensor.shape
33
+
34
+ # Ensure patch_size is valid
35
+ patch_height, patch_width = patch_size, patch_size
36
+ if height % patch_height != 0 or width % patch_width != 0:
37
+ raise ValueError("Input tensor dimensions must be divisible by patch_size")
38
+
39
+ # Use unfold to create patches
40
+ patches = input_tensor.unfold(2, patch_height, patch_height).unfold(3, patch_width, patch_width)
41
+
42
+ # Reshape patches to desired format: (batch_size, num_patches, channels)
43
+ patches = patches.contiguous().view(
44
+ batch_size, channels, -1, patch_height, patch_width
45
+ ).mean(dim=-1).mean(dim=-1)
46
+ patches = patches.permute(0, 2, 1).contiguous()
47
+
48
+ return patches
49
+
50
+
51
+
52
+ class VisionTransformerMoCo(VisionTransformer):
53
+ def __init__(self, stop_grad_conv1=False, **kwargs):
54
+ super().__init__(**kwargs)
55
+ # Use fixed 2D sin-cos position embedding
56
+ self.build_2d_sincos_position_embedding()
57
+
58
+ # weight initialization
59
+ for name, m in self.named_modules():
60
+ if isinstance(m, nn.Linear):
61
+ if 'qkv' in name:
62
+ # treat the weights of Q, K, V separately
63
+ val = math.sqrt(6. / float(m.weight.shape[0] // 3 + m.weight.shape[1]))
64
+ nn.init.uniform_(m.weight, -val, val)
65
+ else:
66
+ nn.init.xavier_uniform_(m.weight)
67
+ nn.init.zeros_(m.bias)
68
+ nn.init.normal_(self.cls_token, std=1e-6)
69
+
70
+ if isinstance(self.patch_embed, PatchEmbed):
71
+ # xavier_uniform initialization
72
+ val = math.sqrt(6. / float(3 * reduce(mul, self.patch_embed.patch_size, 1) + self.embed_dim))
73
+ nn.init.uniform_(self.patch_embed.proj.weight, -val, val)
74
+ nn.init.zeros_(self.patch_embed.proj.bias)
75
+
76
+ if stop_grad_conv1:
77
+ self.patch_embed.proj.weight.requires_grad = False
78
+ self.patch_embed.proj.bias.requires_grad = False
79
+
80
+ def build_2d_sincos_position_embedding(self, temperature=10000.):
81
+ h = self.patch_embed.img_size[0] // self.patch_embed.patch_size[0]
82
+ w = self.patch_embed.img_size[1] // self.patch_embed.patch_size[1]
83
+ grid_w = torch.arange(w, dtype=torch.float32)
84
+ grid_h = torch.arange(h, dtype=torch.float32)
85
+ grid_w, grid_h = torch.meshgrid(grid_w, grid_h)
86
+ assert self.embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
87
+ pos_dim = self.embed_dim // 4
88
+ omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
89
+ omega = 1. / (temperature**omega)
90
+ out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega])
91
+ out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])
92
+ pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :]
93
+
94
+ # assert self.num_tokens == 1, 'Assuming one and only one token, [cls]'
95
+ pe_token = torch.zeros([1, 1, self.embed_dim], dtype=torch.float32)
96
+ self.pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1))
97
+ self.pos_embed.requires_grad = False
98
+
99
+ def forward_diffusion_output(self, x):
100
+ x = x.reshape(*x.shape[0:2], -1).permute(0, 2, 1)
101
+ x = self._pos_embed(x)
102
+ x = self.patch_drop(x)
103
+ x = self.norm_pre(x)
104
+ x = self.blocks(x)
105
+ x = self.norm(x)
106
+ return x
107
+
108
+ class ConvStem(nn.Module):
109
+ """
110
+ ConvStem, from Early Convolutions Help Transformers See Better, Tete et al. https://arxiv.org/abs/2106.14881
111
+ """
112
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
113
+ super().__init__()
114
+
115
+ assert patch_size == 16, 'ConvStem only supports patch size of 16'
116
+ assert embed_dim % 8 == 0, 'Embed dimension must be divisible by 8 for ConvStem'
117
+
118
+ img_size = to_2tuple(img_size)
119
+ patch_size = to_2tuple(patch_size)
120
+ self.img_size = img_size
121
+ self.patch_size = patch_size
122
+ self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
123
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
124
+ self.flatten = flatten
125
+
126
+ # build stem, similar to the design in https://arxiv.org/abs/2106.14881
127
+ stem = []
128
+ input_dim, output_dim = 3, embed_dim // 8
129
+ for l in range(4):
130
+ stem.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=2, padding=1, bias=False))
131
+ stem.append(nn.BatchNorm2d(output_dim))
132
+ stem.append(nn.ReLU(inplace=True))
133
+ input_dim = output_dim
134
+ output_dim *= 2
135
+ stem.append(nn.Conv2d(input_dim, embed_dim, kernel_size=1))
136
+ self.proj = nn.Sequential(*stem)
137
+
138
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
139
+
140
+ def forward(self, x):
141
+ B, C, H, W = x.shape
142
+ assert H == self.img_size[0] and W == self.img_size[1], \
143
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
144
+ x = self.proj(x)
145
+ if self.flatten:
146
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
147
+ x = self.norm(x)
148
+ return x
149
+
150
+
151
+ def vit_small(**kwargs):
152
+ model = VisionTransformerMoCo(
153
+ img_size=256,
154
+ patch_size=16, embed_dim=384, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
155
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
156
+ model.default_cfg = _cfg()
157
+ return model
158
+
159
+ def vit_base(**kwargs):
160
+ model = VisionTransformerMoCo(
161
+ img_size=256,
162
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
163
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
164
+ model.default_cfg = _cfg()
165
+ return model
166
+
167
+ def vit_large(**kwargs):
168
+ model = VisionTransformerMoCo(
169
+ img_size=256,
170
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
171
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
172
+ model.default_cfg = _cfg()
173
+ return model
174
+
175
+ def vit_conv_small(**kwargs):
176
+ # minus one ViT block
177
+ model = VisionTransformerMoCo(
178
+ patch_size=16, embed_dim=384, depth=11, num_heads=12, mlp_ratio=4, qkv_bias=True,
179
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), embed_layer=ConvStem, **kwargs)
180
+ model.default_cfg = _cfg()
181
+ return model
182
+
183
+ def vit_conv_base(**kwargs):
184
+ # minus one ViT block
185
+ model = VisionTransformerMoCo(
186
+ patch_size=16, embed_dim=768, depth=11, num_heads=12, mlp_ratio=4, qkv_bias=True,
187
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), embed_layer=ConvStem, **kwargs)
188
+ model.default_cfg = _cfg()
189
+ return model
190
+
191
+ def build_mlp(num_layers, input_dim, mlp_dim, output_dim, last_bn=True):
192
+ mlp = []
193
+ for l in range(num_layers):
194
+ dim1 = input_dim if l == 0 else mlp_dim
195
+ dim2 = output_dim if l == num_layers - 1 else mlp_dim
196
+
197
+ mlp.append(nn.Linear(dim1, dim2, bias=False))
198
+
199
+ if l < num_layers - 1:
200
+ mlp.append(nn.BatchNorm1d(dim2))
201
+ mlp.append(nn.ReLU(inplace=True))
202
+ elif last_bn:
203
+ # follow SimCLR's design: https://github.com/google-research/simclr/blob/master/model_util.py#L157
204
+ # for simplicity, we further removed gamma in BN
205
+ mlp.append(nn.BatchNorm1d(dim2, affine=False))
206
+
207
+ return nn.Sequential(*mlp)
REG/models/sit.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This source code is licensed under the license found in the
2
+ # LICENSE file in the root directory of this source tree.
3
+ # --------------------------------------------------------
4
+ # References:
5
+ # GLIDE: https://github.com/openai/glide-text2im
6
+ # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
7
+ # --------------------------------------------------------
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import numpy as np
12
+ import math
13
+ from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
14
+
15
+
16
+ def build_mlp(hidden_size, projector_dim, z_dim):
17
+ return nn.Sequential(
18
+ nn.Linear(hidden_size, projector_dim),
19
+ nn.SiLU(),
20
+ nn.Linear(projector_dim, projector_dim),
21
+ nn.SiLU(),
22
+ nn.Linear(projector_dim, z_dim),
23
+ )
24
+
25
+ def modulate(x, shift, scale):
26
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
27
+
28
+ #################################################################################
29
+ # Embedding Layers for Timesteps and Class Labels #
30
+ #################################################################################
31
+ class TimestepEmbedder(nn.Module):
32
+ """
33
+ Embeds scalar timesteps into vector representations.
34
+ """
35
+ def __init__(self, hidden_size, frequency_embedding_size=256):
36
+ super().__init__()
37
+ self.mlp = nn.Sequential(
38
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
39
+ nn.SiLU(),
40
+ nn.Linear(hidden_size, hidden_size, bias=True),
41
+ )
42
+ self.frequency_embedding_size = frequency_embedding_size
43
+
44
+ @staticmethod
45
+ def positional_embedding(t, dim, max_period=10000):
46
+ """
47
+ Create sinusoidal timestep embeddings.
48
+ :param t: a 1-D Tensor of N indices, one per batch element.
49
+ These may be fractional.
50
+ :param dim: the dimension of the output.
51
+ :param max_period: controls the minimum frequency of the embeddings.
52
+ :return: an (N, D) Tensor of positional embeddings.
53
+ """
54
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
55
+ half = dim // 2
56
+ freqs = torch.exp(
57
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
58
+ ).to(device=t.device)
59
+ args = t[:, None].float() * freqs[None]
60
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
61
+ if dim % 2:
62
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
63
+ return embedding
64
+
65
+ def forward(self, t):
66
+ self.timestep_embedding = self.positional_embedding
67
+ t_freq = self.timestep_embedding(t, dim=self.frequency_embedding_size).to(t.dtype)
68
+ t_emb = self.mlp(t_freq)
69
+ return t_emb
70
+
71
+
72
+ class LabelEmbedder(nn.Module):
73
+ """
74
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
75
+ """
76
+ def __init__(self, num_classes, hidden_size, dropout_prob):
77
+ super().__init__()
78
+ use_cfg_embedding = dropout_prob > 0
79
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
80
+ self.num_classes = num_classes
81
+ self.dropout_prob = dropout_prob
82
+
83
+ def token_drop(self, labels, force_drop_ids=None):
84
+ """
85
+ Drops labels to enable classifier-free guidance.
86
+ """
87
+ if force_drop_ids is None:
88
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
89
+ else:
90
+ drop_ids = force_drop_ids == 1
91
+ labels = torch.where(drop_ids, self.num_classes, labels)
92
+ return labels
93
+
94
+ def forward(self, labels, train, force_drop_ids=None):
95
+ use_dropout = self.dropout_prob > 0
96
+ if (train and use_dropout) or (force_drop_ids is not None):
97
+ labels = self.token_drop(labels, force_drop_ids)
98
+ embeddings = self.embedding_table(labels)
99
+ return embeddings
100
+
101
+
102
+ #################################################################################
103
+ # Core SiT Model #
104
+ #################################################################################
105
+
106
+ class SiTBlock(nn.Module):
107
+ """
108
+ A SiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
109
+ """
110
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
111
+ super().__init__()
112
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
113
+ self.attn = Attention(
114
+ hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=block_kwargs["qk_norm"]
115
+ )
116
+ if "fused_attn" in block_kwargs.keys():
117
+ self.attn.fused_attn = block_kwargs["fused_attn"]
118
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
119
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
120
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
121
+ self.mlp = Mlp(
122
+ in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0
123
+ )
124
+ self.adaLN_modulation = nn.Sequential(
125
+ nn.SiLU(),
126
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
127
+ )
128
+
129
+ def forward(self, x, c):
130
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
131
+ self.adaLN_modulation(c).chunk(6, dim=-1)
132
+ )
133
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
134
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
135
+
136
+ return x
137
+
138
+
139
+ class FinalLayer(nn.Module):
140
+ """
141
+ The final layer of SiT.
142
+ """
143
+ def __init__(self, hidden_size, patch_size, out_channels, cls_token_dim):
144
+ super().__init__()
145
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
146
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
147
+ self.linear_cls = nn.Linear(hidden_size, cls_token_dim, bias=True)
148
+ self.adaLN_modulation = nn.Sequential(
149
+ nn.SiLU(),
150
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
151
+ )
152
+
153
+ def forward(self, x, c, cls=None):
154
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
155
+ x = modulate(self.norm_final(x), shift, scale)
156
+
157
+ if cls is None:
158
+ x = self.linear(x)
159
+ return x, None
160
+ else:
161
+ cls_token = self.linear_cls(x[:, 0]).unsqueeze(1)
162
+ x = self.linear(x[:, 1:])
163
+ return x, cls_token.squeeze(1)
164
+
165
+
166
+ class SiT(nn.Module):
167
+ """
168
+ Diffusion model with a Transformer backbone.
169
+ """
170
+ def __init__(
171
+ self,
172
+ path_type='edm',
173
+ input_size=32,
174
+ patch_size=2,
175
+ in_channels=4,
176
+ hidden_size=1152,
177
+ decoder_hidden_size=768,
178
+ encoder_depth=8,
179
+ depth=28,
180
+ num_heads=16,
181
+ mlp_ratio=4.0,
182
+ class_dropout_prob=0.1,
183
+ num_classes=1000,
184
+ use_cfg=False,
185
+ z_dims=[768],
186
+ projector_dim=2048,
187
+ cls_token_dim=768,
188
+ **block_kwargs # fused_attn
189
+ ):
190
+ super().__init__()
191
+ self.path_type = path_type
192
+ self.in_channels = in_channels
193
+ self.out_channels = in_channels
194
+ self.patch_size = patch_size
195
+ self.num_heads = num_heads
196
+ self.use_cfg = use_cfg
197
+ self.num_classes = num_classes
198
+ self.z_dims = z_dims
199
+ self.encoder_depth = encoder_depth
200
+
201
+ self.x_embedder = PatchEmbed(
202
+ input_size, patch_size, in_channels, hidden_size, bias=True
203
+ )
204
+ self.t_embedder = TimestepEmbedder(hidden_size) # timestep embedding type
205
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
206
+ num_patches = self.x_embedder.num_patches
207
+ # Will use fixed sin-cos embedding:
208
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, hidden_size), requires_grad=False)
209
+
210
+ self.blocks = nn.ModuleList([
211
+ SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, **block_kwargs) for _ in range(depth)
212
+ ])
213
+ self.projectors = nn.ModuleList([
214
+ build_mlp(hidden_size, projector_dim, z_dim) for z_dim in z_dims
215
+ ])
216
+
217
+ z_dim = self.z_dims[0]
218
+ cls_token_dim = z_dim
219
+ self.final_layer = FinalLayer(decoder_hidden_size, patch_size, self.out_channels, cls_token_dim)
220
+
221
+
222
+ self.cls_projectors2 = nn.Linear(in_features=cls_token_dim, out_features=hidden_size, bias=True)
223
+ self.wg_norm = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
224
+
225
+ self.initialize_weights()
226
+
227
+ def initialize_weights(self):
228
+ # Initialize transformer layers:
229
+ def _basic_init(module):
230
+ if isinstance(module, nn.Linear):
231
+ torch.nn.init.xavier_uniform_(module.weight)
232
+ if module.bias is not None:
233
+ nn.init.constant_(module.bias, 0)
234
+ self.apply(_basic_init)
235
+
236
+ # Initialize (and freeze) pos_embed by sin-cos embedding:
237
+ pos_embed = get_2d_sincos_pos_embed(
238
+ self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5), cls_token=1, extra_tokens=1
239
+ )
240
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
241
+
242
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
243
+ w = self.x_embedder.proj.weight.data
244
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
245
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
246
+
247
+ # Initialize label embedding table:
248
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
249
+
250
+ # Initialize timestep embedding MLP:
251
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
252
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
253
+
254
+ # Zero-out adaLN modulation layers in SiT blocks:
255
+ for block in self.blocks:
256
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
257
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
258
+
259
+ # Zero-out output layers:
260
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
261
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
262
+ nn.init.constant_(self.final_layer.linear.weight, 0)
263
+ nn.init.constant_(self.final_layer.linear.bias, 0)
264
+ nn.init.constant_(self.final_layer.linear_cls.weight, 0)
265
+ nn.init.constant_(self.final_layer.linear_cls.bias, 0)
266
+
267
+ def unpatchify(self, x, patch_size=None):
268
+ """
269
+ x: (N, T, patch_size**2 * C)
270
+ imgs: (N, C, H, W)
271
+ """
272
+ c = self.out_channels
273
+ p = self.x_embedder.patch_size[0] if patch_size is None else patch_size
274
+ h = w = int(x.shape[1] ** 0.5)
275
+ assert h * w == x.shape[1]
276
+
277
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
278
+ x = torch.einsum('nhwpqc->nchpwq', x)
279
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
280
+ return imgs
281
+
282
+ def forward(self, x, t, y, return_logvar=False, cls_token=None):
283
+ """
284
+ Forward pass of SiT.
285
+ x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
286
+ t: (N,) tensor of diffusion timesteps
287
+ y: (N,) tensor of class labels
288
+ """
289
+
290
+ #cat with cls_token
291
+ x = self.x_embedder(x) # (N, T, D), where T = H * W / patch_size ** 2
292
+ if cls_token is not None:
293
+ cls_token = self.cls_projectors2(cls_token)
294
+ cls_token = self.wg_norm(cls_token)
295
+ cls_token = cls_token.unsqueeze(1) # [b, length, d]
296
+ x = torch.cat((cls_token, x), dim=1)
297
+ x = x + self.pos_embed
298
+ else:
299
+ exit()
300
+ N, T, D = x.shape
301
+
302
+ # timestep and class embedding
303
+ t_embed = self.t_embedder(t) # (N, D)
304
+ y = self.y_embedder(y, self.training) # (N, D)
305
+ c = t_embed + y
306
+
307
+ for i, block in enumerate(self.blocks):
308
+ x = block(x, c)
309
+ if (i + 1) == self.encoder_depth:
310
+ zs = [projector(x.reshape(-1, D)).reshape(N, T, -1) for projector in self.projectors]
311
+
312
+ x, cls_token = self.final_layer(x, c, cls=cls_token)
313
+ x = self.unpatchify(x)
314
+
315
+ return x, zs, cls_token
316
+
317
+
318
+ #################################################################################
319
+ # Sine/Cosine Positional Embedding Functions #
320
+ #################################################################################
321
+ # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
322
+
323
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
324
+ """
325
+ grid_size: int of the grid height and width
326
+ return:
327
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
328
+ """
329
+ grid_h = np.arange(grid_size, dtype=np.float32)
330
+ grid_w = np.arange(grid_size, dtype=np.float32)
331
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
332
+ grid = np.stack(grid, axis=0)
333
+
334
+ grid = grid.reshape([2, 1, grid_size, grid_size])
335
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
336
+ if cls_token and extra_tokens > 0:
337
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
338
+ return pos_embed
339
+
340
+
341
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
342
+ assert embed_dim % 2 == 0
343
+
344
+ # use half of dimensions to encode grid_h
345
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
346
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
347
+
348
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
349
+ return emb
350
+
351
+
352
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
353
+ """
354
+ embed_dim: output dimension for each position
355
+ pos: a list of positions to be encoded: size (M,)
356
+ out: (M, D)
357
+ """
358
+ assert embed_dim % 2 == 0
359
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
360
+ omega /= embed_dim / 2.
361
+ omega = 1. / 10000**omega # (D/2,)
362
+
363
+ pos = pos.reshape(-1) # (M,)
364
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
365
+
366
+ emb_sin = np.sin(out) # (M, D/2)
367
+ emb_cos = np.cos(out) # (M, D/2)
368
+
369
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
370
+ return emb
371
+
372
+
373
+ #################################################################################
374
+ # SiT Configs #
375
+ #################################################################################
376
+
377
+ def SiT_XL_2(**kwargs):
378
+ return SiT(depth=28, hidden_size=1152, decoder_hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
379
+
380
+ def SiT_XL_4(**kwargs):
381
+ return SiT(depth=28, hidden_size=1152, decoder_hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
382
+
383
+ def SiT_XL_8(**kwargs):
384
+ return SiT(depth=28, hidden_size=1152, decoder_hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
385
+
386
+ def SiT_L_2(**kwargs):
387
+ return SiT(depth=24, hidden_size=1024, decoder_hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
388
+
389
+ def SiT_L_4(**kwargs):
390
+ return SiT(depth=24, hidden_size=1024, decoder_hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
391
+
392
+ def SiT_L_8(**kwargs):
393
+ return SiT(depth=24, hidden_size=1024, decoder_hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
394
+
395
+ def SiT_B_2(**kwargs):
396
+ return SiT(depth=12, hidden_size=768, decoder_hidden_size=768, patch_size=2, num_heads=12, **kwargs)
397
+
398
+ def SiT_B_4(**kwargs):
399
+ return SiT(depth=12, hidden_size=768, decoder_hidden_size=768, patch_size=4, num_heads=12, **kwargs)
400
+
401
+ def SiT_B_8(**kwargs):
402
+ return SiT(depth=12, hidden_size=768, decoder_hidden_size=768, patch_size=8, num_heads=12, **kwargs)
403
+
404
+ def SiT_S_2(**kwargs):
405
+ return SiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
406
+
407
+ def SiT_S_4(**kwargs):
408
+ return SiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
409
+
410
+ def SiT_S_8(**kwargs):
411
+ return SiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
412
+
413
+
414
+ SiT_models = {
415
+ 'SiT-XL/2': SiT_XL_2, 'SiT-XL/4': SiT_XL_4, 'SiT-XL/8': SiT_XL_8,
416
+ 'SiT-L/2': SiT_L_2, 'SiT-L/4': SiT_L_4, 'SiT-L/8': SiT_L_8,
417
+ 'SiT-B/2': SiT_B_2, 'SiT-B/4': SiT_B_4, 'SiT-B/8': SiT_B_8,
418
+ 'SiT-S/2': SiT_S_2, 'SiT-S/4': SiT_S_4, 'SiT-S/8': SiT_S_8,
419
+ }
420
+
back/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Sihyun Yu
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
back/README.md ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="center">
2
+ <h1 align="center">Representation Entanglement for Generation: Training Diffusion Transformers Is Much Easier Than You Think (NeurIPS 2025 Oral)
3
+ </h1>
4
+ <p align="center">
5
+ <a href='https://github.com/Martinser' style='text-decoration: none' >Ge Wu</a><sup>1</sup>&emsp;
6
+ <a href='https://github.com/ShenZhang-Shin' style='text-decoration: none' >Shen Zhang</a><sup>3</sup>&emsp;
7
+ <a href='' style='text-decoration: none' >Ruijing Shi</a><sup>1</sup>&emsp;
8
+ <a href='https://shgao.site/' style='text-decoration: none' >Shanghua Gao</a><sup>4</sup>&emsp;
9
+ <a href='https://zhenyuanchenai.github.io/' style='text-decoration: none' >Zhenyuan Chen</a><sup>1</sup>&emsp;
10
+ <a href='https://scholar.google.com/citations?user=6Z66DAwAAAAJ&hl=en' style='text-decoration: none' >Lei Wang</a><sup>1</sup>&emsp;
11
+ <a href='https://www.zhihu.com/people/chen-zhao-wei-16-2' style='text-decoration: none' >Zhaowei Chen</a><sup>3</sup>&emsp;
12
+ <a href='https://gao-hongcheng.github.io/' style='text-decoration: none' >Hongcheng Gao</a><sup>5</sup>&emsp;
13
+ <a href='https://scholar.google.com/citations?view_op=list_works&hl=zh-CN&hl=zh-CN&user=0xP6bxcAAAAJ' style='text-decoration: none' >Yao Tang</a><sup>3</sup>&emsp;
14
+ <a href='https://scholar.google.com/citations?user=6CIDtZQAAAAJ&hl=en' style='text-decoration: none' >Jian Yang</a><sup>1</sup>&emsp;
15
+ <a href='https://mmcheng.net/cmm/' style='text-decoration: none' >Ming-Ming Cheng</a><sup>1,2</sup>&emsp;
16
+ <a href='https://implus.github.io/' style='text-decoration: none' >Xiang Li</a><sup>1,2*</sup>&emsp;
17
+ <p align="center">
18
+ $^{1}$ VCIP, CS, Nankai University, $^{2}$ NKIARI, Shenzhen Futian, $^{3}$ JIIOV Technology,
19
+ $^{4}$ Harvard University, $^{5}$ University of Chinese Academy of Sciences
20
+ <p align='center'>
21
+ <div align="center">
22
+ <a href='https://arxiv.org/abs/2507.01467v2'><img src='https://img.shields.io/badge/arXiv-2507.01467v2-brown.svg?logo=arxiv&logoColor=white'></a>
23
+ <a href='https://huggingface.co/Martinser/REG/tree/main'><img src='https://img.shields.io/badge/🤗-Model-blue.svg'></a>
24
+ <a href='https://zhuanlan.zhihu.com/p/1952346823168595518'><img src='https://img.shields.io/badge/Zhihu-chinese_article-blue.svg?logo=zhihu&logoColor=white'></a>
25
+ </div>
26
+ <p align='center'>
27
+ </p>
28
+ </p>
29
+ </p>
30
+
31
+
32
+ ## 🚩 Overview
33
+
34
+ ![overview](fig/reg.png)
35
+
36
+ REPA and its variants effectively mitigate training challenges in diffusion models by incorporating external visual representations from pretrained models, through alignment between the noisy hidden projections of denoising networks and foundational clean image representations.
37
+ We argue that the external alignment, which is absent during the entire denoising inference process, falls short of fully harnessing the potential of discriminative representations.
38
+
39
+ In this work, we propose a straightforward method called Representation Entanglement for Generation (REG), which entangles low-level image latents with a single high-level class token from pretrained foundation models for denoising.
40
+ REG acquires the capability to produce coherent image-class pairs directly from pure noise,
41
+ substantially improving both generation quality and training efficiency.
42
+ This is accomplished with negligible additional inference overhead, **requiring only one single additional token for denoising (<0.5\% increase in FLOPs and latency).**
43
+ The inference process concurrently reconstructs both image latents and their corresponding global semantics, where the acquired semantic knowledge actively guides and enhances the image generation process.
44
+
45
+ On ImageNet $256{\times}256$, SiT-XL/2 + REG demonstrates remarkable convergence acceleration, **achieving $\textbf{63}\times$ and $\textbf{23}\times$ faster training than SiT-XL/2 and SiT-XL/2 + REPA, respectively.**
46
+ More impressively, SiT-L/2 + REG trained for merely 400K iterations outperforms SiT-XL/2 + REPA trained for 4M iterations ($\textbf{10}\times$ longer).
47
+
48
+
49
+
50
+ ## 📰 News
51
+
52
+ - **[2025.08.05]** We have released the pre-trained weights of REG + SiT-XL/2 in 4M (800 epochs).
53
+
54
+
55
+ ## 📝 Results
56
+
57
+ - Performance on ImageNet $256{\times}256$ with FID=1.36 by introducing a single class token.
58
+ - $\textbf{63}\times$ and $\textbf{23}\times$ faster training than SiT-XL/2 and SiT-XL/2 + REPA.
59
+
60
+ <div align="center">
61
+ <img src="fig/img.png" alt="Results">
62
+ </div>
63
+
64
+
65
+ ## 📋 Plan
66
+ - More training steps on ImageNet 256&512 and T2I.
67
+
68
+
69
+ ## 👊 Usage
70
+
71
+ ### 1. Environment setup
72
+
73
+ ```bash
74
+ conda create -n reg python=3.10.16 -y
75
+ conda activate reg
76
+ pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1
77
+ pip install -r requirements.txt
78
+ ```
79
+
80
+ ### 2. Dataset
81
+
82
+ #### Dataset download
83
+
84
+ Currently, we provide experiments for ImageNet. You can place the data that you want and can specifiy it via `--data-dir` arguments in training scripts.
85
+
86
+ #### Preprocessing data
87
+ Please refer to the preprocessing guide. And you can directly download our processed data, ImageNet data [link](https://huggingface.co/WindATree/ImageNet-256-VAE/tree/main), and ImageNet data after VAE encoder [link]( https://huggingface.co/WindATree/vae-sd/tree/main)
88
+
89
+ ### 3. Training
90
+ Run train.sh
91
+ ```bash
92
+ bash train.sh
93
+ ```
94
+
95
+ train.sh contains the following content.
96
+ ```bash
97
+ accelerate launch --multi_gpu --num_processes $NUM_GPUS train.py \
98
+ --report-to="wandb" \
99
+ --allow-tf32 \
100
+ --mixed-precision="fp16" \
101
+ --seed=0 \
102
+ --path-type="linear" \
103
+ --prediction="v" \
104
+ --weighting="uniform" \
105
+ --model="SiT-B/2" \
106
+ --enc-type="dinov2-vit-b" \
107
+ --proj-coeff=0.5 \
108
+ --encoder-depth=4 \ #SiT-L/XL use 8, SiT-B use 4
109
+ --output-dir="your_path" \
110
+ --exp-name="linear-dinov2-b-enc4" \
111
+ --batch-size=256 \
112
+ --data-dir="data_path/imagenet_vae" \
113
+ --cls=0.03
114
+ ```
115
+
116
+ Then this script will automatically create the folder in `exps` to save logs and checkpoints. You can adjust the following options:
117
+
118
+ - `--models`: `[SiT-B/2, SiT-L/2, SiT-XL/2]`
119
+ - `--enc-type`: `[dinov2-vit-b, clip-vit-L]`
120
+ - `--proj-coeff`: Any values larger than 0
121
+ - `--encoder-depth`: Any values between 1 to the depth of the model
122
+ - `--output-dir`: Any directory that you want to save checkpoints and logs
123
+ - `--exp-name`: Any string name (the folder will be created under `output-dir`)
124
+ - `--cls`: Weight coefficients of REG loss
125
+
126
+
127
+ ### 4. Generate images and evaluation
128
+ You can generate images and get the final results through the following script.
129
+ The weight of REG can be found in this [link](https://pan.baidu.com/s/1QX2p3ybh1KfNU7wsp5McWw?pwd=khpp) or [HF](https://huggingface.co/Martinser/REG/tree/main).
130
+
131
+ ```bash
132
+ bash eval.sh
133
+ ```
134
+
135
+
136
+ ## Citation
137
+ If you find our work, this repository, or pretrained models useful, please consider giving a star and citation.
138
+ ```
139
+ @article{wu2025representation,
140
+ title={Representation Entanglement for Generation: Training Diffusion Transformers Is Much Easier Than You Think},
141
+ author={Wu, Ge and Zhang, Shen and Shi, Ruijing and Gao, Shanghua and Chen, Zhenyuan and Wang, Lei and Chen, Zhaowei and Gao, Hongcheng and Tang, Yao and Yang, Jian and others},
142
+ journal={arXiv preprint arXiv:2507.01467},
143
+ year={2025}
144
+ }
145
+ ```
146
+
147
+ ## Contact
148
+ If you have any questions, please create an issue on this repository, contact at gewu.nku@gmail.com or wechat(wg1158848).
149
+
150
+
151
+ ## Acknowledgements
152
+
153
+ Our code is based on [REPA](https://github.com/sihyun-yu/REPA), along with [SiT](https://github.com/willisma/SiT), [DINOv2](https://github.com/facebookresearch/dinov2), [ADM](https://github.com/openai/guided-diffusion) and [U-ViT](https://github.com/baofff/U-ViT) repositories. We thank the authors for releasing their code. If you use our model and code, please consider citing these works as well.
154
+
155
+
156
+
back/eval.sh ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ random_number=$((RANDOM % 100 + 1200))
3
+ NUM_GPUS=8
4
+ STEP="4000000"
5
+ SAVE_PATH="your_path/reg_xlarge_dinov2_base_align_8_cls/linear-dinov2-b-enc8"
6
+ VAE_PATH="your_vae_path/"
7
+ NUM_STEP=250
8
+ MODEL_SIZE='XL'
9
+ CFG_SCALE=2.3
10
+ CLS_CFG_SCALE=2.3
11
+ GH=0.85
12
+
13
+ export NCCL_P2P_DISABLE=1
14
+
15
+ python -m torch.distributed.launch --master_port=$random_number --nproc_per_node=$NUM_GPUS generate.py \
16
+ --model SiT-XL/2 \
17
+ --num-fid-samples 50000 \
18
+ --ckpt ${SAVE_PATH}/checkpoints/${STEP}.pt \
19
+ --path-type=linear \
20
+ --encoder-depth=8 \
21
+ --projector-embed-dims=768 \
22
+ --per-proc-batch-size=64 \
23
+ --mode=sde \
24
+ --num-steps=${NUM_STEP} \
25
+ --cfg-scale=${CFG_SCALE} \
26
+ --cls-cfg-scale=${CLS_CFG_SCALE} \
27
+ --guidance-high=${GH} \
28
+ --sample-dir ${SAVE_PATH}/checkpoints \
29
+ --cls=768
30
+
31
+
32
+ python ./evaluations/evaluator.py \
33
+ --ref_batch your_path/VIRTUAL_imagenet256_labeled.npz \
34
+ --sample_batch ${SAVE_PATH}/checkpoints/SiT-${MODEL_SIZE}-2-${STEP}-size-256-vae-ema-cfg-${CFG_SCALE}-seed-0-sde-${GH}-${CLS_CFG_SCALE}.npz \
35
+ --save_path ${SAVE_PATH}/checkpoints \
36
+ --cfg_cond 1 \
37
+ --step ${STEP} \
38
+ --num_steps ${NUM_STEP} \
39
+ --cfg ${CFG_SCALE} \
40
+ --cls_cfg ${CLS_CFG_SCALE} \
41
+ --gh ${GH}
42
+
43
+
44
+
45
+
46
+
47
+
48
+
49
+
50
+
51
+
52
+
back/loss.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn.functional as F
4
+
5
+ try:
6
+ from scipy.optimize import linear_sum_assignment
7
+ except ImportError:
8
+ linear_sum_assignment = None
9
+
10
+
11
+ def ot_pair_noise_to_cls(noise_cls, cls_gt):
12
+ """
13
+ Minibatch OT(与 conditional-flow-matching / torchcfm 中 sample_plan_with_scipy 一致):
14
+ 在 batch 内用平方欧氏代价重排 noise,使 noise_ot[i] 与 cls_gt[i] 构成近似最优传输配对。
15
+ noise_cls, cls_gt: (N, D) 或任意可在最后一维展平为 D 的形状。
16
+ """
17
+ n = noise_cls.shape[0]
18
+ if n <= 1:
19
+ return noise_cls, cls_gt
20
+ if linear_sum_assignment is None:
21
+ return noise_cls, cls_gt
22
+ x0 = noise_cls.detach().float().reshape(n, -1)
23
+ x1 = cls_gt.detach().float().reshape(n, -1)
24
+ M = torch.cdist(x0, x1) ** 2
25
+ _, j = linear_sum_assignment(M.cpu().numpy())
26
+ j = torch.as_tensor(j, device=noise_cls.device, dtype=torch.long)
27
+ return noise_cls[j], cls_gt
28
+
29
+
30
+ def mean_flat(x):
31
+ """
32
+ Take the mean over all non-batch dimensions.
33
+ """
34
+ return torch.mean(x, dim=list(range(1, len(x.size()))))
35
+
36
+ def sum_flat(x):
37
+ """
38
+ Take the mean over all non-batch dimensions.
39
+ """
40
+ return torch.sum(x, dim=list(range(1, len(x.size()))))
41
+
42
+ class SILoss:
43
+ def __init__(
44
+ self,
45
+ prediction='v',
46
+ path_type="linear",
47
+ weighting="uniform",
48
+ encoders=[],
49
+ accelerator=None,
50
+ latents_scale=None,
51
+ latents_bias=None,
52
+ t_c=0.5,
53
+ ot_cls=True,
54
+ ):
55
+ self.prediction = prediction
56
+ self.weighting = weighting
57
+ self.path_type = path_type
58
+ self.encoders = encoders
59
+ self.accelerator = accelerator
60
+ self.latents_scale = latents_scale
61
+ self.latents_bias = latents_bias
62
+ # t 与 train.py / JsFlow 一致:t=0 为干净 latent,t=1 为纯噪声。
63
+ # t ∈ (t_c, 1]:语义 cls 沿 OT 配对后的路径从噪声演化为 cls_gt(生成语义通道);
64
+ # t ∈ [0, t_c]:cls 恒为真实 cls_gt,目标速度为 0(通道不再插值)。
65
+ tc = float(t_c)
66
+ self.t_c = min(max(tc, 1e-4), 1.0 - 1e-4)
67
+ self.ot_cls = bool(ot_cls)
68
+
69
+ def interpolant(self, t):
70
+ if self.path_type == "linear":
71
+ alpha_t = 1 - t
72
+ sigma_t = t
73
+ d_alpha_t = -1
74
+ d_sigma_t = 1
75
+ elif self.path_type == "cosine":
76
+ alpha_t = torch.cos(t * np.pi / 2)
77
+ sigma_t = torch.sin(t * np.pi / 2)
78
+ d_alpha_t = -np.pi / 2 * torch.sin(t * np.pi / 2)
79
+ d_sigma_t = np.pi / 2 * torch.cos(t * np.pi / 2)
80
+ else:
81
+ raise NotImplementedError()
82
+
83
+ return alpha_t, sigma_t, d_alpha_t, d_sigma_t
84
+
85
+ def __call__(self, model, images, model_kwargs=None, zs=None, cls_token=None,
86
+ time_input=None, noises=None,):
87
+ if model_kwargs == None:
88
+ model_kwargs = {}
89
+ # sample timesteps
90
+ if time_input is None:
91
+ if self.weighting == "uniform":
92
+ time_input = torch.rand((images.shape[0], 1, 1, 1))
93
+ elif self.weighting == "lognormal":
94
+ # sample timestep according to log-normal distribution of sigmas following EDM
95
+ rnd_normal = torch.randn((images.shape[0], 1 ,1, 1))
96
+ sigma = rnd_normal.exp()
97
+ if self.path_type == "linear":
98
+ time_input = sigma / (1 + sigma)
99
+ elif self.path_type == "cosine":
100
+ time_input = 2 / np.pi * torch.atan(sigma)
101
+
102
+ time_input = time_input.to(device=images.device, dtype=torch.float32)
103
+ cls_token = cls_token.to(device=images.device, dtype=torch.float32)
104
+
105
+ if noises is None:
106
+ noises = torch.randn_like(images)
107
+ noises_cls = torch.randn_like(cls_token)
108
+ else:
109
+ if isinstance(noises, (tuple, list)) and len(noises) == 2:
110
+ noises, noises_cls = noises
111
+ else:
112
+ noises_cls = torch.randn_like(cls_token)
113
+
114
+ alpha_t, sigma_t, d_alpha_t, d_sigma_t = self.interpolant(time_input)
115
+
116
+ model_input = alpha_t * images + sigma_t * noises
117
+ if self.prediction == 'v':
118
+ model_target = d_alpha_t * images + d_sigma_t * noises
119
+ else:
120
+ raise NotImplementedError()
121
+
122
+ N = images.shape[0]
123
+ t_flat = time_input.view(-1).float()
124
+ high_noise_mask = (t_flat > self.t_c).float().view(N, *([1] * (cls_token.dim() - 1)))
125
+ low_noise_mask = 1.0 - high_noise_mask
126
+
127
+ noise_cls_raw = noises_cls
128
+ if self.ot_cls:
129
+ noise_cls_paired, cls_gt_paired = ot_pair_noise_to_cls(noise_cls_raw, cls_token)
130
+ else:
131
+ noise_cls_paired, cls_gt_paired = noise_cls_raw, cls_token
132
+
133
+ tau_shape = (N,) + (1,) * max(0, cls_token.dim() - 1)
134
+ tau = (time_input.reshape(tau_shape) - self.t_c) / (1.0 - self.t_c + 1e-8)
135
+ tau = torch.clamp(tau, 0.0, 1.0)
136
+ alpha_sem = 1.0 - tau
137
+ sigma_sem = tau
138
+
139
+ cls_t_high = alpha_sem * cls_gt_paired + sigma_sem * noise_cls_paired
140
+ cls_t = high_noise_mask * cls_t_high + low_noise_mask * cls_token
141
+ cls_t = torch.nan_to_num(cls_t, nan=0.0, posinf=1e4, neginf=-1e4)
142
+ cls_t = torch.clamp(cls_t, -1e4, 1e4)
143
+
144
+ cls_for_model = cls_t * high_noise_mask + cls_t.detach() * low_noise_mask
145
+
146
+ inv_scale = 1.0 / (1.0 - self.t_c + 1e-8)
147
+ v_cls_high = (noise_cls_paired - cls_gt_paired) * inv_scale
148
+ v_cls_target = high_noise_mask * v_cls_high
149
+
150
+ model_output, zs_tilde, cls_output = model(
151
+ model_input, time_input.flatten(), **model_kwargs, cls_token=cls_for_model
152
+ )
153
+
154
+ #denoising_loss
155
+ denoising_loss = mean_flat((model_output - model_target) ** 2)
156
+ denoising_loss_cls = mean_flat((cls_output - v_cls_target) ** 2)
157
+
158
+ # projection loss
159
+ proj_loss = 0.
160
+ bsz = zs[0].shape[0]
161
+ for i, (z, z_tilde) in enumerate(zip(zs, zs_tilde)):
162
+ for j, (z_j, z_tilde_j) in enumerate(zip(z, z_tilde)):
163
+ z_tilde_j = torch.nn.functional.normalize(z_tilde_j, dim=-1)
164
+ z_j = torch.nn.functional.normalize(z_j, dim=-1)
165
+ proj_loss += mean_flat(-(z_j * z_tilde_j).sum(dim=-1))
166
+ proj_loss /= (len(zs) * bsz)
167
+
168
+ return denoising_loss, proj_loss, time_input, noises, denoising_loss_cls
back/requirements.txt ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ - pip:
2
+ absl-py==2.2.2
3
+ accelerate==1.2.1
4
+ aiohappyeyeballs==2.6.1
5
+ aiohttp==3.11.16
6
+ aiosignal==1.3.2
7
+ astunparse==1.6.3
8
+ async-timeout==5.0.1
9
+ attrs==25.3.0
10
+ certifi==2022.12.7
11
+ charset-normalizer==2.1.1
12
+ click==8.1.8
13
+ datasets==2.20.0
14
+ diffusers==0.32.1
15
+ dill==0.3.8
16
+ docker-pycreds==0.4.0
17
+ einops==0.8.1
18
+ filelock==3.13.1
19
+ flatbuffers==25.2.10
20
+ frozenlist==1.5.0
21
+ fsspec==2024.5.0
22
+ ftfy==6.3.1
23
+ gast==0.6.0
24
+ gitdb==4.0.12
25
+ gitpython==3.1.44
26
+ google-pasta==0.2.0
27
+ grpcio==1.71.0
28
+ h5py==3.13.0
29
+ huggingface-hub==0.27.1
30
+ idna==3.4
31
+ importlib-metadata==8.6.1
32
+ jinja2==3.1.4
33
+ joblib==1.4.2
34
+ keras==3.9.2
35
+ libclang==18.1.1
36
+ markdown==3.8
37
+ markdown-it-py==3.0.0
38
+ markupsafe==2.1.5
39
+ mdurl==0.1.2
40
+ ml-dtypes==0.3.2
41
+ mpmath==1.3.0
42
+ multidict==6.4.3
43
+ multiprocess==0.70.16
44
+ namex==0.0.8
45
+ networkx==3.3
46
+ numpy==1.26.4
47
+ opt-einsum==3.4.0
48
+ optree==0.15.0
49
+ packaging==24.2
50
+ pandas==2.2.3
51
+ pillow==11.0.0
52
+ platformdirs==4.3.7
53
+ propcache==0.3.1
54
+ protobuf==4.25.6
55
+ psutil==7.0.0
56
+ pyarrow==19.0.1
57
+ pyarrow-hotfix==0.6
58
+ pygments==2.19.1
59
+ python-dateutil==2.9.0.post0
60
+ pytz==2025.2
61
+ pyyaml==6.0.2
62
+ regex==2024.11.6
63
+ requests==2.32.3
64
+ rich==14.0.0
65
+ safetensors==0.5.3
66
+ scikit-learn==1.5.1
67
+ scipy==1.15.2
68
+ sentry-sdk==2.26.1
69
+ setproctitle==1.3.5
70
+ six==1.17.0
71
+ smmap==5.0.2
72
+ sympy==1.13.1
73
+ tensorboard==2.16.1
74
+ tensorboard-data-server==0.7.2
75
+ tensorflow==2.16.1
76
+ tensorflow-io-gcs-filesystem==0.37.1
77
+ termcolor==3.0.1
78
+ tf-keras==2.16.0
79
+ threadpoolctl==3.6.0
80
+ timm==1.0.12
81
+ tokenizers==0.21.0
82
+ tqdm==4.67.1
83
+ transformers==4.47.0
84
+ triton==2.1.0
85
+ typing-extensions==4.12.2
86
+ tzdata==2025.2
87
+ urllib3==1.26.13
88
+ wandb==0.17.6
89
+ wcwidth==0.2.13
90
+ werkzeug==3.1.3
91
+ wrapt==1.17.2
92
+ xformer==1.0.1
93
+ xformers==0.0.23
94
+ xxhash==3.5.0
95
+ yarl==1.20.0
96
+ zipp==3.21.0
97
+
back/sample_from_checkpoint.py ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ 从 REG/train.py 保存的检查点加载权重,在指定目录生成若干 PNG。
4
+
5
+ 示例:
6
+ python sample_from_checkpoint.py \\
7
+ --ckpt exps/jsflow-experiment/checkpoints/0050000.pt \\
8
+ --out-dir ./samples_gen \\
9
+ --num-images 64 \\
10
+ --batch-size 8
11
+
12
+ # 按训练 t_c 分段分配步数(t=1→t_c 与 t_c→0;--t-c 可省略若检查点含 t_c):
13
+ python sample_from_checkpoint.py ... \\
14
+ --steps-before-tc 150 --steps-after-tc 100 --t-c 0.5
15
+
16
+ # 同一批初始噪声连跑两种 t_c 后段步数(输出到 out-dir 下子目录):
17
+ python sample_from_checkpoint.py ... \\
18
+ --steps-before-tc 150 --steps-after-tc 5 --dual-compare-after
19
+ # 分段时会在 at_tc/(或 at_tc/after_input、at_tc/after_equal_before)额外保存 t≈t_c 的解码图。
20
+
21
+ 检查点需包含 train.py 写入的键:ema(或 model)、args(推荐,用于自动还原结构)。
22
+ 若缺少 args,需通过命令行显式传入 --model、--resolution、--enc-type 等。
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ import argparse
28
+ import os
29
+ import sys
30
+ import types
31
+ import numpy as np
32
+ import torch
33
+ from diffusers.models import AutoencoderKL
34
+ from PIL import Image
35
+ from tqdm import tqdm
36
+
37
+ from models.sit import SiT_models
38
+ from samplers import (
39
+ euler_maruyama_image_noise_before_tc_sampler,
40
+ euler_maruyama_image_noise_sampler,
41
+ euler_maruyama_sampler,
42
+ euler_ode_sampler,
43
+ )
44
+
45
+
46
+ def semantic_dim_from_enc_type(enc_type):
47
+ """与 train.py 一致:按 enc_type 推断语义/class token 维度。"""
48
+ if enc_type is None:
49
+ return 768
50
+ s = str(enc_type).lower()
51
+ if "vit-g" in s or "vitg" in s:
52
+ return 1536
53
+ if "vit-l" in s or "vitl" in s:
54
+ return 1024
55
+ if "vit-s" in s or "vits" in s:
56
+ return 384
57
+ return 768
58
+
59
+
60
+ def load_train_args_from_ckpt(ckpt: dict) -> argparse.Namespace | None:
61
+ a = ckpt.get("args")
62
+ if a is None:
63
+ return None
64
+ if isinstance(a, argparse.Namespace):
65
+ return a
66
+ if isinstance(a, dict):
67
+ return argparse.Namespace(**a)
68
+ if isinstance(a, types.SimpleNamespace):
69
+ return argparse.Namespace(**vars(a))
70
+ return None
71
+
72
+
73
+ def load_vae(device: torch.device):
74
+ """与 train.py 相同策略:优先本地 diffusers 缓存中的 sd-vae-ft-mse。"""
75
+ try:
76
+ from preprocessing import dnnlib
77
+
78
+ cache_dir = dnnlib.make_cache_dir_path("diffusers")
79
+ os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1")
80
+ os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
81
+ os.environ["HF_HOME"] = cache_dir
82
+ try:
83
+ vae = AutoencoderKL.from_pretrained(
84
+ "stabilityai/sd-vae-ft-mse",
85
+ cache_dir=cache_dir,
86
+ local_files_only=True,
87
+ ).to(device)
88
+ vae.eval()
89
+ print(f"Loaded VAE from local cache: {cache_dir}")
90
+ return vae
91
+ except Exception:
92
+ pass
93
+ candidate_dir = None
94
+ for root_dir in [
95
+ cache_dir,
96
+ os.path.join(os.path.expanduser("~"), ".cache", "dnnlib", "diffusers"),
97
+ os.path.join(os.path.expanduser("~"), ".cache", "diffusers"),
98
+ os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub"),
99
+ ]:
100
+ if not os.path.isdir(root_dir):
101
+ continue
102
+ for root, _, files in os.walk(root_dir):
103
+ if "config.json" in files and "sd-vae-ft-mse" in root.replace("\\", "/"):
104
+ candidate_dir = root
105
+ break
106
+ if candidate_dir is not None:
107
+ break
108
+ if candidate_dir is not None:
109
+ vae = AutoencoderKL.from_pretrained(candidate_dir, local_files_only=True).to(device)
110
+ vae.eval()
111
+ print(f"Loaded VAE from {candidate_dir}")
112
+ return vae
113
+ except Exception as e:
114
+ print(f"VAE local cache search failed: {e}", file=sys.stderr)
115
+ try:
116
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device)
117
+ vae.eval()
118
+ print("Loaded VAE from Hub: stabilityai/sd-vae-ft-mse")
119
+ return vae
120
+ except Exception as e:
121
+ raise RuntimeError(
122
+ "无法加载 VAE stabilityai/sd-vae-ft-mse,请确认已下载或网络可用。"
123
+ ) from e
124
+
125
+
126
+ def build_model_from_train_args(ta: argparse.Namespace, device: torch.device):
127
+ res = int(getattr(ta, "resolution", 256))
128
+ latent_size = res // 8
129
+ enc_type = getattr(ta, "enc_type", "dinov2-vit-b")
130
+ z_dims = [semantic_dim_from_enc_type(enc_type)]
131
+ block_kwargs = {
132
+ "fused_attn": getattr(ta, "fused_attn", True),
133
+ "qk_norm": getattr(ta, "qk_norm", False),
134
+ }
135
+ cfg_prob = float(getattr(ta, "cfg_prob", 0.1))
136
+ if ta.model not in SiT_models:
137
+ raise ValueError(f"未知 model={ta.model!r},可选:{list(SiT_models.keys())}")
138
+ model = SiT_models[ta.model](
139
+ input_size=latent_size,
140
+ num_classes=int(getattr(ta, "num_classes", 1000)),
141
+ use_cfg=(cfg_prob > 0),
142
+ z_dims=z_dims,
143
+ encoder_depth=int(getattr(ta, "encoder_depth", 8)),
144
+ **block_kwargs,
145
+ ).to(device)
146
+ return model, z_dims[0]
147
+
148
+
149
+ def resolve_tc_schedule(cli, ta):
150
+ """
151
+ 若同时给出 --steps-before-tc 与 --steps-after-tc:在 t_c 处分段(--t-c 缺省则用检查点 args.t_c)。
152
+ 否则使用均匀 --num-steps(与旧版一致)。
153
+ """
154
+ sb = cli.steps_before_tc
155
+ sa = cli.steps_after_tc
156
+ tc = cli.t_c
157
+ if sb is None and sa is None:
158
+ return None, None, None
159
+ if sb is None or sa is None:
160
+ print(
161
+ "使用分段步数时必须同时指定 --steps-before-tc 与 --steps-after-tc。",
162
+ file=sys.stderr,
163
+ )
164
+ sys.exit(1)
165
+ if tc is None:
166
+ tc = getattr(ta, "t_c", None) if ta is not None else None
167
+ if tc is None:
168
+ print(
169
+ "分段采样需要 --t-c,或检查点 args 中含 t_c。",
170
+ file=sys.stderr,
171
+ )
172
+ sys.exit(1)
173
+ return float(tc), int(sb), int(sa)
174
+
175
+
176
+ def parse_cli():
177
+ p = argparse.ArgumentParser(description="REG 检查点采样出图(可选 ODE/EM/EM-图像噪声)")
178
+ p.add_argument("--ckpt", type=str, required=True, help="train.py 保存的 .pt 路径")
179
+ p.add_argument("--out-dir", type=str, required=True, help="输出 PNG 目录(会创建)")
180
+ p.add_argument("--num-images", type=int, required=True, help="生成图片总数")
181
+ p.add_argument("--batch-size", type=int, default=16)
182
+ p.add_argument("--seed", type=int, default=0)
183
+ p.add_argument(
184
+ "--weights",
185
+ type=str,
186
+ choices=("ema", "model"),
187
+ default="ema",
188
+ help="使用检查点中的 ema 或 model 权重",
189
+ )
190
+ p.add_argument("--device", type=str, default="cuda", help="如 cuda 或 cuda:0")
191
+ p.add_argument(
192
+ "--num-steps",
193
+ type=int,
194
+ default=50,
195
+ help="均匀时间网格时的欧拉步数(未使用 --steps-before-tc/--steps-after-tc 时生效)",
196
+ )
197
+ p.add_argument(
198
+ "--t-c",
199
+ type=float,
200
+ default=None,
201
+ help="分段时刻:t∈(t_c,1] 与 t∈[0,t_c] 两段;缺省可用检查点 args.t_c(需配合两段步数)",
202
+ )
203
+ p.add_argument(
204
+ "--steps-before-tc",
205
+ type=int,
206
+ default=None,
207
+ help="从 t=1 积分到 t=t_c 的步数(与 --steps-after-tc 成对使用)",
208
+ )
209
+ p.add_argument(
210
+ "--steps-after-tc",
211
+ type=int,
212
+ default=None,
213
+ help="从 t=t_c 积分到 t=0(经 t_floor=0.04)的步数",
214
+ )
215
+ p.add_argument("--cfg-scale", type=float, default=1.0)
216
+ p.add_argument("--cls-cfg-scale", type=float, default=0.0, help="cls 分支 CFG(>0 时需 cfg-scale>1)")
217
+ p.add_argument("--guidance-low", type=float, default=0.0)
218
+ p.add_argument("--guidance-high", type=float, default=1.0)
219
+ p.add_argument(
220
+ "--path-type",
221
+ type=str,
222
+ default=None,
223
+ choices=["linear", "cosine"],
224
+ help="默认从检查点 args 读取;可覆盖",
225
+ )
226
+ p.add_argument("--legacy", action=argparse.BooleanOptionalAction, default=False)
227
+ # 无 args 时的兜底
228
+ p.add_argument("--model", type=str, default=None, help="无检查点 args 时必填;与 SiT_models 键一致,如 SiT-XL/2")
229
+ p.add_argument("--resolution", type=int, default=None, choices=[256, 512])
230
+ p.add_argument("--num-classes", type=int, default=None)
231
+ p.add_argument("--encoder-depth", type=int, default=None)
232
+ p.add_argument("--enc-type", type=str, default=None)
233
+ p.add_argument("--fused-attn", action=argparse.BooleanOptionalAction, default=None)
234
+ p.add_argument("--qk-norm", action=argparse.BooleanOptionalAction, default=None)
235
+ p.add_argument("--cfg-prob", type=float, default=None)
236
+ p.add_argument(
237
+ "--sampler",
238
+ type=str,
239
+ default="em_image_noise",
240
+ choices=["ode", "em", "em_image_noise", "em_image_noise_before_tc"],
241
+ help="采样器:ode=euler_sampler 确定性漂移(linspace 1→0 或 t_c 分段直连 0,无 t_floor;与 EM 网格不同),"
242
+ "em=标准EM(含图像+cls噪声),em_image_noise=仅图像噪声,"
243
+ "em_image_noise_before_tc=t<=t_c时图像去随机+cls全程去随机",
244
+ )
245
+ p.add_argument(
246
+ "--dual-compare-after",
247
+ action="store_true",
248
+ help="需配合分段步数:同批 z/y/cls 连跑两次;after_input 用 --steps-after-tc,"
249
+ "after_equal_before 将 after 步数设为与 --steps-before-tc 相同",
250
+ )
251
+ p.add_argument(
252
+ "--save-fixed-trajectory",
253
+ action="store_true",
254
+ help="保存固定步采样轨迹(npy);仅对非 em 采样器启用,输出在 out-dir/trajectory",
255
+ )
256
+ return p.parse_args()
257
+
258
+
259
+ def _decode_to_uint8_hwc(latents, latents_bias, latents_scale, vae):
260
+ imgs = vae.decode((latents - latents_bias) / latents_scale).sample
261
+ imgs = (imgs + 1) / 2.0
262
+ imgs = torch.clamp(imgs, 0, 1)
263
+ return (
264
+ (imgs * 255.0)
265
+ .round()
266
+ .to(torch.uint8)
267
+ .permute(0, 2, 3, 1)
268
+ .cpu()
269
+ .numpy()
270
+ )
271
+
272
+
273
+ def main():
274
+ cli = parse_cli()
275
+ device = torch.device(cli.device if torch.cuda.is_available() else "cpu")
276
+ if device.type == "cuda":
277
+ torch.backends.cuda.matmul.allow_tf32 = True
278
+
279
+ try:
280
+ ckpt = torch.load(cli.ckpt, map_location="cpu", weights_only=False)
281
+ except TypeError:
282
+ ckpt = torch.load(cli.ckpt, map_location="cpu")
283
+ ta = load_train_args_from_ckpt(ckpt)
284
+ if ta is None:
285
+ if cli.model is None or cli.resolution is None or cli.enc_type is None:
286
+ print(
287
+ "检查点中无 args,请至少指定:--model --resolution --enc-type "
288
+ "(以及按需 --num-classes --encoder-depth)",
289
+ file=sys.stderr,
290
+ )
291
+ sys.exit(1)
292
+ ta = argparse.Namespace(
293
+ model=cli.model,
294
+ resolution=cli.resolution,
295
+ num_classes=cli.num_classes if cli.num_classes is not None else 1000,
296
+ encoder_depth=cli.encoder_depth if cli.encoder_depth is not None else 8,
297
+ enc_type=cli.enc_type,
298
+ fused_attn=cli.fused_attn if cli.fused_attn is not None else True,
299
+ qk_norm=cli.qk_norm if cli.qk_norm is not None else False,
300
+ cfg_prob=cli.cfg_prob if cli.cfg_prob is not None else 0.1,
301
+ )
302
+ else:
303
+ if cli.model is not None:
304
+ ta.model = cli.model
305
+ if cli.resolution is not None:
306
+ ta.resolution = cli.resolution
307
+ if cli.num_classes is not None:
308
+ ta.num_classes = cli.num_classes
309
+ if cli.encoder_depth is not None:
310
+ ta.encoder_depth = cli.encoder_depth
311
+ if cli.enc_type is not None:
312
+ ta.enc_type = cli.enc_type
313
+ if cli.fused_attn is not None:
314
+ ta.fused_attn = cli.fused_attn
315
+ if cli.qk_norm is not None:
316
+ ta.qk_norm = cli.qk_norm
317
+ if cli.cfg_prob is not None:
318
+ ta.cfg_prob = cli.cfg_prob
319
+
320
+ path_type = cli.path_type if cli.path_type is not None else getattr(ta, "path_type", "linear")
321
+
322
+ tc_split = resolve_tc_schedule(cli, ta)
323
+ if cli.dual_compare_after and tc_split[0] is None:
324
+ print("--dual-compare-after 必须配合 --steps-before-tc 与 --steps-after-tc(分段采样)", file=sys.stderr)
325
+ sys.exit(1)
326
+ if tc_split[0] is not None:
327
+ if cli.dual_compare_after:
328
+ print(
329
+ f"双次对比:t_c={tc_split[0]}, before={tc_split[1]}, "
330
+ f"after_input={tc_split[2]}, after_equal_before={tc_split[1]}"
331
+ )
332
+ else:
333
+ print(
334
+ f"时间网格:t_c={tc_split[0]}, 步数 (1→t_c)={tc_split[1]}, (t_c→0)={tc_split[2]} "
335
+ f"(总模型前向约 {tc_split[1] + tc_split[2] + 1} 次)"
336
+ )
337
+ else:
338
+ print(f"时间网格:均匀 num_steps={cli.num_steps}")
339
+
340
+ if cli.sampler == "ode":
341
+ sampler_fn = euler_ode_sampler
342
+ elif cli.sampler == "em":
343
+ sampler_fn = euler_maruyama_sampler
344
+ elif cli.sampler == "em_image_noise_before_tc":
345
+ sampler_fn = euler_maruyama_image_noise_before_tc_sampler
346
+ else:
347
+ sampler_fn = euler_maruyama_image_noise_sampler
348
+
349
+ model, cls_dim = build_model_from_train_args(ta, device)
350
+ wkey = cli.weights
351
+ if wkey not in ckpt:
352
+ raise KeyError(f"检查点中无 '{wkey}' 键,现有键:{list(ckpt.keys())}")
353
+ state = ckpt[wkey]
354
+ if cli.legacy:
355
+ from utils import load_legacy_checkpoints
356
+
357
+ state = load_legacy_checkpoints(
358
+ state_dict=state, encoder_depth=int(getattr(ta, "encoder_depth", 8))
359
+ )
360
+ model.load_state_dict(state, strict=True)
361
+ model.eval()
362
+
363
+ vae = load_vae(device)
364
+ latents_scale = torch.tensor([0.18215] * 4, device=device).view(1, 4, 1, 1)
365
+ latents_bias = torch.tensor([0.0] * 4, device=device).view(1, 4, 1, 1)
366
+
367
+ sampler_args = argparse.Namespace(cls_cfg_scale=float(cli.cls_cfg_scale))
368
+
369
+ at_tc_dir = at_tc_a = at_tc_b = None
370
+ traj_dir = traj_a = traj_b = None
371
+ if cli.dual_compare_after:
372
+ out_a = os.path.join(cli.out_dir, "after_input")
373
+ out_b = os.path.join(cli.out_dir, "after_equal_before")
374
+ os.makedirs(out_a, exist_ok=True)
375
+ os.makedirs(out_b, exist_ok=True)
376
+ if tc_split[0] is not None:
377
+ at_tc_a = os.path.join(cli.out_dir, "at_tc", "after_input")
378
+ at_tc_b = os.path.join(cli.out_dir, "at_tc", "after_equal_before")
379
+ os.makedirs(at_tc_a, exist_ok=True)
380
+ os.makedirs(at_tc_b, exist_ok=True)
381
+ if cli.save_fixed_trajectory and cli.sampler != "em":
382
+ traj_a = os.path.join(cli.out_dir, "trajectory", "after_input")
383
+ traj_b = os.path.join(cli.out_dir, "trajectory", "after_equal_before")
384
+ os.makedirs(traj_a, exist_ok=True)
385
+ os.makedirs(traj_b, exist_ok=True)
386
+ else:
387
+ os.makedirs(cli.out_dir, exist_ok=True)
388
+ if tc_split[0] is not None:
389
+ at_tc_dir = os.path.join(cli.out_dir, "at_tc")
390
+ os.makedirs(at_tc_dir, exist_ok=True)
391
+ if cli.save_fixed_trajectory and cli.sampler != "em":
392
+ traj_dir = os.path.join(cli.out_dir, "trajectory")
393
+ os.makedirs(traj_dir, exist_ok=True)
394
+ latent_size = int(getattr(ta, "resolution", 256)) // 8
395
+ n_total = int(cli.num_images)
396
+ b = max(1, int(cli.batch_size))
397
+
398
+ torch.manual_seed(cli.seed)
399
+ if device.type == "cuda":
400
+ torch.cuda.manual_seed_all(cli.seed)
401
+
402
+ written = 0
403
+ pbar = tqdm(total=n_total, desc="sampling")
404
+ while written < n_total:
405
+ cur = min(b, n_total - written)
406
+ z = torch.randn(cur, model.in_channels, latent_size, latent_size, device=device)
407
+ y = torch.randint(0, int(ta.num_classes), (cur,), device=device)
408
+ cls_z = torch.randn(cur, cls_dim, device=device)
409
+
410
+ with torch.no_grad():
411
+ base_kw = dict(
412
+ num_steps=cli.num_steps,
413
+ cfg_scale=cli.cfg_scale,
414
+ guidance_low=cli.guidance_low,
415
+ guidance_high=cli.guidance_high,
416
+ path_type=path_type,
417
+ cls_latents=cls_z,
418
+ args=sampler_args,
419
+ )
420
+ if cli.dual_compare_after:
421
+ tc_v, sb, sa_in = tc_split
422
+ # 两次完整采样会各自消耗 RNG;不重置则第二条的 1→t_c 噪声与第一条不同,z_tc/at_tc 会对不齐。
423
+ # 在固定 z/y/cls_z 之后打快照,第二条运行前恢复,使 t_c 中间态一致(仅后段步数不同)。
424
+ _rng_cpu_dual = torch.random.get_rng_state()
425
+ _rng_cuda_dual = (
426
+ torch.cuda.get_rng_state_all()
427
+ if device.type == "cuda"
428
+ else None
429
+ )
430
+ for _run_i, (subdir, sa, tc_save_dir) in enumerate(
431
+ (
432
+ (out_a, sa_in, at_tc_a),
433
+ (out_b, sb, at_tc_b),
434
+ )
435
+ ):
436
+ if _run_i > 0:
437
+ torch.random.set_rng_state(_rng_cpu_dual)
438
+ if _rng_cuda_dual is not None:
439
+ torch.cuda.set_rng_state_all(_rng_cuda_dual)
440
+ em_kw = dict(base_kw)
441
+ em_kw["t_c"] = tc_v
442
+ em_kw["num_steps_before_tc"] = sb
443
+ em_kw["num_steps_after_tc"] = sa
444
+ if cli.sampler == "em_image_noise_before_tc":
445
+ if cli.save_fixed_trajectory and cli.sampler != "em":
446
+ latents, z_tc, cls_tc, cls_t0, traj = sampler_fn(
447
+ model,
448
+ z,
449
+ y,
450
+ **em_kw,
451
+ return_mid_state=True,
452
+ t_mid=float(tc_v),
453
+ return_cls_final=True,
454
+ return_trajectory=True,
455
+ )
456
+ else:
457
+ latents, z_tc, cls_tc, cls_t0 = sampler_fn(
458
+ model,
459
+ z,
460
+ y,
461
+ **em_kw,
462
+ return_mid_state=True,
463
+ t_mid=float(tc_v),
464
+ return_cls_final=True,
465
+ )
466
+ traj = None
467
+ else:
468
+ if cli.save_fixed_trajectory and cli.sampler != "em":
469
+ latents, z_tc, cls_tc, traj = sampler_fn(
470
+ model,
471
+ z,
472
+ y,
473
+ **em_kw,
474
+ return_mid_state=True,
475
+ t_mid=float(tc_v),
476
+ return_trajectory=True,
477
+ )
478
+ else:
479
+ latents, z_tc, cls_tc = sampler_fn(
480
+ model,
481
+ z,
482
+ y,
483
+ **em_kw,
484
+ return_mid_state=True,
485
+ t_mid=float(tc_v),
486
+ )
487
+ traj = None
488
+ cls_t0 = None
489
+ latents = latents.to(torch.float32)
490
+ imgs = _decode_to_uint8_hwc(latents, latents_bias, latents_scale, vae)
491
+ for i in range(cur):
492
+ Image.fromarray(imgs[i]).save(
493
+ os.path.join(subdir, f"{written + i:06d}.png")
494
+ )
495
+ if tc_save_dir is not None and z_tc is not None:
496
+ imgs_tc = _decode_to_uint8_hwc(
497
+ z_tc.to(torch.float32), latents_bias, latents_scale, vae
498
+ )
499
+ for i in range(cur):
500
+ Image.fromarray(imgs_tc[i]).save(
501
+ os.path.join(tc_save_dir, f"{written + i:06d}.png")
502
+ )
503
+ if traj is not None:
504
+ traj_np = torch.stack(traj, dim=0).to(torch.float32).cpu().numpy()
505
+ save_traj_dir = traj_a if subdir == out_a else traj_b
506
+ np.save(os.path.join(save_traj_dir, f"{written:06d}_traj.npy"), traj_np)
507
+ else:
508
+ em_kw = dict(base_kw)
509
+ if tc_split[0] is not None:
510
+ em_kw["t_c"] = tc_split[0]
511
+ em_kw["num_steps_before_tc"] = tc_split[1]
512
+ em_kw["num_steps_after_tc"] = tc_split[2]
513
+ if cli.sampler == "em_image_noise_before_tc":
514
+ if cli.save_fixed_trajectory and cli.sampler != "em":
515
+ latents, z_tc, cls_tc, cls_t0, traj = sampler_fn(
516
+ model,
517
+ z,
518
+ y,
519
+ **em_kw,
520
+ return_mid_state=True,
521
+ t_mid=float(tc_split[0]),
522
+ return_cls_final=True,
523
+ return_trajectory=True,
524
+ )
525
+ else:
526
+ latents, z_tc, cls_tc, cls_t0 = sampler_fn(
527
+ model,
528
+ z,
529
+ y,
530
+ **em_kw,
531
+ return_mid_state=True,
532
+ t_mid=float(tc_split[0]),
533
+ return_cls_final=True,
534
+ )
535
+ traj = None
536
+ else:
537
+ if cli.save_fixed_trajectory and cli.sampler != "em":
538
+ latents, z_tc, cls_tc, traj = sampler_fn(
539
+ model,
540
+ z,
541
+ y,
542
+ **em_kw,
543
+ return_mid_state=True,
544
+ t_mid=float(tc_split[0]),
545
+ return_trajectory=True,
546
+ )
547
+ else:
548
+ latents, z_tc, cls_tc = sampler_fn(
549
+ model,
550
+ z,
551
+ y,
552
+ **em_kw,
553
+ return_mid_state=True,
554
+ t_mid=float(tc_split[0]),
555
+ )
556
+ traj = None
557
+ cls_t0 = None
558
+ latents = latents.to(torch.float32)
559
+ if z_tc is not None and at_tc_dir is not None:
560
+ imgs_tc = _decode_to_uint8_hwc(
561
+ z_tc.to(torch.float32), latents_bias, latents_scale, vae
562
+ )
563
+ for i in range(cur):
564
+ Image.fromarray(imgs_tc[i]).save(
565
+ os.path.join(at_tc_dir, f"{written + i:06d}.png")
566
+ )
567
+ if traj is not None and traj_dir is not None:
568
+ traj_np = torch.stack(traj, dim=0).to(torch.float32).cpu().numpy()
569
+ np.save(os.path.join(traj_dir, f"{written:06d}_traj.npy"), traj_np)
570
+ else:
571
+ latents = sampler_fn(model, z, y, **em_kw).to(torch.float32)
572
+ imgs = _decode_to_uint8_hwc(latents, latents_bias, latents_scale, vae)
573
+ for i in range(cur):
574
+ Image.fromarray(imgs[i]).save(
575
+ os.path.join(cli.out_dir, f"{written + i:06d}.png")
576
+ )
577
+ written += cur
578
+ pbar.update(cur)
579
+ pbar.close()
580
+ if cli.dual_compare_after:
581
+ msg = (
582
+ f"Done. Saved {written} images per run under {out_a} and {out_b} "
583
+ f"(parent: {cli.out_dir})"
584
+ )
585
+ if tc_split[0] is not None and at_tc_a is not None:
586
+ msg += f"; t≈t_c decoded under {at_tc_a} and {at_tc_b}"
587
+ print(msg)
588
+ else:
589
+ msg = f"Done. Saved {written} images under {cli.out_dir}"
590
+ if tc_split[0] is not None and at_tc_dir is not None:
591
+ msg += f"; t≈t_c decoded under {at_tc_dir}"
592
+ print(msg)
593
+
594
+
595
+ if __name__ == "__main__":
596
+ main()
back/samples.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # 双次对比步数请用 --dual-compare-after(见 sample_from_checkpoint.py),输出在 out-dir 子目录。
3
+
4
+ CUDA_VISIBLE_DEVICES=1 python sample_from_checkpoint.py \
5
+ --ckpt /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/exps/jsflow-experiment-0.75/checkpoints/0500000.pt \
6
+ --out-dir ./my_samples_test \
7
+ --num-images 24 \
8
+ --batch-size 4 \
9
+ --seed 0 \
10
+ --t-c 0.75 \
11
+ --steps-before-tc 50 \
12
+ --steps-after-tc 5 \
13
+ --sampler ode \
14
+ --cfg-scale 1.0 \
15
+ --dual-compare-after \
back/samples_0.5.log ADDED
The diff for this file is too large to render. See raw diff
 
back/samples_ddp.sh ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # 4 卡 DDP 单路径采样(不做 dual-compare,不保存 at_tc 中间图)
4
+ CUDA_VISIBLE_DEVICES=0,1,2,3 nohup nohup torchrun \
5
+ --nnodes=1 \
6
+ --nproc_per_node=4 \
7
+ --rdzv_endpoint=localhost:29110 \
8
+ sample_from_checkpoint_ddp.py \
9
+ --ckpt /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/exps/jsflow-experiment-0.75/checkpoints/0600000.pt \
10
+ --out-dir ./my_samples_600k_new \
11
+ --num-images 40000 \
12
+ --batch-size 64 \
13
+ --seed 0 \
14
+ --t-c 0.75 \
15
+ --steps-before-tc 100 \
16
+ --steps-after-tc 50 \
17
+ --sampler em_image_noise_before_tc \
18
+ --cfg-scale 1.0 \
19
+ > samples_0.75_new.log 2>&1 &
20
+
21
+ # nohup python sample_from_checkpoint_ddp.py \
22
+ # --ckpt /gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/REG/exps/jsflow-experiment-0.5/checkpoints/0250000.pt \
23
+ # --out-dir ./my_samples_5 \
24
+ # --num-images 20000 \
25
+ # --batch-size 16 \
26
+ # --seed 0 \
27
+ # --t-c 0.5 \
28
+ # --steps-before-tc 100 \
29
+ # --steps-after-tc 50 \
30
+ # --sampler em_image_noise_before_tc \
31
+ # --cfg-scale 1.0 \
32
+ # > samples_0.5.log 2>&1 &
back/train.py ADDED
@@ -0,0 +1,670 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import copy
3
+ from copy import deepcopy
4
+ import logging
5
+ import os
6
+ from pathlib import Path
7
+ from collections import OrderedDict
8
+ import json
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as F
13
+ import torch.utils.checkpoint
14
+ from tqdm.auto import tqdm
15
+ from torch.utils.data import DataLoader
16
+
17
+ from accelerate import Accelerator, DistributedDataParallelKwargs
18
+ from accelerate.logging import get_logger
19
+ from accelerate.utils import ProjectConfiguration, set_seed
20
+
21
+ from models.sit import SiT_models
22
+ from loss import SILoss
23
+ from utils import load_encoders
24
+
25
+ from dataset import CustomDataset
26
+ from diffusers.models import AutoencoderKL
27
+ # import wandb_utils
28
+ import wandb
29
+ import math
30
+ from torchvision.utils import make_grid
31
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
32
+ from torchvision.transforms import Normalize
33
+ from PIL import Image
34
+
35
+ logger = get_logger(__name__)
36
+
37
+
38
+ def semantic_dim_from_enc_type(enc_type):
39
+ """DINOv2 等 enc_type 字符串推断 class token 维度(与预处理特征一致)。"""
40
+ if enc_type is None:
41
+ return 768
42
+ s = str(enc_type).lower()
43
+ if "vit-g" in s or "vitg" in s:
44
+ return 1536
45
+ if "vit-l" in s or "vitl" in s:
46
+ return 1024
47
+ if "vit-s" in s or "vits" in s:
48
+ return 384
49
+ return 768
50
+
51
+
52
+ CLIP_DEFAULT_MEAN = (0.48145466, 0.4578275, 0.40821073)
53
+ CLIP_DEFAULT_STD = (0.26862954, 0.26130258, 0.27577711)
54
+
55
+
56
+
57
+ def preprocess_raw_image(x, enc_type):
58
+ resolution = x.shape[-1]
59
+ if 'clip' in enc_type:
60
+ x = x / 255.
61
+ x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic')
62
+ x = Normalize(CLIP_DEFAULT_MEAN, CLIP_DEFAULT_STD)(x)
63
+ elif 'mocov3' in enc_type or 'mae' in enc_type:
64
+ x = x / 255.
65
+ x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
66
+ elif 'dinov2' in enc_type:
67
+ x = x / 255.
68
+ x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
69
+ x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic')
70
+ elif 'dinov1' in enc_type:
71
+ x = x / 255.
72
+ x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
73
+ elif 'jepa' in enc_type:
74
+ x = x / 255.
75
+ x = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(x)
76
+ x = torch.nn.functional.interpolate(x, 224 * (resolution // 256), mode='bicubic')
77
+
78
+ return x
79
+
80
+
81
+ def array2grid(x):
82
+ nrow = round(math.sqrt(x.size(0)))
83
+ x = make_grid(x.clamp(0, 1), nrow=nrow, value_range=(0, 1))
84
+ x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
85
+ return x
86
+
87
+
88
+ @torch.no_grad()
89
+ def sample_posterior(moments, latents_scale=1., latents_bias=0.):
90
+ device = moments.device
91
+
92
+ mean, std = torch.chunk(moments, 2, dim=1)
93
+ z = mean + std * torch.randn_like(mean)
94
+ z = (z * latents_scale + latents_bias)
95
+ return z
96
+
97
+
98
+ @torch.no_grad()
99
+ def update_ema(ema_model, model, decay=0.9999):
100
+ """
101
+ Step the EMA model towards the current model.
102
+ """
103
+ ema_params = OrderedDict(ema_model.named_parameters())
104
+ model_params = OrderedDict(model.named_parameters())
105
+
106
+ for name, param in model_params.items():
107
+ name = name.replace("module.", "")
108
+ # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
109
+ ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
110
+
111
+
112
+ def create_logger(logging_dir):
113
+ """
114
+ Create a logger that writes to a log file and stdout.
115
+ """
116
+ logging.basicConfig(
117
+ level=logging.INFO,
118
+ format='[\033[34m%(asctime)s\033[0m] %(message)s',
119
+ datefmt='%Y-%m-%d %H:%M:%S',
120
+ handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
121
+ )
122
+ logger = logging.getLogger(__name__)
123
+ return logger
124
+
125
+
126
+ def requires_grad(model, flag=True):
127
+ """
128
+ Set requires_grad flag for all parameters in a model.
129
+ """
130
+ for p in model.parameters():
131
+ p.requires_grad = flag
132
+
133
+
134
+ #################################################################################
135
+ # Training Loop #
136
+ #################################################################################
137
+
138
+ def main(args):
139
+ # set accelerator
140
+ logging_dir = Path(args.output_dir, args.logging_dir)
141
+ accelerator_project_config = ProjectConfiguration(
142
+ project_dir=args.output_dir, logging_dir=logging_dir
143
+ )
144
+
145
+ accelerator = Accelerator(
146
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
147
+ mixed_precision=args.mixed_precision,
148
+ log_with=args.report_to,
149
+ project_config=accelerator_project_config,
150
+ kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)]
151
+ )
152
+
153
+ if accelerator.is_main_process:
154
+ os.makedirs(args.output_dir, exist_ok=True) # Make results folder (holds all experiment subfolders)
155
+ save_dir = os.path.join(args.output_dir, args.exp_name)
156
+ os.makedirs(save_dir, exist_ok=True)
157
+ args_dict = vars(args)
158
+ # Save to a JSON file
159
+ json_dir = os.path.join(save_dir, "args.json")
160
+ with open(json_dir, 'w') as f:
161
+ json.dump(args_dict, f, indent=4)
162
+ checkpoint_dir = f"{save_dir}/checkpoints" # Stores saved model checkpoints
163
+ os.makedirs(checkpoint_dir, exist_ok=True)
164
+ logger = create_logger(save_dir)
165
+ logger.info(f"Experiment directory created at {save_dir}")
166
+ device = accelerator.device
167
+ if torch.backends.mps.is_available():
168
+ accelerator.native_amp = False
169
+ if args.seed is not None:
170
+ set_seed(args.seed + accelerator.process_index)
171
+
172
+ # Create model:
173
+ assert args.resolution % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)."
174
+ latent_size = args.resolution // 8
175
+
176
+ train_dataset = CustomDataset(
177
+ args.data_dir, semantic_features_dir=args.semantic_features_dir
178
+ )
179
+ use_preprocessed_semantic = train_dataset.use_preprocessed_semantic
180
+
181
+ if use_preprocessed_semantic:
182
+ encoders, encoder_types, architectures = [], [], []
183
+ z_dims = [semantic_dim_from_enc_type(args.enc_type)]
184
+ if accelerator.is_main_process:
185
+ logger.info(
186
+ f"Preprocessed semantic features: skip loading online encoder, z_dims={z_dims}"
187
+ )
188
+ elif args.enc_type is not None:
189
+ encoders, encoder_types, architectures = load_encoders(
190
+ args.enc_type, device, args.resolution
191
+ )
192
+ z_dims = [encoder.embed_dim for encoder in encoders]
193
+ else:
194
+ raise NotImplementedError()
195
+ block_kwargs = {"fused_attn": args.fused_attn, "qk_norm": args.qk_norm}
196
+ model = SiT_models[args.model](
197
+ input_size=latent_size,
198
+ num_classes=args.num_classes,
199
+ use_cfg = (args.cfg_prob > 0),
200
+ z_dims = z_dims,
201
+ encoder_depth=args.encoder_depth,
202
+ **block_kwargs
203
+ )
204
+
205
+ model = model.to(device)
206
+ ema = deepcopy(model).to(device) # Create an EMA of the model for use after training
207
+ requires_grad(ema, False)
208
+
209
+ latents_scale = torch.tensor(
210
+ [0.18215, 0.18215, 0.18215, 0.18215]
211
+ ).view(1, 4, 1, 1).to(device)
212
+ latents_bias = torch.tensor(
213
+ [0., 0., 0., 0.]
214
+ ).view(1, 4, 1, 1).to(device)
215
+
216
+ # VAE decoder:采样阶段将 latent 解码为图像(与根目录 train.py / 预处理一致:sd-vae-ft-mse)
217
+ try:
218
+ from preprocessing import dnnlib
219
+ cache_dir = dnnlib.make_cache_dir_path("diffusers")
220
+ os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
221
+ os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
222
+ os.environ["HF_HOME"] = cache_dir
223
+ try:
224
+ vae = AutoencoderKL.from_pretrained(
225
+ "stabilityai/sd-vae-ft-mse",
226
+ cache_dir=cache_dir,
227
+ local_files_only=True,
228
+ ).to(device)
229
+ vae.eval()
230
+ if accelerator.is_main_process:
231
+ logger.info(
232
+ "Loaded VAE 'stabilityai/sd-vae-ft-mse' from local diffusers cache "
233
+ f"at '{cache_dir}' for intermediate sampling."
234
+ )
235
+ except Exception as e_main:
236
+ vae = None
237
+ candidate_dir = None
238
+ possible_roots = [
239
+ cache_dir,
240
+ os.path.join(os.path.expanduser("~"), ".cache", "dnnlib", "diffusers"),
241
+ os.path.join(os.path.expanduser("~"), ".cache", "diffusers"),
242
+ os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub"),
243
+ ]
244
+ checked_roots = []
245
+ for root_dir in possible_roots:
246
+ if not os.path.isdir(root_dir):
247
+ continue
248
+ checked_roots.append(root_dir)
249
+ for root, dirs, files in os.walk(root_dir):
250
+ if "config.json" in files and "sd-vae-ft-mse" in root.replace("\\", "/"):
251
+ candidate_dir = root
252
+ break
253
+ if candidate_dir is not None:
254
+ break
255
+ if candidate_dir is not None:
256
+ try:
257
+ vae = AutoencoderKL.from_pretrained(
258
+ candidate_dir,
259
+ local_files_only=True,
260
+ ).to(device)
261
+ vae.eval()
262
+ if accelerator.is_main_process:
263
+ logger.info(
264
+ "Loaded VAE 'stabilityai/sd-vae-ft-mse' from discovered local path "
265
+ f"'{candidate_dir}'. Searched roots: {checked_roots}"
266
+ )
267
+ except Exception as e_fallback:
268
+ if accelerator.is_main_process:
269
+ logger.warning(
270
+ "Tried to load VAE from discovered local path "
271
+ f"'{candidate_dir}' but failed: {e_fallback}"
272
+ )
273
+ if vae is None and accelerator.is_main_process:
274
+ logger.warning(
275
+ "Could not load VAE 'stabilityai/sd-vae-ft-mse' via repo name or local search. "
276
+ f"Last repo-level error: {e_main}"
277
+ )
278
+ except Exception as e:
279
+ vae = None
280
+ if accelerator.is_main_process:
281
+ logger.warning(
282
+ f"Failed to initialize VAE loading logic (will skip image decoding): {e}"
283
+ )
284
+
285
+ # create loss function
286
+ loss_fn = SILoss(
287
+ prediction=args.prediction,
288
+ path_type=args.path_type,
289
+ encoders=encoders,
290
+ accelerator=accelerator,
291
+ latents_scale=latents_scale,
292
+ latents_bias=latents_bias,
293
+ weighting=args.weighting,
294
+ t_c=args.t_c,
295
+ ot_cls=args.ot_cls,
296
+ )
297
+ if accelerator.is_main_process:
298
+ logger.info(f"SiT Parameters: {sum(p.numel() for p in model.parameters()):,}")
299
+
300
+ # Setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper):
301
+ if args.allow_tf32:
302
+ torch.backends.cuda.matmul.allow_tf32 = True
303
+ torch.backends.cudnn.allow_tf32 = True
304
+
305
+ optimizer = torch.optim.AdamW(
306
+ model.parameters(),
307
+ lr=args.learning_rate,
308
+ betas=(args.adam_beta1, args.adam_beta2),
309
+ weight_decay=args.adam_weight_decay,
310
+ eps=args.adam_epsilon,
311
+ )
312
+
313
+ # Setup data(train_dataset 已在上方创建)
314
+ local_batch_size = int(args.batch_size // accelerator.num_processes)
315
+ train_dataloader = DataLoader(
316
+ train_dataset,
317
+ batch_size=local_batch_size,
318
+ shuffle=True,
319
+ num_workers=args.num_workers,
320
+ pin_memory=True,
321
+ drop_last=True
322
+ )
323
+ if accelerator.is_main_process:
324
+ logger.info(f"Dataset contains {len(train_dataset):,} images ({args.data_dir})")
325
+
326
+ # Prepare models for training:
327
+ update_ema(ema, model, decay=0) # Ensure EMA is initialized with synced weights
328
+ model.train() # important! This enables embedding dropout for classifier-free guidance
329
+ ema.eval() # EMA model should always be in eval mode
330
+
331
+ # resume:
332
+ global_step = 0
333
+ if args.resume_step > 0:
334
+ ckpt_name = str(args.resume_step).zfill(7) +'.pt'
335
+ ckpt = torch.load(
336
+ f'{os.path.join(args.output_dir, args.exp_name)}/checkpoints/{ckpt_name}',
337
+ map_location='cpu',
338
+ )
339
+ model.load_state_dict(ckpt['model'])
340
+ ema.load_state_dict(ckpt['ema'])
341
+ optimizer.load_state_dict(ckpt['opt'])
342
+ global_step = ckpt['steps']
343
+
344
+ model, optimizer, train_dataloader = accelerator.prepare(
345
+ model, optimizer, train_dataloader
346
+ )
347
+
348
+ if accelerator.is_main_process:
349
+ tracker_config = vars(copy.deepcopy(args))
350
+ accelerator.init_trackers(
351
+ project_name="REG",
352
+ config=tracker_config,
353
+ init_kwargs={
354
+ "wandb": {"name": f"{args.exp_name}"}
355
+ },
356
+ )
357
+
358
+
359
+ progress_bar = tqdm(
360
+ range(0, args.max_train_steps),
361
+ initial=global_step,
362
+ desc="Steps",
363
+ # Only show the progress bar once on each machine.
364
+ disable=not accelerator.is_local_main_process,
365
+ )
366
+
367
+ # Labels to condition the model with (feel free to change):
368
+ sample_batch_size = 64 // accelerator.num_processes
369
+ first_batch = next(iter(train_dataloader))
370
+ if len(first_batch) == 4:
371
+ gt_raw_images, gt_xs, _, _ = first_batch
372
+ else:
373
+ gt_raw_images, gt_xs, _ = first_batch
374
+ assert gt_raw_images.shape[-1] == args.resolution
375
+ gt_xs = gt_xs[:sample_batch_size]
376
+ gt_xs = sample_posterior(
377
+ gt_xs.to(device), latents_scale=latents_scale, latents_bias=latents_bias
378
+ )
379
+ ys = torch.randint(1000, size=(sample_batch_size,), device=device)
380
+ ys = ys.to(device)
381
+ # Create sampling noise:
382
+ n = ys.size(0)
383
+ xT = torch.randn((n, 4, latent_size, latent_size), device=device)
384
+
385
+ for epoch in range(args.epochs):
386
+ model.train()
387
+ for batch in train_dataloader:
388
+ if len(batch) == 4:
389
+ raw_image, x, r_preprocessed, y = batch
390
+ use_sem_file = True
391
+ else:
392
+ raw_image, x, y = batch
393
+ r_preprocessed = None
394
+ use_sem_file = False
395
+
396
+ raw_image = raw_image.to(device)
397
+ x = x.squeeze(dim=1).to(device).float()
398
+ y = y.to(device)
399
+ if args.legacy:
400
+ # In our early experiments, we accidentally apply label dropping twice:
401
+ # once in train.py and once in sit.py.
402
+ # We keep this option for exact reproducibility with previous runs.
403
+ drop_ids = torch.rand(y.shape[0], device=y.device) < args.cfg_prob
404
+ labels = torch.where(drop_ids, args.num_classes, y)
405
+ else:
406
+ labels = y
407
+ with torch.no_grad():
408
+ x = sample_posterior(x, latents_scale=latents_scale, latents_bias=latents_bias)
409
+ zs = []
410
+ if use_sem_file and r_preprocessed is not None:
411
+ cls_token = r_preprocessed.to(device).float()
412
+ if cls_token.dim() == 1:
413
+ cls_token = cls_token.unsqueeze(0)
414
+ while cls_token.dim() > 2:
415
+ cls_token = cls_token.squeeze(1)
416
+ base_m = model.module if hasattr(model, "module") else model
417
+ n_pad = base_m.x_embedder.num_patches
418
+ zs = [
419
+ torch.cat(
420
+ [
421
+ cls_token.unsqueeze(1),
422
+ cls_token.unsqueeze(1).expand(-1, n_pad, -1),
423
+ ],
424
+ dim=1,
425
+ )
426
+ ]
427
+ else:
428
+ with accelerator.autocast():
429
+ for encoder, encoder_type, arch in zip(
430
+ encoders, encoder_types, architectures
431
+ ):
432
+ raw_image_ = preprocess_raw_image(raw_image, encoder_type)
433
+ z = encoder.forward_features(raw_image_)
434
+ if 'dinov2' in encoder_type:
435
+ dense_z = z['x_norm_patchtokens']
436
+ cls_token = z['x_norm_clstoken']
437
+ dense_z = torch.cat([cls_token.unsqueeze(1), dense_z], dim=1)
438
+ else:
439
+ exit()
440
+ zs.append(dense_z)
441
+
442
+ with accelerator.accumulate(model):
443
+ model_kwargs = dict(y=labels)
444
+ loss1, proj_loss1, time_input, noises, loss2 = loss_fn(model, x, model_kwargs, zs=zs,
445
+ cls_token=cls_token,
446
+ time_input=None, noises=None)
447
+ loss_mean = loss1.mean()
448
+ loss_mean_cls = loss2.mean() * args.cls
449
+ proj_loss_mean = proj_loss1.mean() * args.proj_coeff
450
+ loss = loss_mean + proj_loss_mean + loss_mean_cls
451
+
452
+
453
+ ## optimization
454
+ accelerator.backward(loss)
455
+ if accelerator.sync_gradients:
456
+ params_to_clip = model.parameters()
457
+ grad_norm = accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
458
+ optimizer.step()
459
+ optimizer.zero_grad(set_to_none=True)
460
+
461
+ if accelerator.sync_gradients:
462
+ update_ema(ema, model) # change ema function
463
+
464
+ ### enter
465
+ if accelerator.sync_gradients:
466
+ progress_bar.update(1)
467
+ global_step += 1
468
+ if global_step % args.checkpointing_steps == 0 and global_step > 0:
469
+ if accelerator.is_main_process:
470
+ checkpoint = {
471
+ "model": model.module.state_dict(),
472
+ "ema": ema.state_dict(),
473
+ "opt": optimizer.state_dict(),
474
+ "args": args,
475
+ "steps": global_step,
476
+ }
477
+ checkpoint_path = f"{checkpoint_dir}/{global_step:07d}.pt"
478
+ torch.save(checkpoint, checkpoint_path)
479
+ logger.info(f"Saved checkpoint to {checkpoint_path}")
480
+
481
+ if (global_step == 1 or (global_step % args.sampling_steps == 0 and global_step > 0)):
482
+ t_mid_vis = float(args.t_c)
483
+ tc_tag = f"{t_mid_vis:.4f}".rstrip("0").rstrip(".").replace(".", "_")
484
+ logging.info(
485
+ f"Generating EMA samples (Euler-Maruyama; t≈{t_mid_vis:g} → t=0)..."
486
+ )
487
+ ema.eval()
488
+ with torch.no_grad():
489
+ latent_size = args.resolution // 8
490
+ n_samples = min(16, args.batch_size)
491
+ base_model = model.module if hasattr(model, "module") else model
492
+ cls_dim = base_model.z_dims[0]
493
+ shared_seed = torch.randint(0, 2**32, (1,), device=device).item()
494
+ torch.manual_seed(shared_seed)
495
+ z_init = torch.randn(n_samples, base_model.in_channels, latent_size, latent_size, device=device)
496
+ torch.manual_seed(shared_seed)
497
+ cls_init = torch.randn(n_samples, cls_dim, device=device)
498
+ y_samples = torch.randint(0, args.num_classes, (n_samples,), device=device)
499
+
500
+ from samplers import euler_maruyama_sampler
501
+ z_0, z_mid, _ = euler_maruyama_sampler(
502
+ ema,
503
+ z_init,
504
+ y_samples,
505
+ num_steps=50,
506
+ cfg_scale=1.0,
507
+ guidance_low=0.0,
508
+ guidance_high=1.0,
509
+ path_type=args.path_type,
510
+ cls_latents=cls_init,
511
+ args=args,
512
+ return_mid_state=True,
513
+ t_mid=t_mid_vis,
514
+ )
515
+
516
+ samples_root = os.path.join(args.output_dir, args.exp_name, "samples")
517
+ t0_dir = os.path.join(samples_root, "t0")
518
+ t_mid_dir = os.path.join(samples_root, f"t0_{tc_tag}")
519
+ os.makedirs(t0_dir, exist_ok=True)
520
+ os.makedirs(t_mid_dir, exist_ok=True)
521
+
522
+ if vae is not None:
523
+ z_f = z_0.to(dtype=torch.float32)
524
+ samples_final = vae.decode((z_f - latents_bias) / latents_scale).sample
525
+ samples_final = (samples_final + 1) / 2.0
526
+ samples_final = samples_final.clamp(0, 1)
527
+ grid_final = array2grid(samples_final)
528
+ Image.fromarray(grid_final).save(
529
+ os.path.join(t0_dir, f"step_{global_step:07d}_t0.png")
530
+ )
531
+
532
+ if z_mid is not None:
533
+ z_m = z_mid.to(dtype=torch.float32)
534
+ samples_mid = vae.decode((z_m - latents_bias) / latents_scale).sample
535
+ samples_mid = (samples_mid + 1) / 2.0
536
+ samples_mid = samples_mid.clamp(0, 1)
537
+ grid_mid = array2grid(samples_mid)
538
+ Image.fromarray(grid_mid).save(
539
+ os.path.join(t_mid_dir, f"step_{global_step:07d}_t0_{tc_tag}.png")
540
+ )
541
+ else:
542
+ logging.warning(
543
+ f"Sampling time grid did not bracket t_mid={t_mid_vis:g}; "
544
+ f"skip t0_{tc_tag} image this step."
545
+ )
546
+
547
+ del z_init, cls_init, y_samples, z_0
548
+ if z_mid is not None:
549
+ del z_mid
550
+ if vae is not None:
551
+ del samples_final, grid_final
552
+ if "samples_mid" in locals():
553
+ del samples_mid, grid_mid
554
+ torch.cuda.empty_cache()
555
+
556
+
557
+ logs = {
558
+ "loss_final": accelerator.gather(loss).mean().detach().item(),
559
+ "loss_mean": accelerator.gather(loss_mean).mean().detach().item(),
560
+ "proj_loss": accelerator.gather(proj_loss_mean).mean().detach().item(),
561
+ "loss_mean_cls": accelerator.gather(loss_mean_cls).mean().detach().item(),
562
+ "grad_norm": accelerator.gather(grad_norm).mean().detach().item()
563
+ }
564
+
565
+ log_message = ", ".join(f"{key}: {value:.6f}" for key, value in logs.items())
566
+ logging.info(f"Step: {global_step}, Training Logs: {log_message}")
567
+
568
+ progress_bar.set_postfix(**logs)
569
+ accelerator.log(logs, step=global_step)
570
+
571
+ if global_step >= args.max_train_steps:
572
+ break
573
+ if global_step >= args.max_train_steps:
574
+ break
575
+
576
+ model.eval() # important! This disables randomized embedding dropout
577
+ # do any sampling/FID calculation/etc. with ema (or model) in eval mode ...
578
+
579
+ accelerator.wait_for_everyone()
580
+ if accelerator.is_main_process:
581
+ logger.info("Done!")
582
+ accelerator.end_training()
583
+
584
+ def parse_args(input_args=None):
585
+ parser = argparse.ArgumentParser(description="Training")
586
+
587
+ # logging:
588
+ parser.add_argument("--output-dir", type=str, default="exps")
589
+ parser.add_argument("--exp-name", type=str, required=True)
590
+ parser.add_argument("--logging-dir", type=str, default="logs")
591
+ parser.add_argument("--report-to", type=str, default="wandb")
592
+ parser.add_argument("--sampling-steps", type=int, default=2000)
593
+ parser.add_argument("--resume-step", type=int, default=0)
594
+
595
+ # model
596
+ parser.add_argument("--model", type=str)
597
+ parser.add_argument("--num-classes", type=int, default=1000)
598
+ parser.add_argument("--encoder-depth", type=int, default=8)
599
+ parser.add_argument("--fused-attn", action=argparse.BooleanOptionalAction, default=True)
600
+ parser.add_argument("--qk-norm", action=argparse.BooleanOptionalAction, default=False)
601
+ parser.add_argument("--ops-head", type=int, default=16)
602
+
603
+ # dataset
604
+ parser.add_argument("--data-dir", type=str, default="../data/imagenet256")
605
+ parser.add_argument(
606
+ "--semantic-features-dir",
607
+ type=str,
608
+ default=None,
609
+ help="预处理 DINOv2 class token 等特征目录(含 dataset.json)。"
610
+ "默认 None 时若存在 data-dir/imagenet_256_features/dinov2-vit-b_tmp/gpu0 则自动使用。",
611
+ )
612
+ parser.add_argument("--resolution", type=int, choices=[256, 512], default=256)
613
+ parser.add_argument("--batch-size", type=int, default=256)#256
614
+
615
+ # precision
616
+ parser.add_argument("--allow-tf32", action="store_true")
617
+ parser.add_argument("--mixed-precision", type=str, default="fp16", choices=["no", "fp16", "bf16"])
618
+
619
+ # optimization
620
+ parser.add_argument("--epochs", type=int, default=1400)
621
+ parser.add_argument("--max-train-steps", type=int, default=1000000)
622
+ parser.add_argument("--checkpointing-steps", type=int, default=10000)
623
+ parser.add_argument("--gradient-accumulation-steps", type=int, default=1)
624
+ parser.add_argument("--learning-rate", type=float, default=1e-4)
625
+ parser.add_argument("--adam-beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
626
+ parser.add_argument("--adam-beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
627
+ parser.add_argument("--adam-weight-decay", type=float, default=0., help="Weight decay to use.")
628
+ parser.add_argument("--adam-epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
629
+ parser.add_argument("--max-grad-norm", default=1.0, type=float, help="Max gradient norm.")
630
+
631
+ # seed
632
+ parser.add_argument("--seed", type=int, default=0)
633
+
634
+ # cpu
635
+ parser.add_argument("--num-workers", type=int, default=4)
636
+
637
+ # loss
638
+ parser.add_argument("--path-type", type=str, default="linear", choices=["linear", "cosine"])
639
+ parser.add_argument("--prediction", type=str, default="v", choices=["v"]) # currently we only support v-prediction
640
+ parser.add_argument("--cfg-prob", type=float, default=0.1)
641
+ parser.add_argument("--enc-type", type=str, default='dinov2-vit-b')
642
+ parser.add_argument("--proj-coeff", type=float, default=0.5)
643
+ parser.add_argument("--weighting", default="uniform", type=str, help="Max gradient norm.")
644
+ parser.add_argument("--legacy", action=argparse.BooleanOptionalAction, default=False)
645
+ parser.add_argument("--cls", type=float, default=0.03)
646
+ parser.add_argument(
647
+ "--t-c",
648
+ type=float,
649
+ default=0.5,
650
+ help="语义分界时刻(与脚本内 t 约定一致:t=1 噪声→t=0 数据)。"
651
+ "t∈(t_c,1]:cls 沿 OT 配对后的路径插值(CFM/OT-CFM 式 minibatch OT);"
652
+ "t∈[0,t_c]:cls 固定为真实 encoder cls,目标 cls 速度为 0。",
653
+ )
654
+ parser.add_argument(
655
+ "--ot-cls",
656
+ action=argparse.BooleanOptionalAction,
657
+ default=True,
658
+ help="在 t>t_c 段对 cls 噪声与 batch 内 cls_gt 做 minibatch 最优传输配对(需 scipy);关闭则退化为独立高斯噪声配对。",
659
+ )
660
+ if input_args is not None:
661
+ args = parser.parse_args(input_args)
662
+ else:
663
+ args = parser.parse_args()
664
+
665
+ return args
666
+
667
+ if __name__ == "__main__":
668
+ args = parse_args()
669
+
670
+ main(args)
back/train.sh ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # REG/train.py:与主仓库类似,可单独指定数据根目录与预处理 cls 特征目录。
3
+ # 数据布局:${DATA_DIR}/imagenet_256_vae/ 下 VAE latent;
4
+ # ${SEMANTIC_FEATURES_DIR}/ 下 img-feature-*.npy + dataset.json(与 parallel_encode 一致)。
5
+
6
+ NUM_GPUS=4
7
+
8
+ # ------------ 按本机路径修改 ------------
9
+ DATA_DIR="/gemini/space/zhaozy/dataset/Imagenet/imagenet_256"
10
+ SEMANTIC_FEATURES_DIR="/gemini/space/zhaozy/dataset/Imagenet/imagenet_256/imagenet_256_features/dinov2-vit-b_tmp/gpu0"
11
+
12
+ # 后台示例(与主实验脚本风格一致):
13
+ # nohup bash train.sh > jsflow-experiment.log 2>&1 &
14
+
15
+ nohup accelerate launch --multi_gpu --num_processes "${NUM_GPUS}" --mixed_precision bf16 train.py \
16
+ --report-to wandb \
17
+ --allow-tf32 \
18
+ --mixed-precision bf16 \
19
+ --seed 0 \
20
+ --path-type linear \
21
+ --prediction v \
22
+ --weighting uniform \
23
+ --model SiT-XL/2 \
24
+ --enc-type dinov2-vit-b \
25
+ --encoder-depth 8 \
26
+ --proj-coeff 0.5 \
27
+ --output-dir exps \
28
+ --exp-name jsflow-experiment-0.75 \
29
+ --batch-size 256 \
30
+ --data-dir "${DATA_DIR}" \
31
+ --semantic-features-dir "${SEMANTIC_FEATURES_DIR}" \
32
+ --learning-rate 0.00005 \
33
+ --t-c 0.75 \
34
+ --cls 0.05 \
35
+ --ot-cls \
36
+ > jsflow-experiment.log 2>&1 &
37
+
38
+ # 说明:
39
+ # - 不使用预处理特征、改在线抽 DINO 时:去掉 --semantic-features-dir,并保证 data-dir 为 REG 原布局
40
+ # (imagenet_256_vae + vae-sd)。
41
+ # - 关闭 minibatch OT:追加 --no-ot-cls。
42
+ # - 主仓库 train.py 中的 --weight-ratio / --semantic-reg-coeff / --repa-* 等为本 REG 脚本未实现项;
43
+ # 投影强度请用 --proj-coeff,cls 流损失权重用 --cls。
back/utils.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from torchvision.datasets.utils import download_url
3
+ import torch
4
+ import torchvision.models as torchvision_models
5
+ import timm
6
+ from models import mocov3_vit
7
+ import math
8
+ import warnings
9
+
10
+
11
+ # code from SiT repository
12
+ pretrained_models = {'last.pt'}
13
+
14
+ def download_model(model_name):
15
+ """
16
+ Downloads a pre-trained SiT model from the web.
17
+ """
18
+ assert model_name in pretrained_models
19
+ local_path = f'pretrained_models/{model_name}'
20
+ if not os.path.isfile(local_path):
21
+ os.makedirs('pretrained_models', exist_ok=True)
22
+ web_path = f'https://www.dl.dropboxusercontent.com/scl/fi/cxedbs4da5ugjq5wg3zrg/last.pt?rlkey=8otgrdkno0nd89po3dpwngwcc&st=apcc645o&dl=0'
23
+ download_url(web_path, 'pretrained_models', filename=model_name)
24
+ model = torch.load(local_path, map_location=lambda storage, loc: storage)
25
+ return model
26
+
27
+ def fix_mocov3_state_dict(state_dict):
28
+ for k in list(state_dict.keys()):
29
+ # retain only base_encoder up to before the embedding layer
30
+ if k.startswith('module.base_encoder'):
31
+ # fix naming bug in checkpoint
32
+ new_k = k[len("module.base_encoder."):]
33
+ if "blocks.13.norm13" in new_k:
34
+ new_k = new_k.replace("norm13", "norm1")
35
+ if "blocks.13.mlp.fc13" in k:
36
+ new_k = new_k.replace("fc13", "fc1")
37
+ if "blocks.14.norm14" in k:
38
+ new_k = new_k.replace("norm14", "norm2")
39
+ if "blocks.14.mlp.fc14" in k:
40
+ new_k = new_k.replace("fc14", "fc2")
41
+ # remove prefix
42
+ if 'head' not in new_k and new_k.split('.')[0] != 'fc':
43
+ state_dict[new_k] = state_dict[k]
44
+ # delete renamed or unused k
45
+ del state_dict[k]
46
+ if 'pos_embed' in state_dict.keys():
47
+ state_dict['pos_embed'] = timm.layers.pos_embed.resample_abs_pos_embed(
48
+ state_dict['pos_embed'], [16, 16],
49
+ )
50
+ return state_dict
51
+
52
+ @torch.no_grad()
53
+ def load_encoders(enc_type, device, resolution=256):
54
+ assert (resolution == 256) or (resolution == 512)
55
+
56
+ enc_names = enc_type.split(',')
57
+ encoders, architectures, encoder_types = [], [], []
58
+ for enc_name in enc_names:
59
+ encoder_type, architecture, model_config = enc_name.split('-')
60
+ # Currently, we only support 512x512 experiments with DINOv2 encoders.
61
+ if resolution == 512:
62
+ if encoder_type != 'dinov2':
63
+ raise NotImplementedError(
64
+ "Currently, we only support 512x512 experiments with DINOv2 encoders."
65
+ )
66
+
67
+ architectures.append(architecture)
68
+ encoder_types.append(encoder_type)
69
+ if encoder_type == 'mocov3':
70
+ if architecture == 'vit':
71
+ if model_config == 's':
72
+ encoder = mocov3_vit.vit_small()
73
+ elif model_config == 'b':
74
+ encoder = mocov3_vit.vit_base()
75
+ elif model_config == 'l':
76
+ encoder = mocov3_vit.vit_large()
77
+ ckpt = torch.load(f'./ckpts/mocov3_vit{model_config}.pth')
78
+ state_dict = fix_mocov3_state_dict(ckpt['state_dict'])
79
+ del encoder.head
80
+ encoder.load_state_dict(state_dict, strict=True)
81
+ encoder.head = torch.nn.Identity()
82
+ elif architecture == 'resnet':
83
+ raise NotImplementedError()
84
+
85
+ encoder = encoder.to(device)
86
+ encoder.eval()
87
+
88
+ elif 'dinov2' in encoder_type:
89
+ import timm
90
+ if 'reg' in encoder_type:
91
+ try:
92
+ encoder = torch.hub.load('your_path/.cache/torch/hub/facebookresearch_dinov2_main',
93
+ f'dinov2_vit{model_config}14_reg', source='local')
94
+ except:
95
+ encoder = torch.hub.load('facebookresearch/dinov2', f'dinov2_vit{model_config}14_reg')
96
+ else:
97
+ try:
98
+ encoder = torch.hub.load('your_path/.cache/torch/hub/facebookresearch_dinov2_main',
99
+ f'dinov2_vit{model_config}14', source='local')
100
+ except:
101
+ encoder = torch.hub.load('facebookresearch/dinov2', f'dinov2_vit{model_config}14')
102
+
103
+ print(f"Now you are using the {enc_name} as the aligning model")
104
+ del encoder.head
105
+ patch_resolution = 16 * (resolution // 256)
106
+ encoder.pos_embed.data = timm.layers.pos_embed.resample_abs_pos_embed(
107
+ encoder.pos_embed.data, [patch_resolution, patch_resolution],
108
+ )
109
+ encoder.head = torch.nn.Identity()
110
+ encoder = encoder.to(device)
111
+ encoder.eval()
112
+
113
+ elif 'dinov1' == encoder_type:
114
+ import timm
115
+ from models import dinov1
116
+ encoder = dinov1.vit_base()
117
+ ckpt = torch.load(f'./ckpts/dinov1_vit{model_config}.pth')
118
+ if 'pos_embed' in ckpt.keys():
119
+ ckpt['pos_embed'] = timm.layers.pos_embed.resample_abs_pos_embed(
120
+ ckpt['pos_embed'], [16, 16],
121
+ )
122
+ del encoder.head
123
+ encoder.head = torch.nn.Identity()
124
+ encoder.load_state_dict(ckpt, strict=True)
125
+ encoder = encoder.to(device)
126
+ encoder.forward_features = encoder.forward
127
+ encoder.eval()
128
+
129
+ elif encoder_type == 'clip':
130
+ import clip
131
+ from models.clip_vit import UpdatedVisionTransformer
132
+ encoder_ = clip.load(f"ViT-{model_config}/14", device='cpu')[0].visual
133
+ encoder = UpdatedVisionTransformer(encoder_).to(device)
134
+ #.to(device)
135
+ encoder.embed_dim = encoder.model.transformer.width
136
+ encoder.forward_features = encoder.forward
137
+ encoder.eval()
138
+
139
+ elif encoder_type == 'mae':
140
+ from models.mae_vit import vit_large_patch16
141
+ import timm
142
+ kwargs = dict(img_size=256)
143
+ encoder = vit_large_patch16(**kwargs).to(device)
144
+ with open(f"ckpts/mae_vit{model_config}.pth", "rb") as f:
145
+ state_dict = torch.load(f)
146
+ if 'pos_embed' in state_dict["model"].keys():
147
+ state_dict["model"]['pos_embed'] = timm.layers.pos_embed.resample_abs_pos_embed(
148
+ state_dict["model"]['pos_embed'], [16, 16],
149
+ )
150
+ encoder.load_state_dict(state_dict["model"])
151
+
152
+ encoder.pos_embed.data = timm.layers.pos_embed.resample_abs_pos_embed(
153
+ encoder.pos_embed.data, [16, 16],
154
+ )
155
+
156
+ elif encoder_type == 'jepa':
157
+ from models.jepa import vit_huge
158
+ kwargs = dict(img_size=[224, 224], patch_size=14)
159
+ encoder = vit_huge(**kwargs).to(device)
160
+ with open(f"ckpts/ijepa_vit{model_config}.pth", "rb") as f:
161
+ state_dict = torch.load(f, map_location=device)
162
+ new_state_dict = dict()
163
+ for key, value in state_dict['encoder'].items():
164
+ new_state_dict[key[7:]] = value
165
+ encoder.load_state_dict(new_state_dict)
166
+ encoder.forward_features = encoder.forward
167
+
168
+ encoders.append(encoder)
169
+
170
+ return encoders, encoder_types, architectures
171
+
172
+
173
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
174
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
175
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
176
+ def norm_cdf(x):
177
+ # Computes standard normal cumulative distribution function
178
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
179
+
180
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
181
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
182
+ "The distribution of values may be incorrect.",
183
+ stacklevel=2)
184
+
185
+ with torch.no_grad():
186
+ # Values are generated by using a truncated uniform distribution and
187
+ # then using the inverse CDF for the normal distribution.
188
+ # Get upper and lower cdf values
189
+ l = norm_cdf((a - mean) / std)
190
+ u = norm_cdf((b - mean) / std)
191
+
192
+ # Uniformly fill tensor with values from [l, u], then translate to
193
+ # [2l-1, 2u-1].
194
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
195
+
196
+ # Use inverse cdf transform for normal distribution to get truncated
197
+ # standard normal
198
+ tensor.erfinv_()
199
+
200
+ # Transform to proper mean, std
201
+ tensor.mul_(std * math.sqrt(2.))
202
+ tensor.add_(mean)
203
+
204
+ # Clamp to ensure it's in the proper range
205
+ tensor.clamp_(min=a, max=b)
206
+ return tensor
207
+
208
+
209
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
210
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
211
+
212
+
213
+ def load_legacy_checkpoints(state_dict, encoder_depth):
214
+ new_state_dict = dict()
215
+ for key, value in state_dict.items():
216
+ if 'decoder_blocks' in key:
217
+ parts =key.split('.')
218
+ new_idx = int(parts[1]) + encoder_depth
219
+ parts[0] = 'blocks'
220
+ parts[1] = str(new_idx)
221
+ new_key = '.'.join(parts)
222
+ new_state_dict[new_key] = value
223
+ else:
224
+ new_state_dict[key] = value
225
+ return new_state_dict