frankleeeee commited on
Commit
8ff4a33
1 Parent(s): a0ea5f2

Upload VQVAE

Browse files
Files changed (6) hide show
  1. _utils.py +86 -0
  2. attention.py +567 -0
  3. config.json +21 -0
  4. configuration_vqvae.py +22 -0
  5. model.safetensors +3 -0
  6. modeling_vqvae.py +321 -0
_utils.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Shifts src_tf dim to dest dim
2
+ # i.e. shift_dim(x, 1, -1) would be (b, c, t, h, w) -> (b, t, h, w, c)
3
+ def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True):
4
+ n_dims = len(x.shape)
5
+ if src_dim < 0:
6
+ src_dim = n_dims + src_dim
7
+ if dest_dim < 0:
8
+ dest_dim = n_dims + dest_dim
9
+
10
+ assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims
11
+
12
+ dims = list(range(n_dims))
13
+ del dims[src_dim]
14
+
15
+ permutation = []
16
+ ctr = 0
17
+ for i in range(n_dims):
18
+ if i == dest_dim:
19
+ permutation.append(src_dim)
20
+ else:
21
+ permutation.append(dims[ctr])
22
+ ctr += 1
23
+ x = x.permute(permutation)
24
+ if make_contiguous:
25
+ x = x.contiguous()
26
+ return x
27
+
28
+ # reshapes tensor start from dim i (inclusive)
29
+ # to dim j (exclusive) to the desired shape
30
+ # e.g. if x.shape = (b, thw, c) then
31
+ # view_range(x, 1, 2, (t, h, w)) returns
32
+ # x of shape (b, t, h, w, c)
33
+ def view_range(x, i, j, shape):
34
+ shape = tuple(shape)
35
+
36
+ n_dims = len(x.shape)
37
+ if i < 0:
38
+ i = n_dims + i
39
+
40
+ if j is None:
41
+ j = n_dims
42
+ elif j < 0:
43
+ j = n_dims + j
44
+
45
+ assert 0 <= i < j <= n_dims
46
+
47
+ x_shape = x.shape
48
+ target_shape = x_shape[:i] + shape + x_shape[j:]
49
+ return x.view(target_shape)
50
+
51
+
52
+ def tensor_slice(x, begin, size):
53
+ assert all([b >= 0 for b in begin])
54
+ size = [l - b if s == -1 else s
55
+ for s, b, l in zip(size, begin, x.shape)]
56
+ assert all([s >= 0 for s in size])
57
+
58
+ slices = [slice(b, b + s) for b, s in zip(begin, size)]
59
+ return x[slices]
60
+
61
+
62
+ import math
63
+ import numpy as np
64
+ import skvideo.io
65
+ def save_video_grid(video, fname, nrow=None):
66
+ b, c, t, h, w = video.shape
67
+ video = video.permute(0, 2, 3, 4, 1)
68
+ video = (video.cpu().numpy() * 255).astype('uint8')
69
+
70
+ if nrow is None:
71
+ nrow = math.ceil(math.sqrt(b))
72
+ ncol = math.ceil(b / nrow)
73
+ padding = 1
74
+ video_grid = np.zeros((t, (padding + h) * nrow + padding,
75
+ (padding + w) * ncol + padding, c), dtype='uint8')
76
+ for i in range(b):
77
+ r = i // ncol
78
+ c = i % ncol
79
+
80
+ start_r = (padding + h) * r
81
+ start_c = (padding + w) * c
82
+ video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i]
83
+
84
+ skvideo.io.vwrite(fname, video_grid, inputdict={'-r': '5'})
85
+ print('saved videos to', fname)
86
+
attention.py ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.utils.checkpoint import checkpoint
7
+
8
+ from ._utils import shift_dim, view_range, tensor_slice
9
+
10
+
11
+ class AttentionStack(nn.Module):
12
+ def __init__(
13
+ self, shape, embd_dim, n_head, n_layer, dropout,
14
+ attn_type, attn_dropout, class_cond_dim, frame_cond_shape,
15
+ ):
16
+ super().__init__()
17
+ self.shape = shape
18
+ self.embd_dim = embd_dim
19
+ self.use_frame_cond = frame_cond_shape is not None
20
+
21
+ self.right_shift = RightShift(embd_dim)
22
+ self.pos_embd = AddBroadcastPosEmbed(
23
+ shape=shape, embd_dim=embd_dim
24
+ )
25
+
26
+ self.attn_nets = nn.ModuleList(
27
+ [
28
+ AttentionBlock(
29
+ shape=shape,
30
+ embd_dim=embd_dim,
31
+ n_head=n_head,
32
+ n_layer=n_layer,
33
+ dropout=dropout,
34
+ attn_type=attn_type,
35
+ attn_dropout=attn_dropout,
36
+ class_cond_dim=class_cond_dim,
37
+ frame_cond_shape=frame_cond_shape
38
+ )
39
+ for i in range(n_layer)
40
+ ]
41
+ )
42
+
43
+ def forward(self, x, cond, decode_step, decode_idx):
44
+ """
45
+ Args
46
+ ------
47
+ x: (b, d1, d2, ..., dn, embd_dim)
48
+ cond: a dictionary of conditioning tensors
49
+
50
+ (below is used only when sampling for fast decoding)
51
+ decode: the enumerated rasterscan order of the current idx being sampled
52
+ decode_step: a tuple representing the current idx being sampled
53
+ """
54
+ x = self.right_shift(x, decode_step)
55
+ x = self.pos_embd(x, decode_step, decode_idx)
56
+ for net in self.attn_nets:
57
+ x = net(x, cond, decode_step, decode_idx)
58
+
59
+ return x
60
+
61
+
62
+ class AttentionBlock(nn.Module):
63
+ def __init__(self, shape, embd_dim, n_head, n_layer, dropout,
64
+ attn_type, attn_dropout, class_cond_dim, frame_cond_shape):
65
+ super().__init__()
66
+ self.use_frame_cond = frame_cond_shape is not None
67
+
68
+ self.pre_attn_norm = LayerNorm(embd_dim, class_cond_dim)
69
+ self.post_attn_dp = nn.Dropout(dropout)
70
+ self.attn = MultiHeadAttention(shape, embd_dim, embd_dim, n_head,
71
+ n_layer, causal=True, attn_type=attn_type,
72
+ attn_kwargs=dict(attn_dropout=attn_dropout))
73
+
74
+ if frame_cond_shape is not None:
75
+ enc_len = np.prod(frame_cond_shape[:-1])
76
+ self.pre_enc_norm = LayerNorm(embd_dim, class_cond_dim)
77
+ self.post_enc_dp = nn.Dropout(dropout)
78
+ self.enc_attn = MultiHeadAttention(shape, embd_dim, frame_cond_shape[-1],
79
+ n_head, n_layer, attn_type='full',
80
+ attn_kwargs=dict(attn_dropout=0.), causal=False)
81
+
82
+ self.pre_fc_norm = LayerNorm(embd_dim, class_cond_dim)
83
+ self.post_fc_dp = nn.Dropout(dropout)
84
+ self.fc_block = nn.Sequential(
85
+ nn.Linear(in_features=embd_dim, out_features=embd_dim * 4),
86
+ GeLU2(),
87
+ nn.Linear(in_features=embd_dim * 4, out_features=embd_dim),
88
+ )
89
+
90
+ def forward(self, x, cond, decode_step, decode_idx):
91
+ h = self.pre_attn_norm(x, cond)
92
+ if self.training:
93
+ h = checkpoint(self.attn, h, h, h, decode_step, decode_idx)
94
+ else:
95
+ h = self.attn(h, h, h, decode_step, decode_idx)
96
+ h = self.post_attn_dp(h)
97
+ x = x + h
98
+
99
+ if self.use_frame_cond:
100
+ h = self.pre_enc_norm(x, cond)
101
+ if self.training:
102
+ h = checkpoint(self.enc_attn, h, cond['frame_cond'], cond['frame_cond'],
103
+ decode_step, decode_idx)
104
+ else:
105
+ h = self.enc_attn(h, cond['frame_cond'], cond['frame_cond'],
106
+ decode_step, decode_idx)
107
+ h = self.post_enc_dp(h)
108
+ x = x + h
109
+
110
+ h = self.pre_fc_norm(x, cond)
111
+ if self.training:
112
+ h = checkpoint(self.fc_block, h)
113
+ else:
114
+ h = self.fc_block(h)
115
+ h = self.post_fc_dp(h)
116
+ x = x + h
117
+
118
+ return x
119
+
120
+
121
+ class MultiHeadAttention(nn.Module):
122
+ def __init__(self, shape, dim_q, dim_kv, n_head, n_layer,
123
+ causal, attn_type, attn_kwargs):
124
+ super().__init__()
125
+ self.causal = causal
126
+ self.shape = shape
127
+
128
+ self.d_k = dim_q // n_head
129
+ self.d_v = dim_kv // n_head
130
+ self.n_head = n_head
131
+
132
+ self.w_qs = nn.Linear(dim_q, n_head * self.d_k, bias=False) # q
133
+ self.w_qs.weight.data.normal_(std=1.0 / np.sqrt(dim_q))
134
+
135
+ self.w_ks = nn.Linear(dim_kv, n_head * self.d_k, bias=False) # k
136
+ self.w_ks.weight.data.normal_(std=1.0 / np.sqrt(dim_kv))
137
+
138
+ self.w_vs = nn.Linear(dim_kv, n_head * self.d_v, bias=False) # v
139
+ self.w_vs.weight.data.normal_(std=1.0 / np.sqrt(dim_kv))
140
+
141
+ self.fc = nn.Linear(n_head * self.d_v, dim_q, bias=True) # c
142
+ self.fc.weight.data.normal_(std=1.0 / np.sqrt(dim_q * n_layer))
143
+
144
+ if attn_type == 'full':
145
+ self.attn = FullAttention(shape, causal, **attn_kwargs)
146
+ elif attn_type == 'axial':
147
+ assert not causal, 'causal axial attention is not supported'
148
+ self.attn = AxialAttention(len(shape), **attn_kwargs)
149
+ elif attn_type == 'sparse':
150
+ self.attn = SparseAttention(shape, n_head, causal, **attn_kwargs)
151
+
152
+ self.cache = None
153
+
154
+ def forward(self, q, k, v, decode_step=None, decode_idx=None):
155
+ """ Compute multi-head attention
156
+ Args
157
+ q, k, v: a [b, d1, ..., dn, c] tensor or
158
+ a [b, 1, ..., 1, c] tensor if decode_step is not None
159
+
160
+ Returns
161
+ The output after performing attention
162
+ """
163
+
164
+ # compute k, q, v
165
+ d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
166
+ q = view_range(self.w_qs(q), -1, None, (n_head, d_k))
167
+ k = view_range(self.w_ks(k), -1, None, (n_head, d_k))
168
+ v = view_range(self.w_vs(v), -1, None, (n_head, d_v))
169
+
170
+ # b x n_head x seq_len x d
171
+ # (b, *d_shape, n_head, d) -> (b, n_head, *d_shape, d)
172
+ q = shift_dim(q, -2, 1)
173
+ k = shift_dim(k, -2, 1)
174
+ v = shift_dim(v, -2, 1)
175
+
176
+ # fast decoding
177
+ if decode_step is not None:
178
+ if decode_step == 0:
179
+ if self.causal:
180
+ k_shape = (q.shape[0], n_head, *self.shape, self.d_k)
181
+ v_shape = (q.shape[0], n_head, *self.shape, self.d_v)
182
+ self.cache = dict(k=torch.zeros(k_shape, dtype=k.dtype, device=q.device),
183
+ v=torch.zeros(v_shape, dtype=v.dtype, device=q.device))
184
+ else:
185
+ # cache only once in the non-causal case
186
+ self.cache = dict(k=k.clone(), v=v.clone())
187
+ if self.causal:
188
+ idx = (slice(None, None), slice(None, None), *[slice(i, i+ 1) for i in decode_idx])
189
+ self.cache['k'][idx] = k
190
+ self.cache['v'][idx] = v
191
+ k, v = self.cache['k'], self.cache['v']
192
+
193
+ a = self.attn(q, k, v, decode_step, decode_idx)
194
+
195
+ # (b, *d_shape, n_head, d) -> (b, *d_shape, n_head * d)
196
+ a = shift_dim(a, 1, -2).flatten(start_dim=-2)
197
+ a = self.fc(a) # (b x seq_len x embd_dim)
198
+
199
+ return a
200
+
201
+ ############## Attention #######################
202
+ class FullAttention(nn.Module):
203
+ def __init__(self, shape, causal, attn_dropout):
204
+ super().__init__()
205
+ self.causal = causal
206
+ self.attn_dropout = attn_dropout
207
+
208
+ seq_len = np.prod(shape)
209
+ if self.causal:
210
+ self.register_buffer('mask', torch.tril(torch.ones(seq_len, seq_len)))
211
+
212
+ def forward(self, q, k, v, decode_step, decode_idx):
213
+ mask = self.mask if self.causal else None
214
+ if decode_step is not None and mask is not None:
215
+ mask = mask[[decode_step]]
216
+
217
+ old_shape = q.shape[2:-1]
218
+ q = q.flatten(start_dim=2, end_dim=-2)
219
+ k = k.flatten(start_dim=2, end_dim=-2)
220
+ v = v.flatten(start_dim=2, end_dim=-2)
221
+
222
+ out = scaled_dot_product_attention(q, k, v, mask=mask,
223
+ attn_dropout=self.attn_dropout,
224
+ training=self.training)
225
+
226
+ return view_range(out, 2, 3, old_shape)
227
+
228
+ class AxialAttention(nn.Module):
229
+ def __init__(self, n_dim, axial_dim):
230
+ super().__init__()
231
+ if axial_dim < 0:
232
+ axial_dim = 2 + n_dim + 1 + axial_dim
233
+ else:
234
+ axial_dim += 2 # account for batch, head, dim
235
+ self.axial_dim = axial_dim
236
+
237
+ def forward(self, q, k, v, decode_step, decode_idx):
238
+ q = shift_dim(q, self.axial_dim, -2).flatten(end_dim=-3)
239
+ k = shift_dim(k, self.axial_dim, -2).flatten(end_dim=-3)
240
+ v = shift_dim(v, self.axial_dim, -2)
241
+ old_shape = list(v.shape)
242
+ v = v.flatten(end_dim=-3)
243
+
244
+ out = scaled_dot_product_attention(q, k, v, training=self.training)
245
+ out = out.view(*old_shape)
246
+ out = shift_dim(out, -2, self.axial_dim)
247
+ return out
248
+
249
+
250
+ class SparseAttention(nn.Module):
251
+ ops = dict()
252
+ attn_mask = dict()
253
+ block_layout = dict()
254
+
255
+ def __init__(self, shape, n_head, causal, num_local_blocks=4, block=32,
256
+ attn_dropout=0.): # does not use attn_dropout
257
+ super().__init__()
258
+ self.causal = causal
259
+ self.shape = shape
260
+
261
+ self.sparsity_config = StridedSparsityConfig(shape=shape, n_head=n_head,
262
+ causal=causal, block=block,
263
+ num_local_blocks=num_local_blocks)
264
+
265
+ if self.shape not in SparseAttention.block_layout:
266
+ SparseAttention.block_layout[self.shape] = self.sparsity_config.make_layout()
267
+ if causal and self.shape not in SparseAttention.attn_mask:
268
+ SparseAttention.attn_mask[self.shape] = self.sparsity_config.make_sparse_attn_mask()
269
+
270
+ def get_ops(self):
271
+ try:
272
+ from deepspeed.ops.sparse_attention import MatMul, Softmax
273
+ except:
274
+ raise Exception('Error importing deepspeed. Please install using `DS_BUILD_SPARSE_ATTN=1 pip install deepspeed`')
275
+ if self.shape not in SparseAttention.ops:
276
+ sparsity_layout = self.sparsity_config.make_layout()
277
+ sparse_dot_sdd_nt = MatMul(sparsity_layout,
278
+ self.sparsity_config.block,
279
+ 'sdd',
280
+ trans_a=False,
281
+ trans_b=True)
282
+
283
+ sparse_dot_dsd_nn = MatMul(sparsity_layout,
284
+ self.sparsity_config.block,
285
+ 'dsd',
286
+ trans_a=False,
287
+ trans_b=False)
288
+
289
+ sparse_softmax = Softmax(sparsity_layout, self.sparsity_config.block)
290
+
291
+ SparseAttention.ops[self.shape] = (sparse_dot_sdd_nt,
292
+ sparse_dot_dsd_nn,
293
+ sparse_softmax)
294
+ return SparseAttention.ops[self.shape]
295
+
296
+ def forward(self, q, k, v, decode_step, decode_idx):
297
+ if self.training and self.shape not in SparseAttention.ops:
298
+ self.get_ops()
299
+
300
+ SparseAttention.block_layout[self.shape] = SparseAttention.block_layout[self.shape].to(q)
301
+ if self.causal:
302
+ SparseAttention.attn_mask[self.shape] = SparseAttention.attn_mask[self.shape].to(q).type_as(q)
303
+ attn_mask = SparseAttention.attn_mask[self.shape] if self.causal else None
304
+
305
+ old_shape = q.shape[2:-1]
306
+ q = q.flatten(start_dim=2, end_dim=-2)
307
+ k = k.flatten(start_dim=2, end_dim=-2)
308
+ v = v.flatten(start_dim=2, end_dim=-2)
309
+
310
+ if decode_step is not None:
311
+ mask = self.sparsity_config.get_non_block_layout_row(SparseAttention.block_layout[self.shape], decode_step)
312
+ out = scaled_dot_product_attention(q, k, v, mask=mask, training=self.training)
313
+ else:
314
+ if q.shape != k.shape or k.shape != v.shape:
315
+ raise Exception('SparseAttention only support self-attention')
316
+ sparse_dot_sdd_nt, sparse_dot_dsd_nn, sparse_softmax = self.get_ops()
317
+ scaling = float(q.shape[-1]) ** -0.5
318
+
319
+ attn_output_weights = sparse_dot_sdd_nt(q, k)
320
+ if attn_mask is not None:
321
+ attn_output_weights = attn_output_weights.masked_fill(attn_mask == 0,
322
+ float('-inf'))
323
+ attn_output_weights = sparse_softmax(
324
+ attn_output_weights,
325
+ scale=scaling
326
+ )
327
+
328
+ out = sparse_dot_dsd_nn(attn_output_weights, v)
329
+
330
+ return view_range(out, 2, 3, old_shape)
331
+
332
+
333
+ class StridedSparsityConfig(object):
334
+ """
335
+ Strided Sparse configuration specified in https://arxiv.org/abs/1904.10509 that
336
+ generalizes to arbitrary dimensions
337
+ """
338
+ def __init__(self, shape, n_head, causal, block, num_local_blocks):
339
+ self.n_head = n_head
340
+ self.shape = shape
341
+ self.causal = causal
342
+ self.block = block
343
+ self.num_local_blocks = num_local_blocks
344
+
345
+ assert self.num_local_blocks >= 1, 'Must have at least 1 local block'
346
+ assert self.seq_len % self.block == 0, 'seq len must be divisible by block size'
347
+
348
+ self._block_shape = self._compute_block_shape()
349
+ self._block_shape_cum = self._block_shape_cum_sizes()
350
+
351
+ @property
352
+ def seq_len(self):
353
+ return np.prod(self.shape)
354
+
355
+ @property
356
+ def num_blocks(self):
357
+ return self.seq_len // self.block
358
+
359
+ def set_local_layout(self, layout):
360
+ num_blocks = self.num_blocks
361
+ for row in range(0, num_blocks):
362
+ end = min(row + self.num_local_blocks, num_blocks)
363
+ for col in range(
364
+ max(0, row - self.num_local_blocks),
365
+ (row + 1 if self.causal else end)):
366
+ layout[:, row, col] = 1
367
+ return layout
368
+
369
+ def set_global_layout(self, layout):
370
+ num_blocks = self.num_blocks
371
+ n_dim = len(self._block_shape)
372
+ for row in range(num_blocks):
373
+ assert self._to_flattened_idx(self._to_unflattened_idx(row)) == row
374
+ cur_idx = self._to_unflattened_idx(row)
375
+ # no strided attention over last dim
376
+ for d in range(n_dim - 1):
377
+ end = self._block_shape[d]
378
+ for i in range(0, (cur_idx[d] + 1 if self.causal else end)):
379
+ new_idx = list(cur_idx)
380
+ new_idx[d] = i
381
+ new_idx = tuple(new_idx)
382
+
383
+ col = self._to_flattened_idx(new_idx)
384
+ layout[:, row, col] = 1
385
+
386
+ return layout
387
+
388
+ def make_layout(self):
389
+ layout = torch.zeros((self.n_head, self.num_blocks, self.num_blocks), dtype=torch.int64)
390
+ layout = self.set_local_layout(layout)
391
+ layout = self.set_global_layout(layout)
392
+ return layout
393
+
394
+ def make_sparse_attn_mask(self):
395
+ block_layout = self.make_layout()
396
+ assert block_layout.shape[1] == block_layout.shape[2] == self.num_blocks
397
+
398
+ num_dense_blocks = block_layout.sum().item()
399
+ attn_mask = torch.ones(num_dense_blocks, self.block, self.block)
400
+ counter = 0
401
+ for h in range(self.n_head):
402
+ for i in range(self.num_blocks):
403
+ for j in range(self.num_blocks):
404
+ elem = block_layout[h, i, j].item()
405
+ if elem == 1:
406
+ assert i >= j
407
+ if i == j: # need to mask within block on diagonals
408
+ attn_mask[counter] = torch.tril(attn_mask[counter])
409
+ counter += 1
410
+ assert counter == num_dense_blocks
411
+
412
+ return attn_mask.unsqueeze(0)
413
+
414
+ def get_non_block_layout_row(self, block_layout, row):
415
+ block_row = row // self.block
416
+ block_row = block_layout[:, [block_row]] # n_head x 1 x n_blocks
417
+ block_row = block_row.repeat_interleave(self.block, dim=-1)
418
+ block_row[:, :, row + 1:] = 0.
419
+ return block_row
420
+
421
+ ############# Helper functions ##########################
422
+
423
+ def _compute_block_shape(self):
424
+ n_dim = len(self.shape)
425
+ cum_prod = 1
426
+ for i in range(n_dim - 1, -1, -1):
427
+ cum_prod *= self.shape[i]
428
+ if cum_prod > self.block:
429
+ break
430
+ assert cum_prod % self.block == 0
431
+ new_shape = (*self.shape[:i], cum_prod // self.block)
432
+
433
+ assert np.prod(new_shape) == np.prod(self.shape) // self.block
434
+
435
+ return new_shape
436
+
437
+ def _block_shape_cum_sizes(self):
438
+ bs = np.flip(np.array(self._block_shape))
439
+ return tuple(np.flip(np.cumprod(bs)[:-1])) + (1,)
440
+
441
+ def _to_flattened_idx(self, idx):
442
+ assert len(idx) == len(self._block_shape), f"{len(idx)} != {len(self._block_shape)}"
443
+ flat_idx = 0
444
+ for i in range(len(self._block_shape)):
445
+ flat_idx += idx[i] * self._block_shape_cum[i]
446
+ return flat_idx
447
+
448
+ def _to_unflattened_idx(self, flat_idx):
449
+ assert flat_idx < np.prod(self._block_shape)
450
+ idx = []
451
+ for i in range(len(self._block_shape)):
452
+ idx.append(flat_idx // self._block_shape_cum[i])
453
+ flat_idx %= self._block_shape_cum[i]
454
+ return tuple(idx)
455
+
456
+
457
+ ################ Spatiotemporal broadcasted positional embeddings ###############
458
+ class AddBroadcastPosEmbed(nn.Module):
459
+ def __init__(self, shape, embd_dim, dim=-1):
460
+ super().__init__()
461
+ assert dim in [-1, 1] # only first or last dim supported
462
+ self.shape = shape
463
+ self.n_dim = n_dim = len(shape)
464
+ self.embd_dim = embd_dim
465
+ self.dim = dim
466
+
467
+ assert embd_dim % n_dim == 0, f"{embd_dim} % {n_dim} != 0"
468
+ self.emb = nn.ParameterDict({
469
+ f'd_{i}': nn.Parameter(torch.randn(shape[i], embd_dim // n_dim) * 0.01
470
+ if dim == -1 else
471
+ torch.randn(embd_dim // n_dim, shape[i]) * 0.01)
472
+ for i in range(n_dim)
473
+ })
474
+
475
+ def forward(self, x, decode_step=None, decode_idx=None):
476
+ embs = []
477
+ for i in range(self.n_dim):
478
+ e = self.emb[f'd_{i}']
479
+ if self.dim == -1:
480
+ # (1, 1, ..., 1, self.shape[i], 1, ..., -1)
481
+ e = e.view(1, *((1,) * i), self.shape[i], *((1,) * (self.n_dim - i - 1)), -1)
482
+ e = e.expand(1, *self.shape, -1)
483
+ else:
484
+ e = e.view(1, -1, *((1,) * i), self.shape[i], *((1,) * (self.n_dim - i - 1)))
485
+ e = e.expand(1, -1, *self.shape)
486
+ embs.append(e)
487
+
488
+ embs = torch.cat(embs, dim=self.dim)
489
+ if decode_step is not None:
490
+ embs = tensor_slice(embs, [0, *decode_idx, 0],
491
+ [x.shape[0], *(1,) * self.n_dim, x.shape[-1]])
492
+
493
+ return x + embs
494
+
495
+ ################# Helper Functions ###################################
496
+ def scaled_dot_product_attention(q, k, v, mask=None, attn_dropout=0., training=True):
497
+ # Performs scaled dot-product attention over the second to last dimension dn
498
+
499
+ # (b, n_head, d1, ..., dn, d)
500
+ attn = torch.matmul(q, k.transpose(-1, -2))
501
+ attn = attn / np.sqrt(q.shape[-1])
502
+ if mask is not None:
503
+ attn = attn.masked_fill(mask == 0, float('-inf'))
504
+ attn_float = F.softmax(attn, dim=-1)
505
+ attn = attn_float.type_as(attn) # b x n_head x d1 x ... x dn x d
506
+ attn = F.dropout(attn, p=attn_dropout, training=training)
507
+
508
+ a = torch.matmul(attn, v) # b x n_head x d1 x ... x dn x d
509
+
510
+ return a
511
+
512
+
513
+ class RightShift(nn.Module):
514
+ def __init__(self, embd_dim):
515
+ super().__init__()
516
+ self.embd_dim = embd_dim
517
+ self.sos = nn.Parameter(torch.FloatTensor(embd_dim).normal_(std=0.02), requires_grad=True)
518
+
519
+ def forward(self, x, decode_step):
520
+ if decode_step is not None and decode_step > 0:
521
+ return x
522
+
523
+ x_shape = list(x.shape)
524
+ x = x.flatten(start_dim=1, end_dim=-2) # (b, seq_len, embd_dim)
525
+ sos = torch.ones(x_shape[0], 1, self.embd_dim, dtype=torch.float32).to(self.sos) * self.sos
526
+ sos = sos.type_as(x)
527
+ x = torch.cat([sos, x[:, :-1, :]], axis=1)
528
+ x = x.view(*x_shape)
529
+
530
+ return x
531
+
532
+
533
+ class GeLU2(nn.Module):
534
+ def forward(self, x):
535
+ return (1.702 * x).sigmoid() * x
536
+
537
+
538
+ class LayerNorm(nn.Module):
539
+ def __init__(self, embd_dim, class_cond_dim):
540
+ super().__init__()
541
+ self.conditional = class_cond_dim is not None
542
+
543
+ if self.conditional:
544
+ self.w = nn.Linear(class_cond_dim, embd_dim, bias=False)
545
+ nn.init.constant_(self.w.weight.data, 1. / np.sqrt(class_cond_dim))
546
+ self.wb = nn.Linear(class_cond_dim, embd_dim, bias=False)
547
+ else:
548
+ self.g = nn.Parameter(torch.ones(embd_dim, dtype=torch.float32), requires_grad=True)
549
+ self.b = nn.Parameter(torch.zeros(embd_dim, dtype=torch.float32), requires_grad=True)
550
+
551
+ def forward(self, x, cond):
552
+ if self.conditional: # (b, cond_dim)
553
+ g = 1 + self.w(cond['class_cond']).view(x.shape[0], *(1,)*(len(x.shape)-2), x.shape[-1]) # (b, ..., embd_dim)
554
+ b = self.wb(cond['class_cond']).view(x.shape[0], *(1,)*(len(x.shape)-2), x.shape[-1])
555
+ else:
556
+ g = self.g # (embd_dim,)
557
+ b = self.b
558
+
559
+ x_float = x.float()
560
+
561
+ mu = x_float.mean(dim=-1, keepdims=True)
562
+ s = (x_float - mu).square().mean(dim=-1, keepdims=True)
563
+ x_float = (x_float - mu) * (1e-5 + s.rsqrt()) # (b, ..., embd_dim)
564
+ x_float = x_float * g + b
565
+
566
+ x = x_float.type_as(x)
567
+ return x
config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "VQVAE"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_vqvae.VQVAEConfig",
7
+ "AutoModel": "modeling_vqvae.VQVAE"
8
+ },
9
+ "downsample": [
10
+ 2,
11
+ 4,
12
+ 4
13
+ ],
14
+ "embedding_dim": 256,
15
+ "model_type": "VQVAE",
16
+ "n_codes": 2048,
17
+ "n_hiddens": 240,
18
+ "n_res_layers": 4,
19
+ "torch_dtype": "float32",
20
+ "transformers_version": "4.37.2"
21
+ }
configuration_vqvae.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import List
3
+
4
+
5
+ class VQVAEConfig(PretrainedConfig):
6
+ model_type = "VQVAE"
7
+
8
+ def __init__(
9
+ self,
10
+ embedding_dim: int = 256,
11
+ n_codes: int = 2048,
12
+ n_hiddens: int = 240,
13
+ n_res_layers: int = 4,
14
+ downsample: List[int] = [2, 4, 4],
15
+ **kwargs,
16
+ ):
17
+ self.embedding_dim = embedding_dim
18
+ self.n_codes = n_codes
19
+ self.n_hiddens = n_hiddens
20
+ self.n_res_layers = n_res_layers
21
+ self.downsample = downsample
22
+ super().__init__(**kwargs)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9fda02bdef17ca1378a9392adc5b1d9692fa194ccaabff3b8352ce7548af0de
3
+ size 88842260
modeling_vqvae.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import numpy as np
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torch.distributed as dist
9
+ import gdown
10
+
11
+ from .attention import MultiHeadAttention
12
+ from ._utils import shift_dim
13
+ from transformers import PreTrainedModel
14
+ from typing import Tuple
15
+ from .configuration_vqvae import VQVAEConfig
16
+
17
+
18
+ _VQVAE = {
19
+ 'bair_stride4x2x2': '1iIAYJ2Qqrx5Q94s5eIXQYJgAydzvT_8L', # trained on 16 frames of 64 x 64 images
20
+ 'ucf101_stride4x4x4': '1uuB_8WzHP_bbBmfuaIV7PK_Itl3DyHY5', # trained on 16 frames of 128 x 128 images
21
+ 'kinetics_stride4x4x4': '1DOvOZnFAIQmux6hG7pN_HkyJZy3lXbCB', # trained on 16 frames of 128 x 128 images
22
+ 'kinetics_stride2x4x4': '1jvtjjtrtE4cy6pl7DK_zWFEPY3RZt2pB' # trained on 16 frames of 128 x 128 images
23
+ }
24
+
25
+ def download(id, fname, root=None):
26
+ """
27
+ Download the VQVAE weights from Google Drive.
28
+
29
+ Args:
30
+ id (str): the ID of the file to download
31
+ fname (str): the name of the file to save
32
+ root (str): the directory to save the file to
33
+ """
34
+ if root is None:
35
+ root = os.path.expanduser('~/.cache/sora')
36
+ os.makedirs(root, exist_ok=True)
37
+ destination = os.path.join(root, fname)
38
+
39
+ if os.path.exists(destination):
40
+ return destination
41
+
42
+ gdown.download(id=id, output=destination, quiet=False)
43
+ return destination
44
+
45
+
46
+ class VQVAE(PreTrainedModel):
47
+ config_class = VQVAEConfig
48
+
49
+ def __init__(self, config):
50
+ super().__init__(config)
51
+ self.embedding_dim = config.embedding_dim
52
+ self.n_codes = config.n_codes
53
+
54
+ self.encoder = Encoder(config.n_hiddens, config.n_res_layers, config.downsample)
55
+ self.decoder = Decoder(config.n_hiddens, config.n_res_layers, config.downsample)
56
+
57
+ self.pre_vq_conv = SamePadConv3d(config.n_hiddens, config.embedding_dim, 1)
58
+ self.post_vq_conv = SamePadConv3d(config.embedding_dim, config.n_hiddens, 1)
59
+
60
+ self.codebook = Codebook(config.n_codes, config.embedding_dim)
61
+
62
+ @property
63
+ def latent_shape(self):
64
+ input_shape = (self.args.sequence_length, self.args.resolution,
65
+ self.args.resolution)
66
+ return tuple([s // d for s, d in zip(input_shape,
67
+ self.args.downsample)])
68
+
69
+ def encode(self, x, include_embeddings=False):
70
+ h = self.pre_vq_conv(self.encoder(x))
71
+ vq_output = self.codebook(h)
72
+ if include_embeddings:
73
+ return vq_output['encodings'], vq_output['embeddings']
74
+ else:
75
+ return vq_output['encodings']
76
+
77
+ def decode(self, encodings):
78
+ h = F.embedding(encodings, self.codebook.embeddings)
79
+ h = self.post_vq_conv(shift_dim(h, -1, 1))
80
+ return self.decoder(h)
81
+
82
+ def forward(self, x):
83
+ z = self.pre_vq_conv(self.encoder(x))
84
+ vq_output = self.codebook(z)
85
+ x_recon = self.decoder(self.post_vq_conv(vq_output['embeddings']))
86
+ recon_loss = F.mse_loss(x_recon, x) / 0.06
87
+
88
+ return recon_loss, x_recon, vq_output
89
+
90
+
91
+ class AxialBlock(nn.Module):
92
+ def __init__(self, n_hiddens, n_head):
93
+ super().__init__()
94
+ kwargs = dict(shape=(0,) * 3, dim_q=n_hiddens,
95
+ dim_kv=n_hiddens, n_head=n_head,
96
+ n_layer=1, causal=False, attn_type='axial')
97
+ self.attn_w = MultiHeadAttention(attn_kwargs=dict(axial_dim=-2),
98
+ **kwargs)
99
+ self.attn_h = MultiHeadAttention(attn_kwargs=dict(axial_dim=-3),
100
+ **kwargs)
101
+ self.attn_t = MultiHeadAttention(attn_kwargs=dict(axial_dim=-4),
102
+ **kwargs)
103
+
104
+ def forward(self, x):
105
+ x = shift_dim(x, 1, -1)
106
+ x = self.attn_w(x, x, x) + self.attn_h(x, x, x) + self.attn_t(x, x, x)
107
+ x = shift_dim(x, -1, 1)
108
+ return x
109
+
110
+
111
+ class AttentionResidualBlock(nn.Module):
112
+ def __init__(self, n_hiddens):
113
+ super().__init__()
114
+ self.block = nn.Sequential(
115
+ nn.BatchNorm3d(n_hiddens),
116
+ nn.ReLU(),
117
+ SamePadConv3d(n_hiddens, n_hiddens // 2, 3, bias=False),
118
+ nn.BatchNorm3d(n_hiddens // 2),
119
+ nn.ReLU(),
120
+ SamePadConv3d(n_hiddens // 2, n_hiddens, 1, bias=False),
121
+ nn.BatchNorm3d(n_hiddens),
122
+ nn.ReLU(),
123
+ AxialBlock(n_hiddens, 2)
124
+ )
125
+
126
+ def forward(self, x):
127
+ return x + self.block(x)
128
+
129
+ class Codebook(nn.Module):
130
+ def __init__(self, n_codes, embedding_dim):
131
+ super().__init__()
132
+ self.register_buffer('embeddings', torch.randn(n_codes, embedding_dim))
133
+ self.register_buffer('N', torch.zeros(n_codes))
134
+ self.register_buffer('z_avg', self.embeddings.data.clone())
135
+
136
+ self.n_codes = n_codes
137
+ self.embedding_dim = embedding_dim
138
+ self._need_init = True
139
+
140
+ def _tile(self, x):
141
+ d, ew = x.shape
142
+ if d < self.n_codes:
143
+ n_repeats = (self.n_codes + d - 1) // d
144
+ std = 0.01 / np.sqrt(ew)
145
+ x = x.repeat(n_repeats, 1)
146
+ x = x + torch.randn_like(x) * std
147
+ return x
148
+
149
+ def _init_embeddings(self, z):
150
+ # z: [b, c, t, h, w]
151
+ self._need_init = False
152
+ flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2)
153
+ y = self._tile(flat_inputs)
154
+
155
+ d = y.shape[0]
156
+ _k_rand = y[torch.randperm(y.shape[0])][:self.n_codes]
157
+ if dist.is_initialized():
158
+ dist.broadcast(_k_rand, 0)
159
+ self.embeddings.data.copy_(_k_rand)
160
+ self.z_avg.data.copy_(_k_rand)
161
+ self.N.data.copy_(torch.ones(self.n_codes))
162
+
163
+ def forward(self, z):
164
+ # z: [b, c, t, h, w]
165
+ if self._need_init and self.training:
166
+ self._init_embeddings(z)
167
+ flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2)
168
+ distances = (flat_inputs ** 2).sum(dim=1, keepdim=True) \
169
+ - 2 * flat_inputs @ self.embeddings.t() \
170
+ + (self.embeddings.t() ** 2).sum(dim=0, keepdim=True)
171
+
172
+ encoding_indices = torch.argmin(distances, dim=1)
173
+ encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as(flat_inputs)
174
+ encoding_indices = encoding_indices.view(z.shape[0], *z.shape[2:])
175
+
176
+ embeddings = F.embedding(encoding_indices, self.embeddings)
177
+ embeddings = shift_dim(embeddings, -1, 1)
178
+
179
+ commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach())
180
+
181
+ # EMA codebook update
182
+ if self.training:
183
+ n_total = encode_onehot.sum(dim=0)
184
+ encode_sum = flat_inputs.t() @ encode_onehot
185
+ if dist.is_initialized():
186
+ dist.all_reduce(n_total)
187
+ dist.all_reduce(encode_sum)
188
+
189
+ self.N.data.mul_(0.99).add_(n_total, alpha=0.01)
190
+ self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01)
191
+
192
+ n = self.N.sum()
193
+ weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n
194
+ encode_normalized = self.z_avg / weights.unsqueeze(1)
195
+ self.embeddings.data.copy_(encode_normalized)
196
+
197
+ y = self._tile(flat_inputs)
198
+ _k_rand = y[torch.randperm(y.shape[0])][:self.n_codes]
199
+ if dist.is_initialized():
200
+ dist.broadcast(_k_rand, 0)
201
+
202
+ usage = (self.N.view(self.n_codes, 1) >= 1).float()
203
+ self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage))
204
+
205
+ embeddings_st = (embeddings - z).detach() + z
206
+
207
+ avg_probs = torch.mean(encode_onehot, dim=0)
208
+ perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
209
+
210
+ return dict(embeddings=embeddings_st, encodings=encoding_indices,
211
+ commitment_loss=commitment_loss, perplexity=perplexity)
212
+
213
+ def dictionary_lookup(self, encodings):
214
+ embeddings = F.embedding(encodings, self.embeddings)
215
+ return embeddings
216
+
217
+ class Encoder(nn.Module):
218
+ def __init__(self, n_hiddens, n_res_layers, downsample):
219
+ super().__init__()
220
+ n_times_downsample = np.array([int(math.log2(d)) for d in downsample])
221
+ self.convs = nn.ModuleList()
222
+ max_ds = n_times_downsample.max()
223
+ for i in range(max_ds):
224
+ in_channels = 3 if i == 0 else n_hiddens
225
+ stride = tuple([2 if d > 0 else 1 for d in n_times_downsample])
226
+ conv = SamePadConv3d(in_channels, n_hiddens, 4, stride=stride)
227
+ self.convs.append(conv)
228
+ n_times_downsample -= 1
229
+ self.conv_last = SamePadConv3d(in_channels, n_hiddens, kernel_size=3)
230
+
231
+ self.res_stack = nn.Sequential(
232
+ *[AttentionResidualBlock(n_hiddens)
233
+ for _ in range(n_res_layers)],
234
+ nn.BatchNorm3d(n_hiddens),
235
+ nn.ReLU()
236
+ )
237
+
238
+ def forward(self, x):
239
+ h = x
240
+ for conv in self.convs:
241
+ h = F.relu(conv(h))
242
+ h = self.conv_last(h)
243
+ h = self.res_stack(h)
244
+ return h
245
+
246
+
247
+ class Decoder(nn.Module):
248
+ def __init__(self, n_hiddens, n_res_layers, upsample):
249
+ super().__init__()
250
+ self.res_stack = nn.Sequential(
251
+ *[AttentionResidualBlock(n_hiddens)
252
+ for _ in range(n_res_layers)],
253
+ nn.BatchNorm3d(n_hiddens),
254
+ nn.ReLU()
255
+ )
256
+
257
+ n_times_upsample = np.array([int(math.log2(d)) for d in upsample])
258
+ max_us = n_times_upsample.max()
259
+ self.convts = nn.ModuleList()
260
+ for i in range(max_us):
261
+ out_channels = 3 if i == max_us - 1 else n_hiddens
262
+ us = tuple([2 if d > 0 else 1 for d in n_times_upsample])
263
+ convt = SamePadConvTranspose3d(n_hiddens, out_channels, 4,
264
+ stride=us)
265
+ self.convts.append(convt)
266
+ n_times_upsample -= 1
267
+
268
+ def forward(self, x):
269
+ h = self.res_stack(x)
270
+ for i, convt in enumerate(self.convts):
271
+ h = convt(h)
272
+ if i < len(self.convts) - 1:
273
+ h = F.relu(h)
274
+ return h
275
+
276
+
277
+ # Does not support dilation
278
+ class SamePadConv3d(nn.Module):
279
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True):
280
+ super().__init__()
281
+ if isinstance(kernel_size, int):
282
+ kernel_size = (kernel_size,) * 3
283
+ if isinstance(stride, int):
284
+ stride = (stride,) * 3
285
+
286
+ # assumes that the input shape is divisible by stride
287
+ total_pad = tuple([k - s for k, s in zip(kernel_size, stride)])
288
+ pad_input = []
289
+ for p in total_pad[::-1]: # reverse since F.pad starts from last dim
290
+ pad_input.append((p // 2 + p % 2, p // 2))
291
+ pad_input = sum(pad_input, tuple())
292
+ self.pad_input = pad_input
293
+
294
+ self.conv = nn.Conv3d(in_channels, out_channels, kernel_size,
295
+ stride=stride, padding=0, bias=bias)
296
+
297
+ def forward(self, x):
298
+ return self.conv(F.pad(x, self.pad_input))
299
+
300
+
301
+ class SamePadConvTranspose3d(nn.Module):
302
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True):
303
+ super().__init__()
304
+ if isinstance(kernel_size, int):
305
+ kernel_size = (kernel_size,) * 3
306
+ if isinstance(stride, int):
307
+ stride = (stride,) * 3
308
+
309
+ total_pad = tuple([k - s for k, s in zip(kernel_size, stride)])
310
+ pad_input = []
311
+ for p in total_pad[::-1]: # reverse since F.pad starts from last dim
312
+ pad_input.append((p // 2 + p % 2, p // 2))
313
+ pad_input = sum(pad_input, tuple())
314
+ self.pad_input = pad_input
315
+
316
+ self.convt = nn.ConvTranspose3d(in_channels, out_channels, kernel_size,
317
+ stride=stride, bias=bias,
318
+ padding=tuple([k - 1 for k in kernel_size]))
319
+
320
+ def forward(self, x):
321
+ return self.convt(F.pad(x, self.pad_input))