aleafy's picture
Start fresh
0a63786
# Part of the implementation is borrowed and modified from stable-diffusion,
# publicly avaialbe at https://github.com/Stability-AI/stablediffusion.
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
__all__ = ['UNetSD']
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
class UNetSD(nn.Module):
def __init__(self,
in_dim=7,
dim=512,
y_dim=512,
context_dim=512,
out_dim=6,
dim_mult=[1, 2, 3, 4],
num_heads=None,
head_dim=64,
num_res_blocks=3,
attn_scales=[1 / 2, 1 / 4, 1 / 8],
use_scale_shift_norm=True,
dropout=0.1,
temporal_attn_times=2,
temporal_attention=True,
use_checkpoint=False,
use_image_dataset=False,
use_fps_condition=False,
use_sim_mask=False):
embed_dim = dim * 4
num_heads = num_heads if num_heads else dim // 32
super(UNetSD, self).__init__()
self.in_dim = in_dim
self.dim = dim
self.y_dim = y_dim
self.context_dim = context_dim
self.embed_dim = embed_dim
self.out_dim = out_dim
self.dim_mult = dim_mult
self.num_heads = num_heads
# parameters for spatial/temporal attention
self.head_dim = head_dim
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.use_scale_shift_norm = use_scale_shift_norm
self.temporal_attn_times = temporal_attn_times
self.temporal_attention = temporal_attention
self.use_checkpoint = use_checkpoint
self.use_image_dataset = use_image_dataset
self.use_fps_condition = use_fps_condition
self.use_sim_mask = use_sim_mask
use_linear_in_temporal = False
transformer_depth = 1
disabled_sa = False
# params
enc_dims = [dim * u for u in [1] + dim_mult]
dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
shortcut_dims = []
scale = 1.0
# embeddings
self.time_embed = nn.Sequential(
nn.Linear(dim, embed_dim), nn.SiLU(),
nn.Linear(embed_dim, embed_dim))
if self.use_fps_condition:
self.fps_embedding = nn.Sequential(
nn.Linear(dim, embed_dim), nn.SiLU(),
nn.Linear(embed_dim, embed_dim))
nn.init.zeros_(self.fps_embedding[-1].weight)
nn.init.zeros_(self.fps_embedding[-1].bias)
# encoder
self.input_blocks = nn.ModuleList()
init_block = nn.ModuleList([nn.Conv2d(self.in_dim, dim, 3, padding=1)])
if temporal_attention:
init_block.append(
TemporalTransformer(
dim,
num_heads,
head_dim,
depth=transformer_depth,
context_dim=context_dim,
disable_self_attn=disabled_sa,
use_linear=use_linear_in_temporal,
multiply_zero=use_image_dataset))
self.input_blocks.append(init_block)
shortcut_dims.append(dim)
for i, (in_dim,
out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])):
for j in range(num_res_blocks):
# residual (+attention) blocks
block = nn.ModuleList([
ResBlock(
in_dim,
embed_dim,
dropout,
out_channels=out_dim,
use_scale_shift_norm=False,
use_image_dataset=use_image_dataset,
)
])
if scale in attn_scales:
block.append(
SpatialTransformer(
out_dim,
out_dim // head_dim,
head_dim,
depth=1,
context_dim=self.context_dim,
disable_self_attn=False,
use_linear=True))
if self.temporal_attention:
block.append(
TemporalTransformer(
out_dim,
out_dim // head_dim,
head_dim,
depth=transformer_depth,
context_dim=context_dim,
disable_self_attn=disabled_sa,
use_linear=use_linear_in_temporal,
multiply_zero=use_image_dataset))
in_dim = out_dim
self.input_blocks.append(block)
shortcut_dims.append(out_dim)
# downsample
if i != len(dim_mult) - 1 and j == num_res_blocks - 1:
downsample = Downsample(
out_dim, True, dims=2, out_channels=out_dim)
shortcut_dims.append(out_dim)
scale /= 2.0
self.input_blocks.append(downsample)
# middle
self.middle_block = nn.ModuleList([
ResBlock(
out_dim,
embed_dim,
dropout,
use_scale_shift_norm=False,
use_image_dataset=use_image_dataset,
),
SpatialTransformer(
out_dim,
out_dim // head_dim,
head_dim,
depth=1,
context_dim=self.context_dim,
disable_self_attn=False,
use_linear=True)
])
if self.temporal_attention:
self.middle_block.append(
TemporalTransformer(
out_dim,
out_dim // head_dim,
head_dim,
depth=transformer_depth,
context_dim=context_dim,
disable_self_attn=disabled_sa,
use_linear=use_linear_in_temporal,
multiply_zero=use_image_dataset,
))
self.middle_block.append(
ResBlock(
out_dim,
embed_dim,
dropout,
use_scale_shift_norm=False,
use_image_dataset=use_image_dataset,
))
# decoder
self.output_blocks = nn.ModuleList()
for i, (in_dim,
out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])):
for j in range(num_res_blocks + 1):
# residual (+attention) blocks
block = nn.ModuleList([
ResBlock(
in_dim + shortcut_dims.pop(),
embed_dim,
dropout,
out_dim,
use_scale_shift_norm=False,
use_image_dataset=use_image_dataset,
)
])
if scale in attn_scales:
block.append(
SpatialTransformer(
out_dim,
out_dim // head_dim,
head_dim,
depth=1,
context_dim=1024,
disable_self_attn=False,
use_linear=True))
if self.temporal_attention:
block.append(
TemporalTransformer(
out_dim,
out_dim // head_dim,
head_dim,
depth=transformer_depth,
context_dim=context_dim,
disable_self_attn=disabled_sa,
use_linear=use_linear_in_temporal,
multiply_zero=use_image_dataset))
in_dim = out_dim
# upsample
if i != len(dim_mult) - 1 and j == num_res_blocks:
upsample = Upsample(
out_dim, True, dims=2.0, out_channels=out_dim)
scale *= 2.0
block.append(upsample)
self.output_blocks.append(block)
# head
self.out = nn.Sequential(
nn.GroupNorm(32, out_dim), nn.SiLU(),
nn.Conv2d(out_dim, self.out_dim, 3, padding=1))
# zero out the last layer params
nn.init.zeros_(self.out[-1].weight)
def forward(
self,
x,
t,
y,
fps=None,
video_mask=None,
focus_present_mask=None,
prob_focus_present=0.,
mask_last_frame_num=0 # mask last frame num
):
"""
prob_focus_present: probability at which a given batch sample will focus on the present
(0. is all off, 1. is completely arrested attention across time)
"""
batch, device = x.shape[0], x.device
self.batch = batch
# image and video joint training, if mask_last_frame_num is set, prob_focus_present will be ignored
if mask_last_frame_num > 0:
focus_present_mask = None
video_mask[-mask_last_frame_num:] = False
else:
focus_present_mask = default(
focus_present_mask, lambda: prob_mask_like(
(batch, ), prob_focus_present, device=device))
time_rel_pos_bias = None
# embeddings
if self.use_fps_condition and fps is not None:
e = self.time_embed(sinusoidal_embedding(
t, self.dim)) + self.fps_embedding(
sinusoidal_embedding(fps, self.dim))
else:
e = self.time_embed(sinusoidal_embedding(t, self.dim))
context = y
# repeat f times for spatial e and context
f = x.shape[2]
e = e.repeat_interleave(repeats=f, dim=0)
if isinstance(context, (tuple, list)):
context = (
context[0].repeat_interleave(repeats=f, dim=0),
context[1].repeat_interleave(repeats=f, dim=0),
)
else:
context = context.repeat_interleave(repeats=f, dim=0)
# always in shape (b f) c h w, except for temporal layer
x = rearrange(x, 'b c f h w -> (b f) c h w')
# encoder
xs = []
for block in self.input_blocks:
x = self._forward_single(block, x, e, context, time_rel_pos_bias,
focus_present_mask, video_mask)
xs.append(x)
# middle
for block in self.middle_block:
x = self._forward_single(block, x, e, context, time_rel_pos_bias,
focus_present_mask, video_mask)
# decoder
for block in self.output_blocks:
x = torch.cat([x, xs.pop()], dim=1)
x = self._forward_single(
block,
x,
e,
context,
time_rel_pos_bias,
focus_present_mask,
video_mask,
reference=xs[-1] if len(xs) > 0 else None)
# head
x = self.out(x)
# reshape back to (b c f h w)
x = rearrange(x, '(b f) c h w -> b c f h w', b=batch)
return x
def _forward_single(self,
module,
x,
e,
context,
time_rel_pos_bias,
focus_present_mask,
video_mask,
reference=None):
if isinstance(module, ResidualBlock):
x = x.contiguous()
x = module(x, e, reference)
elif isinstance(module, ResBlock):
x = x.contiguous()
x = module(x, e, self.batch)
elif isinstance(module, SpatialTransformer):
x = module(x, context)
elif isinstance(module, TemporalTransformer):
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch)
x = module(x, context)
x = rearrange(x, 'b c f h w -> (b f) c h w')
elif isinstance(module, CrossAttention):
x = module(x, context)
elif isinstance(module, BasicTransformerBlock):
x = module(x, context)
elif isinstance(module, FeedForward):
x = module(x, context)
elif isinstance(module, Upsample):
x = module(x)
elif isinstance(module, Downsample):
x = module(x)
elif isinstance(module, Resample):
x = module(x, reference)
elif isinstance(module, nn.ModuleList):
for block in module:
x = self._forward_single(block, x, e, context,
time_rel_pos_bias, focus_present_mask,
video_mask, reference)
else:
x = module(x)
return x
def sinusoidal_embedding(timesteps, dim):
# check input
half = dim // 2
timesteps = timesteps.float()
# compute sinusoidal embedding
sinusoid = torch.outer(
timesteps, torch.pow(10000,
-torch.arange(half).to(timesteps).div(half)))
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
if dim % 2 != 0:
x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1)
return x
class CrossAttention(nn.Module):
def __init__(self,
query_dim,
context_dim=None,
heads=8,
dim_head=64,
dropout=0.):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head**-0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
self.ptp_sa_replace = False
self.num_frames = 1 # for ptp sa replacement use
def forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
is_self_attn = context is None
context = default(context, x)
if (isinstance(context, list) or isinstance(context, tuple)):
k = self.to_k(context[0]) # use old prompt's new mapping in new prompt for key
v = self.to_v(context[1]) # use new prompt for value
else:
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
(q, k, v))
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
del q, k
if is_self_attn and self.ptp_sa_replace: #and x.shape[0] < x.shape[1]:
if x.shape[0] < x.shape[1]:
# spatial attention
sim = rearrange(sim, '(b f h) l d -> b f h l d', b=4, f=self.num_frames, h=h)
sims = sim.chunk(4)
sim = torch.cat((sims[0], sims[0], sims[2], sims[2]))
sim = rearrange(sim, 'b f h l d -> (b f h) l d')
else:
# pass
# temporal attention
sim = rearrange(sim, '(b l) f d -> b l f d', b=4)
sims = sim.chunk(4)
sim = torch.cat((sims[0], sims[0], sims[2], sims[2]))
sim = rearrange(sim, 'b l f d -> (b l) f d')
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
sim = sim.softmax(dim=-1)
out = torch.einsum('b i j, b j d -> b i d', sim, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)
class SpatialTransformer(nn.Module):
"""
Transformer block for image-like data in spatial axis.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
NEW: use_linear for more efficiency instead of the 1x1 convs
"""
def __init__(self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.,
context_dim=None,
disable_self_attn=False,
use_linear=False,
use_checkpoint=True):
super().__init__()
if exists(context_dim) and not isinstance(context_dim, list):
context_dim = [context_dim]
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
if not use_linear:
self.proj_in = nn.Conv2d(
in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
else:
self.proj_in = nn.Linear(in_channels, inner_dim)
self.transformer_blocks = nn.ModuleList([
BasicTransformerBlock(
inner_dim,
n_heads,
d_head,
dropout=dropout,
context_dim=context_dim[d],
disable_self_attn=disable_self_attn,
checkpoint=use_checkpoint) for d in range(depth)
])
if not use_linear:
self.proj_out = zero_module(
nn.Conv2d(
inner_dim, in_channels, kernel_size=1, stride=1,
padding=0))
else:
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
self.use_linear = use_linear
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
if not isinstance(context, list):
context = [context]
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
if self.use_linear:
x = self.proj_in(x)
for i, block in enumerate(self.transformer_blocks):
x = block(x, context=context[i])
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
if not self.use_linear:
x = self.proj_out(x)
return x + x_in
class TemporalTransformer(nn.Module):
"""
Transformer block for image-like data in temporal axis.
First, reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
"""
def __init__(self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.,
context_dim=None,
disable_self_attn=False,
use_linear=False,
use_checkpoint=True,
only_self_att=True,
multiply_zero=False):
super().__init__()
self.multiply_zero = multiply_zero
self.only_self_att = only_self_att
if self.only_self_att:
context_dim = None
if not isinstance(context_dim, list):
context_dim = [context_dim]
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
if not use_linear:
self.proj_in = nn.Conv1d(
in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
else:
self.proj_in = nn.Linear(in_channels, inner_dim)
self.transformer_blocks = nn.ModuleList([
BasicTransformerBlock(
inner_dim,
n_heads,
d_head,
dropout=dropout,
context_dim=context_dim[d],
checkpoint=use_checkpoint) for d in range(depth)
])
if not use_linear:
self.proj_out = zero_module(
nn.Conv1d(
inner_dim, in_channels, kernel_size=1, stride=1,
padding=0))
else:
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
self.use_linear = use_linear
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
if self.only_self_att:
context = None
if not isinstance(context, list):
context = [context]
b, c, f, h, w = x.shape
x_in = x
x = self.norm(x)
if not self.use_linear:
x = rearrange(x, 'b c f h w -> (b h w) c f').contiguous()
x = self.proj_in(x)
if self.use_linear:
x = rearrange(
x, '(b f) c h w -> b (h w) f c', f=self.frames).contiguous()
x = self.proj_in(x)
if self.only_self_att:
x = rearrange(x, 'bhw c f -> bhw f c').contiguous()
for i, block in enumerate(self.transformer_blocks):
x = block(x)
x = rearrange(x, '(b hw) f c -> b hw f c', b=b).contiguous()
else:
x = rearrange(x, '(b hw) c f -> b hw f c', b=b).contiguous()
for i, block in enumerate(self.transformer_blocks):
context[i] = rearrange(
context[i], '(b f) l con -> b f l con',
f=self.frames).contiguous()
# calculate each batch one by one (since number in shape could not greater then 65,535 for some package)
for j in range(b):
context_i_j = repeat(
context[i][j],
'f l con -> (f r) l con',
r=(h * w) // self.frames,
f=self.frames).contiguous()
x[j] = block(x[j], context=context_i_j)
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, 'b (h w) f c -> b f c h w', h=h, w=w).contiguous()
if not self.use_linear:
x = rearrange(x, 'b hw f c -> (b hw) c f').contiguous()
x = self.proj_out(x)
x = rearrange(
x, '(b h w) c f -> b c f h w', b=b, h=h, w=w).contiguous()
if self.multiply_zero:
x = 0.0 * x + x_in
else:
x = x + x_in
return x
class BasicTransformerBlock(nn.Module):
def __init__(self,
dim,
n_heads,
d_head,
dropout=0.,
context_dim=None,
gated_ff=True,
checkpoint=True,
disable_self_attn=False):
super().__init__()
attn_cls = CrossAttention
self.disable_self_attn = disable_self_attn
self.attn1 = attn_cls(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
context_dim=context_dim if self.disable_self_attn else
None) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = attn_cls(
query_dim=dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def forward(self, x, context=None):
x = self.attn1(
self.norm1(x),
context=context if self.disable_self_attn else None) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
return x
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(nn.Linear(
dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(project_in, nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out))
def forward(self, x):
return self.net(x)
class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self,
channels,
use_conv,
dims=2,
out_channels=None,
padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = nn.Conv2d(
self.channels, self.out_channels, 3, padding=padding)
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
mode='nearest')
else:
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.use_conv:
x = self.conv(x)
return x
class ResBlock(nn.Module):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout.
:param out_channels: if specified, the number of out channels.
:param use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the
channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param up: if True, use this block for upsampling.
:param down: if True, use this block for downsampling.
:param use_temporal_conv: if True, use the temporal convolution.
:param use_image_dataset: if True, the temporal parameters will not be optimized.
"""
def __init__(
self,
channels,
emb_channels,
dropout,
out_channels=None,
use_conv=False,
use_scale_shift_norm=False,
dims=2,
up=False,
down=False,
use_temporal_conv=True,
use_image_dataset=False,
):
super().__init__()
self.channels = channels
self.emb_channels = emb_channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_scale_shift_norm = use_scale_shift_norm
self.use_temporal_conv = use_temporal_conv
self.in_layers = nn.Sequential(
nn.GroupNorm(32, channels),
nn.SiLU(),
nn.Conv2d(channels, self.out_channels, 3, padding=1),
)
self.updown = up or down
if up:
self.h_upd = Upsample(channels, False, dims)
self.x_upd = Upsample(channels, False, dims)
elif down:
self.h_upd = Downsample(channels, False, dims)
self.x_upd = Downsample(channels, False, dims)
else:
self.h_upd = self.x_upd = nn.Identity()
self.emb_layers = nn.Sequential(
nn.SiLU(),
nn.Linear(
emb_channels,
2 * self.out_channels
if use_scale_shift_norm else self.out_channels,
),
)
self.out_layers = nn.Sequential(
nn.GroupNorm(32, self.out_channels),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = conv_nd(
dims, channels, self.out_channels, 3, padding=1)
else:
self.skip_connection = nn.Conv2d(channels, self.out_channels, 1)
if self.use_temporal_conv:
self.temopral_conv = TemporalConvBlock_v2(
self.out_channels,
self.out_channels,
dropout=0.1,
use_image_dataset=use_image_dataset)
def forward(self, x, emb, batch_size):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features.
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
return self._forward(x, emb, batch_size)
def _forward(self, x, emb, batch_size):
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
h = self.h_upd(h)
x = self.x_upd(x)
h = in_conv(h)
else:
h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = torch.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
h = out_rest(h)
else:
h = h + emb_out
h = self.out_layers(h)
h = self.skip_connection(x) + h
if self.use_temporal_conv:
h = rearrange(h, '(b f) c h w -> b c f h w', b=batch_size)
h = self.temopral_conv(h)
h = rearrange(h, 'b c f h w -> (b f) c h w')
return h
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self,
channels,
use_conv,
dims=2,
out_channels=None,
padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
stride = 2 if dims != 3 else (1, 2, 2)
if self.use_conv:
self.op = nn.Conv2d(
self.channels,
self.out_channels,
3,
stride=stride,
padding=padding)
else:
assert self.channels == self.out_channels
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.op(x)
class Resample(nn.Module):
def __init__(self, in_dim, out_dim, mode):
assert mode in ['none', 'upsample', 'downsample']
super(Resample, self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.mode = mode
def forward(self, x, reference=None):
if self.mode == 'upsample':
assert reference is not None
x = F.interpolate(x, size=reference.shape[-2:], mode='nearest')
elif self.mode == 'downsample':
x = F.adaptive_avg_pool2d(
x, output_size=tuple(u // 2 for u in x.shape[-2:]))
return x
class ResidualBlock(nn.Module):
def __init__(self,
in_dim,
embed_dim,
out_dim,
use_scale_shift_norm=True,
mode='none',
dropout=0.0):
super(ResidualBlock, self).__init__()
self.in_dim = in_dim
self.embed_dim = embed_dim
self.out_dim = out_dim
self.use_scale_shift_norm = use_scale_shift_norm
self.mode = mode
# layers
self.layer1 = nn.Sequential(
nn.GroupNorm(32, in_dim), nn.SiLU(),
nn.Conv2d(in_dim, out_dim, 3, padding=1))
self.resample = Resample(in_dim, in_dim, mode)
self.embedding = nn.Sequential(
nn.SiLU(),
nn.Linear(embed_dim,
out_dim * 2 if use_scale_shift_norm else out_dim))
self.layer2 = nn.Sequential(
nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
nn.Conv2d(out_dim, out_dim, 3, padding=1))
self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d(
in_dim, out_dim, 1)
# zero out the last layer params
nn.init.zeros_(self.layer2[-1].weight)
def forward(self, x, e, reference=None):
identity = self.resample(x, reference)
x = self.layer1[-1](self.resample(self.layer1[:-1](x), reference))
e = self.embedding(e).unsqueeze(-1).unsqueeze(-1).type(x.dtype)
if self.use_scale_shift_norm:
scale, shift = e.chunk(2, dim=1)
x = self.layer2[0](x) * (1 + scale) + shift
x = self.layer2[1:](x)
else:
x = x + e
x = self.layer2(x)
x = x + self.shortcut(identity)
return x
class AttentionBlock(nn.Module):
def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None):
# consider head_dim first, then num_heads
num_heads = dim // head_dim if head_dim else num_heads
head_dim = dim // num_heads
assert num_heads * head_dim == dim
super(AttentionBlock, self).__init__()
self.dim = dim
self.context_dim = context_dim
self.num_heads = num_heads
self.head_dim = head_dim
self.scale = math.pow(head_dim, -0.25)
# layers
self.norm = nn.GroupNorm(32, dim)
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
if context_dim is not None:
self.context_kv = nn.Linear(context_dim, dim * 2)
self.proj = nn.Conv2d(dim, dim, 1)
# zero out the last layer params
nn.init.zeros_(self.proj.weight)
def forward(self, x, context=None):
r"""x: [B, C, H, W].
context: [B, L, C] or None.
"""
identity = x
b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim
# compute query, key, value
x = self.norm(x)
q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1)
if context is not None:
ck, cv = self.context_kv(context).reshape(b, -1, n * 2,
d).permute(0, 2, 3,
1).chunk(
2, dim=1)
k = torch.cat([ck, k], dim=-1)
v = torch.cat([cv, v], dim=-1)
# compute attention
attn = torch.matmul(q.transpose(-1, -2) * self.scale, k * self.scale)
attn = F.softmax(attn, dim=-1)
# gather context
x = torch.matmul(v, attn.transpose(-1, -2))
x = x.reshape(b, c, h, w)
# output
x = self.proj(x)
return x + identity
class TemporalConvBlock_v2(nn.Module):
def __init__(self,
in_dim,
out_dim=None,
dropout=0.0,
use_image_dataset=False):
super(TemporalConvBlock_v2, self).__init__()
if out_dim is None:
out_dim = in_dim # int(1.5*in_dim)
self.in_dim = in_dim
self.out_dim = out_dim
self.use_image_dataset = use_image_dataset
# conv layers
self.conv1 = nn.Sequential(
nn.GroupNorm(32, in_dim), nn.SiLU(),
nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)))
self.conv2 = nn.Sequential(
nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)))
self.conv3 = nn.Sequential(
nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)))
self.conv4 = nn.Sequential(
nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout),
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)))
# zero out the last layer params,so the conv block is identity
nn.init.zeros_(self.conv4[-1].weight)
nn.init.zeros_(self.conv4[-1].bias)
def forward(self, x):
identity = x
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
if self.use_image_dataset:
x = identity + 0.0 * x
else:
x = identity + x
return x
def prob_mask_like(shape, prob, device):
if prob == 1:
return torch.ones(shape, device=device, dtype=torch.bool)
elif prob == 0:
return torch.zeros(shape, device=device, dtype=torch.bool)
else:
mask = torch.zeros(shape, device=device).float().uniform_(0, 1) < prob
# aviod mask all, which will cause find_unused_parameters error
if mask.all():
mask[0] = False
return mask
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f'unsupported dimensions: {dims}')
def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f'unsupported dimensions: {dims}')