roll-ai's picture
Upload 333 files
e8bdafd verified
from .vae_encoder import ControlnetXsVaeEncoderCogVideoX
from .vae_decoder import ControlnetXsVaeDecoderCogVideoX
from .transformer import CogVideoXTransformer3DControlnetXs
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
import os
from typing import Any, Dict, Optional, Tuple, Union
import torch
from torch import nn
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.attention import Attention, FeedForward
from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero, RMSNorm
import math
from .transformer import CogVideoXControlnetXsLayerNormZero
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class ZeroLayerNormDownProjector(nn.Module):
def __init__(
self,
dim_in: int,
dim_out: int,
time_embed_dim: int,
dropout: float = 0.0,
activation_fn: str = "gelu-approximate",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
) -> None:
super().__init__()
# dim_in > dim_out
self.silu = nn.SiLU()
self.scale_shift_linear = nn.Linear(time_embed_dim, 2 * dim_in, bias=True)
self.gate_linear = nn.Linear(time_embed_dim, dim_out, bias=True)
self.norm = nn.LayerNorm(dim_in, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
self.mlp = nn.Sequential(
nn.Linear(dim_in, 4 * dim_out),
nn.GELU(approximate='tanh'),
nn.Linear(4 * dim_out, dim_out),
)
def forward(
self, hidden_states: torch.Tensor, temb: torch.Tensor, hidden_states2: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
silu_temb = self.silu(temb)
scale, shift = self.scale_shift_linear(silu_temb).chunk(2, dim=1)
gate = self.gate_linear(silu_temb)
return hidden_states2 + gate[:, None, :] * \
self.mlp(self.norm(hidden_states) * (1 + scale[:, None, :]) + shift[:, None, :])
class ZeroLayerNormUpProjector(nn.Module):
def __init__(
self,
dim_in: int,
dim_out: int,
time_embed_dim: int,
dropout: float = 0.0,
activation_fn: str = "gelu-approximate",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
) -> None:
super().__init__()
# dim_in < dim_out
self.silu = nn.SiLU()
self.scale_shift_linear = nn.Linear(time_embed_dim, 2 * dim_in, bias=True)
self.gate_linear = nn.Linear(time_embed_dim, dim_out, bias=True)
self.norm = nn.LayerNorm(dim_in, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
self.mlp = nn.Sequential(
nn.Linear(dim_in, 4 * dim_in),
nn.GELU(approximate='tanh'),
nn.Linear(4 * dim_in, dim_out),
)
def forward(
self, hidden_states: torch.Tensor, temb: torch.Tensor, hidden_states2: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
silu_temb = self.silu(temb)
scale, shift = self.scale_shift_linear(silu_temb).chunk(2, dim=1)
gate = self.gate_linear(silu_temb)
return hidden_states2 + gate[:, None, :] * \
self.mlp(self.norm(hidden_states) * (1 + scale[:, None, :]) + shift[:, None, :])
class ControlnetXs(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
model_path: str,
main_transformer_config=None,
):
super().__init__()
self.vae_encoder = ControlnetXsVaeEncoderCogVideoX.from_config(os.path.join(model_path, "controlnetxs_vae_encoder"))
self.transformer = CogVideoXTransformer3DControlnetXs.from_config(os.path.join(model_path, "controlnetxs_transformer"))
self.main_transformer_config = main_transformer_config
inner_dim_for_transformer = self.transformer.config.num_attention_heads * self.transformer.config.attention_head_dim
inner_dim_for_main_transformer = self.main_transformer_config.num_attention_heads * self.main_transformer_config.attention_head_dim
self.up_down_layer_start_idx = self.transformer.config.up_down_layer_start_idx
self.up_down_layer_end_idx = self.transformer.config.up_down_layer_end_idx
self.down_projectors = nn.ModuleList([
ZeroLayerNormDownProjector(
dim_in=inner_dim_for_main_transformer,
dim_out=inner_dim_for_transformer,
time_embed_dim=self.transformer.config.time_embed_dim,
dropout=self.transformer.config.dropout,
activation_fn=self.transformer.config.activation_fn,
norm_elementwise_affine=self.transformer.config.norm_elementwise_affine,
norm_eps=self.transformer.config.norm_eps,
) for _ in range(self.transformer.config.num_layers)
])
self.up_projectors = nn.ModuleList([
ZeroLayerNormUpProjector(
dim_in=inner_dim_for_transformer,
dim_out=inner_dim_for_main_transformer,
time_embed_dim=self.transformer.config.time_embed_dim,
dropout=self.transformer.config.dropout,
activation_fn=self.transformer.config.activation_fn,
norm_elementwise_affine=self.transformer.config.norm_elementwise_affine,
norm_eps=self.transformer.config.norm_eps,
) for _ in range(self.transformer.config.num_layers)
])
self.emb_projector = nn.Linear(self.main_transformer_config.time_embed_dim, self.transformer.config.time_embed_dim)
self.gradient_checkpointing = False
self.init_weights()
def init_weights(self):
def _basic_init(module):
if isinstance(module, nn.Linear):
if module.weight.requires_grad:
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
for down_projector in self.down_projectors:
nn.init.constant_(down_projector.scale_shift_linear.weight, 0)
nn.init.constant_(down_projector.scale_shift_linear.bias, 0)
nn.init.constant_(down_projector.gate_linear.weight, 0)
nn.init.constant_(down_projector.gate_linear.bias, 0)
for up_projector in self.up_projectors:
nn.init.constant_(up_projector.scale_shift_linear.weight, 0)
nn.init.constant_(up_projector.scale_shift_linear.bias, 0)
nn.init.constant_(up_projector.gate_linear.weight, 0)
nn.init.constant_(up_projector.gate_linear.bias, 0)
for block in self.transformer.transformer_blocks:
nn.init.constant_(block.norm1.linear.weight, 0)
nn.init.constant_(block.norm1.linear.bias, 0)
nn.init.constant_(block.norm2.linear.weight, 0)
nn.init.constant_(block.norm2.linear.bias, 0)
nn.init.constant_(self.transformer.camera_condition_gft_beta_embedding.linear_2.weight, 0)
nn.init.constant_(self.transformer.camera_condition_gft_beta_embedding.linear_2.bias, 0)
def set_main_transformer(self, main_transformer):
self.main_transformer = main_transformer
def forward(
self,
hidden_states: torch.Tensor, # B, F//4, C, H//8, W//8
encoder_hidden_states: torch.Tensor,
timestep: Union[int, float, torch.LongTensor],
timestep_cond: Optional[torch.Tensor] = None,
ofs: Optional[Union[int, float, torch.LongTensor]] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
plucker_embedding=None, # B, C, F, H, W
image_rotary_emb_for_controlnetxs: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
main_transformer=None,
camera_condition_gft_beta: Optional[Union[int, float, torch.LongTensor]] = None, # Guidance-Free Training
camera_condition_dropout=0.0
):
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(main_transformer, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_frames, channels, height, width = hidden_states.shape
# 1. Time embedding
timesteps = timestep
t_emb = main_transformer.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=hidden_states.dtype)
emb = main_transformer.time_embedding(t_emb, timestep_cond)
if main_transformer.ofs_embedding is not None:
ofs_emb = main_transformer.ofs_proj(ofs)
ofs_emb = ofs_emb.to(dtype=hidden_states.dtype)
ofs_emb = main_transformer.ofs_embedding(ofs_emb)
emb = emb + ofs_emb
# Patch embed
emb_for_controlnetxs = self.emb_projector(emb)
if camera_condition_gft_beta is None:
camera_condition_gft_beta = torch.ones(timesteps.shape[0]).to(device=hidden_states.device)
camera_condition_gft_beta_emb = self.transformer.camera_condition_gft_beta_proj(camera_condition_gft_beta)
camera_condition_gft_beta_emb = camera_condition_gft_beta_emb.to(dtype=hidden_states.dtype)
camera_condition_gft_beta_emb = self.transformer.camera_condition_gft_beta_embedding(camera_condition_gft_beta_emb)
emb_for_controlnetxs = emb_for_controlnetxs + camera_condition_gft_beta_emb
plucker_embedding = plucker_embedding.to(self.vae_encoder.device, dtype=self.vae_encoder.dtype)
latent_plucker_embedding_dist = self.vae_encoder.encode(plucker_embedding).latent_dist # B,C=6,F,H,W --> B,128,F//4,H//4,W//4
latent_plucker_embedding = latent_plucker_embedding_dist.sample()
patch_size_t = main_transformer.config.patch_size_t
if patch_size_t is not None:
ncopy = latent_plucker_embedding.shape[2] % patch_size_t
# Copy the first frame ncopy times to match patch_size_t
first_frame = latent_plucker_embedding[:, :, :1, :, :] # Get first frame [B, C, 1, H, W]
latent_plucker_embedding = torch.cat([first_frame.repeat(1, 1, ncopy, 1, 1), latent_plucker_embedding], dim=2)
assert latent_plucker_embedding.shape[2] % patch_size_t == 0
latent_plucker_embedding = latent_plucker_embedding.permute(0, 2, 1, 3, 4) # [B, C=128, F//4, H//8, W//8] to [B, F//4, C=128, H//8, W//8]
num_frames = latent_plucker_embedding.shape[1]
if camera_condition_dropout > 0.0:
drop_ids = torch.rand(latent_plucker_embedding.shape[0]).to(latent_plucker_embedding.device) <= camera_condition_dropout
latent_plucker_embedding = torch.where(
drop_ids[:, None, None, None, None],
0.0,
1.0
).to(latent_plucker_embedding.dtype) * latent_plucker_embedding
hidden_states_for_controlnetxs = self.transformer.patch_embed(latent_plucker_embedding)
hidden_states_for_controlnetxs = self.transformer.embedding_dropout(hidden_states_for_controlnetxs)
# 2. Patch embedding
hidden_states = main_transformer.patch_embed(encoder_hidden_states, hidden_states)
hidden_states = main_transformer.embedding_dropout(hidden_states)
text_seq_length = encoder_hidden_states.shape[1]
encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:]
# 3. Transformer blocks
for i, block in enumerate(main_transformer.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
if i >= self.up_down_layer_start_idx and i <= self.up_down_layer_end_idx:
hidden_states_for_controlnetxs = self.down_projectors[i](
hidden_states, emb_for_controlnetxs, hidden_states_for_controlnetxs
)
hidden_states_for_controlnetxs = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.transformer.transformer_blocks[i]),
hidden_states_for_controlnetxs,
emb_for_controlnetxs,
image_rotary_emb_for_controlnetxs,
**ckpt_kwargs,
)
hidden_states = self.up_projectors[i](
hidden_states_for_controlnetxs, emb_for_controlnetxs, hidden_states
)
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
emb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
if i >= self.up_down_layer_start_idx and i <= self.up_down_layer_end_idx:
hidden_states_for_controlnetxs = self.down_projectors[i](
hidden_states, emb_for_controlnetxs, hidden_states_for_controlnetxs,
)
hidden_states_for_controlnetxs = self.transformer.transformer_blocks[i](
hidden_states=hidden_states_for_controlnetxs,
temb=emb_for_controlnetxs,
image_rotary_emb=image_rotary_emb_for_controlnetxs,
)
hidden_states = self.up_projectors[i](
hidden_states_for_controlnetxs, emb_for_controlnetxs, hidden_states
)
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=emb,
image_rotary_emb=image_rotary_emb,
)
if not main_transformer.config.use_rotary_positional_embeddings:
# CogVideoX-2B
hidden_states = main_transformer.norm_final(hidden_states)
else:
# CogVideoX-5B
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
hidden_states = main_transformer.norm_final(hidden_states)
hidden_states = hidden_states[:, text_seq_length:]
# 4. Final block
hidden_states = main_transformer.norm_out(hidden_states, temb=emb)
hidden_states = main_transformer.proj_out(hidden_states)
# 5. Unpatchify
p = main_transformer.config.patch_size
p_t = main_transformer.config.patch_size_t
if p_t is None:
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
else:
output = hidden_states.reshape(
batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
)
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(main_transformer, lora_scale)
if not return_dict:
return (output,)
rets = {}
rets['sample'] = output
return rets