|
import numpy as np |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.nn import init |
|
|
|
import einops |
|
|
|
def conv3x3(in_channels, out_channels, stride=1, |
|
padding=1, bias=True, groups=1): |
|
return nn.Conv2d( |
|
in_channels, |
|
out_channels, |
|
kernel_size=3, |
|
stride=stride, |
|
padding=padding, |
|
bias=bias, |
|
groups=groups) |
|
|
|
def upconv2x2(in_channels, out_channels, mode='transpose'): |
|
if mode == 'transpose': |
|
return nn.ConvTranspose2d( |
|
in_channels, |
|
out_channels, |
|
kernel_size=2, |
|
stride=2) |
|
else: |
|
|
|
|
|
return nn.Sequential( |
|
nn.Upsample(mode='bilinear', scale_factor=2), |
|
conv1x1(in_channels, out_channels)) |
|
|
|
def conv1x1(in_channels, out_channels, groups=1): |
|
return nn.Conv2d( |
|
in_channels, |
|
out_channels, |
|
kernel_size=1, |
|
groups=groups, |
|
stride=1) |
|
|
|
class ConvTriplane3dAware(nn.Module): |
|
""" 3D aware triplane conv (as described in RODIN) """ |
|
def __init__(self, internal_conv_f, in_channels, out_channels, order='xz'): |
|
""" |
|
Args: |
|
internal_conv_f: function that should return a 2D convolution Module |
|
given in and out channels |
|
order: if triplane input is in 'xz' order |
|
""" |
|
super(ConvTriplane3dAware, self).__init__() |
|
|
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
assert order in ['xz', 'zx'] |
|
self.order = order |
|
|
|
self.plane_convs = nn.ModuleList([ |
|
internal_conv_f(3*self.in_channels, self.out_channels) for _ in range(3)]) |
|
|
|
def forward(self, triplanes_list): |
|
""" |
|
Args: |
|
triplanes_list: [(B,Ci,H,W)]*3 in xy,yz,(zx or xz) depending on order |
|
Returns: |
|
out_triplanes_list: [(B,Co,H,W)]*3 in xy,yz,(zx or xz) depending on order |
|
""" |
|
inps = list(triplanes_list) |
|
xp = 1 |
|
yp = 2 |
|
zp = 0 |
|
|
|
if self.order == 'xz': |
|
|
|
inps[yp] = einops.rearrange(inps[yp], 'b c x z -> b c z x') |
|
|
|
|
|
oplanes = [None]*3 |
|
|
|
for iplane in [zp, xp, yp]: |
|
|
|
|
|
|
|
|
|
|
|
jplane = (iplane+1)%3 |
|
kplane = (iplane+2)%3 |
|
|
|
ifeat = inps[iplane] |
|
|
|
|
|
|
|
|
|
|
|
jpool = torch.mean(inps[jplane], dim=3 ,keepdim=True) |
|
jpool = einops.rearrange(jpool, 'b c k 1 -> b c 1 k') |
|
jpool = einops.repeat(jpool, 'b c 1 k -> b c j k', j=ifeat.size(2)) |
|
|
|
|
|
|
|
kpool = torch.mean(inps[kplane], dim=2 ,keepdim=True) |
|
kpool = einops.rearrange(kpool, 'b c 1 j -> b c j 1') |
|
kpool = einops.repeat(kpool, 'b c j 1 -> b c j k', k=ifeat.size(3)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
catfeat = torch.cat([ifeat, jpool, kpool], dim=1) |
|
oplane = self.plane_convs[iplane](catfeat) |
|
oplanes[iplane] = oplane |
|
|
|
if self.order == 'xz': |
|
|
|
oplanes[yp] = einops.rearrange(oplanes[yp], 'b c z x -> b c x z') |
|
|
|
return oplanes |
|
|
|
def roll_triplanes(triplanes_list): |
|
|
|
tristack = torch.stack((triplanes_list),dim=2) |
|
return einops.rearrange(tristack, 'b c tri h w -> b c (tri h) w', tri=3) |
|
|
|
def unroll_triplanes(rolled_triplane): |
|
|
|
tristack = einops.rearrange(rolled_triplane, 'b c (tri h) w -> b c tri h w', tri=3) |
|
return torch.unbind(tristack, dim=2) |
|
|
|
def conv1x1triplane3daware(in_channels, out_channels, order='xz', **kwargs): |
|
return ConvTriplane3dAware(lambda inp, out: conv1x1(inp,out,**kwargs), |
|
in_channels, out_channels,order=order) |
|
|
|
def Normalize(in_channels, num_groups=32): |
|
num_groups = min(in_channels, num_groups) |
|
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) |
|
|
|
def nonlinearity(x): |
|
|
|
|
|
return x*torch.sigmoid(x) |
|
|
|
class Upsample(nn.Module): |
|
def __init__(self, in_channels, with_conv): |
|
super().__init__() |
|
self.with_conv = with_conv |
|
if self.with_conv: |
|
self.conv = torch.nn.Conv2d(in_channels, |
|
in_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1) |
|
|
|
def forward(self, x): |
|
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") |
|
if self.with_conv: |
|
x = self.conv(x) |
|
return x |
|
|
|
class Downsample(nn.Module): |
|
def __init__(self, in_channels, with_conv): |
|
super().__init__() |
|
self.with_conv = with_conv |
|
if self.with_conv: |
|
|
|
self.conv = torch.nn.Conv2d(in_channels, |
|
in_channels, |
|
kernel_size=3, |
|
stride=2, |
|
padding=0) |
|
|
|
def forward(self, x): |
|
if self.with_conv: |
|
pad = (0,1,0,1) |
|
x = torch.nn.functional.pad(x, pad, mode="constant", value=0) |
|
x = self.conv(x) |
|
else: |
|
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) |
|
return x |
|
|
|
class ResnetBlock3dAware(nn.Module): |
|
def __init__(self, in_channels, out_channels=None): |
|
|
|
super().__init__() |
|
self.in_channels = in_channels |
|
out_channels = in_channels if out_channels is None else out_channels |
|
self.out_channels = out_channels |
|
|
|
|
|
self.norm1 = Normalize(in_channels) |
|
self.conv1 = conv3x3(self.in_channels, self.out_channels) |
|
|
|
self.norm_mid = Normalize(out_channels) |
|
self.conv_3daware = conv1x1triplane3daware(self.out_channels, self.out_channels) |
|
|
|
self.norm2 = Normalize(out_channels) |
|
self.conv2 = conv3x3(self.out_channels, self.out_channels) |
|
|
|
if self.in_channels != self.out_channels: |
|
self.nin_shortcut = torch.nn.Conv2d(in_channels, |
|
out_channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0) |
|
|
|
def forward(self, x): |
|
|
|
h = x |
|
h = self.norm1(h) |
|
h = nonlinearity(h) |
|
h = self.conv1(h) |
|
|
|
|
|
h = self.norm_mid(h) |
|
h = nonlinearity(h) |
|
h = unroll_triplanes(h) |
|
h = self.conv_3daware(h) |
|
h = roll_triplanes(h) |
|
|
|
|
|
h = self.norm2(h) |
|
h = nonlinearity(h) |
|
h = self.conv2(h) |
|
|
|
if self.in_channels != self.out_channels: |
|
x = self.nin_shortcut(x) |
|
|
|
return x+h |
|
|
|
class DownConv3dAware(nn.Module): |
|
""" |
|
A helper Module that performs 2 convolutions and 1 MaxPool. |
|
A ReLU activation follows each convolution. |
|
""" |
|
def __init__(self, in_channels, out_channels, downsample=True, with_conv=False): |
|
super(DownConv3dAware, self).__init__() |
|
|
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
|
|
self.block = ResnetBlock3dAware(in_channels=in_channels, |
|
out_channels=out_channels) |
|
|
|
self.do_downsample = downsample |
|
self.downsample = Downsample(out_channels, with_conv=with_conv) |
|
|
|
def forward(self, x): |
|
""" |
|
rolled input, rolled output |
|
Args: |
|
x: rolled (b c (tri*h) w) |
|
""" |
|
x = self.block(x) |
|
before_pool = x |
|
|
|
|
|
if self.do_downsample: |
|
|
|
x = einops.rearrange(x, 'b c (tri h) w -> b (c tri) h w', tri=3) |
|
x = self.downsample(x) |
|
|
|
x = einops.rearrange(x, 'b (c tri) h w -> b c (tri h) w', tri=3) |
|
return x, before_pool |
|
|
|
class UpConv3dAware(nn.Module): |
|
""" |
|
A helper Module that performs 2 convolutions and 1 UpConvolution. |
|
A ReLU activation follows each convolution. |
|
""" |
|
def __init__(self, in_channels, out_channels, |
|
merge_mode='concat', with_conv=False): |
|
super(UpConv3dAware, self).__init__() |
|
|
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.merge_mode = merge_mode |
|
|
|
self.upsample = Upsample(in_channels, with_conv) |
|
|
|
if self.merge_mode == 'concat': |
|
self.norm1 = Normalize(in_channels+out_channels) |
|
self.block = ResnetBlock3dAware(in_channels=in_channels+out_channels, |
|
out_channels=out_channels) |
|
else: |
|
self.norm1 = Normalize(in_channels) |
|
self.block = ResnetBlock3dAware(in_channels=in_channels, |
|
out_channels=out_channels) |
|
|
|
|
|
def forward(self, from_down, from_up): |
|
""" Forward pass |
|
rolled inputs, rolled output |
|
rolled (b c (tri*h) w) |
|
Arguments: |
|
from_down: tensor from the encoder pathway |
|
from_up: upconv'd tensor from the decoder pathway |
|
""" |
|
|
|
from_up = self.upsample(from_up) |
|
if self.merge_mode == 'concat': |
|
x = torch.cat((from_up, from_down), 1) |
|
else: |
|
x = from_up + from_down |
|
|
|
x = self.norm1(x) |
|
x = self.block(x) |
|
return x |
|
|
|
class UNetTriplane3dAware(nn.Module): |
|
def __init__(self, out_channels, in_channels=3, depth=5, |
|
start_filts=64, |
|
use_initial_conv=False, |
|
merge_mode='concat', **kwargs): |
|
""" |
|
Arguments: |
|
in_channels: int, number of channels in the input tensor. |
|
Default is 3 for RGB images. |
|
depth: int, number of MaxPools in the U-Net. |
|
start_filts: int, number of convolutional filters for the |
|
first conv. |
|
""" |
|
super(UNetTriplane3dAware, self).__init__() |
|
|
|
|
|
self.out_channels = out_channels |
|
self.in_channels = in_channels |
|
self.start_filts = start_filts |
|
self.depth = depth |
|
|
|
self.use_initial_conv = use_initial_conv |
|
if use_initial_conv: |
|
self.conv_initial = conv1x1(self.in_channels, self.start_filts) |
|
|
|
self.down_convs = [] |
|
self.up_convs = [] |
|
|
|
|
|
for i in range(depth): |
|
if i == 0: |
|
ins = self.start_filts if use_initial_conv else self.in_channels |
|
else: |
|
ins = outs |
|
outs = self.start_filts*(2**i) |
|
downsamp_it = True if i < depth-1 else False |
|
|
|
down_conv = DownConv3dAware(ins, outs, downsample = downsamp_it) |
|
self.down_convs.append(down_conv) |
|
|
|
for i in range(depth-1): |
|
ins = outs |
|
outs = ins // 2 |
|
up_conv = UpConv3dAware(ins, outs, |
|
merge_mode=merge_mode) |
|
self.up_convs.append(up_conv) |
|
|
|
|
|
self.down_convs = nn.ModuleList(self.down_convs) |
|
self.up_convs = nn.ModuleList(self.up_convs) |
|
|
|
self.norm_out = Normalize(outs) |
|
self.conv_final = conv1x1(outs, self.out_channels) |
|
|
|
self.reset_params() |
|
|
|
@staticmethod |
|
def weight_init(m): |
|
if isinstance(m, nn.Conv2d): |
|
|
|
init.xavier_normal_(m.weight) |
|
init.constant_(m.bias, 0) |
|
|
|
|
|
def reset_params(self): |
|
for i, m in enumerate(self.modules()): |
|
self.weight_init(m) |
|
|
|
|
|
def forward(self, x): |
|
""" |
|
Args: |
|
x: Stacked triplane expected to be in (B,3,C,H,W) |
|
""" |
|
|
|
x = einops.rearrange(x, 'b tri c h w -> b c (tri h) w', tri=3) |
|
|
|
if self.use_initial_conv: |
|
x = self.conv_initial(x) |
|
|
|
encoder_outs = [] |
|
|
|
for i, module in enumerate(self.down_convs): |
|
x, before_pool = module(x) |
|
encoder_outs.append(before_pool) |
|
|
|
|
|
|
|
|
|
for i, module in enumerate(self.up_convs): |
|
before_pool = encoder_outs[-(i+2)] |
|
x = module(before_pool, x) |
|
|
|
x = self.norm_out(x) |
|
|
|
|
|
|
|
|
|
x = self.conv_final(nonlinearity(x)) |
|
|
|
|
|
x = einops.rearrange(x, 'b c (tri h) w -> b tri c h w', tri=3) |
|
return x |
|
|
|
|
|
def setup_unet(output_channels, input_channels, unet_cfg): |
|
if unet_cfg['use_3d_aware']: |
|
assert(unet_cfg['rolled']) |
|
unet = UNetTriplane3dAware( |
|
out_channels=output_channels, |
|
in_channels=input_channels, |
|
depth=unet_cfg['depth'], |
|
use_initial_conv=unet_cfg['use_initial_conv'], |
|
start_filts=unet_cfg['start_hidden_channels'],) |
|
else: |
|
raise NotImplementedError |
|
return unet |
|
|
|
|