wangfuyun's picture
Upload 42 files
c0fdaf5 verified
raw
history blame
11 kB
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
from einops import rearrange
class InflatedConv3d(nn.Conv2d):
def forward(self, x):
video_length = x.shape[2]
x = rearrange(x, "b c f h w -> (b f) c h w")
x = super().forward(x)
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
return x
class InflatedGroupNorm(nn.GroupNorm):
def forward(self, x):
video_length = x.shape[2]
x = rearrange(x, "b c f h w -> (b f) c h w")
x = super().forward(x)
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
return x
class Upsample3D(nn.Module):
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose
self.name = name
conv = None
if use_conv_transpose:
raise NotImplementedError
elif use_conv:
self.conv = InflatedConv3d(
self.channels, self.out_channels, 3, padding=1)
def forward(self, hidden_states, output_size=None):
assert hidden_states.shape[1] == self.channels
if self.use_conv_transpose:
raise NotImplementedError
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
dtype = hidden_states.dtype
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(torch.float32)
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
if hidden_states.shape[0] >= 64:
hidden_states = hidden_states.contiguous()
# if `output_size` is passed we force the interpolation output
# size and do not make use of `scale_factor=2`
if output_size is None:
hidden_states = F.interpolate(hidden_states, scale_factor=[
1.0, 2.0, 2.0], mode="nearest")
else:
hidden_states = F.interpolate(
hidden_states, size=output_size, mode="nearest")
# If the input is bfloat16, we cast back to bfloat16
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(dtype)
# if self.use_conv:
# if self.name == "conv":
# hidden_states = self.conv(hidden_states)
# else:
# hidden_states = self.Conv2d_0(hidden_states)
hidden_states = self.conv(hidden_states)
return hidden_states
class Downsample3D(nn.Module):
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.padding = padding
stride = 2
self.name = name
if use_conv:
self.conv = InflatedConv3d(
self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
raise NotImplementedError
def forward(self, hidden_states):
assert hidden_states.shape[1] == self.channels
if self.use_conv and self.padding == 0:
raise NotImplementedError
assert hidden_states.shape[1] == self.channels
hidden_states = self.conv(hidden_states)
return hidden_states
class ResnetBlock3D(nn.Module):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout=0.0,
temb_channels=512,
groups=32,
groups_out=None,
pre_norm=True,
eps=1e-6,
non_linearity="swish",
time_embedding_norm="default",
output_scale_factor=1.0,
use_in_shortcut=None,
use_inflated_groupnorm=None,
use_temporal_conv=False,
use_temporal_mixer=False,
):
super().__init__()
self.pre_norm = pre_norm
self.pre_norm = True
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.time_embedding_norm = time_embedding_norm
self.output_scale_factor = output_scale_factor
self.use_temporal_mixer = use_temporal_mixer
if use_temporal_mixer:
self.temporal_mixer = AlphaBlender(0.3, "learned", None)
if groups_out is None:
groups_out = groups
assert use_inflated_groupnorm != None
if use_inflated_groupnorm:
self.norm1 = InflatedGroupNorm(
num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
else:
self.norm1 = torch.nn.GroupNorm(
num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
if use_temporal_conv:
self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=(
3, 1, 1), stride=1, padding=(1, 0, 0))
else:
self.conv1 = InflatedConv3d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels is not None:
if self.time_embedding_norm == "default":
time_emb_proj_out_channels = out_channels
elif self.time_embedding_norm == "scale_shift":
time_emb_proj_out_channels = out_channels * 2
else:
raise ValueError(
f"unknown time_embedding_norm : {self.time_embedding_norm} ")
self.time_emb_proj = torch.nn.Linear(
temb_channels, time_emb_proj_out_channels)
else:
self.time_emb_proj = None
if use_inflated_groupnorm:
self.norm2 = InflatedGroupNorm(
num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
else:
self.norm2 = torch.nn.GroupNorm(
num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
self.dropout = torch.nn.Dropout(dropout)
if use_temporal_conv:
self.conv2 = nn.Conv3d(in_channels, out_channels, kernel_size=(
3, 1, 1), stride=1, padding=(1, 0, 0))
else:
self.conv2 = InflatedConv3d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = lambda x: F.silu(x)
elif non_linearity == "mish":
self.nonlinearity = Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
self.conv_shortcut = None
if self.use_in_shortcut:
self.conv_shortcut = InflatedConv3d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, input_tensor, temb):
if self.use_temporal_mixer:
residual = input_tensor
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
if temb is not None:
temb = self.time_emb_proj(self.nonlinearity(temb))[
:, :, None, None, None]
if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb
hidden_states = self.norm2(hidden_states)
if temb is not None and self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / \
self.output_scale_factor
if self.use_temporal_mixer:
output_tensor = self.temporal_mixer(residual, output_tensor, None)
# return residual + 0.0 * self.temporal_mixer(residual, output_tensor, None)
return output_tensor
class Mish(torch.nn.Module):
def forward(self, hidden_states):
return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
class AlphaBlender(nn.Module):
strategies = ["learned", "fixed", "learned_with_images"]
def __init__(
self,
alpha: float,
merge_strategy: str = "learned_with_images",
rearrange_pattern: str = "b t -> (b t) 1 1",
):
super().__init__()
self.merge_strategy = merge_strategy
self.rearrange_pattern = rearrange_pattern
self.scaler = 10.
assert (
merge_strategy in self.strategies
), f"merge_strategy needs to be in {self.strategies}"
if self.merge_strategy == "fixed":
self.register_buffer("mix_factor", torch.Tensor([alpha]))
elif (
self.merge_strategy == "learned"
or self.merge_strategy == "learned_with_images"
):
self.register_parameter(
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
)
else:
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor:
if self.merge_strategy == "fixed":
alpha = self.mix_factor
elif self.merge_strategy == "learned":
alpha = torch.sigmoid(self.mix_factor*self.scaler)
elif self.merge_strategy == "learned_with_images":
assert image_only_indicator is not None, "need image_only_indicator ..."
alpha = torch.where(
image_only_indicator.bool(),
torch.ones(1, 1, device=image_only_indicator.device),
rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"),
)
alpha = rearrange(alpha, self.rearrange_pattern)
else:
raise NotImplementedError
return alpha
def forward(
self,
x_spatial: torch.Tensor,
x_temporal: torch.Tensor,
image_only_indicator: Optional[torch.Tensor] = None,
) -> torch.Tensor:
alpha = self.get_alpha(image_only_indicator)
x = (
alpha.to(x_spatial.dtype) * x_spatial
+ (1.0 - alpha).to(x_spatial.dtype) * x_temporal
)
return x