valhalla commited on
Commit
1ec58a5
1 Parent(s): 41c77df

Create modeling_vae.py

Browse files
Files changed (1) hide show
  1. modeling_vae.py +858 -0
modeling_vae.py ADDED
@@ -0,0 +1,858 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_diffusion + derived encoder decoder
2
+ import math
3
+
4
+ import numpy as np
5
+ import tqdm
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from diffusers import DiffusionPipeline
10
+ from diffusers.configuration_utils import ConfigMixin
11
+ from diffusers.modeling_utils import ModelMixin
12
+
13
+
14
+ def get_timestep_embedding(timesteps, embedding_dim):
15
+ """
16
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
17
+ From Fairseq.
18
+ Build sinusoidal embeddings.
19
+ This matches the implementation in tensor2tensor, but differs slightly
20
+ from the description in Section 3.5 of "Attention Is All You Need".
21
+ """
22
+ assert len(timesteps.shape) == 1
23
+
24
+ half_dim = embedding_dim // 2
25
+ emb = math.log(10000) / (half_dim - 1)
26
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
27
+ emb = emb.to(device=timesteps.device)
28
+ emb = timesteps.float()[:, None] * emb[None, :]
29
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
30
+ if embedding_dim % 2 == 1: # zero pad
31
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
32
+ return emb
33
+
34
+
35
+ def nonlinearity(x):
36
+ # swish
37
+ return x * torch.sigmoid(x)
38
+
39
+
40
+ def Normalize(in_channels):
41
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
42
+
43
+
44
+ class Upsample(nn.Module):
45
+ def __init__(self, in_channels, with_conv):
46
+ super().__init__()
47
+ self.with_conv = with_conv
48
+ if self.with_conv:
49
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
50
+
51
+ def forward(self, x):
52
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
53
+ if self.with_conv:
54
+ x = self.conv(x)
55
+ return x
56
+
57
+
58
+ class Downsample(nn.Module):
59
+ def __init__(self, in_channels, with_conv):
60
+ super().__init__()
61
+ self.with_conv = with_conv
62
+ if self.with_conv:
63
+ # no asymmetric padding in torch conv, must do it ourselves
64
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
65
+
66
+ def forward(self, x):
67
+ if self.with_conv:
68
+ pad = (0, 1, 0, 1)
69
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
70
+ x = self.conv(x)
71
+ else:
72
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
73
+ return x
74
+
75
+
76
+ class ResnetBlock(nn.Module):
77
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
78
+ super().__init__()
79
+ self.in_channels = in_channels
80
+ out_channels = in_channels if out_channels is None else out_channels
81
+ self.out_channels = out_channels
82
+ self.use_conv_shortcut = conv_shortcut
83
+
84
+ self.norm1 = Normalize(in_channels)
85
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
86
+ if temb_channels > 0:
87
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
88
+ self.norm2 = Normalize(out_channels)
89
+ self.dropout = torch.nn.Dropout(dropout)
90
+ self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
91
+ if self.in_channels != self.out_channels:
92
+ if self.use_conv_shortcut:
93
+ self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
94
+ else:
95
+ self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
96
+
97
+ def forward(self, x, temb):
98
+ h = x
99
+ h = self.norm1(h)
100
+ h = nonlinearity(h)
101
+ h = self.conv1(h)
102
+
103
+ if temb is not None:
104
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
105
+
106
+ h = self.norm2(h)
107
+ h = nonlinearity(h)
108
+ h = self.dropout(h)
109
+ h = self.conv2(h)
110
+
111
+ if self.in_channels != self.out_channels:
112
+ if self.use_conv_shortcut:
113
+ x = self.conv_shortcut(x)
114
+ else:
115
+ x = self.nin_shortcut(x)
116
+
117
+ return x + h
118
+
119
+
120
+ class AttnBlock(nn.Module):
121
+ def __init__(self, in_channels):
122
+ super().__init__()
123
+ self.in_channels = in_channels
124
+
125
+ self.norm = Normalize(in_channels)
126
+ self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
127
+ self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
128
+ self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
129
+ self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
130
+
131
+ def forward(self, x):
132
+ h_ = x
133
+ h_ = self.norm(h_)
134
+ q = self.q(h_)
135
+ k = self.k(h_)
136
+ v = self.v(h_)
137
+
138
+ # compute attention
139
+ b, c, h, w = q.shape
140
+ q = q.reshape(b, c, h * w)
141
+ q = q.permute(0, 2, 1) # b,hw,c
142
+ k = k.reshape(b, c, h * w) # b,c,hw
143
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
144
+ w_ = w_ * (int(c) ** (-0.5))
145
+ w_ = torch.nn.functional.softmax(w_, dim=2)
146
+
147
+ # attend to values
148
+ v = v.reshape(b, c, h * w)
149
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
150
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
151
+ h_ = h_.reshape(b, c, h, w)
152
+
153
+ h_ = self.proj_out(h_)
154
+
155
+ return x + h_
156
+
157
+
158
+ class Model(nn.Module):
159
+ def __init__(
160
+ self,
161
+ *,
162
+ ch,
163
+ out_ch,
164
+ ch_mult=(1, 2, 4, 8),
165
+ num_res_blocks,
166
+ attn_resolutions,
167
+ dropout=0.0,
168
+ resamp_with_conv=True,
169
+ in_channels,
170
+ resolution,
171
+ use_timestep=True,
172
+ ):
173
+ super().__init__()
174
+ self.ch = ch
175
+ self.temb_ch = self.ch * 4
176
+ self.num_resolutions = len(ch_mult)
177
+ self.num_res_blocks = num_res_blocks
178
+ self.resolution = resolution
179
+ self.in_channels = in_channels
180
+
181
+ self.use_timestep = use_timestep
182
+ if self.use_timestep:
183
+ # timestep embedding
184
+ self.temb = nn.Module()
185
+ self.temb.dense = nn.ModuleList(
186
+ [
187
+ torch.nn.Linear(self.ch, self.temb_ch),
188
+ torch.nn.Linear(self.temb_ch, self.temb_ch),
189
+ ]
190
+ )
191
+
192
+ # downsampling
193
+ self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
194
+
195
+ curr_res = resolution
196
+ in_ch_mult = (1,) + tuple(ch_mult)
197
+ self.down = nn.ModuleList()
198
+ for i_level in range(self.num_resolutions):
199
+ block = nn.ModuleList()
200
+ attn = nn.ModuleList()
201
+ block_in = ch * in_ch_mult[i_level]
202
+ block_out = ch * ch_mult[i_level]
203
+ for i_block in range(self.num_res_blocks):
204
+ block.append(
205
+ ResnetBlock(
206
+ in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
207
+ )
208
+ )
209
+ block_in = block_out
210
+ if curr_res in attn_resolutions:
211
+ attn.append(AttnBlock(block_in))
212
+ down = nn.Module()
213
+ down.block = block
214
+ down.attn = attn
215
+ if i_level != self.num_resolutions - 1:
216
+ down.downsample = Downsample(block_in, resamp_with_conv)
217
+ curr_res = curr_res // 2
218
+ self.down.append(down)
219
+
220
+ # middle
221
+ self.mid = nn.Module()
222
+ self.mid.block_1 = ResnetBlock(
223
+ in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
224
+ )
225
+ self.mid.attn_1 = AttnBlock(block_in)
226
+ self.mid.block_2 = ResnetBlock(
227
+ in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
228
+ )
229
+
230
+ # upsampling
231
+ self.up = nn.ModuleList()
232
+ for i_level in reversed(range(self.num_resolutions)):
233
+ block = nn.ModuleList()
234
+ attn = nn.ModuleList()
235
+ block_out = ch * ch_mult[i_level]
236
+ skip_in = ch * ch_mult[i_level]
237
+ for i_block in range(self.num_res_blocks + 1):
238
+ if i_block == self.num_res_blocks:
239
+ skip_in = ch * in_ch_mult[i_level]
240
+ block.append(
241
+ ResnetBlock(
242
+ in_channels=block_in + skip_in,
243
+ out_channels=block_out,
244
+ temb_channels=self.temb_ch,
245
+ dropout=dropout,
246
+ )
247
+ )
248
+ block_in = block_out
249
+ if curr_res in attn_resolutions:
250
+ attn.append(AttnBlock(block_in))
251
+ up = nn.Module()
252
+ up.block = block
253
+ up.attn = attn
254
+ if i_level != 0:
255
+ up.upsample = Upsample(block_in, resamp_with_conv)
256
+ curr_res = curr_res * 2
257
+ self.up.insert(0, up) # prepend to get consistent order
258
+
259
+ # end
260
+ self.norm_out = Normalize(block_in)
261
+ self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
262
+
263
+ def forward(self, x, t=None):
264
+ # assert x.shape[2] == x.shape[3] == self.resolution
265
+
266
+ if self.use_timestep:
267
+ # timestep embedding
268
+ assert t is not None
269
+ temb = get_timestep_embedding(t, self.ch)
270
+ temb = self.temb.dense[0](temb)
271
+ temb = nonlinearity(temb)
272
+ temb = self.temb.dense[1](temb)
273
+ else:
274
+ temb = None
275
+
276
+ # downsampling
277
+ hs = [self.conv_in(x)]
278
+ for i_level in range(self.num_resolutions):
279
+ for i_block in range(self.num_res_blocks):
280
+ h = self.down[i_level].block[i_block](hs[-1], temb)
281
+ if len(self.down[i_level].attn) > 0:
282
+ h = self.down[i_level].attn[i_block](h)
283
+ hs.append(h)
284
+ if i_level != self.num_resolutions - 1:
285
+ hs.append(self.down[i_level].downsample(hs[-1]))
286
+
287
+ # middle
288
+ h = hs[-1]
289
+ h = self.mid.block_1(h, temb)
290
+ h = self.mid.attn_1(h)
291
+ h = self.mid.block_2(h, temb)
292
+
293
+ # upsampling
294
+ for i_level in reversed(range(self.num_resolutions)):
295
+ for i_block in range(self.num_res_blocks + 1):
296
+ h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb)
297
+ if len(self.up[i_level].attn) > 0:
298
+ h = self.up[i_level].attn[i_block](h)
299
+ if i_level != 0:
300
+ h = self.up[i_level].upsample(h)
301
+
302
+ # end
303
+ h = self.norm_out(h)
304
+ h = nonlinearity(h)
305
+ h = self.conv_out(h)
306
+ return h
307
+
308
+
309
+ class Encoder(nn.Module):
310
+ def __init__(
311
+ self,
312
+ *,
313
+ ch,
314
+ out_ch,
315
+ ch_mult=(1, 2, 4, 8),
316
+ num_res_blocks,
317
+ attn_resolutions,
318
+ dropout=0.0,
319
+ resamp_with_conv=True,
320
+ in_channels,
321
+ resolution,
322
+ z_channels,
323
+ double_z=True,
324
+ **ignore_kwargs,
325
+ ):
326
+ super().__init__()
327
+ self.ch = ch
328
+ self.temb_ch = 0
329
+ self.num_resolutions = len(ch_mult)
330
+ self.num_res_blocks = num_res_blocks
331
+ self.resolution = resolution
332
+ self.in_channels = in_channels
333
+
334
+ # downsampling
335
+ self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
336
+
337
+ curr_res = resolution
338
+ in_ch_mult = (1,) + tuple(ch_mult)
339
+ self.down = nn.ModuleList()
340
+ for i_level in range(self.num_resolutions):
341
+ block = nn.ModuleList()
342
+ attn = nn.ModuleList()
343
+ block_in = ch * in_ch_mult[i_level]
344
+ block_out = ch * ch_mult[i_level]
345
+ for i_block in range(self.num_res_blocks):
346
+ block.append(
347
+ ResnetBlock(
348
+ in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
349
+ )
350
+ )
351
+ block_in = block_out
352
+ if curr_res in attn_resolutions:
353
+ attn.append(AttnBlock(block_in))
354
+ down = nn.Module()
355
+ down.block = block
356
+ down.attn = attn
357
+ if i_level != self.num_resolutions - 1:
358
+ down.downsample = Downsample(block_in, resamp_with_conv)
359
+ curr_res = curr_res // 2
360
+ self.down.append(down)
361
+
362
+ # middle
363
+ self.mid = nn.Module()
364
+ self.mid.block_1 = ResnetBlock(
365
+ in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
366
+ )
367
+ self.mid.attn_1 = AttnBlock(block_in)
368
+ self.mid.block_2 = ResnetBlock(
369
+ in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
370
+ )
371
+
372
+ # end
373
+ self.norm_out = Normalize(block_in)
374
+ self.conv_out = torch.nn.Conv2d(
375
+ block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1
376
+ )
377
+
378
+ def forward(self, x):
379
+ # assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
380
+
381
+ # timestep embedding
382
+ temb = None
383
+
384
+ # downsampling
385
+ hs = [self.conv_in(x)]
386
+ for i_level in range(self.num_resolutions):
387
+ for i_block in range(self.num_res_blocks):
388
+ h = self.down[i_level].block[i_block](hs[-1], temb)
389
+ if len(self.down[i_level].attn) > 0:
390
+ h = self.down[i_level].attn[i_block](h)
391
+ hs.append(h)
392
+ if i_level != self.num_resolutions - 1:
393
+ hs.append(self.down[i_level].downsample(hs[-1]))
394
+
395
+ # middle
396
+ h = hs[-1]
397
+ h = self.mid.block_1(h, temb)
398
+ h = self.mid.attn_1(h)
399
+ h = self.mid.block_2(h, temb)
400
+
401
+ # end
402
+ h = self.norm_out(h)
403
+ h = nonlinearity(h)
404
+ h = self.conv_out(h)
405
+ return h
406
+
407
+
408
+ class Decoder(nn.Module):
409
+ def __init__(
410
+ self,
411
+ *,
412
+ ch,
413
+ out_ch,
414
+ ch_mult=(1, 2, 4, 8),
415
+ num_res_blocks,
416
+ attn_resolutions,
417
+ dropout=0.0,
418
+ resamp_with_conv=True,
419
+ in_channels,
420
+ resolution,
421
+ z_channels,
422
+ give_pre_end=False,
423
+ **ignorekwargs,
424
+ ):
425
+ super().__init__()
426
+ self.ch = ch
427
+ self.temb_ch = 0
428
+ self.num_resolutions = len(ch_mult)
429
+ self.num_res_blocks = num_res_blocks
430
+ self.resolution = resolution
431
+ self.in_channels = in_channels
432
+ self.give_pre_end = give_pre_end
433
+
434
+ # compute in_ch_mult, block_in and curr_res at lowest res
435
+ in_ch_mult = (1,) + tuple(ch_mult)
436
+ block_in = ch * ch_mult[self.num_resolutions - 1]
437
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
438
+ self.z_shape = (1, z_channels, curr_res, curr_res)
439
+ print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
440
+
441
+ # z to block_in
442
+ self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
443
+
444
+ # middle
445
+ self.mid = nn.Module()
446
+ self.mid.block_1 = ResnetBlock(
447
+ in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
448
+ )
449
+ self.mid.attn_1 = AttnBlock(block_in)
450
+ self.mid.block_2 = ResnetBlock(
451
+ in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
452
+ )
453
+
454
+ # upsampling
455
+ self.up = nn.ModuleList()
456
+ for i_level in reversed(range(self.num_resolutions)):
457
+ block = nn.ModuleList()
458
+ attn = nn.ModuleList()
459
+ block_out = ch * ch_mult[i_level]
460
+ for i_block in range(self.num_res_blocks + 1):
461
+ block.append(
462
+ ResnetBlock(
463
+ in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout
464
+ )
465
+ )
466
+ block_in = block_out
467
+ if curr_res in attn_resolutions:
468
+ attn.append(AttnBlock(block_in))
469
+ up = nn.Module()
470
+ up.block = block
471
+ up.attn = attn
472
+ if i_level != 0:
473
+ up.upsample = Upsample(block_in, resamp_with_conv)
474
+ curr_res = curr_res * 2
475
+ self.up.insert(0, up) # prepend to get consistent order
476
+
477
+ # end
478
+ self.norm_out = Normalize(block_in)
479
+ self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
480
+
481
+ def forward(self, z):
482
+ # assert z.shape[1:] == self.z_shape[1:]
483
+ self.last_z_shape = z.shape
484
+
485
+ # timestep embedding
486
+ temb = None
487
+
488
+ # z to block_in
489
+ h = self.conv_in(z)
490
+
491
+ # middle
492
+ h = self.mid.block_1(h, temb)
493
+ h = self.mid.attn_1(h)
494
+ h = self.mid.block_2(h, temb)
495
+
496
+ # upsampling
497
+ for i_level in reversed(range(self.num_resolutions)):
498
+ for i_block in range(self.num_res_blocks + 1):
499
+ h = self.up[i_level].block[i_block](h, temb)
500
+ if len(self.up[i_level].attn) > 0:
501
+ h = self.up[i_level].attn[i_block](h)
502
+ if i_level != 0:
503
+ h = self.up[i_level].upsample(h)
504
+
505
+ # end
506
+ if self.give_pre_end:
507
+ return h
508
+
509
+ h = self.norm_out(h)
510
+ h = nonlinearity(h)
511
+ h = self.conv_out(h)
512
+ return h
513
+
514
+
515
+ class VectorQuantizer(nn.Module):
516
+ """
517
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
518
+ avoids costly matrix multiplications and allows for post-hoc remapping of indices.
519
+ """
520
+
521
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
522
+ # backwards compatibility we use the buggy version by default, but you can
523
+ # specify legacy=False to fix it.
524
+ def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True):
525
+ super().__init__()
526
+ self.n_e = n_e
527
+ self.e_dim = e_dim
528
+ self.beta = beta
529
+ self.legacy = legacy
530
+
531
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
532
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
533
+
534
+ self.remap = remap
535
+ if self.remap is not None:
536
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
537
+ self.re_embed = self.used.shape[0]
538
+ self.unknown_index = unknown_index # "random" or "extra" or integer
539
+ if self.unknown_index == "extra":
540
+ self.unknown_index = self.re_embed
541
+ self.re_embed = self.re_embed + 1
542
+ print(
543
+ f"Remapping {self.n_e} indices to {self.re_embed} indices. "
544
+ f"Using {self.unknown_index} for unknown indices."
545
+ )
546
+ else:
547
+ self.re_embed = n_e
548
+
549
+ self.sane_index_shape = sane_index_shape
550
+
551
+ def remap_to_used(self, inds):
552
+ ishape = inds.shape
553
+ assert len(ishape) > 1
554
+ inds = inds.reshape(ishape[0], -1)
555
+ used = self.used.to(inds)
556
+ match = (inds[:, :, None] == used[None, None, ...]).long()
557
+ new = match.argmax(-1)
558
+ unknown = match.sum(2) < 1
559
+ if self.unknown_index == "random":
560
+ new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
561
+ else:
562
+ new[unknown] = self.unknown_index
563
+ return new.reshape(ishape)
564
+
565
+ def unmap_to_all(self, inds):
566
+ ishape = inds.shape
567
+ assert len(ishape) > 1
568
+ inds = inds.reshape(ishape[0], -1)
569
+ used = self.used.to(inds)
570
+ if self.re_embed > self.used.shape[0]: # extra token
571
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
572
+ back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
573
+ return back.reshape(ishape)
574
+
575
+ def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
576
+ assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
577
+ assert rescale_logits == False, "Only for interface compatible with Gumbel"
578
+ assert return_logits == False, "Only for interface compatible with Gumbel"
579
+ # reshape z -> (batch, height, width, channel) and flatten
580
+ z = rearrange(z, "b c h w -> b h w c").contiguous()
581
+ z_flattened = z.view(-1, self.e_dim)
582
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
583
+
584
+ d = (
585
+ torch.sum(z_flattened**2, dim=1, keepdim=True)
586
+ + torch.sum(self.embedding.weight**2, dim=1)
587
+ - 2 * torch.einsum("bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n"))
588
+ )
589
+
590
+ min_encoding_indices = torch.argmin(d, dim=1)
591
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
592
+ perplexity = None
593
+ min_encodings = None
594
+
595
+ # compute loss for embedding
596
+ if not self.legacy:
597
+ loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
598
+ else:
599
+ loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
600
+
601
+ # preserve gradients
602
+ z_q = z + (z_q - z).detach()
603
+
604
+ # reshape back to match original input shape
605
+ z_q = rearrange(z_q, "b h w c -> b c h w").contiguous()
606
+
607
+ if self.remap is not None:
608
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
609
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
610
+ min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
611
+
612
+ if self.sane_index_shape:
613
+ min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
614
+
615
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
616
+
617
+ def get_codebook_entry(self, indices, shape):
618
+ # shape specifying (batch, height, width, channel)
619
+ if self.remap is not None:
620
+ indices = indices.reshape(shape[0], -1) # add batch axis
621
+ indices = self.unmap_to_all(indices)
622
+ indices = indices.reshape(-1) # flatten again
623
+
624
+ # get quantized latent vectors
625
+ z_q = self.embedding(indices)
626
+
627
+ if shape is not None:
628
+ z_q = z_q.view(shape)
629
+ # reshape back to match original input shape
630
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
631
+
632
+ return z_q
633
+
634
+
635
+ class VQModel(ModelMixin, ConfigMixin):
636
+ def __init__(
637
+ self,
638
+ ch,
639
+ out_ch,
640
+ num_res_blocks,
641
+ attn_resolutions,
642
+ in_channels,
643
+ resolution,
644
+ z_channels,
645
+ n_embed,
646
+ embed_dim,
647
+ remap=None,
648
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
649
+ ch_mult=(1, 2, 4, 8),
650
+ dropout=0.0,
651
+ double_z=True,
652
+ resamp_with_conv=True,
653
+ give_pre_end=False,
654
+ ):
655
+ super().__init__()
656
+
657
+ # register all __init__ params with self.register
658
+ self.register(
659
+ ch=ch,
660
+ out_ch=out_ch,
661
+ num_res_blocks=num_res_blocks,
662
+ attn_resolutions=attn_resolutions,
663
+ in_channels=in_channels,
664
+ resolution=resolution,
665
+ z_channels=z_channels,
666
+ n_embed=n_embed,
667
+ embed_dim=embed_dim,
668
+ remap=remap,
669
+ sane_index_shape=sane_index_shape,
670
+ ch_mult=ch_mult,
671
+ dropout=dropout,
672
+ double_z=double_z,
673
+ resamp_with_conv=resamp_with_conv,
674
+ give_pre_end=give_pre_end,
675
+ )
676
+
677
+ # pass init params to Encoder
678
+ self.encoder = Encoder(
679
+ ch=ch,
680
+ out_ch=out_ch,
681
+ num_res_blocks=num_res_blocks,
682
+ attn_resolutions=attn_resolutions,
683
+ in_channels=in_channels,
684
+ resolution=resolution,
685
+ z_channels=z_channels,
686
+ ch_mult=ch_mult,
687
+ dropout=dropout,
688
+ resamp_with_conv=resamp_with_conv,
689
+ double_z=double_z,
690
+ give_pre_end=give_pre_end,
691
+ )
692
+
693
+ self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, remap=remap, sane_index_shape=sane_index_shape)
694
+
695
+ # pass init params to Decoder
696
+ self.decoder = Decoder(
697
+ ch=ch,
698
+ out_ch=out_ch,
699
+ num_res_blocks=num_res_blocks,
700
+ attn_resolutions=attn_resolutions,
701
+ in_channels=in_channels,
702
+ resolution=resolution,
703
+ z_channels=z_channels,
704
+ ch_mult=ch_mult,
705
+ dropout=dropout,
706
+ resamp_with_conv=resamp_with_conv,
707
+ give_pre_end=give_pre_end,
708
+ )
709
+
710
+ def encode(self, x):
711
+ h = self.encoder(x)
712
+ h = self.quant_conv(h)
713
+ return h
714
+
715
+ def decode(self, h, force_not_quantize=False):
716
+ # also go through quantization layer
717
+ if not force_not_quantize:
718
+ quant, emb_loss, info = self.quantize(h)
719
+ else:
720
+ quant = h
721
+ quant = self.post_quant_conv(quant)
722
+ dec = self.decoder(quant)
723
+ return dec
724
+
725
+
726
+ class DiagonalGaussianDistribution(object):
727
+ def __init__(self, parameters, deterministic=False):
728
+ self.parameters = parameters
729
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
730
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
731
+ self.deterministic = deterministic
732
+ self.std = torch.exp(0.5 * self.logvar)
733
+ self.var = torch.exp(self.logvar)
734
+ if self.deterministic:
735
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
736
+
737
+ def sample(self):
738
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
739
+ return x
740
+
741
+ def kl(self, other=None):
742
+ if self.deterministic:
743
+ return torch.Tensor([0.])
744
+ else:
745
+ if other is None:
746
+ return 0.5 * torch.sum(torch.pow(self.mean, 2)
747
+ + self.var - 1.0 - self.logvar,
748
+ dim=[1, 2, 3])
749
+ else:
750
+ return 0.5 * torch.sum(
751
+ torch.pow(self.mean - other.mean, 2) / other.var
752
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
753
+ dim=[1, 2, 3])
754
+
755
+ def nll(self, sample, dims=[1,2,3]):
756
+ if self.deterministic:
757
+ return torch.Tensor([0.])
758
+ logtwopi = np.log(2.0 * np.pi)
759
+ return 0.5 * torch.sum(
760
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
761
+ dim=dims)
762
+
763
+ def mode(self):
764
+ return self.mean
765
+
766
+ class AutoencoderKL(ModelMixin, ConfigMixin):
767
+ def __init__(
768
+ self,
769
+ ch,
770
+ out_ch,
771
+ num_res_blocks,
772
+ attn_resolutions,
773
+ in_channels,
774
+ resolution,
775
+ z_channels,
776
+ embed_dim,
777
+ remap=None,
778
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
779
+ ch_mult=(1, 2, 4, 8),
780
+ dropout=0.0,
781
+ double_z=True,
782
+ resamp_with_conv=True,
783
+ give_pre_end=False,
784
+ ):
785
+ super().__init__()
786
+
787
+ # register all __init__ params with self.register
788
+ self.register(
789
+ ch=ch,
790
+ out_ch=out_ch,
791
+ num_res_blocks=num_res_blocks,
792
+ attn_resolutions=attn_resolutions,
793
+ in_channels=in_channels,
794
+ resolution=resolution,
795
+ z_channels=z_channels,
796
+ embed_dim=embed_dim,
797
+ remap=remap,
798
+ sane_index_shape=sane_index_shape,
799
+ ch_mult=ch_mult,
800
+ dropout=dropout,
801
+ double_z=double_z,
802
+ resamp_with_conv=resamp_with_conv,
803
+ give_pre_end=give_pre_end,
804
+ )
805
+
806
+ # pass init params to Encoder
807
+ self.encoder = Encoder(
808
+ ch=ch,
809
+ out_ch=out_ch,
810
+ num_res_blocks=num_res_blocks,
811
+ attn_resolutions=attn_resolutions,
812
+ in_channels=in_channels,
813
+ resolution=resolution,
814
+ z_channels=z_channels,
815
+ ch_mult=ch_mult,
816
+ dropout=dropout,
817
+ resamp_with_conv=resamp_with_conv,
818
+ double_z=double_z,
819
+ give_pre_end=give_pre_end,
820
+ )
821
+
822
+ # pass init params to Decoder
823
+ self.decoder = Decoder(
824
+ ch=ch,
825
+ out_ch=out_ch,
826
+ num_res_blocks=num_res_blocks,
827
+ attn_resolutions=attn_resolutions,
828
+ in_channels=in_channels,
829
+ resolution=resolution,
830
+ z_channels=z_channels,
831
+ ch_mult=ch_mult,
832
+ dropout=dropout,
833
+ resamp_with_conv=resamp_with_conv,
834
+ give_pre_end=give_pre_end,
835
+ )
836
+
837
+ self.quant_conv = torch.nn.Conv2d(2*z_channels, 2*embed_dim, 1)
838
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)
839
+
840
+ def encode(self, x):
841
+ h = self.encoder(x)
842
+ moments = self.quant_conv(h)
843
+ posterior = DiagonalGaussianDistribution(moments)
844
+ return posterior
845
+
846
+ def decode(self, z):
847
+ z = self.post_quant_conv(z)
848
+ dec = self.decoder(z)
849
+ return dec
850
+
851
+ def forward(self, input, sample_posterior=True):
852
+ posterior = self.encode(input)
853
+ if sample_posterior:
854
+ z = posterior.sample()
855
+ else:
856
+ z = posterior.mode()
857
+ dec = self.decode(z)
858
+ return dec, posterior