Upload 9 files
Browse filesAdd unidiffuser original code since I can't figure out how to package it correctly
- dpm_solver_pp.py +952 -0
- libs/autoencoder.py +519 -0
- libs/caption_decoder.py +283 -0
- libs/clip.py +38 -0
- libs/timm.py +112 -0
- libs/uvit_multi_post_ln.py +277 -0
- libs/uvit_multi_post_ln_v1.py +285 -0
- unidiffuser/sample_v1.py +437 -0
- utils.py +80 -0
dpm_solver_pp.py
ADDED
@@ -0,0 +1,952 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import math
|
4 |
+
import numpy as np
|
5 |
+
import torch.distributed as dist
|
6 |
+
|
7 |
+
|
8 |
+
def interpolate_fn(x: torch.Tensor, xp: torch.Tensor, yp: torch.Tensor) -> torch.Tensor:
|
9 |
+
"""Performs piecewise linear interpolation for x, using xp and yp keypoints (knots).
|
10 |
+
Performs separate interpolation for each channel.
|
11 |
+
Args:
|
12 |
+
x: [N, C] points to be calibrated (interpolated). Batch with C channels.
|
13 |
+
xp: [C, K] x coordinates of the PWL knots. C is the number of channels, K is the number of knots.
|
14 |
+
yp: [C, K] y coordinates of the PWL knots. C is the number of channels, K is the number of knots.
|
15 |
+
Returns:
|
16 |
+
Interpolated points of the shape [N, C].
|
17 |
+
The piecewise linear function extends for the whole x axis (the outermost keypoints define the outermost
|
18 |
+
infinite lines).
|
19 |
+
For example:
|
20 |
+
>>> calibrate1d(torch.tensor([[0.5]]), torch.tensor([[0.0, 1.0]]), torch.tensor([[0.0, 2.0]]))
|
21 |
+
tensor([[1.0000]])
|
22 |
+
>>> calibrate1d(torch.tensor([[-10]]), torch.tensor([[0.0, 1.0]]), torch.tensor([[0.0, 2.0]]))
|
23 |
+
tensor([[-20.0000]])
|
24 |
+
"""
|
25 |
+
x_breakpoints = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((x.shape[0], 1, 1))], dim=2)
|
26 |
+
num_x_points = xp.shape[1]
|
27 |
+
sorted_x_breakpoints, x_indices = torch.sort(x_breakpoints, dim=2)
|
28 |
+
x_idx = torch.argmin(x_indices, dim=2)
|
29 |
+
cand_start_idx = x_idx - 1
|
30 |
+
start_idx = torch.where(
|
31 |
+
torch.eq(x_idx, 0),
|
32 |
+
torch.tensor(1, device=x.device),
|
33 |
+
torch.where(
|
34 |
+
torch.eq(x_idx, num_x_points), torch.tensor(num_x_points - 2, device=x.device), cand_start_idx,
|
35 |
+
),
|
36 |
+
)
|
37 |
+
end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
|
38 |
+
start_x = torch.gather(sorted_x_breakpoints, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
|
39 |
+
end_x = torch.gather(sorted_x_breakpoints, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
|
40 |
+
start_idx2 = torch.where(
|
41 |
+
torch.eq(x_idx, 0),
|
42 |
+
torch.tensor(0, device=x.device),
|
43 |
+
torch.where(
|
44 |
+
torch.eq(x_idx, num_x_points), torch.tensor(num_x_points - 2, device=x.device), cand_start_idx,
|
45 |
+
),
|
46 |
+
)
|
47 |
+
y_positions_expanded = yp.unsqueeze(0).expand(x.shape[0], -1, -1)
|
48 |
+
start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
|
49 |
+
end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
|
50 |
+
cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
|
51 |
+
return cand
|
52 |
+
|
53 |
+
|
54 |
+
class NoiseScheduleVP:
|
55 |
+
def __init__(self, schedule='discrete', beta_0=1e-4, beta_1=2e-2, total_N=1000, betas=None, alphas_cumprod=None):
|
56 |
+
"""Create a wrapper class for the forward SDE (VP type).
|
57 |
+
|
58 |
+
The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
|
59 |
+
We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
|
60 |
+
Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
|
61 |
+
|
62 |
+
log_alpha_t = self.marginal_log_mean_coeff(t)
|
63 |
+
sigma_t = self.marginal_std(t)
|
64 |
+
lambda_t = self.marginal_lambda(t)
|
65 |
+
|
66 |
+
Moreover, as lambda(t) is an invertible function, we also support its inverse function:
|
67 |
+
|
68 |
+
t = self.inverse_lambda(lambda_t)
|
69 |
+
|
70 |
+
===============================================================
|
71 |
+
|
72 |
+
We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
|
73 |
+
schedule are the default settings in DDPM and improved-DDPM:
|
74 |
+
|
75 |
+
beta_min: A `float` number. The smallest beta for the linear schedule.
|
76 |
+
beta_max: A `float` number. The largest beta for the linear schedule.
|
77 |
+
cosine_s: A `float` number. The hyperparameter in the cosine schedule.
|
78 |
+
cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
|
79 |
+
T: A `float` number. The ending time of the forward process.
|
80 |
+
|
81 |
+
Note that the original DDPM (linear schedule) used the discrete-time label (0 to 999). We convert the discrete-time
|
82 |
+
label to the continuous-time time (followed Song et al., 2021), so the beta here is 1000x larger than those in DDPM.
|
83 |
+
|
84 |
+
===============================================================
|
85 |
+
|
86 |
+
Args:
|
87 |
+
schedule: A `str`. The noise schedule of the forward SDE ('linear' or 'cosine').
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
A wrapper object of the forward SDE (VP type).
|
91 |
+
"""
|
92 |
+
if schedule not in ['linear', 'discrete', 'cosine']:
|
93 |
+
raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'linear' or 'cosine'".format(schedule))
|
94 |
+
self.total_N = total_N
|
95 |
+
self.beta_0 = beta_0 * 1000.
|
96 |
+
self.beta_1 = beta_1 * 1000.
|
97 |
+
|
98 |
+
if schedule == 'discrete':
|
99 |
+
if betas is not None:
|
100 |
+
log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
|
101 |
+
else:
|
102 |
+
assert alphas_cumprod is not None
|
103 |
+
log_alphas = 0.5 * torch.log(alphas_cumprod)
|
104 |
+
self.total_N = len(log_alphas)
|
105 |
+
self.t_discrete = torch.linspace(1. / self.total_N, 1., self.total_N).reshape((1, -1))
|
106 |
+
self.log_alpha_discrete = log_alphas.reshape((1, -1))
|
107 |
+
|
108 |
+
self.cosine_s = 0.008
|
109 |
+
self.cosine_beta_max = 999.
|
110 |
+
self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
|
111 |
+
self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
|
112 |
+
self.schedule = schedule
|
113 |
+
if schedule == 'cosine':
|
114 |
+
# For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
|
115 |
+
# Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
|
116 |
+
self.T = 0.9946
|
117 |
+
else:
|
118 |
+
self.T = 1.
|
119 |
+
|
120 |
+
def marginal_log_mean_coeff(self, t):
|
121 |
+
"""
|
122 |
+
Compute log(alpha_t) of a given continuous-time label t in [0, T].
|
123 |
+
"""
|
124 |
+
if self.schedule == 'linear':
|
125 |
+
return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
|
126 |
+
elif self.schedule == 'discrete':
|
127 |
+
return interpolate_fn(t.reshape((-1, 1)), self.t_discrete.clone().to(t.device), self.log_alpha_discrete.clone().to(t.device)).reshape((-1,))
|
128 |
+
elif self.schedule == 'cosine':
|
129 |
+
log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
|
130 |
+
log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
|
131 |
+
return log_alpha_t
|
132 |
+
else:
|
133 |
+
raise ValueError("Unsupported ")
|
134 |
+
|
135 |
+
def marginal_alpha(self, t):
|
136 |
+
return torch.exp(self.marginal_log_mean_coeff(t))
|
137 |
+
|
138 |
+
def marginal_std(self, t):
|
139 |
+
"""
|
140 |
+
Compute sigma_t of a given continuous-time label t in [0, T].
|
141 |
+
"""
|
142 |
+
return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
|
143 |
+
|
144 |
+
def marginal_lambda(self, t):
|
145 |
+
"""
|
146 |
+
Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
|
147 |
+
"""
|
148 |
+
log_mean_coeff = self.marginal_log_mean_coeff(t)
|
149 |
+
log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
|
150 |
+
return log_mean_coeff - log_std
|
151 |
+
|
152 |
+
def inverse_lambda(self, lamb):
|
153 |
+
"""
|
154 |
+
Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
|
155 |
+
"""
|
156 |
+
if self.schedule == 'linear':
|
157 |
+
tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
|
158 |
+
Delta = self.beta_0**2 + tmp
|
159 |
+
return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
|
160 |
+
elif self.schedule == 'discrete':
|
161 |
+
log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
|
162 |
+
t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_discrete.clone().to(lamb.device), [1]), torch.flip(self.t_discrete.clone().to(lamb.device), [1]))
|
163 |
+
return t.reshape((-1,))
|
164 |
+
else:
|
165 |
+
log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
|
166 |
+
t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
|
167 |
+
t = t_fn(log_alpha)
|
168 |
+
return t
|
169 |
+
|
170 |
+
|
171 |
+
def model_wrapper(model, noise_schedule=None, is_cond_classifier=False, classifier_fn=None, classifier_scale=1., time_input_type='1', total_N=1000, model_kwargs={}, is_deis=False):
|
172 |
+
"""Create a wrapper function for the noise prediction model.
|
173 |
+
|
174 |
+
DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
|
175 |
+
firstly wrap the model function to a function that accepts the continuous time as the input.
|
176 |
+
|
177 |
+
The input `model` has the following format:
|
178 |
+
|
179 |
+
``
|
180 |
+
model(x, t_input, **model_kwargs) -> noise
|
181 |
+
``
|
182 |
+
|
183 |
+
where `x` and `noise` have the same shape, and `t_input` is the time label of the model.
|
184 |
+
(may be discrete-time labels (i.e. 0 to 999) or continuous-time labels (i.e. epsilon to T).)
|
185 |
+
|
186 |
+
We wrap the model function to the following format:
|
187 |
+
|
188 |
+
``
|
189 |
+
def model_fn(x, t_continuous) -> noise:
|
190 |
+
t_input = get_model_input_time(t_continuous)
|
191 |
+
return model(x, t_input, **model_kwargs)
|
192 |
+
``
|
193 |
+
|
194 |
+
where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
|
195 |
+
|
196 |
+
For DPMs with classifier guidance, we also combine the model output with the classifier gradient as used in [1].
|
197 |
+
|
198 |
+
[1] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," in Advances in Neural
|
199 |
+
Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
|
200 |
+
|
201 |
+
===============================================================
|
202 |
+
|
203 |
+
Args:
|
204 |
+
model: A noise prediction model with the following format:
|
205 |
+
``
|
206 |
+
def model(x, t_input, **model_kwargs):
|
207 |
+
return noise
|
208 |
+
``
|
209 |
+
noise_schedule: A noise schedule object, such as NoiseScheduleVP. Only used for the classifier guidance.
|
210 |
+
is_cond_classifier: A `bool`. Whether to use the classifier guidance.
|
211 |
+
classifier_fn: A classifier function. Only used for the classifier guidance. The format is:
|
212 |
+
``
|
213 |
+
def classifier_fn(x, t_input):
|
214 |
+
return logits
|
215 |
+
``
|
216 |
+
classifier_scale: A `float`. The scale for the classifier guidance.
|
217 |
+
time_input_type: A `str`. The type for the time input of the model. We support three types:
|
218 |
+
- '0': The continuous-time type. In this case, the model is trained on the continuous time,
|
219 |
+
so `t_input` = `t_continuous`.
|
220 |
+
- '1': The Type-1 discrete type described in the Appendix of DPM-Solver paper.
|
221 |
+
**For discrete-time DPMs, we recommend to use this type for DPM-Solver**.
|
222 |
+
- '2': The Type-2 discrete type described in the Appendix of DPM-Solver paper.
|
223 |
+
total_N: A `int`. The total number of the discrete-time DPMs (default is 1000), used when `time_input_type`
|
224 |
+
is '1' or '2'.
|
225 |
+
model_kwargs: A `dict`. A dict for the other inputs of the model function.
|
226 |
+
Returns:
|
227 |
+
A function that accepts the continuous time as the input, with the following format:
|
228 |
+
``
|
229 |
+
def model_fn(x, t_continuous):
|
230 |
+
t_input = get_model_input_time(t_continuous)
|
231 |
+
return model(x, t_input, **model_kwargs)
|
232 |
+
``
|
233 |
+
"""
|
234 |
+
def get_model_input_time(t_continuous):
|
235 |
+
"""
|
236 |
+
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
|
237 |
+
"""
|
238 |
+
if time_input_type == '0':
|
239 |
+
# discrete_type == '0' means that the model is continuous-time model.
|
240 |
+
# For continuous-time DPMs, the continuous time equals to the discrete time.
|
241 |
+
return t_continuous
|
242 |
+
elif time_input_type == '1':
|
243 |
+
# Type-1 discrete label, as detailed in the Appendix of DPM-Solver.
|
244 |
+
return 1000. * torch.max(t_continuous - 1. / total_N, torch.zeros_like(t_continuous).to(t_continuous))
|
245 |
+
elif time_input_type == '2':
|
246 |
+
# Type-2 discrete label, as detailed in the Appendix of DPM-Solver.
|
247 |
+
max_N = (total_N - 1) / total_N * 1000.
|
248 |
+
return max_N * t_continuous
|
249 |
+
else:
|
250 |
+
raise ValueError("Unsupported time input type {}, must be '0' or '1' or '2'".format(time_input_type))
|
251 |
+
|
252 |
+
def cond_fn(x, t_discrete, y):
|
253 |
+
"""
|
254 |
+
Compute the gradient of the classifier, multiplied with the sclae of the classifier guidance.
|
255 |
+
"""
|
256 |
+
assert y is not None
|
257 |
+
with torch.enable_grad():
|
258 |
+
x_in = x.detach().requires_grad_(True)
|
259 |
+
logits = classifier_fn(x_in, t_discrete)
|
260 |
+
log_probs = F.log_softmax(logits, dim=-1)
|
261 |
+
selected = log_probs[range(len(logits)), y.view(-1)]
|
262 |
+
return classifier_scale * torch.autograd.grad(selected.sum(), x_in)[0]
|
263 |
+
|
264 |
+
def model_fn(x, t_continuous):
|
265 |
+
"""
|
266 |
+
The noise predicition model function that is used for DPM-Solver.
|
267 |
+
"""
|
268 |
+
if t_continuous.reshape((-1,)).shape[0] == 1:
|
269 |
+
t_continuous = torch.ones((x.shape[0],)).to(x.device) * t_continuous
|
270 |
+
if is_cond_classifier:
|
271 |
+
y = model_kwargs.get("y", None)
|
272 |
+
if y is None:
|
273 |
+
raise ValueError("For classifier guidance, the label y has to be in the input.")
|
274 |
+
t_discrete = get_model_input_time(t_continuous)
|
275 |
+
noise_uncond = model(x, t_discrete, **model_kwargs)
|
276 |
+
cond_grad = cond_fn(x, t_discrete, y)
|
277 |
+
if is_deis:
|
278 |
+
sigma_t = noise_schedule.marginal_std(t_continuous / 1000.)
|
279 |
+
else:
|
280 |
+
sigma_t = noise_schedule.marginal_std(t_continuous)
|
281 |
+
dims = len(cond_grad.shape) - 1
|
282 |
+
return noise_uncond - sigma_t[(...,) + (None,)*dims] * cond_grad
|
283 |
+
else:
|
284 |
+
t_discrete = get_model_input_time(t_continuous)
|
285 |
+
return model(x, t_discrete, **model_kwargs)
|
286 |
+
|
287 |
+
return model_fn
|
288 |
+
|
289 |
+
|
290 |
+
class DPM_Solver:
|
291 |
+
def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.):
|
292 |
+
"""Construct a DPM-Solver.
|
293 |
+
|
294 |
+
Args:
|
295 |
+
model_fn: A noise prediction model function which accepts the continuous-time input
|
296 |
+
(t in [epsilon, T]):
|
297 |
+
``
|
298 |
+
def model_fn(x, t_continuous):
|
299 |
+
return noise
|
300 |
+
``
|
301 |
+
noise_schedule: A noise schedule object, such as NoiseScheduleVP.
|
302 |
+
"""
|
303 |
+
self.model = model_fn
|
304 |
+
self.noise_schedule = noise_schedule
|
305 |
+
self.predict_x0 = predict_x0
|
306 |
+
self.thresholding = thresholding
|
307 |
+
self.max_val = max_val
|
308 |
+
|
309 |
+
def model_fn(self, x, t):
|
310 |
+
if self.predict_x0:
|
311 |
+
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
|
312 |
+
noise = self.model(x, t)
|
313 |
+
dims = len(x.shape) - 1
|
314 |
+
x0 = (x - sigma_t[(...,) + (None,)*dims] * noise) / alpha_t[(...,) + (None,)*dims]
|
315 |
+
if self.thresholding:
|
316 |
+
p = 0.995
|
317 |
+
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
|
318 |
+
s = torch.maximum(s, torch.ones_like(s).to(s.device))[(...,) + (None,)*dims]
|
319 |
+
x0 = torch.clamp(x0, -s, s) / (s / self.max_val)
|
320 |
+
return x0
|
321 |
+
else:
|
322 |
+
return self.model(x, t)
|
323 |
+
|
324 |
+
def get_time_steps(self, skip_type, t_T, t_0, N, device):
|
325 |
+
"""Compute the intermediate time steps for sampling.
|
326 |
+
|
327 |
+
Args:
|
328 |
+
skip_type: A `str`. The type for the spacing of the time steps. We support three types:
|
329 |
+
- 'logSNR': uniform logSNR for the time steps, **recommended for DPM-Solver**.
|
330 |
+
- 'time_uniform': uniform time for the time steps. (Used in DDIM and DDPM.)
|
331 |
+
- 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
|
332 |
+
t_T: A `float`. The starting time of the sampling (default is T).
|
333 |
+
t_0: A `float`. The ending time of the sampling (default is epsilon).
|
334 |
+
N: A `int`. The total number of the spacing of the time steps.
|
335 |
+
device: A torch device.
|
336 |
+
Returns:
|
337 |
+
A pytorch tensor of the time steps, with the shape (N + 1,).
|
338 |
+
"""
|
339 |
+
if skip_type == 'logSNR':
|
340 |
+
lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
|
341 |
+
lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
|
342 |
+
logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
|
343 |
+
# print(torch.min(torch.abs(logSNR_steps - self.noise_schedule.marginal_lambda(self.noise_schedule.inverse_lambda(logSNR_steps)))).item())
|
344 |
+
return self.noise_schedule.inverse_lambda(logSNR_steps)
|
345 |
+
elif skip_type == 't2':
|
346 |
+
t_order = 2
|
347 |
+
t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
|
348 |
+
return t
|
349 |
+
elif skip_type == 'time_uniform':
|
350 |
+
return torch.linspace(t_T, t_0, N + 1).to(device)
|
351 |
+
elif skip_type == 'time_quadratic':
|
352 |
+
t = torch.linspace(t_0, t_T, 10000000).to(device)
|
353 |
+
quadratic_t = torch.sqrt(t)
|
354 |
+
quadratic_steps = torch.linspace(quadratic_t[0], quadratic_t[-1], N + 1).to(device)
|
355 |
+
return torch.flip(torch.cat([t[torch.searchsorted(quadratic_t, quadratic_steps)[:-1]], t_T * torch.ones((1,)).to(device)], dim=0), dims=[0])
|
356 |
+
else:
|
357 |
+
raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
|
358 |
+
|
359 |
+
def get_time_steps_for_dpm_solver_fast(self, skip_type, t_T, t_0, steps, order, device):
|
360 |
+
"""
|
361 |
+
Compute the intermediate time steps and the order of each step for sampling by DPM-Solver-fast.
|
362 |
+
|
363 |
+
We recommend DPM-Solver-fast for fast sampling of DPMs. Given a fixed number of function evaluations by `steps`,
|
364 |
+
the sampling procedure by DPM-Solver-fast is:
|
365 |
+
- Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
|
366 |
+
- If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
|
367 |
+
- If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
|
368 |
+
- If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
|
369 |
+
|
370 |
+
============================================
|
371 |
+
Args:
|
372 |
+
t_T: A `float`. The starting time of the sampling (default is T).
|
373 |
+
t_0: A `float`. The ending time of the sampling (default is epsilon).
|
374 |
+
steps: A `int`. The total number of function evaluations (NFE).
|
375 |
+
device: A torch device.
|
376 |
+
Returns:
|
377 |
+
orders: A list of the solver order of each step.
|
378 |
+
timesteps: A pytorch tensor of the time steps, with the shape of (K + 1,).
|
379 |
+
"""
|
380 |
+
if order == 3:
|
381 |
+
K = steps // 3 + 1
|
382 |
+
if steps % 3 == 0:
|
383 |
+
orders = [3,] * (K - 2) + [2, 1]
|
384 |
+
elif steps % 3 == 1:
|
385 |
+
orders = [3,] * (K - 1) + [1]
|
386 |
+
else:
|
387 |
+
orders = [3,] * (K - 1) + [2]
|
388 |
+
timesteps = self.get_time_steps(skip_type, t_T, t_0, K, device)
|
389 |
+
return orders, timesteps
|
390 |
+
elif order == 2:
|
391 |
+
K = steps // 2
|
392 |
+
if steps % 2 == 0:
|
393 |
+
orders = [2,] * K
|
394 |
+
else:
|
395 |
+
orders = [2,] * K + [1]
|
396 |
+
timesteps = self.get_time_steps(skip_type, t_T, t_0, K, device)
|
397 |
+
return orders, timesteps
|
398 |
+
else:
|
399 |
+
raise ValueError("order must >= 2")
|
400 |
+
|
401 |
+
def denoise_fn(self, x, s, noise_s=None):
|
402 |
+
ns = self.noise_schedule
|
403 |
+
dims = len(x.shape) - 1
|
404 |
+
log_alpha_s = ns.marginal_log_mean_coeff(s)
|
405 |
+
sigma_s = ns.marginal_std(s)
|
406 |
+
|
407 |
+
if noise_s is None:
|
408 |
+
noise_s = self.model_fn(x, s)
|
409 |
+
x_0 = (
|
410 |
+
(x - sigma_s[(...,) + (None,)*dims] * noise_s) / torch.exp(log_alpha_s)[(...,) + (None,)*dims]
|
411 |
+
)
|
412 |
+
return x_0
|
413 |
+
|
414 |
+
def dpm_solver_first_update(self, x, s, t, noise_s=None, return_noise=False):
|
415 |
+
"""
|
416 |
+
A single step for DPM-Solver-1.
|
417 |
+
|
418 |
+
Args:
|
419 |
+
x: A pytorch tensor. The initial value at time `s`.
|
420 |
+
s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
|
421 |
+
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
|
422 |
+
return_noise: A `bool`. If true, also return the predicted noise at time `s`.
|
423 |
+
Returns:
|
424 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
425 |
+
"""
|
426 |
+
ns = self.noise_schedule
|
427 |
+
dims = len(x.shape) - 1
|
428 |
+
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
|
429 |
+
h = lambda_t - lambda_s
|
430 |
+
log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
|
431 |
+
sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
|
432 |
+
alpha_t = torch.exp(log_alpha_t)
|
433 |
+
|
434 |
+
if self.predict_x0:
|
435 |
+
phi_1 = (torch.exp(-h) - 1.) / (-1.)
|
436 |
+
if noise_s is None:
|
437 |
+
noise_s = self.model_fn(x, s)
|
438 |
+
x_t = (
|
439 |
+
(sigma_t / sigma_s)[(...,) + (None,)*dims] * x
|
440 |
+
+ (alpha_t * phi_1)[(...,) + (None,)*dims] * noise_s
|
441 |
+
)
|
442 |
+
if return_noise:
|
443 |
+
return x_t, {'noise_s': noise_s}
|
444 |
+
else:
|
445 |
+
return x_t
|
446 |
+
else:
|
447 |
+
phi_1 = torch.expm1(h)
|
448 |
+
if noise_s is None:
|
449 |
+
noise_s = self.model_fn(x, s)
|
450 |
+
x_t = (
|
451 |
+
torch.exp(log_alpha_t - log_alpha_s)[(...,) + (None,)*dims] * x
|
452 |
+
- (sigma_t * phi_1)[(...,) + (None,)*dims] * noise_s
|
453 |
+
)
|
454 |
+
if return_noise:
|
455 |
+
return x_t, {'noise_s': noise_s}
|
456 |
+
else:
|
457 |
+
return x_t
|
458 |
+
|
459 |
+
def dpm_solver_second_update(self, x, s, t, r1=0.5, noise_s=None, return_noise=False, solver_type='dpm_solver'):
|
460 |
+
"""
|
461 |
+
A single step for DPM-Solver-2.
|
462 |
+
|
463 |
+
Args:
|
464 |
+
x: A pytorch tensor. The initial value at time `s`.
|
465 |
+
s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
|
466 |
+
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
|
467 |
+
r1: A `float`. The hyperparameter of the second-order solver. We recommend the default setting `0.5`.
|
468 |
+
noise_s: A pytorch tensor. The predicted noise at time `s`.
|
469 |
+
If `noise_s` is None, we compute the predicted noise by `x` and `s`; otherwise we directly use it.
|
470 |
+
return_noise: A `bool`. If true, also return the predicted noise at time `s` and `s1` (the intermediate time).
|
471 |
+
Returns:
|
472 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
473 |
+
"""
|
474 |
+
if r1 is None:
|
475 |
+
r1 = 0.5
|
476 |
+
ns = self.noise_schedule
|
477 |
+
dims = len(x.shape) - 1
|
478 |
+
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
|
479 |
+
h = lambda_t - lambda_s
|
480 |
+
lambda_s1 = lambda_s + r1 * h
|
481 |
+
s1 = ns.inverse_lambda(lambda_s1)
|
482 |
+
log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(t)
|
483 |
+
sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
|
484 |
+
alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
|
485 |
+
|
486 |
+
if self.predict_x0:
|
487 |
+
phi_11 = torch.expm1(-r1 * h)
|
488 |
+
phi_1 = torch.expm1(-h)
|
489 |
+
|
490 |
+
if noise_s is None:
|
491 |
+
noise_s = self.model_fn(x, s)
|
492 |
+
x_s1 = (
|
493 |
+
(sigma_s1 / sigma_s)[(...,) + (None,)*dims] * x
|
494 |
+
- (alpha_s1 * phi_11)[(...,) + (None,)*dims] * noise_s
|
495 |
+
)
|
496 |
+
noise_s1 = self.model_fn(x_s1, s1)
|
497 |
+
if solver_type == 'dpm_solver':
|
498 |
+
x_t = (
|
499 |
+
(sigma_t / sigma_s)[(...,) + (None,)*dims] * x
|
500 |
+
- (alpha_t * phi_1)[(...,) + (None,)*dims] * noise_s
|
501 |
+
- (0.5 / r1) * (alpha_t * phi_1)[(...,) + (None,)*dims] * (noise_s1 - noise_s)
|
502 |
+
)
|
503 |
+
elif solver_type == 'taylor':
|
504 |
+
x_t = (
|
505 |
+
(sigma_t / sigma_s)[(...,) + (None,)*dims] * x
|
506 |
+
- (alpha_t * phi_1)[(...,) + (None,)*dims] * noise_s
|
507 |
+
+ (1. / r1) * (alpha_t * ((torch.exp(-h) - 1.) / h + 1.))[(...,) + (None,)*dims] * (noise_s1 - noise_s)
|
508 |
+
)
|
509 |
+
else:
|
510 |
+
raise ValueError("solver_type must be either dpm_solver or taylor, got {}".format(solver_type))
|
511 |
+
else:
|
512 |
+
phi_11 = torch.expm1(r1 * h)
|
513 |
+
phi_1 = torch.expm1(h)
|
514 |
+
|
515 |
+
if noise_s is None:
|
516 |
+
noise_s = self.model_fn(x, s)
|
517 |
+
x_s1 = (
|
518 |
+
torch.exp(log_alpha_s1 - log_alpha_s)[(...,) + (None,)*dims] * x
|
519 |
+
- (sigma_s1 * phi_11)[(...,) + (None,)*dims] * noise_s
|
520 |
+
)
|
521 |
+
noise_s1 = self.model_fn(x_s1, s1)
|
522 |
+
if solver_type == 'dpm_solver':
|
523 |
+
x_t = (
|
524 |
+
torch.exp(log_alpha_t - log_alpha_s)[(...,) + (None,)*dims] * x
|
525 |
+
- (sigma_t * phi_1)[(...,) + (None,)*dims] * noise_s
|
526 |
+
- (0.5 / r1) * (sigma_t * phi_1)[(...,) + (None,)*dims] * (noise_s1 - noise_s)
|
527 |
+
)
|
528 |
+
elif solver_type == 'taylor':
|
529 |
+
x_t = (
|
530 |
+
torch.exp(log_alpha_t - log_alpha_s)[(...,) + (None,)*dims] * x
|
531 |
+
- (sigma_t * phi_1)[(...,) + (None,)*dims] * noise_s
|
532 |
+
- (1. / r1) * (sigma_t * ((torch.exp(h) - 1.) / h - 1.))[(...,) + (None,)*dims] * (noise_s1 - noise_s)
|
533 |
+
)
|
534 |
+
else:
|
535 |
+
raise ValueError("solver_type must be either dpm_solver or taylor, got {}".format(solver_type))
|
536 |
+
if return_noise:
|
537 |
+
return x_t, {'noise_s': noise_s, 'noise_s1': noise_s1}
|
538 |
+
else:
|
539 |
+
return x_t
|
540 |
+
|
541 |
+
|
542 |
+
def dpm_multistep_second_update(self, x, noise_prev_list, t_prev_list, t, solver_type="dpm_solver"):
|
543 |
+
ns = self.noise_schedule
|
544 |
+
dims = len(x.shape) - 1
|
545 |
+
noise_prev_1, noise_prev_0 = noise_prev_list
|
546 |
+
t_prev_1, t_prev_0 = t_prev_list
|
547 |
+
lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
|
548 |
+
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
549 |
+
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
550 |
+
alpha_t = torch.exp(log_alpha_t)
|
551 |
+
|
552 |
+
h_0 = lambda_prev_0 - lambda_prev_1
|
553 |
+
h = lambda_t - lambda_prev_0
|
554 |
+
r0 = h_0 / h
|
555 |
+
D1_0 = (1. / r0)[(...,) + (None,)*dims] * (noise_prev_0 - noise_prev_1)
|
556 |
+
if self.predict_x0:
|
557 |
+
if solver_type == 'taylor':
|
558 |
+
x_t = (
|
559 |
+
(sigma_t / sigma_prev_0)[(...,) + (None,)*dims] * x
|
560 |
+
- (alpha_t * (torch.exp(-h) - 1.))[(...,) + (None,)*dims] * noise_prev_0
|
561 |
+
+ (alpha_t * ((torch.exp(-h) - 1.) / h + 1.))[(...,) + (None,)*dims] * D1_0
|
562 |
+
)
|
563 |
+
elif solver_type == 'dpm_solver':
|
564 |
+
x_t = (
|
565 |
+
(sigma_t / sigma_prev_0)[(...,) + (None,)*dims] * x
|
566 |
+
- (alpha_t * (torch.exp(-h) - 1.))[(...,) + (None,)*dims] * noise_prev_0
|
567 |
+
- 0.5 * (alpha_t * (torch.exp(-h) - 1.))[(...,) + (None,)*dims] * D1_0
|
568 |
+
)
|
569 |
+
else:
|
570 |
+
if solver_type == 'taylor':
|
571 |
+
x_t = (
|
572 |
+
torch.exp(log_alpha_t - log_alpha_prev_0)[(...,) + (None,)*dims] * x
|
573 |
+
- (sigma_t * (torch.exp(h) - 1.))[(...,) + (None,)*dims] * noise_prev_0
|
574 |
+
- (sigma_t * ((torch.exp(h) - 1.) / h - 1.))[(...,) + (None,)*dims] * D1_0
|
575 |
+
)
|
576 |
+
elif solver_type == 'dpm_solver':
|
577 |
+
x_t = (
|
578 |
+
torch.exp(log_alpha_t - log_alpha_prev_0)[(...,) + (None,)*dims] * x
|
579 |
+
- (sigma_t * (torch.exp(h) - 1.))[(...,) + (None,)*dims] * noise_prev_0
|
580 |
+
- 0.5 * (sigma_t * (torch.exp(h) - 1.))[(...,) + (None,)*dims] * D1_0
|
581 |
+
)
|
582 |
+
return x_t
|
583 |
+
|
584 |
+
|
585 |
+
def dpm_multistep_third_update(self, x, noise_prev_list, t_prev_list, t, solver_type='dpm_solver'):
|
586 |
+
ns = self.noise_schedule
|
587 |
+
dims = len(x.shape) - 1
|
588 |
+
noise_prev_2, noise_prev_1, noise_prev_0 = noise_prev_list
|
589 |
+
t_prev_2, t_prev_1, t_prev_0 = t_prev_list
|
590 |
+
lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
|
591 |
+
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
592 |
+
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
593 |
+
alpha_t = torch.exp(log_alpha_t)
|
594 |
+
|
595 |
+
h_1 = lambda_prev_1 - lambda_prev_2
|
596 |
+
h_0 = lambda_prev_0 - lambda_prev_1
|
597 |
+
h = lambda_t - lambda_prev_0
|
598 |
+
r0, r1 = h_0 / h, h_1 / h
|
599 |
+
D1_0 = (1. / r0)[(...,) + (None,)*dims] * (noise_prev_0 - noise_prev_1)
|
600 |
+
D1_1 = (1. / r1)[(...,) + (None,)*dims] * (noise_prev_1 - noise_prev_2)
|
601 |
+
D1 = D1_0 + (r0 / (r0 + r1))[(...,) + (None,)*dims] * (D1_0 - D1_1)
|
602 |
+
D2 = (1. / (r0 + r1))[(...,) + (None,)*dims] * (D1_0 - D1_1)
|
603 |
+
if self.predict_x0:
|
604 |
+
x_t = (
|
605 |
+
(sigma_t / sigma_prev_0)[(...,) + (None,)*dims] * x
|
606 |
+
- (alpha_t * (torch.exp(-h) - 1.))[(...,) + (None,)*dims] * noise_prev_0
|
607 |
+
+ (alpha_t * ((torch.exp(-h) - 1.) / h + 1.))[(...,) + (None,)*dims] * D1
|
608 |
+
- (alpha_t * ((torch.exp(-h) - 1. + h) / h**2 - 0.5))[(...,) + (None,)*dims] * D2
|
609 |
+
)
|
610 |
+
else:
|
611 |
+
x_t = (
|
612 |
+
torch.exp(log_alpha_t - log_alpha_prev_0)[(...,) + (None,)*dims] * x
|
613 |
+
- (sigma_t * (torch.exp(h) - 1.))[(...,) + (None,)*dims] * noise_prev_0
|
614 |
+
- (sigma_t * ((torch.exp(h) - 1.) / h - 1.))[(...,) + (None,)*dims] * D1
|
615 |
+
- (sigma_t * ((torch.exp(h) - 1. - h) / h**2 - 0.5))[(...,) + (None,)*dims] * D2
|
616 |
+
)
|
617 |
+
return x_t
|
618 |
+
|
619 |
+
def dpm_solver_third_update(self, x, s, t, r1=1./3., r2=2./3., noise_s=None, noise_s1=None, noise_s2=None, return_noise=False, solver_type='dpm_solver'):
|
620 |
+
"""
|
621 |
+
A single step for DPM-Solver-3.
|
622 |
+
|
623 |
+
Args:
|
624 |
+
x: A pytorch tensor. The initial value at time `s`.
|
625 |
+
s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
|
626 |
+
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
|
627 |
+
r1: A `float`. The hyperparameter of the third-order solver. We recommend the default setting `1 / 3`.
|
628 |
+
r2: A `float`. The hyperparameter of the third-order solver. We recommend the default setting `2 / 3`.
|
629 |
+
noise_s: A pytorch tensor. The predicted noise at time `s`.
|
630 |
+
If `noise_s` is None, we compute the predicted noise by `x` and `s`; otherwise we directly use it.
|
631 |
+
noise_s1: A pytorch tensor. The predicted noise at time `s1` (the intermediate time given by `r1`).
|
632 |
+
If `noise_s1` is None, we compute the predicted noise by `s1`; otherwise we directly use it.
|
633 |
+
Returns:
|
634 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
635 |
+
"""
|
636 |
+
if r1 is None:
|
637 |
+
r1 = 1. / 3.
|
638 |
+
if r2 is None:
|
639 |
+
r2 = 2. / 3.
|
640 |
+
ns = self.noise_schedule
|
641 |
+
dims = len(x.shape) - 1
|
642 |
+
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
|
643 |
+
h = lambda_t - lambda_s
|
644 |
+
lambda_s1 = lambda_s + r1 * h
|
645 |
+
lambda_s2 = lambda_s + r2 * h
|
646 |
+
s1 = ns.inverse_lambda(lambda_s1)
|
647 |
+
s2 = ns.inverse_lambda(lambda_s2)
|
648 |
+
log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
|
649 |
+
sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(s2), ns.marginal_std(t)
|
650 |
+
alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
|
651 |
+
|
652 |
+
if self.predict_x0:
|
653 |
+
phi_11 = torch.expm1(-r1 * h)
|
654 |
+
phi_12 = torch.expm1(-r2 * h)
|
655 |
+
phi_1 = torch.expm1(-h)
|
656 |
+
phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.
|
657 |
+
phi_2 = phi_1 / h + 1.
|
658 |
+
phi_3 = phi_2 / h - 0.5
|
659 |
+
|
660 |
+
if noise_s is None:
|
661 |
+
noise_s = self.model_fn(x, s)
|
662 |
+
if noise_s1 is None:
|
663 |
+
x_s1 = (
|
664 |
+
(sigma_s1 / sigma_s)[(...,) + (None,)*dims] * x
|
665 |
+
- (alpha_s1 * phi_11)[(...,) + (None,)*dims] * noise_s
|
666 |
+
)
|
667 |
+
noise_s1 = self.model_fn(x_s1, s1)
|
668 |
+
if noise_s2 is None:
|
669 |
+
x_s2 = (
|
670 |
+
(sigma_s2 / sigma_s)[(...,) + (None,)*dims] * x
|
671 |
+
- (alpha_s2 * phi_12)[(...,) + (None,)*dims] * noise_s
|
672 |
+
+ r2 / r1 * (alpha_s2 * phi_22)[(...,) + (None,)*dims] * (noise_s1 - noise_s)
|
673 |
+
)
|
674 |
+
noise_s2 = self.model_fn(x_s2, s2)
|
675 |
+
if solver_type == 'dpm_solver':
|
676 |
+
x_t = (
|
677 |
+
(sigma_t / sigma_s)[(...,) + (None,)*dims] * x
|
678 |
+
- (alpha_t * phi_1)[(...,) + (None,)*dims] * noise_s
|
679 |
+
+ (1. / r2) * (alpha_t * phi_2)[(...,) + (None,)*dims] * (noise_s2 - noise_s)
|
680 |
+
)
|
681 |
+
elif solver_type == 'taylor':
|
682 |
+
D1_0 = (1. / r1) * (noise_s1 - noise_s)
|
683 |
+
D1_1 = (1. / r2) * (noise_s2 - noise_s)
|
684 |
+
D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
|
685 |
+
D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
|
686 |
+
x_t = (
|
687 |
+
(sigma_t / sigma_s)[(...,) + (None,)*dims] * x
|
688 |
+
- (alpha_t * phi_1)[(...,) + (None,)*dims] * noise_s
|
689 |
+
+ (alpha_t * phi_2)[(...,) + (None,)*dims] * D1
|
690 |
+
- (alpha_t * phi_3)[(...,) + (None,)*dims] * D2
|
691 |
+
)
|
692 |
+
else:
|
693 |
+
raise ValueError("solver_type must be either dpm_solver or dpm_solver++, got {}".format(solver_type))
|
694 |
+
else:
|
695 |
+
phi_11 = torch.expm1(r1 * h)
|
696 |
+
phi_12 = torch.expm1(r2 * h)
|
697 |
+
phi_1 = torch.expm1(h)
|
698 |
+
phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.
|
699 |
+
phi_2 = phi_1 / h - 1.
|
700 |
+
phi_3 = phi_2 / h - 0.5
|
701 |
+
|
702 |
+
if noise_s is None:
|
703 |
+
noise_s = self.model_fn(x, s)
|
704 |
+
if noise_s1 is None:
|
705 |
+
x_s1 = (
|
706 |
+
torch.exp(log_alpha_s1 - log_alpha_s)[(...,) + (None,)*dims] * x
|
707 |
+
- (sigma_s1 * phi_11)[(...,) + (None,)*dims] * noise_s
|
708 |
+
)
|
709 |
+
noise_s1 = self.model_fn(x_s1, s1)
|
710 |
+
if noise_s2 is None:
|
711 |
+
x_s2 = (
|
712 |
+
torch.exp(log_alpha_s2 - log_alpha_s)[(...,) + (None,)*dims] * x
|
713 |
+
- (sigma_s2 * phi_12)[(...,) + (None,)*dims] * noise_s
|
714 |
+
- r2 / r1 * (sigma_s2 * phi_22)[(...,) + (None,)*dims] * (noise_s1 - noise_s)
|
715 |
+
)
|
716 |
+
noise_s2 = self.model_fn(x_s2, s2)
|
717 |
+
if solver_type == 'dpm_solver':
|
718 |
+
x_t = (
|
719 |
+
torch.exp(log_alpha_t - log_alpha_s)[(...,) + (None,)*dims] * x
|
720 |
+
- (sigma_t * phi_1)[(...,) + (None,)*dims] * noise_s
|
721 |
+
- (1. / r2) * (sigma_t * phi_2)[(...,) + (None,)*dims] * (noise_s2 - noise_s)
|
722 |
+
)
|
723 |
+
elif solver_type == 'taylor':
|
724 |
+
D1_0 = (1. / r1) * (noise_s1 - noise_s)
|
725 |
+
D1_1 = (1. / r2) * (noise_s2 - noise_s)
|
726 |
+
D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
|
727 |
+
D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
|
728 |
+
x_t = (
|
729 |
+
torch.exp(log_alpha_t - log_alpha_s)[(...,) + (None,)*dims] * x
|
730 |
+
- (sigma_t * phi_1)[(...,) + (None,)*dims] * noise_s
|
731 |
+
- (sigma_t * phi_2)[(...,) + (None,)*dims] * D1
|
732 |
+
- (sigma_t * phi_3)[(...,) + (None,)*dims] * D2
|
733 |
+
)
|
734 |
+
else:
|
735 |
+
raise ValueError("solver_type must be either dpm_solver or dpm_solver++, got {}".format(solver_type))
|
736 |
+
|
737 |
+
if return_noise:
|
738 |
+
return x_t, {'noise_s': noise_s, 'noise_s1': noise_s1, 'noise_s2': noise_s2}
|
739 |
+
else:
|
740 |
+
return x_t
|
741 |
+
|
742 |
+
def dpm_solver_update(self, x, s, t, order, return_noise=False, solver_type='dpm_solver', r1=None, r2=None):
|
743 |
+
"""
|
744 |
+
A single step for DPM-Solver of the given order `order`.
|
745 |
+
|
746 |
+
Args:
|
747 |
+
x: A pytorch tensor. The initial value at time `s`.
|
748 |
+
s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
|
749 |
+
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
|
750 |
+
order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
|
751 |
+
Returns:
|
752 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
753 |
+
"""
|
754 |
+
if order == 1:
|
755 |
+
return self.dpm_solver_first_update(x, s, t, return_noise=return_noise)
|
756 |
+
elif order == 2:
|
757 |
+
return self.dpm_solver_second_update(x, s, t, return_noise=return_noise, solver_type=solver_type, r1=r1)
|
758 |
+
elif order == 3:
|
759 |
+
return self.dpm_solver_third_update(x, s, t, return_noise=return_noise, solver_type=solver_type, r1=r1, r2=r2)
|
760 |
+
else:
|
761 |
+
raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
|
762 |
+
|
763 |
+
def dpm_multistep_update(self, x, noise_prev_list, t_prev_list, t, order, solver_type='taylor'):
|
764 |
+
"""
|
765 |
+
A single step for DPM-Solver of the given order `order`.
|
766 |
+
|
767 |
+
Args:
|
768 |
+
x: A pytorch tensor. The initial value at time `s`.
|
769 |
+
s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
|
770 |
+
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
|
771 |
+
order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
|
772 |
+
Returns:
|
773 |
+
x_t: A pytorch tensor. The approximated solution at time `t`.
|
774 |
+
"""
|
775 |
+
if order == 1:
|
776 |
+
return self.dpm_solver_first_update(x, t_prev_list[-1], t, noise_s=noise_prev_list[-1])
|
777 |
+
elif order == 2:
|
778 |
+
return self.dpm_multistep_second_update(x, noise_prev_list, t_prev_list, t, solver_type=solver_type)
|
779 |
+
elif order == 3:
|
780 |
+
return self.dpm_multistep_third_update(x, noise_prev_list, t_prev_list, t, solver_type=solver_type)
|
781 |
+
else:
|
782 |
+
raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
|
783 |
+
|
784 |
+
def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type='dpm_solver'):
|
785 |
+
"""
|
786 |
+
The adaptive step size solver based on DPM-Solver.
|
787 |
+
|
788 |
+
Args:
|
789 |
+
x: A pytorch tensor. The initial value at time `t_T`.
|
790 |
+
order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
|
791 |
+
t_T: A `float`. The starting time of the sampling (default is T).
|
792 |
+
t_0: A `float`. The ending time of the sampling (default is epsilon).
|
793 |
+
h_init: A `float`. The initial step size (for logSNR).
|
794 |
+
atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
|
795 |
+
rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
|
796 |
+
theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
|
797 |
+
t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
|
798 |
+
current time and `t_0` is less than `t_err`. The default setting is 1e-5.
|
799 |
+
Returns:
|
800 |
+
x_0: A pytorch tensor. The approximated solution at time `t_0`.
|
801 |
+
|
802 |
+
[1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
|
803 |
+
"""
|
804 |
+
ns = self.noise_schedule
|
805 |
+
s = t_T * torch.ones((x.shape[0],)).to(x)
|
806 |
+
lambda_s = ns.marginal_lambda(s)
|
807 |
+
lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
|
808 |
+
h = h_init * torch.ones_like(s).to(x)
|
809 |
+
x_prev = x
|
810 |
+
nfe = 0
|
811 |
+
if order == 2:
|
812 |
+
r1 = 0.5
|
813 |
+
lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_noise=True)
|
814 |
+
higher_update = lambda x, s, t, **kwargs: self.dpm_solver_second_update(x, s, t, r1=r1, solver_type=solver_type, **kwargs)
|
815 |
+
elif order == 3:
|
816 |
+
r1, r2 = 1. / 3., 2. / 3.
|
817 |
+
lower_update = lambda x, s, t: self.dpm_solver_second_update(x, s, t, r1=r1, return_noise=True, solver_type=solver_type)
|
818 |
+
higher_update = lambda x, s, t, **kwargs: self.dpm_solver_third_update(x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs)
|
819 |
+
else:
|
820 |
+
raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
|
821 |
+
while torch.abs((s - t_0)).mean() > t_err:
|
822 |
+
t = ns.inverse_lambda(lambda_s + h)
|
823 |
+
x_lower, lower_noise_kwargs = lower_update(x, s, t)
|
824 |
+
x_higher = higher_update(x, s, t, **lower_noise_kwargs)
|
825 |
+
delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
|
826 |
+
norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
|
827 |
+
E = norm_fn((x_higher - x_lower) / delta).max()
|
828 |
+
if torch.all(E <= 1.):
|
829 |
+
x = x_higher
|
830 |
+
s = t
|
831 |
+
x_prev = x_lower
|
832 |
+
lambda_s = ns.marginal_lambda(s)
|
833 |
+
h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s)
|
834 |
+
nfe += order
|
835 |
+
print('adaptive solver nfe', nfe)
|
836 |
+
return x
|
837 |
+
|
838 |
+
def sample(self, x, steps=10, eps=1e-4, T=None, order=3, skip_type='time_uniform',
|
839 |
+
denoise=False, method='fast', solver_type='dpm_solver', atol=0.0078,
|
840 |
+
rtol=0.05,
|
841 |
+
):
|
842 |
+
"""
|
843 |
+
Compute the sample at time `eps` by DPM-Solver, given the initial `x` at time `T`.
|
844 |
+
|
845 |
+
We support the following algorithms:
|
846 |
+
|
847 |
+
- Adaptive step size DPM-Solver (i.e. DPM-Solver-12 and DPM-Solver-23)
|
848 |
+
|
849 |
+
- Fixed order DPM-Solver (i.e. DPM-Solver-1, DPM-Solver-2 and DPM-Solver-3).
|
850 |
+
|
851 |
+
- Fast version of DPM-Solver (i.e. DPM-Solver-fast), which uses uniform logSNR steps and combine
|
852 |
+
different orders of DPM-Solver.
|
853 |
+
|
854 |
+
**We recommend DPM-Solver-fast for both fast sampling in few steps (<=20) and fast convergence in many steps (50 to 100).**
|
855 |
+
|
856 |
+
Choosing the algorithms:
|
857 |
+
|
858 |
+
- If `adaptive_step_size` is True:
|
859 |
+
We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
|
860 |
+
If `order`=2, we use DPM-Solver-12 which combines DPM-Solver-1 and DPM-Solver-2.
|
861 |
+
If `order`=3, we use DPM-Solver-23 which combines DPM-Solver-2 and DPM-Solver-3.
|
862 |
+
You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
|
863 |
+
(NFE) and the sample quality.
|
864 |
+
|
865 |
+
- If `adaptive_step_size` is False and `fast_version` is True:
|
866 |
+
We ignore `order` and use DPM-Solver-fast with number of function evaluations (NFE) = `steps`.
|
867 |
+
We ignore `skip_type` and use uniform logSNR steps for DPM-Solver-fast.
|
868 |
+
Given a fixed NFE=`steps`, the sampling procedure by DPM-Solver-fast is:
|
869 |
+
- Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
|
870 |
+
- If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
|
871 |
+
- If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
|
872 |
+
- If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
|
873 |
+
|
874 |
+
- If `adaptive_step_size` is False and `fast_version` is False:
|
875 |
+
We use DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
|
876 |
+
We support three types of `skip_type`:
|
877 |
+
- 'logSNR': uniform logSNR for the time steps, **recommended for DPM-Solver**.
|
878 |
+
- 'time_uniform': uniform time for the time steps. (Used in DDIM and DDPM.)
|
879 |
+
- 'time_quadratic': quadratic time for the time steps. (Used in DDIM.)
|
880 |
+
|
881 |
+
=====================================================
|
882 |
+
Args:
|
883 |
+
x: A pytorch tensor. The initial value at time `T` (a sample from the normal distribution).
|
884 |
+
steps: A `int`. The total number of function evaluations (NFE).
|
885 |
+
eps: A `float`. The ending time of the sampling.
|
886 |
+
We recommend `eps`=1e-3 when `steps` <= 15; and `eps`=1e-4 when `steps` > 15.
|
887 |
+
T: A `float`. The starting time of the sampling. Default is `None`.
|
888 |
+
If `T` is None, we use self.noise_schedule.T.
|
889 |
+
order: A `int`. The order of DPM-Solver.
|
890 |
+
skip_type: A `str`. The type for the spacing of the time steps. Default is 'logSNR'.
|
891 |
+
adaptive_step_size: A `bool`. If true, use the adaptive step size DPM-Solver.
|
892 |
+
fast_version: A `bool`. If true, use DPM-Solver-fast (recommended).
|
893 |
+
atol: A `float`. The absolute tolerance of the adaptive step size solver.
|
894 |
+
rtol: A `float`. The relative tolerance of the adaptive step size solver.
|
895 |
+
Returns:
|
896 |
+
x_0: A pytorch tensor. The approximated solution at time `t_0`.
|
897 |
+
|
898 |
+
[1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
|
899 |
+
"""
|
900 |
+
t_0 = eps
|
901 |
+
t_T = self.noise_schedule.T if T is None else T
|
902 |
+
device = x.device
|
903 |
+
if method == 'adaptive':
|
904 |
+
with torch.no_grad():
|
905 |
+
x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type)
|
906 |
+
elif method == 'multistep':
|
907 |
+
assert steps >= order
|
908 |
+
if timesteps is None:
|
909 |
+
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
|
910 |
+
assert timesteps.shape[0] - 1 == steps
|
911 |
+
with torch.no_grad():
|
912 |
+
vec_t = timesteps[0].expand((x.shape[0]))
|
913 |
+
noise_prev_list = [self.model_fn(x, vec_t)]
|
914 |
+
t_prev_list = [vec_t]
|
915 |
+
for init_order in range(1, order):
|
916 |
+
vec_t = timesteps[init_order].expand(x.shape[0])
|
917 |
+
x = self.dpm_multistep_update(x, noise_prev_list, t_prev_list, vec_t, init_order, solver_type=solver_type)
|
918 |
+
noise_prev_list.append(self.model_fn(x, vec_t))
|
919 |
+
t_prev_list.append(vec_t)
|
920 |
+
for step in range(order, steps + 1):
|
921 |
+
vec_t = timesteps[step].expand(x.shape[0])
|
922 |
+
x = self.dpm_multistep_update(x, noise_prev_list, t_prev_list, vec_t, order, solver_type=solver_type)
|
923 |
+
for i in range(order - 1):
|
924 |
+
t_prev_list[i] = t_prev_list[i + 1]
|
925 |
+
noise_prev_list[i] = noise_prev_list[i + 1]
|
926 |
+
t_prev_list[-1] = vec_t
|
927 |
+
if step < steps:
|
928 |
+
noise_prev_list[-1] = self.model_fn(x, vec_t)
|
929 |
+
elif method == 'fast':
|
930 |
+
orders, _ = self.get_time_steps_for_dpm_solver_fast(skip_type=skip_type, t_T=t_T, t_0=t_0, steps=steps, order=order, device=device)
|
931 |
+
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
|
932 |
+
with torch.no_grad():
|
933 |
+
i = 0
|
934 |
+
for order in orders:
|
935 |
+
vec_s, vec_t = torch.ones((x.shape[0],)).to(device) * timesteps[i], torch.ones((x.shape[0],)).to(device) * timesteps[i + order]
|
936 |
+
h = self.noise_schedule.marginal_lambda(timesteps[i + order]) - self.noise_schedule.marginal_lambda(timesteps[i])
|
937 |
+
r1 = None if order <= 1 else (self.noise_schedule.marginal_lambda(timesteps[i + 1]) - self.noise_schedule.marginal_lambda(timesteps[i])) / h
|
938 |
+
r2 = None if order <= 2 else (self.noise_schedule.marginal_lambda(timesteps[i + 2]) - self.noise_schedule.marginal_lambda(timesteps[i])) / h
|
939 |
+
x = self.dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2)
|
940 |
+
i += order
|
941 |
+
elif method == 'singlestep':
|
942 |
+
N_steps = steps // order
|
943 |
+
orders = [order,] * N_steps
|
944 |
+
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=N_steps, device=device)
|
945 |
+
assert len(timesteps) - 1 == N_steps
|
946 |
+
with torch.no_grad():
|
947 |
+
for i, order in enumerate(orders):
|
948 |
+
vec_s, vec_t = torch.ones((x.shape[0],)).to(device) * timesteps[i], torch.ones((x.shape[0],)).to(device) * timesteps[i + 1]
|
949 |
+
x = self.dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type)
|
950 |
+
if denoise:
|
951 |
+
x = self.denoise_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
|
952 |
+
return x
|
libs/autoencoder.py
ADDED
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
from einops import rearrange
|
5 |
+
|
6 |
+
|
7 |
+
class LinearAttention(nn.Module):
|
8 |
+
def __init__(self, dim, heads=4, dim_head=32):
|
9 |
+
super().__init__()
|
10 |
+
self.heads = heads
|
11 |
+
hidden_dim = dim_head * heads
|
12 |
+
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
|
13 |
+
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
14 |
+
|
15 |
+
def forward(self, x):
|
16 |
+
b, c, h, w = x.shape
|
17 |
+
qkv = self.to_qkv(x)
|
18 |
+
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
|
19 |
+
k = k.softmax(dim=-1)
|
20 |
+
context = torch.einsum('bhdn,bhen->bhde', k, v)
|
21 |
+
out = torch.einsum('bhde,bhdn->bhen', context, q)
|
22 |
+
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
|
23 |
+
return self.to_out(out)
|
24 |
+
|
25 |
+
|
26 |
+
def nonlinearity(x):
|
27 |
+
# swish
|
28 |
+
return x*torch.sigmoid(x)
|
29 |
+
|
30 |
+
|
31 |
+
def Normalize(in_channels, num_groups=32):
|
32 |
+
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
33 |
+
|
34 |
+
|
35 |
+
class Upsample(nn.Module):
|
36 |
+
def __init__(self, in_channels, with_conv):
|
37 |
+
super().__init__()
|
38 |
+
self.with_conv = with_conv
|
39 |
+
if self.with_conv:
|
40 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
41 |
+
in_channels,
|
42 |
+
kernel_size=3,
|
43 |
+
stride=1,
|
44 |
+
padding=1)
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
48 |
+
if self.with_conv:
|
49 |
+
x = self.conv(x)
|
50 |
+
return x
|
51 |
+
|
52 |
+
|
53 |
+
class Downsample(nn.Module):
|
54 |
+
def __init__(self, in_channels, with_conv):
|
55 |
+
super().__init__()
|
56 |
+
self.with_conv = with_conv
|
57 |
+
if self.with_conv:
|
58 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
59 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
60 |
+
in_channels,
|
61 |
+
kernel_size=3,
|
62 |
+
stride=2,
|
63 |
+
padding=0)
|
64 |
+
|
65 |
+
def forward(self, x):
|
66 |
+
if self.with_conv:
|
67 |
+
pad = (0,1,0,1)
|
68 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
69 |
+
x = self.conv(x)
|
70 |
+
else:
|
71 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
72 |
+
return x
|
73 |
+
|
74 |
+
|
75 |
+
class ResnetBlock(nn.Module):
|
76 |
+
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
77 |
+
dropout, temb_channels=512):
|
78 |
+
super().__init__()
|
79 |
+
self.in_channels = in_channels
|
80 |
+
out_channels = in_channels if out_channels is None else out_channels
|
81 |
+
self.out_channels = out_channels
|
82 |
+
self.use_conv_shortcut = conv_shortcut
|
83 |
+
|
84 |
+
self.norm1 = Normalize(in_channels)
|
85 |
+
self.conv1 = torch.nn.Conv2d(in_channels,
|
86 |
+
out_channels,
|
87 |
+
kernel_size=3,
|
88 |
+
stride=1,
|
89 |
+
padding=1)
|
90 |
+
if temb_channels > 0:
|
91 |
+
self.temb_proj = torch.nn.Linear(temb_channels,
|
92 |
+
out_channels)
|
93 |
+
self.norm2 = Normalize(out_channels)
|
94 |
+
self.dropout = torch.nn.Dropout(dropout)
|
95 |
+
self.conv2 = torch.nn.Conv2d(out_channels,
|
96 |
+
out_channels,
|
97 |
+
kernel_size=3,
|
98 |
+
stride=1,
|
99 |
+
padding=1)
|
100 |
+
if self.in_channels != self.out_channels:
|
101 |
+
if self.use_conv_shortcut:
|
102 |
+
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
103 |
+
out_channels,
|
104 |
+
kernel_size=3,
|
105 |
+
stride=1,
|
106 |
+
padding=1)
|
107 |
+
else:
|
108 |
+
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
109 |
+
out_channels,
|
110 |
+
kernel_size=1,
|
111 |
+
stride=1,
|
112 |
+
padding=0)
|
113 |
+
|
114 |
+
def forward(self, x, temb):
|
115 |
+
h = x
|
116 |
+
h = self.norm1(h)
|
117 |
+
h = nonlinearity(h)
|
118 |
+
h = self.conv1(h)
|
119 |
+
|
120 |
+
if temb is not None:
|
121 |
+
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
|
122 |
+
|
123 |
+
h = self.norm2(h)
|
124 |
+
h = nonlinearity(h)
|
125 |
+
h = self.dropout(h)
|
126 |
+
h = self.conv2(h)
|
127 |
+
|
128 |
+
if self.in_channels != self.out_channels:
|
129 |
+
if self.use_conv_shortcut:
|
130 |
+
x = self.conv_shortcut(x)
|
131 |
+
else:
|
132 |
+
x = self.nin_shortcut(x)
|
133 |
+
|
134 |
+
return x+h
|
135 |
+
|
136 |
+
|
137 |
+
class LinAttnBlock(LinearAttention):
|
138 |
+
"""to match AttnBlock usage"""
|
139 |
+
def __init__(self, in_channels):
|
140 |
+
super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
|
141 |
+
|
142 |
+
|
143 |
+
class AttnBlock(nn.Module):
|
144 |
+
def __init__(self, in_channels):
|
145 |
+
super().__init__()
|
146 |
+
self.in_channels = in_channels
|
147 |
+
|
148 |
+
self.norm = Normalize(in_channels)
|
149 |
+
self.q = torch.nn.Conv2d(in_channels,
|
150 |
+
in_channels,
|
151 |
+
kernel_size=1,
|
152 |
+
stride=1,
|
153 |
+
padding=0)
|
154 |
+
self.k = torch.nn.Conv2d(in_channels,
|
155 |
+
in_channels,
|
156 |
+
kernel_size=1,
|
157 |
+
stride=1,
|
158 |
+
padding=0)
|
159 |
+
self.v = torch.nn.Conv2d(in_channels,
|
160 |
+
in_channels,
|
161 |
+
kernel_size=1,
|
162 |
+
stride=1,
|
163 |
+
padding=0)
|
164 |
+
self.proj_out = torch.nn.Conv2d(in_channels,
|
165 |
+
in_channels,
|
166 |
+
kernel_size=1,
|
167 |
+
stride=1,
|
168 |
+
padding=0)
|
169 |
+
|
170 |
+
|
171 |
+
def forward(self, x):
|
172 |
+
h_ = x
|
173 |
+
h_ = self.norm(h_)
|
174 |
+
q = self.q(h_)
|
175 |
+
k = self.k(h_)
|
176 |
+
v = self.v(h_)
|
177 |
+
|
178 |
+
# compute attention
|
179 |
+
b,c,h,w = q.shape
|
180 |
+
q = q.reshape(b,c,h*w)
|
181 |
+
q = q.permute(0,2,1) # b,hw,c
|
182 |
+
k = k.reshape(b,c,h*w) # b,c,hw
|
183 |
+
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
184 |
+
w_ = w_ * (int(c)**(-0.5))
|
185 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
186 |
+
|
187 |
+
# attend to values
|
188 |
+
v = v.reshape(b,c,h*w)
|
189 |
+
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
|
190 |
+
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
191 |
+
h_ = h_.reshape(b,c,h,w)
|
192 |
+
|
193 |
+
h_ = self.proj_out(h_)
|
194 |
+
|
195 |
+
return x+h_
|
196 |
+
|
197 |
+
|
198 |
+
def make_attn(in_channels, attn_type="vanilla"):
|
199 |
+
assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
|
200 |
+
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
201 |
+
if attn_type == "vanilla":
|
202 |
+
return AttnBlock(in_channels)
|
203 |
+
elif attn_type == "none":
|
204 |
+
return nn.Identity(in_channels)
|
205 |
+
else:
|
206 |
+
return LinAttnBlock(in_channels)
|
207 |
+
|
208 |
+
|
209 |
+
class Encoder(nn.Module):
|
210 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
211 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
212 |
+
resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
|
213 |
+
**ignore_kwargs):
|
214 |
+
super().__init__()
|
215 |
+
if use_linear_attn: attn_type = "linear"
|
216 |
+
self.ch = ch
|
217 |
+
self.temb_ch = 0
|
218 |
+
self.num_resolutions = len(ch_mult)
|
219 |
+
self.num_res_blocks = num_res_blocks
|
220 |
+
self.resolution = resolution
|
221 |
+
self.in_channels = in_channels
|
222 |
+
|
223 |
+
# downsampling
|
224 |
+
self.conv_in = torch.nn.Conv2d(in_channels,
|
225 |
+
self.ch,
|
226 |
+
kernel_size=3,
|
227 |
+
stride=1,
|
228 |
+
padding=1)
|
229 |
+
|
230 |
+
curr_res = resolution
|
231 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
232 |
+
self.in_ch_mult = in_ch_mult
|
233 |
+
self.down = nn.ModuleList()
|
234 |
+
for i_level in range(self.num_resolutions):
|
235 |
+
block = nn.ModuleList()
|
236 |
+
attn = nn.ModuleList()
|
237 |
+
block_in = ch*in_ch_mult[i_level]
|
238 |
+
block_out = ch*ch_mult[i_level]
|
239 |
+
for i_block in range(self.num_res_blocks):
|
240 |
+
block.append(ResnetBlock(in_channels=block_in,
|
241 |
+
out_channels=block_out,
|
242 |
+
temb_channels=self.temb_ch,
|
243 |
+
dropout=dropout))
|
244 |
+
block_in = block_out
|
245 |
+
if curr_res in attn_resolutions:
|
246 |
+
attn.append(make_attn(block_in, attn_type=attn_type))
|
247 |
+
down = nn.Module()
|
248 |
+
down.block = block
|
249 |
+
down.attn = attn
|
250 |
+
if i_level != self.num_resolutions-1:
|
251 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
252 |
+
curr_res = curr_res // 2
|
253 |
+
self.down.append(down)
|
254 |
+
|
255 |
+
# middle
|
256 |
+
self.mid = nn.Module()
|
257 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
258 |
+
out_channels=block_in,
|
259 |
+
temb_channels=self.temb_ch,
|
260 |
+
dropout=dropout)
|
261 |
+
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
262 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
263 |
+
out_channels=block_in,
|
264 |
+
temb_channels=self.temb_ch,
|
265 |
+
dropout=dropout)
|
266 |
+
|
267 |
+
# end
|
268 |
+
self.norm_out = Normalize(block_in)
|
269 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
270 |
+
2*z_channels if double_z else z_channels,
|
271 |
+
kernel_size=3,
|
272 |
+
stride=1,
|
273 |
+
padding=1)
|
274 |
+
|
275 |
+
def forward(self, x):
|
276 |
+
# timestep embedding
|
277 |
+
temb = None
|
278 |
+
|
279 |
+
# downsampling
|
280 |
+
hs = [self.conv_in(x)]
|
281 |
+
for i_level in range(self.num_resolutions):
|
282 |
+
for i_block in range(self.num_res_blocks):
|
283 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
284 |
+
if len(self.down[i_level].attn) > 0:
|
285 |
+
h = self.down[i_level].attn[i_block](h)
|
286 |
+
hs.append(h)
|
287 |
+
if i_level != self.num_resolutions-1:
|
288 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
289 |
+
|
290 |
+
# middle
|
291 |
+
h = hs[-1]
|
292 |
+
h = self.mid.block_1(h, temb)
|
293 |
+
h = self.mid.attn_1(h)
|
294 |
+
h = self.mid.block_2(h, temb)
|
295 |
+
|
296 |
+
# end
|
297 |
+
h = self.norm_out(h)
|
298 |
+
h = nonlinearity(h)
|
299 |
+
h = self.conv_out(h)
|
300 |
+
return h
|
301 |
+
|
302 |
+
|
303 |
+
class Decoder(nn.Module):
|
304 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
305 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
306 |
+
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
|
307 |
+
attn_type="vanilla", **ignorekwargs):
|
308 |
+
super().__init__()
|
309 |
+
if use_linear_attn: attn_type = "linear"
|
310 |
+
self.ch = ch
|
311 |
+
self.temb_ch = 0
|
312 |
+
self.num_resolutions = len(ch_mult)
|
313 |
+
self.num_res_blocks = num_res_blocks
|
314 |
+
self.resolution = resolution
|
315 |
+
self.in_channels = in_channels
|
316 |
+
self.give_pre_end = give_pre_end
|
317 |
+
self.tanh_out = tanh_out
|
318 |
+
|
319 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
320 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
321 |
+
block_in = ch*ch_mult[self.num_resolutions-1]
|
322 |
+
curr_res = resolution // 2**(self.num_resolutions-1)
|
323 |
+
self.z_shape = (1,z_channels,curr_res,curr_res)
|
324 |
+
print("Working with z of shape {} = {} dimensions.".format(
|
325 |
+
self.z_shape, np.prod(self.z_shape)))
|
326 |
+
|
327 |
+
# z to block_in
|
328 |
+
self.conv_in = torch.nn.Conv2d(z_channels,
|
329 |
+
block_in,
|
330 |
+
kernel_size=3,
|
331 |
+
stride=1,
|
332 |
+
padding=1)
|
333 |
+
|
334 |
+
# middle
|
335 |
+
self.mid = nn.Module()
|
336 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
337 |
+
out_channels=block_in,
|
338 |
+
temb_channels=self.temb_ch,
|
339 |
+
dropout=dropout)
|
340 |
+
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
341 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
342 |
+
out_channels=block_in,
|
343 |
+
temb_channels=self.temb_ch,
|
344 |
+
dropout=dropout)
|
345 |
+
|
346 |
+
# upsampling
|
347 |
+
self.up = nn.ModuleList()
|
348 |
+
for i_level in reversed(range(self.num_resolutions)):
|
349 |
+
block = nn.ModuleList()
|
350 |
+
attn = nn.ModuleList()
|
351 |
+
block_out = ch*ch_mult[i_level]
|
352 |
+
for i_block in range(self.num_res_blocks+1):
|
353 |
+
block.append(ResnetBlock(in_channels=block_in,
|
354 |
+
out_channels=block_out,
|
355 |
+
temb_channels=self.temb_ch,
|
356 |
+
dropout=dropout))
|
357 |
+
block_in = block_out
|
358 |
+
if curr_res in attn_resolutions:
|
359 |
+
attn.append(make_attn(block_in, attn_type=attn_type))
|
360 |
+
up = nn.Module()
|
361 |
+
up.block = block
|
362 |
+
up.attn = attn
|
363 |
+
if i_level != 0:
|
364 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
365 |
+
curr_res = curr_res * 2
|
366 |
+
self.up.insert(0, up) # prepend to get consistent order
|
367 |
+
|
368 |
+
# end
|
369 |
+
self.norm_out = Normalize(block_in)
|
370 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
371 |
+
out_ch,
|
372 |
+
kernel_size=3,
|
373 |
+
stride=1,
|
374 |
+
padding=1)
|
375 |
+
|
376 |
+
def forward(self, z):
|
377 |
+
#assert z.shape[1:] == self.z_shape[1:]
|
378 |
+
self.last_z_shape = z.shape
|
379 |
+
|
380 |
+
# timestep embedding
|
381 |
+
temb = None
|
382 |
+
|
383 |
+
# z to block_in
|
384 |
+
h = self.conv_in(z)
|
385 |
+
|
386 |
+
# middle
|
387 |
+
h = self.mid.block_1(h, temb)
|
388 |
+
h = self.mid.attn_1(h)
|
389 |
+
h = self.mid.block_2(h, temb)
|
390 |
+
|
391 |
+
# upsampling
|
392 |
+
for i_level in reversed(range(self.num_resolutions)):
|
393 |
+
for i_block in range(self.num_res_blocks+1):
|
394 |
+
h = self.up[i_level].block[i_block](h, temb)
|
395 |
+
if len(self.up[i_level].attn) > 0:
|
396 |
+
h = self.up[i_level].attn[i_block](h)
|
397 |
+
if i_level != 0:
|
398 |
+
h = self.up[i_level].upsample(h)
|
399 |
+
|
400 |
+
# end
|
401 |
+
if self.give_pre_end:
|
402 |
+
return h
|
403 |
+
|
404 |
+
h = self.norm_out(h)
|
405 |
+
h = nonlinearity(h)
|
406 |
+
h = self.conv_out(h)
|
407 |
+
if self.tanh_out:
|
408 |
+
h = torch.tanh(h)
|
409 |
+
return h
|
410 |
+
|
411 |
+
|
412 |
+
class FrozenAutoencoderKL(nn.Module):
|
413 |
+
def __init__(self, ddconfig, embed_dim, pretrained_path, scale_factor=0.18215):
|
414 |
+
super().__init__()
|
415 |
+
print(f'Create autoencoder with scale_factor={scale_factor}')
|
416 |
+
self.encoder = Encoder(**ddconfig)
|
417 |
+
self.decoder = Decoder(**ddconfig)
|
418 |
+
assert ddconfig["double_z"]
|
419 |
+
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
|
420 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
421 |
+
self.embed_dim = embed_dim
|
422 |
+
self.scale_factor = scale_factor
|
423 |
+
m, u = self.load_state_dict(torch.load(pretrained_path, map_location='cpu'))
|
424 |
+
assert len(m) == 0 and len(u) == 0
|
425 |
+
self.eval()
|
426 |
+
self.requires_grad_(False)
|
427 |
+
|
428 |
+
def encode_moments(self, x):
|
429 |
+
h = self.encoder(x)
|
430 |
+
moments = self.quant_conv(h)
|
431 |
+
return moments
|
432 |
+
|
433 |
+
def sample(self, moments):
|
434 |
+
mean, logvar = torch.chunk(moments, 2, dim=1)
|
435 |
+
logvar = torch.clamp(logvar, -30.0, 20.0)
|
436 |
+
std = torch.exp(0.5 * logvar)
|
437 |
+
z = mean + std * torch.randn_like(mean)
|
438 |
+
z = self.scale_factor * z
|
439 |
+
return z
|
440 |
+
|
441 |
+
def encode(self, x):
|
442 |
+
moments = self.encode_moments(x)
|
443 |
+
z = self.sample(moments)
|
444 |
+
return z
|
445 |
+
|
446 |
+
def decode(self, z):
|
447 |
+
z = (1. / self.scale_factor) * z
|
448 |
+
z = self.post_quant_conv(z)
|
449 |
+
dec = self.decoder(z)
|
450 |
+
return dec
|
451 |
+
|
452 |
+
def forward(self, inputs, fn):
|
453 |
+
if fn == 'encode_moments':
|
454 |
+
return self.encode_moments(inputs)
|
455 |
+
elif fn == 'encode':
|
456 |
+
return self.encode(inputs)
|
457 |
+
elif fn == 'decode':
|
458 |
+
return self.decode(inputs)
|
459 |
+
else:
|
460 |
+
raise NotImplementedError
|
461 |
+
|
462 |
+
|
463 |
+
def get_model(pretrained_path, scale_factor=0.18215):
|
464 |
+
ddconfig = dict(
|
465 |
+
double_z=True,
|
466 |
+
z_channels=4,
|
467 |
+
resolution=256,
|
468 |
+
in_channels=3,
|
469 |
+
out_ch=3,
|
470 |
+
ch=128,
|
471 |
+
ch_mult=[1, 2, 4, 4],
|
472 |
+
num_res_blocks=2,
|
473 |
+
attn_resolutions=[],
|
474 |
+
dropout=0.0
|
475 |
+
)
|
476 |
+
return FrozenAutoencoderKL(ddconfig, 4, pretrained_path, scale_factor)
|
477 |
+
|
478 |
+
|
479 |
+
def main():
|
480 |
+
import torchvision.transforms as transforms
|
481 |
+
from torchvision.utils import save_image
|
482 |
+
import os
|
483 |
+
from PIL import Image
|
484 |
+
|
485 |
+
model = get_model('assets/stable-diffusion/autoencoder_kl.pth')
|
486 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
487 |
+
model = model.to(device)
|
488 |
+
|
489 |
+
scale_factor = 0.18215
|
490 |
+
T = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(256), transforms.ToTensor()])
|
491 |
+
path = 'imgs'
|
492 |
+
fnames = os.listdir(path)
|
493 |
+
for fname in fnames:
|
494 |
+
p = os.path.join(path, fname)
|
495 |
+
img = Image.open(p)
|
496 |
+
img = T(img)
|
497 |
+
img = img * 2. - 1
|
498 |
+
img = img[None, ...]
|
499 |
+
img = img.to(device)
|
500 |
+
|
501 |
+
# with torch.cuda.amp.autocast():
|
502 |
+
# moments = model.encode_moments(img)
|
503 |
+
# mean, logvar = torch.chunk(moments, 2, dim=1)
|
504 |
+
# logvar = torch.clamp(logvar, -30.0, 20.0)
|
505 |
+
# std = torch.exp(0.5 * logvar)
|
506 |
+
# zs = [(mean + std * torch.randn_like(mean)) * scale_factor for _ in range(4)]
|
507 |
+
# recons = [model.decode(z) for z in zs]
|
508 |
+
|
509 |
+
with torch.cuda.amp.autocast():
|
510 |
+
print('test encode & decode')
|
511 |
+
recons = [model.decode(model.encode(img)) for _ in range(4)]
|
512 |
+
|
513 |
+
out = torch.cat([img, *recons], dim=0)
|
514 |
+
out = (out + 1) * 0.5
|
515 |
+
save_image(out, f'recons_{fname}')
|
516 |
+
|
517 |
+
|
518 |
+
if __name__ == "__main__":
|
519 |
+
main()
|
libs/caption_decoder.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as nnf
|
6 |
+
|
7 |
+
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
8 |
+
from transformers import default_data_collator
|
9 |
+
from transformers import EarlyStoppingCallback
|
10 |
+
|
11 |
+
data_collator = default_data_collator
|
12 |
+
es = EarlyStoppingCallback(early_stopping_patience=5)
|
13 |
+
import json
|
14 |
+
import argparse
|
15 |
+
from typing import Union, Optional
|
16 |
+
from collections import OrderedDict
|
17 |
+
|
18 |
+
|
19 |
+
# %% model initial
|
20 |
+
class ClipCaptionModel(nn.Module):
|
21 |
+
"""
|
22 |
+
"""
|
23 |
+
|
24 |
+
def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor:
|
25 |
+
return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
|
26 |
+
|
27 |
+
def forward(self, tokens: torch.Tensor, prefix: torch.Tensor, mask: Optional[torch.Tensor] = None,
|
28 |
+
labels: Optional[torch.Tensor] = None):
|
29 |
+
"""
|
30 |
+
: param tokens: (Tensor) [N x max_seq_len] eg. [4 X 33]
|
31 |
+
: param prefix: (Tensor) [N x prefix_length x 768] eg. [4 x 77 x 768]
|
32 |
+
: param mask: (Tensor) [N x (prefix_length + max_seq_len) x 768] eg. [4 x 110 x768]
|
33 |
+
|
34 |
+
: attribute embedding_text: (Tensor) [N x max_seq_len x 768] eg. [4 x 33 x 768]
|
35 |
+
: attribute embedding_cat: (Tensor) [N x (prefix_length + max_seq_len) x 768] eg. [4 x 110 x 768]
|
36 |
+
"""
|
37 |
+
embedding_text = self.gpt.transformer.wte(tokens)
|
38 |
+
hidden = self.encode_prefix(prefix)
|
39 |
+
prefix = self.decode_prefix(hidden)
|
40 |
+
embedding_cat = torch.cat((prefix, embedding_text), dim=1)
|
41 |
+
|
42 |
+
if labels is not None:
|
43 |
+
dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
|
44 |
+
labels = torch.cat((dummy_token, tokens), dim=1)
|
45 |
+
out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
|
46 |
+
if self.hidden_dim is not None:
|
47 |
+
return out, hidden
|
48 |
+
else:
|
49 |
+
return out
|
50 |
+
|
51 |
+
def encode_decode_prefix(self, prefix):
|
52 |
+
return self.decode_prefix(self.encode_prefix(prefix))
|
53 |
+
|
54 |
+
def __init__(self, prefix_length: int, hidden_dim=None):
|
55 |
+
super(ClipCaptionModel, self).__init__()
|
56 |
+
self.prefix_length = prefix_length
|
57 |
+
eos = '<|EOS|>'
|
58 |
+
special_tokens_dict = {'eos_token': eos}
|
59 |
+
base_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
60 |
+
base_tokenizer.add_special_tokens(special_tokens_dict)
|
61 |
+
self.gpt = GPT2LMHeadModel.from_pretrained('gpt2', eos_token_id=base_tokenizer.eos_token_id)
|
62 |
+
self.gpt.resize_token_embeddings(len(base_tokenizer))
|
63 |
+
|
64 |
+
self.hidden_dim = hidden_dim
|
65 |
+
self.encode_prefix = nn.Linear(768, hidden_dim) if hidden_dim is not None else nn.Identity()
|
66 |
+
self.decode_prefix = nn.Linear(hidden_dim, 768) if hidden_dim is not None else nn.Identity()
|
67 |
+
|
68 |
+
|
69 |
+
|
70 |
+
|
71 |
+
def load_model(config_path: str, epoch_or_latest: Union[str, int] = '_latest'):
|
72 |
+
with open(config_path) as f:
|
73 |
+
config = json.load(f)
|
74 |
+
parser = argparse.ArgumentParser()
|
75 |
+
parser.set_defaults(**config)
|
76 |
+
args = parser.parse_args()
|
77 |
+
if type(epoch_or_latest) is int:
|
78 |
+
epoch_or_latest = f"-{epoch_or_latest:03d}"
|
79 |
+
model_path = os.path.join(args.out_dir, f"{args.prefix}{epoch_or_latest}.pt")
|
80 |
+
model = ClipCaptionModel(args.prefix_length)
|
81 |
+
if os.path.isfile(model_path):
|
82 |
+
print(f"loading model from {model_path}")
|
83 |
+
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
|
84 |
+
else:
|
85 |
+
print(f"{model_path} is not exist")
|
86 |
+
return model, parser
|
87 |
+
|
88 |
+
|
89 |
+
def generate_beam(
|
90 |
+
model,
|
91 |
+
tokenizer,
|
92 |
+
beam_size: int = 5,
|
93 |
+
prompt=None,
|
94 |
+
embed=None,
|
95 |
+
entry_length=67,
|
96 |
+
temperature=1.0,
|
97 |
+
stop_token: str = '<|EOS|>',
|
98 |
+
):
|
99 |
+
model.eval()
|
100 |
+
stop_token_index = tokenizer.encode(stop_token)[0]
|
101 |
+
tokens = None
|
102 |
+
scores = None
|
103 |
+
device = next(model.parameters()).device
|
104 |
+
seq_lengths = torch.ones(beam_size, device=device)
|
105 |
+
is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
|
106 |
+
with torch.no_grad():
|
107 |
+
if embed is not None:
|
108 |
+
generated = embed
|
109 |
+
else:
|
110 |
+
if tokens is None:
|
111 |
+
tokens = torch.tensor(tokenizer.encode(prompt))
|
112 |
+
tokens = tokens.unsqueeze(0).to(device)
|
113 |
+
generated = model.gpt.transformer.wte(tokens)
|
114 |
+
# pbar = tqdm(range(entry_length))
|
115 |
+
# pbar.set_description("generating text ...")
|
116 |
+
for i in range(entry_length):
|
117 |
+
# print(generated.shape)
|
118 |
+
outputs = model.gpt(inputs_embeds=generated)
|
119 |
+
logits = outputs.logits
|
120 |
+
logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
|
121 |
+
logits = logits.softmax(-1).log()
|
122 |
+
if scores is None:
|
123 |
+
scores, next_tokens = logits.topk(beam_size, -1)
|
124 |
+
generated = generated.expand(beam_size, *generated.shape[1:])
|
125 |
+
next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
|
126 |
+
if tokens is None:
|
127 |
+
tokens = next_tokens
|
128 |
+
else:
|
129 |
+
tokens = tokens.expand(beam_size, *tokens.shape[1:])
|
130 |
+
tokens = torch.cat((tokens, next_tokens), dim=1)
|
131 |
+
else:
|
132 |
+
logits[is_stopped] = -float(np.inf)
|
133 |
+
logits[is_stopped, 0] = 0
|
134 |
+
scores_sum = scores[:, None] + logits
|
135 |
+
seq_lengths[~is_stopped] += 1
|
136 |
+
scores_sum_average = scores_sum / seq_lengths[:, None]
|
137 |
+
scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(
|
138 |
+
beam_size, -1
|
139 |
+
)
|
140 |
+
next_tokens_source = next_tokens // scores_sum.shape[1]
|
141 |
+
seq_lengths = seq_lengths[next_tokens_source]
|
142 |
+
next_tokens = next_tokens % scores_sum.shape[1]
|
143 |
+
next_tokens = next_tokens.unsqueeze(1)
|
144 |
+
tokens = tokens[next_tokens_source]
|
145 |
+
tokens = torch.cat((tokens, next_tokens), dim=1)
|
146 |
+
generated = generated[next_tokens_source]
|
147 |
+
scores = scores_sum_average * seq_lengths
|
148 |
+
is_stopped = is_stopped[next_tokens_source]
|
149 |
+
next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view(
|
150 |
+
generated.shape[0], 1, -1
|
151 |
+
)
|
152 |
+
generated = torch.cat((generated, next_token_embed), dim=1)
|
153 |
+
is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
|
154 |
+
if is_stopped.all():
|
155 |
+
break
|
156 |
+
scores = scores / seq_lengths
|
157 |
+
output_list = tokens.cpu().numpy()
|
158 |
+
output_texts = [
|
159 |
+
tokenizer.decode(output[: int(length)], skip_special_tokens=True)
|
160 |
+
for output, length in zip(output_list, seq_lengths)
|
161 |
+
]
|
162 |
+
order = scores.argsort(descending=True)
|
163 |
+
output_texts = [output_texts[i] for i in order]
|
164 |
+
model.train()
|
165 |
+
return output_texts
|
166 |
+
|
167 |
+
|
168 |
+
def generate2(
|
169 |
+
model,
|
170 |
+
tokenizer,
|
171 |
+
tokens=None,
|
172 |
+
prompt=None,
|
173 |
+
embed=None,
|
174 |
+
entry_count=1,
|
175 |
+
entry_length=67, # maximum number of words
|
176 |
+
top_p=0.8,
|
177 |
+
temperature=1.0,
|
178 |
+
stop_token: str = '<|EOS|>',
|
179 |
+
):
|
180 |
+
model.eval()
|
181 |
+
generated_num = 0
|
182 |
+
generated_list = []
|
183 |
+
stop_token_index = tokenizer.encode(stop_token)[0]
|
184 |
+
filter_value = -float("Inf")
|
185 |
+
device = next(model.parameters()).device
|
186 |
+
|
187 |
+
with torch.no_grad():
|
188 |
+
|
189 |
+
for entry_idx in range(entry_count):
|
190 |
+
if embed is not None:
|
191 |
+
generated = embed
|
192 |
+
else:
|
193 |
+
if tokens is None:
|
194 |
+
tokens = torch.tensor(tokenizer.encode(prompt))
|
195 |
+
tokens = tokens.unsqueeze(0).to(device)
|
196 |
+
|
197 |
+
generated = model.gpt.transformer.wte(tokens)
|
198 |
+
|
199 |
+
for i in range(entry_length):
|
200 |
+
|
201 |
+
outputs = model.gpt(inputs_embeds=generated)
|
202 |
+
logits = outputs.logits
|
203 |
+
logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
|
204 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
205 |
+
cumulative_probs = torch.cumsum(
|
206 |
+
nnf.softmax(sorted_logits, dim=-1), dim=-1
|
207 |
+
)
|
208 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
209 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
|
210 |
+
..., :-1
|
211 |
+
].clone()
|
212 |
+
sorted_indices_to_remove[..., 0] = 0
|
213 |
+
|
214 |
+
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
215 |
+
logits[:, indices_to_remove] = filter_value
|
216 |
+
next_token = torch.argmax(logits, -1).unsqueeze(0)
|
217 |
+
next_token_embed = model.gpt.transformer.wte(next_token)
|
218 |
+
if tokens is None:
|
219 |
+
tokens = next_token
|
220 |
+
else:
|
221 |
+
tokens = torch.cat((tokens, next_token), dim=1)
|
222 |
+
generated = torch.cat((generated, next_token_embed), dim=1)
|
223 |
+
if stop_token_index == next_token.item():
|
224 |
+
break
|
225 |
+
|
226 |
+
output_list = list(tokens.squeeze().cpu().numpy())
|
227 |
+
output_text = tokenizer.decode(output_list)
|
228 |
+
generated_list.append(output_text)
|
229 |
+
|
230 |
+
return generated_list[0]
|
231 |
+
|
232 |
+
|
233 |
+
class CaptionDecoder(object):
|
234 |
+
def __init__(self, device, pretrained_path, hidden_dim=-1):
|
235 |
+
if hidden_dim < 0:
|
236 |
+
hidden_dim = None
|
237 |
+
# tokenizer initialize
|
238 |
+
eos = '<|EOS|>'
|
239 |
+
special_tokens_dict = {'eos_token': eos}
|
240 |
+
self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
241 |
+
self.tokenizer.add_special_tokens(special_tokens_dict)
|
242 |
+
|
243 |
+
# model initialize
|
244 |
+
feature_length = 77
|
245 |
+
# modelFile = "assets/caption_decoder/coco_v2_latest.pt"
|
246 |
+
self.caption_model = ClipCaptionModel(feature_length, hidden_dim=hidden_dim)
|
247 |
+
# print("Load Model...")
|
248 |
+
ckpt = torch.load(pretrained_path, map_location='cpu')
|
249 |
+
state_dict = OrderedDict()
|
250 |
+
for k, v in ckpt.items():
|
251 |
+
new_k = k[7:]
|
252 |
+
state_dict[new_k] = v
|
253 |
+
mk, uk = self.caption_model.load_state_dict(state_dict, strict=False)
|
254 |
+
assert len(mk) == 0
|
255 |
+
assert all([name.startswith('clip') for name in uk])
|
256 |
+
self.caption_model.eval()
|
257 |
+
self.caption_model.to(device)
|
258 |
+
self.caption_model.requires_grad_(False)
|
259 |
+
self.device = device
|
260 |
+
|
261 |
+
def encode_prefix(self, features):
|
262 |
+
return self.caption_model.encode_prefix(features)
|
263 |
+
|
264 |
+
def generate_captions(self, features): # the low dimension representation of clip feature
|
265 |
+
"""
|
266 |
+
generate captions given features
|
267 |
+
: param features : (tensor([B x L x D]))
|
268 |
+
: return generated_text: (list([L]))
|
269 |
+
"""
|
270 |
+
|
271 |
+
# generate config
|
272 |
+
use_beam_search = True
|
273 |
+
|
274 |
+
features = torch.split(features, 1, dim=0)
|
275 |
+
generated_captions = []
|
276 |
+
with torch.no_grad():
|
277 |
+
for feature in features:
|
278 |
+
feature = self.caption_model.decode_prefix(feature.to(self.device)) # back to the clip feature
|
279 |
+
if use_beam_search:
|
280 |
+
generated_captions.append(generate_beam(self.caption_model, self.tokenizer, embed=feature)[0])
|
281 |
+
else:
|
282 |
+
generated_captions.append(generate2(self.caption_model, self.tokenizer, embed=feature))
|
283 |
+
return generated_captions
|
libs/clip.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from transformers import CLIPTokenizer, CLIPTextModel
|
3 |
+
|
4 |
+
|
5 |
+
class AbstractEncoder(nn.Module):
|
6 |
+
def __init__(self):
|
7 |
+
super().__init__()
|
8 |
+
|
9 |
+
def encode(self, *args, **kwargs):
|
10 |
+
raise NotImplementedError
|
11 |
+
|
12 |
+
|
13 |
+
class FrozenCLIPEmbedder(AbstractEncoder):
|
14 |
+
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
15 |
+
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
|
16 |
+
super().__init__()
|
17 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
18 |
+
self.transformer = CLIPTextModel.from_pretrained(version)
|
19 |
+
self.device = device
|
20 |
+
self.max_length = max_length
|
21 |
+
self.freeze()
|
22 |
+
|
23 |
+
def freeze(self):
|
24 |
+
self.transformer = self.transformer.eval()
|
25 |
+
for param in self.parameters():
|
26 |
+
param.requires_grad = False
|
27 |
+
|
28 |
+
def forward(self, text):
|
29 |
+
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
30 |
+
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
31 |
+
tokens = batch_encoding["input_ids"].to(self.device)
|
32 |
+
outputs = self.transformer(input_ids=tokens)
|
33 |
+
|
34 |
+
z = outputs.last_hidden_state
|
35 |
+
return z
|
36 |
+
|
37 |
+
def encode(self, text):
|
38 |
+
return self(text)
|
libs/timm.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# code from timm 0.3.2
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import math
|
5 |
+
import warnings
|
6 |
+
|
7 |
+
|
8 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
9 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
10 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
11 |
+
def norm_cdf(x):
|
12 |
+
# Computes standard normal cumulative distribution function
|
13 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
14 |
+
|
15 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
16 |
+
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
17 |
+
"The distribution of values may be incorrect.",
|
18 |
+
stacklevel=2)
|
19 |
+
|
20 |
+
with torch.no_grad():
|
21 |
+
# Values are generated by using a truncated uniform distribution and
|
22 |
+
# then using the inverse CDF for the normal distribution.
|
23 |
+
# Get upper and lower cdf values
|
24 |
+
l = norm_cdf((a - mean) / std)
|
25 |
+
u = norm_cdf((b - mean) / std)
|
26 |
+
|
27 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
28 |
+
# [2l-1, 2u-1].
|
29 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
30 |
+
|
31 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
32 |
+
# standard normal
|
33 |
+
tensor.erfinv_()
|
34 |
+
|
35 |
+
# Transform to proper mean, std
|
36 |
+
tensor.mul_(std * math.sqrt(2.))
|
37 |
+
tensor.add_(mean)
|
38 |
+
|
39 |
+
# Clamp to ensure it's in the proper range
|
40 |
+
tensor.clamp_(min=a, max=b)
|
41 |
+
return tensor
|
42 |
+
|
43 |
+
|
44 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
45 |
+
# type: (Tensor, float, float, float, float) -> Tensor
|
46 |
+
r"""Fills the input Tensor with values drawn from a truncated
|
47 |
+
normal distribution. The values are effectively drawn from the
|
48 |
+
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
49 |
+
with values outside :math:`[a, b]` redrawn until they are within
|
50 |
+
the bounds. The method used for generating the random values works
|
51 |
+
best when :math:`a \leq \text{mean} \leq b`.
|
52 |
+
Args:
|
53 |
+
tensor: an n-dimensional `torch.Tensor`
|
54 |
+
mean: the mean of the normal distribution
|
55 |
+
std: the standard deviation of the normal distribution
|
56 |
+
a: the minimum cutoff value
|
57 |
+
b: the maximum cutoff value
|
58 |
+
Examples:
|
59 |
+
>>> w = torch.empty(3, 5)
|
60 |
+
>>> nn.init.trunc_normal_(w)
|
61 |
+
"""
|
62 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
63 |
+
|
64 |
+
|
65 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
66 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
67 |
+
|
68 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
69 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
70 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
71 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
72 |
+
'survival rate' as the argument.
|
73 |
+
|
74 |
+
"""
|
75 |
+
if drop_prob == 0. or not training:
|
76 |
+
return x
|
77 |
+
keep_prob = 1 - drop_prob
|
78 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
79 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
80 |
+
random_tensor.floor_() # binarize
|
81 |
+
output = x.div(keep_prob) * random_tensor
|
82 |
+
return output
|
83 |
+
|
84 |
+
|
85 |
+
class DropPath(nn.Module):
|
86 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
87 |
+
"""
|
88 |
+
def __init__(self, drop_prob=None):
|
89 |
+
super(DropPath, self).__init__()
|
90 |
+
self.drop_prob = drop_prob
|
91 |
+
|
92 |
+
def forward(self, x):
|
93 |
+
return drop_path(x, self.drop_prob, self.training)
|
94 |
+
|
95 |
+
|
96 |
+
class Mlp(nn.Module):
|
97 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
98 |
+
super().__init__()
|
99 |
+
out_features = out_features or in_features
|
100 |
+
hidden_features = hidden_features or in_features
|
101 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
102 |
+
self.act = act_layer()
|
103 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
104 |
+
self.drop = nn.Dropout(drop)
|
105 |
+
|
106 |
+
def forward(self, x):
|
107 |
+
x = self.fc1(x)
|
108 |
+
x = self.act(x)
|
109 |
+
x = self.drop(x)
|
110 |
+
x = self.fc2(x)
|
111 |
+
x = self.drop(x)
|
112 |
+
return x
|
libs/uvit_multi_post_ln.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import math
|
4 |
+
from .timm import trunc_normal_, DropPath, Mlp
|
5 |
+
import einops
|
6 |
+
import torch.utils.checkpoint
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
10 |
+
ATTENTION_MODE = 'flash'
|
11 |
+
else:
|
12 |
+
try:
|
13 |
+
import xformers
|
14 |
+
import xformers.ops
|
15 |
+
ATTENTION_MODE = 'xformers'
|
16 |
+
except:
|
17 |
+
ATTENTION_MODE = 'math'
|
18 |
+
print(f'attention mode is {ATTENTION_MODE}')
|
19 |
+
|
20 |
+
|
21 |
+
def timestep_embedding(timesteps, dim, max_period=10000):
|
22 |
+
"""
|
23 |
+
Create sinusoidal timestep embeddings.
|
24 |
+
|
25 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
26 |
+
These may be fractional.
|
27 |
+
:param dim: the dimension of the output.
|
28 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
29 |
+
:return: an [N x dim] Tensor of positional embeddings.
|
30 |
+
"""
|
31 |
+
half = dim // 2
|
32 |
+
freqs = torch.exp(
|
33 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
34 |
+
).to(device=timesteps.device)
|
35 |
+
args = timesteps[:, None].float() * freqs[None]
|
36 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
37 |
+
if dim % 2:
|
38 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
39 |
+
return embedding
|
40 |
+
|
41 |
+
|
42 |
+
def patchify(imgs, patch_size):
|
43 |
+
x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size)
|
44 |
+
return x
|
45 |
+
|
46 |
+
|
47 |
+
def unpatchify(x, in_chans):
|
48 |
+
patch_size = int((x.shape[2] // in_chans) ** 0.5)
|
49 |
+
h = w = int(x.shape[1] ** .5)
|
50 |
+
assert h * w == x.shape[1] and patch_size ** 2 * in_chans == x.shape[2]
|
51 |
+
x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h, p1=patch_size, p2=patch_size)
|
52 |
+
return x
|
53 |
+
|
54 |
+
|
55 |
+
def interpolate_pos_emb(pos_emb, old_shape, new_shape):
|
56 |
+
pos_emb = einops.rearrange(pos_emb, 'B (H W) C -> B C H W', H=old_shape[0], W=old_shape[1])
|
57 |
+
pos_emb = F.interpolate(pos_emb, new_shape, mode='bilinear')
|
58 |
+
pos_emb = einops.rearrange(pos_emb, 'B C H W -> B (H W) C')
|
59 |
+
return pos_emb
|
60 |
+
|
61 |
+
|
62 |
+
class Attention(nn.Module):
|
63 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
64 |
+
super().__init__()
|
65 |
+
self.num_heads = num_heads
|
66 |
+
head_dim = dim // num_heads
|
67 |
+
self.scale = qk_scale or head_dim ** -0.5
|
68 |
+
|
69 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
70 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
71 |
+
self.proj = nn.Linear(dim, dim)
|
72 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
73 |
+
|
74 |
+
def forward(self, x):
|
75 |
+
B, L, C = x.shape
|
76 |
+
|
77 |
+
qkv = self.qkv(x)
|
78 |
+
if ATTENTION_MODE == 'flash':
|
79 |
+
qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float()
|
80 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
|
81 |
+
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
82 |
+
x = einops.rearrange(x, 'B H L D -> B L (H D)')
|
83 |
+
elif ATTENTION_MODE == 'xformers':
|
84 |
+
qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads)
|
85 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
|
86 |
+
x = xformers.ops.memory_efficient_attention(q, k, v)
|
87 |
+
x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads)
|
88 |
+
elif ATTENTION_MODE == 'math':
|
89 |
+
with torch.amp.autocast(device_type='cuda', enabled=False):
|
90 |
+
qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float()
|
91 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
|
92 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
93 |
+
attn = attn.softmax(dim=-1)
|
94 |
+
attn = self.attn_drop(attn)
|
95 |
+
x = (attn @ v).transpose(1, 2).reshape(B, L, C)
|
96 |
+
else:
|
97 |
+
raise NotImplemented
|
98 |
+
|
99 |
+
x = self.proj(x)
|
100 |
+
x = self.proj_drop(x)
|
101 |
+
return x
|
102 |
+
|
103 |
+
|
104 |
+
class Block(nn.Module):
|
105 |
+
|
106 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
107 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False):
|
108 |
+
super().__init__()
|
109 |
+
self.norm1 = norm_layer(dim) if skip else None
|
110 |
+
self.norm2 = norm_layer(dim)
|
111 |
+
|
112 |
+
self.attn = Attention(
|
113 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
114 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
115 |
+
self.norm3 = norm_layer(dim)
|
116 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
117 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
118 |
+
self.skip_linear = nn.Linear(2 * dim, dim) if skip else None
|
119 |
+
self.use_checkpoint = use_checkpoint
|
120 |
+
|
121 |
+
def forward(self, x, skip=None):
|
122 |
+
if self.use_checkpoint:
|
123 |
+
return torch.utils.checkpoint.checkpoint(self._forward, x, skip)
|
124 |
+
else:
|
125 |
+
return self._forward(x, skip)
|
126 |
+
|
127 |
+
def _forward(self, x, skip=None):
|
128 |
+
if self.skip_linear is not None:
|
129 |
+
x = self.skip_linear(torch.cat([x, skip], dim=-1))
|
130 |
+
x = self.norm1(x)
|
131 |
+
x = x + self.drop_path(self.attn(x))
|
132 |
+
x = self.norm2(x)
|
133 |
+
|
134 |
+
x = x + self.drop_path(self.mlp(x))
|
135 |
+
x = self.norm3(x)
|
136 |
+
|
137 |
+
return x
|
138 |
+
|
139 |
+
|
140 |
+
class PatchEmbed(nn.Module):
|
141 |
+
""" Image to Patch Embedding
|
142 |
+
"""
|
143 |
+
def __init__(self, patch_size, in_chans=3, embed_dim=768):
|
144 |
+
super().__init__()
|
145 |
+
self.patch_size = patch_size
|
146 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
147 |
+
|
148 |
+
def forward(self, x):
|
149 |
+
B, C, H, W = x.shape
|
150 |
+
assert H % self.patch_size == 0 and W % self.patch_size == 0
|
151 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
152 |
+
return x
|
153 |
+
|
154 |
+
|
155 |
+
class UViT(nn.Module):
|
156 |
+
def __init__(self, img_size, in_chans, patch_size, embed_dim=768, depth=12,
|
157 |
+
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, pos_drop_rate=0., drop_rate=0., attn_drop_rate=0.,
|
158 |
+
norm_layer=nn.LayerNorm, mlp_time_embed=False, use_checkpoint=False,
|
159 |
+
text_dim=None, num_text_tokens=None, clip_img_dim=None):
|
160 |
+
super().__init__()
|
161 |
+
self.in_chans = in_chans
|
162 |
+
self.patch_size = patch_size
|
163 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
164 |
+
|
165 |
+
self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
166 |
+
self.img_size = (img_size, img_size) if isinstance(img_size, int) else img_size # the default img size
|
167 |
+
assert self.img_size[0] % patch_size == 0 and self.img_size[1] % patch_size == 0
|
168 |
+
self.num_patches = (self.img_size[0] // patch_size) * (self.img_size[1] // patch_size)
|
169 |
+
|
170 |
+
self.time_img_embed = nn.Sequential(
|
171 |
+
nn.Linear(embed_dim, 4 * embed_dim),
|
172 |
+
nn.SiLU(),
|
173 |
+
nn.Linear(4 * embed_dim, embed_dim),
|
174 |
+
) if mlp_time_embed else nn.Identity()
|
175 |
+
|
176 |
+
self.time_text_embed = nn.Sequential(
|
177 |
+
nn.Linear(embed_dim, 4 * embed_dim),
|
178 |
+
nn.SiLU(),
|
179 |
+
nn.Linear(4 * embed_dim, embed_dim),
|
180 |
+
) if mlp_time_embed else nn.Identity()
|
181 |
+
|
182 |
+
self.text_embed = nn.Linear(text_dim, embed_dim)
|
183 |
+
self.text_out = nn.Linear(embed_dim, text_dim)
|
184 |
+
|
185 |
+
self.clip_img_embed = nn.Linear(clip_img_dim, embed_dim)
|
186 |
+
self.clip_img_out = nn.Linear(embed_dim, clip_img_dim)
|
187 |
+
|
188 |
+
self.num_text_tokens = num_text_tokens
|
189 |
+
self.num_tokens = 1 + 1 + num_text_tokens + 1 + self.num_patches
|
190 |
+
|
191 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim))
|
192 |
+
self.pos_drop = nn.Dropout(p=pos_drop_rate)
|
193 |
+
|
194 |
+
self.in_blocks = nn.ModuleList([
|
195 |
+
Block(
|
196 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
197 |
+
drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer, use_checkpoint=use_checkpoint)
|
198 |
+
for _ in range(depth // 2)])
|
199 |
+
|
200 |
+
self.mid_block = Block(
|
201 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
202 |
+
drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer, use_checkpoint=use_checkpoint)
|
203 |
+
|
204 |
+
self.out_blocks = nn.ModuleList([
|
205 |
+
Block(
|
206 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
207 |
+
drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer, skip=True, use_checkpoint=use_checkpoint)
|
208 |
+
for _ in range(depth // 2)])
|
209 |
+
|
210 |
+
self.norm = norm_layer(embed_dim)
|
211 |
+
self.patch_dim = patch_size ** 2 * in_chans
|
212 |
+
self.decoder_pred = nn.Linear(embed_dim, self.patch_dim, bias=True)
|
213 |
+
|
214 |
+
trunc_normal_(self.pos_embed, std=.02)
|
215 |
+
self.apply(self._init_weights)
|
216 |
+
|
217 |
+
def _init_weights(self, m):
|
218 |
+
if isinstance(m, nn.Linear):
|
219 |
+
trunc_normal_(m.weight, std=.02)
|
220 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
221 |
+
nn.init.constant_(m.bias, 0)
|
222 |
+
elif isinstance(m, nn.LayerNorm):
|
223 |
+
nn.init.constant_(m.bias, 0)
|
224 |
+
nn.init.constant_(m.weight, 1.0)
|
225 |
+
|
226 |
+
@torch.jit.ignore
|
227 |
+
def no_weight_decay(self):
|
228 |
+
return {'pos_embed'}
|
229 |
+
|
230 |
+
def forward(self, img, clip_img, text, t_img, t_text):
|
231 |
+
_, _, H, W = img.shape
|
232 |
+
|
233 |
+
img = self.patch_embed(img)
|
234 |
+
|
235 |
+
t_img_token = self.time_img_embed(timestep_embedding(t_img, self.embed_dim))
|
236 |
+
t_img_token = t_img_token.unsqueeze(dim=1)
|
237 |
+
t_text_token = self.time_text_embed(timestep_embedding(t_text, self.embed_dim))
|
238 |
+
t_text_token = t_text_token.unsqueeze(dim=1)
|
239 |
+
|
240 |
+
text = self.text_embed(text)
|
241 |
+
clip_img = self.clip_img_embed(clip_img)
|
242 |
+
x = torch.cat((t_img_token, t_text_token, text, clip_img, img), dim=1)
|
243 |
+
|
244 |
+
num_text_tokens, num_img_tokens = text.size(1), img.size(1)
|
245 |
+
|
246 |
+
if H == self.img_size[0] and W == self.img_size[1]:
|
247 |
+
pos_embed = self.pos_embed
|
248 |
+
else: # interpolate the positional embedding when the input image is not of the default shape
|
249 |
+
pos_embed_others, pos_embed_patches = torch.split(self.pos_embed, [1 + 1 + num_text_tokens + 1, self.num_patches], dim=1)
|
250 |
+
pos_embed_patches = interpolate_pos_emb(pos_embed_patches, (self.img_size[0] // self.patch_size, self.img_size[1] // self.patch_size),
|
251 |
+
(H // self.patch_size, W // self.patch_size))
|
252 |
+
pos_embed = torch.cat((pos_embed_others, pos_embed_patches), dim=1)
|
253 |
+
|
254 |
+
x = x + pos_embed
|
255 |
+
x = self.pos_drop(x)
|
256 |
+
|
257 |
+
skips = []
|
258 |
+
for blk in self.in_blocks:
|
259 |
+
x = blk(x)
|
260 |
+
skips.append(x)
|
261 |
+
|
262 |
+
x = self.mid_block(x)
|
263 |
+
|
264 |
+
for blk in self.out_blocks:
|
265 |
+
x = blk(x, skips.pop())
|
266 |
+
|
267 |
+
x = self.norm(x)
|
268 |
+
|
269 |
+
t_img_token_out, t_text_token_out, text_out, clip_img_out, img_out = x.split((1, 1, num_text_tokens, 1, num_img_tokens), dim=1)
|
270 |
+
|
271 |
+
img_out = self.decoder_pred(img_out)
|
272 |
+
img_out = unpatchify(img_out, self.in_chans)
|
273 |
+
|
274 |
+
clip_img_out = self.clip_img_out(clip_img_out)
|
275 |
+
|
276 |
+
text_out = self.text_out(text_out)
|
277 |
+
return img_out, clip_img_out, text_out
|
libs/uvit_multi_post_ln_v1.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import math
|
4 |
+
from .timm import trunc_normal_, DropPath, Mlp
|
5 |
+
import einops
|
6 |
+
import torch.utils.checkpoint
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
10 |
+
ATTENTION_MODE = 'flash'
|
11 |
+
else:
|
12 |
+
try:
|
13 |
+
import xformers
|
14 |
+
import xformers.ops
|
15 |
+
ATTENTION_MODE = 'xformers'
|
16 |
+
except:
|
17 |
+
ATTENTION_MODE = 'math'
|
18 |
+
print(f'attention mode is {ATTENTION_MODE}')
|
19 |
+
|
20 |
+
|
21 |
+
def timestep_embedding(timesteps, dim, max_period=10000):
|
22 |
+
"""
|
23 |
+
Create sinusoidal timestep embeddings.
|
24 |
+
|
25 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
26 |
+
These may be fractional.
|
27 |
+
:param dim: the dimension of the output.
|
28 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
29 |
+
:return: an [N x dim] Tensor of positional embeddings.
|
30 |
+
"""
|
31 |
+
half = dim // 2
|
32 |
+
freqs = torch.exp(
|
33 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
34 |
+
).to(device=timesteps.device)
|
35 |
+
args = timesteps[:, None].float() * freqs[None]
|
36 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
37 |
+
if dim % 2:
|
38 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
39 |
+
return embedding
|
40 |
+
|
41 |
+
|
42 |
+
def patchify(imgs, patch_size):
|
43 |
+
x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size)
|
44 |
+
return x
|
45 |
+
|
46 |
+
|
47 |
+
def unpatchify(x, in_chans):
|
48 |
+
patch_size = int((x.shape[2] // in_chans) ** 0.5)
|
49 |
+
h = w = int(x.shape[1] ** .5)
|
50 |
+
assert h * w == x.shape[1] and patch_size ** 2 * in_chans == x.shape[2]
|
51 |
+
x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h, p1=patch_size, p2=patch_size)
|
52 |
+
return x
|
53 |
+
|
54 |
+
|
55 |
+
def interpolate_pos_emb(pos_emb, old_shape, new_shape):
|
56 |
+
pos_emb = einops.rearrange(pos_emb, 'B (H W) C -> B C H W', H=old_shape[0], W=old_shape[1])
|
57 |
+
pos_emb = F.interpolate(pos_emb, new_shape, mode='bilinear')
|
58 |
+
pos_emb = einops.rearrange(pos_emb, 'B C H W -> B (H W) C')
|
59 |
+
return pos_emb
|
60 |
+
|
61 |
+
|
62 |
+
class Attention(nn.Module):
|
63 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
64 |
+
super().__init__()
|
65 |
+
self.num_heads = num_heads
|
66 |
+
head_dim = dim // num_heads
|
67 |
+
self.scale = qk_scale or head_dim ** -0.5
|
68 |
+
|
69 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
70 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
71 |
+
self.proj = nn.Linear(dim, dim)
|
72 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
73 |
+
|
74 |
+
def forward(self, x):
|
75 |
+
B, L, C = x.shape
|
76 |
+
|
77 |
+
qkv = self.qkv(x)
|
78 |
+
if ATTENTION_MODE == 'flash':
|
79 |
+
qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float()
|
80 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
|
81 |
+
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
82 |
+
x = einops.rearrange(x, 'B H L D -> B L (H D)')
|
83 |
+
elif ATTENTION_MODE == 'xformers':
|
84 |
+
qkv = einops.rearrange(qkv, 'B L (K H D) -> K B L H D', K=3, H=self.num_heads)
|
85 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
|
86 |
+
x = xformers.ops.memory_efficient_attention(q, k, v)
|
87 |
+
x = einops.rearrange(x, 'B L H D -> B L (H D)', H=self.num_heads)
|
88 |
+
elif ATTENTION_MODE == 'math':
|
89 |
+
with torch.amp.autocast(device_type='cuda', enabled=False):
|
90 |
+
qkv = einops.rearrange(qkv, 'B L (K H D) -> K B H L D', K=3, H=self.num_heads).float()
|
91 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
|
92 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
93 |
+
attn = attn.softmax(dim=-1)
|
94 |
+
attn = self.attn_drop(attn)
|
95 |
+
x = (attn @ v).transpose(1, 2).reshape(B, L, C)
|
96 |
+
else:
|
97 |
+
raise NotImplemented
|
98 |
+
|
99 |
+
x = self.proj(x)
|
100 |
+
x = self.proj_drop(x)
|
101 |
+
return x
|
102 |
+
|
103 |
+
|
104 |
+
class Block(nn.Module):
|
105 |
+
|
106 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
107 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False):
|
108 |
+
super().__init__()
|
109 |
+
self.norm1 = norm_layer(dim) if skip else None
|
110 |
+
self.norm2 = norm_layer(dim)
|
111 |
+
|
112 |
+
self.attn = Attention(
|
113 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
114 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
115 |
+
self.norm3 = norm_layer(dim)
|
116 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
117 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
118 |
+
self.skip_linear = nn.Linear(2 * dim, dim) if skip else None
|
119 |
+
self.use_checkpoint = use_checkpoint
|
120 |
+
|
121 |
+
def forward(self, x, skip=None):
|
122 |
+
if self.use_checkpoint:
|
123 |
+
return torch.utils.checkpoint.checkpoint(self._forward, x, skip)
|
124 |
+
else:
|
125 |
+
return self._forward(x, skip)
|
126 |
+
|
127 |
+
def _forward(self, x, skip=None):
|
128 |
+
if self.skip_linear is not None:
|
129 |
+
x = self.skip_linear(torch.cat([x, skip], dim=-1))
|
130 |
+
x = self.norm1(x)
|
131 |
+
x = x + self.drop_path(self.attn(x))
|
132 |
+
x = self.norm2(x)
|
133 |
+
|
134 |
+
x = x + self.drop_path(self.mlp(x))
|
135 |
+
x = self.norm3(x)
|
136 |
+
|
137 |
+
return x
|
138 |
+
|
139 |
+
|
140 |
+
class PatchEmbed(nn.Module):
|
141 |
+
""" Image to Patch Embedding
|
142 |
+
"""
|
143 |
+
def __init__(self, patch_size, in_chans=3, embed_dim=768):
|
144 |
+
super().__init__()
|
145 |
+
self.patch_size = patch_size
|
146 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
147 |
+
|
148 |
+
def forward(self, x):
|
149 |
+
B, C, H, W = x.shape
|
150 |
+
assert H % self.patch_size == 0 and W % self.patch_size == 0
|
151 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
152 |
+
return x
|
153 |
+
|
154 |
+
|
155 |
+
class UViT(nn.Module):
|
156 |
+
def __init__(self, img_size, in_chans, patch_size, embed_dim=768, depth=12,
|
157 |
+
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, pos_drop_rate=0., drop_rate=0., attn_drop_rate=0.,
|
158 |
+
norm_layer=nn.LayerNorm, mlp_time_embed=False, use_checkpoint=False,
|
159 |
+
text_dim=None, num_text_tokens=None, clip_img_dim=None):
|
160 |
+
super().__init__()
|
161 |
+
self.in_chans = in_chans
|
162 |
+
self.patch_size = patch_size
|
163 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
164 |
+
|
165 |
+
self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
166 |
+
self.img_size = (img_size, img_size) if isinstance(img_size, int) else img_size # the default img size
|
167 |
+
assert self.img_size[0] % patch_size == 0 and self.img_size[1] % patch_size == 0
|
168 |
+
self.num_patches = (self.img_size[0] // patch_size) * (self.img_size[1] // patch_size)
|
169 |
+
|
170 |
+
self.time_img_embed = nn.Sequential(
|
171 |
+
nn.Linear(embed_dim, 4 * embed_dim),
|
172 |
+
nn.SiLU(),
|
173 |
+
nn.Linear(4 * embed_dim, embed_dim),
|
174 |
+
) if mlp_time_embed else nn.Identity()
|
175 |
+
|
176 |
+
self.time_text_embed = nn.Sequential(
|
177 |
+
nn.Linear(embed_dim, 4 * embed_dim),
|
178 |
+
nn.SiLU(),
|
179 |
+
nn.Linear(4 * embed_dim, embed_dim),
|
180 |
+
) if mlp_time_embed else nn.Identity()
|
181 |
+
|
182 |
+
self.text_embed = nn.Linear(text_dim, embed_dim)
|
183 |
+
self.text_out = nn.Linear(embed_dim, text_dim)
|
184 |
+
|
185 |
+
self.clip_img_embed = nn.Linear(clip_img_dim, embed_dim)
|
186 |
+
self.clip_img_out = nn.Linear(embed_dim, clip_img_dim)
|
187 |
+
|
188 |
+
self.num_text_tokens = num_text_tokens
|
189 |
+
self.num_tokens = 1 + 1 + num_text_tokens + 1 + self.num_patches
|
190 |
+
|
191 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim))
|
192 |
+
self.pos_drop = nn.Dropout(p=pos_drop_rate)
|
193 |
+
|
194 |
+
self.in_blocks = nn.ModuleList([
|
195 |
+
Block(
|
196 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
197 |
+
drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer, use_checkpoint=use_checkpoint)
|
198 |
+
for _ in range(depth // 2)])
|
199 |
+
|
200 |
+
self.mid_block = Block(
|
201 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
202 |
+
drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer, use_checkpoint=use_checkpoint)
|
203 |
+
|
204 |
+
self.out_blocks = nn.ModuleList([
|
205 |
+
Block(
|
206 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
207 |
+
drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer, skip=True, use_checkpoint=use_checkpoint)
|
208 |
+
for _ in range(depth // 2)])
|
209 |
+
|
210 |
+
self.norm = norm_layer(embed_dim)
|
211 |
+
self.patch_dim = patch_size ** 2 * in_chans
|
212 |
+
self.decoder_pred = nn.Linear(embed_dim, self.patch_dim, bias=True)
|
213 |
+
|
214 |
+
trunc_normal_(self.pos_embed, std=.02)
|
215 |
+
self.apply(self._init_weights)
|
216 |
+
|
217 |
+
self.token_embedding = nn.Embedding(2, embed_dim)
|
218 |
+
self.pos_embed_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
219 |
+
|
220 |
+
def _init_weights(self, m):
|
221 |
+
if isinstance(m, nn.Linear):
|
222 |
+
trunc_normal_(m.weight, std=.02)
|
223 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
224 |
+
nn.init.constant_(m.bias, 0)
|
225 |
+
elif isinstance(m, nn.LayerNorm):
|
226 |
+
nn.init.constant_(m.bias, 0)
|
227 |
+
nn.init.constant_(m.weight, 1.0)
|
228 |
+
|
229 |
+
@torch.jit.ignore
|
230 |
+
def no_weight_decay(self):
|
231 |
+
return {'pos_embed'}
|
232 |
+
|
233 |
+
def forward(self, img, clip_img, text, t_img, t_text, data_type):
|
234 |
+
_, _, H, W = img.shape
|
235 |
+
|
236 |
+
img = self.patch_embed(img)
|
237 |
+
|
238 |
+
t_img_token = self.time_img_embed(timestep_embedding(t_img, self.embed_dim))
|
239 |
+
t_img_token = t_img_token.unsqueeze(dim=1)
|
240 |
+
t_text_token = self.time_text_embed(timestep_embedding(t_text, self.embed_dim))
|
241 |
+
t_text_token = t_text_token.unsqueeze(dim=1)
|
242 |
+
|
243 |
+
text = self.text_embed(text)
|
244 |
+
clip_img = self.clip_img_embed(clip_img)
|
245 |
+
|
246 |
+
token_embed = self.token_embedding(data_type).unsqueeze(dim=1)
|
247 |
+
|
248 |
+
x = torch.cat((t_img_token, t_text_token, token_embed, text, clip_img, img), dim=1)
|
249 |
+
|
250 |
+
num_text_tokens, num_img_tokens = text.size(1), img.size(1)
|
251 |
+
|
252 |
+
pos_embed = torch.cat(
|
253 |
+
[self.pos_embed[:, :1 + 1, :], self.pos_embed_token, self.pos_embed[:, 1 + 1:, :]], dim=1)
|
254 |
+
if H == self.img_size[0] and W == self.img_size[1]:
|
255 |
+
pass
|
256 |
+
else: # interpolate the positional embedding when the input image is not of the default shape
|
257 |
+
pos_embed_others, pos_embed_patches = torch.split(pos_embed, [1 + 1 + 1 + num_text_tokens + 1, self.num_patches], dim=1)
|
258 |
+
pos_embed_patches = interpolate_pos_emb(pos_embed_patches, (self.img_size[0] // self.patch_size, self.img_size[1] // self.patch_size),
|
259 |
+
(H // self.patch_size, W // self.patch_size))
|
260 |
+
pos_embed = torch.cat((pos_embed_others, pos_embed_patches), dim=1)
|
261 |
+
|
262 |
+
x = x + pos_embed
|
263 |
+
x = self.pos_drop(x)
|
264 |
+
|
265 |
+
skips = []
|
266 |
+
for blk in self.in_blocks:
|
267 |
+
x = blk(x)
|
268 |
+
skips.append(x)
|
269 |
+
|
270 |
+
x = self.mid_block(x)
|
271 |
+
|
272 |
+
for blk in self.out_blocks:
|
273 |
+
x = blk(x, skips.pop())
|
274 |
+
|
275 |
+
x = self.norm(x)
|
276 |
+
|
277 |
+
t_img_token_out, t_text_token_out, token_embed_out, text_out, clip_img_out, img_out = x.split((1, 1, 1, num_text_tokens, 1, num_img_tokens), dim=1)
|
278 |
+
|
279 |
+
img_out = self.decoder_pred(img_out)
|
280 |
+
img_out = unpatchify(img_out, self.in_chans)
|
281 |
+
|
282 |
+
clip_img_out = self.clip_img_out(clip_img_out)
|
283 |
+
|
284 |
+
text_out = self.text_out(text_out)
|
285 |
+
return img_out, clip_img_out, text_out
|
unidiffuser/sample_v1.py
ADDED
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ml_collections
|
2 |
+
import torch
|
3 |
+
import random
|
4 |
+
from absl import logging
|
5 |
+
import einops
|
6 |
+
from torchvision.utils import save_image, make_grid
|
7 |
+
import torchvision.transforms as standard_transforms
|
8 |
+
import numpy as np
|
9 |
+
import clip
|
10 |
+
from PIL import Image
|
11 |
+
import time
|
12 |
+
import os
|
13 |
+
|
14 |
+
from libs.autoencoder import get_model
|
15 |
+
from libs.clip import FrozenCLIPEmbedder
|
16 |
+
from dpm_solver_pp import NoiseScheduleVP, DPM_Solver
|
17 |
+
from utils import center_crop, set_logger, get_nnet
|
18 |
+
|
19 |
+
|
20 |
+
def stable_diffusion_beta_schedule(linear_start=0.00085, linear_end=0.0120, n_timestep=1000):
|
21 |
+
_betas = (
|
22 |
+
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
|
23 |
+
)
|
24 |
+
return _betas.numpy()
|
25 |
+
|
26 |
+
|
27 |
+
def prepare_contexts(config, clip_text_model, clip_img_model, clip_img_model_preprocess, autoencoder):
|
28 |
+
resolution = config.z_shape[-1] * 8
|
29 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
30 |
+
|
31 |
+
contexts = torch.randn(config.n_samples, 77, config.clip_text_dim).to(device)
|
32 |
+
img_contexts = torch.randn(config.n_samples, 2 * config.z_shape[0], config.z_shape[1], config.z_shape[2])
|
33 |
+
clip_imgs = torch.randn(config.n_samples, 1, config.clip_img_dim)
|
34 |
+
|
35 |
+
if config.mode in ['t2i', 't2i2t']:
|
36 |
+
prompts = [ config.prompt ] * config.n_samples
|
37 |
+
contexts = clip_text_model.encode(prompts)
|
38 |
+
|
39 |
+
elif config.mode in ['i2t', 'i2t2i']:
|
40 |
+
from PIL import Image
|
41 |
+
img_contexts = []
|
42 |
+
clip_imgs = []
|
43 |
+
|
44 |
+
def get_img_feature(image):
|
45 |
+
image = np.array(image).astype(np.uint8)
|
46 |
+
image = center_crop(resolution, resolution, image)
|
47 |
+
clip_img_feature = clip_img_model.encode_image(clip_img_model_preprocess(Image.fromarray(image)).unsqueeze(0).to(device))
|
48 |
+
|
49 |
+
image = (image / 127.5 - 1.0).astype(np.float32)
|
50 |
+
image = einops.rearrange(image, 'h w c -> 1 c h w')
|
51 |
+
image = torch.tensor(image, device=device)
|
52 |
+
moments = autoencoder.encode_moments(image)
|
53 |
+
|
54 |
+
return clip_img_feature, moments
|
55 |
+
|
56 |
+
image = Image.open(config.img).convert('RGB')
|
57 |
+
clip_img, img_context = get_img_feature(image)
|
58 |
+
|
59 |
+
img_contexts.append(img_context)
|
60 |
+
clip_imgs.append(clip_img)
|
61 |
+
img_contexts = img_contexts * config.n_samples
|
62 |
+
clip_imgs = clip_imgs * config.n_samples
|
63 |
+
|
64 |
+
img_contexts = torch.concat(img_contexts, dim=0)
|
65 |
+
clip_imgs = torch.stack(clip_imgs, dim=0)
|
66 |
+
|
67 |
+
return contexts, img_contexts, clip_imgs
|
68 |
+
|
69 |
+
|
70 |
+
def unpreprocess(v): # to B C H W and [0, 1]
|
71 |
+
v = 0.5 * (v + 1.)
|
72 |
+
v.clamp_(0., 1.)
|
73 |
+
return v
|
74 |
+
|
75 |
+
|
76 |
+
def set_seed(seed: int):
|
77 |
+
random.seed(seed)
|
78 |
+
np.random.seed(seed)
|
79 |
+
torch.manual_seed(seed)
|
80 |
+
torch.cuda.manual_seed_all(seed)
|
81 |
+
|
82 |
+
|
83 |
+
def evaluate(config):
|
84 |
+
if config.get('benchmark', False):
|
85 |
+
torch.backends.cudnn.benchmark = True
|
86 |
+
torch.backends.cudnn.deterministic = False
|
87 |
+
|
88 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
89 |
+
set_seed(config.seed)
|
90 |
+
|
91 |
+
config = ml_collections.FrozenConfigDict(config)
|
92 |
+
set_logger(log_level='info')
|
93 |
+
|
94 |
+
_betas = stable_diffusion_beta_schedule()
|
95 |
+
N = len(_betas)
|
96 |
+
|
97 |
+
nnet = get_nnet(**config.nnet)
|
98 |
+
logging.info(f'load nnet from {config.nnet_path}')
|
99 |
+
nnet.load_state_dict(torch.load(config.nnet_path, map_location='cpu'))
|
100 |
+
nnet.to(device)
|
101 |
+
nnet.eval()
|
102 |
+
|
103 |
+
use_caption_decoder = config.text_dim < config.clip_text_dim or config.mode != 't2i'
|
104 |
+
if use_caption_decoder:
|
105 |
+
from ..libs.caption_decoder import CaptionDecoder
|
106 |
+
caption_decoder = CaptionDecoder(device=device, **config.caption_decoder)
|
107 |
+
else:
|
108 |
+
caption_decoder = None
|
109 |
+
|
110 |
+
clip_text_model = FrozenCLIPEmbedder(device=device)
|
111 |
+
clip_text_model.eval()
|
112 |
+
clip_text_model.to(device)
|
113 |
+
|
114 |
+
autoencoder = get_model(**config.autoencoder)
|
115 |
+
autoencoder.to(device)
|
116 |
+
|
117 |
+
clip_img_model, clip_img_model_preprocess = clip.load("ViT-B/32", device=device, jit=False)
|
118 |
+
|
119 |
+
empty_context = clip_text_model.encode([''])[0]
|
120 |
+
|
121 |
+
def split(x):
|
122 |
+
C, H, W = config.z_shape
|
123 |
+
z_dim = C * H * W
|
124 |
+
z, clip_img = x.split([z_dim, config.clip_img_dim], dim=1)
|
125 |
+
z = einops.rearrange(z, 'B (C H W) -> B C H W', C=C, H=H, W=W)
|
126 |
+
clip_img = einops.rearrange(clip_img, 'B (L D) -> B L D', L=1, D=config.clip_img_dim)
|
127 |
+
return z, clip_img
|
128 |
+
|
129 |
+
|
130 |
+
def combine(z, clip_img):
|
131 |
+
z = einops.rearrange(z, 'B C H W -> B (C H W)')
|
132 |
+
clip_img = einops.rearrange(clip_img, 'B L D -> B (L D)')
|
133 |
+
return torch.concat([z, clip_img], dim=-1)
|
134 |
+
|
135 |
+
|
136 |
+
def t2i_nnet(x, timesteps, text): # text is the low dimension version of the text clip embedding
|
137 |
+
"""
|
138 |
+
1. calculate the conditional model output
|
139 |
+
2. calculate unconditional model output
|
140 |
+
config.sample.t2i_cfg_mode == 'empty_token': using the original cfg with the empty string
|
141 |
+
config.sample.t2i_cfg_mode == 'true_uncond: using the unconditional model learned by our method
|
142 |
+
3. return linear combination of conditional output and unconditional output
|
143 |
+
"""
|
144 |
+
z, clip_img = split(x)
|
145 |
+
|
146 |
+
t_text = torch.zeros(timesteps.size(0), dtype=torch.int, device=device)
|
147 |
+
|
148 |
+
z_out, clip_img_out, text_out = nnet(z, clip_img, text=text, t_img=timesteps, t_text=t_text,
|
149 |
+
data_type=torch.zeros_like(t_text, device=device, dtype=torch.int) + config.data_type)
|
150 |
+
x_out = combine(z_out, clip_img_out)
|
151 |
+
|
152 |
+
if config.sample.scale == 0.:
|
153 |
+
return x_out
|
154 |
+
|
155 |
+
if config.sample.t2i_cfg_mode == 'empty_token':
|
156 |
+
_empty_context = einops.repeat(empty_context, 'L D -> B L D', B=x.size(0))
|
157 |
+
if use_caption_decoder:
|
158 |
+
_empty_context = caption_decoder.encode_prefix(_empty_context)
|
159 |
+
z_out_uncond, clip_img_out_uncond, text_out_uncond = nnet(z, clip_img, text=_empty_context, t_img=timesteps, t_text=t_text,
|
160 |
+
data_type=torch.zeros_like(t_text, device=device, dtype=torch.int) + config.data_type)
|
161 |
+
x_out_uncond = combine(z_out_uncond, clip_img_out_uncond)
|
162 |
+
elif config.sample.t2i_cfg_mode == 'true_uncond':
|
163 |
+
text_N = torch.randn_like(text) # 3 other possible choices
|
164 |
+
z_out_uncond, clip_img_out_uncond, text_out_uncond = nnet(z, clip_img, text=text_N, t_img=timesteps, t_text=torch.ones_like(timesteps) * N,
|
165 |
+
data_type=torch.zeros_like(t_text, device=device, dtype=torch.int) + config.data_type)
|
166 |
+
x_out_uncond = combine(z_out_uncond, clip_img_out_uncond)
|
167 |
+
else:
|
168 |
+
raise NotImplementedError
|
169 |
+
|
170 |
+
return x_out + config.sample.scale * (x_out - x_out_uncond)
|
171 |
+
|
172 |
+
|
173 |
+
def i_nnet(x, timesteps):
|
174 |
+
z, clip_img = split(x)
|
175 |
+
text = torch.randn(x.size(0), 77, config.text_dim, device=device)
|
176 |
+
t_text = torch.ones_like(timesteps) * N
|
177 |
+
z_out, clip_img_out, text_out = nnet(z, clip_img, text=text, t_img=timesteps, t_text=t_text,
|
178 |
+
data_type=torch.zeros_like(t_text, device=device, dtype=torch.int) + config.data_type)
|
179 |
+
x_out = combine(z_out, clip_img_out)
|
180 |
+
return x_out
|
181 |
+
|
182 |
+
def t_nnet(x, timesteps):
|
183 |
+
z = torch.randn(x.size(0), *config.z_shape, device=device)
|
184 |
+
clip_img = torch.randn(x.size(0), 1, config.clip_img_dim, device=device)
|
185 |
+
z_out, clip_img_out, text_out = nnet(z, clip_img, text=x, t_img=torch.ones_like(timesteps) * N, t_text=timesteps,
|
186 |
+
data_type=torch.zeros_like(timesteps, device=device, dtype=torch.int) + config.data_type)
|
187 |
+
return text_out
|
188 |
+
|
189 |
+
def i2t_nnet(x, timesteps, z, clip_img):
|
190 |
+
"""
|
191 |
+
1. calculate the conditional model output
|
192 |
+
2. calculate unconditional model output
|
193 |
+
3. return linear combination of conditional output and unconditional output
|
194 |
+
"""
|
195 |
+
t_img = torch.zeros(timesteps.size(0), dtype=torch.int, device=device)
|
196 |
+
|
197 |
+
z_out, clip_img_out, text_out = nnet(z, clip_img, text=x, t_img=t_img, t_text=timesteps,
|
198 |
+
data_type=torch.zeros_like(t_img, device=device, dtype=torch.int) + config.data_type)
|
199 |
+
|
200 |
+
if config.sample.scale == 0.:
|
201 |
+
return text_out
|
202 |
+
|
203 |
+
z_N = torch.randn_like(z) # 3 other possible choices
|
204 |
+
clip_img_N = torch.randn_like(clip_img)
|
205 |
+
z_out_uncond, clip_img_out_uncond, text_out_uncond = nnet(z_N, clip_img_N, text=x, t_img=torch.ones_like(timesteps) * N, t_text=timesteps,
|
206 |
+
data_type=torch.zeros_like(timesteps, device=device, dtype=torch.int) + config.data_type)
|
207 |
+
|
208 |
+
return text_out + config.sample.scale * (text_out - text_out_uncond)
|
209 |
+
|
210 |
+
def split_joint(x):
|
211 |
+
C, H, W = config.z_shape
|
212 |
+
z_dim = C * H * W
|
213 |
+
z, clip_img, text = x.split([z_dim, config.clip_img_dim, 77 * config.text_dim], dim=1)
|
214 |
+
z = einops.rearrange(z, 'B (C H W) -> B C H W', C=C, H=H, W=W)
|
215 |
+
clip_img = einops.rearrange(clip_img, 'B (L D) -> B L D', L=1, D=config.clip_img_dim)
|
216 |
+
text = einops.rearrange(text, 'B (L D) -> B L D', L=77, D=config.text_dim)
|
217 |
+
return z, clip_img, text
|
218 |
+
|
219 |
+
def combine_joint(z, clip_img, text):
|
220 |
+
z = einops.rearrange(z, 'B C H W -> B (C H W)')
|
221 |
+
clip_img = einops.rearrange(clip_img, 'B L D -> B (L D)')
|
222 |
+
text = einops.rearrange(text, 'B L D -> B (L D)')
|
223 |
+
return torch.concat([z, clip_img, text], dim=-1)
|
224 |
+
|
225 |
+
def joint_nnet(x, timesteps):
|
226 |
+
z, clip_img, text = split_joint(x)
|
227 |
+
z_out, clip_img_out, text_out = nnet(z, clip_img, text=text, t_img=timesteps, t_text=timesteps,
|
228 |
+
data_type=torch.zeros_like(timesteps, device=device, dtype=torch.int) + config.data_type)
|
229 |
+
x_out = combine_joint(z_out, clip_img_out, text_out)
|
230 |
+
|
231 |
+
if config.sample.scale == 0.:
|
232 |
+
return x_out
|
233 |
+
|
234 |
+
z_noise = torch.randn(x.size(0), *config.z_shape, device=device)
|
235 |
+
clip_img_noise = torch.randn(x.size(0), 1, config.clip_img_dim, device=device)
|
236 |
+
text_noise = torch.randn(x.size(0), 77, config.text_dim, device=device)
|
237 |
+
|
238 |
+
_, _, text_out_uncond = nnet(z_noise, clip_img_noise, text=text, t_img=torch.ones_like(timesteps) * N, t_text=timesteps,
|
239 |
+
data_type=torch.zeros_like(timesteps, device=device, dtype=torch.int) + config.data_type)
|
240 |
+
z_out_uncond, clip_img_out_uncond, _ = nnet(z, clip_img, text=text_noise, t_img=timesteps, t_text=torch.ones_like(timesteps) * N,
|
241 |
+
data_type=torch.zeros_like(timesteps, device=device, dtype=torch.int) + config.data_type)
|
242 |
+
|
243 |
+
x_out_uncond = combine_joint(z_out_uncond, clip_img_out_uncond, text_out_uncond)
|
244 |
+
|
245 |
+
return x_out + config.sample.scale * (x_out - x_out_uncond)
|
246 |
+
|
247 |
+
@torch.cuda.amp.autocast()
|
248 |
+
def encode(_batch):
|
249 |
+
return autoencoder.encode(_batch)
|
250 |
+
|
251 |
+
@torch.cuda.amp.autocast()
|
252 |
+
def decode(_batch):
|
253 |
+
return autoencoder.decode(_batch)
|
254 |
+
|
255 |
+
|
256 |
+
logging.info(config.sample)
|
257 |
+
logging.info(f'N={N}')
|
258 |
+
|
259 |
+
contexts, img_contexts, clip_imgs = prepare_contexts(config, clip_text_model, clip_img_model, clip_img_model_preprocess, autoencoder)
|
260 |
+
|
261 |
+
contexts = contexts # the clip embedding of conditioned texts
|
262 |
+
contexts_low_dim = contexts if not use_caption_decoder else caption_decoder.encode_prefix(contexts) # the low dimensional version of the contexts, which is the input to the nnet
|
263 |
+
|
264 |
+
img_contexts = img_contexts # img_contexts is the autoencoder moment
|
265 |
+
z_img = autoencoder.sample(img_contexts)
|
266 |
+
clip_imgs = clip_imgs # the clip embedding of conditioned image
|
267 |
+
|
268 |
+
if config.mode in ['t2i', 't2i2t']:
|
269 |
+
_n_samples = contexts_low_dim.size(0)
|
270 |
+
elif config.mode in ['i2t', 'i2t2i']:
|
271 |
+
_n_samples = img_contexts.size(0)
|
272 |
+
else:
|
273 |
+
_n_samples = config.n_samples
|
274 |
+
|
275 |
+
|
276 |
+
def sample_fn(mode, **kwargs):
|
277 |
+
|
278 |
+
_z_init = torch.randn(_n_samples, *config.z_shape, device=device)
|
279 |
+
_clip_img_init = torch.randn(_n_samples, 1, config.clip_img_dim, device=device)
|
280 |
+
_text_init = torch.randn(_n_samples, 77, config.text_dim, device=device)
|
281 |
+
if mode == 'joint':
|
282 |
+
_x_init = combine_joint(_z_init, _clip_img_init, _text_init)
|
283 |
+
elif mode in ['t2i', 'i']:
|
284 |
+
_x_init = combine(_z_init, _clip_img_init)
|
285 |
+
elif mode in ['i2t', 't']:
|
286 |
+
_x_init = _text_init
|
287 |
+
noise_schedule = NoiseScheduleVP(schedule='discrete', betas=torch.tensor(_betas, device=device).float())
|
288 |
+
|
289 |
+
def model_fn(x, t_continuous):
|
290 |
+
t = t_continuous * N
|
291 |
+
if mode == 'joint':
|
292 |
+
return joint_nnet(x, t)
|
293 |
+
elif mode == 't2i':
|
294 |
+
return t2i_nnet(x, t, **kwargs)
|
295 |
+
elif mode == 'i2t':
|
296 |
+
return i2t_nnet(x, t, **kwargs)
|
297 |
+
elif mode == 'i':
|
298 |
+
return i_nnet(x, t)
|
299 |
+
elif mode == 't':
|
300 |
+
return t_nnet(x, t)
|
301 |
+
|
302 |
+
dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True, thresholding=False)
|
303 |
+
with torch.no_grad():
|
304 |
+
with torch.autocast(device_type=device):
|
305 |
+
start_time = time.time()
|
306 |
+
x = dpm_solver.sample(_x_init, steps=config.sample.sample_steps, eps=1. / N, T=1.)
|
307 |
+
end_time = time.time()
|
308 |
+
print(f'\ngenerate {_n_samples} samples with {config.sample.sample_steps} steps takes {end_time - start_time:.2f}s')
|
309 |
+
|
310 |
+
# os.makedirs(config.output_path, exist_ok=True)
|
311 |
+
if mode == 'joint':
|
312 |
+
_z, _clip_img, _text = split_joint(x)
|
313 |
+
return _z, _clip_img, _text
|
314 |
+
elif mode in ['t2i', 'i']:
|
315 |
+
_z, _clip_img = split(x)
|
316 |
+
return _z, _clip_img
|
317 |
+
elif mode in ['i2t', 't']:
|
318 |
+
return x
|
319 |
+
|
320 |
+
output_images = None
|
321 |
+
output_text = None
|
322 |
+
|
323 |
+
if config.mode in ['joint']:
|
324 |
+
_z, _clip_img, _text = sample_fn(config.mode)
|
325 |
+
samples = unpreprocess(decode(_z))
|
326 |
+
prompts = caption_decoder.generate_captions(_text)
|
327 |
+
# Just get the first output image for now
|
328 |
+
output_images = samples
|
329 |
+
output_text = prompts
|
330 |
+
|
331 |
+
elif config.mode in ['t2i', 'i', 'i2t2i']:
|
332 |
+
if config.mode == 't2i':
|
333 |
+
_z, _clip_img = sample_fn(config.mode, text=contexts_low_dim) # conditioned on the text embedding
|
334 |
+
elif config.mode == 'i':
|
335 |
+
_z, _clip_img = sample_fn(config.mode)
|
336 |
+
elif config.mode == 'i2t2i':
|
337 |
+
_text = sample_fn('i2t', z=z_img, clip_img=clip_imgs) # conditioned on the image embedding
|
338 |
+
_z, _clip_img = sample_fn('t2i', text=_text)
|
339 |
+
samples = unpreprocess(decode(_z))
|
340 |
+
output_images = samples
|
341 |
+
|
342 |
+
|
343 |
+
elif config.mode in ['i2t', 't', 't2i2t']:
|
344 |
+
if config.mode == 'i2t':
|
345 |
+
_text = sample_fn(config.mode, z=z_img, clip_img=clip_imgs) # conditioned on the image embedding
|
346 |
+
elif config.mode == 't':
|
347 |
+
_text = sample_fn(config.mode)
|
348 |
+
elif config.mode == 't2i2t':
|
349 |
+
_z, _clip_img = sample_fn('t2i', text=contexts_low_dim)
|
350 |
+
_text = sample_fn('i2t', z=_z, clip_img=_clip_img)
|
351 |
+
samples = caption_decoder.generate_captions(_text)
|
352 |
+
output_text = samples
|
353 |
+
|
354 |
+
print(f'\nGPU memory usage: {torch.cuda.max_memory_reserved() / 1024 ** 3:.2f} GB')
|
355 |
+
print(f'\nresults are saved in {os.path.join(config.output_path, config.mode)} :)')
|
356 |
+
|
357 |
+
# Convert sample images to PIL
|
358 |
+
if output_images is not None:
|
359 |
+
for sample in output_images:
|
360 |
+
sample = standard_transforms.ToPILImage()(sample)
|
361 |
+
|
362 |
+
return output_images, output_text
|
363 |
+
|
364 |
+
|
365 |
+
|
366 |
+
def d(**kwargs):
|
367 |
+
"""Helper of creating a config dict."""
|
368 |
+
return ml_collections.ConfigDict(initial_dictionary=kwargs)
|
369 |
+
|
370 |
+
|
371 |
+
def get_config():
|
372 |
+
config = ml_collections.ConfigDict()
|
373 |
+
|
374 |
+
config.seed = 1234
|
375 |
+
config.pred = 'noise_pred'
|
376 |
+
config.z_shape = (4, 64, 64)
|
377 |
+
config.clip_img_dim = 512
|
378 |
+
config.clip_text_dim = 768
|
379 |
+
config.text_dim = 64 # reduce dimension
|
380 |
+
config.data_type = 1
|
381 |
+
|
382 |
+
config.autoencoder = d(
|
383 |
+
pretrained_path='models/autoencoder_kl.pth',
|
384 |
+
)
|
385 |
+
|
386 |
+
config.caption_decoder = d(
|
387 |
+
pretrained_path="models/caption_decoder.pth",
|
388 |
+
hidden_dim=config.get_ref('text_dim')
|
389 |
+
)
|
390 |
+
|
391 |
+
config.nnet = d(
|
392 |
+
name='uvit_multi_post_ln_v1',
|
393 |
+
img_size=64,
|
394 |
+
in_chans=4,
|
395 |
+
patch_size=2,
|
396 |
+
embed_dim=1536,
|
397 |
+
depth=30,
|
398 |
+
num_heads=24,
|
399 |
+
mlp_ratio=4,
|
400 |
+
qkv_bias=False,
|
401 |
+
pos_drop_rate=0.,
|
402 |
+
drop_rate=0.,
|
403 |
+
attn_drop_rate=0.,
|
404 |
+
mlp_time_embed=False,
|
405 |
+
text_dim=config.get_ref('text_dim'),
|
406 |
+
num_text_tokens=77,
|
407 |
+
clip_img_dim=config.get_ref('clip_img_dim'),
|
408 |
+
use_checkpoint=True
|
409 |
+
)
|
410 |
+
|
411 |
+
config.sample = d(
|
412 |
+
sample_steps=50,
|
413 |
+
scale=7.,
|
414 |
+
t2i_cfg_mode='true_uncond'
|
415 |
+
)
|
416 |
+
|
417 |
+
return config
|
418 |
+
|
419 |
+
|
420 |
+
def sample(mode, prompt, image, sample_steps=50, scale=7.0, seed=None):
|
421 |
+
config = get_config()
|
422 |
+
|
423 |
+
config.nnet_path = "models/uvit_v1.pth"
|
424 |
+
config.n_samples = 1
|
425 |
+
config.nrow = 1
|
426 |
+
|
427 |
+
config.mode = mode
|
428 |
+
config.prompt = prompt
|
429 |
+
config.img = image
|
430 |
+
|
431 |
+
config.sample.sample_steps = sample_steps
|
432 |
+
config.sample.scale = scale
|
433 |
+
if seed is not None:
|
434 |
+
config.seed = seed
|
435 |
+
|
436 |
+
evaluate(config)
|
437 |
+
|
utils.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from absl import logging
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image, ImageDraw, ImageFont
|
4 |
+
|
5 |
+
|
6 |
+
def center_crop(width, height, img):
|
7 |
+
resample = {'box': Image.BOX, 'lanczos': Image.LANCZOS}['lanczos']
|
8 |
+
crop = np.min(img.shape[:2])
|
9 |
+
img = img[(img.shape[0] - crop) // 2: (img.shape[0] + crop) // 2,
|
10 |
+
(img.shape[1] - crop) // 2: (img.shape[1] + crop) // 2] # center crop
|
11 |
+
try:
|
12 |
+
img = Image.fromarray(img, 'RGB')
|
13 |
+
except:
|
14 |
+
img = Image.fromarray(img)
|
15 |
+
img = img.resize((width, height), resample) # resize the center crop from [crop, crop] to [width, height]
|
16 |
+
|
17 |
+
return np.array(img).astype(np.uint8)
|
18 |
+
|
19 |
+
|
20 |
+
def set_logger(log_level='info', fname=None):
|
21 |
+
import logging as _logging
|
22 |
+
handler = logging.get_absl_handler()
|
23 |
+
formatter = _logging.Formatter('%(asctime)s - %(filename)s - %(message)s')
|
24 |
+
handler.setFormatter(formatter)
|
25 |
+
logging.set_verbosity(log_level)
|
26 |
+
if fname is not None:
|
27 |
+
handler = _logging.FileHandler(fname)
|
28 |
+
handler.setFormatter(formatter)
|
29 |
+
logging.get_absl_logger().addHandler(handler)
|
30 |
+
|
31 |
+
|
32 |
+
def get_nnet(name, **kwargs):
|
33 |
+
if name == 'uvit_multi_post_ln':
|
34 |
+
from libs.uvit_multi_post_ln import UViT
|
35 |
+
return UViT(**kwargs)
|
36 |
+
elif name == 'uvit_multi_post_ln_v1':
|
37 |
+
from libs.uvit_multi_post_ln_v1 import UViT
|
38 |
+
return UViT(**kwargs)
|
39 |
+
else:
|
40 |
+
raise NotImplementedError(name)
|
41 |
+
|
42 |
+
|
43 |
+
def drawRoundRec(draw, color, x, y, w, h, r):
|
44 |
+
drawObject = draw
|
45 |
+
|
46 |
+
'''Rounds'''
|
47 |
+
drawObject.ellipse((x, y, x + r, y + r), fill=color)
|
48 |
+
drawObject.ellipse((x + w - r, y, x + w, y + r), fill=color)
|
49 |
+
drawObject.ellipse((x, y + h - r, x + r, y + h), fill=color)
|
50 |
+
drawObject.ellipse((x + w - r, y + h - r, x + w, y + h), fill=color)
|
51 |
+
|
52 |
+
'''rec.s'''
|
53 |
+
drawObject.rectangle((x + r / 2, y, x + w - (r / 2), y + h), fill=color)
|
54 |
+
drawObject.rectangle((x, y + r / 2, x + w, y + h - (r / 2)), fill=color)
|
55 |
+
|
56 |
+
|
57 |
+
def add_water(img, text='UniDiffuser', pos=3):
|
58 |
+
width, height = img.size
|
59 |
+
scale = 4
|
60 |
+
scale_size = 0.5
|
61 |
+
img = img.resize((width * scale, height * scale), Image.LANCZOS)
|
62 |
+
result = Image.new(img.mode, (width * scale, height * scale), color=(255, 255, 255))
|
63 |
+
result.paste(img, box=(0, 0))
|
64 |
+
|
65 |
+
delta_w = int(width * scale * 0.27 * scale_size) # text width
|
66 |
+
delta_h = width * scale * 0.05 * scale_size # text height
|
67 |
+
postions = np.array([[0, 0], [0, height * scale - delta_h], [width * scale - delta_w, 0],
|
68 |
+
[width * scale - delta_w, height * scale - delta_h]])
|
69 |
+
postion = postions[pos]
|
70 |
+
# 文本
|
71 |
+
draw = ImageDraw.Draw(result)
|
72 |
+
fillColor = (107, 92, 231)
|
73 |
+
setFont = ImageFont.truetype("assets/ArialBoldMT.ttf", int(width * scale * 0.05 * scale_size))
|
74 |
+
delta = 20 * scale_size
|
75 |
+
padding = 15 * scale_size
|
76 |
+
drawRoundRec(draw, (223, 230, 233), postion[0] - delta - padding, postion[1] - delta - padding,
|
77 |
+
w=delta_w + 2 * padding, h=delta_h + 2 * padding, r=50 * scale_size)
|
78 |
+
draw.text((postion[0] - delta, postion[1] - delta), text, font=setFont, fill=fillColor)
|
79 |
+
|
80 |
+
return result.resize((width, height), Image.LANCZOS)
|