keithhon commited on
Commit
bceda6b
·
1 Parent(s): 9ee87b9

Upload dalle/models/stage1/layers.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. dalle/models/stage1/layers.py +373 -0
dalle/models/stage1/layers.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------
2
+ # Modified from VQGAN (https://github.com/CompVis/taming-transformers)
3
+ # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
4
+ # ------------------------------------------------------------------------------------
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from typing import Tuple, Optional
9
+
10
+
11
+ def nonlinearity(x):
12
+ # swish
13
+ return x*torch.sigmoid(x)
14
+
15
+
16
+ def Normalize(in_channels):
17
+ return torch.nn.GroupNorm(num_groups=32,
18
+ num_channels=in_channels,
19
+ eps=1e-6,
20
+ affine=True)
21
+
22
+
23
+ class Upsample(nn.Module):
24
+ def __init__(self, in_channels, with_conv):
25
+ super().__init__()
26
+ self.with_conv = with_conv
27
+ if self.with_conv:
28
+ self.conv = torch.nn.Conv2d(in_channels,
29
+ in_channels,
30
+ kernel_size=3,
31
+ stride=1,
32
+ padding=1)
33
+
34
+ def forward(self, x):
35
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
36
+ if self.with_conv:
37
+ x = self.conv(x)
38
+ return x
39
+
40
+
41
+ class Downsample(nn.Module):
42
+ def __init__(self, in_channels, with_conv):
43
+ super().__init__()
44
+ self.with_conv = with_conv
45
+ if self.with_conv:
46
+ # no asymmetric padding in torch conv, must do it ourselves
47
+ self.conv = torch.nn.Conv2d(in_channels,
48
+ in_channels,
49
+ kernel_size=3,
50
+ stride=2,
51
+ padding=0)
52
+
53
+ def forward(self, x):
54
+ if self.with_conv:
55
+ pad = (0, 1, 0, 1)
56
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
57
+ x = self.conv(x)
58
+ else:
59
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
60
+ return x
61
+
62
+
63
+ class ResnetBlock(nn.Module):
64
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
65
+ dropout, temb_channels=512):
66
+ assert temb_channels == 0
67
+ super().__init__()
68
+ self.in_channels = in_channels
69
+ out_channels = in_channels if out_channels is None else out_channels
70
+ self.out_channels = out_channels
71
+ self.use_conv_shortcut = conv_shortcut
72
+
73
+ self.norm1 = Normalize(in_channels)
74
+ self.conv1 = torch.nn.Conv2d(in_channels,
75
+ out_channels,
76
+ kernel_size=3,
77
+ stride=1,
78
+ padding=1)
79
+ self.norm2 = Normalize(out_channels)
80
+ self.dropout = torch.nn.Dropout(dropout)
81
+ self.conv2 = torch.nn.Conv2d(out_channels,
82
+ out_channels,
83
+ kernel_size=3,
84
+ stride=1,
85
+ padding=1)
86
+ if self.in_channels != self.out_channels:
87
+ if self.use_conv_shortcut:
88
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
89
+ out_channels,
90
+ kernel_size=3,
91
+ stride=1,
92
+ padding=1)
93
+ else:
94
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
95
+ out_channels,
96
+ kernel_size=1,
97
+ stride=1,
98
+ padding=0)
99
+
100
+ def forward(self, x, temb=None):
101
+ assert temb is None
102
+
103
+ h = x
104
+ h = self.norm1(h)
105
+ h = nonlinearity(h)
106
+ h = self.conv1(h)
107
+
108
+ h = self.norm2(h)
109
+ h = nonlinearity(h)
110
+ h = self.dropout(h)
111
+ h = self.conv2(h)
112
+
113
+ if self.in_channels != self.out_channels:
114
+ if self.use_conv_shortcut:
115
+ x = self.conv_shortcut(x)
116
+ else:
117
+ x = self.nin_shortcut(x)
118
+ return x+h
119
+
120
+
121
+ class AttnBlock(nn.Module):
122
+ def __init__(self, in_channels):
123
+ super().__init__()
124
+ self.in_channels = in_channels
125
+
126
+ self.norm = Normalize(in_channels)
127
+ self.q = torch.nn.Conv2d(in_channels,
128
+ in_channels,
129
+ kernel_size=1,
130
+ stride=1,
131
+ padding=0)
132
+ self.k = torch.nn.Conv2d(in_channels,
133
+ in_channels,
134
+ kernel_size=1,
135
+ stride=1,
136
+ padding=0)
137
+ self.v = torch.nn.Conv2d(in_channels,
138
+ in_channels,
139
+ kernel_size=1,
140
+ stride=1,
141
+ padding=0)
142
+ self.proj_out = torch.nn.Conv2d(in_channels,
143
+ in_channels,
144
+ kernel_size=1,
145
+ stride=1,
146
+ padding=0)
147
+
148
+ def forward(self, x):
149
+ h_ = x
150
+ h_ = self.norm(h_)
151
+ q = self.q(h_)
152
+ k = self.k(h_)
153
+ v = self.v(h_)
154
+
155
+ # compute attention
156
+ b, c, h, w = q.shape
157
+ q = q.reshape(b, c, h*w)
158
+ q = q.permute(0, 2, 1) # b,hw,c
159
+ k = k.reshape(b, c, h*w) # b,c,hw
160
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
161
+ w_ = w_ * (int(c)**(-0.5))
162
+ w_ = torch.nn.functional.softmax(w_, dim=2)
163
+
164
+ # attend to values
165
+ v = v.reshape(b, c, h*w)
166
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
167
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
168
+ h_ = h_.reshape(b, c, h, w)
169
+
170
+ h_ = self.proj_out(h_)
171
+ return x+h_
172
+
173
+
174
+ class Encoder(nn.Module):
175
+ def __init__(self,
176
+ *, # forced to use named arguments
177
+ ch: int,
178
+ out_ch: int,
179
+ ch_mult: Tuple[int] = (1, 2, 4, 8),
180
+ num_res_blocks: int,
181
+ attn_resolutions: Tuple[int],
182
+ pdrop: float = 0.0,
183
+ resamp_with_conv: bool = True,
184
+ in_channels: int,
185
+ resolution: int,
186
+ z_channels: int,
187
+ double_z: Optional[bool] = None) -> None:
188
+ super().__init__()
189
+ self.ch = ch
190
+ self.temb_ch = 0
191
+ self.num_resolutions = len(ch_mult)
192
+ self.num_res_blocks = num_res_blocks
193
+ self.resolution = resolution
194
+ self.in_channels = in_channels
195
+
196
+ # downsampling
197
+ self.conv_in = torch.nn.Conv2d(in_channels,
198
+ self.ch,
199
+ kernel_size=3,
200
+ stride=1,
201
+ padding=1)
202
+
203
+ curr_res = resolution
204
+ in_ch_mult = (1,)+tuple(ch_mult)
205
+ self.down = nn.ModuleList()
206
+ for i_level in range(self.num_resolutions):
207
+ block = nn.ModuleList()
208
+ attn = nn.ModuleList()
209
+ block_in = ch*in_ch_mult[i_level]
210
+ block_out = ch*ch_mult[i_level]
211
+ for i_block in range(self.num_res_blocks):
212
+ block.append(ResnetBlock(in_channels=block_in,
213
+ out_channels=block_out,
214
+ temb_channels=self.temb_ch,
215
+ dropout=pdrop))
216
+ block_in = block_out
217
+ if curr_res in attn_resolutions:
218
+ attn.append(AttnBlock(block_in))
219
+ down = nn.Module()
220
+ down.block = block
221
+ down.attn = attn
222
+ if i_level != self.num_resolutions-1:
223
+ down.downsample = Downsample(block_in, resamp_with_conv)
224
+ curr_res = curr_res // 2
225
+ self.down.append(down)
226
+
227
+ # middle
228
+ self.mid = nn.Module()
229
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
230
+ out_channels=block_in,
231
+ temb_channels=self.temb_ch,
232
+ dropout=pdrop)
233
+ self.mid.attn_1 = AttnBlock(block_in)
234
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
235
+ out_channels=block_in,
236
+ temb_channels=self.temb_ch,
237
+ dropout=pdrop)
238
+
239
+ # end
240
+ self.norm_out = Normalize(block_in)
241
+ self.conv_out = torch.nn.Conv2d(block_in,
242
+ 2*z_channels if double_z else z_channels,
243
+ kernel_size=3,
244
+ stride=1,
245
+ padding=1)
246
+
247
+ def forward(self, x):
248
+ assert x.shape[2] == x.shape[3] == self.resolution, \
249
+ "{}, {}".format(x.shape, self.resolution)
250
+
251
+ # downsampling
252
+ h = self.conv_in(x)
253
+ for i_level in range(self.num_resolutions):
254
+ for i_block in range(self.num_res_blocks):
255
+ h = self.down[i_level].block[i_block](h)
256
+ if len(self.down[i_level].attn) > 0:
257
+ h = self.down[i_level].attn[i_block](h)
258
+ if i_level != self.num_resolutions-1:
259
+ h = self.down[i_level].downsample(h)
260
+
261
+ # middle
262
+ h = self.mid.block_1(h)
263
+ h = self.mid.attn_1(h)
264
+ h = self.mid.block_2(h)
265
+
266
+ # end
267
+ h = self.norm_out(h)
268
+ h = nonlinearity(h)
269
+ h = self.conv_out(h)
270
+ return h
271
+
272
+
273
+ class Decoder(nn.Module):
274
+ def __init__(self,
275
+ *, # forced to use named arguments
276
+ ch: int,
277
+ out_ch: int,
278
+ ch_mult: Tuple[int] = (1, 2, 4, 8),
279
+ num_res_blocks: int,
280
+ attn_resolutions: Tuple[int],
281
+ pdrop: float = 0.0,
282
+ resamp_with_conv: bool = True,
283
+ in_channels: int,
284
+ resolution: int,
285
+ z_channels: int,
286
+ double_z: bool) -> None:
287
+ super().__init__()
288
+ self.ch = ch
289
+ self.temb_ch = 0
290
+ self.num_resolutions = len(ch_mult)
291
+ self.num_res_blocks = num_res_blocks
292
+ self.resolution = resolution
293
+ self.in_channels = in_channels
294
+
295
+ # compute in_ch_mult, block_in and curr_res at lowest res
296
+ block_in = ch*ch_mult[self.num_resolutions-1]
297
+ curr_res = resolution // 2**(self.num_resolutions-1)
298
+ self.z_shape = (1, z_channels, curr_res, curr_res)
299
+
300
+ # z to block_in
301
+ self.conv_in = torch.nn.Conv2d(z_channels,
302
+ block_in,
303
+ kernel_size=3,
304
+ stride=1,
305
+ padding=1)
306
+
307
+ # middle
308
+ self.mid = nn.Module()
309
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
310
+ out_channels=block_in,
311
+ temb_channels=self.temb_ch,
312
+ dropout=pdrop)
313
+ self.mid.attn_1 = AttnBlock(block_in)
314
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
315
+ out_channels=block_in,
316
+ temb_channels=self.temb_ch,
317
+ dropout=pdrop)
318
+
319
+ # upsampling
320
+ self.up = nn.ModuleList()
321
+ for i_level in reversed(range(self.num_resolutions)):
322
+ block = nn.ModuleList()
323
+ attn = nn.ModuleList()
324
+ block_out = ch*ch_mult[i_level]
325
+ for i_block in range(self.num_res_blocks+1):
326
+ block.append(ResnetBlock(in_channels=block_in,
327
+ out_channels=block_out,
328
+ temb_channels=self.temb_ch,
329
+ dropout=pdrop))
330
+ block_in = block_out
331
+ if curr_res in attn_resolutions:
332
+ attn.append(AttnBlock(block_in))
333
+ up = nn.Module()
334
+ up.block = block
335
+ up.attn = attn
336
+ if i_level != 0:
337
+ up.upsample = Upsample(block_in, resamp_with_conv)
338
+ curr_res = curr_res * 2
339
+ self.up.insert(0, up) # prepend to get consistent order
340
+
341
+ # end
342
+ self.norm_out = Normalize(block_in)
343
+ self.conv_out = torch.nn.Conv2d(block_in,
344
+ out_ch,
345
+ kernel_size=3,
346
+ stride=1,
347
+ padding=1)
348
+
349
+ def forward(self, z):
350
+ assert z.shape[1:] == self.z_shape[1:]
351
+ self.last_z_shape = z.shape
352
+
353
+ # z to block_in
354
+ h = self.conv_in(z)
355
+
356
+ # middle
357
+ h = self.mid.block_1(h)
358
+ h = self.mid.attn_1(h)
359
+ h = self.mid.block_2(h)
360
+
361
+ # upsampling
362
+ for i_level in reversed(range(self.num_resolutions)):
363
+ for i_block in range(self.num_res_blocks+1):
364
+ h = self.up[i_level].block[i_block](h)
365
+ if len(self.up[i_level].attn) > 0:
366
+ h = self.up[i_level].attn[i_block](h)
367
+ if i_level != 0:
368
+ h = self.up[i_level].upsample(h)
369
+
370
+ h = self.norm_out(h)
371
+ h = nonlinearity(h)
372
+ h = self.conv_out(h)
373
+ return h