Upload modules.py with huggingface_hub
Browse files- modules.py +365 -0
modules.py
ADDED
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import numpy as np
|
6 |
+
from fast_pytorch_kmeans import KMeans
|
7 |
+
from torch import einsum
|
8 |
+
import torch.distributed as dist
|
9 |
+
from einops import rearrange
|
10 |
+
|
11 |
+
|
12 |
+
def get_timestep_embedding(timesteps, embedding_dim):
|
13 |
+
"""
|
14 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
15 |
+
From Fairseq.
|
16 |
+
Build sinusoidal embeddings.
|
17 |
+
This matches the implementation in tensor2tensor, but differs slightly
|
18 |
+
from the description in Section 3.5 of "Attention Is All You Need".
|
19 |
+
"""
|
20 |
+
assert len(timesteps.shape) == 1
|
21 |
+
|
22 |
+
half_dim = embedding_dim // 2
|
23 |
+
emb = math.log(10000) / (half_dim - 1)
|
24 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
25 |
+
emb = emb.to(device=timesteps.device)
|
26 |
+
emb = timesteps.float()[:, None] * emb[None, :]
|
27 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
28 |
+
if embedding_dim % 2 == 1: # zero pad
|
29 |
+
emb = torch.nn.functional.pad(emb, (0,1,0,0))
|
30 |
+
return emb
|
31 |
+
|
32 |
+
|
33 |
+
def nonlinearity(x):
|
34 |
+
# swish
|
35 |
+
return x*torch.sigmoid(x)
|
36 |
+
|
37 |
+
|
38 |
+
def Normalize(in_channels):
|
39 |
+
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
40 |
+
|
41 |
+
|
42 |
+
class Upsample(nn.Module):
|
43 |
+
def __init__(self, in_channels, with_conv):
|
44 |
+
super().__init__()
|
45 |
+
self.with_conv = with_conv
|
46 |
+
if self.with_conv:
|
47 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
48 |
+
in_channels,
|
49 |
+
kernel_size=3,
|
50 |
+
stride=1,
|
51 |
+
padding=1)
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
55 |
+
if self.with_conv:
|
56 |
+
x = self.conv(x)
|
57 |
+
return x
|
58 |
+
|
59 |
+
|
60 |
+
class Downsample(nn.Module):
|
61 |
+
def __init__(self, in_channels, with_conv):
|
62 |
+
super().__init__()
|
63 |
+
self.with_conv = with_conv
|
64 |
+
if self.with_conv:
|
65 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
66 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
67 |
+
in_channels,
|
68 |
+
kernel_size=3,
|
69 |
+
stride=2,
|
70 |
+
padding=0)
|
71 |
+
|
72 |
+
def forward(self, x):
|
73 |
+
if self.with_conv:
|
74 |
+
pad = (0,1,0,1)
|
75 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
76 |
+
x = self.conv(x)
|
77 |
+
else:
|
78 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
79 |
+
return x
|
80 |
+
|
81 |
+
|
82 |
+
class ResnetBlock(nn.Module):
|
83 |
+
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout):
|
84 |
+
super().__init__()
|
85 |
+
self.in_channels = in_channels
|
86 |
+
out_channels = in_channels if out_channels is None else out_channels
|
87 |
+
self.out_channels = out_channels
|
88 |
+
self.use_conv_shortcut = conv_shortcut
|
89 |
+
|
90 |
+
self.norm1 = Normalize(in_channels)
|
91 |
+
self.conv1 = torch.nn.Conv2d(in_channels,
|
92 |
+
out_channels,
|
93 |
+
kernel_size=3,
|
94 |
+
stride=1,
|
95 |
+
padding=1)
|
96 |
+
self.norm2 = Normalize(out_channels)
|
97 |
+
self.dropout = torch.nn.Dropout(dropout)
|
98 |
+
self.conv2 = torch.nn.Conv2d(out_channels,
|
99 |
+
out_channels,
|
100 |
+
kernel_size=3,
|
101 |
+
stride=1,
|
102 |
+
padding=1)
|
103 |
+
if self.in_channels != self.out_channels:
|
104 |
+
if self.use_conv_shortcut:
|
105 |
+
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
106 |
+
out_channels,
|
107 |
+
kernel_size=3,
|
108 |
+
stride=1,
|
109 |
+
padding=1)
|
110 |
+
else:
|
111 |
+
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
112 |
+
out_channels,
|
113 |
+
kernel_size=1,
|
114 |
+
stride=1,
|
115 |
+
padding=0)
|
116 |
+
|
117 |
+
def forward(self, x):
|
118 |
+
h = x
|
119 |
+
h = self.norm1(h)
|
120 |
+
h = nonlinearity(h)
|
121 |
+
h = self.conv1(h)
|
122 |
+
|
123 |
+
h = self.norm2(h)
|
124 |
+
h = nonlinearity(h)
|
125 |
+
h = self.dropout(h)
|
126 |
+
h = self.conv2(h)
|
127 |
+
|
128 |
+
if self.in_channels != self.out_channels:
|
129 |
+
if self.use_conv_shortcut:
|
130 |
+
x = self.conv_shortcut(x)
|
131 |
+
else:
|
132 |
+
x = self.nin_shortcut(x)
|
133 |
+
|
134 |
+
return x+h
|
135 |
+
|
136 |
+
|
137 |
+
class AttnBlock(nn.Module):
|
138 |
+
def __init__(self, in_channels):
|
139 |
+
super().__init__()
|
140 |
+
self.in_channels = in_channels
|
141 |
+
|
142 |
+
self.norm = Normalize(in_channels)
|
143 |
+
self.q = torch.nn.Conv2d(in_channels,
|
144 |
+
in_channels,
|
145 |
+
kernel_size=1,
|
146 |
+
stride=1,
|
147 |
+
padding=0)
|
148 |
+
self.k = torch.nn.Conv2d(in_channels,
|
149 |
+
in_channels,
|
150 |
+
kernel_size=1,
|
151 |
+
stride=1,
|
152 |
+
padding=0)
|
153 |
+
self.v = torch.nn.Conv2d(in_channels,
|
154 |
+
in_channels,
|
155 |
+
kernel_size=1,
|
156 |
+
stride=1,
|
157 |
+
padding=0)
|
158 |
+
self.proj_out = torch.nn.Conv2d(in_channels,
|
159 |
+
in_channels,
|
160 |
+
kernel_size=1,
|
161 |
+
stride=1,
|
162 |
+
padding=0)
|
163 |
+
|
164 |
+
|
165 |
+
def forward(self, x):
|
166 |
+
h_ = x
|
167 |
+
h_ = self.norm(h_)
|
168 |
+
q = self.q(h_)
|
169 |
+
k = self.k(h_)
|
170 |
+
v = self.v(h_)
|
171 |
+
|
172 |
+
# compute attention
|
173 |
+
b,c,h,w = q.shape
|
174 |
+
q = q.reshape(b,c,h*w)
|
175 |
+
q = q.permute(0,2,1) # b,hw,c
|
176 |
+
k = k.reshape(b,c,h*w) # b,c,hw
|
177 |
+
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
178 |
+
w_ = w_ * (int(c)**(-0.5))
|
179 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
180 |
+
|
181 |
+
# attend to values
|
182 |
+
v = v.reshape(b,c,h*w)
|
183 |
+
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
|
184 |
+
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
185 |
+
h_ = h_.reshape(b,c,h,w)
|
186 |
+
|
187 |
+
h_ = self.proj_out(h_)
|
188 |
+
|
189 |
+
return x+h_
|
190 |
+
|
191 |
+
|
192 |
+
class Swish(nn.Module):
|
193 |
+
def forward(self, x):
|
194 |
+
return x * torch.sigmoid(x)
|
195 |
+
|
196 |
+
|
197 |
+
class Encoder(nn.Module):
|
198 |
+
"""
|
199 |
+
Encoder of VQ-GAN to map input batch of images to latent space.
|
200 |
+
Dimension Transformations:
|
201 |
+
3x256x256 --Conv2d--> 32x256x256
|
202 |
+
for loop:
|
203 |
+
--ResBlock--> 64x256x256 --DownBlock--> 64x128x128
|
204 |
+
--ResBlock--> 128x128x128 --DownBlock--> 128x64x64
|
205 |
+
--ResBlock--> 256x64x64 --DownBlock--> 256x32x32
|
206 |
+
--ResBlock--> 512x32x32
|
207 |
+
--ResBlock--> 512x32x32
|
208 |
+
--NonLocalBlock--> 512x32x32
|
209 |
+
--ResBlock--> 512x32x32
|
210 |
+
--GroupNorm-->
|
211 |
+
--Swish-->
|
212 |
+
--Conv2d-> 256x32x32
|
213 |
+
"""
|
214 |
+
|
215 |
+
def __init__(self, in_channels=3, channels=[128, 128, 128, 256, 512, 512], attn_resolutions=[32], resolution=512, dropout=0.0, num_res_blocks=2, z_channels=256, **kwargs):
|
216 |
+
super(Encoder, self).__init__()
|
217 |
+
layers = [nn.Conv2d(in_channels, channels[0], 3, 1, 1)]
|
218 |
+
for i in range(len(channels) - 1):
|
219 |
+
in_channels = channels[i]
|
220 |
+
out_channels = channels[i + 1]
|
221 |
+
for j in range(num_res_blocks):
|
222 |
+
layers.append(ResnetBlock(in_channels=in_channels, out_channels=out_channels, dropout=0.0))
|
223 |
+
in_channels = out_channels
|
224 |
+
if resolution in attn_resolutions:
|
225 |
+
layers.append(AttnBlock(in_channels))
|
226 |
+
if i < len(channels) - 2:
|
227 |
+
layers.append(Downsample(channels[i + 1], with_conv=True))
|
228 |
+
resolution //= 2
|
229 |
+
layers.append(ResnetBlock(in_channels=channels[-1], out_channels=channels[-1], dropout=0.0))
|
230 |
+
layers.append(AttnBlock(channels[-1]))
|
231 |
+
layers.append(ResnetBlock(in_channels=channels[-1], out_channels=channels[-1], dropout=0.0))
|
232 |
+
layers.append(Normalize(channels[-1]))
|
233 |
+
layers.append(Swish())
|
234 |
+
layers.append(nn.Conv2d(channels[-1], z_channels, 3, 1, 1))
|
235 |
+
self.model = nn.Sequential(*layers)
|
236 |
+
|
237 |
+
def forward(self, x):
|
238 |
+
return self.model(x)
|
239 |
+
|
240 |
+
|
241 |
+
class Decoder(nn.Module):
|
242 |
+
def __init__(self, out_channels=3, channels=[128, 128, 128, 256, 512, 512], attn_resolutions=[32], resolution=512, dropout=0.0, num_res_blocks=2, z_channels=256, **kwargs):
|
243 |
+
super(Decoder, self).__init__()
|
244 |
+
ch_mult = channels[1:]
|
245 |
+
num_resolutions = len(ch_mult)
|
246 |
+
block_in = ch_mult[num_resolutions - 1]
|
247 |
+
curr_res = resolution// 2 ** (num_resolutions - 1)
|
248 |
+
|
249 |
+
layers = [nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1),
|
250 |
+
ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=0.0),
|
251 |
+
AttnBlock(block_in),
|
252 |
+
ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=0.0)
|
253 |
+
]
|
254 |
+
|
255 |
+
for i in reversed(range(num_resolutions)):
|
256 |
+
block_out = ch_mult[i]
|
257 |
+
for i_block in range(num_res_blocks+1):
|
258 |
+
layers.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dropout=0.0))
|
259 |
+
block_in = block_out
|
260 |
+
if curr_res in attn_resolutions:
|
261 |
+
layers.append(AttnBlock(block_in))
|
262 |
+
if i > 0:
|
263 |
+
layers.append(Upsample(block_in, with_conv=True))
|
264 |
+
curr_res = curr_res * 2
|
265 |
+
|
266 |
+
layers.append(Normalize(block_in))
|
267 |
+
layers.append(Swish())
|
268 |
+
layers.append(nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1))
|
269 |
+
|
270 |
+
self.model = nn.Sequential(*layers)
|
271 |
+
|
272 |
+
def forward(self, x):
|
273 |
+
return self.model(x)
|
274 |
+
|
275 |
+
|
276 |
+
class Codebook(nn.Module):
|
277 |
+
"""
|
278 |
+
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
|
279 |
+
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
|
280 |
+
"""
|
281 |
+
def __init__(self, codebook_size, codebook_dim, beta, init_steps=2000, reservoir_size=2e5):
|
282 |
+
super().__init__()
|
283 |
+
self.codebook_size = codebook_size
|
284 |
+
self.codebook_dim = codebook_dim
|
285 |
+
self.beta = beta
|
286 |
+
|
287 |
+
self.embedding = nn.Embedding(self.codebook_size, self.codebook_dim)
|
288 |
+
self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
|
289 |
+
|
290 |
+
self.q_start_collect, self.q_init, self.q_re_end, self.q_re_step = init_steps, init_steps * 3, init_steps * 30, init_steps // 2
|
291 |
+
self.q_counter = 0
|
292 |
+
self.reservoir_size = int(reservoir_size)
|
293 |
+
self.reservoir = None
|
294 |
+
|
295 |
+
def forward(self, z):
|
296 |
+
z = rearrange(z, 'b c h w -> b h w c').contiguous()
|
297 |
+
batch_size = z.size(0)
|
298 |
+
z_flattened = z.view(-1, self.codebook_dim)
|
299 |
+
if self.training:
|
300 |
+
self.q_counter += 1
|
301 |
+
# x_flat = x.permute(0, 2, 3, 1).reshape(-1, z.shape(1))
|
302 |
+
if self.q_counter > self.q_start_collect:
|
303 |
+
z_new = z_flattened.clone().detach().view(batch_size, -1, self.codebook_dim)
|
304 |
+
z_new = z_new[:, torch.randperm(z_new.size(1))][:, :10].reshape(-1, self.codebook_dim)
|
305 |
+
self.reservoir = z_new if self.reservoir is None else torch.cat([self.reservoir, z_new], dim=0)
|
306 |
+
self.reservoir = self.reservoir[torch.randperm(self.reservoir.size(0))[:self.reservoir_size]].detach()
|
307 |
+
if self.q_counter < self.q_init:
|
308 |
+
z_q = rearrange(z, 'b h w c -> b c h w').contiguous()
|
309 |
+
return z_q, z_q.new_tensor(0), None # z_q, loss, min_encoding_indices
|
310 |
+
else:
|
311 |
+
# if self.q_counter < self.q_init + self.q_re_end:
|
312 |
+
if self.q_init <= self.q_counter < self.q_re_end:
|
313 |
+
if (self.q_counter - self.q_init) % self.q_re_step == 0 or self.q_counter == self.q_init + self.q_re_end - 1:
|
314 |
+
kmeans = KMeans(n_clusters=self.codebook_size)
|
315 |
+
world_size = dist.get_world_size()
|
316 |
+
print("Updating codebook from reservoir.")
|
317 |
+
if world_size > 1:
|
318 |
+
global_reservoir = [torch.zeros_like(self.reservoir) for _ in range(world_size)]
|
319 |
+
dist.all_gather(global_reservoir, self.reservoir.clone())
|
320 |
+
global_reservoir = torch.cat(global_reservoir, dim=0)
|
321 |
+
else:
|
322 |
+
global_reservoir = self.reservoir
|
323 |
+
kmeans.fit_predict(global_reservoir) # reservoir is 20k encoded latents
|
324 |
+
self.embedding.weight.data = kmeans.centroids.detach()
|
325 |
+
|
326 |
+
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
|
327 |
+
torch.sum(self.embedding.weight**2, dim=1) - 2 * \
|
328 |
+
torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
|
329 |
+
|
330 |
+
min_encoding_indices = torch.argmin(d, dim=1)
|
331 |
+
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
332 |
+
|
333 |
+
# compute loss for embedding
|
334 |
+
loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
|
335 |
+
|
336 |
+
# preserve gradients
|
337 |
+
z_q = z + (z_q - z).detach()
|
338 |
+
|
339 |
+
# reshape back to match original input shape
|
340 |
+
z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
|
341 |
+
|
342 |
+
return z_q, loss, min_encoding_indices
|
343 |
+
|
344 |
+
def get_codebook_entry(self, indices, shape):
|
345 |
+
# get quantized latent vectors
|
346 |
+
z_q = self.embedding(indices)
|
347 |
+
|
348 |
+
if shape is not None:
|
349 |
+
z_q = z_q.view(shape)
|
350 |
+
# reshape back to match original input shape
|
351 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
352 |
+
|
353 |
+
return z_q
|
354 |
+
|
355 |
+
|
356 |
+
if __name__ == '__main__':
|
357 |
+
enc = Encoder()
|
358 |
+
dec = Decoder()
|
359 |
+
print(sum([p.numel() for p in enc.parameters()]))
|
360 |
+
print(sum([p.numel() for p in dec.parameters()]))
|
361 |
+
x = torch.randn(1, 3, 512, 512)
|
362 |
+
res = enc(x)
|
363 |
+
print(res.shape)
|
364 |
+
res = dec(res)
|
365 |
+
print(res.shape)
|