radames's picture
initial commit
c7f097c
raw
history blame
8 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from .BasePIFuNet import BasePIFuNet
import functools
from .SurfaceClassifier import SurfaceClassifier
from .DepthNormalizer import DepthNormalizer
from ..net_util import *
class ResBlkPIFuNet(BasePIFuNet):
def __init__(self, opt,
projection_mode='orthogonal'):
if opt.color_loss_type == 'l1':
error_term = nn.L1Loss()
elif opt.color_loss_type == 'mse':
error_term = nn.MSELoss()
super(ResBlkPIFuNet, self).__init__(
projection_mode=projection_mode,
error_term=error_term)
self.name = 'respifu'
self.opt = opt
norm_type = get_norm_layer(norm_type=opt.norm_color)
self.image_filter = ResnetFilter(opt, norm_layer=norm_type)
self.surface_classifier = SurfaceClassifier(
filter_channels=self.opt.mlp_dim_color,
num_views=self.opt.num_views,
no_residual=self.opt.no_residual,
last_op=nn.Tanh())
self.normalizer = DepthNormalizer(opt)
init_net(self)
def filter(self, images):
'''
Filter the input images
store all intermediate features.
:param images: [B, C, H, W] input images
'''
self.im_feat = self.image_filter(images)
def attach(self, im_feat):
self.im_feat = torch.cat([im_feat, self.im_feat], 1)
def query(self, points, calibs, transforms=None, labels=None):
'''
Given 3D points, query the network predictions for each point.
Image features should be pre-computed before this call.
store all intermediate features.
query() function may behave differently during training/testing.
:param points: [B, 3, N] world space coordinates of points
:param calibs: [B, 3, 4] calibration matrices for each image
:param transforms: Optional [B, 2, 3] image space coordinate transforms
:param labels: Optional [B, Res, N] gt labeling
:return: [B, Res, N] predictions for each point
'''
if labels is not None:
self.labels = labels
xyz = self.projection(points, calibs, transforms)
xy = xyz[:, :2, :]
z = xyz[:, 2:3, :]
z_feat = self.normalizer(z)
# This is a list of [B, Feat_i, N] features
point_local_feat_list = [self.index(self.im_feat, xy), z_feat]
# [B, Feat_all, N]
point_local_feat = torch.cat(point_local_feat_list, 1)
self.preds = self.surface_classifier(point_local_feat)
def forward(self, images, im_feat, points, calibs, transforms=None, labels=None):
self.filter(images)
self.attach(im_feat)
self.query(points, calibs, transforms, labels)
res = self.get_preds()
error = self.get_error()
return res, error
class ResnetBlock(nn.Module):
"""Define a Resnet block"""
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias, last=False):
"""Initialize the Resnet block
A resnet block is a conv block with skip connections
We construct a conv block with build_conv_block function,
and implement skip connections in <forward> function.
Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
"""
super(ResnetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias, last)
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias, last=False):
"""Construct a convolutional block.
Parameters:
dim (int) -- the number of channels in the conv layer.
padding_type (str) -- the name of padding layer: reflect | replicate | zero
norm_layer -- normalization layer
use_dropout (bool) -- if use dropout layers.
use_bias (bool) -- if the conv layer uses bias or not
Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
"""
conv_block = []
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
if last:
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias)]
else:
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
return nn.Sequential(*conv_block)
def forward(self, x):
"""Forward function (with skip connections)"""
out = x + self.conv_block(x) # add skip connections
return out
class ResnetFilter(nn.Module):
"""Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
"""
def __init__(self, opt, input_nc=3, output_nc=256, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False,
n_blocks=6, padding_type='reflect'):
"""Construct a Resnet-based generator
Parameters:
input_nc (int) -- the number of channels in input images
output_nc (int) -- the number of channels in output images
ngf (int) -- the number of filters in the last conv layer
norm_layer -- normalization layer
use_dropout (bool) -- if use dropout layers
n_blocks (int) -- the number of ResNet blocks
padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
"""
assert (n_blocks >= 0)
super(ResnetFilter, self).__init__()
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
model = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
norm_layer(ngf),
nn.ReLU(True)]
n_downsampling = 2
for i in range(n_downsampling): # add downsampling layers
mult = 2 ** i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
norm_layer(ngf * mult * 2),
nn.ReLU(True)]
mult = 2 ** n_downsampling
for i in range(n_blocks): # add ResNet blocks
if i == n_blocks - 1:
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer,
use_dropout=use_dropout, use_bias=use_bias, last=True)]
else:
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer,
use_dropout=use_dropout, use_bias=use_bias)]
if opt.use_tanh:
model += [nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, input):
"""Standard forward"""
return self.model(input)