linyq's picture
Add files using upload-large-folder tool
ca5da2f verified
import torch
import torch.nn as nn
from diffusers import ModelMixin, ConfigMixin
from diffusers.configuration_utils import register_to_config
class ConditionalEmbedder(ModelMixin, ConfigMixin):
"""
Patchifies VAE-encoded conditions (source video or reference image)
into the DiT hidden dimension space via a Conv3d layer.
"""
@register_to_config
def __init__(
self,
in_dim: int = 48,
dim: int = 3072,
patch_size: list = [1, 2, 2],
zero_init: bool = True,
ref_pad_first: bool = False,
):
super().__init__()
kernel_size = tuple(patch_size)
self.patch_embedding = nn.Conv3d(
in_dim, dim, kernel_size=kernel_size, stride=kernel_size
)
self.ref_pad_first = ref_pad_first
if zero_init:
nn.init.zeros_(self.patch_embedding.weight)
nn.init.zeros_(self.patch_embedding.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.patch_embedding(x)