sam-motamed commited on
Commit
42ada15
·
verified ·
1 Parent(s): ee02363

Move cogvideox_transformer3d.py to diffusers/

Browse files
Files changed (1) hide show
  1. cogvideox_transformer3d.py +0 -845
cogvideox_transformer3d.py DELETED
@@ -1,845 +0,0 @@
1
- # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
- # All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import glob
17
- import json
18
- import os
19
- from typing import Any, Dict, Optional, Tuple, Union
20
-
21
- import torch
22
- import torch.nn.functional as F
23
- from diffusers.configuration_utils import ConfigMixin, register_to_config
24
- from diffusers.models.attention import Attention, FeedForward
25
- from diffusers.models.attention_processor import (
26
- AttentionProcessor, CogVideoXAttnProcessor2_0,
27
- FusedCogVideoXAttnProcessor2_0)
28
- from diffusers.models.embeddings import (CogVideoXPatchEmbed,
29
- TimestepEmbedding, Timesteps,
30
- get_3d_sincos_pos_embed)
31
- from diffusers.models.modeling_outputs import Transformer2DModelOutput
32
- from diffusers.models.modeling_utils import ModelMixin
33
- from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
34
- from diffusers.utils import is_torch_version, logging
35
- from diffusers.utils.torch_utils import maybe_allow_in_graph
36
- from torch import nn
37
-
38
- from dist_utils import (get_sequence_parallel_rank,
39
- get_sequence_parallel_world_size,
40
- get_sp_group,
41
- xFuserLongContextAttention)
42
- from dist_utils import CogVideoXMultiGPUsAttnProcessor2_0
43
-
44
-
45
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
46
-
47
-
48
- class CogVideoXPatchEmbed(nn.Module):
49
- def __init__(
50
- self,
51
- patch_size: int = 2,
52
- patch_size_t: Optional[int] = None,
53
- in_channels: int = 16,
54
- embed_dim: int = 1920,
55
- text_embed_dim: int = 4096,
56
- bias: bool = True,
57
- sample_width: int = 90,
58
- sample_height: int = 60,
59
- sample_frames: int = 49,
60
- temporal_compression_ratio: int = 4,
61
- max_text_seq_length: int = 226,
62
- spatial_interpolation_scale: float = 1.875,
63
- temporal_interpolation_scale: float = 1.0,
64
- use_positional_embeddings: bool = True,
65
- use_learned_positional_embeddings: bool = True,
66
- ) -> None:
67
- super().__init__()
68
-
69
- post_patch_height = sample_height // patch_size
70
- post_patch_width = sample_width // patch_size
71
- post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
72
- self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
73
- self.post_patch_height = post_patch_height
74
- self.post_patch_width = post_patch_width
75
- self.post_time_compression_frames = post_time_compression_frames
76
- self.patch_size = patch_size
77
- self.patch_size_t = patch_size_t
78
- self.embed_dim = embed_dim
79
- self.sample_height = sample_height
80
- self.sample_width = sample_width
81
- self.sample_frames = sample_frames
82
- self.temporal_compression_ratio = temporal_compression_ratio
83
- self.max_text_seq_length = max_text_seq_length
84
- self.spatial_interpolation_scale = spatial_interpolation_scale
85
- self.temporal_interpolation_scale = temporal_interpolation_scale
86
- self.use_positional_embeddings = use_positional_embeddings
87
- self.use_learned_positional_embeddings = use_learned_positional_embeddings
88
-
89
- if patch_size_t is None:
90
- # CogVideoX 1.0 checkpoints
91
- self.proj = nn.Conv2d(
92
- in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
93
- )
94
- else:
95
- # CogVideoX 1.5 checkpoints
96
- self.proj = nn.Linear(in_channels * patch_size * patch_size * patch_size_t, embed_dim)
97
-
98
- self.text_proj = nn.Linear(text_embed_dim, embed_dim)
99
-
100
- if use_positional_embeddings or use_learned_positional_embeddings:
101
- persistent = use_learned_positional_embeddings
102
- pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
103
- self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
104
-
105
- def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
106
- post_patch_height = sample_height // self.patch_size
107
- post_patch_width = sample_width // self.patch_size
108
- post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
109
- num_patches = post_patch_height * post_patch_width * post_time_compression_frames
110
-
111
- pos_embedding = get_3d_sincos_pos_embed(
112
- self.embed_dim,
113
- (post_patch_width, post_patch_height),
114
- post_time_compression_frames,
115
- self.spatial_interpolation_scale,
116
- self.temporal_interpolation_scale,
117
- )
118
- pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
119
- joint_pos_embedding = torch.zeros(
120
- 1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
121
- )
122
- joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)
123
-
124
- return joint_pos_embedding
125
-
126
- def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
127
- r"""
128
- Args:
129
- text_embeds (`torch.Tensor`):
130
- Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
131
- image_embeds (`torch.Tensor`):
132
- Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
133
- """
134
- text_embeds = self.text_proj(text_embeds)
135
-
136
- text_batch_size, text_seq_length, text_channels = text_embeds.shape
137
- batch_size, num_frames, channels, height, width = image_embeds.shape
138
-
139
- if self.patch_size_t is None:
140
- image_embeds = image_embeds.reshape(-1, channels, height, width)
141
- image_embeds = self.proj(image_embeds)
142
- image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:])
143
- image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
144
- image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
145
- else:
146
- p = self.patch_size
147
- p_t = self.patch_size_t
148
-
149
- image_embeds = image_embeds.permute(0, 1, 3, 4, 2)
150
- # b, f, h, w, c => b, f // 2, 2, h // 2, 2, w // 2, 2, c
151
- image_embeds = image_embeds.reshape(
152
- batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels
153
- )
154
- # b, f // 2, 2, h // 2, 2, w // 2, 2, c => b, f // 2, h // 2, w // 2, c, 2, 2, 2
155
- image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
156
- image_embeds = self.proj(image_embeds)
157
-
158
- embeds = torch.cat(
159
- [text_embeds, image_embeds], dim=1
160
- ).contiguous() # [batch, seq_length + num_frames x height x width, channels]
161
-
162
- if self.use_positional_embeddings or self.use_learned_positional_embeddings:
163
- seq_length = height * width * num_frames // (self.patch_size**2)
164
- # pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
165
- pos_embeds = self.pos_embedding
166
- emb_size = embeds.size()[-1]
167
- pos_embeds_without_text = pos_embeds[:, text_seq_length: ].view(1, self.post_time_compression_frames, self.post_patch_height, self.post_patch_width, emb_size)
168
- pos_embeds_without_text = pos_embeds_without_text.permute([0, 4, 1, 2, 3])
169
- pos_embeds_without_text = F.interpolate(pos_embeds_without_text,size=[self.post_time_compression_frames, height // self.patch_size, width // self.patch_size], mode='trilinear', align_corners=False)
170
- pos_embeds_without_text = pos_embeds_without_text.permute([0, 2, 3, 4, 1]).view(1, -1, emb_size)
171
- pos_embeds = torch.cat([pos_embeds[:, :text_seq_length], pos_embeds_without_text], dim = 1)
172
- pos_embeds = pos_embeds[:, : text_seq_length + seq_length]
173
- embeds = embeds + pos_embeds
174
-
175
- return embeds
176
-
177
- @maybe_allow_in_graph
178
- class CogVideoXBlock(nn.Module):
179
- r"""
180
- Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
181
-
182
- Parameters:
183
- dim (`int`):
184
- The number of channels in the input and output.
185
- num_attention_heads (`int`):
186
- The number of heads to use for multi-head attention.
187
- attention_head_dim (`int`):
188
- The number of channels in each head.
189
- time_embed_dim (`int`):
190
- The number of channels in timestep embedding.
191
- dropout (`float`, defaults to `0.0`):
192
- The dropout probability to use.
193
- activation_fn (`str`, defaults to `"gelu-approximate"`):
194
- Activation function to be used in feed-forward.
195
- attention_bias (`bool`, defaults to `False`):
196
- Whether or not to use bias in attention projection layers.
197
- qk_norm (`bool`, defaults to `True`):
198
- Whether or not to use normalization after query and key projections in Attention.
199
- norm_elementwise_affine (`bool`, defaults to `True`):
200
- Whether to use learnable elementwise affine parameters for normalization.
201
- norm_eps (`float`, defaults to `1e-5`):
202
- Epsilon value for normalization layers.
203
- final_dropout (`bool` defaults to `False`):
204
- Whether to apply a final dropout after the last feed-forward layer.
205
- ff_inner_dim (`int`, *optional*, defaults to `None`):
206
- Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
207
- ff_bias (`bool`, defaults to `True`):
208
- Whether or not to use bias in Feed-forward layer.
209
- attention_out_bias (`bool`, defaults to `True`):
210
- Whether or not to use bias in Attention output projection layer.
211
- """
212
-
213
- def __init__(
214
- self,
215
- dim: int,
216
- num_attention_heads: int,
217
- attention_head_dim: int,
218
- time_embed_dim: int,
219
- dropout: float = 0.0,
220
- activation_fn: str = "gelu-approximate",
221
- attention_bias: bool = False,
222
- qk_norm: bool = True,
223
- norm_elementwise_affine: bool = True,
224
- norm_eps: float = 1e-5,
225
- final_dropout: bool = True,
226
- ff_inner_dim: Optional[int] = None,
227
- ff_bias: bool = True,
228
- attention_out_bias: bool = True,
229
- ):
230
- super().__init__()
231
-
232
- # 1. Self Attention
233
- self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
234
-
235
- self.attn1 = Attention(
236
- query_dim=dim,
237
- dim_head=attention_head_dim,
238
- heads=num_attention_heads,
239
- qk_norm="layer_norm" if qk_norm else None,
240
- eps=1e-6,
241
- bias=attention_bias,
242
- out_bias=attention_out_bias,
243
- processor=CogVideoXAttnProcessor2_0(),
244
- )
245
-
246
- # 2. Feed Forward
247
- self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
248
-
249
- self.ff = FeedForward(
250
- dim,
251
- dropout=dropout,
252
- activation_fn=activation_fn,
253
- final_dropout=final_dropout,
254
- inner_dim=ff_inner_dim,
255
- bias=ff_bias,
256
- )
257
-
258
- def forward(
259
- self,
260
- hidden_states: torch.Tensor,
261
- encoder_hidden_states: torch.Tensor,
262
- temb: torch.Tensor,
263
- image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
264
- ) -> torch.Tensor:
265
- text_seq_length = encoder_hidden_states.size(1)
266
-
267
- # norm & modulate
268
- norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
269
- hidden_states, encoder_hidden_states, temb
270
- )
271
-
272
- # attention
273
- attn_hidden_states, attn_encoder_hidden_states = self.attn1(
274
- hidden_states=norm_hidden_states,
275
- encoder_hidden_states=norm_encoder_hidden_states,
276
- image_rotary_emb=image_rotary_emb,
277
- )
278
-
279
- hidden_states = hidden_states + gate_msa * attn_hidden_states
280
- encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
281
-
282
- # norm & modulate
283
- norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
284
- hidden_states, encoder_hidden_states, temb
285
- )
286
-
287
- # feed-forward
288
- norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
289
- ff_output = self.ff(norm_hidden_states)
290
-
291
- hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
292
- encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
293
-
294
- return hidden_states, encoder_hidden_states
295
-
296
-
297
- class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
298
- """
299
- A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
300
-
301
- Parameters:
302
- num_attention_heads (`int`, defaults to `30`):
303
- The number of heads to use for multi-head attention.
304
- attention_head_dim (`int`, defaults to `64`):
305
- The number of channels in each head.
306
- in_channels (`int`, defaults to `16`):
307
- The number of channels in the input.
308
- out_channels (`int`, *optional*, defaults to `16`):
309
- The number of channels in the output.
310
- flip_sin_to_cos (`bool`, defaults to `True`):
311
- Whether to flip the sin to cos in the time embedding.
312
- time_embed_dim (`int`, defaults to `512`):
313
- Output dimension of timestep embeddings.
314
- text_embed_dim (`int`, defaults to `4096`):
315
- Input dimension of text embeddings from the text encoder.
316
- num_layers (`int`, defaults to `30`):
317
- The number of layers of Transformer blocks to use.
318
- dropout (`float`, defaults to `0.0`):
319
- The dropout probability to use.
320
- attention_bias (`bool`, defaults to `True`):
321
- Whether or not to use bias in the attention projection layers.
322
- sample_width (`int`, defaults to `90`):
323
- The width of the input latents.
324
- sample_height (`int`, defaults to `60`):
325
- The height of the input latents.
326
- sample_frames (`int`, defaults to `49`):
327
- The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
328
- instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
329
- but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
330
- K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
331
- patch_size (`int`, defaults to `2`):
332
- The size of the patches to use in the patch embedding layer.
333
- temporal_compression_ratio (`int`, defaults to `4`):
334
- The compression ratio across the temporal dimension. See documentation for `sample_frames`.
335
- max_text_seq_length (`int`, defaults to `226`):
336
- The maximum sequence length of the input text embeddings.
337
- activation_fn (`str`, defaults to `"gelu-approximate"`):
338
- Activation function to use in feed-forward.
339
- timestep_activation_fn (`str`, defaults to `"silu"`):
340
- Activation function to use when generating the timestep embeddings.
341
- norm_elementwise_affine (`bool`, defaults to `True`):
342
- Whether or not to use elementwise affine in normalization layers.
343
- norm_eps (`float`, defaults to `1e-5`):
344
- The epsilon value to use in normalization layers.
345
- spatial_interpolation_scale (`float`, defaults to `1.875`):
346
- Scaling factor to apply in 3D positional embeddings across spatial dimensions.
347
- temporal_interpolation_scale (`float`, defaults to `1.0`):
348
- Scaling factor to apply in 3D positional embeddings across temporal dimensions.
349
- """
350
-
351
- _supports_gradient_checkpointing = True
352
-
353
- @register_to_config
354
- def __init__(
355
- self,
356
- num_attention_heads: int = 30,
357
- attention_head_dim: int = 64,
358
- in_channels: int = 16,
359
- out_channels: Optional[int] = 16,
360
- flip_sin_to_cos: bool = True,
361
- freq_shift: int = 0,
362
- time_embed_dim: int = 512,
363
- text_embed_dim: int = 4096,
364
- num_layers: int = 30,
365
- dropout: float = 0.0,
366
- attention_bias: bool = True,
367
- sample_width: int = 90,
368
- sample_height: int = 60,
369
- sample_frames: int = 49,
370
- patch_size: int = 2,
371
- patch_size_t: Optional[int] = None,
372
- temporal_compression_ratio: int = 4,
373
- max_text_seq_length: int = 226,
374
- activation_fn: str = "gelu-approximate",
375
- timestep_activation_fn: str = "silu",
376
- norm_elementwise_affine: bool = True,
377
- norm_eps: float = 1e-5,
378
- spatial_interpolation_scale: float = 1.875,
379
- temporal_interpolation_scale: float = 1.0,
380
- use_rotary_positional_embeddings: bool = False,
381
- use_learned_positional_embeddings: bool = False,
382
- patch_bias: bool = True,
383
- add_noise_in_inpaint_model: bool = False,
384
- ):
385
- super().__init__()
386
- inner_dim = num_attention_heads * attention_head_dim
387
- self.patch_size_t = patch_size_t
388
- if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
389
- raise ValueError(
390
- "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional "
391
- "embeddings. If you're using a custom model and/or believe this should be supported, please open an "
392
- "issue at https://github.com/huggingface/diffusers/issues."
393
- )
394
-
395
- # 1. Patch embedding
396
- self.patch_embed = CogVideoXPatchEmbed(
397
- patch_size=patch_size,
398
- patch_size_t=patch_size_t,
399
- in_channels=in_channels,
400
- embed_dim=inner_dim,
401
- text_embed_dim=text_embed_dim,
402
- bias=patch_bias,
403
- sample_width=sample_width,
404
- sample_height=sample_height,
405
- sample_frames=sample_frames,
406
- temporal_compression_ratio=temporal_compression_ratio,
407
- max_text_seq_length=max_text_seq_length,
408
- spatial_interpolation_scale=spatial_interpolation_scale,
409
- temporal_interpolation_scale=temporal_interpolation_scale,
410
- use_positional_embeddings=not use_rotary_positional_embeddings,
411
- use_learned_positional_embeddings=use_learned_positional_embeddings,
412
- )
413
- self.embedding_dropout = nn.Dropout(dropout)
414
-
415
- # 2. Time embeddings
416
- self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
417
- self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
418
-
419
- # 3. Define spatio-temporal transformers blocks
420
- self.transformer_blocks = nn.ModuleList(
421
- [
422
- CogVideoXBlock(
423
- dim=inner_dim,
424
- num_attention_heads=num_attention_heads,
425
- attention_head_dim=attention_head_dim,
426
- time_embed_dim=time_embed_dim,
427
- dropout=dropout,
428
- activation_fn=activation_fn,
429
- attention_bias=attention_bias,
430
- norm_elementwise_affine=norm_elementwise_affine,
431
- norm_eps=norm_eps,
432
- )
433
- for _ in range(num_layers)
434
- ]
435
- )
436
- self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
437
-
438
- # 4. Output blocks
439
- self.norm_out = AdaLayerNorm(
440
- embedding_dim=time_embed_dim,
441
- output_dim=2 * inner_dim,
442
- norm_elementwise_affine=norm_elementwise_affine,
443
- norm_eps=norm_eps,
444
- chunk_dim=1,
445
- )
446
-
447
- if patch_size_t is None:
448
- # For CogVideox 1.0
449
- output_dim = patch_size * patch_size * out_channels
450
- else:
451
- # For CogVideoX 1.5
452
- output_dim = patch_size * patch_size * patch_size_t * out_channels
453
-
454
- self.proj_out = nn.Linear(inner_dim, output_dim)
455
-
456
- self.gradient_checkpointing = False
457
- self.sp_world_size = 1
458
- self.sp_world_rank = 0
459
-
460
- def _set_gradient_checkpointing(self, module, value=False):
461
- self.gradient_checkpointing = value
462
-
463
- def enable_multi_gpus_inference(self,):
464
- self.sp_world_size = get_sequence_parallel_world_size()
465
- self.sp_world_rank = get_sequence_parallel_rank()
466
- self.set_attn_processor(CogVideoXMultiGPUsAttnProcessor2_0())
467
-
468
- @property
469
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
470
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
471
- r"""
472
- Returns:
473
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
474
- indexed by its weight name.
475
- """
476
- # set recursively
477
- processors = {}
478
-
479
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
480
- if hasattr(module, "get_processor"):
481
- processors[f"{name}.processor"] = module.get_processor()
482
-
483
- for sub_name, child in module.named_children():
484
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
485
-
486
- return processors
487
-
488
- for name, module in self.named_children():
489
- fn_recursive_add_processors(name, module, processors)
490
-
491
- return processors
492
-
493
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
494
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
495
- r"""
496
- Sets the attention processor to use to compute attention.
497
-
498
- Parameters:
499
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
500
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
501
- for **all** `Attention` layers.
502
-
503
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
504
- processor. This is strongly recommended when setting trainable attention processors.
505
-
506
- """
507
- count = len(self.attn_processors.keys())
508
-
509
- if isinstance(processor, dict) and len(processor) != count:
510
- raise ValueError(
511
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
512
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
513
- )
514
-
515
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
516
- if hasattr(module, "set_processor"):
517
- if not isinstance(processor, dict):
518
- module.set_processor(processor)
519
- else:
520
- module.set_processor(processor.pop(f"{name}.processor"))
521
-
522
- for sub_name, child in module.named_children():
523
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
524
-
525
- for name, module in self.named_children():
526
- fn_recursive_attn_processor(name, module, processor)
527
-
528
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
529
- def fuse_qkv_projections(self):
530
- """
531
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
532
- are fused. For cross-attention modules, key and value projection matrices are fused.
533
-
534
- <Tip warning={true}>
535
-
536
- This API is 🧪 experimental.
537
-
538
- </Tip>
539
- """
540
- self.original_attn_processors = None
541
-
542
- for _, attn_processor in self.attn_processors.items():
543
- if "Added" in str(attn_processor.__class__.__name__):
544
- raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
545
-
546
- self.original_attn_processors = self.attn_processors
547
-
548
- for module in self.modules():
549
- if isinstance(module, Attention):
550
- module.fuse_projections(fuse=True)
551
-
552
- self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
553
-
554
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
555
- def unfuse_qkv_projections(self):
556
- """Disables the fused QKV projection if enabled.
557
-
558
- <Tip warning={true}>
559
-
560
- This API is 🧪 experimental.
561
-
562
- </Tip>
563
-
564
- """
565
- if self.original_attn_processors is not None:
566
- self.set_attn_processor(self.original_attn_processors)
567
-
568
- def forward(
569
- self,
570
- hidden_states: torch.Tensor,
571
- encoder_hidden_states: torch.Tensor,
572
- timestep: Union[int, float, torch.LongTensor],
573
- timestep_cond: Optional[torch.Tensor] = None,
574
- inpaint_latents: Optional[torch.Tensor] = None,
575
- control_latents: Optional[torch.Tensor] = None,
576
- image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
577
- return_dict: bool = True,
578
- ):
579
- batch_size, num_frames, channels, height, width = hidden_states.shape
580
- if num_frames == 1 and self.patch_size_t is not None:
581
- hidden_states = torch.cat([hidden_states, torch.zeros_like(hidden_states)], dim=1)
582
- if inpaint_latents is not None:
583
- inpaint_latents = torch.concat([inpaint_latents, torch.zeros_like(inpaint_latents)], dim=1)
584
- if control_latents is not None:
585
- control_latents = torch.concat([control_latents, torch.zeros_like(control_latents)], dim=1)
586
- local_num_frames = num_frames + 1
587
- else:
588
- local_num_frames = num_frames
589
-
590
- # 1. Time embedding
591
- timesteps = timestep
592
- t_emb = self.time_proj(timesteps)
593
-
594
- # timesteps does not contain any weights and will always return f32 tensors
595
- # but time_embedding might actually be running in fp16. so we need to cast here.
596
- # there might be better ways to encapsulate this.
597
- t_emb = t_emb.to(dtype=hidden_states.dtype)
598
- emb = self.time_embedding(t_emb, timestep_cond)
599
-
600
- # 2. Patch embedding
601
- if inpaint_latents is not None:
602
- hidden_states = torch.concat([hidden_states, inpaint_latents], 2)
603
- if control_latents is not None:
604
- hidden_states = torch.concat([hidden_states, control_latents], 2)
605
- hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
606
- hidden_states = self.embedding_dropout(hidden_states)
607
-
608
- text_seq_length = encoder_hidden_states.shape[1]
609
- encoder_hidden_states = hidden_states[:, :text_seq_length]
610
- hidden_states = hidden_states[:, text_seq_length:]
611
-
612
- # Context Parallel
613
- if self.sp_world_size > 1:
614
- hidden_states = torch.chunk(hidden_states, self.sp_world_size, dim=1)[self.sp_world_rank]
615
- if image_rotary_emb is not None:
616
- image_rotary_emb = (
617
- torch.chunk(image_rotary_emb[0], self.sp_world_size, dim=0)[self.sp_world_rank],
618
- torch.chunk(image_rotary_emb[1], self.sp_world_size, dim=0)[self.sp_world_rank]
619
- )
620
-
621
- # 3. Transformer blocks
622
- for i, block in enumerate(self.transformer_blocks):
623
- if torch.is_grad_enabled() and self.gradient_checkpointing:
624
-
625
- def create_custom_forward(module):
626
- def custom_forward(*inputs):
627
- return module(*inputs)
628
-
629
- return custom_forward
630
-
631
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
632
- hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
633
- create_custom_forward(block),
634
- hidden_states,
635
- encoder_hidden_states,
636
- emb,
637
- image_rotary_emb,
638
- **ckpt_kwargs,
639
- )
640
- else:
641
- hidden_states, encoder_hidden_states = block(
642
- hidden_states=hidden_states,
643
- encoder_hidden_states=encoder_hidden_states,
644
- temb=emb,
645
- image_rotary_emb=image_rotary_emb,
646
- )
647
-
648
- if not self.config.use_rotary_positional_embeddings:
649
- # CogVideoX-2B
650
- hidden_states = self.norm_final(hidden_states)
651
- else:
652
- # CogVideoX-5B
653
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
654
- hidden_states = self.norm_final(hidden_states)
655
- hidden_states = hidden_states[:, text_seq_length:]
656
-
657
- # 4. Final block
658
- hidden_states = self.norm_out(hidden_states, temb=emb)
659
- hidden_states = self.proj_out(hidden_states)
660
-
661
- if self.sp_world_size > 1:
662
- hidden_states = get_sp_group().all_gather(hidden_states, dim=1)
663
-
664
- # 5. Unpatchify
665
- p = self.config.patch_size
666
- p_t = self.config.patch_size_t
667
-
668
- if p_t is None:
669
- output = hidden_states.reshape(batch_size, local_num_frames, height // p, width // p, -1, p, p)
670
- output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
671
- else:
672
- output = hidden_states.reshape(
673
- batch_size, (local_num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
674
- )
675
- output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
676
-
677
- if num_frames == 1:
678
- output = output[:, :num_frames, :]
679
-
680
- if not return_dict:
681
- return (output,)
682
- return Transformer2DModelOutput(sample=output)
683
-
684
- @classmethod
685
- def from_pretrained(
686
- cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={},
687
- low_cpu_mem_usage=False, torch_dtype=torch.bfloat16, use_vae_mask=False, stack_mask=False,
688
- ):
689
- if subfolder is not None:
690
- pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
691
- print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
692
-
693
- config_file = os.path.join(pretrained_model_path, 'config.json')
694
- if not os.path.isfile(config_file):
695
- raise RuntimeError(f"{config_file} does not exist")
696
- with open(config_file, "r") as f:
697
- config = json.load(f)
698
-
699
- if use_vae_mask:
700
- print('[DEBUG] use vae to encode mask')
701
- config['in_channels'] = 48
702
- elif stack_mask:
703
- print('[DEBUG] use stacking mask')
704
- config['in_channels'] = 36
705
-
706
- from diffusers.utils import WEIGHTS_NAME
707
- model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
708
- model_file_safetensors = model_file.replace(".bin", ".safetensors")
709
-
710
- if "dict_mapping" in transformer_additional_kwargs.keys():
711
- for key in transformer_additional_kwargs["dict_mapping"]:
712
- transformer_additional_kwargs[transformer_additional_kwargs["dict_mapping"][key]] = config[key]
713
-
714
- if low_cpu_mem_usage:
715
- try:
716
- import re
717
-
718
- from diffusers.models.modeling_utils import \
719
- load_model_dict_into_meta
720
- from diffusers.utils import is_accelerate_available
721
- if is_accelerate_available():
722
- import accelerate
723
-
724
- # Instantiate model with empty weights
725
- with accelerate.init_empty_weights():
726
- model = cls.from_config(config, **transformer_additional_kwargs)
727
-
728
- param_device = "cpu"
729
- if os.path.exists(model_file):
730
- state_dict = torch.load(model_file, map_location="cpu")
731
- elif os.path.exists(model_file_safetensors):
732
- from safetensors.torch import load_file, safe_open
733
- state_dict = load_file(model_file_safetensors)
734
- else:
735
- from safetensors.torch import load_file, safe_open
736
- model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
737
- state_dict = {}
738
- for _model_file_safetensors in model_files_safetensors:
739
- _state_dict = load_file(_model_file_safetensors)
740
- for key in _state_dict:
741
- state_dict[key] = _state_dict[key]
742
- model._convert_deprecated_attention_blocks(state_dict)
743
- # move the params from meta device to cpu
744
- missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
745
- if len(missing_keys) > 0:
746
- raise ValueError(
747
- f"Cannot load {cls} from {pretrained_model_path} because the following keys are"
748
- f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
749
- " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
750
- " those weights or else make sure your checkpoint file is correct."
751
- )
752
-
753
- unexpected_keys = load_model_dict_into_meta(
754
- model,
755
- state_dict,
756
- device=param_device,
757
- dtype=torch_dtype,
758
- model_name_or_path=pretrained_model_path,
759
- )
760
-
761
- if cls._keys_to_ignore_on_load_unexpected is not None:
762
- for pat in cls._keys_to_ignore_on_load_unexpected:
763
- unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
764
-
765
- if len(unexpected_keys) > 0:
766
- print(
767
- f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
768
- )
769
- return model
770
- except Exception as e:
771
- print(
772
- f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
773
- )
774
-
775
- model = cls.from_config(config, **transformer_additional_kwargs)
776
- if os.path.exists(model_file):
777
- state_dict = torch.load(model_file, map_location="cpu")
778
- elif os.path.exists(model_file_safetensors):
779
- from safetensors.torch import load_file, safe_open
780
- state_dict = load_file(model_file_safetensors)
781
- else:
782
- from safetensors.torch import load_file, safe_open
783
- model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
784
- state_dict = {}
785
- for _model_file_safetensors in model_files_safetensors:
786
- _state_dict = load_file(_model_file_safetensors)
787
- for key in _state_dict:
788
- state_dict[key] = _state_dict[key]
789
-
790
- if model.state_dict()['patch_embed.proj.weight'].size() != state_dict['patch_embed.proj.weight'].size():
791
- new_shape = model.state_dict()['patch_embed.proj.weight'].size()
792
- if len(new_shape) == 5:
793
- state_dict['patch_embed.proj.weight'] = state_dict['patch_embed.proj.weight'].unsqueeze(2).expand(new_shape).clone()
794
- state_dict['patch_embed.proj.weight'][:, :, :-1] = 0
795
- elif len(new_shape) == 2:
796
- if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]:
797
- if use_vae_mask:
798
- print('[DEBUG] patch_embed.proj.weight size does not match due to vae-encoded mask')
799
- latent_ch = 16
800
- feat_scale = 8
801
- feat_dim = int(latent_ch * feat_scale)
802
- old_total_dim = state_dict['patch_embed.proj.weight'].size(1)
803
- new_total_dim = model.state_dict()['patch_embed.proj.weight'].size(1)
804
- model.state_dict()['patch_embed.proj.weight'][:, :feat_dim] = state_dict['patch_embed.proj.weight'][:, :feat_dim]
805
- model.state_dict()['patch_embed.proj.weight'][:, -feat_dim:] = state_dict['patch_embed.proj.weight'][:, -feat_dim:]
806
- for i in range(feat_dim, new_total_dim - feat_dim, feat_scale):
807
- model.state_dict()['patch_embed.proj.weight'][:, i:i+feat_scale] = state_dict['patch_embed.proj.weight'][:, feat_dim:-feat_dim]
808
- state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
809
- else:
810
- model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1]] = state_dict['patch_embed.proj.weight']
811
- model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:] = 0
812
- state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
813
- else:
814
- model.state_dict()['patch_embed.proj.weight'][:, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1]]
815
- state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
816
- else:
817
- if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]:
818
- model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1], :, :] = state_dict['patch_embed.proj.weight']
819
- model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:, :, :] = 0
820
- state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
821
- else:
822
- model.state_dict()['patch_embed.proj.weight'][:, :, :, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1], :, :]
823
- state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
824
-
825
- tmp_state_dict = {}
826
- for key in state_dict:
827
- if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
828
- tmp_state_dict[key] = state_dict[key]
829
- else:
830
- print(key, "Size don't match, skip")
831
-
832
- state_dict = tmp_state_dict
833
-
834
- m, u = model.load_state_dict(state_dict, strict=False)
835
- print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
836
- print(m)
837
-
838
- params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()]
839
- print(f"### All Parameters: {sum(params) / 1e6} M")
840
-
841
- params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
842
- print(f"### attn1 Parameters: {sum(params) / 1e6} M")
843
-
844
- model = model.to(torch_dtype)
845
- return model