frankleeeee commited on
Commit
5c162ac
1 Parent(s): 9d245c1

Upload STDiT2

Browse files
Files changed (6) hide show
  1. config.json +39 -0
  2. configuration_stdit2.py +51 -0
  3. layers.py +652 -0
  4. model.safetensors +3 -0
  5. modeling_stdit2.py +327 -0
  6. utils.py +90 -0
config.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "STDiT2"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_stdit2.STDiT2Config",
7
+ "AutoModel": "modeling_stdit2.STDiT2"
8
+ },
9
+ "caption_channels": 4096,
10
+ "class_dropout_prob": 0.1,
11
+ "depth": 28,
12
+ "drop_path": 0.0,
13
+ "enable_flashattn": false,
14
+ "enable_layernorm_kernel": false,
15
+ "enable_sequence_parallelism": false,
16
+ "freeze": null,
17
+ "hidden_size": 1152,
18
+ "in_channels": 4,
19
+ "input_size": [
20
+ null,
21
+ null,
22
+ null
23
+ ],
24
+ "input_sq_size": 512,
25
+ "mlp_ratio": 4.0,
26
+ "model_max_length": 120,
27
+ "model_type": "stdit2",
28
+ "no_temporal_pos_emb": false,
29
+ "num_heads": 16,
30
+ "patch_size": [
31
+ 1,
32
+ 2,
33
+ 2
34
+ ],
35
+ "pred_sigma": true,
36
+ "qk_norm": true,
37
+ "torch_dtype": "float32",
38
+ "transformers_version": "4.40.1"
39
+ }
configuration_stdit2.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import PretrainedConfig
3
+
4
+
5
+ class STDiT2Config(PretrainedConfig):
6
+
7
+ model_type = "stdit2"
8
+
9
+ def __init__(
10
+ self,
11
+ input_size=(None, None, None),
12
+ input_sq_size=32,
13
+ in_channels=4,
14
+ patch_size=(1, 2, 2),
15
+ hidden_size=1152,
16
+ depth=28,
17
+ num_heads=16,
18
+ mlp_ratio=4.0,
19
+ class_dropout_prob=0.1,
20
+ pred_sigma=True,
21
+ drop_path=0.0,
22
+ no_temporal_pos_emb=False,
23
+ caption_channels=4096,
24
+ model_max_length=120,
25
+ freeze=None,
26
+ qk_norm=False,
27
+ enable_flashattn=False,
28
+ enable_layernorm_kernel=False,
29
+ enable_sequence_parallelism=False,
30
+ **kwargs,
31
+ ):
32
+ self.input_size = input_size
33
+ self.input_sq_size = input_sq_size
34
+ self.in_channels = in_channels
35
+ self.patch_size = patch_size
36
+ self.hidden_size = hidden_size
37
+ self.depth = depth
38
+ self.num_heads = num_heads
39
+ self.mlp_ratio = mlp_ratio
40
+ self.class_dropout_prob = class_dropout_prob
41
+ self.pred_sigma = pred_sigma
42
+ self.drop_path = drop_path
43
+ self.no_temporal_pos_emb = no_temporal_pos_emb
44
+ self.caption_channels = caption_channels
45
+ self.model_max_length = model_max_length
46
+ self.freeze = freeze
47
+ self.qk_norm = qk_norm
48
+ self.enable_flashattn = enable_flashattn
49
+ self.enable_layernorm_kernel = enable_layernorm_kernel
50
+ self.enable_sequence_parallelism = enable_sequence_parallelism
51
+ super().__init__(**kwargs)
layers.py ADDED
@@ -0,0 +1,652 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import collections.abc
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import functools
8
+
9
+ from einops import rearrange
10
+ from itertools import repeat
11
+ from functools import partial
12
+ from .utils import approx_gelu, get_layernorm, t2i_modulate
13
+ from typing import Optional
14
+
15
+
16
+ try:
17
+ import xformers
18
+ HAS_XFORMERS = True
19
+ except:
20
+ HAS_XFORMERS = False
21
+
22
+
23
+ # =================
24
+ # STDiT2Block
25
+ # =================
26
+ class STDiT2Block(nn.Module):
27
+ def __init__(
28
+ self,
29
+ hidden_size,
30
+ num_heads,
31
+ mlp_ratio=4.0,
32
+ drop_path=0.0,
33
+ enable_flashattn=False,
34
+ enable_layernorm_kernel=False,
35
+ enable_sequence_parallelism=False,
36
+ rope=None,
37
+ qk_norm=False,
38
+ ):
39
+ super().__init__()
40
+ self.hidden_size = hidden_size
41
+ self.enable_flashattn = enable_flashattn
42
+ self._enable_sequence_parallelism = enable_sequence_parallelism
43
+
44
+ assert not self._enable_sequence_parallelism, "Sequence parallelism is not supported."
45
+ if enable_sequence_parallelism:
46
+ self.attn_cls = SeqParallelAttention
47
+ self.mha_cls = SeqParallelMultiHeadCrossAttention
48
+ else:
49
+ self.attn_cls = Attention
50
+ self.mha_cls = MultiHeadCrossAttention
51
+
52
+ # spatial branch
53
+ self.norm1 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
54
+ self.attn = self.attn_cls(
55
+ hidden_size,
56
+ num_heads=num_heads,
57
+ qkv_bias=True,
58
+ enable_flashattn=enable_flashattn,
59
+ qk_norm=qk_norm,
60
+ )
61
+ self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5)
62
+
63
+ # cross attn
64
+ self.cross_attn = self.mha_cls(hidden_size, num_heads)
65
+
66
+ # mlp branch
67
+ self.norm2 = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel)
68
+ self.mlp = Mlp(
69
+ in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0
70
+ )
71
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
72
+
73
+ # temporal branch
74
+ self.norm_temp = get_layernorm(hidden_size, eps=1e-6, affine=False, use_kernel=enable_layernorm_kernel) # new
75
+ self.attn_temp = self.attn_cls(
76
+ hidden_size,
77
+ num_heads=num_heads,
78
+ qkv_bias=True,
79
+ enable_flashattn=self.enable_flashattn,
80
+ rope=rope,
81
+ qk_norm=qk_norm,
82
+ )
83
+ self.scale_shift_table_temporal = nn.Parameter(torch.randn(3, hidden_size) / hidden_size**0.5) # new
84
+
85
+ def t_mask_select(self, x_mask, x, masked_x, T, S):
86
+ # x: [B, (T, S), C]
87
+ # mased_x: [B, (T, S), C]
88
+ # x_mask: [B, T]
89
+ x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
90
+ masked_x = rearrange(masked_x, "B (T S) C -> B T S C", T=T, S=S)
91
+ x = torch.where(x_mask[:, :, None, None], x, masked_x)
92
+ x = rearrange(x, "B T S C -> B (T S) C")
93
+ return x
94
+
95
+ def forward(self, x, y, t, t_tmp, mask=None, x_mask=None, t0=None, t0_tmp=None, T=None, S=None):
96
+ B, N, C = x.shape
97
+
98
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
99
+ self.scale_shift_table[None] + t.reshape(B, 6, -1)
100
+ ).chunk(6, dim=1)
101
+ shift_tmp, scale_tmp, gate_tmp = (self.scale_shift_table_temporal[None] + t_tmp.reshape(B, 3, -1)).chunk(
102
+ 3, dim=1
103
+ )
104
+ if x_mask is not None:
105
+ shift_msa_zero, scale_msa_zero, gate_msa_zero, shift_mlp_zero, scale_mlp_zero, gate_mlp_zero = (
106
+ self.scale_shift_table[None] + t0.reshape(B, 6, -1)
107
+ ).chunk(6, dim=1)
108
+ shift_tmp_zero, scale_tmp_zero, gate_tmp_zero = (
109
+ self.scale_shift_table_temporal[None] + t0_tmp.reshape(B, 3, -1)
110
+ ).chunk(3, dim=1)
111
+
112
+ # modulate
113
+ x_m = t2i_modulate(self.norm1(x), shift_msa, scale_msa)
114
+ if x_mask is not None:
115
+ x_m_zero = t2i_modulate(self.norm1(x), shift_msa_zero, scale_msa_zero)
116
+ x_m = self.t_mask_select(x_mask, x_m, x_m_zero, T, S)
117
+
118
+ # spatial branch
119
+ x_s = rearrange(x_m, "B (T S) C -> (B T) S C", T=T, S=S)
120
+ x_s = self.attn(x_s)
121
+ x_s = rearrange(x_s, "(B T) S C -> B (T S) C", T=T, S=S)
122
+ if x_mask is not None:
123
+ x_s_zero = gate_msa_zero * x_s
124
+ x_s = gate_msa * x_s
125
+ x_s = self.t_mask_select(x_mask, x_s, x_s_zero, T, S)
126
+ else:
127
+ x_s = gate_msa * x_s
128
+ x = x + self.drop_path(x_s)
129
+
130
+ # modulate
131
+ x_m = t2i_modulate(self.norm_temp(x), shift_tmp, scale_tmp)
132
+ if x_mask is not None:
133
+ x_m_zero = t2i_modulate(self.norm_temp(x), shift_tmp_zero, scale_tmp_zero)
134
+ x_m = self.t_mask_select(x_mask, x_m, x_m_zero, T, S)
135
+
136
+ # temporal branch
137
+ x_t = rearrange(x_m, "B (T S) C -> (B S) T C", T=T, S=S)
138
+ x_t = self.attn_temp(x_t)
139
+ x_t = rearrange(x_t, "(B S) T C -> B (T S) C", T=T, S=S)
140
+ if x_mask is not None:
141
+ x_t_zero = gate_tmp_zero * x_t
142
+ x_t = gate_tmp * x_t
143
+ x_t = self.t_mask_select(x_mask, x_t, x_t_zero, T, S)
144
+ else:
145
+ x_t = gate_tmp * x_t
146
+ x = x + self.drop_path(x_t)
147
+
148
+ # cross attn
149
+ x = x + self.cross_attn(x, y, mask)
150
+
151
+ # modulate
152
+ x_m = t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)
153
+ if x_mask is not None:
154
+ x_m_zero = t2i_modulate(self.norm2(x), shift_mlp_zero, scale_mlp_zero)
155
+ x_m = self.t_mask_select(x_mask, x_m, x_m_zero, T, S)
156
+
157
+ # mlp
158
+ x_mlp = self.mlp(x_m)
159
+ if x_mask is not None:
160
+ x_mlp_zero = gate_mlp_zero * x_mlp
161
+ x_mlp = gate_mlp * x_mlp
162
+ x_mlp = self.t_mask_select(x_mask, x_mlp, x_mlp_zero, T, S)
163
+ else:
164
+ x_mlp = gate_mlp * x_mlp
165
+ x = x + self.drop_path(x_mlp)
166
+
167
+ return x
168
+
169
+
170
+ # =================
171
+ # Attention
172
+ # =================
173
+ class LlamaRMSNorm(nn.Module):
174
+ def __init__(self, hidden_size, eps=1e-6):
175
+ """
176
+ LlamaRMSNorm is equivalent to T5LayerNorm
177
+ """
178
+ super().__init__()
179
+ self.weight = nn.Parameter(torch.ones(hidden_size))
180
+ self.variance_epsilon = eps
181
+
182
+ def forward(self, hidden_states):
183
+ input_dtype = hidden_states.dtype
184
+ hidden_states = hidden_states.to(torch.float32)
185
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
186
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
187
+ return self.weight * hidden_states.to(input_dtype)
188
+
189
+ class Attention(nn.Module):
190
+ def __init__(
191
+ self,
192
+ dim: int,
193
+ num_heads: int = 8,
194
+ qkv_bias: bool = False,
195
+ qk_norm: bool = False,
196
+ attn_drop: float = 0.0,
197
+ proj_drop: float = 0.0,
198
+ norm_layer: nn.Module = LlamaRMSNorm,
199
+ enable_flashattn: bool = False,
200
+ rope=None,
201
+ ) -> None:
202
+ super().__init__()
203
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
204
+ self.dim = dim
205
+ self.num_heads = num_heads
206
+ self.head_dim = dim // num_heads
207
+ self.scale = self.head_dim**-0.5
208
+ self.enable_flashattn = enable_flashattn
209
+
210
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
211
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
212
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
213
+ self.attn_drop = nn.Dropout(attn_drop)
214
+ self.proj = nn.Linear(dim, dim)
215
+ self.proj_drop = nn.Dropout(proj_drop)
216
+
217
+ self.rope = False
218
+ if rope is not None:
219
+ self.rope = True
220
+ self.rotary_emb = rope
221
+
222
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
223
+ B, N, C = x.shape
224
+ # flash attn is not memory efficient for small sequences, this is empirical
225
+ enable_flashattn = self.enable_flashattn and (N > B)
226
+ qkv = self.qkv(x)
227
+ qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
228
+
229
+ qkv = qkv.view(qkv_shape).permute(2, 0, 3, 1, 4)
230
+ q, k, v = qkv.unbind(0)
231
+ if self.rope:
232
+ q = self.rotary_emb(q)
233
+ k = self.rotary_emb(k)
234
+ q, k = self.q_norm(q), self.k_norm(k)
235
+
236
+ if enable_flashattn:
237
+ from flash_attn import flash_attn_func
238
+
239
+ # (B, #heads, N, #dim) -> (B, N, #heads, #dim)
240
+ q = q.permute(0, 2, 1, 3)
241
+ k = k.permute(0, 2, 1, 3)
242
+ v = v.permute(0, 2, 1, 3)
243
+ x = flash_attn_func(
244
+ q,
245
+ k,
246
+ v,
247
+ dropout_p=self.attn_drop.p if self.training else 0.0,
248
+ softmax_scale=self.scale,
249
+ )
250
+ else:
251
+ dtype = q.dtype
252
+ q = q * self.scale
253
+ attn = q @ k.transpose(-2, -1) # translate attn to float32
254
+ attn = attn.to(torch.float32)
255
+ attn = attn.softmax(dim=-1)
256
+ attn = attn.to(dtype) # cast back attn to original dtype
257
+ attn = self.attn_drop(attn)
258
+ x = attn @ v
259
+
260
+ x_output_shape = (B, N, C)
261
+ if not enable_flashattn:
262
+ x = x.transpose(1, 2)
263
+ x = x.reshape(x_output_shape)
264
+ x = self.proj(x)
265
+ x = self.proj_drop(x)
266
+ return x
267
+
268
+
269
+ # ========================
270
+ # MultiHeadCrossAttention
271
+ # ========================
272
+ class MultiHeadCrossAttention(nn.Module):
273
+ def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0):
274
+ super(MultiHeadCrossAttention, self).__init__()
275
+ assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
276
+
277
+ self.d_model = d_model
278
+ self.num_heads = num_heads
279
+ self.head_dim = d_model // num_heads
280
+
281
+ self.q_linear = nn.Linear(d_model, d_model)
282
+ self.kv_linear = nn.Linear(d_model, d_model * 2)
283
+ self.attn_drop = nn.Dropout(attn_drop)
284
+ self.proj = nn.Linear(d_model, d_model)
285
+ self.proj_drop = nn.Dropout(proj_drop)
286
+
287
+ def forward(self, x, cond, mask=None):
288
+ # query/value: img tokens; key: condition; mask: if padding tokens
289
+ B, N, C = x.shape
290
+
291
+ q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
292
+ kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
293
+ k, v = kv.unbind(2)
294
+
295
+ attn_bias = None
296
+ if mask is not None:
297
+ attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask)
298
+ x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
299
+
300
+ x = x.view(B, -1, C)
301
+ x = self.proj(x)
302
+ x = self.proj_drop(x)
303
+ return x
304
+
305
+
306
+ # =================
307
+ # Timm Components
308
+ # =================
309
+ def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
310
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
311
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
312
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
313
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
314
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
315
+ 'survival rate' as the argument.
316
+ """
317
+ if drop_prob == 0. or not training:
318
+ return x
319
+ keep_prob = 1 - drop_prob
320
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
321
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
322
+ if keep_prob > 0.0 and scale_by_keep:
323
+ random_tensor.div_(keep_prob)
324
+ return x * random_tensor
325
+
326
+ class DropPath(nn.Module):
327
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
328
+ """
329
+ def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
330
+ super(DropPath, self).__init__()
331
+ self.drop_prob = drop_prob
332
+ self.scale_by_keep = scale_by_keep
333
+
334
+ def forward(self, x):
335
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
336
+
337
+ def extra_repr(self):
338
+ return f'drop_prob={round(self.drop_prob,3):0.3f}'
339
+
340
+ def _ntuple(n):
341
+ def parse(x):
342
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
343
+ return tuple(x)
344
+ return tuple(repeat(x, n))
345
+ return parse
346
+
347
+ to_2tuple = _ntuple(2)
348
+
349
+ class Mlp(nn.Module):
350
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
351
+ """
352
+ def __init__(
353
+ self,
354
+ in_features,
355
+ hidden_features=None,
356
+ out_features=None,
357
+ act_layer=nn.GELU,
358
+ norm_layer=None,
359
+ bias=True,
360
+ drop=0.,
361
+ use_conv=False,
362
+ ):
363
+ super().__init__()
364
+ out_features = out_features or in_features
365
+ hidden_features = hidden_features or in_features
366
+ bias = to_2tuple(bias)
367
+ drop_probs = to_2tuple(drop)
368
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
369
+
370
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
371
+ self.act = act_layer()
372
+ self.drop1 = nn.Dropout(drop_probs[0])
373
+ self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
374
+ self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
375
+ self.drop2 = nn.Dropout(drop_probs[1])
376
+
377
+ def forward(self, x):
378
+ x = self.fc1(x)
379
+ x = self.act(x)
380
+ x = self.drop1(x)
381
+ x = self.norm(x)
382
+ x = self.fc2(x)
383
+ x = self.drop2(x)
384
+ return x
385
+
386
+
387
+ # =================
388
+ # Embedding
389
+ # =================
390
+ class CaptionEmbedder(nn.Module):
391
+ """
392
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
393
+ """
394
+
395
+ def __init__(
396
+ self,
397
+ in_channels,
398
+ hidden_size,
399
+ uncond_prob,
400
+ act_layer=nn.GELU(approximate="tanh"),
401
+ token_num=120,
402
+ ):
403
+ super().__init__()
404
+ self.y_proj = Mlp(
405
+ in_features=in_channels,
406
+ hidden_features=hidden_size,
407
+ out_features=hidden_size,
408
+ act_layer=act_layer,
409
+ drop=0,
410
+ )
411
+ self.register_buffer(
412
+ "y_embedding",
413
+ torch.randn(token_num, in_channels) / in_channels**0.5,
414
+ )
415
+ self.uncond_prob = uncond_prob
416
+
417
+ def token_drop(self, caption, force_drop_ids=None):
418
+ """
419
+ Drops labels to enable classifier-free guidance.
420
+ """
421
+ if force_drop_ids is None:
422
+ drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob
423
+ else:
424
+ drop_ids = force_drop_ids == 1
425
+ caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
426
+ return caption
427
+
428
+ def forward(self, caption, train, force_drop_ids=None):
429
+ if train:
430
+ assert caption.shape[2:] == self.y_embedding.shape
431
+ use_dropout = self.uncond_prob > 0
432
+ if (train and use_dropout) or (force_drop_ids is not None):
433
+ caption = self.token_drop(caption, force_drop_ids)
434
+ caption = self.y_proj(caption)
435
+ return caption
436
+
437
+
438
+ class PatchEmbed3D(nn.Module):
439
+ """Video to Patch Embedding.
440
+
441
+ Args:
442
+ patch_size (int): Patch token size. Default: (2,4,4).
443
+ in_chans (int): Number of input video channels. Default: 3.
444
+ embed_dim (int): Number of linear projection output channels. Default: 96.
445
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
446
+ """
447
+
448
+ def __init__(
449
+ self,
450
+ patch_size=(2, 4, 4),
451
+ in_chans=3,
452
+ embed_dim=96,
453
+ norm_layer=None,
454
+ flatten=True,
455
+ ):
456
+ super().__init__()
457
+ self.patch_size = patch_size
458
+ self.flatten = flatten
459
+
460
+ self.in_chans = in_chans
461
+ self.embed_dim = embed_dim
462
+
463
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
464
+ if norm_layer is not None:
465
+ self.norm = norm_layer(embed_dim)
466
+ else:
467
+ self.norm = None
468
+
469
+ def forward(self, x):
470
+ """Forward function."""
471
+ # padding
472
+ _, _, D, H, W = x.size()
473
+ if W % self.patch_size[2] != 0:
474
+ x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
475
+ if H % self.patch_size[1] != 0:
476
+ x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
477
+ if D % self.patch_size[0] != 0:
478
+ x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))
479
+
480
+ x = self.proj(x) # (B C T H W)
481
+ if self.norm is not None:
482
+ D, Wh, Ww = x.size(2), x.size(3), x.size(4)
483
+ x = x.flatten(2).transpose(1, 2)
484
+ x = self.norm(x)
485
+ x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
486
+ if self.flatten:
487
+ x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC
488
+ return x
489
+
490
+ class T2IFinalLayer(nn.Module):
491
+ """
492
+ The final layer of PixArt.
493
+ """
494
+
495
+ def __init__(self, hidden_size, num_patch, out_channels, d_t=None, d_s=None):
496
+ super().__init__()
497
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
498
+ self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True)
499
+ self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size**0.5)
500
+ self.out_channels = out_channels
501
+ self.d_t = d_t
502
+ self.d_s = d_s
503
+
504
+ def t_mask_select(self, x_mask, x, masked_x, T, S):
505
+ # x: [B, (T, S), C]
506
+ # mased_x: [B, (T, S), C]
507
+ # x_mask: [B, T]
508
+ x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
509
+ masked_x = rearrange(masked_x, "B (T S) C -> B T S C", T=T, S=S)
510
+ x = torch.where(x_mask[:, :, None, None], x, masked_x)
511
+ x = rearrange(x, "B T S C -> B (T S) C")
512
+ return x
513
+
514
+ def forward(self, x, t, x_mask=None, t0=None, T=None, S=None):
515
+ if T is None:
516
+ T = self.d_t
517
+ if S is None:
518
+ S = self.d_s
519
+ shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1)
520
+ x = t2i_modulate(self.norm_final(x), shift, scale)
521
+ if x_mask is not None:
522
+ shift_zero, scale_zero = (self.scale_shift_table[None] + t0[:, None]).chunk(2, dim=1)
523
+ x_zero = t2i_modulate(self.norm_final(x), shift_zero, scale_zero)
524
+ x = self.t_mask_select(x_mask, x, x_zero, T, S)
525
+ x = self.linear(x)
526
+ return x
527
+
528
+ class TimestepEmbedder(nn.Module):
529
+ """
530
+ Embeds scalar timesteps into vector representations.
531
+ """
532
+
533
+ def __init__(self, hidden_size, frequency_embedding_size=256):
534
+ super().__init__()
535
+ self.mlp = nn.Sequential(
536
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
537
+ nn.SiLU(),
538
+ nn.Linear(hidden_size, hidden_size, bias=True),
539
+ )
540
+ self.frequency_embedding_size = frequency_embedding_size
541
+
542
+ @staticmethod
543
+ def timestep_embedding(t, dim, max_period=10000):
544
+ """
545
+ Create sinusoidal timestep embeddings.
546
+ :param t: a 1-D Tensor of N indices, one per batch element.
547
+ These may be fractional.
548
+ :param dim: the dimension of the output.
549
+ :param max_period: controls the minimum frequency of the embeddings.
550
+ :return: an (N, D) Tensor of positional embeddings.
551
+ """
552
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
553
+ half = dim // 2
554
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half)
555
+ freqs = freqs.to(device=t.device)
556
+ args = t[:, None].float() * freqs[None]
557
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
558
+ if dim % 2:
559
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
560
+ return embedding
561
+
562
+ def forward(self, t, dtype):
563
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
564
+ if t_freq.dtype != dtype:
565
+ t_freq = t_freq.to(dtype)
566
+ t_emb = self.mlp(t_freq)
567
+ return t_emb
568
+
569
+ class SizeEmbedder(TimestepEmbedder):
570
+ """
571
+ Embeds scalar timesteps into vector representations.
572
+ """
573
+
574
+ def __init__(self, hidden_size, frequency_embedding_size=256):
575
+ super().__init__(hidden_size=hidden_size, frequency_embedding_size=frequency_embedding_size)
576
+ self.mlp = nn.Sequential(
577
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
578
+ nn.SiLU(),
579
+ nn.Linear(hidden_size, hidden_size, bias=True),
580
+ )
581
+ self.frequency_embedding_size = frequency_embedding_size
582
+ self.outdim = hidden_size
583
+
584
+ def forward(self, s, bs):
585
+ if s.ndim == 1:
586
+ s = s[:, None]
587
+ assert s.ndim == 2
588
+ if s.shape[0] != bs:
589
+ s = s.repeat(bs // s.shape[0], 1)
590
+ assert s.shape[0] == bs
591
+ b, dims = s.shape[0], s.shape[1]
592
+ s = rearrange(s, "b d -> (b d)")
593
+ s_freq = self.timestep_embedding(s, self.frequency_embedding_size).to(self.dtype)
594
+ s_emb = self.mlp(s_freq)
595
+ s_emb = rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
596
+ return s_emb
597
+
598
+ @property
599
+ def dtype(self):
600
+ return next(self.parameters()).dtype
601
+
602
+
603
+ class PositionEmbedding2D(nn.Module):
604
+ def __init__(self, dim: int) -> None:
605
+ super().__init__()
606
+ self.dim = dim
607
+ assert dim % 4 == 0, "dim must be divisible by 4"
608
+ half_dim = dim // 2
609
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, half_dim, 2).float() / half_dim))
610
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
611
+
612
+ def _get_sin_cos_emb(self, t: torch.Tensor):
613
+ out = torch.einsum("i,d->id", t, self.inv_freq)
614
+ emb_cos = torch.cos(out)
615
+ emb_sin = torch.sin(out)
616
+ return torch.cat((emb_sin, emb_cos), dim=-1)
617
+
618
+ @functools.lru_cache(maxsize=512)
619
+ def _get_cached_emb(
620
+ self,
621
+ device: torch.device,
622
+ dtype: torch.dtype,
623
+ h: int,
624
+ w: int,
625
+ scale: float = 1.0,
626
+ base_size: Optional[int] = None,
627
+ ):
628
+ grid_h = torch.arange(h, device=device) / scale
629
+ grid_w = torch.arange(w, device=device) / scale
630
+ if base_size is not None:
631
+ grid_h *= base_size / h
632
+ grid_w *= base_size / w
633
+ grid_h, grid_w = torch.meshgrid(
634
+ grid_w,
635
+ grid_h,
636
+ indexing="ij",
637
+ ) # here w goes first
638
+ grid_h = grid_h.t().reshape(-1)
639
+ grid_w = grid_w.t().reshape(-1)
640
+ emb_h = self._get_sin_cos_emb(grid_h)
641
+ emb_w = self._get_sin_cos_emb(grid_w)
642
+ return torch.concat([emb_h, emb_w], dim=-1).unsqueeze(0).to(dtype)
643
+
644
+ def forward(
645
+ self,
646
+ x: torch.Tensor,
647
+ h: int,
648
+ w: int,
649
+ scale: Optional[float] = 1.0,
650
+ base_size: Optional[int] = None,
651
+ ) -> torch.Tensor:
652
+ return self._get_cached_emb(x.device, x.dtype, h, w, scale, base_size)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e777ae49713478957c48f97eda4405e392e3fd12580e01be944465b741c6521c
3
+ size 3071846872
modeling_stdit2.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.distributed as dist
4
+ import torch.nn as nn
5
+ from einops import rearrange
6
+ from .configuration_stdit2 import STDiT2Config
7
+ from .layers import (
8
+ STDiT2Block,
9
+ CaptionEmbedder,
10
+ PatchEmbed3D,
11
+ T2IFinalLayer,
12
+ TimestepEmbedder,
13
+ SizeEmbedder,
14
+ PositionEmbedding2D
15
+ )
16
+ from rotary_embedding_torch import RotaryEmbedding
17
+ from .utils import (
18
+ get_2d_sincos_pos_embed,
19
+ approx_gelu
20
+ )
21
+ from transformers import PreTrainedModel
22
+
23
+
24
+ class STDiT2(PreTrainedModel):
25
+
26
+ config_class = STDiT2Config
27
+
28
+ def __init__(
29
+ self,
30
+ config: STDiT2Config
31
+ ):
32
+ super().__init__(config)
33
+ self.pred_sigma = config.pred_sigma
34
+ self.in_channels = config.in_channels
35
+ self.out_channels = config.in_channels * 2 if config.pred_sigma else config.in_channels
36
+ self.hidden_size = config.hidden_size
37
+ self.num_heads = config.num_heads
38
+ self.no_temporal_pos_emb = config.no_temporal_pos_emb
39
+ self.depth = config.depth
40
+ self.mlp_ratio = config.mlp_ratio
41
+ self.enable_flashattn = config.enable_flashattn
42
+ self.enable_layernorm_kernel = config.enable_layernorm_kernel
43
+ self.enable_sequence_parallelism = config.enable_sequence_parallelism
44
+
45
+ # support dynamic input
46
+ self.patch_size = config.patch_size
47
+ self.input_size = config.input_size
48
+ self.input_sq_size = config.input_sq_size
49
+ self.pos_embed = PositionEmbedding2D(config.hidden_size)
50
+
51
+ self.x_embedder = PatchEmbed3D(config.patch_size, config.in_channels, config.hidden_size)
52
+ self.t_embedder = TimestepEmbedder(config.hidden_size)
53
+ self.t_block = nn.Sequential(nn.SiLU(), nn.Linear(config.hidden_size, 6 * config.hidden_size, bias=True))
54
+ self.t_block_temp = nn.Sequential(nn.SiLU(), nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=True)) # new
55
+ self.y_embedder = CaptionEmbedder(
56
+ in_channels=config.caption_channels,
57
+ hidden_size=config.hidden_size,
58
+ uncond_prob=config.class_dropout_prob,
59
+ act_layer=approx_gelu,
60
+ token_num=config.model_max_length,
61
+ )
62
+
63
+ drop_path = [x.item() for x in torch.linspace(0, config.drop_path, config.depth)]
64
+ self.rope = RotaryEmbedding(dim=self.hidden_size // self.num_heads) # new
65
+ self.blocks = nn.ModuleList(
66
+ [
67
+ STDiT2Block(
68
+ self.hidden_size,
69
+ self.num_heads,
70
+ mlp_ratio=self.mlp_ratio,
71
+ drop_path=drop_path[i],
72
+ enable_flashattn=self.enable_flashattn,
73
+ enable_layernorm_kernel=self.enable_layernorm_kernel,
74
+ enable_sequence_parallelism=self.enable_sequence_parallelism,
75
+ rope=self.rope.rotate_queries_or_keys,
76
+ qk_norm=config.qk_norm,
77
+ )
78
+ for i in range(self.depth)
79
+ ]
80
+ )
81
+ self.final_layer = T2IFinalLayer(config.hidden_size, np.prod(self.patch_size), self.out_channels)
82
+
83
+ # multi_res
84
+ assert self.hidden_size % 3 == 0, "hidden_size must be divisible by 3"
85
+ self.csize_embedder = SizeEmbedder(self.hidden_size // 3)
86
+ self.ar_embedder = SizeEmbedder(self.hidden_size // 3)
87
+ self.fl_embedder = SizeEmbedder(self.hidden_size) # new
88
+ self.fps_embedder = SizeEmbedder(self.hidden_size) # new
89
+
90
+ # init model
91
+ self.initialize_weights()
92
+ self.initialize_temporal()
93
+ if config.freeze is not None:
94
+ assert config.freeze in ["not_temporal", "text"]
95
+ if config.freeze == "not_temporal":
96
+ self.freeze_not_temporal()
97
+ elif config.freeze == "text":
98
+ self.freeze_text()
99
+
100
+ # sequence parallel related configs
101
+ if self.enable_sequence_parallelism:
102
+ self.sp_rank = dist.get_rank(get_sequence_parallel_group())
103
+ else:
104
+ self.sp_rank = None
105
+
106
+ def get_dynamic_size(self, x):
107
+ _, _, T, H, W = x.size()
108
+ if T % self.patch_size[0] != 0:
109
+ T += self.patch_size[0] - T % self.patch_size[0]
110
+ if H % self.patch_size[1] != 0:
111
+ H += self.patch_size[1] - H % self.patch_size[1]
112
+ if W % self.patch_size[2] != 0:
113
+ W += self.patch_size[2] - W % self.patch_size[2]
114
+ T = T // self.patch_size[0]
115
+ H = H // self.patch_size[1]
116
+ W = W // self.patch_size[2]
117
+ return (T, H, W)
118
+
119
+ def forward(
120
+ self, x, timestep, y, mask=None, x_mask=None, num_frames=None, height=None, width=None, ar=None, fps=None
121
+ ):
122
+ """
123
+ Forward pass of STDiT.
124
+ Args:
125
+ x (torch.Tensor): latent representation of video; of shape [B, C, T, H, W]
126
+ timestep (torch.Tensor): diffusion time steps; of shape [B]
127
+ y (torch.Tensor): representation of prompts; of shape [B, 1, N_token, C]
128
+ mask (torch.Tensor): mask for selecting prompt tokens; of shape [B, N_token]
129
+
130
+ Returns:
131
+ x (torch.Tensor): output latent representation; of shape [B, C, T, H, W]
132
+ """
133
+ B = x.shape[0]
134
+ x = x.to(self.final_layer.linear.weight.dtype)
135
+ timestep = timestep.to(self.final_layer.linear.weight.dtype)
136
+ y = y.to(self.final_layer.linear.weight.dtype)
137
+
138
+
139
+ # === process data info ===
140
+ # 1. get dynamic size
141
+ hw = torch.cat([height[:, None], width[:, None]], dim=1)
142
+ rs = (height[0].item() * width[0].item()) ** 0.5
143
+ csize = self.csize_embedder(hw, B)
144
+
145
+ # 2. get aspect ratio
146
+ ar = ar.unsqueeze(1)
147
+ ar = self.ar_embedder(ar, B)
148
+ data_info = torch.cat([csize, ar], dim=1)
149
+
150
+ # 3. get number of frames
151
+ fl = num_frames.unsqueeze(1)
152
+ fps = fps.unsqueeze(1)
153
+ fl = self.fl_embedder(fl, B)
154
+ fl = fl + self.fps_embedder(fps, B)
155
+
156
+ # === get dynamic shape size ===
157
+ _, _, Tx, Hx, Wx = x.size()
158
+ T, H, W = self.get_dynamic_size(x)
159
+ S = H * W
160
+ scale = rs / self.input_sq_size
161
+ base_size = round(S**0.5)
162
+ pos_emb = self.pos_embed(x, H, W, scale=scale, base_size=base_size)
163
+
164
+ # embedding
165
+ x = self.x_embedder(x) # [B, N, C]
166
+ x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
167
+ x = x + pos_emb
168
+ x = rearrange(x, "B T S C -> B (T S) C")
169
+
170
+ # shard over the sequence dim if sp is enabled
171
+ if self.enable_sequence_parallelism:
172
+ x = split_forward_gather_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="down")
173
+
174
+ # prepare adaIN
175
+ t = self.t_embedder(timestep, dtype=x.dtype) # [B, C]
176
+ t_spc = t + data_info # [B, C]
177
+ t_tmp = t + fl # [B, C]
178
+ t_spc_mlp = self.t_block(t_spc) # [B, 6*C]
179
+ t_tmp_mlp = self.t_block_temp(t_tmp) # [B, 3*C]
180
+ if x_mask is not None:
181
+ t0_timestep = torch.zeros_like(timestep)
182
+ t0 = self.t_embedder(t0_timestep, dtype=x.dtype)
183
+ t0_spc = t0 + data_info
184
+ t0_tmp = t0 + fl
185
+ t0_spc_mlp = self.t_block(t0_spc)
186
+ t0_tmp_mlp = self.t_block_temp(t0_tmp)
187
+ else:
188
+ t0_spc = None
189
+ t0_tmp = None
190
+ t0_spc_mlp = None
191
+ t0_tmp_mlp = None
192
+
193
+ # prepare y
194
+ y = self.y_embedder(y, self.training) # [B, 1, N_token, C]
195
+
196
+ if mask is not None:
197
+ if mask.shape[0] != y.shape[0]:
198
+ mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
199
+ mask = mask.squeeze(1).squeeze(1)
200
+ y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
201
+ y_lens = mask.sum(dim=1).tolist()
202
+ else:
203
+ y_lens = [y.shape[2]] * y.shape[0]
204
+ y = y.squeeze(1).view(1, -1, x.shape[-1])
205
+
206
+ # blocks
207
+ for _, block in enumerate(self.blocks):
208
+ x = block(
209
+ x,
210
+ y,
211
+ t_spc_mlp,
212
+ t_tmp_mlp,
213
+ y_lens,
214
+ x_mask,
215
+ t0_spc_mlp,
216
+ t0_tmp_mlp,
217
+ T,
218
+ S,
219
+ )
220
+
221
+ if self.enable_sequence_parallelism:
222
+ x = gather_forward_split_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="up")
223
+ # x.shape: [B, N, C]
224
+
225
+ # final process
226
+ x = self.final_layer(x, t, x_mask, t0_spc, T, S) # [B, N, C=T_p * H_p * W_p * C_out]
227
+ x = self.unpatchify(x, T, H, W, Tx, Hx, Wx) # [B, C_out, T, H, W]
228
+
229
+ # cast to float32 for better accuracy
230
+ x = x.to(torch.float32)
231
+ return x
232
+
233
+ def unpatchify(self, x, N_t, N_h, N_w, R_t, R_h, R_w):
234
+ """
235
+ Args:
236
+ x (torch.Tensor): of shape [B, N, C]
237
+
238
+ Return:
239
+ x (torch.Tensor): of shape [B, C_out, T, H, W]
240
+ """
241
+
242
+ # N_t, N_h, N_w = [self.input_size[i] // self.patch_size[i] for i in range(3)]
243
+ T_p, H_p, W_p = self.patch_size
244
+ x = rearrange(
245
+ x,
246
+ "B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)",
247
+ N_t=N_t,
248
+ N_h=N_h,
249
+ N_w=N_w,
250
+ T_p=T_p,
251
+ H_p=H_p,
252
+ W_p=W_p,
253
+ C_out=self.out_channels,
254
+ )
255
+ # unpad
256
+ x = x[:, :, :R_t, :R_h, :R_w]
257
+ return x
258
+
259
+ def unpatchify_old(self, x):
260
+ c = self.out_channels
261
+ t, h, w = [self.input_size[i] // self.patch_size[i] for i in range(3)]
262
+ pt, ph, pw = self.patch_size
263
+
264
+ x = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c))
265
+ x = rearrange(x, "n t h w r p q c -> n c t r h p w q")
266
+ imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
267
+ return imgs
268
+
269
+ def get_spatial_pos_embed(self, H, W, scale=1.0, base_size=None):
270
+ pos_embed = get_2d_sincos_pos_embed(
271
+ self.hidden_size,
272
+ (H, W),
273
+ scale=scale,
274
+ base_size=base_size,
275
+ )
276
+ pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
277
+ return pos_embed
278
+
279
+ def freeze_not_temporal(self):
280
+ for n, p in self.named_parameters():
281
+ if "attn_temp" not in n:
282
+ p.requires_grad = False
283
+
284
+ def freeze_text(self):
285
+ for n, p in self.named_parameters():
286
+ if "cross_attn" in n:
287
+ p.requires_grad = False
288
+
289
+ def initialize_temporal(self):
290
+ for block in self.blocks:
291
+ nn.init.constant_(block.attn_temp.proj.weight, 0)
292
+ nn.init.constant_(block.attn_temp.proj.bias, 0)
293
+
294
+ def initialize_weights(self):
295
+ # Initialize transformer layers:
296
+ def _basic_init(module):
297
+ if isinstance(module, nn.Linear):
298
+ torch.nn.init.xavier_uniform_(module.weight)
299
+ if module.bias is not None:
300
+ nn.init.constant_(module.bias, 0)
301
+
302
+ self.apply(_basic_init)
303
+
304
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
305
+ w = self.x_embedder.proj.weight.data
306
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
307
+
308
+ # Initialize timestep embedding MLP:
309
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
310
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
311
+ nn.init.normal_(self.t_block[1].weight, std=0.02)
312
+ nn.init.normal_(self.t_block_temp[1].weight, std=0.02)
313
+
314
+ # Initialize caption embedding MLP:
315
+ nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02)
316
+ nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
317
+
318
+ # Zero-out adaLN modulation layers in PixArt blocks:
319
+ for block in self.blocks:
320
+ nn.init.constant_(block.cross_attn.proj.weight, 0)
321
+ nn.init.constant_(block.cross_attn.proj.bias, 0)
322
+
323
+ # Zero-out output layers:
324
+ nn.init.constant_(self.final_layer.linear.weight, 0)
325
+ nn.init.constant_(self.final_layer.linear.bias, 0)
326
+
327
+
utils.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+
5
+
6
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
7
+
8
+
9
+ def get_layernorm(hidden_size: torch.Tensor, eps: float, affine: bool, use_kernel: bool):
10
+ if use_kernel:
11
+ try:
12
+ from apex.normalization import FusedLayerNorm
13
+
14
+ return FusedLayerNorm(hidden_size, elementwise_affine=affine, eps=eps)
15
+ except ImportError:
16
+ raise RuntimeError("FusedLayerNorm not available. Please install apex.")
17
+ else:
18
+ return nn.LayerNorm(hidden_size, eps, elementwise_affine=affine)
19
+
20
+
21
+ def t2i_modulate(x, shift, scale):
22
+ return x * (1 + scale) + shift
23
+
24
+
25
+ # ===============================================
26
+ # Sine/Cosine Positional Embedding Functions
27
+ # ===============================================
28
+ # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
29
+
30
+
31
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, scale=1.0, base_size=None):
32
+ """
33
+ grid_size: int of the grid height and width
34
+ return:
35
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
36
+ """
37
+ if not isinstance(grid_size, tuple):
38
+ grid_size = (grid_size, grid_size)
39
+
40
+ grid_h = np.arange(grid_size[0], dtype=np.float32) / scale
41
+ grid_w = np.arange(grid_size[1], dtype=np.float32) / scale
42
+ if base_size is not None:
43
+ grid_h *= base_size / grid_size[0]
44
+ grid_w *= base_size / grid_size[1]
45
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
46
+ grid = np.stack(grid, axis=0)
47
+
48
+ grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
49
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
50
+ if cls_token and extra_tokens > 0:
51
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
52
+ return pos_embed
53
+
54
+
55
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
56
+ assert embed_dim % 2 == 0
57
+
58
+ # use half of dimensions to encode grid_h
59
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
60
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
61
+
62
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
63
+ return emb
64
+
65
+
66
+ def get_1d_sincos_pos_embed(embed_dim, length, scale=1.0):
67
+ pos = np.arange(0, length)[..., None] / scale
68
+ return get_1d_sincos_pos_embed_from_grid(embed_dim, pos)
69
+
70
+
71
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
72
+ """
73
+ embed_dim: output dimension for each position
74
+ pos: a list of positions to be encoded: size (M,)
75
+ out: (M, D)
76
+ """
77
+ assert embed_dim % 2 == 0
78
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
79
+ omega /= embed_dim / 2.0
80
+ omega = 1.0 / 10000**omega # (D/2,)
81
+
82
+ pos = pos.reshape(-1) # (M,)
83
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
84
+
85
+ emb_sin = np.sin(out) # (M, D/2)
86
+ emb_cos = np.cos(out) # (M, D/2)
87
+
88
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
89
+ return emb
90
+