|
from typing import Any, Dict, Optional, Tuple, Union |
|
|
|
import torch |
|
import numpy as np |
|
from diffusers.utils import is_torch_version |
|
from diffusers.models.transformers.cogvideox_transformer_3d import CogVideoXTransformer3DModel, Transformer2DModelOutput |
|
|
|
|
|
class CustomCogVideoXTransformer3DModel(CogVideoXTransformer3DModel): |
|
def set_learnable_parameters(self, unfrozen_layers: int = 16): |
|
for param in self.patch_embed.parameters(): |
|
param.requires_grad = True |
|
|
|
for i in range(unfrozen_layers): |
|
block = self.transformer_blocks[i] |
|
attn = block.attn1 |
|
for module in [block.norm2, block.ff]: |
|
for param in self.patch_embed.parameters(): |
|
param.requires_grad = True |
|
|
|
for name in ['to_q', 'to_k', 'to_v', 'norm_q', 'norm_k']: |
|
module = getattr(attn, name, None) |
|
if module is not None: |
|
for param in module.parameters(): |
|
param.requires_grad = True |
|
else: |
|
print(f"[Warning] {name} not found in attn1 of block {i}") |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: torch.Tensor, |
|
timestep: Union[int, float, torch.LongTensor], |
|
start_frame = None, |
|
timestep_cond: Optional[torch.Tensor] = None, |
|
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
controlnet_states: torch.Tensor = None, |
|
controlnet_weights: Optional[Union[float, int, list, np.ndarray, torch.FloatTensor]] = 1.0, |
|
return_dict: bool = True, |
|
): |
|
batch_size, num_frames, channels, height, width = hidden_states.shape |
|
|
|
if start_frame is not None: |
|
hidden_states = torch.cat([start_frame, hidden_states], dim=2) |
|
|
|
timesteps = timestep |
|
t_emb = self.time_proj(timesteps) |
|
|
|
|
|
|
|
|
|
t_emb = t_emb.to(dtype=hidden_states.dtype) |
|
emb = self.time_embedding(t_emb, timestep_cond) |
|
|
|
|
|
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) |
|
hidden_states = self.embedding_dropout(hidden_states) |
|
|
|
text_seq_length = encoder_hidden_states.shape[1] |
|
encoder_hidden_states = hidden_states[:, :text_seq_length] |
|
hidden_states = hidden_states[:, text_seq_length:] |
|
|
|
|
|
for i, block in enumerate(self.transformer_blocks): |
|
if self.gradient_checkpointing: |
|
def create_custom_forward(module): |
|
def custom_forward(*inputs): |
|
return module(*inputs) |
|
|
|
return custom_forward |
|
|
|
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} |
|
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(block), |
|
hidden_states, |
|
encoder_hidden_states, |
|
emb, |
|
image_rotary_emb, |
|
**ckpt_kwargs, |
|
) |
|
else: |
|
hidden_states, encoder_hidden_states = block( |
|
hidden_states=hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
temb=emb, |
|
image_rotary_emb=image_rotary_emb, |
|
) |
|
|
|
if (controlnet_states is not None) and (i < len(controlnet_states)): |
|
controlnet_states_block = controlnet_states[i] |
|
controlnet_block_weight = 1.0 |
|
if isinstance(controlnet_weights, (list, np.ndarray)) or torch.is_tensor(controlnet_weights): |
|
controlnet_block_weight = controlnet_weights[i] |
|
elif isinstance(controlnet_weights, (float, int)): |
|
controlnet_block_weight = controlnet_weights |
|
|
|
hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight |
|
|
|
if not self.config.use_rotary_positional_embeddings: |
|
|
|
hidden_states = self.norm_final(hidden_states) |
|
else: |
|
|
|
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) |
|
hidden_states = self.norm_final(hidden_states) |
|
hidden_states = hidden_states[:, text_seq_length:] |
|
|
|
|
|
hidden_states = self.norm_out(hidden_states, temb=emb) |
|
hidden_states = self.proj_out(hidden_states) |
|
|
|
|
|
p = self.config.patch_size |
|
p_t = self.config.patch_size_t |
|
|
|
if p_t is None: |
|
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) |
|
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) |
|
else: |
|
output = hidden_states.reshape( |
|
batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p |
|
) |
|
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2) |
|
|
|
if not return_dict: |
|
return (output,) |
|
return Transformer2DModelOutput(sample=output) |