dg845 commited on
Commit
10ccafc
1 Parent(s): f6adb30

Upload 9 files

Browse files

Add unidiffuser original code since I can't figure out how to package it correctly

dpm_solver_pp.py ADDED
@@ -0,0 +1,952 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import math
4
+ import numpy as np
5
+ import torch.distributed as dist
6
+
7
+
8
+ def interpolate_fn(x: torch.Tensor, xp: torch.Tensor, yp: torch.Tensor) -> torch.Tensor:
9
+ """Performs piecewise linear interpolation for x, using xp and yp keypoints (knots).
10
+ Performs separate interpolation for each channel.
11
+ Args:
12
+ x: [N, C] points to be calibrated (interpolated). Batch with C channels.
13
+ xp: [C, K] x coordinates of the PWL knots. C is the number of channels, K is the number of knots.
14
+ yp: [C, K] y coordinates of the PWL knots. C is the number of channels, K is the number of knots.
15
+ Returns:
16
+ Interpolated points of the shape [N, C].
17
+ The piecewise linear function extends for the whole x axis (the outermost keypoints define the outermost
18
+ infinite lines).
19
+ For example:
20
+ >>> calibrate1d(torch.tensor([[0.5]]), torch.tensor([[0.0, 1.0]]), torch.tensor([[0.0, 2.0]]))
21
+ tensor([[1.0000]])
22
+ >>> calibrate1d(torch.tensor([[-10]]), torch.tensor([[0.0, 1.0]]), torch.tensor([[0.0, 2.0]]))
23
+ tensor([[-20.0000]])
24
+ """
25
+ x_breakpoints = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((x.shape[0], 1, 1))], dim=2)
26
+ num_x_points = xp.shape[1]
27
+ sorted_x_breakpoints, x_indices = torch.sort(x_breakpoints, dim=2)
28
+ x_idx = torch.argmin(x_indices, dim=2)
29
+ cand_start_idx = x_idx - 1
30
+ start_idx = torch.where(
31
+ torch.eq(x_idx, 0),
32
+ torch.tensor(1, device=x.device),
33
+ torch.where(
34
+ torch.eq(x_idx, num_x_points), torch.tensor(num_x_points - 2, device=x.device), cand_start_idx,
35
+ ),
36
+ )
37
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
38
+ start_x = torch.gather(sorted_x_breakpoints, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
39
+ end_x = torch.gather(sorted_x_breakpoints, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
40
+ start_idx2 = torch.where(
41
+ torch.eq(x_idx, 0),
42
+ torch.tensor(0, device=x.device),
43
+ torch.where(
44
+ torch.eq(x_idx, num_x_points), torch.tensor(num_x_points - 2, device=x.device), cand_start_idx,
45
+ ),
46
+ )
47
+ y_positions_expanded = yp.unsqueeze(0).expand(x.shape[0], -1, -1)
48
+ start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
49
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
50
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
51
+ return cand
52
+
53
+
54
+ class NoiseScheduleVP:
55
+ def __init__(self, schedule='discrete', beta_0=1e-4, beta_1=2e-2, total_N=1000, betas=None, alphas_cumprod=None):
56
+ """Create a wrapper class for the forward SDE (VP type).
57
+
58
+ The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
59
+ We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
60
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
61
+
62
+ log_alpha_t = self.marginal_log_mean_coeff(t)
63
+ sigma_t = self.marginal_std(t)
64
+ lambda_t = self.marginal_lambda(t)
65
+
66
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
67
+
68
+ t = self.inverse_lambda(lambda_t)
69
+
70
+ ===============================================================
71
+
72
+ We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
73
+ schedule are the default settings in DDPM and improved-DDPM:
74
+
75
+ beta_min: A `float` number. The smallest beta for the linear schedule.
76
+ beta_max: A `float` number. The largest beta for the linear schedule.
77
+ cosine_s: A `float` number. The hyperparameter in the cosine schedule.
78
+ cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
79
+ T: A `float` number. The ending time of the forward process.
80
+
81
+ Note that the original DDPM (linear schedule) used the discrete-time label (0 to 999). We convert the discrete-time
82
+ label to the continuous-time time (followed Song et al., 2021), so the beta here is 1000x larger than those in DDPM.
83
+
84
+ ===============================================================
85
+
86
+ Args:
87
+ schedule: A `str`. The noise schedule of the forward SDE ('linear' or 'cosine').
88
+
89
+ Returns:
90
+ A wrapper object of the forward SDE (VP type).
91
+ """
92
+ if schedule not in ['linear', 'discrete', 'cosine']:
93
+ raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'linear' or 'cosine'".format(schedule))
94
+ self.total_N = total_N
95
+ self.beta_0 = beta_0 * 1000.
96
+ self.beta_1 = beta_1 * 1000.
97
+
98
+ if schedule == 'discrete':
99
+ if betas is not None:
100
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
101
+ else:
102
+ assert alphas_cumprod is not None
103
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
104
+ self.total_N = len(log_alphas)
105
+ self.t_discrete = torch.linspace(1. / self.total_N, 1., self.total_N).reshape((1, -1))
106
+ self.log_alpha_discrete = log_alphas.reshape((1, -1))
107
+
108
+ self.cosine_s = 0.008
109
+ self.cosine_beta_max = 999.
110
+ self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
111
+ self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
112
+ self.schedule = schedule
113
+ if schedule == 'cosine':
114
+ # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
115
+ # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
116
+ self.T = 0.9946
117
+ else:
118
+ self.T = 1.
119
+
120
+ def marginal_log_mean_coeff(self, t):
121
+ """
122
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
123
+ """
124
+ if self.schedule == 'linear':
125
+ return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
126
+ elif self.schedule == 'discrete':
127
+ return interpolate_fn(t.reshape((-1, 1)), self.t_discrete.clone().to(t.device), self.log_alpha_discrete.clone().to(t.device)).reshape((-1,))
128
+ elif self.schedule == 'cosine':
129
+ log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
130
+ log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
131
+ return log_alpha_t
132
+ else:
133
+ raise ValueError("Unsupported ")
134
+
135
+ def marginal_alpha(self, t):
136
+ return torch.exp(self.marginal_log_mean_coeff(t))
137
+
138
+ def marginal_std(self, t):
139
+ """
140
+ Compute sigma_t of a given continuous-time label t in [0, T].
141
+ """
142
+ return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
143
+
144
+ def marginal_lambda(self, t):
145
+ """
146
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
147
+ """
148
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
149
+ log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
150
+ return log_mean_coeff - log_std
151
+
152
+ def inverse_lambda(self, lamb):
153
+ """
154
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
155
+ """
156
+ if self.schedule == 'linear':
157
+ tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
158
+ Delta = self.beta_0**2 + tmp
159
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
160
+ elif self.schedule == 'discrete':
161
+ log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
162
+ t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_discrete.clone().to(lamb.device), [1]), torch.flip(self.t_discrete.clone().to(lamb.device), [1]))
163
+ return t.reshape((-1,))
164
+ else:
165
+ log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
166
+ t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
167
+ t = t_fn(log_alpha)
168
+ return t
169
+
170
+
171
+ def model_wrapper(model, noise_schedule=None, is_cond_classifier=False, classifier_fn=None, classifier_scale=1., time_input_type='1', total_N=1000, model_kwargs={}, is_deis=False):
172
+ """Create a wrapper function for the noise prediction model.
173
+
174
+ DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
175
+ firstly wrap the model function to a function that accepts the continuous time as the input.
176
+
177
+ The input `model` has the following format:
178
+
179
+ ``
180
+ model(x, t_input, **model_kwargs) -> noise
181
+ ``
182
+
183
+ where `x` and `noise` have the same shape, and `t_input` is the time label of the model.
184
+ (may be discrete-time labels (i.e. 0 to 999) or continuous-time labels (i.e. epsilon to T).)
185
+
186
+ We wrap the model function to the following format:
187
+
188
+ ``
189
+ def model_fn(x, t_continuous) -> noise:
190
+ t_input = get_model_input_time(t_continuous)
191
+ return model(x, t_input, **model_kwargs)
192
+ ``
193
+
194
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
195
+
196
+ For DPMs with classifier guidance, we also combine the model output with the classifier gradient as used in [1].
197
+
198
+ [1] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," in Advances in Neural
199
+ Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
200
+
201
+ ===============================================================
202
+
203
+ Args:
204
+ model: A noise prediction model with the following format:
205
+ ``
206
+ def model(x, t_input, **model_kwargs):
207
+ return noise
208
+ ``
209
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP. Only used for the classifier guidance.
210
+ is_cond_classifier: A `bool`. Whether to use the classifier guidance.
211
+ classifier_fn: A classifier function. Only used for the classifier guidance. The format is:
212
+ ``
213
+ def classifier_fn(x, t_input):
214
+ return logits
215
+ ``
216
+ classifier_scale: A `float`. The scale for the classifier guidance.
217
+ time_input_type: A `str`. The type for the time input of the model. We support three types:
218
+ - '0': The continuous-time type. In this case, the model is trained on the continuous time,
219
+ so `t_input` = `t_continuous`.
220
+ - '1': The Type-1 discrete type described in the Appendix of DPM-Solver paper.
221
+ **For discrete-time DPMs, we recommend to use this type for DPM-Solver**.
222
+ - '2': The Type-2 discrete type described in the Appendix of DPM-Solver paper.
223
+ total_N: A `int`. The total number of the discrete-time DPMs (default is 1000), used when `time_input_type`
224
+ is '1' or '2'.
225
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
226
+ Returns:
227
+ A function that accepts the continuous time as the input, with the following format:
228
+ ``
229
+ def model_fn(x, t_continuous):
230
+ t_input = get_model_input_time(t_continuous)
231
+ return model(x, t_input, **model_kwargs)
232
+ ``
233
+ """
234
+ def get_model_input_time(t_continuous):
235
+ """
236
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
237
+ """
238
+ if time_input_type == '0':
239
+ # discrete_type == '0' means that the model is continuous-time model.
240
+ # For continuous-time DPMs, the continuous time equals to the discrete time.
241
+ return t_continuous
242
+ elif time_input_type == '1':
243
+ # Type-1 discrete label, as detailed in the Appendix of DPM-Solver.
244
+ return 1000. * torch.max(t_continuous - 1. / total_N, torch.zeros_like(t_continuous).to(t_continuous))
245
+ elif time_input_type == '2':
246
+ # Type-2 discrete label, as detailed in the Appendix of DPM-Solver.
247
+ max_N = (total_N - 1) / total_N * 1000.
248
+ return max_N * t_continuous
249
+ else:
250
+ raise ValueError("Unsupported time input type {}, must be '0' or '1' or '2'".format(time_input_type))
251
+
252
+ def cond_fn(x, t_discrete, y):
253
+ """
254
+ Compute the gradient of the classifier, multiplied with the sclae of the classifier guidance.
255
+ """
256
+ assert y is not None
257
+ with torch.enable_grad():
258
+ x_in = x.detach().requires_grad_(True)
259
+ logits = classifier_fn(x_in, t_discrete)
260
+ log_probs = F.log_softmax(logits, dim=-1)
261
+ selected = log_probs[range(len(logits)), y.view(-1)]
262
+ return classifier_scale * torch.autograd.grad(selected.sum(), x_in)[0]
263
+
264
+ def model_fn(x, t_continuous):
265
+ """
266
+ The noise predicition model function that is used for DPM-Solver.
267
+ """
268
+ if t_continuous.reshape((-1,)).shape[0] == 1:
269
+ t_continuous = torch.ones((x.shape[0],)).to(x.device) * t_continuous
270
+ if is_cond_classifier:
271
+ y = model_kwargs.get("y", None)
272
+ if y is None:
273
+ raise ValueError("For classifier guidance, the label y has to be in the input.")
274
+ t_discrete = get_model_input_time(t_continuous)
275
+ noise_uncond = model(x, t_discrete, **model_kwargs)
276
+ cond_grad = cond_fn(x, t_discrete, y)
277
+ if is_deis:
278
+ sigma_t = noise_schedule.marginal_std(t_continuous / 1000.)
279
+ else:
280
+ sigma_t = noise_schedule.marginal_std(t_continuous)
281
+ dims = len(cond_grad.shape) - 1
282
+ return noise_uncond - sigma_t[(...,) + (None,)*dims] * cond_grad
283
+ else:
284
+ t_discrete = get_model_input_time(t_continuous)
285
+ return model(x, t_discrete, **model_kwargs)
286
+
287
+ return model_fn
288
+
289
+
290
+ class DPM_Solver:
291
+ def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.):
292
+ """Construct a DPM-Solver.
293
+
294
+ Args:
295
+ model_fn: A noise prediction model function which accepts the continuous-time input
296
+ (t in [epsilon, T]):
297
+ ``
298
+ def model_fn(x, t_continuous):
299
+ return noise
300
+ ``
301
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
302
+ """
303
+ self.model = model_fn
304
+ self.noise_schedule = noise_schedule
305
+ self.predict_x0 = predict_x0
306
+ self.thresholding = thresholding
307
+ self.max_val = max_val
308
+
309
+ def model_fn(self, x, t):
310
+ if self.predict_x0:
311
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
312
+ noise = self.model(x, t)
313
+ dims = len(x.shape) - 1
314
+ x0 = (x - sigma_t[(...,) + (None,)*dims] * noise) / alpha_t[(...,) + (None,)*dims]
315
+ if self.thresholding:
316
+ p = 0.995
317
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
318
+ s = torch.maximum(s, torch.ones_like(s).to(s.device))[(...,) + (None,)*dims]
319
+ x0 = torch.clamp(x0, -s, s) / (s / self.max_val)
320
+ return x0
321
+ else:
322
+ return self.model(x, t)
323
+
324
+ def get_time_steps(self, skip_type, t_T, t_0, N, device):
325
+ """Compute the intermediate time steps for sampling.
326
+
327
+ Args:
328
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
329
+ - 'logSNR': uniform logSNR for the time steps, **recommended for DPM-Solver**.
330
+ - 'time_uniform': uniform time for the time steps. (Used in DDIM and DDPM.)
331
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
332
+ t_T: A `float`. The starting time of the sampling (default is T).
333
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
334
+ N: A `int`. The total number of the spacing of the time steps.
335
+ device: A torch device.
336
+ Returns:
337
+ A pytorch tensor of the time steps, with the shape (N + 1,).
338
+ """
339
+ if skip_type == 'logSNR':
340
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
341
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
342
+ logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
343
+ # print(torch.min(torch.abs(logSNR_steps - self.noise_schedule.marginal_lambda(self.noise_schedule.inverse_lambda(logSNR_steps)))).item())
344
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
345
+ elif skip_type == 't2':
346
+ t_order = 2
347
+ t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
348
+ return t
349
+ elif skip_type == 'time_uniform':
350
+ return torch.linspace(t_T, t_0, N + 1).to(device)
351
+ elif skip_type == 'time_quadratic':
352
+ t = torch.linspace(t_0, t_T, 10000000).to(device)
353
+ quadratic_t = torch.sqrt(t)
354
+ quadratic_steps = torch.linspace(quadratic_t[0], quadratic_t[-1], N + 1).to(device)
355
+ return torch.flip(torch.cat([t[torch.searchsorted(quadratic_t, quadratic_steps)[:-1]], t_T * torch.ones((1,)).to(device)], dim=0), dims=[0])
356
+ else:
357
+ raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
358
+
359
+ def get_time_steps_for_dpm_solver_fast(self, skip_type, t_T, t_0, steps, order, device):
360
+ """
361
+ Compute the intermediate time steps and the order of each step for sampling by DPM-Solver-fast.
362
+
363
+ We recommend DPM-Solver-fast for fast sampling of DPMs. Given a fixed number of function evaluations by `steps`,
364
+ the sampling procedure by DPM-Solver-fast is:
365
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
366
+ - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
367
+ - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
368
+ - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
369
+
370
+ ============================================
371
+ Args:
372
+ t_T: A `float`. The starting time of the sampling (default is T).
373
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
374
+ steps: A `int`. The total number of function evaluations (NFE).
375
+ device: A torch device.
376
+ Returns:
377
+ orders: A list of the solver order of each step.
378
+ timesteps: A pytorch tensor of the time steps, with the shape of (K + 1,).
379
+ """
380
+ if order == 3:
381
+ K = steps // 3 + 1
382
+ if steps % 3 == 0:
383
+ orders = [3,] * (K - 2) + [2, 1]
384
+ elif steps % 3 == 1:
385
+ orders = [3,] * (K - 1) + [1]
386
+ else:
387
+ orders = [3,] * (K - 1) + [2]
388
+ timesteps = self.get_time_steps(skip_type, t_T, t_0, K, device)
389
+ return orders, timesteps
390
+ elif order == 2:
391
+ K = steps // 2
392
+ if steps % 2 == 0:
393
+ orders = [2,] * K
394
+ else:
395
+ orders = [2,] * K + [1]
396
+ timesteps = self.get_time_steps(skip_type, t_T, t_0, K, device)
397
+ return orders, timesteps
398
+ else:
399
+ raise ValueError("order must >= 2")
400
+
401
+ def denoise_fn(self, x, s, noise_s=None):
402
+ ns = self.noise_schedule
403
+ dims = len(x.shape) - 1
404
+ log_alpha_s = ns.marginal_log_mean_coeff(s)
405
+ sigma_s = ns.marginal_std(s)
406
+
407
+ if noise_s is None:
408
+ noise_s = self.model_fn(x, s)
409
+ x_0 = (
410
+ (x - sigma_s[(...,) + (None,)*dims] * noise_s) / torch.exp(log_alpha_s)[(...,) + (None,)*dims]
411
+ )
412
+ return x_0
413
+
414
+ def dpm_solver_first_update(self, x, s, t, noise_s=None, return_noise=False):
415
+ """
416
+ A single step for DPM-Solver-1.
417
+
418
+ Args:
419
+ x: A pytorch tensor. The initial value at time `s`.
420
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
421
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
422
+ return_noise: A `bool`. If true, also return the predicted noise at time `s`.
423
+ Returns:
424
+ x_t: A pytorch tensor. The approximated solution at time `t`.
425
+ """
426
+ ns = self.noise_schedule
427
+ dims = len(x.shape) - 1
428
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
429
+ h = lambda_t - lambda_s
430
+ log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
431
+ sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
432
+ alpha_t = torch.exp(log_alpha_t)
433
+
434
+ if self.predict_x0:
435
+ phi_1 = (torch.exp(-h) - 1.) / (-1.)
436
+ if noise_s is None:
437
+ noise_s = self.model_fn(x, s)
438
+ x_t = (
439
+ (sigma_t / sigma_s)[(...,) + (None,)*dims] * x
440
+ + (alpha_t * phi_1)[(...,) + (None,)*dims] * noise_s
441
+ )
442
+ if return_noise:
443
+ return x_t, {'noise_s': noise_s}
444
+ else:
445
+ return x_t
446
+ else:
447
+ phi_1 = torch.expm1(h)
448
+ if noise_s is None:
449
+ noise_s = self.model_fn(x, s)
450
+ x_t = (
451
+ torch.exp(log_alpha_t - log_alpha_s)[(...,) + (None,)*dims] * x
452
+ - (sigma_t * phi_1)[(...,) + (None,)*dims] * noise_s
453
+ )
454
+ if return_noise:
455
+ return x_t, {'noise_s': noise_s}
456
+ else:
457
+ return x_t
458
+
459
+ def dpm_solver_second_update(self, x, s, t, r1=0.5, noise_s=None, return_noise=False, solver_type='dpm_solver'):
460
+ """
461
+ A single step for DPM-Solver-2.
462
+
463
+ Args:
464
+ x: A pytorch tensor. The initial value at time `s`.
465
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
466
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
467
+ r1: A `float`. The hyperparameter of the second-order solver. We recommend the default setting `0.5`.
468
+ noise_s: A pytorch tensor. The predicted noise at time `s`.
469
+ If `noise_s` is None, we compute the predicted noise by `x` and `s`; otherwise we directly use it.
470
+ return_noise: A `bool`. If true, also return the predicted noise at time `s` and `s1` (the intermediate time).
471
+ Returns:
472
+ x_t: A pytorch tensor. The approximated solution at time `t`.
473
+ """
474
+ if r1 is None:
475
+ r1 = 0.5
476
+ ns = self.noise_schedule
477
+ dims = len(x.shape) - 1
478
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
479
+ h = lambda_t - lambda_s
480
+ lambda_s1 = lambda_s + r1 * h
481
+ s1 = ns.inverse_lambda(lambda_s1)
482
+ log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(t)
483
+ sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
484
+ alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
485
+
486
+ if self.predict_x0:
487
+ phi_11 = torch.expm1(-r1 * h)
488
+ phi_1 = torch.expm1(-h)
489
+
490
+ if noise_s is None:
491
+ noise_s = self.model_fn(x, s)
492
+ x_s1 = (
493
+ (sigma_s1 / sigma_s)[(...,) + (None,)*dims] * x
494
+ - (alpha_s1 * phi_11)[(...,) + (None,)*dims] * noise_s
495
+ )
496
+ noise_s1 = self.model_fn(x_s1, s1)
497
+ if solver_type == 'dpm_solver':
498
+ x_t = (
499
+ (sigma_t / sigma_s)[(...,) + (None,)*dims] * x
500
+ - (alpha_t * phi_1)[(...,) + (None,)*dims] * noise_s
501
+ - (0.5 / r1) * (alpha_t * phi_1)[(...,) + (None,)*dims] * (noise_s1 - noise_s)
502
+ )
503
+ elif solver_type == 'taylor':
504
+ x_t = (
505
+ (sigma_t / sigma_s)[(...,) + (None,)*dims] * x
506
+ - (alpha_t * phi_1)[(...,) + (None,)*dims] * noise_s
507
+ + (1. / r1) * (alpha_t * ((torch.exp(-h) - 1.) / h + 1.))[(...,) + (None,)*dims] * (noise_s1 - noise_s)
508
+ )
509
+ else:
510
+ raise ValueError("solver_type must be either dpm_solver or taylor, got {}".format(solver_type))
511
+ else:
512
+ phi_11 = torch.expm1(r1 * h)
513
+ phi_1 = torch.expm1(h)
514
+
515
+ if noise_s is None:
516
+ noise_s = self.model_fn(x, s)
517
+ x_s1 = (
518
+ torch.exp(log_alpha_s1 - log_alpha_s)[(...,) + (None,)*dims] * x
519
+ - (sigma_s1 * phi_11)[(...,) + (None,)*dims] * noise_s
520
+ )
521
+ noise_s1 = self.model_fn(x_s1, s1)
522
+ if solver_type == 'dpm_solver':
523
+ x_t = (
524
+ torch.exp(log_alpha_t - log_alpha_s)[(...,) + (None,)*dims] * x
525
+ - (sigma_t * phi_1)[(...,) + (None,)*dims] * noise_s
526
+ - (0.5 / r1) * (sigma_t * phi_1)[(...,) + (None,)*dims] * (noise_s1 - noise_s)
527
+ )
528
+ elif solver_type == 'taylor':
529
+ x_t = (
530
+ torch.exp(log_alpha_t - log_alpha_s)[(...,) + (None,)*dims] * x
531
+ - (sigma_t * phi_1)[(...,) + (None,)*dims] * noise_s
532
+ - (1. / r1) * (sigma_t * ((torch.exp(h) - 1.) / h - 1.))[(...,) + (None,)*dims] * (noise_s1 - noise_s)
533
+ )
534
+ else:
535
+ raise ValueError("solver_type must be either dpm_solver or taylor, got {}".format(solver_type))
536
+ if return_noise:
537
+ return x_t, {'noise_s': noise_s, 'noise_s1': noise_s1}
538
+ else:
539
+ return x_t
540
+
541
+
542
+ def dpm_multistep_second_update(self, x, noise_prev_list, t_prev_list, t, solver_type="dpm_solver"):
543
+ ns = self.noise_schedule
544
+ dims = len(x.shape) - 1
545
+ noise_prev_1, noise_prev_0 = noise_prev_list
546
+ t_prev_1, t_prev_0 = t_prev_list
547
+ lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
548
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
549
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
550
+ alpha_t = torch.exp(log_alpha_t)
551
+
552
+ h_0 = lambda_prev_0 - lambda_prev_1
553
+ h = lambda_t - lambda_prev_0
554
+ r0 = h_0 / h
555
+ D1_0 = (1. / r0)[(...,) + (None,)*dims] * (noise_prev_0 - noise_prev_1)
556
+ if self.predict_x0:
557
+ if solver_type == 'taylor':
558
+ x_t = (
559
+ (sigma_t / sigma_prev_0)[(...,) + (None,)*dims] * x
560
+ - (alpha_t * (torch.exp(-h) - 1.))[(...,) + (None,)*dims] * noise_prev_0
561
+ + (alpha_t * ((torch.exp(-h) - 1.) / h + 1.))[(...,) + (None,)*dims] * D1_0
562
+ )
563
+ elif solver_type == 'dpm_solver':
564
+ x_t = (
565
+ (sigma_t / sigma_prev_0)[(...,) + (None,)*dims] * x
566
+ - (alpha_t * (torch.exp(-h) - 1.))[(...,) + (None,)*dims] * noise_prev_0
567
+ - 0.5 * (alpha_t * (torch.exp(-h) - 1.))[(...,) + (None,)*dims] * D1_0
568
+ )
569
+ else:
570
+ if solver_type == 'taylor':
571
+ x_t = (
572
+ torch.exp(log_alpha_t - log_alpha_prev_0)[(...,) + (None,)*dims] * x
573
+ - (sigma_t * (torch.exp(h) - 1.))[(...,) + (None,)*dims] * noise_prev_0
574
+ - (sigma_t * ((torch.exp(h) - 1.) / h - 1.))[(...,) + (None,)*dims] * D1_0
575
+ )
576
+ elif solver_type == 'dpm_solver':
577
+ x_t = (
578
+ torch.exp(log_alpha_t - log_alpha_prev_0)[(...,) + (None,)*dims] * x
579
+ - (sigma_t * (torch.exp(h) - 1.))[(...,) + (None,)*dims] * noise_prev_0
580
+ - 0.5 * (sigma_t * (torch.exp(h) - 1.))[(...,) + (None,)*dims] * D1_0
581
+ )
582
+ return x_t
583
+
584
+
585
+ def dpm_multistep_third_update(self, x, noise_prev_list, t_prev_list, t, solver_type='dpm_solver'):
586
+ ns = self.noise_schedule
587
+ dims = len(x.shape) - 1
588
+ noise_prev_2, noise_prev_1, noise_prev_0 = noise_prev_list
589
+ t_prev_2, t_prev_1, t_prev_0 = t_prev_list
590
+ lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
591
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
592
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
593
+ alpha_t = torch.exp(log_alpha_t)
594
+
595
+ h_1 = lambda_prev_1 - lambda_prev_2
596
+ h_0 = lambda_prev_0 - lambda_prev_1
597
+ h = lambda_t - lambda_prev_0
598
+ r0, r1 = h_0 / h, h_1 / h
599
+ D1_0 = (1. / r0)[(...,) + (None,)*dims] * (noise_prev_0 - noise_prev_1)
600
+ D1_1 = (1. / r1)[(...,) + (None,)*dims] * (noise_prev_1 - noise_prev_2)
601
+ D1 = D1_0 + (r0 / (r0 + r1))[(...,) + (None,)*dims] * (D1_0 - D1_1)
602
+ D2 = (1. / (r0 + r1))[(...,) + (None,)*dims] * (D1_0 - D1_1)
603
+ if self.predict_x0:
604
+ x_t = (
605
+ (sigma_t / sigma_prev_0)[(...,) + (None,)*dims] * x
606
+ - (alpha_t * (torch.exp(-h) - 1.))[(...,) + (None,)*dims] * noise_prev_0
607
+ + (alpha_t * ((torch.exp(-h) - 1.) / h + 1.))[(...,) + (None,)*dims] * D1
608
+ - (alpha_t * ((torch.exp(-h) - 1. + h) / h**2 - 0.5))[(...,) + (None,)*dims] * D2
609
+ )
610
+ else:
611
+ x_t = (
612
+ torch.exp(log_alpha_t - log_alpha_prev_0)[(...,) + (None,)*dims] * x
613
+ - (sigma_t * (torch.exp(h) - 1.))[(...,) + (None,)*dims] * noise_prev_0
614
+ - (sigma_t * ((torch.exp(h) - 1.) / h - 1.))[(...,) + (None,)*dims] * D1
615
+ - (sigma_t * ((torch.exp(h) - 1. - h) / h**2 - 0.5))[(...,) + (None,)*dims] * D2
616
+ )
617
+ return x_t
618
+
619
+ def dpm_solver_third_update(self, x, s, t, r1=1./3., r2=2./3., noise_s=None, noise_s1=None, noise_s2=None, return_noise=False, solver_type='dpm_solver'):
620
+ """
621
+ A single step for DPM-Solver-3.
622
+
623
+ Args:
624
+ x: A pytorch tensor. The initial value at time `s`.
625
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
626
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
627
+ r1: A `float`. The hyperparameter of the third-order solver. We recommend the default setting `1 / 3`.
628
+ r2: A `float`. The hyperparameter of the third-order solver. We recommend the default setting `2 / 3`.
629
+ noise_s: A pytorch tensor. The predicted noise at time `s`.
630
+ If `noise_s` is None, we compute the predicted noise by `x` and `s`; otherwise we directly use it.
631
+ noise_s1: A pytorch tensor. The predicted noise at time `s1` (the intermediate time given by `r1`).
632
+ If `noise_s1` is None, we compute the predicted noise by `s1`; otherwise we directly use it.
633
+ Returns:
634
+ x_t: A pytorch tensor. The approximated solution at time `t`.
635
+ """
636
+ if r1 is None:
637
+ r1 = 1. / 3.
638
+ if r2 is None:
639
+ r2 = 2. / 3.
640
+ ns = self.noise_schedule
641
+ dims = len(x.shape) - 1
642
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
643
+ h = lambda_t - lambda_s
644
+ lambda_s1 = lambda_s + r1 * h
645
+ lambda_s2 = lambda_s + r2 * h
646
+ s1 = ns.inverse_lambda(lambda_s1)
647
+ s2 = ns.inverse_lambda(lambda_s2)
648
+ log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
649
+ sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(s2), ns.marginal_std(t)
650
+ alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
651
+
652
+ if self.predict_x0:
653
+ phi_11 = torch.expm1(-r1 * h)
654
+ phi_12 = torch.expm1(-r2 * h)
655
+ phi_1 = torch.expm1(-h)
656
+ phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.
657
+ phi_2 = phi_1 / h + 1.
658
+ phi_3 = phi_2 / h - 0.5
659
+
660
+ if noise_s is None:
661
+ noise_s = self.model_fn(x, s)
662
+ if noise_s1 is None:
663
+ x_s1 = (
664
+ (sigma_s1 / sigma_s)[(...,) + (None,)*dims] * x
665
+ - (alpha_s1 * phi_11)[(...,) + (None,)*dims] * noise_s
666
+ )
667
+ noise_s1 = self.model_fn(x_s1, s1)
668
+ if noise_s2 is None:
669
+ x_s2 = (
670
+ (sigma_s2 / sigma_s)[(...,) + (None,)*dims] * x
671
+ - (alpha_s2 * phi_12)[(...,) + (None,)*dims] * noise_s
672
+ + r2 / r1 * (alpha_s2 * phi_22)[(...,) + (None,)*dims] * (noise_s1 - noise_s)
673
+ )
674
+ noise_s2 = self.model_fn(x_s2, s2)
675
+ if solver_type == 'dpm_solver':
676
+ x_t = (
677
+ (sigma_t / sigma_s)[(...,) + (None,)*dims] * x
678
+ - (alpha_t * phi_1)[(...,) + (None,)*dims] * noise_s
679
+ + (1. / r2) * (alpha_t * phi_2)[(...,) + (None,)*dims] * (noise_s2 - noise_s)
680
+ )
681
+ elif solver_type == 'taylor':
682
+ D1_0 = (1. / r1) * (noise_s1 - noise_s)
683
+ D1_1 = (1. / r2) * (noise_s2 - noise_s)
684
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
685
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
686
+ x_t = (
687
+ (sigma_t / sigma_s)[(...,) + (None,)*dims] * x
688
+ - (alpha_t * phi_1)[(...,) + (None,)*dims] * noise_s
689
+ + (alpha_t * phi_2)[(...,) + (None,)*dims] * D1
690
+ - (alpha_t * phi_3)[(...,) + (None,)*dims] * D2
691
+ )
692
+ else:
693
+ raise ValueError("solver_type must be either dpm_solver or dpm_solver++, got {}".format(solver_type))
694
+ else:
695
+ phi_11 = torch.expm1(r1 * h)
696
+ phi_12 = torch.expm1(r2 * h)
697
+ phi_1 = torch.expm1(h)
698
+ phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.
699
+ phi_2 = phi_1 / h - 1.
700
+ phi_3 = phi_2 / h - 0.5
701
+
702
+ if noise_s is None:
703
+ noise_s = self.model_fn(x, s)
704
+ if noise_s1 is None:
705
+ x_s1 = (
706
+ torch.exp(log_alpha_s1 - log_alpha_s)[(...,) + (None,)*dims] * x
707
+ - (sigma_s1 * phi_11)[(...,) + (None,)*dims] * noise_s
708
+ )
709
+ noise_s1 = self.model_fn(x_s1, s1)
710
+ if noise_s2 is None:
711
+ x_s2 = (
712
+ torch.exp(log_alpha_s2 - log_alpha_s)[(...,) + (None,)*dims] * x
713
+ - (sigma_s2 * phi_12)[(...,) + (None,)*dims] * noise_s
714
+ - r2 / r1 * (sigma_s2 * phi_22)[(...,) + (None,)*dims] * (noise_s1 - noise_s)
715
+ )
716
+ noise_s2 = self.model_fn(x_s2, s2)
717
+ if solver_type == 'dpm_solver':
718
+ x_t = (
719
+ torch.exp(log_alpha_t - log_alpha_s)[(...,) + (None,)*dims] * x
720
+ - (sigma_t * phi_1)[(...,) + (None,)*dims] * noise_s
721
+ - (1. / r2) * (sigma_t * phi_2)[(...,) + (None,)*dims] * (noise_s2 - noise_s)
722
+ )
723
+ elif solver_type == 'taylor':
724
+ D1_0 = (1. / r1) * (noise_s1 - noise_s)
725
+ D1_1 = (1. / r2) * (noise_s2 - noise_s)
726
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
727
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
728
+ x_t = (
729
+ torch.exp(log_alpha_t - log_alpha_s)[(...,) + (None,)*dims] * x
730
+ - (sigma_t * phi_1)[(...,) + (None,)*dims] * noise_s
731
+ - (sigma_t * phi_2)[(...,) + (None,)*dims] * D1
732
+ - (sigma_t * phi_3)[(...,) + (None,)*dims] * D2
733
+ )
734
+ else:
735
+ raise ValueError("solver_type must be either dpm_solver or dpm_solver++, got {}".format(solver_type))
736
+
737
+ if return_noise:
738
+ return x_t, {'noise_s': noise_s, 'noise_s1': noise_s1, 'noise_s2': noise_s2}
739
+ else:
740
+ return x_t
741
+
742
+ def dpm_solver_update(self, x, s, t, order, return_noise=False, solver_type='dpm_solver', r1=None, r2=None):
743
+ """
744
+ A single step for DPM-Solver of the given order `order`.
745
+
746
+ Args:
747
+ x: A pytorch tensor. The initial value at time `s`.
748
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
749
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
750
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
751
+ Returns:
752
+ x_t: A pytorch tensor. The approximated solution at time `t`.
753
+ """
754
+ if order == 1:
755
+ return self.dpm_solver_first_update(x, s, t, return_noise=return_noise)
756
+ elif order == 2:
757
+ return self.dpm_solver_second_update(x, s, t, return_noise=return_noise, solver_type=solver_type, r1=r1)
758
+ elif order == 3:
759
+ return self.dpm_solver_third_update(x, s, t, return_noise=return_noise, solver_type=solver_type, r1=r1, r2=r2)
760
+ else:
761
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
762
+
763
+ def dpm_multistep_update(self, x, noise_prev_list, t_prev_list, t, order, solver_type='taylor'):
764
+ """
765
+ A single step for DPM-Solver of the given order `order`.
766
+
767
+ Args:
768
+ x: A pytorch tensor. The initial value at time `s`.
769
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
770
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
771
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
772
+ Returns:
773
+ x_t: A pytorch tensor. The approximated solution at time `t`.
774
+ """
775
+ if order == 1:
776
+ return self.dpm_solver_first_update(x, t_prev_list[-1], t, noise_s=noise_prev_list[-1])
777
+ elif order == 2:
778
+ return self.dpm_multistep_second_update(x, noise_prev_list, t_prev_list, t, solver_type=solver_type)
779
+ elif order == 3:
780
+ return self.dpm_multistep_third_update(x, noise_prev_list, t_prev_list, t, solver_type=solver_type)
781
+ else:
782
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
783
+
784
+ def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type='dpm_solver'):
785
+ """
786
+ The adaptive step size solver based on DPM-Solver.
787
+
788
+ Args:
789
+ x: A pytorch tensor. The initial value at time `t_T`.
790
+ order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
791
+ t_T: A `float`. The starting time of the sampling (default is T).
792
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
793
+ h_init: A `float`. The initial step size (for logSNR).
794
+ atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
795
+ rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
796
+ theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
797
+ t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
798
+ current time and `t_0` is less than `t_err`. The default setting is 1e-5.
799
+ Returns:
800
+ x_0: A pytorch tensor. The approximated solution at time `t_0`.
801
+
802
+ [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
803
+ """
804
+ ns = self.noise_schedule
805
+ s = t_T * torch.ones((x.shape[0],)).to(x)
806
+ lambda_s = ns.marginal_lambda(s)
807
+ lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
808
+ h = h_init * torch.ones_like(s).to(x)
809
+ x_prev = x
810
+ nfe = 0
811
+ if order == 2:
812
+ r1 = 0.5
813
+ lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_noise=True)
814
+ higher_update = lambda x, s, t, **kwargs: self.dpm_solver_second_update(x, s, t, r1=r1, solver_type=solver_type, **kwargs)
815
+ elif order == 3:
816
+ r1, r2 = 1. / 3., 2. / 3.
817
+ lower_update = lambda x, s, t: self.dpm_solver_second_update(x, s, t, r1=r1, return_noise=True, solver_type=solver_type)
818
+ higher_update = lambda x, s, t, **kwargs: self.dpm_solver_third_update(x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs)
819
+ else:
820
+ raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
821
+ while torch.abs((s - t_0)).mean() > t_err:
822
+ t = ns.inverse_lambda(lambda_s + h)
823
+ x_lower, lower_noise_kwargs = lower_update(x, s, t)
824
+ x_higher = higher_update(x, s, t, **lower_noise_kwargs)
825
+ delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
826
+ norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
827
+ E = norm_fn((x_higher - x_lower) / delta).max()
828
+ if torch.all(E <= 1.):
829
+ x = x_higher
830
+ s = t
831
+ x_prev = x_lower
832
+ lambda_s = ns.marginal_lambda(s)
833
+ h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s)
834
+ nfe += order
835
+ print('adaptive solver nfe', nfe)
836
+ return x
837
+
838
+ def sample(self, x, steps=10, eps=1e-4, T=None, order=3, skip_type='time_uniform',
839
+ denoise=False, method='fast', solver_type='dpm_solver', atol=0.0078,
840
+ rtol=0.05,
841
+ ):
842
+ """
843
+ Compute the sample at time `eps` by DPM-Solver, given the initial `x` at time `T`.
844
+
845
+ We support the following algorithms:
846
+
847
+ - Adaptive step size DPM-Solver (i.e. DPM-Solver-12 and DPM-Solver-23)
848
+
849
+ - Fixed order DPM-Solver (i.e. DPM-Solver-1, DPM-Solver-2 and DPM-Solver-3).
850
+
851
+ - Fast version of DPM-Solver (i.e. DPM-Solver-fast), which uses uniform logSNR steps and combine
852
+ different orders of DPM-Solver.
853
+
854
+ **We recommend DPM-Solver-fast for both fast sampling in few steps (<=20) and fast convergence in many steps (50 to 100).**
855
+
856
+ Choosing the algorithms:
857
+
858
+ - If `adaptive_step_size` is True:
859
+ We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
860
+ If `order`=2, we use DPM-Solver-12 which combines DPM-Solver-1 and DPM-Solver-2.
861
+ If `order`=3, we use DPM-Solver-23 which combines DPM-Solver-2 and DPM-Solver-3.
862
+ You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
863
+ (NFE) and the sample quality.
864
+
865
+ - If `adaptive_step_size` is False and `fast_version` is True:
866
+ We ignore `order` and use DPM-Solver-fast with number of function evaluations (NFE) = `steps`.
867
+ We ignore `skip_type` and use uniform logSNR steps for DPM-Solver-fast.
868
+ Given a fixed NFE=`steps`, the sampling procedure by DPM-Solver-fast is:
869
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
870
+ - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
871
+ - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
872
+ - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
873
+
874
+ - If `adaptive_step_size` is False and `fast_version` is False:
875
+ We use DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
876
+ We support three types of `skip_type`:
877
+ - 'logSNR': uniform logSNR for the time steps, **recommended for DPM-Solver**.
878
+ - 'time_uniform': uniform time for the time steps. (Used in DDIM and DDPM.)
879
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM.)
880
+
881
+ =====================================================
882
+ Args:
883
+ x: A pytorch tensor. The initial value at time `T` (a sample from the normal distribution).
884
+ steps: A `int`. The total number of function evaluations (NFE).
885
+ eps: A `float`. The ending time of the sampling.
886
+ We recommend `eps`=1e-3 when `steps` <= 15; and `eps`=1e-4 when `steps` > 15.
887
+ T: A `float`. The starting time of the sampling. Default is `None`.
888
+ If `T` is None, we use self.noise_schedule.T.
889
+ order: A `int`. The order of DPM-Solver.
890
+ skip_type: A `str`. The type for the spacing of the time steps. Default is 'logSNR'.
891
+ adaptive_step_size: A `bool`. If true, use the adaptive step size DPM-Solver.
892
+ fast_version: A `bool`. If true, use DPM-Solver-fast (recommended).
893
+ atol: A `float`. The absolute tolerance of the adaptive step size solver.
894
+ rtol: A `float`. The relative tolerance of the adaptive step size solver.
895
+ Returns:
896
+ x_0: A pytorch tensor. The approximated solution at time `t_0`.
897
+
898
+ [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
899
+ """
900
+ t_0 = eps
901
+ t_T = self.noise_schedule.T if T is None else T
902
+ device = x.device
903
+ if method == 'adaptive':
904
+ with torch.no_grad():
905
+ x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type)
906
+ elif method == 'multistep':
907
+ assert steps >= order
908
+ if timesteps is None:
909
+ timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
910
+ assert timesteps.shape[0] - 1 == steps
911
+ with torch.no_grad():
912
+ vec_t = timesteps[0].expand((x.shape[0]))
913
+ noise_prev_list = [self.model_fn(x, vec_t)]
914
+ t_prev_list = [vec_t]
915
+ for init_order in range(1, order):
916
+ vec_t = timesteps[init_order].expand(x.shape[0])
917
+ x = self.dpm_multistep_update(x, noise_prev_list, t_prev_list, vec_t, init_order, solver_type=solver_type)
918
+ noise_prev_list.append(self.model_fn(x, vec_t))
919
+ t_prev_list.append(vec_t)
920
+ for step in range(order, steps + 1):
921
+ vec_t = timesteps[step].expand(x.shape[0])
922
+ x = self.dpm_multistep_update(x, noise_prev_list, t_prev_list, vec_t, order, solver_type=solver_type)
923
+ for i in range(order - 1):
924
+ t_prev_list[i] = t_prev_list[i + 1]
925
+ noise_prev_list[i] = noise_prev_list[i + 1]
926
+ t_prev_list[-1] = vec_t
927
+ if step < steps:
928
+ noise_prev_list[-1] = self.model_fn(x, vec_t)
929
+ elif method == 'fast':
930
+ orders, _ = self.get_time_steps_for_dpm_solver_fast(skip_type=skip_type, t_T=t_T, t_0=t_0, steps=steps, order=order, device=device)
931
+ timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
932
+ with torch.no_grad():
933
+ i = 0
934
+ for order in orders:
935
+ vec_s, vec_t = torch.ones((x.shape[0],)).to(device) * timesteps[i], torch.ones((x.shape[0],)).to(device) * timesteps[i + order]
936
+ h = self.noise_schedule.marginal_lambda(timesteps[i + order]) - self.noise_schedule.marginal_lambda(timesteps[i])
937
+ r1 = None if order <= 1 else (self.noise_schedule.marginal_lambda(timesteps[i + 1]) - self.noise_schedule.marginal_lambda(timesteps[i])) / h
938
+ r2 = None if order <= 2 else (self.noise_schedule.marginal_lambda(timesteps[i + 2]) - self.noise_schedule.marginal_lambda(timesteps[i])) / h
939
+ x = self.dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2)
940
+ i += order
941
+ elif method == 'singlestep':
942
+ N_steps = steps // order
943
+ orders = [order,] * N_steps
944
+ timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=N_steps, device=device)
945
+ assert len(timesteps) - 1 == N_steps
946
+ with torch.no_grad():
947
+ for i, order in enumerate(orders):
948
+ vec_s, vec_t = torch.ones((x.shape[0],)).to(device) * timesteps[i], torch.ones((x.shape[0],)).to(device) * timesteps[i + 1]
949
+ x = self.dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type)
950
+ if denoise:
951
+ x = self.denoise_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
952
+ return x
libs/autoencoder.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from einops import rearrange
5
+
6
+
7
+ class LinearAttention(nn.Module):
8
+ def __init__(self, dim, heads=4, dim_head=32):
9
+ super().__init__()
10
+ self.heads = heads
11
+ hidden_dim = dim_head * heads
12
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
13
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
14
+
15
+ def forward(self, x):
16
+ b, c, h, w = x.shape
17
+ qkv = self.to_qkv(x)
18
+ q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
19
+ k = k.softmax(dim=-1)
20
+ context = torch.einsum('bhdn,bhen->bhde', k, v)
21
+ out = torch.einsum('bhde,bhdn->bhen', context, q)
22
+ out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
23
+ return self.to_out(out)
24
+
25
+
26
+ def nonlinearity(x):
27
+ # swish
28
+ return x*torch.sigmoid(x)
29
+
30
+
31
+ def Normalize(in_channels, num_groups=32):
32
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
33
+
34
+
35
+ class Upsample(nn.Module):
36
+ def __init__(self, in_channels, with_conv):
37
+ super().__init__()
38
+ self.with_conv = with_conv
39
+ if self.with_conv:
40
+ self.conv = torch.nn.Conv2d(in_channels,
41
+ in_channels,
42
+ kernel_size=3,
43
+ stride=1,
44
+ padding=1)
45
+
46
+ def forward(self, x):
47
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
48
+ if self.with_conv:
49
+ x = self.conv(x)
50
+ return x
51
+
52
+
53
+ class Downsample(nn.Module):
54
+ def __init__(self, in_channels, with_conv):
55
+ super().__init__()
56
+ self.with_conv = with_conv
57
+ if self.with_conv:
58
+ # no asymmetric padding in torch conv, must do it ourselves
59
+ self.conv = torch.nn.Conv2d(in_channels,
60
+ in_channels,
61
+ kernel_size=3,
62
+ stride=2,
63
+ padding=0)
64
+
65
+ def forward(self, x):
66
+ if self.with_conv:
67
+ pad = (0,1,0,1)
68
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
69
+ x = self.conv(x)
70
+ else:
71
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
72
+ return x
73
+
74
+
75
+ class ResnetBlock(nn.Module):
76
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
77
+ dropout, temb_channels=512):
78
+ super().__init__()
79
+ self.in_channels = in_channels
80
+ out_channels = in_channels if out_channels is None else out_channels
81
+ self.out_channels = out_channels
82
+ self.use_conv_shortcut = conv_shortcut
83
+
84
+ self.norm1 = Normalize(in_channels)
85
+ self.conv1 = torch.nn.Conv2d(in_channels,
86
+ out_channels,
87
+ kernel_size=3,
88
+ stride=1,
89
+ padding=1)
90
+ if temb_channels > 0:
91
+ self.temb_proj = torch.nn.Linear(temb_channels,
92
+ out_channels)
93
+ self.norm2 = Normalize(out_channels)
94
+ self.dropout = torch.nn.Dropout(dropout)
95
+ self.conv2 = torch.nn.Conv2d(out_channels,
96
+ out_channels,
97
+ kernel_size=3,
98
+ stride=1,
99
+ padding=1)
100
+ if self.in_channels != self.out_channels:
101
+ if self.use_conv_shortcut:
102
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
103
+ out_channels,
104
+ kernel_size=3,
105
+ stride=1,
106
+ padding=1)
107
+ else:
108
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
109
+ out_channels,
110
+ kernel_size=1,
111
+ stride=1,
112
+ padding=0)
113
+
114
+ def forward(self, x, temb):
115
+ h = x
116
+ h = self.norm1(h)
117
+ h = nonlinearity(h)
118
+ h = self.conv1(h)
119
+
120
+ if temb is not None:
121
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
122
+
123
+ h = self.norm2(h)
124
+ h = nonlinearity(h)
125
+ h = self.dropout(h)
126
+ h = self.conv2(h)
127
+
128
+ if self.in_channels != self.out_channels:
129
+ if self.use_conv_shortcut:
130
+ x = self.conv_shortcut(x)
131
+ else:
132
+ x = self.nin_shortcut(x)
133
+
134
+ return x+h
135
+
136
+
137
+ class LinAttnBlock(LinearAttention):
138
+ """to match AttnBlock usage"""
139
+ def __init__(self, in_channels):
140
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
141
+
142
+
143
+ class AttnBlock(nn.Module):
144
+ def __init__(self, in_channels):
145
+ super().__init__()
146
+ self.in_channels = in_channels
147
+
148
+ self.norm = Normalize(in_channels)
149
+ self.q = torch.nn.Conv2d(in_channels,
150
+ in_channels,
151
+ kernel_size=1,
152
+ stride=1,
153
+ padding=0)
154
+ self.k = torch.nn.Conv2d(in_channels,
155
+ in_channels,
156
+ kernel_size=1,
157
+ stride=1,
158
+ padding=0)
159
+ self.v = torch.nn.Conv2d(in_channels,
160
+ in_channels,
161
+ kernel_size=1,
162
+ stride=1,
163
+ padding=0)
164
+ self.proj_out = torch.nn.Conv2d(in_channels,
165
+ in_channels,
166
+ kernel_size=1,
167
+ stride=1,
168
+ padding=0)
169
+
170
+
171
+ def forward(self, x):
172
+ h_ = x
173
+ h_ = self.norm(h_)
174
+ q = self.q(h_)
175
+ k = self.k(h_)
176
+ v = self.v(h_)
177
+
178
+ # compute attention
179
+ b,c,h,w = q.shape
180
+ q = q.reshape(b,c,h*w)
181
+ q = q.permute(0,2,1) # b,hw,c
182
+ k = k.reshape(b,c,h*w) # b,c,hw
183
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
184
+ w_ = w_ * (int(c)**(-0.5))
185
+ w_ = torch.nn.functional.softmax(w_, dim=2)
186
+
187
+ # attend to values
188
+ v = v.reshape(b,c,h*w)
189
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
190
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
191
+ h_ = h_.reshape(b,c,h,w)
192
+
193
+ h_ = self.proj_out(h_)
194
+
195
+ return x+h_
196
+
197
+
198
+ def make_attn(in_channels, attn_type="vanilla"):
199
+ assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
200
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
201
+ if attn_type == "vanilla":
202
+ return AttnBlock(in_channels)
203
+ elif attn_type == "none":
204
+ return nn.Identity(in_channels)
205
+ else:
206
+ return LinAttnBlock(in_channels)
207
+
208
+
209
+ class Encoder(nn.Module):
210
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
211
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
212
+ resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
213
+ **ignore_kwargs):
214
+ super().__init__()
215
+ if use_linear_attn: attn_type = "linear"
216
+ self.ch = ch
217
+ self.temb_ch = 0
218
+ self.num_resolutions = len(ch_mult)
219
+ self.num_res_blocks = num_res_blocks
220
+ self.resolution = resolution
221
+ self.in_channels = in_channels
222
+
223
+ # downsampling
224
+ self.conv_in = torch.nn.Conv2d(in_channels,
225
+ self.ch,
226
+ kernel_size=3,
227
+ stride=1,
228
+ padding=1)
229
+
230
+ curr_res = resolution
231
+ in_ch_mult = (1,)+tuple(ch_mult)
232
+ self.in_ch_mult = in_ch_mult
233
+ self.down = nn.ModuleList()
234
+ for i_level in range(self.num_resolutions):
235
+ block = nn.ModuleList()
236
+ attn = nn.ModuleList()
237
+ block_in = ch*in_ch_mult[i_level]
238
+ block_out = ch*ch_mult[i_level]
239
+ for i_block in range(self.num_res_blocks):
240
+ block.append(ResnetBlock(in_channels=block_in,
241
+ out_channels=block_out,
242
+ temb_channels=self.temb_ch,
243
+ dropout=dropout))
244
+ block_in = block_out
245
+ if curr_res in attn_resolutions:
246
+ attn.append(make_attn(block_in, attn_type=attn_type))
247
+ down = nn.Module()
248
+ down.block = block
249
+ down.attn = attn
250
+ if i_level != self.num_resolutions-1:
251
+ down.downsample = Downsample(block_in, resamp_with_conv)
252
+ curr_res = curr_res // 2
253
+ self.down.append(down)
254
+
255
+ # middle
256
+ self.mid = nn.Module()
257
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
258
+ out_channels=block_in,
259
+ temb_channels=self.temb_ch,
260
+ dropout=dropout)
261
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
262
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
263
+ out_channels=block_in,
264
+ temb_channels=self.temb_ch,
265
+ dropout=dropout)
266
+
267
+ # end
268
+ self.norm_out = Normalize(block_in)
269
+ self.conv_out = torch.nn.Conv2d(block_in,
270
+ 2*z_channels if double_z else z_channels,
271
+ kernel_size=3,
272
+ stride=1,
273
+ padding=1)
274
+
275
+ def forward(self, x):
276
+ # timestep embedding
277
+ temb = None
278
+
279
+ # downsampling
280
+ hs = [self.conv_in(x)]
281
+ for i_level in range(self.num_resolutions):
282
+ for i_block in range(self.num_res_blocks):
283
+ h = self.down[i_level].block[i_block](hs[-1], temb)
284
+ if len(self.down[i_level].attn) > 0:
285
+ h = self.down[i_level].attn[i_block](h)
286
+ hs.append(h)
287
+ if i_level != self.num_resolutions-1:
288
+ hs.append(self.down[i_level].downsample(hs[-1]))
289
+
290
+ # middle
291
+ h = hs[-1]
292
+ h = self.mid.block_1(h, temb)
293
+ h = self.mid.attn_1(h)
294
+ h = self.mid.block_2(h, temb)
295
+
296
+ # end
297
+ h = self.norm_out(h)
298
+ h = nonlinearity(h)
299
+ h = self.conv_out(h)
300
+ return h
301
+
302
+
303
+ class Decoder(nn.Module):
304
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
305
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
306
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
307
+ attn_type="vanilla", **ignorekwargs):
308
+ super().__init__()
309
+ if use_linear_attn: attn_type = "linear"
310
+ self.ch = ch
311
+ self.temb_ch = 0
312
+ self.num_resolutions = len(ch_mult)
313
+ self.num_res_blocks = num_res_blocks
314
+ self.resolution = resolution
315
+ self.in_channels = in_channels
316
+ self.give_pre_end = give_pre_end
317
+ self.tanh_out = tanh_out
318
+
319
+ # compute in_ch_mult, block_in and curr_res at lowest res
320
+ in_ch_mult = (1,)+tuple(ch_mult)
321
+ block_in = ch*ch_mult[self.num_resolutions-1]
322
+ curr_res = resolution // 2**(self.num_resolutions-1)
323
+ self.z_shape = (1,z_channels,curr_res,curr_res)
324
+ print("Working with z of shape {} = {} dimensions.".format(
325
+ self.z_shape, np.prod(self.z_shape)))
326
+
327
+ # z to block_in
328
+ self.conv_in = torch.nn.Conv2d(z_channels,
329
+ block_in,
330
+ kernel_size=3,
331
+ stride=1,
332
+ padding=1)
333
+
334
+ # middle
335
+ self.mid = nn.Module()
336
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
337
+ out_channels=block_in,
338
+ temb_channels=self.temb_ch,
339
+ dropout=dropout)
340
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
341
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
342
+ out_channels=block_in,
343
+ temb_channels=self.temb_ch,
344
+ dropout=dropout)
345
+
346
+ # upsampling
347
+ self.up = nn.ModuleList()
348
+ for i_level in reversed(range(self.num_resolutions)):
349
+ block = nn.ModuleList()
350
+ attn = nn.ModuleList()
351
+ block_out = ch*ch_mult[i_level]
352
+ for i_block in range(self.num_res_blocks+1):
353
+ block.append(ResnetBlock(in_channels=block_in,
354
+ out_channels=block_out,
355
+ temb_channels=self.temb_ch,
356
+ dropout=dropout))
357
+ block_in = block_out
358
+ if curr_res in attn_resolutions:
359
+ attn.append(make_attn(block_in, attn_type=attn_type))
360
+ up = nn.Module()
361
+ up.block = block
362
+ up.attn = attn
363
+ if i_level != 0:
364
+ up.upsample = Upsample(block_in, resamp_with_conv)
365
+ curr_res = curr_res * 2
366
+ self.up.insert(0, up) # prepend to get consistent order
367
+
368
+ # end
369
+ self.norm_out = Normalize(block_in)
370
+ self.conv_out = torch.nn.Conv2d(block_in,
371
+ out_ch,
372
+ kernel_size=3,
373
+ stride=1,
374
+ padding=1)
375
+
376
+ def forward(self, z):
377
+ #assert z.shape[1:] == self.z_shape[1:]
378
+ self.last_z_shape = z.shape
379
+
380
+ # timestep embedding
381
+ temb = None
382
+
383
+ # z to block_in
384
+ h = self.conv_in(z)
385
+
386
+ # middle
387
+ h = self.mid.block_1(h, temb)
388
+ h = self.mid.attn_1(h)
389
+ h = self.mid.block_2(h, temb)
390
+
391
+ # upsampling
392
+ for i_level in reversed(range(self.num_resolutions)):
393
+ for i_block in range(self.num_res_blocks+1):
394
+ h = self.up[i_level].block[i_block](h, temb)
395
+ if len(self.up[i_level].attn) > 0:
396
+ h = self.up[i_level].attn[i_block](h)
397
+ if i_level != 0:
398
+ h = self.up[i_level].upsample(h)
399
+
400
+ # end
401
+ if self.give_pre_end:
402
+ return h
403
+
404
+ h = self.norm_out(h)
405
+ h = nonlinearity(h)
406
+ h = self.conv_out(h)
407
+ if self.tanh_out:
408
+ h = torch.tanh(h)
409
+ return h
410
+
411
+
412
+ class FrozenAutoencoderKL(nn.Module):
413
+ def __init__(self, ddconfig, embed_dim, pretrained_path, scale_factor=0.18215):
414
+ super().__init__()
415
+ print(f'Create autoencoder with scale_factor={scale_factor}')
416
+ self.encoder = Encoder(**ddconfig)
417
+ self.decoder = Decoder(**ddconfig)
418
+ assert ddconfig["double_z"]
419
+ self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
420
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
421
+ self.embed_dim = embed_dim
422
+ self.scale_factor = scale_factor
423
+ m, u = self.load_state_dict(torch.load(pretrained_path, map_location='cpu'))
424
+ assert len(m) == 0 and len(u) == 0
425
+ self.eval()
426
+ self.requires_grad_(False)
427
+
428
+ def encode_moments(self, x):
429
+ h = self.encoder(x)
430
+ moments = self.quant_conv(h)
431
+ return moments
432
+
433
+ def sample(self, moments):
434
+ mean, logvar = torch.chunk(moments, 2, dim=1)
435
+ logvar = torch.clamp(logvar, -30.0, 20.0)
436
+ std = torch.exp(0.5 * logvar)
437
+ z = mean + std * torch.randn_like(mean)
438
+ z = self.scale_factor * z
439
+ return z
440
+
441
+ def encode(self, x):
442
+ moments = self.encode_moments(x)
443
+ z = self.sample(moments)
444
+ return z
445
+
446
+ def decode(self, z):
447
+ z = (1. / self.scale_factor) * z
448
+ z = self.post_quant_conv(z)
449
+ dec = self.decoder(z)
450
+ return dec
451
+
452
+ def forward(self, inputs, fn):
453
+ if fn == 'encode_moments':
454
+ return self.encode_moments(inputs)
455
+ elif fn == 'encode':
456
+ return self.encode(inputs)
457
+ elif fn == 'decode':
458
+ return self.decode(inputs)
459
+ else:
460
+ raise NotImplementedError
461
+
462
+
463
+ def get_model(pretrained_path, scale_factor=0.18215):
464
+ ddconfig = dict(
465
+ double_z=True,
466
+ z_channels=4,
467
+ resolution=256,
468
+ in_channels=3,
469
+ out_ch=3,
470
+ ch=128,
471
+ ch_mult=[1, 2, 4, 4],
472
+ num_res_blocks=2,
473
+ attn_resolutions=[],
474
+ dropout=0.0
475
+ )
476
+ return FrozenAutoencoderKL(ddconfig, 4, pretrained_path, scale_factor)
477
+
478
+
479
+ def main():
480
+ import torchvision.transforms as transforms
481
+ from torchvision.utils import save_image
482
+ import os
483
+ from PIL import Image
484
+
485
+ model = get_model('assets/stable-diffusion/autoencoder_kl.pth')
486
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
487
+ model = model.to(device)
488
+
489
+ scale_factor = 0.18215
490
+ T = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(256), transforms.ToTensor()])
491
+ path = 'imgs'
492
+ fnames = os.listdir(path)
493
+ for fname in fnames:
494
+ p = os.path.join(path, fname)
495
+ img = Image.open(p)
496
+ img = T(img)
497
+ img = img * 2. - 1
498
+ img = img[None, ...]
499
+ img = img.to(device)
500
+
501
+ # with torch.cuda.amp.autocast():
502
+ # moments = model.encode_moments(img)
503
+ # mean, logvar = torch.chunk(moments, 2, dim=1)
504
+ # logvar = torch.clamp(logvar, -30.0, 20.0)
505
+ # std = torch.exp(0.5 * logvar)
506
+ # zs = [(mean + std * torch.randn_like(mean)) * scale_factor for _ in range(4)]
507
+ # recons = [model.decode(z) for z in zs]
508
+
509
+ with torch.cuda.amp.autocast():
510
+ print('test encode & decode')
511
+ recons = [model.decode(model.encode(img)) for _ in range(4)]
512
+
513
+ out = torch.cat([img, *recons], dim=0)
514
+ out = (out + 1) * 0.5
515
+ save_image(out, f'recons_{fname}')
516
+
517
+
518
+ if __name__ == "__main__":
519
+ main()
libs/caption_decoder.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as nnf
6
+
7
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel
8
+ from transformers import default_data_collator
9
+ from transformers import EarlyStoppingCallback
10
+
11
+ data_collator = default_data_collator
12
+ es = EarlyStoppingCallback(early_stopping_patience=5)
13
+ import json
14
+ import argparse
15
+ from typing import Union, Optional
16
+ from collections import OrderedDict
17
+
18
+
19
+ # %% model initial
20
+ class ClipCaptionModel(nn.Module):
21
+ """
22
+ """
23
+
24
+ def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor:
25
+ return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
26
+
27
+ def forward(self, tokens: torch.Tensor, prefix: torch.Tensor, mask: Optional[torch.Tensor] = None,
28
+ labels: Optional[torch.Tensor] = None):
29
+ """
30
+ : param tokens: (Tensor) [N x max_seq_len] eg. [4 X 33]
31
+ : param prefix: (Tensor) [N x prefix_length x 768] eg. [4 x 77 x 768]
32
+ : param mask: (Tensor) [N x (prefix_length + max_seq_len) x 768] eg. [4 x 110 x768]
33
+
34
+ : attribute embedding_text: (Tensor) [N x max_seq_len x 768] eg. [4 x 33 x 768]
35
+ : attribute embedding_cat: (Tensor) [N x (prefix_length + max_seq_len) x 768] eg. [4 x 110 x 768]
36
+ """
37
+ embedding_text = self.gpt.transformer.wte(tokens)
38
+ hidden = self.encode_prefix(prefix)
39
+ prefix = self.decode_prefix(hidden)
40
+ embedding_cat = torch.cat((prefix, embedding_text), dim=1)
41
+
42
+ if labels is not None:
43
+ dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
44
+ labels = torch.cat((dummy_token, tokens), dim=1)
45
+ out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
46
+ if self.hidden_dim is not None:
47
+ return out, hidden
48
+ else:
49
+ return out
50
+
51
+ def encode_decode_prefix(self, prefix):
52
+ return self.decode_prefix(self.encode_prefix(prefix))
53
+
54
+ def __init__(self, prefix_length: int, hidden_dim=None):
55
+ super(ClipCaptionModel, self).__init__()
56
+ self.prefix_length = prefix_length
57
+ eos = '<|EOS|>'
58
+ special_tokens_dict = {'eos_token': eos}
59
+ base_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
60
+ base_tokenizer.add_special_tokens(special_tokens_dict)
61
+ self.gpt = GPT2LMHeadModel.from_pretrained('gpt2', eos_token_id=base_tokenizer.eos_token_id)
62
+ self.gpt.resize_token_embeddings(len(base_tokenizer))
63
+
64
+ self.hidden_dim = hidden_dim
65
+ self.encode_prefix = nn.Linear(768, hidden_dim) if hidden_dim is not None else nn.Identity()
66
+ self.decode_prefix = nn.Linear(hidden_dim, 768) if hidden_dim is not None else nn.Identity()
67
+
68
+
69
+
70
+
71
+ def load_model(config_path: str, epoch_or_latest: Union[str, int] = '_latest'):
72
+ with open(config_path) as f:
73
+ config = json.load(f)
74
+ parser = argparse.ArgumentParser()
75
+ parser.set_defaults(**config)
76
+ args = parser.parse_args()
77
+ if type(epoch_or_latest) is int:
78
+ epoch_or_latest = f"-{epoch_or_latest:03d}"
79
+ model_path = os.path.join(args.out_dir, f"{args.prefix}{epoch_or_latest}.pt")
80
+ model = ClipCaptionModel(args.prefix_length)
81
+ if os.path.isfile(model_path):
82
+ print(f"loading model from {model_path}")
83
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
84
+ else:
85
+ print(f"{model_path} is not exist")
86
+ return model, parser
87
+
88
+
89
+ def generate_beam(
90
+ model,
91
+ tokenizer,
92
+ beam_size: int = 5,
93
+ prompt=None,
94
+ embed=None,
95
+ entry_length=67,
96
+ temperature=1.0,
97
+ stop_token: str = '<|EOS|>',
98
+ ):
99
+ model.eval()
100
+ stop_token_index = tokenizer.encode(stop_token)[0]
101
+ tokens = None
102
+ scores = None
103
+ device = next(model.parameters()).device
104
+ seq_lengths = torch.ones(beam_size, device=device)
105
+ is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
106
+ with torch.no_grad():
107
+ if embed is not None:
108
+ generated = embed
109
+ else:
110
+ if tokens is None:
111
+ tokens = torch.tensor(tokenizer.encode(prompt))
112
+ tokens = tokens.unsqueeze(0).to(device)
113
+ generated = model.gpt.transformer.wte(tokens)
114
+ # pbar = tqdm(range(entry_length))
115
+ # pbar.set_description("generating text ...")
116
+ for i in range(entry_length):
117
+ # print(generated.shape)
118
+ outputs = model.gpt(inputs_embeds=generated)
119
+ logits = outputs.logits
120
+ logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
121
+ logits = logits.softmax(-1).log()
122
+ if scores is None:
123
+ scores, next_tokens = logits.topk(beam_size, -1)
124
+ generated = generated.expand(beam_size, *generated.shape[1:])
125
+ next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
126
+ if tokens is None:
127
+ tokens = next_tokens
128
+ else:
129
+ tokens = tokens.expand(beam_size, *tokens.shape[1:])
130
+ tokens = torch.cat((tokens, next_tokens), dim=1)
131
+ else:
132
+ logits[is_stopped] = -float(np.inf)
133
+ logits[is_stopped, 0] = 0
134
+ scores_sum = scores[:, None] + logits
135
+ seq_lengths[~is_stopped] += 1
136
+ scores_sum_average = scores_sum / seq_lengths[:, None]
137
+ scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(
138
+ beam_size, -1
139
+ )
140
+ next_tokens_source = next_tokens // scores_sum.shape[1]
141
+ seq_lengths = seq_lengths[next_tokens_source]
142
+ next_tokens = next_tokens % scores_sum.shape[1]
143
+ next_tokens = next_tokens.unsqueeze(1)
144
+ tokens = tokens[next_tokens_source]
145
+ tokens = torch.cat((tokens, next_tokens), dim=1)
146
+ generated = generated[next_tokens_source]
147
+ scores = scores_sum_average * seq_lengths
148
+ is_stopped = is_stopped[next_tokens_source]
149
+ next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view(
150
+ generated.shape[0], 1, -1
151
+ )
152
+ generated = torch.cat((generated, next_token_embed), dim=1)
153
+ is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
154
+ if is_stopped.all():
155
+ break
156
+ scores = scores / seq_lengths
157
+ output_list = tokens.cpu().numpy()
158
+ output_texts = [
159
+ tokenizer.decode(output[: int(length)], skip_special_tokens=True)
160
+ for output, length in zip(output_list, seq_lengths)
161
+ ]
162
+ order = scores.argsort(descending=True)
163
+ output_texts = [output_texts[i] for i in order]
164
+ model.train()
165
+ return output_texts
166
+
167
+
168
+ def generate2(
169
+ model,
170
+ tokenizer,
171
+ tokens=None,
172
+ prompt=None,
173
+ embed=None,
174
+ entry_count=1,
175
+ entry_length=67, # maximum number of words
176
+ top_p=0.8,
177
+ temperature=1.0,
178
+ stop_token: str = '<|EOS|>',
179
+ ):
180
+ model.eval()
181
+ generated_num = 0
182
+ generated_list = []
183
+ stop_token_index = tokenizer.encode(stop_token)[0]
184
+ filter_value = -float("Inf")
185
+ device = next(model.parameters()).device
186
+
187
+ with torch.no_grad():
188
+
189
+ for entry_idx in range(entry_count):
190
+ if embed is not None:
191
+ generated = embed
192
+ else:
193
+ if tokens is None:
194
+ tokens = torch.tensor(tokenizer.encode(prompt))
195
+ tokens = tokens.unsqueeze(0).to(device)
196
+
197
+ generated = model.gpt.transformer.wte(tokens)
198
+
199
+ for i in range(entry_length):
200
+
201
+ outputs = model.gpt(inputs_embeds=generated)
202
+ logits = outputs.logits
203
+ logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
204
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
205
+ cumulative_probs = torch.cumsum(
206
+ nnf.softmax(sorted_logits, dim=-1), dim=-1
207
+ )
208
+ sorted_indices_to_remove = cumulative_probs > top_p
209
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
210
+ ..., :-1
211
+ ].clone()
212
+ sorted_indices_to_remove[..., 0] = 0
213
+
214
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
215
+ logits[:, indices_to_remove] = filter_value
216
+ next_token = torch.argmax(logits, -1).unsqueeze(0)
217
+ next_token_embed = model.gpt.transformer.wte(next_token)
218
+ if tokens is None:
219
+ tokens = next_token
220
+ else:
221
+ tokens = torch.cat((tokens, next_token), dim=1)
222
+ generated = torch.cat((generated, next_token_embed), dim=1)
223
+ if stop_token_index == next_token.item():
224
+ break
225
+
226
+ output_list = list(tokens.squeeze().cpu().numpy())
227
+ output_text = tokenizer.decode(output_list)
228
+ generated_list.append(output_text)
229
+
230
+ return generated_list[0]
231
+
232
+
233
+ class CaptionDecoder(object):
234
+ def __init__(self, device, pretrained_path, hidden_dim=-1):
235
+ if hidden_dim < 0:
236
+ hidden_dim = None
237
+ # tokenizer initialize
238
+ eos = '<|EOS|>'
239
+ special_tokens_dict = {'eos_token': eos}
240
+ self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
241
+ self.tokenizer.add_special_tokens(special_tokens_dict)
242
+
243
+ # model initialize
244
+ feature_length = 77
245
+ # modelFile = "assets/caption_decoder/coco_v2_latest.pt"
246
+ self.caption_model = ClipCaptionModel(feature_length, hidden_dim=hidden_dim)
247
+ # print("Load Model...")
248
+ ckpt = torch.load(pretrained_path, map_location='cpu')
249
+ state_dict = OrderedDict()
250
+ for k, v in ckpt.items():
251
+ new_k = k[7:]
252
+ state_dict[new_k] = v
253
+ mk, uk = self.caption_model.load_state_dict(state_dict, strict=False)
254
+ assert len(mk) == 0
255
+ assert all([name.startswith('clip') for name in uk])
256
+ self.caption_model.eval()
257
+ self.caption_model.to(device)
258
+ self.caption_model.requires_grad_(False)
259
+ self.device = device
260
+
261
+ def encode_prefix(self, features):
262
+ return self.caption_model.encode_prefix(features)
263
+
264
+ def generate_captions(self, features): # the low dimension representation of clip feature
265
+ """
266
+ generate captions given features
267
+ : param features : (tensor([B x L x D]))
268
+ : return generated_text: (list([L]))
269
+ """
270
+
271
+ # generate config
272
+ use_beam_search = True
273
+
274
+ features = torch.split(features, 1, dim=0)
275
+ generated_captions = []
276
+ with torch.no_grad():
277
+ for feature in features:
278
+ feature = self.caption_model.decode_prefix(feature.to(self.device)) # back to the clip feature
279
+ if use_beam_search:
280
+ generated_captions.append(generate_beam(self.caption_model, self.tokenizer, embed=feature)[0])
281
+ else:
282
+ generated_captions.append(generate2(self.caption_model, self.tokenizer, embed=feature))
283
+ return generated_captions
libs/clip.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from transformers import CLIPTokenizer, CLIPTextModel
3
+
4
+
5
+ class AbstractEncoder(nn.Module):
6
+ def __init__(self):
7
+ super().__init__()
8
+
9
+ def encode(self, *args, **kwargs):
10
+ raise NotImplementedError
11
+
12
+
13
+ class FrozenCLIPEmbedder(AbstractEncoder):
14
+ """Uses the CLIP transformer encoder for text (from Hugging Face)"""
15
+ def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
16
+ super().__init__()
17
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
18
+ self.transformer = CLIPTextModel.from_pretrained(version)
19
+ self.device = device
20
+ self.max_length = max_length
21
+ self.freeze()
22
+
23
+ def freeze(self):
24
+ self.transformer = self.transformer.eval()
25
+ for param in self.parameters():
26
+ param.requires_grad = False
27
+
28
+ def forward(self, text):
29
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
30
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
31
+ tokens = batch_encoding["input_ids"].to(self.device)
32
+ outputs = self.transformer(input_ids=tokens)
33
+
34
+ z = outputs.last_hidden_state
35
+ return z
36
+
37
+ def encode(self, text):
38
+ return self(text)
libs/timm.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # code from timm 0.3.2
2
+ import torch
3
+ import torch.nn as nn
4
+ import math
5
+ import warnings
6
+
7
+
8
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
9
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
10
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
11
+ def norm_cdf(x):
12
+ # Computes standard normal cumulative distribution function
13
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
14
+
15
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
16
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
17
+ "The distribution of values may be incorrect.",
18
+ stacklevel=2)
19
+
20
+ with torch.no_grad():
21
+ # Values are generated by using a truncated uniform distribution and
22
+ # then using the inverse CDF for the normal distribution.
23
+ # Get upper and lower cdf values
24
+ l = norm_cdf((a - mean) / std)
25
+ u = norm_cdf((b - mean) / std)
26
+
27
+ # Uniformly fill tensor with values from [l, u], then translate to
28
+ # [2l-1, 2u-1].
29
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
30
+
31
+ # Use inverse cdf transform for normal distribution to get truncated
32
+ # standard normal
33
+ tensor.erfinv_()
34
+
35
+ # Transform to proper mean, std
36
+ tensor.mul_(std * math.sqrt(2.))
37
+ tensor.add_(mean)
38
+
39
+ # Clamp to ensure it's in the proper range
40
+ tensor.clamp_(min=a, max=b)
41
+ return tensor
42
+
43
+
44
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
45
+ # type: (Tensor, float, float, float, float) -> Tensor
46
+ r"""Fills the input Tensor with values drawn from a truncated
47
+ normal distribution. The values are effectively drawn from the
48
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
49
+ with values outside :math:`[a, b]` redrawn until they are within
50
+ the bounds. The method used for generating the random values works
51
+ best when :math:`a \leq \text{mean} \leq b`.
52
+ Args:
53
+ tensor: an n-dimensional `torch.Tensor`
54
+ mean: the mean of the normal distribution
55
+ std: the standard deviation of the normal distribution
56
+ a: the minimum cutoff value
57
+ b: the maximum cutoff value
58
+ Examples:
59
+ >>> w = torch.empty(3, 5)
60
+ >>> nn.init.trunc_normal_(w)
61
+ """
62
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
63
+
64
+
65
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
66
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
67
+
68
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
69
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
70
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
71
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
72
+ 'survival rate' as the argument.
73
+
74
+ """
75
+ if drop_prob == 0. or not training:
76
+ return x
77
+ keep_prob = 1 - drop_prob
78
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
79
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
80
+ random_tensor.floor_() # binarize
81
+ output = x.div(keep_prob) * random_tensor
82
+ return output
83
+
84
+
85
+ class DropPath(nn.Module):
86
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
87
+ """
88
+ def __init__(self, drop_prob=None):
89
+ super(DropPath, self).__init__()
90
+ self.drop_prob = drop_prob
91
+
92
+ def forward(self, x):
93
+ return drop_path(x, self.drop_prob, self.training)
94
+
95
+
96
+ class Mlp(nn.Module):
97
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
98
+ super().__init__()
99
+ out_features = out_features or in_features
100
+ hidden_features = hidden_features or in_features
101
+ self.fc1 = nn.Linear(in_features, hidden_features)
102
+ self.act = act_layer()
103
+ self.fc2 = nn.Linear(hidden_features, out_features)
104
+ self.drop = nn.Dropout(drop)
105
+
106
+ def forward(self, x):
107
+ x = self.fc1(x)
108
+ x = self.act(x)
109
+ x = self.drop(x)
110
+ x = self.fc2(x)
111
+ x = self.drop(x)
112
+ return x
libs/uvit_multi_post_ln.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+ from .timm import trunc_normal_, DropPath, Mlp
5
+ import einops
6
+ import torch.utils.checkpoint
7
+ import torch.nn.functional as F
8
+
9
+ if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
10
+ ATTENTION_MODE = 'flash'
11
+ else:
12
+ try:
13
+ import xformers
14
+ import xformers.ops
15
+ ATTENTION_MODE = 'xformers'
16
+ except:
17
+ ATTENTION_MODE = 'math'
18
+ print(f'attention mode is {ATTENTION_MODE}')
19
+
20
+
21
+ def timestep_embedding(timesteps, dim, max_period=10000):
22
+ """
23
+ Create sinusoidal timestep embeddings.
24
+
25
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
26
+ These may be fractional.
27
+ :param dim: the dimension of the output.
28
+ :param max_period: controls the minimum frequency of the embeddings.
29
+ :return: an [N x dim] Tensor of positional embeddings.
30
+ """
31
+ half = dim // 2
32
+ freqs = torch.exp(
33
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
34
+ ).to(device=timesteps.device)
35
+ args = timesteps[:, None].float() * freqs[None]
36
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
37
+ if dim % 2:
38
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
39
+ return embedding
40
+
41
+
42
+ def patchify(imgs, patch_size):
43
+ x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size)
44
+ return x
45
+
46
+
47
+ def unpatchify(x, in_chans):
48
+ patch_size = int((x.shape[2] // in_chans) ** 0.5)
49
+ h = w = int(x.shape[1] ** .5)
50
+ assert h * w == x.shape[1] and patch_size ** 2 * in_chans == x.shape[2]
51
+ x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h, p1=patch_size, p2=patch_size)
52
+ return x
53
+
54
+
55
+ def interpolate_pos_emb(pos_emb, old_shape, new_shape):
56
+ pos_emb = einops.rearrange(pos_emb, 'B (H W) C -> B C H W', H=old_shape[0], W=old_shape[1])
57
+ pos_emb = F.interpolate(pos_emb, new_shape, mode='bilinear')
58
+ pos_emb = einops.rearrange(pos_emb, 'B C H W -> B (H W) C')
59
+ return pos_emb
60
+
61
+
62
+ class Attention(nn.Module):
63
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
64
+ super().__init__()
65
+ self.num_heads = num_heads
66
+ head_dim = dim // num_heads
67
+ self.scale = qk_scale or head_dim ** -0.5
68
+
69
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
70
+ self.attn_drop = nn.Dropout(attn_drop)
71
+ self.proj = nn.Linear(dim, dim)
72
+ self.proj_drop = nn.Dropout(proj_drop)
73
+
74
+ def forward(self, x):
75
+ B, L, C = x.shape
76
+
77
+ qkv = self.qkv(x)
78
+ if ATTENTION_MODE == 'flash':
79
+ qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float()
80
+ q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
81
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
82
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
83
+ elif ATTENTION_MODE == 'xformers':
84
+ qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads)
85
+ q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
86
+ x = xformers.ops.memory_efficient_attention(q, k, v)
87
+ x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads)
88
+ elif ATTENTION_MODE == 'math':
89
+ with torch.amp.autocast(device_type='cuda', enabled=False):
90
+ qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float()
91
+ q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
92
+ attn = (q @ k.transpose(-2, -1)) * self.scale
93
+ attn = attn.softmax(dim=-1)
94
+ attn = self.attn_drop(attn)
95
+ x = (attn @ v).transpose(1, 2).reshape(B, L, C)
96
+ else:
97
+ raise NotImplemented
98
+
99
+ x = self.proj(x)
100
+ x = self.proj_drop(x)
101
+ return x
102
+
103
+
104
+ class Block(nn.Module):
105
+
106
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
107
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False):
108
+ super().__init__()
109
+ self.norm1 = norm_layer(dim) if skip else None
110
+ self.norm2 = norm_layer(dim)
111
+
112
+ self.attn = Attention(
113
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
114
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
115
+ self.norm3 = norm_layer(dim)
116
+ mlp_hidden_dim = int(dim * mlp_ratio)
117
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
118
+ self.skip_linear = nn.Linear(2 * dim, dim) if skip else None
119
+ self.use_checkpoint = use_checkpoint
120
+
121
+ def forward(self, x, skip=None):
122
+ if self.use_checkpoint:
123
+ return torch.utils.checkpoint.checkpoint(self._forward, x, skip)
124
+ else:
125
+ return self._forward(x, skip)
126
+
127
+ def _forward(self, x, skip=None):
128
+ if self.skip_linear is not None:
129
+ x = self.skip_linear(torch.cat([x, skip], dim=-1))
130
+ x = self.norm1(x)
131
+ x = x + self.drop_path(self.attn(x))
132
+ x = self.norm2(x)
133
+
134
+ x = x + self.drop_path(self.mlp(x))
135
+ x = self.norm3(x)
136
+
137
+ return x
138
+
139
+
140
+ class PatchEmbed(nn.Module):
141
+ """ Image to Patch Embedding
142
+ """
143
+ def __init__(self, patch_size, in_chans=3, embed_dim=768):
144
+ super().__init__()
145
+ self.patch_size = patch_size
146
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
147
+
148
+ def forward(self, x):
149
+ B, C, H, W = x.shape
150
+ assert H % self.patch_size == 0 and W % self.patch_size == 0
151
+ x = self.proj(x).flatten(2).transpose(1, 2)
152
+ return x
153
+
154
+
155
+ class UViT(nn.Module):
156
+ def __init__(self, img_size, in_chans, patch_size, embed_dim=768, depth=12,
157
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, pos_drop_rate=0., drop_rate=0., attn_drop_rate=0.,
158
+ norm_layer=nn.LayerNorm, mlp_time_embed=False, use_checkpoint=False,
159
+ text_dim=None, num_text_tokens=None, clip_img_dim=None):
160
+ super().__init__()
161
+ self.in_chans = in_chans
162
+ self.patch_size = patch_size
163
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
164
+
165
+ self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
166
+ self.img_size = (img_size, img_size) if isinstance(img_size, int) else img_size # the default img size
167
+ assert self.img_size[0] % patch_size == 0 and self.img_size[1] % patch_size == 0
168
+ self.num_patches = (self.img_size[0] // patch_size) * (self.img_size[1] // patch_size)
169
+
170
+ self.time_img_embed = nn.Sequential(
171
+ nn.Linear(embed_dim, 4 * embed_dim),
172
+ nn.SiLU(),
173
+ nn.Linear(4 * embed_dim, embed_dim),
174
+ ) if mlp_time_embed else nn.Identity()
175
+
176
+ self.time_text_embed = nn.Sequential(
177
+ nn.Linear(embed_dim, 4 * embed_dim),
178
+ nn.SiLU(),
179
+ nn.Linear(4 * embed_dim, embed_dim),
180
+ ) if mlp_time_embed else nn.Identity()
181
+
182
+ self.text_embed = nn.Linear(text_dim, embed_dim)
183
+ self.text_out = nn.Linear(embed_dim, text_dim)
184
+
185
+ self.clip_img_embed = nn.Linear(clip_img_dim, embed_dim)
186
+ self.clip_img_out = nn.Linear(embed_dim, clip_img_dim)
187
+
188
+ self.num_text_tokens = num_text_tokens
189
+ self.num_tokens = 1 + 1 + num_text_tokens + 1 + self.num_patches
190
+
191
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim))
192
+ self.pos_drop = nn.Dropout(p=pos_drop_rate)
193
+
194
+ self.in_blocks = nn.ModuleList([
195
+ Block(
196
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
197
+ drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer, use_checkpoint=use_checkpoint)
198
+ for _ in range(depth // 2)])
199
+
200
+ self.mid_block = Block(
201
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
202
+ drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer, use_checkpoint=use_checkpoint)
203
+
204
+ self.out_blocks = nn.ModuleList([
205
+ Block(
206
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
207
+ drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer, skip=True, use_checkpoint=use_checkpoint)
208
+ for _ in range(depth // 2)])
209
+
210
+ self.norm = norm_layer(embed_dim)
211
+ self.patch_dim = patch_size ** 2 * in_chans
212
+ self.decoder_pred = nn.Linear(embed_dim, self.patch_dim, bias=True)
213
+
214
+ trunc_normal_(self.pos_embed, std=.02)
215
+ self.apply(self._init_weights)
216
+
217
+ def _init_weights(self, m):
218
+ if isinstance(m, nn.Linear):
219
+ trunc_normal_(m.weight, std=.02)
220
+ if isinstance(m, nn.Linear) and m.bias is not None:
221
+ nn.init.constant_(m.bias, 0)
222
+ elif isinstance(m, nn.LayerNorm):
223
+ nn.init.constant_(m.bias, 0)
224
+ nn.init.constant_(m.weight, 1.0)
225
+
226
+ @torch.jit.ignore
227
+ def no_weight_decay(self):
228
+ return {'pos_embed'}
229
+
230
+ def forward(self, img, clip_img, text, t_img, t_text):
231
+ _, _, H, W = img.shape
232
+
233
+ img = self.patch_embed(img)
234
+
235
+ t_img_token = self.time_img_embed(timestep_embedding(t_img, self.embed_dim))
236
+ t_img_token = t_img_token.unsqueeze(dim=1)
237
+ t_text_token = self.time_text_embed(timestep_embedding(t_text, self.embed_dim))
238
+ t_text_token = t_text_token.unsqueeze(dim=1)
239
+
240
+ text = self.text_embed(text)
241
+ clip_img = self.clip_img_embed(clip_img)
242
+ x = torch.cat((t_img_token, t_text_token, text, clip_img, img), dim=1)
243
+
244
+ num_text_tokens, num_img_tokens = text.size(1), img.size(1)
245
+
246
+ if H == self.img_size[0] and W == self.img_size[1]:
247
+ pos_embed = self.pos_embed
248
+ else: # interpolate the positional embedding when the input image is not of the default shape
249
+ pos_embed_others, pos_embed_patches = torch.split(self.pos_embed, [1 + 1 + num_text_tokens + 1, self.num_patches], dim=1)
250
+ pos_embed_patches = interpolate_pos_emb(pos_embed_patches, (self.img_size[0] // self.patch_size, self.img_size[1] // self.patch_size),
251
+ (H // self.patch_size, W // self.patch_size))
252
+ pos_embed = torch.cat((pos_embed_others, pos_embed_patches), dim=1)
253
+
254
+ x = x + pos_embed
255
+ x = self.pos_drop(x)
256
+
257
+ skips = []
258
+ for blk in self.in_blocks:
259
+ x = blk(x)
260
+ skips.append(x)
261
+
262
+ x = self.mid_block(x)
263
+
264
+ for blk in self.out_blocks:
265
+ x = blk(x, skips.pop())
266
+
267
+ x = self.norm(x)
268
+
269
+ t_img_token_out, t_text_token_out, text_out, clip_img_out, img_out = x.split((1, 1, num_text_tokens, 1, num_img_tokens), dim=1)
270
+
271
+ img_out = self.decoder_pred(img_out)
272
+ img_out = unpatchify(img_out, self.in_chans)
273
+
274
+ clip_img_out = self.clip_img_out(clip_img_out)
275
+
276
+ text_out = self.text_out(text_out)
277
+ return img_out, clip_img_out, text_out
libs/uvit_multi_post_ln_v1.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+ from .timm import trunc_normal_, DropPath, Mlp
5
+ import einops
6
+ import torch.utils.checkpoint
7
+ import torch.nn.functional as F
8
+
9
+ if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
10
+ ATTENTION_MODE = 'flash'
11
+ else:
12
+ try:
13
+ import xformers
14
+ import xformers.ops
15
+ ATTENTION_MODE = 'xformers'
16
+ except:
17
+ ATTENTION_MODE = 'math'
18
+ print(f'attention mode is {ATTENTION_MODE}')
19
+
20
+
21
+ def timestep_embedding(timesteps, dim, max_period=10000):
22
+ """
23
+ Create sinusoidal timestep embeddings.
24
+
25
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
26
+ These may be fractional.
27
+ :param dim: the dimension of the output.
28
+ :param max_period: controls the minimum frequency of the embeddings.
29
+ :return: an [N x dim] Tensor of positional embeddings.
30
+ """
31
+ half = dim // 2
32
+ freqs = torch.exp(
33
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
34
+ ).to(device=timesteps.device)
35
+ args = timesteps[:, None].float() * freqs[None]
36
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
37
+ if dim % 2:
38
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
39
+ return embedding
40
+
41
+
42
+ def patchify(imgs, patch_size):
43
+ x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size)
44
+ return x
45
+
46
+
47
+ def unpatchify(x, in_chans):
48
+ patch_size = int((x.shape[2] // in_chans) ** 0.5)
49
+ h = w = int(x.shape[1] ** .5)
50
+ assert h * w == x.shape[1] and patch_size ** 2 * in_chans == x.shape[2]
51
+ x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h, p1=patch_size, p2=patch_size)
52
+ return x
53
+
54
+
55
+ def interpolate_pos_emb(pos_emb, old_shape, new_shape):
56
+ pos_emb = einops.rearrange(pos_emb, 'B (H W) C -> B C H W', H=old_shape[0], W=old_shape[1])
57
+ pos_emb = F.interpolate(pos_emb, new_shape, mode='bilinear')
58
+ pos_emb = einops.rearrange(pos_emb, 'B C H W -> B (H W) C')
59
+ return pos_emb
60
+
61
+
62
+ class Attention(nn.Module):
63
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
64
+ super().__init__()
65
+ self.num_heads = num_heads
66
+ head_dim = dim // num_heads
67
+ self.scale = qk_scale or head_dim ** -0.5
68
+
69
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
70
+ self.attn_drop = nn.Dropout(attn_drop)
71
+ self.proj = nn.Linear(dim, dim)
72
+ self.proj_drop = nn.Dropout(proj_drop)
73
+
74
+ def forward(self, x):
75
+ B, L, C = x.shape
76
+
77
+ qkv = self.qkv(x)
78
+ if ATTENTION_MODE == 'flash':
79
+ qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float()
80
+ q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
81
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
82
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
83
+ elif ATTENTION_MODE == 'xformers':
84
+ qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads)
85
+ q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
86
+ x = xformers.ops.memory_efficient_attention(q, k, v)
87
+ x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads)
88
+ elif ATTENTION_MODE == 'math':
89
+ with torch.amp.autocast(device_type='cuda', enabled=False):
90
+ qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float()
91
+ q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
92
+ attn = (q @ k.transpose(-2, -1)) * self.scale
93
+ attn = attn.softmax(dim=-1)
94
+ attn = self.attn_drop(attn)
95
+ x = (attn @ v).transpose(1, 2).reshape(B, L, C)
96
+ else:
97
+ raise NotImplemented
98
+
99
+ x = self.proj(x)
100
+ x = self.proj_drop(x)
101
+ return x
102
+
103
+
104
+ class Block(nn.Module):
105
+
106
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
107
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False):
108
+ super().__init__()
109
+ self.norm1 = norm_layer(dim) if skip else None
110
+ self.norm2 = norm_layer(dim)
111
+
112
+ self.attn = Attention(
113
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
114
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
115
+ self.norm3 = norm_layer(dim)
116
+ mlp_hidden_dim = int(dim * mlp_ratio)
117
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
118
+ self.skip_linear = nn.Linear(2 * dim, dim) if skip else None
119
+ self.use_checkpoint = use_checkpoint
120
+
121
+ def forward(self, x, skip=None):
122
+ if self.use_checkpoint:
123
+ return torch.utils.checkpoint.checkpoint(self._forward, x, skip)
124
+ else:
125
+ return self._forward(x, skip)
126
+
127
+ def _forward(self, x, skip=None):
128
+ if self.skip_linear is not None:
129
+ x = self.skip_linear(torch.cat([x, skip], dim=-1))
130
+ x = self.norm1(x)
131
+ x = x + self.drop_path(self.attn(x))
132
+ x = self.norm2(x)
133
+
134
+ x = x + self.drop_path(self.mlp(x))
135
+ x = self.norm3(x)
136
+
137
+ return x
138
+
139
+
140
+ class PatchEmbed(nn.Module):
141
+ """ Image to Patch Embedding
142
+ """
143
+ def __init__(self, patch_size, in_chans=3, embed_dim=768):
144
+ super().__init__()
145
+ self.patch_size = patch_size
146
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
147
+
148
+ def forward(self, x):
149
+ B, C, H, W = x.shape
150
+ assert H % self.patch_size == 0 and W % self.patch_size == 0
151
+ x = self.proj(x).flatten(2).transpose(1, 2)
152
+ return x
153
+
154
+
155
+ class UViT(nn.Module):
156
+ def __init__(self, img_size, in_chans, patch_size, embed_dim=768, depth=12,
157
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, pos_drop_rate=0., drop_rate=0., attn_drop_rate=0.,
158
+ norm_layer=nn.LayerNorm, mlp_time_embed=False, use_checkpoint=False,
159
+ text_dim=None, num_text_tokens=None, clip_img_dim=None):
160
+ super().__init__()
161
+ self.in_chans = in_chans
162
+ self.patch_size = patch_size
163
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
164
+
165
+ self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
166
+ self.img_size = (img_size, img_size) if isinstance(img_size, int) else img_size # the default img size
167
+ assert self.img_size[0] % patch_size == 0 and self.img_size[1] % patch_size == 0
168
+ self.num_patches = (self.img_size[0] // patch_size) * (self.img_size[1] // patch_size)
169
+
170
+ self.time_img_embed = nn.Sequential(
171
+ nn.Linear(embed_dim, 4 * embed_dim),
172
+ nn.SiLU(),
173
+ nn.Linear(4 * embed_dim, embed_dim),
174
+ ) if mlp_time_embed else nn.Identity()
175
+
176
+ self.time_text_embed = nn.Sequential(
177
+ nn.Linear(embed_dim, 4 * embed_dim),
178
+ nn.SiLU(),
179
+ nn.Linear(4 * embed_dim, embed_dim),
180
+ ) if mlp_time_embed else nn.Identity()
181
+
182
+ self.text_embed = nn.Linear(text_dim, embed_dim)
183
+ self.text_out = nn.Linear(embed_dim, text_dim)
184
+
185
+ self.clip_img_embed = nn.Linear(clip_img_dim, embed_dim)
186
+ self.clip_img_out = nn.Linear(embed_dim, clip_img_dim)
187
+
188
+ self.num_text_tokens = num_text_tokens
189
+ self.num_tokens = 1 + 1 + num_text_tokens + 1 + self.num_patches
190
+
191
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim))
192
+ self.pos_drop = nn.Dropout(p=pos_drop_rate)
193
+
194
+ self.in_blocks = nn.ModuleList([
195
+ Block(
196
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
197
+ drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer, use_checkpoint=use_checkpoint)
198
+ for _ in range(depth // 2)])
199
+
200
+ self.mid_block = Block(
201
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
202
+ drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer, use_checkpoint=use_checkpoint)
203
+
204
+ self.out_blocks = nn.ModuleList([
205
+ Block(
206
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
207
+ drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer, skip=True, use_checkpoint=use_checkpoint)
208
+ for _ in range(depth // 2)])
209
+
210
+ self.norm = norm_layer(embed_dim)
211
+ self.patch_dim = patch_size ** 2 * in_chans
212
+ self.decoder_pred = nn.Linear(embed_dim, self.patch_dim, bias=True)
213
+
214
+ trunc_normal_(self.pos_embed, std=.02)
215
+ self.apply(self._init_weights)
216
+
217
+ self.token_embedding = nn.Embedding(2, embed_dim)
218
+ self.pos_embed_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
219
+
220
+ def _init_weights(self, m):
221
+ if isinstance(m, nn.Linear):
222
+ trunc_normal_(m.weight, std=.02)
223
+ if isinstance(m, nn.Linear) and m.bias is not None:
224
+ nn.init.constant_(m.bias, 0)
225
+ elif isinstance(m, nn.LayerNorm):
226
+ nn.init.constant_(m.bias, 0)
227
+ nn.init.constant_(m.weight, 1.0)
228
+
229
+ @torch.jit.ignore
230
+ def no_weight_decay(self):
231
+ return {'pos_embed'}
232
+
233
+ def forward(self, img, clip_img, text, t_img, t_text, data_type):
234
+ _, _, H, W = img.shape
235
+
236
+ img = self.patch_embed(img)
237
+
238
+ t_img_token = self.time_img_embed(timestep_embedding(t_img, self.embed_dim))
239
+ t_img_token = t_img_token.unsqueeze(dim=1)
240
+ t_text_token = self.time_text_embed(timestep_embedding(t_text, self.embed_dim))
241
+ t_text_token = t_text_token.unsqueeze(dim=1)
242
+
243
+ text = self.text_embed(text)
244
+ clip_img = self.clip_img_embed(clip_img)
245
+
246
+ token_embed = self.token_embedding(data_type).unsqueeze(dim=1)
247
+
248
+ x = torch.cat((t_img_token, t_text_token, token_embed, text, clip_img, img), dim=1)
249
+
250
+ num_text_tokens, num_img_tokens = text.size(1), img.size(1)
251
+
252
+ pos_embed = torch.cat(
253
+ [self.pos_embed[:, :1 + 1, :], self.pos_embed_token, self.pos_embed[:, 1 + 1:, :]], dim=1)
254
+ if H == self.img_size[0] and W == self.img_size[1]:
255
+ pass
256
+ else: # interpolate the positional embedding when the input image is not of the default shape
257
+ pos_embed_others, pos_embed_patches = torch.split(pos_embed, [1 + 1 + 1 + num_text_tokens + 1, self.num_patches], dim=1)
258
+ pos_embed_patches = interpolate_pos_emb(pos_embed_patches, (self.img_size[0] // self.patch_size, self.img_size[1] // self.patch_size),
259
+ (H // self.patch_size, W // self.patch_size))
260
+ pos_embed = torch.cat((pos_embed_others, pos_embed_patches), dim=1)
261
+
262
+ x = x + pos_embed
263
+ x = self.pos_drop(x)
264
+
265
+ skips = []
266
+ for blk in self.in_blocks:
267
+ x = blk(x)
268
+ skips.append(x)
269
+
270
+ x = self.mid_block(x)
271
+
272
+ for blk in self.out_blocks:
273
+ x = blk(x, skips.pop())
274
+
275
+ x = self.norm(x)
276
+
277
+ t_img_token_out, t_text_token_out, token_embed_out, text_out, clip_img_out, img_out = x.split((1, 1, 1, num_text_tokens, 1, num_img_tokens), dim=1)
278
+
279
+ img_out = self.decoder_pred(img_out)
280
+ img_out = unpatchify(img_out, self.in_chans)
281
+
282
+ clip_img_out = self.clip_img_out(clip_img_out)
283
+
284
+ text_out = self.text_out(text_out)
285
+ return img_out, clip_img_out, text_out
unidiffuser/sample_v1.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ml_collections
2
+ import torch
3
+ import random
4
+ from absl import logging
5
+ import einops
6
+ from torchvision.utils import save_image, make_grid
7
+ import torchvision.transforms as standard_transforms
8
+ import numpy as np
9
+ import clip
10
+ from PIL import Image
11
+ import time
12
+ import os
13
+
14
+ from libs.autoencoder import get_model
15
+ from libs.clip import FrozenCLIPEmbedder
16
+ from dpm_solver_pp import NoiseScheduleVP, DPM_Solver
17
+ from utils import center_crop, set_logger, get_nnet
18
+
19
+
20
+ def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000):
21
+ _betas = (
22
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
23
+ )
24
+ return _betas.numpy()
25
+
26
+
27
+ def prepare_contexts(config, clip_text_model, clip_img_model, clip_img_model_preprocess, autoencoder):
28
+ resolution = config.z_shape[-1] * 8
29
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
30
+
31
+ contexts = torch.randn(config.n_samples, 77, config.clip_text_dim).to(device)
32
+ img_contexts = torch.randn(config.n_samples, 2 * config.z_shape[0], config.z_shape[1], config.z_shape[2])
33
+ clip_imgs = torch.randn(config.n_samples, 1, config.clip_img_dim)
34
+
35
+ if config.mode in ['t2i', 't2i2t']:
36
+ prompts = [ config.prompt ] * config.n_samples
37
+ contexts = clip_text_model.encode(prompts)
38
+
39
+ elif config.mode in ['i2t', 'i2t2i']:
40
+ from PIL import Image
41
+ img_contexts = []
42
+ clip_imgs = []
43
+
44
+ def get_img_feature(image):
45
+ image = np.array(image).astype(np.uint8)
46
+ image = center_crop(resolution, resolution, image)
47
+ clip_img_feature = clip_img_model.encode_image(clip_img_model_preprocess(Image.fromarray(image)).unsqueeze(0).to(device))
48
+
49
+ image = (image / 127.5 - 1.0).astype(np.float32)
50
+ image = einops.rearrange(image, 'h w c -> 1 c h w')
51
+ image = torch.tensor(image, device=device)
52
+ moments = autoencoder.encode_moments(image)
53
+
54
+ return clip_img_feature, moments
55
+
56
+ image = Image.open(config.img).convert('RGB')
57
+ clip_img, img_context = get_img_feature(image)
58
+
59
+ img_contexts.append(img_context)
60
+ clip_imgs.append(clip_img)
61
+ img_contexts = img_contexts * config.n_samples
62
+ clip_imgs = clip_imgs * config.n_samples
63
+
64
+ img_contexts = torch.concat(img_contexts, dim=0)
65
+ clip_imgs = torch.stack(clip_imgs, dim=0)
66
+
67
+ return contexts, img_contexts, clip_imgs
68
+
69
+
70
+ def unpreprocess(v): # to B C H W and [0, 1]
71
+ v = 0.5 * (v + 1.)
72
+ v.clamp_(0., 1.)
73
+ return v
74
+
75
+
76
+ def set_seed(seed: int):
77
+ random.seed(seed)
78
+ np.random.seed(seed)
79
+ torch.manual_seed(seed)
80
+ torch.cuda.manual_seed_all(seed)
81
+
82
+
83
+ def evaluate(config):
84
+ if config.get('benchmark', False):
85
+ torch.backends.cudnn.benchmark = True
86
+ torch.backends.cudnn.deterministic = False
87
+
88
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
89
+ set_seed(config.seed)
90
+
91
+ config = ml_collections.FrozenConfigDict(config)
92
+ set_logger(log_level='info')
93
+
94
+ _betas = stable_diffusion_beta_schedule()
95
+ N = len(_betas)
96
+
97
+ nnet = get_nnet(**config.nnet)
98
+ logging.info(f'load nnet from {config.nnet_path}')
99
+ nnet.load_state_dict(torch.load(config.nnet_path, map_location='cpu'))
100
+ nnet.to(device)
101
+ nnet.eval()
102
+
103
+ use_caption_decoder = config.text_dim < config.clip_text_dim or config.mode != 't2i'
104
+ if use_caption_decoder:
105
+ from ..libs.caption_decoder import CaptionDecoder
106
+ caption_decoder = CaptionDecoder(device=device, **config.caption_decoder)
107
+ else:
108
+ caption_decoder = None
109
+
110
+ clip_text_model = FrozenCLIPEmbedder(device=device)
111
+ clip_text_model.eval()
112
+ clip_text_model.to(device)
113
+
114
+ autoencoder = get_model(**config.autoencoder)
115
+ autoencoder.to(device)
116
+
117
+ clip_img_model, clip_img_model_preprocess = clip.load("ViT-B/32", device=device, jit=False)
118
+
119
+ empty_context = clip_text_model.encode([''])[0]
120
+
121
+ def split(x):
122
+ C, H, W = config.z_shape
123
+ z_dim = C * H * W
124
+ z, clip_img = x.split([z_dim, config.clip_img_dim], dim=1)
125
+ z = einops.rearrange(z, 'B (C H W) -> B C H W', C=C, H=H, W=W)
126
+ clip_img = einops.rearrange(clip_img, 'B (L D) -> B L D', L=1, D=config.clip_img_dim)
127
+ return z, clip_img
128
+
129
+
130
+ def combine(z, clip_img):
131
+ z = einops.rearrange(z, 'B C H W -> B (C H W)')
132
+ clip_img = einops.rearrange(clip_img, 'B L D -> B (L D)')
133
+ return torch.concat([z, clip_img], dim=-1)
134
+
135
+
136
+ def t2i_nnet(x, timesteps, text): # text is the low dimension version of the text clip embedding
137
+ """
138
+ 1. calculate the conditional model output
139
+ 2. calculate unconditional model output
140
+ config.sample.t2i_cfg_mode == 'empty_token': using the original cfg with the empty string
141
+ config.sample.t2i_cfg_mode == 'true_uncond: using the unconditional model learned by our method
142
+ 3. return linear combination of conditional output and unconditional output
143
+ """
144
+ z, clip_img = split(x)
145
+
146
+ t_text = torch.zeros(timesteps.size(0), dtype=torch.int, device=device)
147
+
148
+ z_out, clip_img_out, text_out = nnet(z, clip_img, text=text, t_img=timesteps, t_text=t_text,
149
+ data_type=torch.zeros_like(t_text, device=device, dtype=torch.int) + config.data_type)
150
+ x_out = combine(z_out, clip_img_out)
151
+
152
+ if config.sample.scale == 0.:
153
+ return x_out
154
+
155
+ if config.sample.t2i_cfg_mode == 'empty_token':
156
+ _empty_context = einops.repeat(empty_context, 'L D -> B L D', B=x.size(0))
157
+ if use_caption_decoder:
158
+ _empty_context = caption_decoder.encode_prefix(_empty_context)
159
+ z_out_uncond, clip_img_out_uncond, text_out_uncond = nnet(z, clip_img, text=_empty_context, t_img=timesteps, t_text=t_text,
160
+ data_type=torch.zeros_like(t_text, device=device, dtype=torch.int) + config.data_type)
161
+ x_out_uncond = combine(z_out_uncond, clip_img_out_uncond)
162
+ elif config.sample.t2i_cfg_mode == 'true_uncond':
163
+ text_N = torch.randn_like(text) # 3 other possible choices
164
+ z_out_uncond, clip_img_out_uncond, text_out_uncond = nnet(z, clip_img, text=text_N, t_img=timesteps, t_text=torch.ones_like(timesteps) * N,
165
+ data_type=torch.zeros_like(t_text, device=device, dtype=torch.int) + config.data_type)
166
+ x_out_uncond = combine(z_out_uncond, clip_img_out_uncond)
167
+ else:
168
+ raise NotImplementedError
169
+
170
+ return x_out + config.sample.scale * (x_out - x_out_uncond)
171
+
172
+
173
+ def i_nnet(x, timesteps):
174
+ z, clip_img = split(x)
175
+ text = torch.randn(x.size(0), 77, config.text_dim, device=device)
176
+ t_text = torch.ones_like(timesteps) * N
177
+ z_out, clip_img_out, text_out = nnet(z, clip_img, text=text, t_img=timesteps, t_text=t_text,
178
+ data_type=torch.zeros_like(t_text, device=device, dtype=torch.int) + config.data_type)
179
+ x_out = combine(z_out, clip_img_out)
180
+ return x_out
181
+
182
+ def t_nnet(x, timesteps):
183
+ z = torch.randn(x.size(0), *config.z_shape, device=device)
184
+ clip_img = torch.randn(x.size(0), 1, config.clip_img_dim, device=device)
185
+ z_out, clip_img_out, text_out = nnet(z, clip_img, text=x, t_img=torch.ones_like(timesteps) * N, t_text=timesteps,
186
+ data_type=torch.zeros_like(timesteps, device=device, dtype=torch.int) + config.data_type)
187
+ return text_out
188
+
189
+ def i2t_nnet(x, timesteps, z, clip_img):
190
+ """
191
+ 1. calculate the conditional model output
192
+ 2. calculate unconditional model output
193
+ 3. return linear combination of conditional output and unconditional output
194
+ """
195
+ t_img = torch.zeros(timesteps.size(0), dtype=torch.int, device=device)
196
+
197
+ z_out, clip_img_out, text_out = nnet(z, clip_img, text=x, t_img=t_img, t_text=timesteps,
198
+ data_type=torch.zeros_like(t_img, device=device, dtype=torch.int) + config.data_type)
199
+
200
+ if config.sample.scale == 0.:
201
+ return text_out
202
+
203
+ z_N = torch.randn_like(z) # 3 other possible choices
204
+ clip_img_N = torch.randn_like(clip_img)
205
+ z_out_uncond, clip_img_out_uncond, text_out_uncond = nnet(z_N, clip_img_N, text=x, t_img=torch.ones_like(timesteps) * N, t_text=timesteps,
206
+ data_type=torch.zeros_like(timesteps, device=device, dtype=torch.int) + config.data_type)
207
+
208
+ return text_out + config.sample.scale * (text_out - text_out_uncond)
209
+
210
+ def split_joint(x):
211
+ C, H, W = config.z_shape
212
+ z_dim = C * H * W
213
+ z, clip_img, text = x.split([z_dim, config.clip_img_dim, 77 * config.text_dim], dim=1)
214
+ z = einops.rearrange(z, 'B (C H W) -> B C H W', C=C, H=H, W=W)
215
+ clip_img = einops.rearrange(clip_img, 'B (L D) -> B L D', L=1, D=config.clip_img_dim)
216
+ text = einops.rearrange(text, 'B (L D) -> B L D', L=77, D=config.text_dim)
217
+ return z, clip_img, text
218
+
219
+ def combine_joint(z, clip_img, text):
220
+ z = einops.rearrange(z, 'B C H W -> B (C H W)')
221
+ clip_img = einops.rearrange(clip_img, 'B L D -> B (L D)')
222
+ text = einops.rearrange(text, 'B L D -> B (L D)')
223
+ return torch.concat([z, clip_img, text], dim=-1)
224
+
225
+ def joint_nnet(x, timesteps):
226
+ z, clip_img, text = split_joint(x)
227
+ z_out, clip_img_out, text_out = nnet(z, clip_img, text=text, t_img=timesteps, t_text=timesteps,
228
+ data_type=torch.zeros_like(timesteps, device=device, dtype=torch.int) + config.data_type)
229
+ x_out = combine_joint(z_out, clip_img_out, text_out)
230
+
231
+ if config.sample.scale == 0.:
232
+ return x_out
233
+
234
+ z_noise = torch.randn(x.size(0), *config.z_shape, device=device)
235
+ clip_img_noise = torch.randn(x.size(0), 1, config.clip_img_dim, device=device)
236
+ text_noise = torch.randn(x.size(0), 77, config.text_dim, device=device)
237
+
238
+ _, _, text_out_uncond = nnet(z_noise, clip_img_noise, text=text, t_img=torch.ones_like(timesteps) * N, t_text=timesteps,
239
+ data_type=torch.zeros_like(timesteps, device=device, dtype=torch.int) + config.data_type)
240
+ z_out_uncond, clip_img_out_uncond, _ = nnet(z, clip_img, text=text_noise, t_img=timesteps, t_text=torch.ones_like(timesteps) * N,
241
+ data_type=torch.zeros_like(timesteps, device=device, dtype=torch.int) + config.data_type)
242
+
243
+ x_out_uncond = combine_joint(z_out_uncond, clip_img_out_uncond, text_out_uncond)
244
+
245
+ return x_out + config.sample.scale * (x_out - x_out_uncond)
246
+
247
+ @torch.cuda.amp.autocast()
248
+ def encode(_batch):
249
+ return autoencoder.encode(_batch)
250
+
251
+ @torch.cuda.amp.autocast()
252
+ def decode(_batch):
253
+ return autoencoder.decode(_batch)
254
+
255
+
256
+ logging.info(config.sample)
257
+ logging.info(f'N={N}')
258
+
259
+ contexts, img_contexts, clip_imgs = prepare_contexts(config, clip_text_model, clip_img_model, clip_img_model_preprocess, autoencoder)
260
+
261
+ contexts = contexts # the clip embedding of conditioned texts
262
+ contexts_low_dim = contexts if not use_caption_decoder else caption_decoder.encode_prefix(contexts) # the low dimensional version of the contexts, which is the input to the nnet
263
+
264
+ img_contexts = img_contexts # img_contexts is the autoencoder moment
265
+ z_img = autoencoder.sample(img_contexts)
266
+ clip_imgs = clip_imgs # the clip embedding of conditioned image
267
+
268
+ if config.mode in ['t2i', 't2i2t']:
269
+ _n_samples = contexts_low_dim.size(0)
270
+ elif config.mode in ['i2t', 'i2t2i']:
271
+ _n_samples = img_contexts.size(0)
272
+ else:
273
+ _n_samples = config.n_samples
274
+
275
+
276
+ def sample_fn(mode, **kwargs):
277
+
278
+ _z_init = torch.randn(_n_samples, *config.z_shape, device=device)
279
+ _clip_img_init = torch.randn(_n_samples, 1, config.clip_img_dim, device=device)
280
+ _text_init = torch.randn(_n_samples, 77, config.text_dim, device=device)
281
+ if mode == 'joint':
282
+ _x_init = combine_joint(_z_init, _clip_img_init, _text_init)
283
+ elif mode in ['t2i', 'i']:
284
+ _x_init = combine(_z_init, _clip_img_init)
285
+ elif mode in ['i2t', 't']:
286
+ _x_init = _text_init
287
+ noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float())
288
+
289
+ def model_fn(x, t_continuous):
290
+ t = t_continuous * N
291
+ if mode == 'joint':
292
+ return joint_nnet(x, t)
293
+ elif mode == 't2i':
294
+ return t2i_nnet(x, t, **kwargs)
295
+ elif mode == 'i2t':
296
+ return i2t_nnet(x, t, **kwargs)
297
+ elif mode == 'i':
298
+ return i_nnet(x, t)
299
+ elif mode == 't':
300
+ return t_nnet(x, t)
301
+
302
+ dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False)
303
+ with torch.no_grad():
304
+ with torch.autocast(device_type=device):
305
+ start_time = time.time()
306
+ x = dpm_solver.sample(_x_init, steps=config.sample.sample_steps, eps=1. / N, T=1.)
307
+ end_time = time.time()
308
+ print(f'\ngenerate {_n_samples} samples with {config.sample.sample_steps} steps takes {end_time - start_time:.2f}s')
309
+
310
+ # os.makedirs(config.output_path, exist_ok=True)
311
+ if mode == 'joint':
312
+ _z, _clip_img, _text = split_joint(x)
313
+ return _z, _clip_img, _text
314
+ elif mode in ['t2i', 'i']:
315
+ _z, _clip_img = split(x)
316
+ return _z, _clip_img
317
+ elif mode in ['i2t', 't']:
318
+ return x
319
+
320
+ output_images = None
321
+ output_text = None
322
+
323
+ if config.mode in ['joint']:
324
+ _z, _clip_img, _text = sample_fn(config.mode)
325
+ samples = unpreprocess(decode(_z))
326
+ prompts = caption_decoder.generate_captions(_text)
327
+ # Just get the first output image for now
328
+ output_images = samples
329
+ output_text = prompts
330
+
331
+ elif config.mode in ['t2i', 'i', 'i2t2i']:
332
+ if config.mode == 't2i':
333
+ _z, _clip_img = sample_fn(config.mode, text=contexts_low_dim) # conditioned on the text embedding
334
+ elif config.mode == 'i':
335
+ _z, _clip_img = sample_fn(config.mode)
336
+ elif config.mode == 'i2t2i':
337
+ _text = sample_fn('i2t', z=z_img, clip_img=clip_imgs) # conditioned on the image embedding
338
+ _z, _clip_img = sample_fn('t2i', text=_text)
339
+ samples = unpreprocess(decode(_z))
340
+ output_images = samples
341
+
342
+
343
+ elif config.mode in ['i2t', 't', 't2i2t']:
344
+ if config.mode == 'i2t':
345
+ _text = sample_fn(config.mode, z=z_img, clip_img=clip_imgs) # conditioned on the image embedding
346
+ elif config.mode == 't':
347
+ _text = sample_fn(config.mode)
348
+ elif config.mode == 't2i2t':
349
+ _z, _clip_img = sample_fn('t2i', text=contexts_low_dim)
350
+ _text = sample_fn('i2t', z=_z, clip_img=_clip_img)
351
+ samples = caption_decoder.generate_captions(_text)
352
+ output_text = samples
353
+
354
+ print(f'\nGPU memory usage: {torch.cuda.max_memory_reserved() / 1024 ** 3:.2f} GB')
355
+ print(f'\nresults are saved in {os.path.join(config.output_path, config.mode)} :)')
356
+
357
+ # Convert sample images to PIL
358
+ if output_images is not None:
359
+ for sample in output_images:
360
+ sample = standard_transforms.ToPILImage()(sample)
361
+
362
+ return output_images, output_text
363
+
364
+
365
+
366
+ def d(**kwargs):
367
+ """Helper of creating a config dict."""
368
+ return ml_collections.ConfigDict(initial_dictionary=kwargs)
369
+
370
+
371
+ def get_config():
372
+ config = ml_collections.ConfigDict()
373
+
374
+ config.seed = 1234
375
+ config.pred = 'noise_pred'
376
+ config.z_shape = (4, 64, 64)
377
+ config.clip_img_dim = 512
378
+ config.clip_text_dim = 768
379
+ config.text_dim = 64 # reduce dimension
380
+ config.data_type = 1
381
+
382
+ config.autoencoder = d(
383
+ pretrained_path='models/autoencoder_kl.pth',
384
+ )
385
+
386
+ config.caption_decoder = d(
387
+ pretrained_path="models/caption_decoder.pth",
388
+ hidden_dim=config.get_ref('text_dim')
389
+ )
390
+
391
+ config.nnet = d(
392
+ name='uvit_multi_post_ln_v1',
393
+ img_size=64,
394
+ in_chans=4,
395
+ patch_size=2,
396
+ embed_dim=1536,
397
+ depth=30,
398
+ num_heads=24,
399
+ mlp_ratio=4,
400
+ qkv_bias=False,
401
+ pos_drop_rate=0.,
402
+ drop_rate=0.,
403
+ attn_drop_rate=0.,
404
+ mlp_time_embed=False,
405
+ text_dim=config.get_ref('text_dim'),
406
+ num_text_tokens=77,
407
+ clip_img_dim=config.get_ref('clip_img_dim'),
408
+ use_checkpoint=True
409
+ )
410
+
411
+ config.sample = d(
412
+ sample_steps=50,
413
+ scale=7.,
414
+ t2i_cfg_mode='true_uncond'
415
+ )
416
+
417
+ return config
418
+
419
+
420
+ def sample(mode, prompt, image, sample_steps=50, scale=7.0, seed=None):
421
+ config = get_config()
422
+
423
+ config.nnet_path = "models/uvit_v1.pth"
424
+ config.n_samples = 1
425
+ config.nrow = 1
426
+
427
+ config.mode = mode
428
+ config.prompt = prompt
429
+ config.img = image
430
+
431
+ config.sample.sample_steps = sample_steps
432
+ config.sample.scale = scale
433
+ if seed is not None:
434
+ config.seed = seed
435
+
436
+ evaluate(config)
437
+
utils.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from absl import logging
2
+ import numpy as np
3
+ from PIL import Image, ImageDraw, ImageFont
4
+
5
+
6
+ def center_crop(width, height, img):
7
+ resample = {'box': Image.BOX, 'lanczos': Image.LANCZOS}['lanczos']
8
+ crop = np.min(img.shape[:2])
9
+ img = img[(img.shape[0] - crop) // 2: (img.shape[0] + crop) // 2,
10
+ (img.shape[1] - crop) // 2: (img.shape[1] + crop) // 2] # center crop
11
+ try:
12
+ img = Image.fromarray(img, 'RGB')
13
+ except:
14
+ img = Image.fromarray(img)
15
+ img = img.resize((width, height), resample) # resize the center crop from [crop, crop] to [width, height]
16
+
17
+ return np.array(img).astype(np.uint8)
18
+
19
+
20
+ def set_logger(log_level='info', fname=None):
21
+ import logging as _logging
22
+ handler = logging.get_absl_handler()
23
+ formatter = _logging.Formatter('%(asctime)s - %(filename)s - %(message)s')
24
+ handler.setFormatter(formatter)
25
+ logging.set_verbosity(log_level)
26
+ if fname is not None:
27
+ handler = _logging.FileHandler(fname)
28
+ handler.setFormatter(formatter)
29
+ logging.get_absl_logger().addHandler(handler)
30
+
31
+
32
+ def get_nnet(name, **kwargs):
33
+ if name == 'uvit_multi_post_ln':
34
+ from libs.uvit_multi_post_ln import UViT
35
+ return UViT(**kwargs)
36
+ elif name == 'uvit_multi_post_ln_v1':
37
+ from libs.uvit_multi_post_ln_v1 import UViT
38
+ return UViT(**kwargs)
39
+ else:
40
+ raise NotImplementedError(name)
41
+
42
+
43
+ def drawRoundRec(draw, color, x, y, w, h, r):
44
+ drawObject = draw
45
+
46
+ '''Rounds'''
47
+ drawObject.ellipse((x, y, x + r, y + r), fill=color)
48
+ drawObject.ellipse((x + w - r, y, x + w, y + r), fill=color)
49
+ drawObject.ellipse((x, y + h - r, x + r, y + h), fill=color)
50
+ drawObject.ellipse((x + w - r, y + h - r, x + w, y + h), fill=color)
51
+
52
+ '''rec.s'''
53
+ drawObject.rectangle((x + r / 2, y, x + w - (r / 2), y + h), fill=color)
54
+ drawObject.rectangle((x, y + r / 2, x + w, y + h - (r / 2)), fill=color)
55
+
56
+
57
+ def add_water(img, text='UniDiffuser', pos=3):
58
+ width, height = img.size
59
+ scale = 4
60
+ scale_size = 0.5
61
+ img = img.resize((width * scale, height * scale), Image.LANCZOS)
62
+ result = Image.new(img.mode, (width * scale, height * scale), color=(255, 255, 255))
63
+ result.paste(img, box=(0, 0))
64
+
65
+ delta_w = int(width * scale * 0.27 * scale_size) # text width
66
+ delta_h = width * scale * 0.05 * scale_size # text height
67
+ postions = np.array([[0, 0], [0, height * scale - delta_h], [width * scale - delta_w, 0],
68
+ [width * scale - delta_w, height * scale - delta_h]])
69
+ postion = postions[pos]
70
+ # 文本
71
+ draw = ImageDraw.Draw(result)
72
+ fillColor = (107, 92, 231)
73
+ setFont = ImageFont.truetype("assets/ArialBoldMT.ttf", int(width * scale * 0.05 * scale_size))
74
+ delta = 20 * scale_size
75
+ padding = 15 * scale_size
76
+ drawRoundRec(draw, (223, 230, 233), postion[0] - delta - padding, postion[1] - delta - padding,
77
+ w=delta_w + 2 * padding, h=delta_h + 2 * padding, r=50 * scale_size)
78
+ draw.text((postion[0] - delta, postion[1] - delta), text, font=setFont, fill=fillColor)
79
+
80
+ return result.resize((width, height), Image.LANCZOS)