|
|
|
from einops import rearrange, repeat |
|
import logging |
|
from typing import Any, Callable, Optional, Iterable, Union |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from packaging import version |
|
logpy = logging.getLogger(__name__) |
|
|
|
try: |
|
import xformers |
|
import xformers.ops |
|
|
|
XFORMERS_IS_AVAILABLE = True |
|
except: |
|
XFORMERS_IS_AVAILABLE = False |
|
logpy.warning("no module 'xformers'. Processing without...") |
|
|
|
from lvdm.modules.attention_svd import LinearAttention, MemoryEfficientCrossAttention |
|
|
|
|
|
def nonlinearity(x): |
|
|
|
return x * torch.sigmoid(x) |
|
|
|
|
|
def Normalize(in_channels, num_groups=32): |
|
return torch.nn.GroupNorm( |
|
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True |
|
) |
|
|
|
|
|
class ResnetBlock(nn.Module): |
|
def __init__( |
|
self, |
|
*, |
|
in_channels, |
|
out_channels=None, |
|
conv_shortcut=False, |
|
dropout, |
|
temb_channels=512, |
|
): |
|
super().__init__() |
|
self.in_channels = in_channels |
|
out_channels = in_channels if out_channels is None else out_channels |
|
self.out_channels = out_channels |
|
self.use_conv_shortcut = conv_shortcut |
|
|
|
self.norm1 = Normalize(in_channels) |
|
self.conv1 = torch.nn.Conv2d( |
|
in_channels, out_channels, kernel_size=3, stride=1, padding=1 |
|
) |
|
if temb_channels > 0: |
|
self.temb_proj = torch.nn.Linear(temb_channels, out_channels) |
|
self.norm2 = Normalize(out_channels) |
|
self.dropout = torch.nn.Dropout(dropout) |
|
self.conv2 = torch.nn.Conv2d( |
|
out_channels, out_channels, kernel_size=3, stride=1, padding=1 |
|
) |
|
if self.in_channels != self.out_channels: |
|
if self.use_conv_shortcut: |
|
self.conv_shortcut = torch.nn.Conv2d( |
|
in_channels, out_channels, kernel_size=3, stride=1, padding=1 |
|
) |
|
else: |
|
self.nin_shortcut = torch.nn.Conv2d( |
|
in_channels, out_channels, kernel_size=1, stride=1, padding=0 |
|
) |
|
|
|
def forward(self, x, temb): |
|
h = x |
|
h = self.norm1(h) |
|
h = nonlinearity(h) |
|
h = self.conv1(h) |
|
|
|
if temb is not None: |
|
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] |
|
|
|
h = self.norm2(h) |
|
h = nonlinearity(h) |
|
h = self.dropout(h) |
|
h = self.conv2(h) |
|
|
|
if self.in_channels != self.out_channels: |
|
if self.use_conv_shortcut: |
|
x = self.conv_shortcut(x) |
|
else: |
|
x = self.nin_shortcut(x) |
|
|
|
return x + h |
|
|
|
|
|
class LinAttnBlock(LinearAttention): |
|
"""to match AttnBlock usage""" |
|
|
|
def __init__(self, in_channels): |
|
super().__init__(dim=in_channels, heads=1, dim_head=in_channels) |
|
|
|
|
|
class AttnBlock(nn.Module): |
|
def __init__(self, in_channels): |
|
super().__init__() |
|
self.in_channels = in_channels |
|
|
|
self.norm = Normalize(in_channels) |
|
self.q = torch.nn.Conv2d( |
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
|
) |
|
self.k = torch.nn.Conv2d( |
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
|
) |
|
self.v = torch.nn.Conv2d( |
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
|
) |
|
self.proj_out = torch.nn.Conv2d( |
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
|
) |
|
|
|
def attention(self, h_: torch.Tensor) -> torch.Tensor: |
|
h_ = self.norm(h_) |
|
q = self.q(h_) |
|
k = self.k(h_) |
|
v = self.v(h_) |
|
|
|
b, c, h, w = q.shape |
|
q, k, v = map( |
|
lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v) |
|
) |
|
h_ = torch.nn.functional.scaled_dot_product_attention( |
|
q, k, v |
|
) |
|
|
|
|
|
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) |
|
|
|
def forward(self, x, **kwargs): |
|
h_ = x |
|
h_ = self.attention(h_) |
|
h_ = self.proj_out(h_) |
|
return x + h_ |
|
|
|
|
|
class MemoryEfficientAttnBlock(nn.Module): |
|
""" |
|
Uses xformers efficient implementation, |
|
see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 |
|
Note: this is a single-head self-attention operation |
|
""" |
|
|
|
|
|
def __init__(self, in_channels): |
|
super().__init__() |
|
self.in_channels = in_channels |
|
|
|
self.norm = Normalize(in_channels) |
|
self.q = torch.nn.Conv2d( |
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
|
) |
|
self.k = torch.nn.Conv2d( |
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
|
) |
|
self.v = torch.nn.Conv2d( |
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
|
) |
|
self.proj_out = torch.nn.Conv2d( |
|
in_channels, in_channels, kernel_size=1, stride=1, padding=0 |
|
) |
|
self.attention_op: Optional[Any] = None |
|
|
|
def attention(self, h_: torch.Tensor) -> torch.Tensor: |
|
h_ = self.norm(h_) |
|
q = self.q(h_) |
|
k = self.k(h_) |
|
v = self.v(h_) |
|
|
|
|
|
B, C, H, W = q.shape |
|
q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v)) |
|
|
|
q, k, v = map( |
|
lambda t: t.unsqueeze(3) |
|
.reshape(B, t.shape[1], 1, C) |
|
.permute(0, 2, 1, 3) |
|
.reshape(B * 1, t.shape[1], C) |
|
.contiguous(), |
|
(q, k, v), |
|
) |
|
out = xformers.ops.memory_efficient_attention( |
|
q, k, v, attn_bias=None, op=self.attention_op |
|
) |
|
|
|
out = ( |
|
out.unsqueeze(0) |
|
.reshape(B, 1, out.shape[1], C) |
|
.permute(0, 2, 1, 3) |
|
.reshape(B, out.shape[1], C) |
|
) |
|
return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C) |
|
|
|
def forward(self, x, **kwargs): |
|
h_ = x |
|
h_ = self.attention(h_) |
|
h_ = self.proj_out(h_) |
|
return x + h_ |
|
|
|
|
|
class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention): |
|
def forward(self, x, context=None, mask=None, **unused_kwargs): |
|
b, c, h, w = x.shape |
|
x = rearrange(x, "b c h w -> b (h w) c") |
|
out = super().forward(x, context=context, mask=mask) |
|
out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c) |
|
return x + out |
|
|
|
|
|
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): |
|
assert attn_type in [ |
|
"vanilla", |
|
"vanilla-xformers", |
|
"memory-efficient-cross-attn", |
|
"linear", |
|
"none", |
|
"memory-efficient-cross-attn-fusion", |
|
], f"attn_type {attn_type} unknown" |
|
if ( |
|
version.parse(torch.__version__) < version.parse("2.0.0") |
|
and attn_type != "none" |
|
): |
|
assert XFORMERS_IS_AVAILABLE, ( |
|
f"We do not support vanilla attention in {torch.__version__} anymore, " |
|
f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'" |
|
) |
|
|
|
logpy.info(f"making attention of type '{attn_type}' with {in_channels} in_channels") |
|
if attn_type == "vanilla": |
|
assert attn_kwargs is None |
|
return AttnBlock(in_channels) |
|
elif attn_type == "vanilla-xformers": |
|
logpy.info( |
|
f"building MemoryEfficientAttnBlock with {in_channels} in_channels..." |
|
) |
|
return MemoryEfficientAttnBlock(in_channels) |
|
elif attn_type == "memory-efficient-cross-attn": |
|
attn_kwargs["query_dim"] = in_channels |
|
return MemoryEfficientCrossAttentionWrapper(**attn_kwargs) |
|
elif attn_type == "memory-efficient-cross-attn-fusion": |
|
attn_kwargs["query_dim"] = in_channels |
|
return MemoryEfficientCrossAttentionWrapperFusion(**attn_kwargs) |
|
elif attn_type == "none": |
|
return nn.Identity(in_channels) |
|
else: |
|
return LinAttnBlock(in_channels) |
|
|
|
class MemoryEfficientCrossAttentionWrapperFusion(MemoryEfficientCrossAttention): |
|
|
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0, **kwargs): |
|
super().__init__(query_dim, context_dim, heads, dim_head, dropout, **kwargs) |
|
self.norm = Normalize(query_dim) |
|
nn.init.zeros_(self.to_out[0].weight) |
|
nn.init.zeros_(self.to_out[0].bias) |
|
|
|
def forward(self, x, context=None, mask=None): |
|
if self.training: |
|
return checkpoint(self._forward, x, context, mask, use_reentrant=False) |
|
else: |
|
return self._forward(x, context, mask) |
|
|
|
def _forward( |
|
self, |
|
x, |
|
context=None, |
|
mask=None, |
|
): |
|
bt, c, h, w = x.shape |
|
h_ = self.norm(x) |
|
h_ = rearrange(h_, "b c h w -> b (h w) c") |
|
q = self.to_q(h_) |
|
|
|
|
|
b, c, l, h, w = context.shape |
|
context = rearrange(context, "b c l h w -> (b l) (h w) c") |
|
k = self.to_k(context) |
|
v = self.to_v(context) |
|
k = rearrange(k, "(b l) d c -> b l d c", l=l) |
|
k = torch.cat([k[:, [0] * (bt//b)], k[:, [1]*(bt//b)]], dim=2) |
|
k = rearrange(k, "b l d c -> (b l) d c") |
|
|
|
v = rearrange(v, "(b l) d c -> b l d c", l=l) |
|
v = torch.cat([v[:, [0] * (bt//b)], v[:, [1]*(bt//b)]], dim=2) |
|
v = rearrange(v, "b l d c -> (b l) d c") |
|
|
|
|
|
b, _, _ = q.shape |
|
q, k, v = map( |
|
lambda t: t.unsqueeze(3) |
|
.reshape(b, t.shape[1], self.heads, self.dim_head) |
|
.permute(0, 2, 1, 3) |
|
.reshape(b * self.heads, t.shape[1], self.dim_head) |
|
.contiguous(), |
|
(q, k, v), |
|
) |
|
|
|
|
|
if version.parse(xformers.__version__) >= version.parse("0.0.21"): |
|
|
|
|
|
max_bs = 32768 |
|
N = q.shape[0] |
|
n_batches = math.ceil(N / max_bs) |
|
out = list() |
|
for i_batch in range(n_batches): |
|
batch = slice(i_batch * max_bs, (i_batch + 1) * max_bs) |
|
out.append( |
|
xformers.ops.memory_efficient_attention( |
|
q[batch], |
|
k[batch], |
|
v[batch], |
|
attn_bias=None, |
|
op=self.attention_op, |
|
) |
|
) |
|
out = torch.cat(out, 0) |
|
else: |
|
out = xformers.ops.memory_efficient_attention( |
|
q, k, v, attn_bias=None, op=self.attention_op |
|
) |
|
|
|
|
|
if exists(mask): |
|
raise NotImplementedError |
|
out = ( |
|
out.unsqueeze(0) |
|
.reshape(b, self.heads, out.shape[1], self.dim_head) |
|
.permute(0, 2, 1, 3) |
|
.reshape(b, out.shape[1], self.heads * self.dim_head) |
|
) |
|
out = self.to_out(out) |
|
out = rearrange(out, "bt (h w) c -> bt c h w", h=h, w=w, c=c) |
|
return x + out |
|
|
|
class Combiner(nn.Module): |
|
def __init__(self, ch) -> None: |
|
super().__init__() |
|
self.conv = nn.Conv2d(ch,ch,1,padding=0) |
|
|
|
nn.init.zeros_(self.conv.weight) |
|
nn.init.zeros_(self.conv.bias) |
|
|
|
def forward(self, x, context): |
|
if self.training: |
|
return checkpoint(self._forward, x, context, use_reentrant=False) |
|
else: |
|
return self._forward(x, context) |
|
|
|
def _forward(self, x, context): |
|
|
|
b, c, l, h, w = context.shape |
|
bt, c, h, w = x.shape |
|
context = rearrange(context, "b c l h w -> (b l) c h w") |
|
context = self.conv(context) |
|
context = rearrange(context, "(b l) c h w -> b c l h w", l=l) |
|
x = rearrange(x, "(b t) c h w -> b c t h w", t=bt//b) |
|
x[:,:,0] = x[:,:,0] + context[:,:,0] |
|
x[:,:,-1] = x[:,:,-1] + context[:,:,1] |
|
x = rearrange(x, "b c t h w -> (b t) c h w") |
|
return x |
|
|
|
|
|
class Decoder(nn.Module): |
|
def __init__( |
|
self, |
|
*, |
|
ch, |
|
out_ch, |
|
ch_mult=(1, 2, 4, 8), |
|
num_res_blocks, |
|
attn_resolutions, |
|
dropout=0.0, |
|
resamp_with_conv=True, |
|
in_channels, |
|
resolution, |
|
z_channels, |
|
give_pre_end=False, |
|
tanh_out=False, |
|
use_linear_attn=False, |
|
attn_type="vanilla-xformers", |
|
attn_level=[2,3], |
|
**ignorekwargs, |
|
): |
|
super().__init__() |
|
if use_linear_attn: |
|
attn_type = "linear" |
|
self.ch = ch |
|
self.temb_ch = 0 |
|
self.num_resolutions = len(ch_mult) |
|
self.num_res_blocks = num_res_blocks |
|
self.resolution = resolution |
|
self.in_channels = in_channels |
|
self.give_pre_end = give_pre_end |
|
self.tanh_out = tanh_out |
|
self.attn_level = attn_level |
|
|
|
in_ch_mult = (1,) + tuple(ch_mult) |
|
block_in = ch * ch_mult[self.num_resolutions - 1] |
|
curr_res = resolution // 2 ** (self.num_resolutions - 1) |
|
self.z_shape = (1, z_channels, curr_res, curr_res) |
|
logpy.info( |
|
"Working with z of shape {} = {} dimensions.".format( |
|
self.z_shape, np.prod(self.z_shape) |
|
) |
|
) |
|
|
|
make_attn_cls = self._make_attn() |
|
make_resblock_cls = self._make_resblock() |
|
make_conv_cls = self._make_conv() |
|
|
|
self.conv_in = torch.nn.Conv2d( |
|
z_channels, block_in, kernel_size=3, stride=1, padding=1 |
|
) |
|
|
|
|
|
self.mid = nn.Module() |
|
self.mid.block_1 = make_resblock_cls( |
|
in_channels=block_in, |
|
out_channels=block_in, |
|
temb_channels=self.temb_ch, |
|
dropout=dropout, |
|
) |
|
self.mid.attn_1 = make_attn_cls(block_in, attn_type=attn_type) |
|
self.mid.block_2 = make_resblock_cls( |
|
in_channels=block_in, |
|
out_channels=block_in, |
|
temb_channels=self.temb_ch, |
|
dropout=dropout, |
|
) |
|
|
|
|
|
self.up = nn.ModuleList() |
|
self.attn_refinement = nn.ModuleList() |
|
for i_level in reversed(range(self.num_resolutions)): |
|
block = nn.ModuleList() |
|
attn = nn.ModuleList() |
|
block_out = ch * ch_mult[i_level] |
|
for i_block in range(self.num_res_blocks + 1): |
|
block.append( |
|
make_resblock_cls( |
|
in_channels=block_in, |
|
out_channels=block_out, |
|
temb_channels=self.temb_ch, |
|
dropout=dropout, |
|
) |
|
) |
|
block_in = block_out |
|
if curr_res in attn_resolutions: |
|
attn.append(make_attn_cls(block_in, attn_type=attn_type)) |
|
up = nn.Module() |
|
up.block = block |
|
up.attn = attn |
|
if i_level != 0: |
|
up.upsample = Upsample(block_in, resamp_with_conv) |
|
curr_res = curr_res * 2 |
|
self.up.insert(0, up) |
|
|
|
if i_level in self.attn_level: |
|
self.attn_refinement.insert(0, make_attn_cls(block_in, attn_type='memory-efficient-cross-attn-fusion', attn_kwargs={})) |
|
else: |
|
self.attn_refinement.insert(0, Combiner(block_in)) |
|
|
|
self.norm_out = Normalize(block_in) |
|
self.attn_refinement.append(Combiner(block_in)) |
|
self.conv_out = make_conv_cls( |
|
block_in, out_ch, kernel_size=3, stride=1, padding=1 |
|
) |
|
|
|
def _make_attn(self) -> Callable: |
|
return make_attn |
|
|
|
def _make_resblock(self) -> Callable: |
|
return ResnetBlock |
|
|
|
def _make_conv(self) -> Callable: |
|
return torch.nn.Conv2d |
|
|
|
def get_last_layer(self, **kwargs): |
|
return self.conv_out.weight |
|
|
|
def forward(self, z, ref_context=None, **kwargs): |
|
|
|
|
|
ref_context = None |
|
self.last_z_shape = z.shape |
|
|
|
temb = None |
|
|
|
|
|
h = self.conv_in(z) |
|
|
|
|
|
h = self.mid.block_1(h, temb, **kwargs) |
|
h = self.mid.attn_1(h, **kwargs) |
|
h = self.mid.block_2(h, temb, **kwargs) |
|
|
|
|
|
for i_level in reversed(range(self.num_resolutions)): |
|
for i_block in range(self.num_res_blocks + 1): |
|
h = self.up[i_level].block[i_block](h, temb, **kwargs) |
|
if len(self.up[i_level].attn) > 0: |
|
h = self.up[i_level].attn[i_block](h, **kwargs) |
|
if ref_context: |
|
h = self.attn_refinement[i_level](x=h, context=ref_context[i_level]) |
|
if i_level != 0: |
|
h = self.up[i_level].upsample(h) |
|
|
|
|
|
if self.give_pre_end: |
|
return h |
|
|
|
h = self.norm_out(h) |
|
h = nonlinearity(h) |
|
if ref_context: |
|
|
|
h = self.attn_refinement[-1](x=h, context=ref_context[-1]) |
|
h = self.conv_out(h, **kwargs) |
|
if self.tanh_out: |
|
h = torch.tanh(h) |
|
return h |
|
|
|
|
|
|
|
|
|
from abc import abstractmethod |
|
from lvdm.models.utils_diffusion import timestep_embedding |
|
|
|
from torch.utils.checkpoint import checkpoint |
|
from lvdm.basics import ( |
|
zero_module, |
|
conv_nd, |
|
linear, |
|
normalization, |
|
) |
|
from lvdm.modules.networks.openaimodel3d import Upsample, Downsample |
|
class TimestepBlock(nn.Module): |
|
""" |
|
Any module where forward() takes timestep embeddings as a second argument. |
|
""" |
|
|
|
@abstractmethod |
|
def forward(self, x: torch.Tensor, emb: torch.Tensor): |
|
""" |
|
Apply the module to `x` given `emb` timestep embeddings. |
|
""" |
|
|
|
class ResBlock(TimestepBlock): |
|
""" |
|
A residual block that can optionally change the number of channels. |
|
:param channels: the number of input channels. |
|
:param emb_channels: the number of timestep embedding channels. |
|
:param dropout: the rate of dropout. |
|
:param out_channels: if specified, the number of out channels. |
|
:param use_conv: if True and out_channels is specified, use a spatial |
|
convolution instead of a smaller 1x1 convolution to change the |
|
channels in the skip connection. |
|
:param dims: determines if the signal is 1D, 2D, or 3D. |
|
:param use_checkpoint: if True, use gradient checkpointing on this module. |
|
:param up: if True, use this block for upsampling. |
|
:param down: if True, use this block for downsampling. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
channels: int, |
|
emb_channels: int, |
|
dropout: float, |
|
out_channels: Optional[int] = None, |
|
use_conv: bool = False, |
|
use_scale_shift_norm: bool = False, |
|
dims: int = 2, |
|
use_checkpoint: bool = False, |
|
up: bool = False, |
|
down: bool = False, |
|
kernel_size: int = 3, |
|
exchange_temb_dims: bool = False, |
|
skip_t_emb: bool = False, |
|
): |
|
super().__init__() |
|
self.channels = channels |
|
self.emb_channels = emb_channels |
|
self.dropout = dropout |
|
self.out_channels = out_channels or channels |
|
self.use_conv = use_conv |
|
self.use_checkpoint = use_checkpoint |
|
self.use_scale_shift_norm = use_scale_shift_norm |
|
self.exchange_temb_dims = exchange_temb_dims |
|
|
|
if isinstance(kernel_size, Iterable): |
|
padding = [k // 2 for k in kernel_size] |
|
else: |
|
padding = kernel_size // 2 |
|
|
|
self.in_layers = nn.Sequential( |
|
normalization(channels), |
|
nn.SiLU(), |
|
conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding), |
|
) |
|
|
|
self.updown = up or down |
|
|
|
if up: |
|
self.h_upd = Upsample(channels, False, dims) |
|
self.x_upd = Upsample(channels, False, dims) |
|
elif down: |
|
self.h_upd = Downsample(channels, False, dims) |
|
self.x_upd = Downsample(channels, False, dims) |
|
else: |
|
self.h_upd = self.x_upd = nn.Identity() |
|
|
|
self.skip_t_emb = skip_t_emb |
|
self.emb_out_channels = ( |
|
2 * self.out_channels if use_scale_shift_norm else self.out_channels |
|
) |
|
if self.skip_t_emb: |
|
|
|
assert not self.use_scale_shift_norm |
|
self.emb_layers = None |
|
self.exchange_temb_dims = False |
|
else: |
|
self.emb_layers = nn.Sequential( |
|
nn.SiLU(), |
|
linear( |
|
emb_channels, |
|
self.emb_out_channels, |
|
), |
|
) |
|
|
|
self.out_layers = nn.Sequential( |
|
normalization(self.out_channels), |
|
nn.SiLU(), |
|
nn.Dropout(p=dropout), |
|
zero_module( |
|
conv_nd( |
|
dims, |
|
self.out_channels, |
|
self.out_channels, |
|
kernel_size, |
|
padding=padding, |
|
) |
|
), |
|
) |
|
|
|
if self.out_channels == channels: |
|
self.skip_connection = nn.Identity() |
|
elif use_conv: |
|
self.skip_connection = conv_nd( |
|
dims, channels, self.out_channels, kernel_size, padding=padding |
|
) |
|
else: |
|
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) |
|
|
|
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Apply the block to a Tensor, conditioned on a timestep embedding. |
|
:param x: an [N x C x ...] Tensor of features. |
|
:param emb: an [N x emb_channels] Tensor of timestep embeddings. |
|
:return: an [N x C x ...] Tensor of outputs. |
|
""" |
|
if self.use_checkpoint: |
|
return checkpoint(self._forward, x, emb, use_reentrant=False) |
|
else: |
|
return self._forward(x, emb) |
|
|
|
def _forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: |
|
if self.updown: |
|
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] |
|
h = in_rest(x) |
|
h = self.h_upd(h) |
|
x = self.x_upd(x) |
|
h = in_conv(h) |
|
else: |
|
h = self.in_layers(x) |
|
|
|
if self.skip_t_emb: |
|
emb_out = torch.zeros_like(h) |
|
else: |
|
emb_out = self.emb_layers(emb).type(h.dtype) |
|
while len(emb_out.shape) < len(h.shape): |
|
emb_out = emb_out[..., None] |
|
if self.use_scale_shift_norm: |
|
out_norm, out_rest = self.out_layers[0], self.out_layers[1:] |
|
scale, shift = torch.chunk(emb_out, 2, dim=1) |
|
h = out_norm(h) * (1 + scale) + shift |
|
h = out_rest(h) |
|
else: |
|
if self.exchange_temb_dims: |
|
emb_out = rearrange(emb_out, "b t c ... -> b c t ...") |
|
h = h + emb_out |
|
h = self.out_layers(h) |
|
return self.skip_connection(x) + h |
|
|
|
|
|
|
|
from lvdm.modules.attention_svd import * |
|
class VideoTransformerBlock(nn.Module): |
|
ATTENTION_MODES = { |
|
"softmax": CrossAttention, |
|
"softmax-xformers": MemoryEfficientCrossAttention, |
|
} |
|
|
|
def __init__( |
|
self, |
|
dim, |
|
n_heads, |
|
d_head, |
|
dropout=0.0, |
|
context_dim=None, |
|
gated_ff=True, |
|
checkpoint=True, |
|
timesteps=None, |
|
ff_in=False, |
|
inner_dim=None, |
|
attn_mode="softmax", |
|
disable_self_attn=False, |
|
disable_temporal_crossattention=False, |
|
switch_temporal_ca_to_sa=False, |
|
): |
|
super().__init__() |
|
|
|
attn_cls = self.ATTENTION_MODES[attn_mode] |
|
|
|
self.ff_in = ff_in or inner_dim is not None |
|
if inner_dim is None: |
|
inner_dim = dim |
|
|
|
assert int(n_heads * d_head) == inner_dim |
|
|
|
self.is_res = inner_dim == dim |
|
|
|
if self.ff_in: |
|
self.norm_in = nn.LayerNorm(dim) |
|
self.ff_in = FeedForward( |
|
dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff |
|
) |
|
|
|
self.timesteps = timesteps |
|
self.disable_self_attn = disable_self_attn |
|
if self.disable_self_attn: |
|
self.attn1 = attn_cls( |
|
query_dim=inner_dim, |
|
heads=n_heads, |
|
dim_head=d_head, |
|
context_dim=context_dim, |
|
dropout=dropout, |
|
) |
|
else: |
|
self.attn1 = attn_cls( |
|
query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout |
|
) |
|
|
|
self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff) |
|
|
|
if disable_temporal_crossattention: |
|
if switch_temporal_ca_to_sa: |
|
raise ValueError |
|
else: |
|
self.attn2 = None |
|
else: |
|
self.norm2 = nn.LayerNorm(inner_dim) |
|
if switch_temporal_ca_to_sa: |
|
self.attn2 = attn_cls( |
|
query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout |
|
) |
|
else: |
|
self.attn2 = attn_cls( |
|
query_dim=inner_dim, |
|
context_dim=context_dim, |
|
heads=n_heads, |
|
dim_head=d_head, |
|
dropout=dropout, |
|
) |
|
|
|
self.norm1 = nn.LayerNorm(inner_dim) |
|
self.norm3 = nn.LayerNorm(inner_dim) |
|
self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa |
|
|
|
self.checkpoint = checkpoint |
|
if self.checkpoint: |
|
print(f"====>{self.__class__.__name__} is using checkpointing") |
|
else: |
|
print(f"====>{self.__class__.__name__} is NOT using checkpointing") |
|
|
|
def forward( |
|
self, x: torch.Tensor, context: torch.Tensor = None, timesteps: int = None |
|
) -> torch.Tensor: |
|
if self.checkpoint: |
|
return checkpoint(self._forward, x, context, timesteps, use_reentrant=False) |
|
else: |
|
return self._forward(x, context, timesteps=timesteps) |
|
|
|
def _forward(self, x, context=None, timesteps=None): |
|
assert self.timesteps or timesteps |
|
assert not (self.timesteps and timesteps) or self.timesteps == timesteps |
|
timesteps = self.timesteps or timesteps |
|
B, S, C = x.shape |
|
x = rearrange(x, "(b t) s c -> (b s) t c", t=timesteps) |
|
|
|
if self.ff_in: |
|
x_skip = x |
|
x = self.ff_in(self.norm_in(x)) |
|
if self.is_res: |
|
x += x_skip |
|
|
|
if self.disable_self_attn: |
|
x = self.attn1(self.norm1(x), context=context) + x |
|
else: |
|
x = self.attn1(self.norm1(x)) + x |
|
|
|
if self.attn2 is not None: |
|
if self.switch_temporal_ca_to_sa: |
|
x = self.attn2(self.norm2(x)) + x |
|
else: |
|
x = self.attn2(self.norm2(x), context=context) + x |
|
x_skip = x |
|
x = self.ff(self.norm3(x)) |
|
if self.is_res: |
|
x += x_skip |
|
|
|
x = rearrange( |
|
x, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps |
|
) |
|
return x |
|
|
|
def get_last_layer(self): |
|
return self.ff.net[-1].weight |
|
|
|
|
|
|
|
|
|
import functools |
|
def partialclass(cls, *args, **kwargs): |
|
class NewCls(cls): |
|
__init__ = functools.partialmethod(cls.__init__, *args, **kwargs) |
|
|
|
return NewCls |
|
|
|
|
|
class VideoResBlock(ResnetBlock): |
|
def __init__( |
|
self, |
|
out_channels, |
|
*args, |
|
dropout=0.0, |
|
video_kernel_size=3, |
|
alpha=0.0, |
|
merge_strategy="learned", |
|
**kwargs, |
|
): |
|
super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs) |
|
if video_kernel_size is None: |
|
video_kernel_size = [3, 1, 1] |
|
self.time_stack = ResBlock( |
|
channels=out_channels, |
|
emb_channels=0, |
|
dropout=dropout, |
|
dims=3, |
|
use_scale_shift_norm=False, |
|
use_conv=False, |
|
up=False, |
|
down=False, |
|
kernel_size=video_kernel_size, |
|
use_checkpoint=True, |
|
skip_t_emb=True, |
|
) |
|
|
|
self.merge_strategy = merge_strategy |
|
if self.merge_strategy == "fixed": |
|
self.register_buffer("mix_factor", torch.Tensor([alpha])) |
|
elif self.merge_strategy == "learned": |
|
self.register_parameter( |
|
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) |
|
) |
|
else: |
|
raise ValueError(f"unknown merge strategy {self.merge_strategy}") |
|
|
|
def get_alpha(self, bs): |
|
if self.merge_strategy == "fixed": |
|
return self.mix_factor |
|
elif self.merge_strategy == "learned": |
|
return torch.sigmoid(self.mix_factor) |
|
else: |
|
raise NotImplementedError() |
|
|
|
def forward(self, x, temb, skip_video=False, timesteps=None): |
|
if timesteps is None: |
|
timesteps = self.timesteps |
|
|
|
b, c, h, w = x.shape |
|
|
|
x = super().forward(x, temb) |
|
|
|
if not skip_video: |
|
x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) |
|
|
|
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) |
|
|
|
x = self.time_stack(x, temb) |
|
|
|
alpha = self.get_alpha(bs=b // timesteps) |
|
x = alpha * x + (1.0 - alpha) * x_mix |
|
|
|
x = rearrange(x, "b c t h w -> (b t) c h w") |
|
return x |
|
|
|
|
|
class AE3DConv(torch.nn.Conv2d): |
|
def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs): |
|
super().__init__(in_channels, out_channels, *args, **kwargs) |
|
if isinstance(video_kernel_size, Iterable): |
|
padding = [int(k // 2) for k in video_kernel_size] |
|
else: |
|
padding = int(video_kernel_size // 2) |
|
|
|
self.time_mix_conv = torch.nn.Conv3d( |
|
in_channels=out_channels, |
|
out_channels=out_channels, |
|
kernel_size=video_kernel_size, |
|
padding=padding, |
|
) |
|
|
|
def forward(self, input, timesteps, skip_video=False): |
|
x = super().forward(input) |
|
if skip_video: |
|
return x |
|
x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) |
|
x = self.time_mix_conv(x) |
|
return rearrange(x, "b c t h w -> (b t) c h w") |
|
|
|
|
|
class VideoBlock(AttnBlock): |
|
def __init__( |
|
self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned" |
|
): |
|
super().__init__(in_channels) |
|
|
|
self.time_mix_block = VideoTransformerBlock( |
|
dim=in_channels, |
|
n_heads=1, |
|
d_head=in_channels, |
|
checkpoint=True, |
|
ff_in=True, |
|
attn_mode="softmax", |
|
) |
|
|
|
time_embed_dim = self.in_channels * 4 |
|
self.video_time_embed = torch.nn.Sequential( |
|
torch.nn.Linear(self.in_channels, time_embed_dim), |
|
torch.nn.SiLU(), |
|
torch.nn.Linear(time_embed_dim, self.in_channels), |
|
) |
|
|
|
self.merge_strategy = merge_strategy |
|
if self.merge_strategy == "fixed": |
|
self.register_buffer("mix_factor", torch.Tensor([alpha])) |
|
elif self.merge_strategy == "learned": |
|
self.register_parameter( |
|
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) |
|
) |
|
else: |
|
raise ValueError(f"unknown merge strategy {self.merge_strategy}") |
|
|
|
def forward(self, x, timesteps, skip_video=False): |
|
if skip_video: |
|
return super().forward(x) |
|
|
|
x_in = x |
|
x = self.attention(x) |
|
h, w = x.shape[2:] |
|
x = rearrange(x, "b c h w -> b (h w) c") |
|
|
|
x_mix = x |
|
num_frames = torch.arange(timesteps, device=x.device) |
|
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) |
|
num_frames = rearrange(num_frames, "b t -> (b t)") |
|
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False) |
|
emb = self.video_time_embed(t_emb) |
|
emb = emb[:, None, :] |
|
x_mix = x_mix + emb |
|
|
|
alpha = self.get_alpha() |
|
x_mix = self.time_mix_block(x_mix, timesteps=timesteps) |
|
x = alpha * x + (1.0 - alpha) * x_mix |
|
|
|
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) |
|
x = self.proj_out(x) |
|
|
|
return x_in + x |
|
|
|
def get_alpha( |
|
self, |
|
): |
|
if self.merge_strategy == "fixed": |
|
return self.mix_factor |
|
elif self.merge_strategy == "learned": |
|
return torch.sigmoid(self.mix_factor) |
|
else: |
|
raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}") |
|
|
|
|
|
class MemoryEfficientVideoBlock(MemoryEfficientAttnBlock): |
|
def __init__( |
|
self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned" |
|
): |
|
super().__init__(in_channels) |
|
|
|
self.time_mix_block = VideoTransformerBlock( |
|
dim=in_channels, |
|
n_heads=1, |
|
d_head=in_channels, |
|
checkpoint=True, |
|
ff_in=True, |
|
attn_mode="softmax-xformers", |
|
) |
|
|
|
time_embed_dim = self.in_channels * 4 |
|
self.video_time_embed = torch.nn.Sequential( |
|
torch.nn.Linear(self.in_channels, time_embed_dim), |
|
torch.nn.SiLU(), |
|
torch.nn.Linear(time_embed_dim, self.in_channels), |
|
) |
|
|
|
self.merge_strategy = merge_strategy |
|
if self.merge_strategy == "fixed": |
|
self.register_buffer("mix_factor", torch.Tensor([alpha])) |
|
elif self.merge_strategy == "learned": |
|
self.register_parameter( |
|
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) |
|
) |
|
else: |
|
raise ValueError(f"unknown merge strategy {self.merge_strategy}") |
|
|
|
def forward(self, x, timesteps, skip_time_block=False): |
|
if skip_time_block: |
|
return super().forward(x) |
|
|
|
x_in = x |
|
x = self.attention(x) |
|
h, w = x.shape[2:] |
|
x = rearrange(x, "b c h w -> b (h w) c") |
|
|
|
x_mix = x |
|
num_frames = torch.arange(timesteps, device=x.device) |
|
num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) |
|
num_frames = rearrange(num_frames, "b t -> (b t)") |
|
t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False) |
|
emb = self.video_time_embed(t_emb) |
|
emb = emb[:, None, :] |
|
x_mix = x_mix + emb |
|
|
|
alpha = self.get_alpha() |
|
x_mix = self.time_mix_block(x_mix, timesteps=timesteps) |
|
x = alpha * x + (1.0 - alpha) * x_mix |
|
|
|
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) |
|
x = self.proj_out(x) |
|
|
|
return x_in + x |
|
|
|
def get_alpha( |
|
self, |
|
): |
|
if self.merge_strategy == "fixed": |
|
return self.mix_factor |
|
elif self.merge_strategy == "learned": |
|
return torch.sigmoid(self.mix_factor) |
|
else: |
|
raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}") |
|
|
|
|
|
def make_time_attn( |
|
in_channels, |
|
attn_type="vanilla", |
|
attn_kwargs=None, |
|
alpha: float = 0, |
|
merge_strategy: str = "learned", |
|
): |
|
assert attn_type in [ |
|
"vanilla", |
|
"vanilla-xformers", |
|
], f"attn_type {attn_type} not supported for spatio-temporal attention" |
|
print( |
|
f"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels" |
|
) |
|
if not XFORMERS_IS_AVAILABLE and attn_type == "vanilla-xformers": |
|
print( |
|
f"Attention mode '{attn_type}' is not available. Falling back to vanilla attention. " |
|
f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}" |
|
) |
|
attn_type = "vanilla" |
|
|
|
if attn_type == "vanilla": |
|
assert attn_kwargs is None |
|
return partialclass( |
|
VideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy |
|
) |
|
elif attn_type == "vanilla-xformers": |
|
print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...") |
|
return partialclass( |
|
MemoryEfficientVideoBlock, |
|
in_channels, |
|
alpha=alpha, |
|
merge_strategy=merge_strategy, |
|
) |
|
else: |
|
return NotImplementedError() |
|
|
|
|
|
class Conv2DWrapper(torch.nn.Conv2d): |
|
def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor: |
|
return super().forward(input) |
|
|
|
|
|
class VideoDecoder(Decoder): |
|
available_time_modes = ["all", "conv-only", "attn-only"] |
|
|
|
def __init__( |
|
self, |
|
*args, |
|
video_kernel_size: Union[int, list] = [3,1,1], |
|
alpha: float = 0.0, |
|
merge_strategy: str = "learned", |
|
time_mode: str = "conv-only", |
|
**kwargs, |
|
): |
|
self.video_kernel_size = video_kernel_size |
|
self.alpha = alpha |
|
self.merge_strategy = merge_strategy |
|
self.time_mode = time_mode |
|
assert ( |
|
self.time_mode in self.available_time_modes |
|
), f"time_mode parameter has to be in {self.available_time_modes}" |
|
super().__init__(*args, **kwargs) |
|
|
|
def get_last_layer(self, skip_time_mix=False, **kwargs): |
|
if self.time_mode == "attn-only": |
|
raise NotImplementedError("TODO") |
|
else: |
|
return ( |
|
self.conv_out.time_mix_conv.weight |
|
if not skip_time_mix |
|
else self.conv_out.weight |
|
) |
|
|
|
def _make_attn(self) -> Callable: |
|
if self.time_mode not in ["conv-only", "only-last-conv"]: |
|
return partialclass( |
|
make_time_attn, |
|
alpha=self.alpha, |
|
merge_strategy=self.merge_strategy, |
|
) |
|
else: |
|
return super()._make_attn() |
|
|
|
def _make_conv(self) -> Callable: |
|
if self.time_mode != "attn-only": |
|
return partialclass(AE3DConv, video_kernel_size=self.video_kernel_size) |
|
else: |
|
return Conv2DWrapper |
|
|
|
def _make_resblock(self) -> Callable: |
|
if self.time_mode not in ["attn-only", "only-last-conv"]: |
|
return partialclass( |
|
VideoResBlock, |
|
video_kernel_size=self.video_kernel_size, |
|
alpha=self.alpha, |
|
merge_strategy=self.merge_strategy, |
|
) |
|
else: |
|
return super()._make_resblock() |