|
import math |
|
from typing import List, Optional, Tuple, Union, Dict, Any |
|
import copy |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
from diffusers import CogVideoXTransformer3DModel |
|
from diffusers.models.transformers.cogvideox_transformer_3d import CogVideoXBlock |
|
from diffusers.models.normalization import CogVideoXLayerNormZero |
|
from diffusers.models.attention import FeedForward |
|
from diffusers.models.attention_processor import CogVideoXAttnProcessor2_0, Attention |
|
from diffusers.models.embeddings import CogVideoXPatchEmbed |
|
from diffusers.models.modeling_outputs import Transformer2DModelOutput |
|
from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers |
|
from diffusers.utils.torch_utils import maybe_allow_in_graph |
|
|
|
from contextlib import contextmanager |
|
from peft.tuners.lora.layer import LoraLayer |
|
|
|
import pdb |
|
|
|
|
|
|
|
|
|
class enable_lora: |
|
def __init__(self, modules, enable=True): |
|
self.modules = modules |
|
self.enable = enable |
|
self.prev_states = {} |
|
|
|
def __enter__(self): |
|
for module in self.modules: |
|
self.prev_states[module] = getattr(module, "lora_enabled", True) |
|
setattr(module, "lora_enabled", self.enable) |
|
return self |
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
for module in self.modules: |
|
setattr(module, "lora_enabled", self.prev_states[module]) |
|
return False |
|
|
|
|
|
|
|
class CustomCogVideoXPatchEmbed(CogVideoXPatchEmbed): |
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
patch_size = kwargs['patch_size'] |
|
patch_size_t = kwargs['patch_size_t'] |
|
bias = kwargs['bias'] |
|
in_channels = kwargs['in_channels'] |
|
embed_dim = kwargs['embed_dim'] |
|
|
|
|
|
if patch_size_t is None: |
|
|
|
self.flow_proj = nn.Conv2d(in_channels//2, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias) |
|
else: |
|
|
|
self.flow_proj = nn.Linear(in_channels//2 * patch_size * patch_size * patch_size_t, embed_dim) |
|
|
|
|
|
if self.use_positional_embeddings or self.use_learned_positional_embeddings: |
|
flow_pos_embedding = self._get_positional_embeddings(self.sample_height, self.sample_width, self.sample_frames)[:,self.max_text_seq_length:] |
|
self.flow_pos_embedding = nn.Parameter(flow_pos_embedding) |
|
|
|
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor, flow_embeds: torch.Tensor): |
|
r""" |
|
Args: |
|
text_embeds (`torch.Tensor`): |
|
Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim). |
|
image_embeds (`torch.Tensor`): |
|
Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width). |
|
flow_embeds (`torch.Tensor`): |
|
Input flow embeddings. Expected shape: (batch_size, num_frames, channels, height, width). |
|
""" |
|
text_embeds = self.text_proj(text_embeds) |
|
|
|
batch_size, num_frames, channels, height, width = image_embeds.shape |
|
|
|
if self.patch_size_t is None: |
|
|
|
image_embeds = image_embeds.reshape(-1, channels, height, width) |
|
image_embeds = self.proj(image_embeds) |
|
image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:]) |
|
image_embeds = image_embeds.flatten(3).transpose(2, 3) |
|
image_embeds = image_embeds.flatten(1, 2) |
|
|
|
|
|
flow_embeds = flow_embeds.reshape(-1, channels//2, height, width) |
|
flow_embeds = self.flow_proj(flow_embeds) |
|
flow_embeds = flow_embeds.view(batch_size, num_frames, *flow_embeds.shape[1:]) |
|
flow_embeds = flow_embeds.flatten(3).transpose(2, 3) |
|
flow_embeds = flow_embeds.flatten(1, 2) |
|
else: |
|
p = self.patch_size |
|
p_t = self.patch_size_t |
|
|
|
|
|
image_embeds = image_embeds.permute(0, 1, 3, 4, 2) |
|
image_embeds = image_embeds.reshape( |
|
batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels |
|
) |
|
image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3) |
|
image_embeds = self.proj(image_embeds) |
|
|
|
|
|
flow_embeds = flow_embeds.permute(0, 1, 3, 4, 2) |
|
flow_embeds = flow_embeds.reshape( |
|
batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels//2 |
|
) |
|
flow_embeds = flow_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3) |
|
flow_embeds = self.flow_proj(flow_embeds) |
|
|
|
|
|
|
|
|
|
|
|
embeds = torch.cat( |
|
[text_embeds, image_embeds, flow_embeds], dim=1 |
|
).contiguous() |
|
|
|
if self.use_positional_embeddings or self.use_learned_positional_embeddings: |
|
if self.use_learned_positional_embeddings and (self.sample_width != width or self.sample_height != height): |
|
raise ValueError( |
|
"It is currently not possible to generate videos at a different resolution that the defaults. This should only be the case with 'THUDM/CogVideoX-5b-I2V'." |
|
"If you think this is incorrect, please open an issue at https://github.com/huggingface/diffusers/issues." |
|
) |
|
|
|
pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1 |
|
|
|
if ( |
|
self.sample_height != height |
|
or self.sample_width != width |
|
or self.sample_frames != pre_time_compression_frames |
|
): |
|
pos_embedding = self._get_positional_embeddings( |
|
height, width, pre_time_compression_frames, device=embeds.device |
|
) |
|
else: |
|
pos_embedding = self.pos_embedding |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flow_pos_embedding = self.flow_pos_embedding |
|
pos_embedding_total = torch.cat([pos_embedding, flow_pos_embedding], dim=1).to(dtype=embeds.dtype) |
|
embeds = embeds + pos_embedding_total |
|
|
|
return embeds |
|
|
|
|
|
|
|
@maybe_allow_in_graph |
|
class CustomCogVideoXBlock(nn.Module): |
|
r""" |
|
Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model. |
|
|
|
Parameters: |
|
dim (`int`): |
|
The number of channels in the input and output. |
|
num_attention_heads (`int`): |
|
The number of heads to use for multi-head attention. |
|
attention_head_dim (`int`): |
|
The number of channels in each head. |
|
time_embed_dim (`int`): |
|
The number of channels in timestep embedding. |
|
dropout (`float`, defaults to `0.0`): |
|
The dropout probability to use. |
|
activation_fn (`str`, defaults to `"gelu-approximate"`): |
|
Activation function to be used in feed-forward. |
|
attention_bias (`bool`, defaults to `False`): |
|
Whether or not to use bias in attention projection layers. |
|
qk_norm (`bool`, defaults to `True`): |
|
Whether or not to use normalization after query and key projections in Attention. |
|
norm_elementwise_affine (`bool`, defaults to `True`): |
|
Whether to use learnable elementwise affine parameters for normalization. |
|
norm_eps (`float`, defaults to `1e-5`): |
|
Epsilon value for normalization layers. |
|
final_dropout (`bool` defaults to `False`): |
|
Whether to apply a final dropout after the last feed-forward layer. |
|
ff_inner_dim (`int`, *optional*, defaults to `None`): |
|
Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used. |
|
ff_bias (`bool`, defaults to `True`): |
|
Whether or not to use bias in Feed-forward layer. |
|
attention_out_bias (`bool`, defaults to `True`): |
|
Whether or not to use bias in Attention output projection layer. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim: int, |
|
num_attention_heads: int, |
|
attention_head_dim: int, |
|
time_embed_dim: int, |
|
dropout: float = 0.0, |
|
activation_fn: str = "gelu-approximate", |
|
attention_bias: bool = False, |
|
qk_norm: bool = True, |
|
norm_elementwise_affine: bool = True, |
|
norm_eps: float = 1e-5, |
|
final_dropout: bool = True, |
|
ff_inner_dim: Optional[int] = None, |
|
ff_bias: bool = True, |
|
attention_out_bias: bool = True, |
|
): |
|
super().__init__() |
|
|
|
|
|
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) |
|
|
|
self.attn1 = Attention( |
|
query_dim=dim, |
|
dim_head=attention_head_dim, |
|
heads=num_attention_heads, |
|
qk_norm="layer_norm" if qk_norm else None, |
|
eps=1e-6, |
|
bias=attention_bias, |
|
out_bias=attention_out_bias, |
|
processor=CustomCogVideoXAttnProcessor2_0(), |
|
) |
|
|
|
|
|
self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) |
|
|
|
self.ff = FeedForward( |
|
dim, |
|
dropout=dropout, |
|
activation_fn=activation_fn, |
|
final_dropout=final_dropout, |
|
inner_dim=ff_inner_dim, |
|
bias=ff_bias, |
|
) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: torch.Tensor, |
|
temb: torch.Tensor, |
|
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
attention_kwargs: Optional[Dict[str, Any]] = None, |
|
) -> torch.Tensor: |
|
text_seq_length = encoder_hidden_states.size(1) |
|
attention_kwargs = attention_kwargs or {} |
|
|
|
|
|
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( |
|
hidden_states, encoder_hidden_states, temb |
|
) |
|
|
|
|
|
attn_hidden_states, attn_encoder_hidden_states = self.attn1( |
|
hidden_states=norm_hidden_states, |
|
encoder_hidden_states=norm_encoder_hidden_states, |
|
image_rotary_emb=image_rotary_emb, |
|
**attention_kwargs, |
|
) |
|
|
|
hidden_states = hidden_states + gate_msa * attn_hidden_states |
|
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states |
|
|
|
|
|
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2( |
|
hidden_states, encoder_hidden_states, temb |
|
) |
|
|
|
|
|
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) |
|
ff_output = self.ff(norm_hidden_states) |
|
|
|
hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:] |
|
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length] |
|
|
|
return hidden_states, encoder_hidden_states |
|
|
|
|
|
class CustomCogVideoXAttnProcessor2_0: |
|
r""" |
|
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on |
|
query and key vectors, but does not include spatial normalization. |
|
""" |
|
|
|
def __init__(self): |
|
if not hasattr(F, "scaled_dot_product_attention"): |
|
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") |
|
|
|
def __call__( |
|
self, |
|
attn: Attention, |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
image_rotary_emb: Optional[torch.Tensor] = None, |
|
notextinflow: Optional[bool] = False, |
|
) -> torch.Tensor: |
|
text_seq_length = encoder_hidden_states.size(1) |
|
|
|
if not notextinflow: |
|
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) |
|
|
|
batch_size, sequence_length, _ = hidden_states.shape |
|
|
|
if attention_mask is not None: |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) |
|
|
|
query = attn.to_q(hidden_states) |
|
key = attn.to_k(hidden_states) |
|
value = attn.to_v(hidden_states) |
|
|
|
inner_dim = key.shape[-1] |
|
head_dim = inner_dim // attn.heads |
|
|
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
if attn.norm_q is not None: |
|
query = attn.norm_q(query) |
|
if attn.norm_k is not None: |
|
key = attn.norm_k(key) |
|
|
|
|
|
if image_rotary_emb is not None: |
|
from diffusers.models.embeddings import apply_rotary_emb |
|
|
|
if not notextinflow: |
|
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) |
|
if not attn.is_cross_attention: |
|
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) |
|
else: |
|
query[:, :, :] = apply_rotary_emb(query[:, :, :], image_rotary_emb) |
|
if not attn.is_cross_attention: |
|
key[:, :, :] = apply_rotary_emb(key[:, :, :], image_rotary_emb) |
|
|
|
hidden_states = F.scaled_dot_product_attention( |
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
|
) |
|
|
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
if not notextinflow: |
|
encoder_hidden_states, hidden_states = hidden_states.split( |
|
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 |
|
) |
|
|
|
return hidden_states, encoder_hidden_states |