alandao commited on
Commit
ddf288e
1 Parent(s): ae6f85f

Upload modules.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modules.py +365 -0
modules.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ from fast_pytorch_kmeans import KMeans
7
+ from torch import einsum
8
+ import torch.distributed as dist
9
+ from einops import rearrange
10
+
11
+
12
+ def get_timestep_embedding(timesteps, embedding_dim):
13
+ """
14
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
15
+ From Fairseq.
16
+ Build sinusoidal embeddings.
17
+ This matches the implementation in tensor2tensor, but differs slightly
18
+ from the description in Section 3.5 of "Attention Is All You Need".
19
+ """
20
+ assert len(timesteps.shape) == 1
21
+
22
+ half_dim = embedding_dim // 2
23
+ emb = math.log(10000) / (half_dim - 1)
24
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
25
+ emb = emb.to(device=timesteps.device)
26
+ emb = timesteps.float()[:, None] * emb[None, :]
27
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
28
+ if embedding_dim % 2 == 1: # zero pad
29
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
30
+ return emb
31
+
32
+
33
+ def nonlinearity(x):
34
+ # swish
35
+ return x*torch.sigmoid(x)
36
+
37
+
38
+ def Normalize(in_channels):
39
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
40
+
41
+
42
+ class Upsample(nn.Module):
43
+ def __init__(self, in_channels, with_conv):
44
+ super().__init__()
45
+ self.with_conv = with_conv
46
+ if self.with_conv:
47
+ self.conv = torch.nn.Conv2d(in_channels,
48
+ in_channels,
49
+ kernel_size=3,
50
+ stride=1,
51
+ padding=1)
52
+
53
+ def forward(self, x):
54
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
55
+ if self.with_conv:
56
+ x = self.conv(x)
57
+ return x
58
+
59
+
60
+ class Downsample(nn.Module):
61
+ def __init__(self, in_channels, with_conv):
62
+ super().__init__()
63
+ self.with_conv = with_conv
64
+ if self.with_conv:
65
+ # no asymmetric padding in torch conv, must do it ourselves
66
+ self.conv = torch.nn.Conv2d(in_channels,
67
+ in_channels,
68
+ kernel_size=3,
69
+ stride=2,
70
+ padding=0)
71
+
72
+ def forward(self, x):
73
+ if self.with_conv:
74
+ pad = (0,1,0,1)
75
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
76
+ x = self.conv(x)
77
+ else:
78
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
79
+ return x
80
+
81
+
82
+ class ResnetBlock(nn.Module):
83
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout):
84
+ super().__init__()
85
+ self.in_channels = in_channels
86
+ out_channels = in_channels if out_channels is None else out_channels
87
+ self.out_channels = out_channels
88
+ self.use_conv_shortcut = conv_shortcut
89
+
90
+ self.norm1 = Normalize(in_channels)
91
+ self.conv1 = torch.nn.Conv2d(in_channels,
92
+ out_channels,
93
+ kernel_size=3,
94
+ stride=1,
95
+ padding=1)
96
+ self.norm2 = Normalize(out_channels)
97
+ self.dropout = torch.nn.Dropout(dropout)
98
+ self.conv2 = torch.nn.Conv2d(out_channels,
99
+ out_channels,
100
+ kernel_size=3,
101
+ stride=1,
102
+ padding=1)
103
+ if self.in_channels != self.out_channels:
104
+ if self.use_conv_shortcut:
105
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
106
+ out_channels,
107
+ kernel_size=3,
108
+ stride=1,
109
+ padding=1)
110
+ else:
111
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
112
+ out_channels,
113
+ kernel_size=1,
114
+ stride=1,
115
+ padding=0)
116
+
117
+ def forward(self, x):
118
+ h = x
119
+ h = self.norm1(h)
120
+ h = nonlinearity(h)
121
+ h = self.conv1(h)
122
+
123
+ h = self.norm2(h)
124
+ h = nonlinearity(h)
125
+ h = self.dropout(h)
126
+ h = self.conv2(h)
127
+
128
+ if self.in_channels != self.out_channels:
129
+ if self.use_conv_shortcut:
130
+ x = self.conv_shortcut(x)
131
+ else:
132
+ x = self.nin_shortcut(x)
133
+
134
+ return x+h
135
+
136
+
137
+ class AttnBlock(nn.Module):
138
+ def __init__(self, in_channels):
139
+ super().__init__()
140
+ self.in_channels = in_channels
141
+
142
+ self.norm = Normalize(in_channels)
143
+ self.q = torch.nn.Conv2d(in_channels,
144
+ in_channels,
145
+ kernel_size=1,
146
+ stride=1,
147
+ padding=0)
148
+ self.k = torch.nn.Conv2d(in_channels,
149
+ in_channels,
150
+ kernel_size=1,
151
+ stride=1,
152
+ padding=0)
153
+ self.v = torch.nn.Conv2d(in_channels,
154
+ in_channels,
155
+ kernel_size=1,
156
+ stride=1,
157
+ padding=0)
158
+ self.proj_out = torch.nn.Conv2d(in_channels,
159
+ in_channels,
160
+ kernel_size=1,
161
+ stride=1,
162
+ padding=0)
163
+
164
+
165
+ def forward(self, x):
166
+ h_ = x
167
+ h_ = self.norm(h_)
168
+ q = self.q(h_)
169
+ k = self.k(h_)
170
+ v = self.v(h_)
171
+
172
+ # compute attention
173
+ b,c,h,w = q.shape
174
+ q = q.reshape(b,c,h*w)
175
+ q = q.permute(0,2,1) # b,hw,c
176
+ k = k.reshape(b,c,h*w) # b,c,hw
177
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
178
+ w_ = w_ * (int(c)**(-0.5))
179
+ w_ = torch.nn.functional.softmax(w_, dim=2)
180
+
181
+ # attend to values
182
+ v = v.reshape(b,c,h*w)
183
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
184
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
185
+ h_ = h_.reshape(b,c,h,w)
186
+
187
+ h_ = self.proj_out(h_)
188
+
189
+ return x+h_
190
+
191
+
192
+ class Swish(nn.Module):
193
+ def forward(self, x):
194
+ return x * torch.sigmoid(x)
195
+
196
+
197
+ class Encoder(nn.Module):
198
+ """
199
+ Encoder of VQ-GAN to map input batch of images to latent space.
200
+ Dimension Transformations:
201
+ 3x256x256 --Conv2d--> 32x256x256
202
+ for loop:
203
+ --ResBlock--> 64x256x256 --DownBlock--> 64x128x128
204
+ --ResBlock--> 128x128x128 --DownBlock--> 128x64x64
205
+ --ResBlock--> 256x64x64 --DownBlock--> 256x32x32
206
+ --ResBlock--> 512x32x32
207
+ --ResBlock--> 512x32x32
208
+ --NonLocalBlock--> 512x32x32
209
+ --ResBlock--> 512x32x32
210
+ --GroupNorm-->
211
+ --Swish-->
212
+ --Conv2d-> 256x32x32
213
+ """
214
+
215
+ def __init__(self, in_channels=3, channels=[128, 128, 128, 256, 512, 512], attn_resolutions=[32], resolution=512, dropout=0.0, num_res_blocks=2, z_channels=256, **kwargs):
216
+ super(Encoder, self).__init__()
217
+ layers = [nn.Conv2d(in_channels, channels[0], 3, 1, 1)]
218
+ for i in range(len(channels) - 1):
219
+ in_channels = channels[i]
220
+ out_channels = channels[i + 1]
221
+ for j in range(num_res_blocks):
222
+ layers.append(ResnetBlock(in_channels=in_channels, out_channels=out_channels, dropout=0.0))
223
+ in_channels = out_channels
224
+ if resolution in attn_resolutions:
225
+ layers.append(AttnBlock(in_channels))
226
+ if i < len(channels) - 2:
227
+ layers.append(Downsample(channels[i + 1], with_conv=True))
228
+ resolution //= 2
229
+ layers.append(ResnetBlock(in_channels=channels[-1], out_channels=channels[-1], dropout=0.0))
230
+ layers.append(AttnBlock(channels[-1]))
231
+ layers.append(ResnetBlock(in_channels=channels[-1], out_channels=channels[-1], dropout=0.0))
232
+ layers.append(Normalize(channels[-1]))
233
+ layers.append(Swish())
234
+ layers.append(nn.Conv2d(channels[-1], z_channels, 3, 1, 1))
235
+ self.model = nn.Sequential(*layers)
236
+
237
+ def forward(self, x):
238
+ return self.model(x)
239
+
240
+
241
+ class Decoder(nn.Module):
242
+ def __init__(self, out_channels=3, channels=[128, 128, 128, 256, 512, 512], attn_resolutions=[32], resolution=512, dropout=0.0, num_res_blocks=2, z_channels=256, **kwargs):
243
+ super(Decoder, self).__init__()
244
+ ch_mult = channels[1:]
245
+ num_resolutions = len(ch_mult)
246
+ block_in = ch_mult[num_resolutions - 1]
247
+ curr_res = resolution// 2 ** (num_resolutions - 1)
248
+
249
+ layers = [nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1),
250
+ ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=0.0),
251
+ AttnBlock(block_in),
252
+ ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=0.0)
253
+ ]
254
+
255
+ for i in reversed(range(num_resolutions)):
256
+ block_out = ch_mult[i]
257
+ for i_block in range(num_res_blocks+1):
258
+ layers.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dropout=0.0))
259
+ block_in = block_out
260
+ if curr_res in attn_resolutions:
261
+ layers.append(AttnBlock(block_in))
262
+ if i > 0:
263
+ layers.append(Upsample(block_in, with_conv=True))
264
+ curr_res = curr_res * 2
265
+
266
+ layers.append(Normalize(block_in))
267
+ layers.append(Swish())
268
+ layers.append(nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1))
269
+
270
+ self.model = nn.Sequential(*layers)
271
+
272
+ def forward(self, x):
273
+ return self.model(x)
274
+
275
+
276
+ class Codebook(nn.Module):
277
+ """
278
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
279
+ avoids costly matrix multiplications and allows for post-hoc remapping of indices.
280
+ """
281
+ def __init__(self, codebook_size, codebook_dim, beta, init_steps=2000, reservoir_size=2e5):
282
+ super().__init__()
283
+ self.codebook_size = codebook_size
284
+ self.codebook_dim = codebook_dim
285
+ self.beta = beta
286
+
287
+ self.embedding = nn.Embedding(self.codebook_size, self.codebook_dim)
288
+ self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
289
+
290
+ self.q_start_collect, self.q_init, self.q_re_end, self.q_re_step = init_steps, init_steps * 3, init_steps * 30, init_steps // 2
291
+ self.q_counter = 0
292
+ self.reservoir_size = int(reservoir_size)
293
+ self.reservoir = None
294
+
295
+ def forward(self, z):
296
+ z = rearrange(z, 'b c h w -> b h w c').contiguous()
297
+ batch_size = z.size(0)
298
+ z_flattened = z.view(-1, self.codebook_dim)
299
+ if self.training:
300
+ self.q_counter += 1
301
+ # x_flat = x.permute(0, 2, 3, 1).reshape(-1, z.shape(1))
302
+ if self.q_counter > self.q_start_collect:
303
+ z_new = z_flattened.clone().detach().view(batch_size, -1, self.codebook_dim)
304
+ z_new = z_new[:, torch.randperm(z_new.size(1))][:, :10].reshape(-1, self.codebook_dim)
305
+ self.reservoir = z_new if self.reservoir is None else torch.cat([self.reservoir, z_new], dim=0)
306
+ self.reservoir = self.reservoir[torch.randperm(self.reservoir.size(0))[:self.reservoir_size]].detach()
307
+ if self.q_counter < self.q_init:
308
+ z_q = rearrange(z, 'b h w c -> b c h w').contiguous()
309
+ return z_q, z_q.new_tensor(0), None # z_q, loss, min_encoding_indices
310
+ else:
311
+ # if self.q_counter < self.q_init + self.q_re_end:
312
+ if self.q_init <= self.q_counter < self.q_re_end:
313
+ if (self.q_counter - self.q_init) % self.q_re_step == 0 or self.q_counter == self.q_init + self.q_re_end - 1:
314
+ kmeans = KMeans(n_clusters=self.codebook_size)
315
+ world_size = dist.get_world_size()
316
+ print("Updating codebook from reservoir.")
317
+ if world_size > 1:
318
+ global_reservoir = [torch.zeros_like(self.reservoir) for _ in range(world_size)]
319
+ dist.all_gather(global_reservoir, self.reservoir.clone())
320
+ global_reservoir = torch.cat(global_reservoir, dim=0)
321
+ else:
322
+ global_reservoir = self.reservoir
323
+ kmeans.fit_predict(global_reservoir) # reservoir is 20k encoded latents
324
+ self.embedding.weight.data = kmeans.centroids.detach()
325
+
326
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
327
+ torch.sum(self.embedding.weight**2, dim=1) - 2 * \
328
+ torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
329
+
330
+ min_encoding_indices = torch.argmin(d, dim=1)
331
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
332
+
333
+ # compute loss for embedding
334
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
335
+
336
+ # preserve gradients
337
+ z_q = z + (z_q - z).detach()
338
+
339
+ # reshape back to match original input shape
340
+ z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
341
+
342
+ return z_q, loss, min_encoding_indices
343
+
344
+ def get_codebook_entry(self, indices, shape):
345
+ # get quantized latent vectors
346
+ z_q = self.embedding(indices)
347
+
348
+ if shape is not None:
349
+ z_q = z_q.view(shape)
350
+ # reshape back to match original input shape
351
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
352
+
353
+ return z_q
354
+
355
+
356
+ if __name__ == '__main__':
357
+ enc = Encoder()
358
+ dec = Decoder()
359
+ print(sum([p.numel() for p in enc.parameters()]))
360
+ print(sum([p.numel() for p in dec.parameters()]))
361
+ x = torch.randn(1, 3, 512, 512)
362
+ res = enc(x)
363
+ print(res.shape)
364
+ res = dec(res)
365
+ print(res.shape)