|
from dataclasses import dataclass
|
|
from typing import Optional, Tuple
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from diffusers.utils import BaseOutput, is_torch_version
|
|
from diffusers.utils.torch_utils import randn_tensor
|
|
from diffusers.models.attention_processor import SpatialNorm
|
|
from .unet_causal_3d_blocks import (
|
|
CausalConv3d,
|
|
UNetMidBlockCausal3D,
|
|
get_down_block3d,
|
|
get_up_block3d,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class DecoderOutput(BaseOutput):
|
|
r"""
|
|
Output of decoding method.
|
|
|
|
Args:
|
|
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
|
The decoded output sample from the last layer of the model.
|
|
"""
|
|
|
|
sample: torch.FloatTensor
|
|
|
|
|
|
class EncoderCausal3D(nn.Module):
|
|
r"""
|
|
The `EncoderCausal3D` layer of a variational autoencoder that encodes its input into a latent representation.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels: int = 3,
|
|
out_channels: int = 3,
|
|
down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D",),
|
|
block_out_channels: Tuple[int, ...] = (64,),
|
|
layers_per_block: int = 2,
|
|
norm_num_groups: int = 32,
|
|
act_fn: str = "silu",
|
|
double_z: bool = True,
|
|
mid_block_add_attention=True,
|
|
time_compression_ratio: int = 4,
|
|
spatial_compression_ratio: int = 8,
|
|
):
|
|
super().__init__()
|
|
self.layers_per_block = layers_per_block
|
|
|
|
self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1)
|
|
self.mid_block = None
|
|
self.down_blocks = nn.ModuleList([])
|
|
|
|
|
|
output_channel = block_out_channels[0]
|
|
for i, down_block_type in enumerate(down_block_types):
|
|
input_channel = output_channel
|
|
output_channel = block_out_channels[i]
|
|
is_final_block = i == len(block_out_channels) - 1
|
|
num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
|
|
num_time_downsample_layers = int(np.log2(time_compression_ratio))
|
|
|
|
if time_compression_ratio == 4:
|
|
add_spatial_downsample = bool(i < num_spatial_downsample_layers)
|
|
add_time_downsample = bool(
|
|
i >= (len(block_out_channels) - 1 - num_time_downsample_layers)
|
|
and not is_final_block
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
|
|
|
|
downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
|
|
downsample_stride_T = (2,) if add_time_downsample else (1,)
|
|
downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
|
|
down_block = get_down_block3d(
|
|
down_block_type,
|
|
num_layers=self.layers_per_block,
|
|
in_channels=input_channel,
|
|
out_channels=output_channel,
|
|
add_downsample=bool(add_spatial_downsample or add_time_downsample),
|
|
downsample_stride=downsample_stride,
|
|
resnet_eps=1e-6,
|
|
downsample_padding=0,
|
|
resnet_act_fn=act_fn,
|
|
resnet_groups=norm_num_groups,
|
|
attention_head_dim=output_channel,
|
|
temb_channels=None,
|
|
)
|
|
self.down_blocks.append(down_block)
|
|
|
|
|
|
self.mid_block = UNetMidBlockCausal3D(
|
|
in_channels=block_out_channels[-1],
|
|
resnet_eps=1e-6,
|
|
resnet_act_fn=act_fn,
|
|
output_scale_factor=1,
|
|
resnet_time_scale_shift="default",
|
|
attention_head_dim=block_out_channels[-1],
|
|
resnet_groups=norm_num_groups,
|
|
temb_channels=None,
|
|
add_attention=mid_block_add_attention,
|
|
)
|
|
|
|
|
|
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
|
|
self.conv_act = nn.SiLU()
|
|
|
|
conv_out_channels = 2 * out_channels if double_z else out_channels
|
|
self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3)
|
|
|
|
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
|
r"""The forward method of the `EncoderCausal3D` class."""
|
|
assert len(sample.shape) == 5, "The input tensor should have 5 dimensions"
|
|
|
|
sample = self.conv_in(sample)
|
|
|
|
|
|
for down_block in self.down_blocks:
|
|
sample = down_block(sample)
|
|
|
|
|
|
sample = self.mid_block(sample)
|
|
|
|
|
|
sample = self.conv_norm_out(sample)
|
|
sample = self.conv_act(sample)
|
|
sample = self.conv_out(sample)
|
|
|
|
return sample
|
|
|
|
|
|
class DecoderCausal3D(nn.Module):
|
|
r"""
|
|
The `DecoderCausal3D` layer of a variational autoencoder that decodes its latent representation into an output sample.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels: int = 3,
|
|
out_channels: int = 3,
|
|
up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D",),
|
|
block_out_channels: Tuple[int, ...] = (64,),
|
|
layers_per_block: int = 2,
|
|
norm_num_groups: int = 32,
|
|
act_fn: str = "silu",
|
|
norm_type: str = "group",
|
|
mid_block_add_attention=True,
|
|
time_compression_ratio: int = 4,
|
|
spatial_compression_ratio: int = 8,
|
|
):
|
|
super().__init__()
|
|
self.layers_per_block = layers_per_block
|
|
|
|
self.conv_in = CausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1)
|
|
self.mid_block = None
|
|
self.up_blocks = nn.ModuleList([])
|
|
|
|
temb_channels = in_channels if norm_type == "spatial" else None
|
|
|
|
|
|
self.mid_block = UNetMidBlockCausal3D(
|
|
in_channels=block_out_channels[-1],
|
|
resnet_eps=1e-6,
|
|
resnet_act_fn=act_fn,
|
|
output_scale_factor=1,
|
|
resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
|
|
attention_head_dim=block_out_channels[-1],
|
|
resnet_groups=norm_num_groups,
|
|
temb_channels=temb_channels,
|
|
add_attention=mid_block_add_attention,
|
|
)
|
|
|
|
|
|
reversed_block_out_channels = list(reversed(block_out_channels))
|
|
output_channel = reversed_block_out_channels[0]
|
|
for i, up_block_type in enumerate(up_block_types):
|
|
prev_output_channel = output_channel
|
|
output_channel = reversed_block_out_channels[i]
|
|
is_final_block = i == len(block_out_channels) - 1
|
|
num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))
|
|
num_time_upsample_layers = int(np.log2(time_compression_ratio))
|
|
|
|
if time_compression_ratio == 4:
|
|
add_spatial_upsample = bool(i < num_spatial_upsample_layers)
|
|
add_time_upsample = bool(
|
|
i >= len(block_out_channels) - 1 - num_time_upsample_layers
|
|
and not is_final_block
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
|
|
|
|
upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
|
|
upsample_scale_factor_T = (2,) if add_time_upsample else (1,)
|
|
upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW)
|
|
up_block = get_up_block3d(
|
|
up_block_type,
|
|
num_layers=self.layers_per_block + 1,
|
|
in_channels=prev_output_channel,
|
|
out_channels=output_channel,
|
|
prev_output_channel=None,
|
|
add_upsample=bool(add_spatial_upsample or add_time_upsample),
|
|
upsample_scale_factor=upsample_scale_factor,
|
|
resnet_eps=1e-6,
|
|
resnet_act_fn=act_fn,
|
|
resnet_groups=norm_num_groups,
|
|
attention_head_dim=output_channel,
|
|
temb_channels=temb_channels,
|
|
resnet_time_scale_shift=norm_type,
|
|
)
|
|
self.up_blocks.append(up_block)
|
|
prev_output_channel = output_channel
|
|
|
|
|
|
if norm_type == "spatial":
|
|
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
|
|
else:
|
|
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
|
|
self.conv_act = nn.SiLU()
|
|
self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
|
|
|
|
self.gradient_checkpointing = False
|
|
|
|
def forward(
|
|
self,
|
|
sample: torch.FloatTensor,
|
|
latent_embeds: Optional[torch.FloatTensor] = None,
|
|
) -> torch.FloatTensor:
|
|
r"""The forward method of the `DecoderCausal3D` class."""
|
|
assert len(sample.shape) == 5, "The input tensor should have 5 dimensions."
|
|
|
|
sample = self.conv_in(sample)
|
|
|
|
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
|
if self.training and self.gradient_checkpointing:
|
|
|
|
def create_custom_forward(module):
|
|
def custom_forward(*inputs):
|
|
return module(*inputs)
|
|
|
|
return custom_forward
|
|
|
|
if is_torch_version(">=", "1.11.0"):
|
|
|
|
sample = torch.utils.checkpoint.checkpoint(
|
|
create_custom_forward(self.mid_block),
|
|
sample,
|
|
latent_embeds,
|
|
use_reentrant=False,
|
|
)
|
|
sample = sample.to(upscale_dtype)
|
|
|
|
|
|
for up_block in self.up_blocks:
|
|
sample = torch.utils.checkpoint.checkpoint(
|
|
create_custom_forward(up_block),
|
|
sample,
|
|
latent_embeds,
|
|
use_reentrant=False,
|
|
)
|
|
else:
|
|
|
|
sample = torch.utils.checkpoint.checkpoint(
|
|
create_custom_forward(self.mid_block), sample, latent_embeds
|
|
)
|
|
sample = sample.to(upscale_dtype)
|
|
|
|
|
|
for up_block in self.up_blocks:
|
|
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
|
|
else:
|
|
|
|
sample = self.mid_block(sample, latent_embeds)
|
|
sample = sample.to(upscale_dtype)
|
|
|
|
|
|
for up_block in self.up_blocks:
|
|
sample = up_block(sample, latent_embeds)
|
|
|
|
|
|
if latent_embeds is None:
|
|
sample = self.conv_norm_out(sample)
|
|
else:
|
|
sample = self.conv_norm_out(sample, latent_embeds)
|
|
sample = self.conv_act(sample)
|
|
sample = self.conv_out(sample)
|
|
|
|
return sample
|
|
|
|
|
|
class DiagonalGaussianDistribution(object):
|
|
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
|
|
if parameters.ndim == 3:
|
|
dim = 2
|
|
elif parameters.ndim == 5 or parameters.ndim == 4:
|
|
dim = 1
|
|
else:
|
|
raise NotImplementedError
|
|
self.parameters = parameters
|
|
self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim)
|
|
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
|
self.deterministic = deterministic
|
|
self.std = torch.exp(0.5 * self.logvar)
|
|
self.var = torch.exp(self.logvar)
|
|
if self.deterministic:
|
|
self.var = self.std = torch.zeros_like(
|
|
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
|
|
)
|
|
|
|
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
|
|
|
|
sample = randn_tensor(
|
|
self.mean.shape,
|
|
generator=generator,
|
|
device=self.parameters.device,
|
|
dtype=self.parameters.dtype,
|
|
)
|
|
x = self.mean + self.std * sample
|
|
return x
|
|
|
|
def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
|
|
if self.deterministic:
|
|
return torch.Tensor([0.0])
|
|
else:
|
|
reduce_dim = list(range(1, self.mean.ndim))
|
|
if other is None:
|
|
return 0.5 * torch.sum(
|
|
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
|
dim=reduce_dim,
|
|
)
|
|
else:
|
|
return 0.5 * torch.sum(
|
|
torch.pow(self.mean - other.mean, 2) / other.var
|
|
+ self.var / other.var
|
|
- 1.0
|
|
- self.logvar
|
|
+ other.logvar,
|
|
dim=reduce_dim,
|
|
)
|
|
|
|
def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
|
|
if self.deterministic:
|
|
return torch.Tensor([0.0])
|
|
logtwopi = np.log(2.0 * np.pi)
|
|
return 0.5 * torch.sum(
|
|
logtwopi + self.logvar +
|
|
torch.pow(sample - self.mean, 2) / self.var,
|
|
dim=dims,
|
|
)
|
|
|
|
def mode(self) -> torch.Tensor:
|
|
return self.mean
|
|
|