NIRVANALAN
init
11e6f7b
raw
history blame
14 kB
from numpy import sqrt
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Tuple, Literal
from functools import partial
from pdb import set_trace as st
# from core.attention import MemEffAttention
from vit.vision_transformer import MemEffAttention
class MVAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
proj_bias: bool = True,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
groups: int = 32,
eps: float = 1e-5,
residual: bool = True,
skip_scale: float = 1,
num_frames: int = 4, # WARN: hardcoded!
):
super().__init__()
self.residual = residual
self.skip_scale = skip_scale
self.num_frames = num_frames
self.norm = nn.GroupNorm(num_groups=groups,
num_channels=dim,
eps=eps,
affine=True)
self.attn = MemEffAttention(dim, num_heads, qkv_bias, proj_bias,
attn_drop, proj_drop)
def forward(self, x):
# x: [B*V, C, H, W]
BV, C, H, W = x.shape
B = BV // self.num_frames # assert BV % self.num_frames == 0
res = x
x = self.norm(x)
x = x.reshape(B, self.num_frames, C, H,
W).permute(0, 1, 3, 4, 2).reshape(B, -1, C)
x = self.attn(x)
x = x.reshape(B, self.num_frames, H, W,
C).permute(0, 1, 4, 2, 3).reshape(BV, C, H, W)
if self.residual:
x = (x + res) * self.skip_scale
return x
class ResnetBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
resample: Literal['default', 'up', 'down'] = 'default',
groups: int = 32,
eps: float = 1e-5,
skip_scale: float = 1, # multiplied to output
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.skip_scale = skip_scale
self.norm1 = nn.GroupNorm(num_groups=groups,
num_channels=in_channels,
eps=eps,
affine=True)
self.conv1 = nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
self.norm2 = nn.GroupNorm(num_groups=groups,
num_channels=out_channels,
eps=eps,
affine=True)
self.conv2 = nn.Conv2d(out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
self.act = F.silu
self.resample = None
if resample == 'up':
self.resample = partial(F.interpolate,
scale_factor=2.0,
mode="nearest")
elif resample == 'down':
self.resample = nn.AvgPool2d(kernel_size=2, stride=2)
self.shortcut = nn.Identity()
if self.in_channels != self.out_channels:
self.shortcut = nn.Conv2d(in_channels,
out_channels,
kernel_size=1,
bias=True)
def forward(self, x):
res = x
x = self.norm1(x)
x = self.act(x)
if self.resample:
res = self.resample(res)
x = self.resample(x)
x = self.conv1(x)
x = self.norm2(x)
x = self.act(x)
x = self.conv2(x)
x = (x + self.shortcut(res)) * self.skip_scale
return x
class DownBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
downsample: bool = True,
attention: bool = True,
attention_heads: int = 16,
skip_scale: float = 1,
):
super().__init__()
nets = []
attns = []
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
nets.append(
ResnetBlock(in_channels, out_channels, skip_scale=skip_scale))
if attention:
attns.append(
MVAttention(out_channels,
attention_heads,
skip_scale=skip_scale))
else:
attns.append(None)
self.nets = nn.ModuleList(nets)
self.attns = nn.ModuleList(attns)
self.downsample = None
if downsample:
self.downsample = nn.Conv2d(out_channels,
out_channels,
kernel_size=3,
stride=2,
padding=1)
def forward(self, x):
xs = []
for attn, net in zip(self.attns, self.nets):
x = net(x)
if attn:
x = attn(x)
xs.append(x)
if self.downsample:
x = self.downsample(x)
xs.append(x)
return x, xs
class MidBlock(nn.Module):
def __init__(
self,
in_channels: int,
num_layers: int = 1,
attention: bool = True,
attention_heads: int = 16,
skip_scale: float = 1,
):
super().__init__()
nets = []
attns = []
# first layer
nets.append(
ResnetBlock(in_channels, in_channels, skip_scale=skip_scale))
# more layers
for i in range(num_layers):
nets.append(
ResnetBlock(in_channels, in_channels, skip_scale=skip_scale))
if attention:
attns.append(
MVAttention(in_channels,
attention_heads,
skip_scale=skip_scale))
else:
attns.append(None)
self.nets = nn.ModuleList(nets)
self.attns = nn.ModuleList(attns)
def forward(self, x):
x = self.nets[0](x)
for attn, net in zip(self.attns, self.nets[1:]):
if attn:
x = attn(x)
x = net(x)
return x
class UpBlock(nn.Module):
def __init__(
self,
in_channels: int,
prev_out_channels: int,
out_channels: int,
num_layers: int = 1,
upsample: bool = True,
attention: bool = True,
attention_heads: int = 16,
skip_scale: float = 1,
):
super().__init__()
nets = []
attns = []
for i in range(num_layers):
cin = in_channels if i == 0 else out_channels
cskip = prev_out_channels if (i == num_layers -
1) else out_channels
nets.append(
ResnetBlock(cin + cskip, out_channels, skip_scale=skip_scale))
if attention:
attns.append(
MVAttention(out_channels,
attention_heads,
skip_scale=skip_scale))
else:
attns.append(None)
self.nets = nn.ModuleList(nets)
self.attns = nn.ModuleList(attns)
self.upsample = None
if upsample:
self.upsample = nn.Conv2d(out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x, xs):
for attn, net in zip(self.attns, self.nets):
res_x = xs[-1]
xs = xs[:-1]
x = torch.cat([x, res_x], dim=1)
x = net(x)
if attn:
x = attn(x)
if self.upsample:
x = F.interpolate(x, scale_factor=2.0, mode='nearest')
x = self.upsample(x)
return x
# it could be asymmetric!
class MVUNet(nn.Module):
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
down_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024),
down_attention: Tuple[bool,
...] = (False, False, False, True, True),
mid_attention: bool = True,
up_channels: Tuple[int, ...] = (1024, 512, 256),
up_attention: Tuple[bool, ...] = (True, True, False),
layers_per_block: int = 2,
skip_scale: float = np.sqrt(0.5),
):
super().__init__()
# first
self.conv_in = nn.Conv2d(in_channels,
down_channels[0],
kernel_size=3,
stride=1,
padding=1)
# down
down_blocks = []
cout = down_channels[0]
for i in range(len(down_channels)):
cin = cout
cout = down_channels[i]
down_blocks.append(
DownBlock(
cin,
cout,
num_layers=layers_per_block,
downsample=(i
!= len(down_channels) - 1), # not final layer
attention=down_attention[i],
skip_scale=skip_scale,
))
self.down_blocks = nn.ModuleList(down_blocks)
# mid
self.mid_block = MidBlock(down_channels[-1],
attention=mid_attention,
skip_scale=skip_scale)
# up
up_blocks = []
cout = up_channels[0]
for i in range(len(up_channels)):
cin = cout
cout = up_channels[i]
cskip = down_channels[max(-2 - i,
-len(down_channels))] # for assymetric
up_blocks.append(
UpBlock(
cin,
cskip,
cout,
num_layers=layers_per_block + 1, # one more layer for up
upsample=(i != len(up_channels) - 1), # not final layer
attention=up_attention[i],
skip_scale=skip_scale,
))
self.up_blocks = nn.ModuleList(up_blocks)
# last
self.norm_out = nn.GroupNorm(num_channels=up_channels[-1],
num_groups=32,
eps=1e-5)
self.conv_out = nn.Conv2d(up_channels[-1],
out_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
# x: [B, Cin, H, W]
# first
x = self.conv_in(x)
# down
xss = [x]
for block in self.down_blocks:
x, xs = block(x)
xss.extend(xs)
# mid
x = self.mid_block(x) # 32 (B V) 1024 16 16
# up
for block in self.up_blocks:
xs = xss[-len(block.nets):]
xss = xss[:-len(block.nets)]
x = block(x, xs)
# last
x = self.norm_out(x)
x = F.silu(x)
x = self.conv_out(x) # [B, Cout, H', W']
return x
class LGM_MVEncoder(MVUNet):
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
down_channels: Tuple[int] = (64, 128, 256, 512, 1024),
down_attention: Tuple[bool] = (False, False, False, True, True),
mid_attention: bool = True,
up_channels: Tuple[int] = (1024, 512, 256),
up_attention: Tuple[bool] = (True, True, False),
layers_per_block: int = 2,
skip_scale: float = np.sqrt(0.5),
z_channels=4,
double_z=True,
add_fusion_layer=True,
):
super().__init__(in_channels, out_channels, down_channels,
down_attention, mid_attention, up_channels,
up_attention, layers_per_block, skip_scale)
del self.up_blocks
self.conv_out = torch.nn.Conv2d(up_channels[0],
2 *
z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
padding=1)
if add_fusion_layer: # fusion 4 frames
self.fusion_layer = torch.nn.Conv2d(
2 * z_channels * 4 if double_z else z_channels * 4,
2 * z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
padding=1)
self.num_frames = 4 # !hard coded
def forward(self, x):
# first
x = self.conv_in(x)
# down
xss = [x]
for block in self.down_blocks:
x, xs = block(x)
xss.extend(xs)
# mid
x = self.mid_block(x) # 32 (B V) 1024 16 16
# multi-view aggregation, as in pixel-nerf
x = x.chunk(x.shape[0] // self.num_frames) # features from the same single instance aggregated here
# h = [feat.max(keepdim=True, dim=0)[0] for feat in h] # max pooling
x = [self.fusion_layer(torch.cat(feat.chunk(feat.shape[0]), dim=1)) for feat in x] # conv pooling
st()
return torch.cat(x, dim=0)