|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Dict, Optional, Union |
|
|
|
import torch |
|
from diffusers.configuration_utils import ConfigMixin, register_to_config |
|
from diffusers.models.attention import Attention, FeedForward |
|
from diffusers.models.embeddings import TimestepEmbedding, Timesteps |
|
from diffusers.models.modeling_outputs import Transformer2DModelOutput |
|
from diffusers.models.modeling_utils import ModelMixin |
|
from diffusers.utils import is_torch_version, logging |
|
from diffusers.utils.torch_utils import maybe_allow_in_graph |
|
from torch import nn |
|
|
|
from .modules import AdaLayerNorm, CogVideoXLayerNormZero, CogVideoXPatchEmbed, get_3d_sincos_pos_embed |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
@maybe_allow_in_graph |
|
class CogVideoXBlock(nn.Module): |
|
r""" |
|
Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model. |
|
|
|
Parameters: |
|
dim (`int`): The number of channels in the input and output. |
|
num_attention_heads (`int`): The number of heads to use for multi-head attention. |
|
attention_head_dim (`int`): The number of channels in each head. |
|
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. |
|
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. |
|
attention_bias (: |
|
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. |
|
qk_norm (`bool`, defaults to `True`): |
|
Whether or not to use normalization after query and key projections in Attention. |
|
norm_elementwise_affine (`bool`, defaults to `True`): |
|
Whether to use learnable elementwise affine parameters for normalization. |
|
norm_eps (`float`, defaults to `1e-5`): |
|
Epsilon value for normalization layers. |
|
final_dropout (`bool` defaults to `False`): |
|
Whether to apply a final dropout after the last feed-forward layer. |
|
ff_inner_dim (`int`, *optional*, defaults to `None`): |
|
Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used. |
|
ff_bias (`bool`, defaults to `True`): |
|
Whether or not to use bias in Feed-forward layer. |
|
attention_out_bias (`bool`, defaults to `True`): |
|
Whether or not to use bias in Attention output projection layer. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim: int, |
|
num_attention_heads: int, |
|
attention_head_dim: int, |
|
time_embed_dim: int, |
|
dropout: float = 0.0, |
|
activation_fn: str = "gelu-approximate", |
|
attention_bias: bool = False, |
|
qk_norm: bool = True, |
|
norm_elementwise_affine: bool = True, |
|
norm_eps: float = 1e-5, |
|
final_dropout: bool = True, |
|
ff_inner_dim: Optional[int] = None, |
|
ff_bias: bool = True, |
|
attention_out_bias: bool = True, |
|
): |
|
super().__init__() |
|
|
|
|
|
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) |
|
|
|
self.attn1 = Attention( |
|
query_dim=dim, |
|
dim_head=attention_head_dim, |
|
heads=num_attention_heads, |
|
qk_norm="layer_norm" if qk_norm else None, |
|
eps=1e-6, |
|
bias=attention_bias, |
|
out_bias=attention_out_bias, |
|
) |
|
|
|
|
|
self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) |
|
|
|
self.ff = FeedForward( |
|
dim, |
|
dropout=dropout, |
|
activation_fn=activation_fn, |
|
final_dropout=final_dropout, |
|
inner_dim=ff_inner_dim, |
|
bias=ff_bias, |
|
) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: torch.Tensor, |
|
temb: torch.Tensor, |
|
) -> torch.Tensor: |
|
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( |
|
hidden_states, encoder_hidden_states, temb |
|
) |
|
|
|
|
|
text_length = norm_encoder_hidden_states.size(1) |
|
|
|
|
|
|
|
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) |
|
attn_output = self.attn1( |
|
hidden_states=norm_hidden_states, |
|
encoder_hidden_states=None, |
|
) |
|
|
|
hidden_states = hidden_states + gate_msa * attn_output[:, text_length:] |
|
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_output[:, :text_length] |
|
|
|
|
|
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2( |
|
hidden_states, encoder_hidden_states, temb |
|
) |
|
|
|
|
|
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) |
|
ff_output = self.ff(norm_hidden_states) |
|
|
|
hidden_states = hidden_states + gate_ff * ff_output[:, text_length:] |
|
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_length] |
|
return hidden_states, encoder_hidden_states |
|
|
|
|
|
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): |
|
""" |
|
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo). |
|
|
|
Parameters: |
|
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. |
|
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. |
|
in_channels (`int`, *optional*): |
|
The number of channels in the input. |
|
out_channels (`int`, *optional*): |
|
The number of channels in the output. |
|
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. |
|
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. |
|
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. |
|
attention_bias (`bool`, *optional*): |
|
Configure if the `TransformerBlocks` attention should contain a bias parameter. |
|
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). |
|
This is fixed during training since it is used to learn a number of position embeddings. |
|
patch_size (`int`, *optional*): |
|
The size of the patches to use in the patch embedding layer. |
|
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. |
|
num_embeds_ada_norm ( `int`, *optional*): |
|
The number of diffusion steps used during training. Pass if at least one of the norm_layers is |
|
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are |
|
added to the hidden states. During inference, you can denoise for up to but not more steps than |
|
`num_embeds_ada_norm`. |
|
norm_type (`str`, *optional*, defaults to `"layer_norm"`): |
|
The type of normalization to use. Options are `"layer_norm"` or `"ada_layer_norm"`. |
|
norm_elementwise_affine (`bool`, *optional*, defaults to `True`): |
|
Whether or not to use elementwise affine in normalization layers. |
|
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use in normalization layers. |
|
caption_channels (`int`, *optional*): |
|
The number of channels in the caption embeddings. |
|
video_length (`int`, *optional*): |
|
The number of frames in the video-like data. |
|
""" |
|
|
|
_supports_gradient_checkpointing = True |
|
|
|
@register_to_config |
|
def __init__( |
|
self, |
|
num_attention_heads: int = 30, |
|
attention_head_dim: int = 64, |
|
in_channels: Optional[int] = 16, |
|
out_channels: Optional[int] = 16, |
|
flip_sin_to_cos: bool = True, |
|
freq_shift: int = 0, |
|
time_embed_dim: int = 512, |
|
text_embed_dim: int = 4096, |
|
num_layers: int = 30, |
|
dropout: float = 0.0, |
|
attention_bias: bool = True, |
|
sample_width: int = 90, |
|
sample_height: int = 60, |
|
sample_frames: int = 49, |
|
patch_size: int = 2, |
|
temporal_compression_ratio: int = 4, |
|
max_text_seq_length: int = 226, |
|
activation_fn: str = "gelu-approximate", |
|
timestep_activation_fn: str = "silu", |
|
norm_elementwise_affine: bool = True, |
|
norm_eps: float = 1e-5, |
|
spatial_interpolation_scale: float = 1.875, |
|
temporal_interpolation_scale: float = 1.0, |
|
): |
|
super().__init__() |
|
inner_dim = num_attention_heads * attention_head_dim |
|
|
|
post_patch_height = sample_height // patch_size |
|
post_patch_width = sample_width // patch_size |
|
post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1 |
|
self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames |
|
|
|
|
|
self.patch_embed = CogVideoXPatchEmbed(patch_size, in_channels, inner_dim, text_embed_dim, bias=True) |
|
self.embedding_dropout = nn.Dropout(dropout) |
|
|
|
|
|
spatial_pos_embedding = get_3d_sincos_pos_embed( |
|
inner_dim, |
|
(post_patch_width, post_patch_height), |
|
post_time_compression_frames, |
|
spatial_interpolation_scale, |
|
temporal_interpolation_scale, |
|
) |
|
spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1) |
|
pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False) |
|
pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding) |
|
self.register_buffer("pos_embedding", pos_embedding, persistent=False) |
|
|
|
|
|
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift) |
|
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn) |
|
|
|
|
|
self.transformer_blocks = nn.ModuleList( |
|
[ |
|
CogVideoXBlock( |
|
dim=inner_dim, |
|
num_attention_heads=num_attention_heads, |
|
attention_head_dim=attention_head_dim, |
|
time_embed_dim=time_embed_dim, |
|
dropout=dropout, |
|
activation_fn=activation_fn, |
|
attention_bias=attention_bias, |
|
norm_elementwise_affine=norm_elementwise_affine, |
|
norm_eps=norm_eps, |
|
) |
|
for _ in range(num_layers) |
|
] |
|
) |
|
self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine) |
|
|
|
|
|
self.norm_out = AdaLayerNorm( |
|
embedding_dim=time_embed_dim, |
|
output_dim=2 * inner_dim, |
|
norm_elementwise_affine=norm_elementwise_affine, |
|
norm_eps=norm_eps, |
|
chunk_dim=1, |
|
) |
|
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) |
|
|
|
self.gradient_checkpointing = False |
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
self.gradient_checkpointing = value |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: torch.Tensor, |
|
timestep: Union[int, float, torch.LongTensor], |
|
timestep_cond: Optional[torch.Tensor] = None, |
|
return_dict: bool = True, |
|
): |
|
batch_size, num_frames, channels, height, width = hidden_states.shape |
|
|
|
|
|
timesteps = timestep |
|
t_emb = self.time_proj(timesteps) |
|
|
|
|
|
|
|
|
|
t_emb = t_emb.to(dtype=hidden_states.dtype) |
|
emb = self.time_embedding(t_emb, timestep_cond) |
|
|
|
|
|
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) |
|
|
|
|
|
seq_length = height * width * num_frames // (self.config.patch_size**2) |
|
|
|
pos_embeds = self.pos_embedding[:, : self.config.max_text_seq_length + seq_length] |
|
hidden_states = hidden_states + pos_embeds |
|
hidden_states = self.embedding_dropout(hidden_states) |
|
|
|
encoder_hidden_states = hidden_states[:, : self.config.max_text_seq_length] |
|
hidden_states = hidden_states[:, self.config.max_text_seq_length :] |
|
|
|
|
|
for i, block in enumerate(self.transformer_blocks): |
|
if self.training and self.gradient_checkpointing: |
|
|
|
def create_custom_forward(module): |
|
def custom_forward(*inputs): |
|
return module(*inputs) |
|
|
|
return custom_forward |
|
|
|
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} |
|
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(block), |
|
hidden_states, |
|
encoder_hidden_states, |
|
emb, |
|
**ckpt_kwargs, |
|
) |
|
else: |
|
hidden_states, encoder_hidden_states = block( |
|
hidden_states=hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
temb=emb, |
|
) |
|
|
|
hidden_states = self.norm_final(hidden_states) |
|
|
|
|
|
hidden_states = self.norm_out(hidden_states, temb=emb) |
|
hidden_states = self.proj_out(hidden_states) |
|
|
|
|
|
p = self.config.patch_size |
|
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p) |
|
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) |
|
|
|
if not return_dict: |
|
return (output,) |
|
return Transformer2DModelOutput(sample=output) |
|
|