| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """Simple implementation of AutoEncoderKL for LTX v0.95.""" |
|
|
| from einops import rearrange |
| import torch |
| import torch.nn as nn |
|
|
| from diffusers.configuration_utils import ConfigMixin, register_to_config |
| from diffusers.models.modeling_outputs import AutoencoderKLOutput |
| from diffusers.models.modeling_utils import ModelMixin |
|
|
| from diffnext.models.autoencoders.modeling_utils import DiagonalGaussianDistribution |
| from diffnext.models.autoencoders.modeling_utils import DecoderOutput, TilingMixin |
|
|
|
|
| class Conv3d(nn.Conv3d): |
| """3D convolution.""" |
|
|
| def __init__(self, *args, **kwargs): |
| self.causal = kwargs.pop("causal", True) |
| super(Conv3d, self).__init__(*args, **kwargs) |
| self.padding = (0,) + tuple((_ // 2 for _ in self.kernel_size[1:])) |
| self.pad1 = nn.ReplicationPad3d((0,) * 4 + (self.kernel_size[0] - 1, 0)) |
| self.pad2 = nn.ReplicationPad3d((0,) * 4 + (self.pad1.padding[-2] // 2,) * 2) |
| self.pad1 = nn.Identity() if self.kernel_size[0] == 1 else self.pad1 |
| self.pad2 = nn.Identity() if self.kernel_size[0] == 1 else self.pad2 |
|
|
| def forward(self, x): |
| return super().forward(self.pad1(x) if self.causal else self.pad2(x)) |
|
|
|
|
| class RMSNorm(nn.Module): |
| """RMS normalization.""" |
|
|
| def forward(self, x): |
| |
| return x.mul(x.float().square().mean(-1, True).add_(1e-8).rsqrt().to(x.dtype)) |
|
|
|
|
| class TimeEmbed(nn.Module): |
| """Time embedding layer.""" |
|
|
| def __init__(self, embed_dim, freq_dim=256): |
| super(TimeEmbed, self).__init__() |
| self.timestep_proj = nn.Module() |
| self.timestep_proj.fc1 = nn.Linear(freq_dim, embed_dim) |
| self.timestep_proj.fc2 = nn.Linear(embed_dim, embed_dim) |
| self.freq_dim, self.time_freq = freq_dim, None |
|
|
| def get_freq_embed(self, timestep) -> torch.Tensor: |
| if self.time_freq is None: |
| dim, log_theta = self.freq_dim // 2, 9.210340371976184 |
| freq = torch.arange(dim, dtype=torch.float32, device=timestep.device) |
| self.time_freq = freq.mul(-log_theta / dim).exp().unsqueeze_(0) |
| emb = timestep.unsqueeze(-1).float() * self.time_freq |
| return torch.cat([emb.cos(), emb.sin()], dim=-1).to(dtype=timestep.dtype) |
|
|
| def forward(self, temb) -> torch.Tensor: |
| x = self.get_freq_embed(temb) if temb.dim() == 1 else temb |
| return self.timestep_proj.fc2(nn.functional.silu(self.timestep_proj.fc1(x))) |
|
|
|
|
| class ResBlock(nn.Module): |
| """Resnet block.""" |
|
|
| def __init__(self, dim, out_dim, causal=True): |
| super(ResBlock, self).__init__() |
| self.norm1, self.norm2 = RMSNorm(), RMSNorm() |
| self.conv1 = Conv3d(dim, out_dim, 3, causal=causal) |
| self.conv2 = Conv3d(out_dim, out_dim, 3, causal=causal) |
| self.nonlinearity, self.dropout = nn.SiLU(), nn.Dropout(0, inplace=True) |
| self.scale_shift_table = None if causal else nn.Parameter(torch.randn(4, dim) / dim**0.5) |
|
|
| def forward(self, x: torch.Tensor, temb: torch.Tensor = None) -> torch.Tensor: |
| shortcut, stats = x, [] |
| if self.scale_shift_table is not None: |
| stats = temb.add(self.scale_shift_table.view(1, -1))[..., None, None, None].chunk(4, 1) |
| x = self.norm1(x.movedim(1, -1)).movedim(-1, 1) |
| x = x.mul(1 + stats[1]).add_(stats[0]) if stats else x |
| x = self.conv1(self.nonlinearity(x)) |
| x = self.norm2(x.movedim(1, -1)).movedim(-1, 1) |
| x = x.mul(1 + stats[3]).add_(stats[2]) if stats else x |
| return self.conv2(self.dropout(self.nonlinearity(x))).add_(shortcut) |
|
|
|
|
| class MidBlock(nn.Module): |
| """UNet mid block.""" |
|
|
| def __init__(self, dim, depth=1, causal=True): |
| super(MidBlock, self).__init__() |
| self.time_embed = None if causal else TimeEmbed(dim * 4) |
| self.resnets = nn.ModuleList(ResBlock(dim, dim, causal=causal) for _ in range(depth)) |
|
|
| def forward(self, x: torch.Tensor, temb: torch.Tensor = None) -> torch.Tensor: |
| temb = self.time_embed(temb) if self.time_embed else None |
| for resnet in self.resnets: |
| x = resnet(x, temb) |
| return x |
|
|
|
|
| class Downsample(nn.Module): |
| """Residual downsample layer.""" |
|
|
| def __init__(self, dim, out_dim, stride, causal=True): |
| super(Downsample, self).__init__() |
| self.stride = stride = stride if isinstance(stride, (tuple, list)) else (stride,) * 3 |
| self.group_size = (dim * torch.Size(stride).numel()) // out_dim |
| self.pad_t, conv_dim = stride[0] - 1, out_dim // torch.Size(stride).numel() |
| self.conv = Conv3d(dim, conv_dim, 3, 1, causal=causal) |
| self.patch_args = {"r": stride[0], "p": stride[1], "q": stride[2]} |
| self.patch_args["pattern"] = "b c (t r) (h p) (w q) -> b (c r p q) t h w" |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = nn.functional.pad(x, (0,) * 4 + (self.pad_t, 0), "replicate") if self.pad_t else x |
| shortcut = rearrange(x, **self.patch_args).unflatten(1, (-1, self.group_size)).mean(dim=2) |
| return rearrange(self.conv(x), **self.patch_args).add_(shortcut) |
|
|
|
|
| class Upsample(nn.Module): |
| """Residual upsample layer.""" |
|
|
| def __init__(self, dim, out_dim, stride, causal=False): |
| super(Upsample, self).__init__() |
| self.stride = stride = stride if isinstance(stride, (tuple, list)) else (stride,) * 3 |
| self.repeats = (out_dim * torch.Size(stride).numel()) // dim |
| self.slice_t, conv_dim = stride[0] - 1, out_dim * torch.Size(stride).numel() |
| self.conv = Conv3d(dim, conv_dim, 3, 1, causal=causal) |
| self.patch_args = {"r": stride[0], "p": stride[1], "q": stride[2]} |
| self.patch_args["pattern"] = "b (c r p q) t h w -> b c (t r) (h p) (w q)" |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| shortcut = rearrange(x, **self.patch_args).repeat(1, self.repeats, 1, 1, 1) |
| x = rearrange(self.conv(x), **self.patch_args) |
| x = x[:, :, self.slice_t :] if self.slice_t else x |
| return x.add_(shortcut[:, :, self.slice_t :] if self.slice_t else shortcut) |
|
|
|
|
| class DownBlock(nn.Module): |
| """Downsample block.""" |
|
|
| def __init__(self, dim, out_dim, depth=1, causal=True, downsample=""): |
| super(DownBlock, self).__init__() |
| self.resnets, self.downsamplers = nn.ModuleList(), nn.ModuleList() |
| for _ in range(depth): |
| self.resnets.append(ResBlock(dim, dim, causal=causal)) |
| for _ in range(1 if downsample else 0): |
| stride = {"spatial": (1, 2, 2), "temporal": (2, 1, 1), "spatiotemporal": 2}[downsample] |
| self.downsamplers.append(Downsample(dim, out_dim, stride, causal=causal)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| for resnet in self.resnets: |
| x = resnet(x) |
| for downsampler in self.downsamplers: |
| x = downsampler(x) |
| return x |
|
|
|
|
| class UpBlock(nn.Module): |
| """Upsample block.""" |
|
|
| def __init__(self, dim, out_dim, depth=1, causal=False, upscale_factor=2): |
| super(UpBlock, self).__init__() |
| self.time_embed = TimeEmbed(out_dim * 4) |
| self.resnets, self.upsamplers = nn.ModuleList(), nn.ModuleList() |
| for _ in range(1 if upscale_factor > 1 else 0): |
| self.upsamplers.append(Upsample(dim, out_dim, 2, causal=causal)) |
| for _ in range(depth): |
| self.resnets.append(ResBlock(out_dim, out_dim, causal=causal)) |
|
|
| def forward(self, x: torch.Tensor, temb: torch.Tensor = None) -> torch.Tensor: |
| for upsampler in self.upsamplers: |
| x = upsampler(x) |
| temb = self.time_embed(temb) |
| for resnet in self.resnets: |
| x = resnet(x, temb) |
| return x |
|
|
|
|
| class Encoder(nn.Module): |
| """VAE encoder.""" |
|
|
| def __init__(self, dim, out_dim, block_dims, block_depths, patch_size=4): |
| super(Encoder, self).__init__() |
| self.patch_args = {"p": patch_size, "q": patch_size} |
| downsample_type = ["spatial", "temporal", "spatiotemporal", "spatiotemporal"] |
| self.conv_in = Conv3d(dim * patch_size**2, block_dims[0], 3, 1) |
| self.down_blocks = nn.ModuleList() |
| for i, (in_dim, depth, down) in enumerate(zip(block_dims, block_depths, downsample_type)): |
| blk = DownBlock(in_dim, block_dims[i + 1], depth, downsample=down) |
| self.down_blocks.append(blk) |
| self.mid_block = MidBlock(block_dims[-1], block_depths[-1]) |
| self.norm_out, self.conv_act = RMSNorm(), nn.SiLU() |
| self.conv_out = Conv3d(block_dims[-1], out_dim + 1, 3, 1) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = rearrange(x, "b c t (h p) (w q) -> b (c q p) t h w", **self.patch_args) |
| x = self.conv_in(x) |
| for down_block in self.down_blocks: |
| x = down_block(x) |
| x = self.mid_block(x) |
| x = self.norm_out(x.movedim(1, -1)).movedim(-1, 1) |
| return self.conv_out(self.conv_act(x)) |
|
|
|
|
| class Decoder(nn.Module): |
| """VAE decoder.""" |
|
|
| def __init__(self, dim, out_dim, block_dims, block_depths, patch_size=4): |
| super(Decoder, self).__init__() |
| block_dims = tuple(reversed(block_dims)) |
| self.patch_args = {"p": patch_size, "q": patch_size} |
| self.conv_in = Conv3d(dim, block_dims[0], 3, 1, causal=False) |
| self.mid_block = MidBlock(block_dims[0], block_depths[-1], causal=False) |
| self.up_blocks = nn.ModuleList([]) |
| for in_dim, depth in zip(block_dims, block_depths[:-1]): |
| self.up_blocks.append(UpBlock(in_dim, in_dim // 2, depth, upscale_factor=2)) |
| self.norm_out, self.conv_act = RMSNorm(), nn.SiLU() |
| self.conv_out = Conv3d(block_dims[-1], out_dim * patch_size**2, 3, 1, causal=False) |
| self.time_embed = TimeEmbed(block_dims[-1] * 2) |
| self.scale_shift_table = nn.Parameter(torch.randn(2, block_dims[-1])) |
| self.timestep_scale = nn.Parameter(torch.tensor(1000, dtype=torch.float32)) |
|
|
| def forward(self, x: torch.Tensor, temb: torch.Tensor = None) -> torch.Tensor: |
| x = self.conv_in(x) |
| temb = self.time_embed.get_freq_embed(temb * self.timestep_scale) |
| x = self.mid_block(x, temb) |
| for up_block in self.up_blocks: |
| x = up_block(x, temb) |
| x = self.norm_out(x.movedim(1, -1)).movedim(-1, 1) |
| temb = self.time_embed(temb) |
| stats = temb.add(self.scale_shift_table.view(1, -1))[..., None, None, None].chunk(2, 1) |
| x = x.mul(1 + stats[1]).add_(stats[0]) |
| x = self.conv_out(self.conv_act(x)) |
| return rearrange(x, "b (c q p) t h w -> b c t (h p) (w q)", **self.patch_args) |
|
|
|
|
| class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, TilingMixin): |
| """AutoEncoder KL.""" |
|
|
| @register_to_config |
| def __init__( |
| self, |
| in_channels=3, |
| out_channels=3, |
| down_block_types=("LTXVideoDownBlock3D",) * 4, |
| block_out_channels=(128, 256, 512, 1024, 2048), |
| layers_per_block=(4, 6, 6, 2, 2), |
| decoder_block_out_channels=(128, 256, 512, 1024), |
| decoder_layers_per_block=(5, 5, 5, 5), |
| act_fn="silu", |
| latent_channels=128, |
| sample_size=1024, |
| scaling_factor=1.0, |
| shift_factor=None, |
| latents_mean=None, |
| latents_std=None, |
| patch_size=4, |
| ): |
| super(AutoencoderKLLTXVideo, self).__init__() |
| TilingMixin.__init__(self, sample_min_t=249, latent_min_t=32, sample_ovr_t=1) |
| channels, layers = block_out_channels, layers_per_block |
| self.encoder = Encoder(in_channels, latent_channels, channels, layers) |
| channels, layers = decoder_block_out_channels, decoder_layers_per_block |
| self.decoder = Decoder(latent_channels, out_channels, channels, layers) |
| self.register_buffer("shift_factors", torch.zeros(latents_mean) if latents_mean else None) |
| self.register_buffer("scaling_factors", torch.ones(latents_std) if latents_std else None) |
| self.latent_dist = DiagonalGaussianDistribution |
|
|
| def scale_(self, x) -> torch.Tensor: |
| """Scale the input latents.""" |
| if self.shift_factors is not None: |
| return x.sub_(self.shift_factors).mul_(self.scaling_factors) |
| x.add_(-self.config.shift_factor) if self.config.shift_factor else None |
| return x.mul_(self.config.scaling_factor) |
|
|
| def unscale_(self, x) -> torch.Tensor: |
| """Unscale the input latents.""" |
| if self.shift_factors is not None: |
| return x.div_(self.scaling_factors).add_(self.shift_factors) |
| x.mul_(1 / self.config.scaling_factor) |
| return x.add_(self.config.shift_factor) if self.config.shift_factor else x |
|
|
| def encode(self, x) -> AutoencoderKLOutput: |
| """Encode the input samples.""" |
| z = self.tiled_encoder(self.forward(x)) |
| posterior = self.latent_dist(z) |
| return AutoencoderKLOutput(latent_dist=posterior) |
|
|
| def decode(self, z, temb: torch.Tensor = None) -> DecoderOutput: |
| """Decode the input latents.""" |
| if temb is None: |
| temb = torch.tensor([0] * z.size(0), dtype=z.dtype, device=z.device) |
| extra_dim = 2 if z.dim() == 4 else None |
| z = z.unsqueeze_(extra_dim) if extra_dim is not None else z |
| x = self.tiled_decoder(self.forward(z), temb=temb) |
| x = x.squeeze_(extra_dim) if extra_dim is not None else x |
| return DecoderOutput(sample=x) |
|
|
| def forward(self, x): |
| return x |
|
|