Spaces:
Running
Running
# pylint: disable=R0801 | |
# src/models/unet_3d_blocks.py | |
""" | |
This module defines various 3D UNet blocks used in the video model. | |
The blocks include: | |
- UNetMidBlock3DCrossAttn: The middle block of the UNet with cross attention. | |
- CrossAttnDownBlock3D: The downsampling block with cross attention. | |
- DownBlock3D: The standard downsampling block without cross attention. | |
- CrossAttnUpBlock3D: The upsampling block with cross attention. | |
- UpBlock3D: The standard upsampling block without cross attention. | |
These blocks are used to construct the 3D UNet architecture for video-related tasks. | |
""" | |
import torch | |
from einops import rearrange | |
from torch import nn | |
from .motion_module import get_motion_module | |
from .resnet import Downsample3D, ResnetBlock3D, Upsample3D | |
from .transformer_3d import Transformer3DModel | |
def get_down_block( | |
down_block_type, | |
num_layers, | |
in_channels, | |
out_channels, | |
temb_channels, | |
add_downsample, | |
resnet_eps, | |
resnet_act_fn, | |
attn_num_head_channels, | |
resnet_groups=None, | |
cross_attention_dim=None, | |
audio_attention_dim=None, | |
downsample_padding=None, | |
dual_cross_attention=False, | |
use_linear_projection=False, | |
only_cross_attention=False, | |
upcast_attention=False, | |
resnet_time_scale_shift="default", | |
unet_use_cross_frame_attention=None, | |
unet_use_temporal_attention=None, | |
use_inflated_groupnorm=None, | |
use_motion_module=None, | |
motion_module_type=None, | |
motion_module_kwargs=None, | |
use_audio_module=None, | |
depth=0, | |
stack_enable_blocks_name=None, | |
stack_enable_blocks_depth=None, | |
): | |
""" | |
Factory function to instantiate a down-block module for the 3D UNet architecture. | |
Down blocks are used in the downsampling part of the U-Net to reduce the spatial dimensions | |
of the feature maps while increasing the depth. This function can create blocks with or without | |
cross attention based on the specified parameters. | |
Parameters: | |
- down_block_type (str): The type of down block to instantiate. | |
- num_layers (int): The number of layers in the block. | |
- in_channels (int): The number of input channels. | |
- out_channels (int): The number of output channels. | |
- temb_channels (int): The number of token embedding channels. | |
- add_downsample (bool): Flag to add a downsampling layer. | |
- resnet_eps (float): Epsilon for residual block stability. | |
- resnet_act_fn (callable): Activation function for the residual block. | |
- ... (remaining parameters): Additional parameters for configuring the block. | |
Returns: | |
- nn.Module: An instance of a down-sampling block module. | |
""" | |
down_block_type = ( | |
down_block_type[7:] | |
if down_block_type.startswith("UNetRes") | |
else down_block_type | |
) | |
if down_block_type == "DownBlock3D": | |
return DownBlock3D( | |
num_layers=num_layers, | |
in_channels=in_channels, | |
out_channels=out_channels, | |
temb_channels=temb_channels, | |
add_downsample=add_downsample, | |
resnet_eps=resnet_eps, | |
resnet_act_fn=resnet_act_fn, | |
resnet_groups=resnet_groups, | |
downsample_padding=downsample_padding, | |
resnet_time_scale_shift=resnet_time_scale_shift, | |
use_inflated_groupnorm=use_inflated_groupnorm, | |
use_motion_module=use_motion_module, | |
motion_module_type=motion_module_type, | |
motion_module_kwargs=motion_module_kwargs, | |
) | |
if down_block_type == "CrossAttnDownBlock3D": | |
if cross_attention_dim is None: | |
raise ValueError( | |
"cross_attention_dim must be specified for CrossAttnDownBlock3D" | |
) | |
return CrossAttnDownBlock3D( | |
num_layers=num_layers, | |
in_channels=in_channels, | |
out_channels=out_channels, | |
temb_channels=temb_channels, | |
add_downsample=add_downsample, | |
resnet_eps=resnet_eps, | |
resnet_act_fn=resnet_act_fn, | |
resnet_groups=resnet_groups, | |
downsample_padding=downsample_padding, | |
cross_attention_dim=cross_attention_dim, | |
audio_attention_dim=audio_attention_dim, | |
attn_num_head_channels=attn_num_head_channels, | |
dual_cross_attention=dual_cross_attention, | |
use_linear_projection=use_linear_projection, | |
only_cross_attention=only_cross_attention, | |
upcast_attention=upcast_attention, | |
resnet_time_scale_shift=resnet_time_scale_shift, | |
unet_use_cross_frame_attention=unet_use_cross_frame_attention, | |
unet_use_temporal_attention=unet_use_temporal_attention, | |
use_inflated_groupnorm=use_inflated_groupnorm, | |
use_motion_module=use_motion_module, | |
motion_module_type=motion_module_type, | |
motion_module_kwargs=motion_module_kwargs, | |
use_audio_module=use_audio_module, | |
depth=depth, | |
stack_enable_blocks_name=stack_enable_blocks_name, | |
stack_enable_blocks_depth=stack_enable_blocks_depth, | |
) | |
raise ValueError(f"{down_block_type} does not exist.") | |
def get_up_block( | |
up_block_type, | |
num_layers, | |
in_channels, | |
out_channels, | |
prev_output_channel, | |
temb_channels, | |
add_upsample, | |
resnet_eps, | |
resnet_act_fn, | |
attn_num_head_channels, | |
resnet_groups=None, | |
cross_attention_dim=None, | |
audio_attention_dim=None, | |
dual_cross_attention=False, | |
use_linear_projection=False, | |
only_cross_attention=False, | |
upcast_attention=False, | |
resnet_time_scale_shift="default", | |
unet_use_cross_frame_attention=None, | |
unet_use_temporal_attention=None, | |
use_inflated_groupnorm=None, | |
use_motion_module=None, | |
motion_module_type=None, | |
motion_module_kwargs=None, | |
use_audio_module=None, | |
depth=0, | |
stack_enable_blocks_name=None, | |
stack_enable_blocks_depth=None, | |
): | |
""" | |
Factory function to instantiate an up-block module for the 3D UNet architecture. | |
Up blocks are used in the upsampling part of the U-Net to increase the spatial dimensions | |
of the feature maps while decreasing the depth. This function can create blocks with or without | |
cross attention based on the specified parameters. | |
Parameters: | |
- up_block_type (str): The type of up block to instantiate. | |
- num_layers (int): The number of layers in the block. | |
- in_channels (int): The number of input channels. | |
- out_channels (int): The number of output channels. | |
- prev_output_channel (int): The number of channels from the previous layer's output. | |
- temb_channels (int): The number of token embedding channels. | |
- add_upsample (bool): Flag to add an upsampling layer. | |
- resnet_eps (float): Epsilon for residual block stability. | |
- resnet_act_fn (callable): Activation function for the residual block. | |
- ... (remaining parameters): Additional parameters for configuring the block. | |
Returns: | |
- nn.Module: An instance of an up-sampling block module. | |
""" | |
up_block_type = ( | |
up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type | |
) | |
if up_block_type == "UpBlock3D": | |
return UpBlock3D( | |
num_layers=num_layers, | |
in_channels=in_channels, | |
out_channels=out_channels, | |
prev_output_channel=prev_output_channel, | |
temb_channels=temb_channels, | |
add_upsample=add_upsample, | |
resnet_eps=resnet_eps, | |
resnet_act_fn=resnet_act_fn, | |
resnet_groups=resnet_groups, | |
resnet_time_scale_shift=resnet_time_scale_shift, | |
use_inflated_groupnorm=use_inflated_groupnorm, | |
use_motion_module=use_motion_module, | |
motion_module_type=motion_module_type, | |
motion_module_kwargs=motion_module_kwargs, | |
) | |
if up_block_type == "CrossAttnUpBlock3D": | |
if cross_attention_dim is None: | |
raise ValueError( | |
"cross_attention_dim must be specified for CrossAttnUpBlock3D" | |
) | |
return CrossAttnUpBlock3D( | |
num_layers=num_layers, | |
in_channels=in_channels, | |
out_channels=out_channels, | |
prev_output_channel=prev_output_channel, | |
temb_channels=temb_channels, | |
add_upsample=add_upsample, | |
resnet_eps=resnet_eps, | |
resnet_act_fn=resnet_act_fn, | |
resnet_groups=resnet_groups, | |
cross_attention_dim=cross_attention_dim, | |
audio_attention_dim=audio_attention_dim, | |
attn_num_head_channels=attn_num_head_channels, | |
dual_cross_attention=dual_cross_attention, | |
use_linear_projection=use_linear_projection, | |
only_cross_attention=only_cross_attention, | |
upcast_attention=upcast_attention, | |
resnet_time_scale_shift=resnet_time_scale_shift, | |
unet_use_cross_frame_attention=unet_use_cross_frame_attention, | |
unet_use_temporal_attention=unet_use_temporal_attention, | |
use_inflated_groupnorm=use_inflated_groupnorm, | |
use_motion_module=use_motion_module, | |
motion_module_type=motion_module_type, | |
motion_module_kwargs=motion_module_kwargs, | |
use_audio_module=use_audio_module, | |
depth=depth, | |
stack_enable_blocks_name=stack_enable_blocks_name, | |
stack_enable_blocks_depth=stack_enable_blocks_depth, | |
) | |
raise ValueError(f"{up_block_type} does not exist.") | |
class UNetMidBlock3DCrossAttn(nn.Module): | |
""" | |
A 3D UNet middle block with cross attention mechanism. This block is part of the U-Net architecture | |
and is used for feature extraction in the middle of the downsampling path. | |
Parameters: | |
- in_channels (int): Number of input channels. | |
- temb_channels (int): Number of token embedding channels. | |
- dropout (float): Dropout rate. | |
- num_layers (int): Number of layers in the block. | |
- resnet_eps (float): Epsilon for residual block. | |
- resnet_time_scale_shift (str): Time scale shift for time embedding normalization. | |
- resnet_act_fn (str): Activation function for the residual block. | |
- resnet_groups (int): Number of groups for the convolutions in the residual block. | |
- resnet_pre_norm (bool): Whether to use pre-normalization in the residual block. | |
- attn_num_head_channels (int): Number of attention heads. | |
- cross_attention_dim (int): Dimensionality of the cross attention layers. | |
- audio_attention_dim (int): Dimensionality of the audio attention layers. | |
- dual_cross_attention (bool): Whether to use dual cross attention. | |
- use_linear_projection (bool): Whether to use linear projection in attention. | |
- upcast_attention (bool): Whether to upcast attention to the original input dimension. | |
- unet_use_cross_frame_attention (bool): Whether to use cross frame attention in U-Net. | |
- unet_use_temporal_attention (bool): Whether to use temporal attention in U-Net. | |
- use_inflated_groupnorm (bool): Whether to use inflated group normalization. | |
- use_motion_module (bool): Whether to use motion module. | |
- motion_module_type (str): Type of motion module. | |
- motion_module_kwargs (dict): Keyword arguments for the motion module. | |
- use_audio_module (bool): Whether to use audio module. | |
- depth (int): Depth of the block in the network. | |
- stack_enable_blocks_name (str): Name of the stack enable blocks. | |
- stack_enable_blocks_depth (int): Depth of the stack enable blocks. | |
Forward method: | |
The forward method applies the residual blocks, cross attention, and optional motion and audio modules | |
to the input hidden states. It returns the transformed hidden states. | |
""" | |
def __init__( | |
self, | |
in_channels: int, | |
temb_channels: int, | |
dropout: float = 0.0, | |
num_layers: int = 1, | |
resnet_eps: float = 1e-6, | |
resnet_time_scale_shift: str = "default", | |
resnet_act_fn: str = "swish", | |
resnet_groups: int = 32, | |
resnet_pre_norm: bool = True, | |
attn_num_head_channels=1, | |
output_scale_factor=1.0, | |
cross_attention_dim=1280, | |
audio_attention_dim=1024, | |
dual_cross_attention=False, | |
use_linear_projection=False, | |
upcast_attention=False, | |
unet_use_cross_frame_attention=None, | |
unet_use_temporal_attention=None, | |
use_inflated_groupnorm=None, | |
use_motion_module=None, | |
motion_module_type=None, | |
motion_module_kwargs=None, | |
use_audio_module=None, | |
depth=0, | |
stack_enable_blocks_name=None, | |
stack_enable_blocks_depth=None, | |
): | |
super().__init__() | |
self.has_cross_attention = True | |
self.attn_num_head_channels = attn_num_head_channels | |
resnet_groups = ( | |
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) | |
) | |
# there is always at least one resnet | |
resnets = [ | |
ResnetBlock3D( | |
in_channels=in_channels, | |
out_channels=in_channels, | |
temb_channels=temb_channels, | |
eps=resnet_eps, | |
groups=resnet_groups, | |
dropout=dropout, | |
time_embedding_norm=resnet_time_scale_shift, | |
non_linearity=resnet_act_fn, | |
output_scale_factor=output_scale_factor, | |
pre_norm=resnet_pre_norm, | |
use_inflated_groupnorm=use_inflated_groupnorm, | |
) | |
] | |
attentions = [] | |
motion_modules = [] | |
audio_modules = [] | |
for _ in range(num_layers): | |
if dual_cross_attention: | |
raise NotImplementedError | |
attentions.append( | |
Transformer3DModel( | |
attn_num_head_channels, | |
in_channels // attn_num_head_channels, | |
in_channels=in_channels, | |
num_layers=1, | |
cross_attention_dim=cross_attention_dim, | |
norm_num_groups=resnet_groups, | |
use_linear_projection=use_linear_projection, | |
upcast_attention=upcast_attention, | |
unet_use_cross_frame_attention=unet_use_cross_frame_attention, | |
unet_use_temporal_attention=unet_use_temporal_attention, | |
) | |
) | |
audio_modules.append( | |
Transformer3DModel( | |
attn_num_head_channels, | |
in_channels // attn_num_head_channels, | |
in_channels=in_channels, | |
num_layers=1, | |
cross_attention_dim=audio_attention_dim, | |
norm_num_groups=resnet_groups, | |
use_linear_projection=use_linear_projection, | |
upcast_attention=upcast_attention, | |
use_audio_module=use_audio_module, | |
depth=depth, | |
unet_block_name="mid", | |
stack_enable_blocks_name=stack_enable_blocks_name, | |
stack_enable_blocks_depth=stack_enable_blocks_depth, | |
) | |
if use_audio_module | |
else None | |
) | |
motion_modules.append( | |
get_motion_module( | |
in_channels=in_channels, | |
motion_module_type=motion_module_type, | |
motion_module_kwargs=motion_module_kwargs, | |
) | |
if use_motion_module | |
else None | |
) | |
resnets.append( | |
ResnetBlock3D( | |
in_channels=in_channels, | |
out_channels=in_channels, | |
temb_channels=temb_channels, | |
eps=resnet_eps, | |
groups=resnet_groups, | |
dropout=dropout, | |
time_embedding_norm=resnet_time_scale_shift, | |
non_linearity=resnet_act_fn, | |
output_scale_factor=output_scale_factor, | |
pre_norm=resnet_pre_norm, | |
use_inflated_groupnorm=use_inflated_groupnorm, | |
) | |
) | |
self.attentions = nn.ModuleList(attentions) | |
self.resnets = nn.ModuleList(resnets) | |
self.audio_modules = nn.ModuleList(audio_modules) | |
self.motion_modules = nn.ModuleList(motion_modules) | |
def forward( | |
self, | |
hidden_states, | |
temb=None, | |
encoder_hidden_states=None, | |
attention_mask=None, | |
full_mask=None, | |
face_mask=None, | |
lip_mask=None, | |
audio_embedding=None, | |
motion_scale=None, | |
): | |
""" | |
Forward pass for the UNetMidBlock3DCrossAttn class. | |
Args: | |
self (UNetMidBlock3DCrossAttn): An instance of the UNetMidBlock3DCrossAttn class. | |
hidden_states (Tensor): The input hidden states tensor. | |
temb (Tensor, optional): The input temporal embedding tensor. Defaults to None. | |
encoder_hidden_states (Tensor, optional): The encoder hidden states tensor. Defaults to None. | |
attention_mask (Tensor, optional): The attention mask tensor. Defaults to None. | |
full_mask (Tensor, optional): The full mask tensor. Defaults to None. | |
face_mask (Tensor, optional): The face mask tensor. Defaults to None. | |
lip_mask (Tensor, optional): The lip mask tensor. Defaults to None. | |
audio_embedding (Tensor, optional): The audio embedding tensor. Defaults to None. | |
Returns: | |
Tensor: The output tensor after passing through the UNetMidBlock3DCrossAttn layers. | |
""" | |
hidden_states = self.resnets[0](hidden_states, temb) | |
for attn, resnet, audio_module, motion_module in zip( | |
self.attentions, self.resnets[1:], self.audio_modules, self.motion_modules | |
): | |
hidden_states, motion_frame = attn( | |
hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
return_dict=False, | |
) # .sample | |
if len(motion_frame[0]) > 0: | |
# if motion_frame[0][0].numel() > 0: | |
motion_frames = motion_frame[0][0] | |
motion_frames = rearrange( | |
motion_frames, | |
"b f (d1 d2) c -> b c f d1 d2", | |
d1=hidden_states.size(-1), | |
) | |
else: | |
motion_frames = torch.zeros( | |
hidden_states.shape[0], | |
hidden_states.shape[1], | |
4, | |
hidden_states.shape[3], | |
hidden_states.shape[4], | |
) | |
n_motion_frames = motion_frames.size(2) | |
if audio_module is not None: | |
hidden_states = ( | |
audio_module( | |
hidden_states, | |
encoder_hidden_states=audio_embedding, | |
attention_mask=attention_mask, | |
full_mask=full_mask, | |
face_mask=face_mask, | |
lip_mask=lip_mask, | |
motion_scale=motion_scale, | |
return_dict=False, | |
) | |
)[0] # .sample | |
if motion_module is not None: | |
motion_frames = motion_frames.to( | |
device=hidden_states.device, dtype=hidden_states.dtype | |
) | |
_hidden_states = ( | |
torch.cat([motion_frames, hidden_states], dim=2) | |
if n_motion_frames > 0 | |
else hidden_states | |
) | |
hidden_states = motion_module( | |
_hidden_states, encoder_hidden_states=encoder_hidden_states | |
) | |
hidden_states = hidden_states[:, :, n_motion_frames:] | |
hidden_states = resnet(hidden_states, temb) | |
return hidden_states | |
class CrossAttnDownBlock3D(nn.Module): | |
""" | |
A 3D downsampling block with cross attention for the U-Net architecture. | |
Parameters: | |
- (same as above, refer to the constructor for details) | |
Forward method: | |
The forward method downsamples the input hidden states using residual blocks and cross attention. | |
It also applies optional motion and audio modules. The method supports gradient checkpointing | |
to save memory during training. | |
""" | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
temb_channels: int, | |
dropout: float = 0.0, | |
num_layers: int = 1, | |
resnet_eps: float = 1e-6, | |
resnet_time_scale_shift: str = "default", | |
resnet_act_fn: str = "swish", | |
resnet_groups: int = 32, | |
resnet_pre_norm: bool = True, | |
attn_num_head_channels=1, | |
cross_attention_dim=1280, | |
audio_attention_dim=1024, | |
output_scale_factor=1.0, | |
downsample_padding=1, | |
add_downsample=True, | |
dual_cross_attention=False, | |
use_linear_projection=False, | |
only_cross_attention=False, | |
upcast_attention=False, | |
unet_use_cross_frame_attention=None, | |
unet_use_temporal_attention=None, | |
use_inflated_groupnorm=None, | |
use_motion_module=None, | |
motion_module_type=None, | |
motion_module_kwargs=None, | |
use_audio_module=None, | |
depth=0, | |
stack_enable_blocks_name=None, | |
stack_enable_blocks_depth=None, | |
): | |
super().__init__() | |
resnets = [] | |
attentions = [] | |
audio_modules = [] | |
motion_modules = [] | |
self.has_cross_attention = True | |
self.attn_num_head_channels = attn_num_head_channels | |
for i in range(num_layers): | |
in_channels = in_channels if i == 0 else out_channels | |
resnets.append( | |
ResnetBlock3D( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
temb_channels=temb_channels, | |
eps=resnet_eps, | |
groups=resnet_groups, | |
dropout=dropout, | |
time_embedding_norm=resnet_time_scale_shift, | |
non_linearity=resnet_act_fn, | |
output_scale_factor=output_scale_factor, | |
pre_norm=resnet_pre_norm, | |
use_inflated_groupnorm=use_inflated_groupnorm, | |
) | |
) | |
if dual_cross_attention: | |
raise NotImplementedError | |
attentions.append( | |
Transformer3DModel( | |
attn_num_head_channels, | |
out_channels // attn_num_head_channels, | |
in_channels=out_channels, | |
num_layers=1, | |
cross_attention_dim=cross_attention_dim, | |
norm_num_groups=resnet_groups, | |
use_linear_projection=use_linear_projection, | |
only_cross_attention=only_cross_attention, | |
upcast_attention=upcast_attention, | |
unet_use_cross_frame_attention=unet_use_cross_frame_attention, | |
unet_use_temporal_attention=unet_use_temporal_attention, | |
) | |
) | |
# TODO:检查维度 | |
audio_modules.append( | |
Transformer3DModel( | |
attn_num_head_channels, | |
in_channels // attn_num_head_channels, | |
in_channels=out_channels, | |
num_layers=1, | |
cross_attention_dim=audio_attention_dim, | |
norm_num_groups=resnet_groups, | |
use_linear_projection=use_linear_projection, | |
only_cross_attention=only_cross_attention, | |
upcast_attention=upcast_attention, | |
use_audio_module=use_audio_module, | |
depth=depth, | |
unet_block_name="down", | |
stack_enable_blocks_name=stack_enable_blocks_name, | |
stack_enable_blocks_depth=stack_enable_blocks_depth, | |
) | |
if use_audio_module | |
else None | |
) | |
motion_modules.append( | |
get_motion_module( | |
in_channels=out_channels, | |
motion_module_type=motion_module_type, | |
motion_module_kwargs=motion_module_kwargs, | |
) | |
if use_motion_module | |
else None | |
) | |
self.attentions = nn.ModuleList(attentions) | |
self.resnets = nn.ModuleList(resnets) | |
self.audio_modules = nn.ModuleList(audio_modules) | |
self.motion_modules = nn.ModuleList(motion_modules) | |
if add_downsample: | |
self.downsamplers = nn.ModuleList( | |
[ | |
Downsample3D( | |
out_channels, | |
use_conv=True, | |
out_channels=out_channels, | |
padding=downsample_padding, | |
name="op", | |
) | |
] | |
) | |
else: | |
self.downsamplers = None | |
self.gradient_checkpointing = False | |
def forward( | |
self, | |
hidden_states, | |
temb=None, | |
encoder_hidden_states=None, | |
attention_mask=None, | |
full_mask=None, | |
face_mask=None, | |
lip_mask=None, | |
audio_embedding=None, | |
motion_scale=None, | |
): | |
""" | |
Defines the forward pass for the CrossAttnDownBlock3D class. | |
Parameters: | |
- hidden_states : torch.Tensor | |
The input tensor to the block. | |
temb : torch.Tensor, optional | |
The token embeddings from the previous block. | |
encoder_hidden_states : torch.Tensor, optional | |
The hidden states from the encoder. | |
attention_mask : torch.Tensor, optional | |
The attention mask for the cross-attention mechanism. | |
full_mask : torch.Tensor, optional | |
The full mask for the cross-attention mechanism. | |
face_mask : torch.Tensor, optional | |
The face mask for the cross-attention mechanism. | |
lip_mask : torch.Tensor, optional | |
The lip mask for the cross-attention mechanism. | |
audio_embedding : torch.Tensor, optional | |
The audio embedding for the cross-attention mechanism. | |
Returns: | |
-- torch.Tensor | |
The output tensor from the block. | |
""" | |
output_states = () | |
for _, (resnet, attn, audio_module, motion_module) in enumerate( | |
zip(self.resnets, self.attentions, self.audio_modules, self.motion_modules) | |
): | |
# self.gradient_checkpointing = False | |
if self.training and self.gradient_checkpointing: | |
def create_custom_forward(module, return_dict=None): | |
def custom_forward(*inputs): | |
if return_dict is not None: | |
return module(*inputs, return_dict=return_dict) | |
return module(*inputs) | |
return custom_forward | |
hidden_states = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(resnet), hidden_states, temb | |
) | |
motion_frames = [] | |
hidden_states, motion_frame = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(attn, return_dict=False), | |
hidden_states, | |
encoder_hidden_states, | |
) | |
if len(motion_frame[0]) > 0: | |
motion_frames = motion_frame[0][0] | |
# motion_frames = torch.cat(motion_frames, dim=0) | |
motion_frames = rearrange( | |
motion_frames, | |
"b f (d1 d2) c -> b c f d1 d2", | |
d1=hidden_states.size(-1), | |
) | |
else: | |
motion_frames = torch.zeros( | |
hidden_states.shape[0], | |
hidden_states.shape[1], | |
4, | |
hidden_states.shape[3], | |
hidden_states.shape[4], | |
) | |
n_motion_frames = motion_frames.size(2) | |
if audio_module is not None: | |
# audio_embedding = audio_embedding | |
hidden_states = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(audio_module, return_dict=False), | |
hidden_states, | |
audio_embedding, | |
attention_mask, | |
full_mask, | |
face_mask, | |
lip_mask, | |
motion_scale, | |
)[0] | |
# add motion module | |
if motion_module is not None: | |
motion_frames = motion_frames.to( | |
device=hidden_states.device, dtype=hidden_states.dtype | |
) | |
_hidden_states = torch.cat( | |
[motion_frames, hidden_states], dim=2 | |
) # if n_motion_frames > 0 else hidden_states | |
hidden_states = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(motion_module), | |
_hidden_states, | |
encoder_hidden_states, | |
) | |
hidden_states = hidden_states[:, :, n_motion_frames:] | |
else: | |
hidden_states = resnet(hidden_states, temb) | |
hidden_states = attn( | |
hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
).sample | |
if audio_module is not None: | |
hidden_states = audio_module( | |
hidden_states, | |
audio_embedding, | |
attention_mask=attention_mask, | |
full_mask=full_mask, | |
face_mask=face_mask, | |
lip_mask=lip_mask, | |
return_dict=False, | |
)[0] | |
# add motion module | |
if motion_module is not None: | |
hidden_states = motion_module( | |
hidden_states, encoder_hidden_states=encoder_hidden_states | |
) | |
output_states += (hidden_states,) | |
if self.downsamplers is not None: | |
for downsampler in self.downsamplers: | |
hidden_states = downsampler(hidden_states) | |
output_states += (hidden_states,) | |
return hidden_states, output_states | |
class DownBlock3D(nn.Module): | |
""" | |
A 3D downsampling block for the U-Net architecture. This block performs downsampling operations | |
using residual blocks and an optional motion module. | |
Parameters: | |
- in_channels (int): Number of input channels. | |
- out_channels (int): Number of output channels. | |
- temb_channels (int): Number of token embedding channels. | |
- dropout (float): Dropout rate for the block. | |
- num_layers (int): Number of layers in the block. | |
- resnet_eps (float): Epsilon for residual block stability. | |
- resnet_time_scale_shift (str): Time scale shift for the residual block's time embedding. | |
- resnet_act_fn (str): Activation function used in the residual block. | |
- resnet_groups (int): Number of groups for the convolutions in the residual block. | |
- resnet_pre_norm (bool): Whether to use pre-normalization in the residual block. | |
- output_scale_factor (float): Scaling factor for the block's output. | |
- add_downsample (bool): Whether to add a downsampling layer. | |
- downsample_padding (int): Padding for the downsampling layer. | |
- use_inflated_groupnorm (bool): Whether to use inflated group normalization. | |
- use_motion_module (bool): Whether to include a motion module. | |
- motion_module_type (str): Type of motion module to use. | |
- motion_module_kwargs (dict): Keyword arguments for the motion module. | |
Forward method: | |
The forward method processes the input hidden states through the residual blocks and optional | |
motion modules, followed by an optional downsampling step. It supports gradient checkpointing | |
during training to reduce memory usage. | |
""" | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
temb_channels: int, | |
dropout: float = 0.0, | |
num_layers: int = 1, | |
resnet_eps: float = 1e-6, | |
resnet_time_scale_shift: str = "default", | |
resnet_act_fn: str = "swish", | |
resnet_groups: int = 32, | |
resnet_pre_norm: bool = True, | |
output_scale_factor=1.0, | |
add_downsample=True, | |
downsample_padding=1, | |
use_inflated_groupnorm=None, | |
use_motion_module=None, | |
motion_module_type=None, | |
motion_module_kwargs=None, | |
): | |
super().__init__() | |
resnets = [] | |
motion_modules = [] | |
# use_motion_module = False | |
for i in range(num_layers): | |
in_channels = in_channels if i == 0 else out_channels | |
resnets.append( | |
ResnetBlock3D( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
temb_channels=temb_channels, | |
eps=resnet_eps, | |
groups=resnet_groups, | |
dropout=dropout, | |
time_embedding_norm=resnet_time_scale_shift, | |
non_linearity=resnet_act_fn, | |
output_scale_factor=output_scale_factor, | |
pre_norm=resnet_pre_norm, | |
use_inflated_groupnorm=use_inflated_groupnorm, | |
) | |
) | |
motion_modules.append( | |
get_motion_module( | |
in_channels=out_channels, | |
motion_module_type=motion_module_type, | |
motion_module_kwargs=motion_module_kwargs, | |
) | |
if use_motion_module | |
else None | |
) | |
self.resnets = nn.ModuleList(resnets) | |
self.motion_modules = nn.ModuleList(motion_modules) | |
if add_downsample: | |
self.downsamplers = nn.ModuleList( | |
[ | |
Downsample3D( | |
out_channels, | |
use_conv=True, | |
out_channels=out_channels, | |
padding=downsample_padding, | |
name="op", | |
) | |
] | |
) | |
else: | |
self.downsamplers = None | |
self.gradient_checkpointing = False | |
def forward( | |
self, | |
hidden_states, | |
temb=None, | |
encoder_hidden_states=None, | |
): | |
""" | |
forward method for the DownBlock3D class. | |
Args: | |
hidden_states (Tensor): The input tensor to the DownBlock3D layer. | |
temb (Tensor, optional): The token embeddings, if using transformer. | |
encoder_hidden_states (Tensor, optional): The hidden states from the encoder. | |
Returns: | |
Tensor: The output tensor after passing through the DownBlock3D layer. | |
""" | |
output_states = () | |
for resnet, motion_module in zip(self.resnets, self.motion_modules): | |
# print(f"DownBlock3D {self.gradient_checkpointing = }") | |
if self.training and self.gradient_checkpointing: | |
def create_custom_forward(module): | |
def custom_forward(*inputs): | |
return module(*inputs) | |
return custom_forward | |
hidden_states = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(resnet), hidden_states, temb | |
) | |
else: | |
hidden_states = resnet(hidden_states, temb) | |
# add motion module | |
hidden_states = ( | |
motion_module( | |
hidden_states, encoder_hidden_states=encoder_hidden_states | |
) | |
if motion_module is not None | |
else hidden_states | |
) | |
output_states += (hidden_states,) | |
if self.downsamplers is not None: | |
for downsampler in self.downsamplers: | |
hidden_states = downsampler(hidden_states) | |
output_states += (hidden_states,) | |
return hidden_states, output_states | |
class CrossAttnUpBlock3D(nn.Module): | |
""" | |
Standard 3D downsampling block for the U-Net architecture. This block performs downsampling | |
operations in the U-Net using residual blocks and an optional motion module. | |
Parameters: | |
- in_channels (int): Number of input channels. | |
- out_channels (int): Number of output channels. | |
- temb_channels (int): Number of channels for the temporal embedding. | |
- dropout (float): Dropout rate for the block. | |
- num_layers (int): Number of layers in the block. | |
- resnet_eps (float): Epsilon for residual block stability. | |
- resnet_time_scale_shift (str): Time scale shift for the residual block's time embedding. | |
- resnet_act_fn (str): Activation function used in the residual block. | |
- resnet_groups (int): Number of groups for the convolutions in the residual block. | |
- resnet_pre_norm (bool): Whether to use pre-normalization in the residual block. | |
- output_scale_factor (float): Scaling factor for the block's output. | |
- add_downsample (bool): Whether to add a downsampling layer. | |
- downsample_padding (int): Padding for the downsampling layer. | |
- use_inflated_groupnorm (bool): Whether to use inflated group normalization. | |
- use_motion_module (bool): Whether to include a motion module. | |
- motion_module_type (str): Type of motion module to use. | |
- motion_module_kwargs (dict): Keyword arguments for the motion module. | |
Forward method: | |
The forward method processes the input hidden states through the residual blocks and optional | |
motion modules, followed by an optional downsampling step. It supports gradient checkpointing | |
during training to reduce memory usage. | |
""" | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
prev_output_channel: int, | |
temb_channels: int, | |
dropout: float = 0.0, | |
num_layers: int = 1, | |
resnet_eps: float = 1e-6, | |
resnet_time_scale_shift: str = "default", | |
resnet_act_fn: str = "swish", | |
resnet_groups: int = 32, | |
resnet_pre_norm: bool = True, | |
attn_num_head_channels=1, | |
cross_attention_dim=1280, | |
audio_attention_dim=1024, | |
output_scale_factor=1.0, | |
add_upsample=True, | |
dual_cross_attention=False, | |
use_linear_projection=False, | |
only_cross_attention=False, | |
upcast_attention=False, | |
unet_use_cross_frame_attention=None, | |
unet_use_temporal_attention=None, | |
use_motion_module=None, | |
use_inflated_groupnorm=None, | |
motion_module_type=None, | |
motion_module_kwargs=None, | |
use_audio_module=None, | |
depth=0, | |
stack_enable_blocks_name=None, | |
stack_enable_blocks_depth=None, | |
): | |
super().__init__() | |
resnets = [] | |
attentions = [] | |
audio_modules = [] | |
motion_modules = [] | |
self.has_cross_attention = True | |
self.attn_num_head_channels = attn_num_head_channels | |
for i in range(num_layers): | |
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels | |
resnet_in_channels = prev_output_channel if i == 0 else out_channels | |
resnets.append( | |
ResnetBlock3D( | |
in_channels=resnet_in_channels + res_skip_channels, | |
out_channels=out_channels, | |
temb_channels=temb_channels, | |
eps=resnet_eps, | |
groups=resnet_groups, | |
dropout=dropout, | |
time_embedding_norm=resnet_time_scale_shift, | |
non_linearity=resnet_act_fn, | |
output_scale_factor=output_scale_factor, | |
pre_norm=resnet_pre_norm, | |
use_inflated_groupnorm=use_inflated_groupnorm, | |
) | |
) | |
if dual_cross_attention: | |
raise NotImplementedError | |
attentions.append( | |
Transformer3DModel( | |
attn_num_head_channels, | |
out_channels // attn_num_head_channels, | |
in_channels=out_channels, | |
num_layers=1, | |
cross_attention_dim=cross_attention_dim, | |
norm_num_groups=resnet_groups, | |
use_linear_projection=use_linear_projection, | |
only_cross_attention=only_cross_attention, | |
upcast_attention=upcast_attention, | |
unet_use_cross_frame_attention=unet_use_cross_frame_attention, | |
unet_use_temporal_attention=unet_use_temporal_attention, | |
) | |
) | |
audio_modules.append( | |
Transformer3DModel( | |
attn_num_head_channels, | |
in_channels // attn_num_head_channels, | |
in_channels=out_channels, | |
num_layers=1, | |
cross_attention_dim=audio_attention_dim, | |
norm_num_groups=resnet_groups, | |
use_linear_projection=use_linear_projection, | |
only_cross_attention=only_cross_attention, | |
upcast_attention=upcast_attention, | |
use_audio_module=use_audio_module, | |
depth=depth, | |
unet_block_name="up", | |
stack_enable_blocks_name=stack_enable_blocks_name, | |
stack_enable_blocks_depth=stack_enable_blocks_depth, | |
) | |
if use_audio_module | |
else None | |
) | |
motion_modules.append( | |
get_motion_module( | |
in_channels=out_channels, | |
motion_module_type=motion_module_type, | |
motion_module_kwargs=motion_module_kwargs, | |
) | |
if use_motion_module | |
else None | |
) | |
self.attentions = nn.ModuleList(attentions) | |
self.resnets = nn.ModuleList(resnets) | |
self.audio_modules = nn.ModuleList(audio_modules) | |
self.motion_modules = nn.ModuleList(motion_modules) | |
if add_upsample: | |
self.upsamplers = nn.ModuleList( | |
[Upsample3D(out_channels, use_conv=True, out_channels=out_channels)] | |
) | |
else: | |
self.upsamplers = None | |
self.gradient_checkpointing = False | |
def forward( | |
self, | |
hidden_states, | |
res_hidden_states_tuple, | |
temb=None, | |
encoder_hidden_states=None, | |
upsample_size=None, | |
attention_mask=None, | |
full_mask=None, | |
face_mask=None, | |
lip_mask=None, | |
audio_embedding=None, | |
motion_scale=None, | |
): | |
""" | |
Forward pass for the CrossAttnUpBlock3D class. | |
Args: | |
self (CrossAttnUpBlock3D): An instance of the CrossAttnUpBlock3D class. | |
hidden_states (Tensor): The input hidden states tensor. | |
res_hidden_states_tuple (Tuple[Tensor]): A tuple of residual hidden states tensors. | |
temb (Tensor, optional): The token embeddings tensor. Defaults to None. | |
encoder_hidden_states (Tensor, optional): The encoder hidden states tensor. Defaults to None. | |
upsample_size (int, optional): The upsample size. Defaults to None. | |
attention_mask (Tensor, optional): The attention mask tensor. Defaults to None. | |
full_mask (Tensor, optional): The full mask tensor. Defaults to None. | |
face_mask (Tensor, optional): The face mask tensor. Defaults to None. | |
lip_mask (Tensor, optional): The lip mask tensor. Defaults to None. | |
audio_embedding (Tensor, optional): The audio embedding tensor. Defaults to None. | |
Returns: | |
Tensor: The output tensor after passing through the CrossAttnUpBlock3D. | |
""" | |
for _, (resnet, attn, audio_module, motion_module) in enumerate( | |
zip(self.resnets, self.attentions, self.audio_modules, self.motion_modules) | |
): | |
# pop res hidden states | |
res_hidden_states = res_hidden_states_tuple[-1] | |
res_hidden_states_tuple = res_hidden_states_tuple[:-1] | |
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) | |
if self.training and self.gradient_checkpointing: | |
def create_custom_forward(module, return_dict=None): | |
def custom_forward(*inputs): | |
if return_dict is not None: | |
return module(*inputs, return_dict=return_dict) | |
return module(*inputs) | |
return custom_forward | |
hidden_states = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(resnet), hidden_states, temb | |
) | |
motion_frames = [] | |
hidden_states, motion_frame = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(attn, return_dict=False), | |
hidden_states, | |
encoder_hidden_states, | |
) | |
if len(motion_frame[0]) > 0: | |
motion_frames = motion_frame[0][0] | |
# motion_frames = torch.cat(motion_frames, dim=0) | |
motion_frames = rearrange( | |
motion_frames, | |
"b f (d1 d2) c -> b c f d1 d2", | |
d1=hidden_states.size(-1), | |
) | |
else: | |
motion_frames = torch.zeros( | |
hidden_states.shape[0], | |
hidden_states.shape[1], | |
4, | |
hidden_states.shape[3], | |
hidden_states.shape[4], | |
) | |
n_motion_frames = motion_frames.size(2) | |
if audio_module is not None: | |
# audio_embedding = audio_embedding | |
hidden_states = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(audio_module, return_dict=False), | |
hidden_states, | |
audio_embedding, | |
attention_mask, | |
full_mask, | |
face_mask, | |
lip_mask, | |
motion_scale, | |
)[0] | |
# add motion module | |
if motion_module is not None: | |
motion_frames = motion_frames.to( | |
device=hidden_states.device, dtype=hidden_states.dtype | |
) | |
_hidden_states = ( | |
torch.cat([motion_frames, hidden_states], dim=2) | |
if n_motion_frames > 0 | |
else hidden_states | |
) | |
hidden_states = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(motion_module), | |
_hidden_states, | |
encoder_hidden_states, | |
) | |
hidden_states = hidden_states[:, :, n_motion_frames:] | |
else: | |
hidden_states = resnet(hidden_states, temb) | |
hidden_states = attn( | |
hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
).sample | |
if audio_module is not None: | |
hidden_states = ( | |
audio_module( | |
hidden_states, | |
encoder_hidden_states=audio_embedding, | |
attention_mask=attention_mask, | |
full_mask=full_mask, | |
face_mask=face_mask, | |
lip_mask=lip_mask, | |
) | |
).sample | |
# add motion module | |
hidden_states = ( | |
motion_module( | |
hidden_states, encoder_hidden_states=encoder_hidden_states | |
) | |
if motion_module is not None | |
else hidden_states | |
) | |
if self.upsamplers is not None: | |
for upsampler in self.upsamplers: | |
hidden_states = upsampler(hidden_states, upsample_size) | |
return hidden_states | |
class UpBlock3D(nn.Module): | |
""" | |
3D upsampling block with cross attention for the U-Net architecture. This block performs | |
upsampling operations and incorporates cross attention mechanisms, which allow the model to | |
focus on different parts of the input when upscaling. | |
Parameters: | |
- in_channels (int): Number of input channels. | |
- out_channels (int): Number of output channels. | |
- prev_output_channel (int): Number of channels from the previous layer's output. | |
- temb_channels (int): Number of channels for the temporal embedding. | |
- dropout (float): Dropout rate for the block. | |
- num_layers (int): Number of layers in the block. | |
- resnet_eps (float): Epsilon for residual block stability. | |
- resnet_time_scale_shift (str): Time scale shift for the residual block's time embedding. | |
- resnet_act_fn (str): Activation function used in the residual block. | |
- resnet_groups (int): Number of groups for the convolutions in the residual block. | |
- resnet_pre_norm (bool): Whether to use pre-normalization in the residual block. | |
- attn_num_head_channels (int): Number of attention heads for the cross attention mechanism. | |
- cross_attention_dim (int): Dimensionality of the cross attention layers. | |
- audio_attention_dim (int): Dimensionality of the audio attention layers. | |
- output_scale_factor (float): Scaling factor for the block's output. | |
- add_upsample (bool): Whether to add an upsampling layer. | |
- dual_cross_attention (bool): Whether to use dual cross attention (not implemented). | |
- use_linear_projection (bool): Whether to use linear projection in the cross attention. | |
- only_cross_attention (bool): Whether to use only cross attention (no self-attention). | |
- upcast_attention (bool): Whether to upcast attention to the original input dimension. | |
- unet_use_cross_frame_attention (bool): Whether to use cross frame attention in U-Net. | |
- unet_use_temporal_attention (bool): Whether to use temporal attention in U-Net. | |
- use_motion_module (bool): Whether to include a motion module. | |
- use_inflated_groupnorm (bool): Whether to use inflated group normalization. | |
- motion_module_type (str): Type of motion module to use. | |
- motion_module_kwargs (dict): Keyword arguments for the motion module. | |
- use_audio_module (bool): Whether to include an audio module. | |
- depth (int): Depth of the block in the network. | |
- stack_enable_blocks_name (str): Name of the stack enable blocks. | |
- stack_enable_blocks_depth (int): Depth of the stack enable blocks. | |
Forward method: | |
The forward method upsamples the input hidden states and residual hidden states, processes | |
them through the residual and cross attention blocks, and optional motion and audio modules. | |
It supports gradient checkpointing during training. | |
""" | |
def __init__( | |
self, | |
in_channels: int, | |
prev_output_channel: int, | |
out_channels: int, | |
temb_channels: int, | |
dropout: float = 0.0, | |
num_layers: int = 1, | |
resnet_eps: float = 1e-6, | |
resnet_time_scale_shift: str = "default", | |
resnet_act_fn: str = "swish", | |
resnet_groups: int = 32, | |
resnet_pre_norm: bool = True, | |
output_scale_factor=1.0, | |
add_upsample=True, | |
use_inflated_groupnorm=None, | |
use_motion_module=None, | |
motion_module_type=None, | |
motion_module_kwargs=None, | |
): | |
super().__init__() | |
resnets = [] | |
motion_modules = [] | |
# use_motion_module = False | |
for i in range(num_layers): | |
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels | |
resnet_in_channels = prev_output_channel if i == 0 else out_channels | |
resnets.append( | |
ResnetBlock3D( | |
in_channels=resnet_in_channels + res_skip_channels, | |
out_channels=out_channels, | |
temb_channels=temb_channels, | |
eps=resnet_eps, | |
groups=resnet_groups, | |
dropout=dropout, | |
time_embedding_norm=resnet_time_scale_shift, | |
non_linearity=resnet_act_fn, | |
output_scale_factor=output_scale_factor, | |
pre_norm=resnet_pre_norm, | |
use_inflated_groupnorm=use_inflated_groupnorm, | |
) | |
) | |
motion_modules.append( | |
get_motion_module( | |
in_channels=out_channels, | |
motion_module_type=motion_module_type, | |
motion_module_kwargs=motion_module_kwargs, | |
) | |
if use_motion_module | |
else None | |
) | |
self.resnets = nn.ModuleList(resnets) | |
self.motion_modules = nn.ModuleList(motion_modules) | |
if add_upsample: | |
self.upsamplers = nn.ModuleList( | |
[Upsample3D(out_channels, use_conv=True, out_channels=out_channels)] | |
) | |
else: | |
self.upsamplers = None | |
self.gradient_checkpointing = False | |
def forward( | |
self, | |
hidden_states, | |
res_hidden_states_tuple, | |
temb=None, | |
upsample_size=None, | |
encoder_hidden_states=None, | |
): | |
""" | |
Forward pass for the UpBlock3D class. | |
Args: | |
self (UpBlock3D): An instance of the UpBlock3D class. | |
hidden_states (Tensor): The input hidden states tensor. | |
res_hidden_states_tuple (Tuple[Tensor]): A tuple of residual hidden states tensors. | |
temb (Tensor, optional): The token embeddings tensor. Defaults to None. | |
upsample_size (int, optional): The upsample size. Defaults to None. | |
encoder_hidden_states (Tensor, optional): The encoder hidden states tensor. Defaults to None. | |
Returns: | |
Tensor: The output tensor after passing through the UpBlock3D layers. | |
""" | |
for resnet, motion_module in zip(self.resnets, self.motion_modules): | |
# pop res hidden states | |
res_hidden_states = res_hidden_states_tuple[-1] | |
res_hidden_states_tuple = res_hidden_states_tuple[:-1] | |
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) | |
# print(f"UpBlock3D {self.gradient_checkpointing = }") | |
if self.training and self.gradient_checkpointing: | |
def create_custom_forward(module): | |
def custom_forward(*inputs): | |
return module(*inputs) | |
return custom_forward | |
hidden_states = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(resnet), hidden_states, temb | |
) | |
else: | |
hidden_states = resnet(hidden_states, temb) | |
hidden_states = ( | |
motion_module( | |
hidden_states, encoder_hidden_states=encoder_hidden_states | |
) | |
if motion_module is not None | |
else hidden_states | |
) | |
if self.upsamplers is not None: | |
for upsampler in self.upsamplers: | |
hidden_states = upsampler(hidden_states, upsample_size) | |
return hidden_states | |