|
from typing import Any, Dict, Optional, Tuple, Union |
|
|
|
import torch |
|
from torch import nn |
|
from einops import rearrange |
|
import torch.nn.functional as F |
|
|
|
from diffusers.configuration_utils import FrozenDict |
|
|
|
from diffusers import CogVideoXTransformer3DModel |
|
from diffusers.models.transformers.cogvideox_transformer_3d import Transformer2DModelOutput, CogVideoXBlock |
|
from diffusers.utils import is_torch_version |
|
from diffusers.loaders import PeftAdapterMixin |
|
from diffusers.utils.torch_utils import maybe_allow_in_graph |
|
from diffusers.models.embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed |
|
from diffusers.models.modeling_utils import ModelMixin |
|
from diffusers.models.attention import Attention, FeedForward |
|
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor2_0 |
|
from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero, AdaLayerNormZeroSingle |
|
from diffusers.configuration_utils import ConfigMixin, register_to_config |
|
|
|
from .cogvideox_custom_modules import CustomCogVideoXPatchEmbed, CustomCogVideoXBlock |
|
|
|
import pdb |
|
|
|
class CogVideoXControlnet(ModelMixin, ConfigMixin, PeftAdapterMixin): |
|
_supports_gradient_checkpointing = True |
|
|
|
@register_to_config |
|
def __init__( |
|
self, |
|
num_attention_heads: int = 30, |
|
attention_head_dim: int = 64, |
|
|
|
in_channels: int = 16, |
|
out_channels: Optional[int] = 16, |
|
flip_sin_to_cos: bool = True, |
|
freq_shift: int = 0, |
|
time_embed_dim: int = 512, |
|
ofs_embed_dim: Optional[int] = None, |
|
text_embed_dim: int = 4096, |
|
num_layers: int = 30, |
|
dropout: float = 0.0, |
|
attention_bias: bool = True, |
|
sample_width: int = 90, |
|
sample_height: int = 60, |
|
sample_frames: int = 49, |
|
patch_size: int = 2, |
|
patch_size_t: Optional[int] = None, |
|
temporal_compression_ratio: int = 4, |
|
max_text_seq_length: int = 226, |
|
activation_fn: str = "gelu-approximate", |
|
timestep_activation_fn: str = "silu", |
|
norm_elementwise_affine: bool = True, |
|
norm_eps: float = 1e-5, |
|
spatial_interpolation_scale: float = 1.875, |
|
temporal_interpolation_scale: float = 1.0, |
|
use_rotary_positional_embeddings: bool = False, |
|
use_learned_positional_embeddings: bool = False, |
|
patch_bias: bool = True, |
|
out_proj_dim_factor: int = 8, |
|
out_proj_dim_zero_init: bool = True, |
|
notextinflow: bool = False, |
|
): |
|
super().__init__() |
|
inner_dim = num_attention_heads * attention_head_dim |
|
|
|
self.notextinflow = notextinflow |
|
|
|
if not use_rotary_positional_embeddings and use_learned_positional_embeddings: |
|
raise ValueError( |
|
"There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional " |
|
"embeddings. If you're using a custom model and/or believe this should be supported, please open an " |
|
"issue at https://github.com/huggingface/diffusers/issues." |
|
) |
|
|
|
""" |
|
Delete below. |
|
In our case, FloVD, controlnet_hidden_states is already flow_latents encoded by 3D-Causal-VAE |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.patch_embed = CogVideoXPatchEmbed( |
|
patch_size=patch_size, |
|
in_channels=in_channels, |
|
embed_dim=inner_dim, |
|
bias=True, |
|
sample_width=sample_width, |
|
sample_height=sample_height, |
|
sample_frames=sample_frames, |
|
temporal_compression_ratio=temporal_compression_ratio, |
|
spatial_interpolation_scale=spatial_interpolation_scale, |
|
temporal_interpolation_scale=temporal_interpolation_scale, |
|
use_positional_embeddings=not use_rotary_positional_embeddings, |
|
use_learned_positional_embeddings=use_learned_positional_embeddings, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.embedding_dropout = nn.Dropout(dropout) |
|
|
|
|
|
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift) |
|
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.transformer_blocks = nn.ModuleList( |
|
[ |
|
CustomCogVideoXBlock( |
|
dim=inner_dim, |
|
num_attention_heads=num_attention_heads, |
|
attention_head_dim=attention_head_dim, |
|
time_embed_dim=time_embed_dim, |
|
dropout=dropout, |
|
activation_fn=activation_fn, |
|
attention_bias=attention_bias, |
|
norm_elementwise_affine=norm_elementwise_affine, |
|
norm_eps=norm_eps, |
|
) |
|
for _ in range(num_layers) |
|
] |
|
) |
|
|
|
self.out_projectors = None |
|
if out_proj_dim_factor is not None: |
|
out_proj_dim = num_attention_heads * out_proj_dim_factor |
|
self.out_projectors = nn.ModuleList( |
|
[nn.Linear(inner_dim, out_proj_dim) for _ in range(num_layers)] |
|
) |
|
if out_proj_dim_zero_init: |
|
for out_projector in self.out_projectors: |
|
self.zeros_init_linear(out_projector) |
|
|
|
self.gradient_checkpointing = False |
|
|
|
def zeros_init_linear(self, linear: nn.Module): |
|
if isinstance(linear, (nn.Linear, nn.Conv1d)): |
|
if hasattr(linear, "weight"): |
|
nn.init.zeros_(linear.weight) |
|
if hasattr(linear, "bias"): |
|
nn.init.zeros_(linear.bias) |
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
self.gradient_checkpointing = value |
|
|
|
def compress_time(self, x, num_frames): |
|
x = rearrange(x, '(b f) c h w -> b f c h w', f=num_frames) |
|
batch_size, frames, channels, height, width = x.shape |
|
x = rearrange(x, 'b f c h w -> (b h w) c f') |
|
|
|
if x.shape[-1] % 2 == 1: |
|
x_first, x_rest = x[..., 0], x[..., 1:] |
|
if x_rest.shape[-1] > 0: |
|
x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2) |
|
|
|
x = torch.cat([x_first[..., None], x_rest], dim=-1) |
|
else: |
|
x = F.avg_pool1d(x, kernel_size=2, stride=2) |
|
x = rearrange(x, '(b h w) c f -> (b f) c h w', b=batch_size, h=height, w=width) |
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod |
|
def from_pretrained(cls, model_path, subfolder, **additional_kwargs): |
|
base = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder=subfolder) |
|
controlnet_config = FrozenDict({**base.config, **additional_kwargs}) |
|
model = cls(**controlnet_config) |
|
|
|
missing, unexpected = model.load_state_dict(base.state_dict(), strict=False) |
|
print(f"Load CogVideoXTransformer3DModel.") |
|
|
|
|
|
|
|
|
|
del base |
|
torch.cuda.empty_cache() |
|
|
|
|
|
return model |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: torch.Tensor, |
|
controlnet_hidden_states: torch.Tensor, |
|
timestep: Union[int, float, torch.LongTensor], |
|
controlnet_valid_mask: torch.Tensor = None, |
|
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
timestep_cond: Optional[torch.Tensor] = None, |
|
return_dict: bool = True, |
|
): |
|
""" |
|
Delete below. |
|
In our case, FloVD, controlnet_hidden_states is already flow_latents encoded by 3D-Causal-VAE |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
batch_size, num_frames, channels, height, width = hidden_states.shape |
|
|
|
|
|
|
|
|
|
|
|
|
|
hidden_states = torch.cat([hidden_states, controlnet_hidden_states], dim=2) |
|
|
|
|
|
|
|
timesteps = timestep |
|
t_emb = self.time_proj(timesteps) |
|
|
|
|
|
|
|
|
|
t_emb = t_emb.to(dtype=hidden_states.dtype) |
|
emb = self.time_embedding(t_emb, timestep_cond) |
|
|
|
|
|
|
|
|
|
|
|
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) |
|
|
|
hidden_states = self.embedding_dropout(hidden_states) |
|
|
|
""" |
|
Not modified below. |
|
hidden_states include both hidden_states and controlnet_hidden_states |
|
""" |
|
text_seq_length = encoder_hidden_states.shape[1] |
|
encoder_hidden_states = hidden_states[:, :text_seq_length] |
|
hidden_states = hidden_states[:, text_seq_length:] |
|
|
|
|
|
if controlnet_valid_mask is not None: |
|
mask_shape = controlnet_valid_mask.shape |
|
attention_mask = torch.nn.functional.interpolate(controlnet_valid_mask, size=(mask_shape[2], mask_shape[3]//2, mask_shape[4]//2), mode='trilinear', align_corners=False) |
|
attention_mask[attention_mask>=0.5] = 1 |
|
attention_mask[attention_mask<0.5] = 0 |
|
attention_mask = attention_mask.to(torch.bool) |
|
attention_mask = rearrange(attention_mask.squeeze(1), 'b f h w -> b (f h w)') |
|
|
|
|
|
if not self.notextinflow: |
|
attention_mask = F.pad(attention_mask, (text_seq_length, 0), value=0.0) |
|
|
|
attention_kwargs = { |
|
'attention_mask': attention_mask if controlnet_valid_mask is not None else None, |
|
'notextinflow': self.notextinflow, |
|
} |
|
|
|
controlnet_hidden_states = () |
|
|
|
for i, block in enumerate(self.transformer_blocks): |
|
if self.training and self.gradient_checkpointing: |
|
|
|
def create_custom_forward(module): |
|
def custom_forward(*inputs): |
|
return module(*inputs) |
|
|
|
return custom_forward |
|
|
|
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} |
|
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(block), |
|
hidden_states, |
|
encoder_hidden_states, |
|
emb, |
|
image_rotary_emb, |
|
attention_kwargs, |
|
**ckpt_kwargs, |
|
) |
|
else: |
|
hidden_states, encoder_hidden_states = block( |
|
hidden_states=hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
temb=emb, |
|
image_rotary_emb=image_rotary_emb, |
|
attention_kwargs=attention_kwargs, |
|
) |
|
|
|
if self.out_projectors is not None: |
|
controlnet_hidden_states += (self.out_projectors[i](hidden_states),) |
|
else: |
|
controlnet_hidden_states += (hidden_states,) |
|
|
|
if not return_dict: |
|
return (controlnet_hidden_states,) |
|
return Transformer2DModelOutput(sample=controlnet_hidden_states) |