omnipart's picture
init
491eded
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
import einops
def conv3x3(in_channels, out_channels, stride=1,
padding=1, bias=True, groups=1):
return nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=padding,
bias=bias,
groups=groups)
def upconv2x2(in_channels, out_channels, mode='transpose'):
if mode == 'transpose':
return nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size=2,
stride=2)
else:
# out_channels is always going to be the same
# as in_channels
return nn.Sequential(
nn.Upsample(mode='bilinear', scale_factor=2),
conv1x1(in_channels, out_channels))
def conv1x1(in_channels, out_channels, groups=1):
return nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
groups=groups,
stride=1)
class ConvTriplane3dAware(nn.Module):
""" 3D aware triplane conv (as described in RODIN) """
def __init__(self, internal_conv_f, in_channels, out_channels, order='xz'):
"""
Args:
internal_conv_f: function that should return a 2D convolution Module
given in and out channels
order: if triplane input is in 'xz' order
"""
super(ConvTriplane3dAware, self).__init__()
# Need 3 seperate convolutions
self.in_channels = in_channels
self.out_channels = out_channels
assert order in ['xz', 'zx']
self.order = order
# Going to stack from other planes
self.plane_convs = nn.ModuleList([
internal_conv_f(3*self.in_channels, self.out_channels) for _ in range(3)])
def forward(self, triplanes_list):
"""
Args:
triplanes_list: [(B,Ci,H,W)]*3 in xy,yz,(zx or xz) depending on order
Returns:
out_triplanes_list: [(B,Co,H,W)]*3 in xy,yz,(zx or xz) depending on order
"""
inps = list(triplanes_list)
xp = 1 #(yz)
yp = 2 #(zx)
zp = 0 #(xy)
if self.order == 'xz':
# get into zx order
inps[yp] = einops.rearrange(inps[yp], 'b c x z -> b c z x')
oplanes = [None]*3
# order shouldn't matter
for iplane in [zp, xp, yp]:
# i_plane -> (j,k)
# need to average out i and convert to (j,k)
# j_plane -> (k,i)
# k_plane -> (i,j)
jplane = (iplane+1)%3
kplane = (iplane+2)%3
ifeat = inps[iplane]
# need to average out nonshared dim
# Average pool across
# j_plane -> (k,i) -> (k,1) -> (1,k) -> (j,k)
# b c k i -> b c k 1
jpool = torch.mean(inps[jplane], dim=3 ,keepdim=True)
jpool = einops.rearrange(jpool, 'b c k 1 -> b c 1 k')
jpool = einops.repeat(jpool, 'b c 1 k -> b c j k', j=ifeat.size(2))
# k_plane -> (i,j) -> (1,j) -> (j,1) -> (j,k)
# b c i j -> b c 1 j
kpool = torch.mean(inps[kplane], dim=2 ,keepdim=True)
kpool = einops.rearrange(kpool, 'b c 1 j -> b c j 1')
kpool = einops.repeat(kpool, 'b c j 1 -> b c j k', k=ifeat.size(3))
# b c h w
# jpool = jpool.expand_as(ifeat)
# kpool = kpool.expand_as(ifeat)
# concat and conv on feature dim
catfeat = torch.cat([ifeat, jpool, kpool], dim=1)
oplane = self.plane_convs[iplane](catfeat)
oplanes[iplane] = oplane
if self.order == 'xz':
# get back into xz order
oplanes[yp] = einops.rearrange(oplanes[yp], 'b c z x -> b c x z')
return oplanes
def roll_triplanes(triplanes_list):
# B, C, tri, h, w
tristack = torch.stack((triplanes_list),dim=2)
return einops.rearrange(tristack, 'b c tri h w -> b c (tri h) w', tri=3)
def unroll_triplanes(rolled_triplane):
# B, C, tri*h, w
tristack = einops.rearrange(rolled_triplane, 'b c (tri h) w -> b c tri h w', tri=3)
return torch.unbind(tristack, dim=2)
def conv1x1triplane3daware(in_channels, out_channels, order='xz', **kwargs):
return ConvTriplane3dAware(lambda inp, out: conv1x1(inp,out,**kwargs),
in_channels, out_channels,order=order)
def Normalize(in_channels, num_groups=32):
num_groups = min(in_channels, num_groups) # avoid error if in_channels < 32
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
def nonlinearity(x):
# return F.relu(x)
# Swish
return x*torch.sigmoid(x)
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
if self.with_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=2,
padding=0)
def forward(self, x):
if self.with_conv:
pad = (0,1,0,1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
class ResnetBlock3dAware(nn.Module):
def __init__(self, in_channels, out_channels=None):
#, conv_shortcut=False):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
# self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels)
self.conv1 = conv3x3(self.in_channels, self.out_channels)
self.norm_mid = Normalize(out_channels)
self.conv_3daware = conv1x1triplane3daware(self.out_channels, self.out_channels)
self.norm2 = Normalize(out_channels)
self.conv2 = conv3x3(self.out_channels, self.out_channels)
if self.in_channels != self.out_channels:
self.nin_shortcut = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x):
# 3x3 plane comm
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
# 1x1 3d aware, crossplane comm
h = self.norm_mid(h)
h = nonlinearity(h)
h = unroll_triplanes(h)
h = self.conv_3daware(h)
h = roll_triplanes(h)
# 3x3 plane comm
h = self.norm2(h)
h = nonlinearity(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x+h
class DownConv3dAware(nn.Module):
"""
A helper Module that performs 2 convolutions and 1 MaxPool.
A ReLU activation follows each convolution.
"""
def __init__(self, in_channels, out_channels, downsample=True, with_conv=False):
super(DownConv3dAware, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.block = ResnetBlock3dAware(in_channels=in_channels,
out_channels=out_channels)
self.do_downsample = downsample
self.downsample = Downsample(out_channels, with_conv=with_conv)
def forward(self, x):
"""
rolled input, rolled output
Args:
x: rolled (b c (tri*h) w)
"""
x = self.block(x)
before_pool = x
# if self.pooling:
# x = self.pool(x)
if self.do_downsample:
# unroll and cat channel-wise (to prevent pooling across triplane boundaries)
x = einops.rearrange(x, 'b c (tri h) w -> b (c tri) h w', tri=3)
x = self.downsample(x)
# undo
x = einops.rearrange(x, 'b (c tri) h w -> b c (tri h) w', tri=3)
return x, before_pool
class UpConv3dAware(nn.Module):
"""
A helper Module that performs 2 convolutions and 1 UpConvolution.
A ReLU activation follows each convolution.
"""
def __init__(self, in_channels, out_channels,
merge_mode='concat', with_conv=False): #up_mode='transpose', ):
super(UpConv3dAware, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.merge_mode = merge_mode
self.upsample = Upsample(in_channels, with_conv)
if self.merge_mode == 'concat':
self.norm1 = Normalize(in_channels+out_channels)
self.block = ResnetBlock3dAware(in_channels=in_channels+out_channels,
out_channels=out_channels)
else:
self.norm1 = Normalize(in_channels)
self.block = ResnetBlock3dAware(in_channels=in_channels,
out_channels=out_channels)
def forward(self, from_down, from_up):
""" Forward pass
rolled inputs, rolled output
rolled (b c (tri*h) w)
Arguments:
from_down: tensor from the encoder pathway
from_up: upconv'd tensor from the decoder pathway
"""
# from_up = self.upconv(from_up)
from_up = self.upsample(from_up)
if self.merge_mode == 'concat':
x = torch.cat((from_up, from_down), 1)
else:
x = from_up + from_down
x = self.norm1(x)
x = self.block(x)
return x
class UNetTriplane3dAware(nn.Module):
def __init__(self, out_channels, in_channels=3, depth=5,
start_filts=64,# up_mode='transpose',
use_initial_conv=False,
merge_mode='concat', **kwargs):
"""
Arguments:
in_channels: int, number of channels in the input tensor.
Default is 3 for RGB images.
depth: int, number of MaxPools in the U-Net.
start_filts: int, number of convolutional filters for the
first conv.
"""
super(UNetTriplane3dAware, self).__init__()
self.out_channels = out_channels
self.in_channels = in_channels
self.start_filts = start_filts
self.depth = depth
self.use_initial_conv = use_initial_conv
if use_initial_conv:
self.conv_initial = conv1x1(self.in_channels, self.start_filts)
self.down_convs = []
self.up_convs = []
# create the encoder pathway and add to a list
for i in range(depth):
if i == 0:
ins = self.start_filts if use_initial_conv else self.in_channels
else:
ins = outs
outs = self.start_filts*(2**i)
downsamp_it = True if i < depth-1 else False
down_conv = DownConv3dAware(ins, outs, downsample = downsamp_it)
self.down_convs.append(down_conv)
for i in range(depth-1):
ins = outs
outs = ins // 2
up_conv = UpConv3dAware(ins, outs,
merge_mode=merge_mode)
self.up_convs.append(up_conv)
# add the list of modules to current module
self.down_convs = nn.ModuleList(self.down_convs)
self.up_convs = nn.ModuleList(self.up_convs)
self.norm_out = Normalize(outs)
self.conv_final = conv1x1(outs, self.out_channels)
self.reset_params()
@staticmethod
def weight_init(m):
if isinstance(m, nn.Conv2d):
# init.xavier_normal_(m.weight, gain=0.1)
init.xavier_normal_(m.weight)
init.constant_(m.bias, 0)
def reset_params(self):
for i, m in enumerate(self.modules()):
self.weight_init(m)
def forward(self, x):
"""
Args:
x: Stacked triplane expected to be in (B,3,C,H,W)
"""
# Roll
x = einops.rearrange(x, 'b tri c h w -> b c (tri h) w', tri=3)
if self.use_initial_conv:
x = self.conv_initial(x)
encoder_outs = []
# encoder pathway, save outputs for merging
for i, module in enumerate(self.down_convs):
x, before_pool = module(x)
encoder_outs.append(before_pool)
# Spend a block in the middle
# x = self.block_mid(x)
for i, module in enumerate(self.up_convs):
before_pool = encoder_outs[-(i+2)]
x = module(before_pool, x)
x = self.norm_out(x)
# No softmax is used. This means you need to use
# nn.CrossEntropyLoss is your training script,
# as this module includes a softmax already.
x = self.conv_final(nonlinearity(x))
# Unroll
x = einops.rearrange(x, 'b c (tri h) w -> b tri c h w', tri=3)
return x
def setup_unet(output_channels, input_channels, unet_cfg):
if unet_cfg['use_3d_aware']:
assert(unet_cfg['rolled'])
unet = UNetTriplane3dAware(
out_channels=output_channels,
in_channels=input_channels,
depth=unet_cfg['depth'],
use_initial_conv=unet_cfg['use_initial_conv'],
start_filts=unet_cfg['start_hidden_channels'],)
else:
raise NotImplementedError
return unet