radames's picture
initial commit
c7f097c
raw history blame
No virus
3.47 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from .BasePIFuNet import BasePIFuNet
from .SurfaceClassifier import SurfaceClassifier
from .DepthNormalizer import DepthNormalizer
from .ConvFilters import *
from ..net_util import init_net
class ConvPIFuNet(BasePIFuNet):
'''
Conv Piximp network is the standard 3-phase network that we will use.
The image filter is a pure multi-layer convolutional network,
while during feature extraction phase all features in the pyramid at the projected location
will be aggregated.
It does the following:
1. Compute image feature pyramids and store it in self.im_feat_list
2. Calculate calibration and indexing on each of the feat, and append them together
3. Classification.
'''
def __init__(self,
opt,
projection_mode='orthogonal',
error_term=nn.MSELoss(),
):
super(ConvPIFuNet, self).__init__(
projection_mode=projection_mode,
error_term=error_term)
self.name = 'convpifu'
self.opt = opt
self.num_views = self.opt.num_views
self.image_filter = self.define_imagefilter(opt)
self.surface_classifier = SurfaceClassifier(
filter_channels=self.opt.mlp_dim,
num_views=self.opt.num_views,
no_residual=self.opt.no_residual,
last_op=nn.Sigmoid())
self.normalizer = DepthNormalizer(opt)
# This is a list of [B x Feat_i x H x W] features
self.im_feat_list = []
init_net(self)
def define_imagefilter(self, opt):
net = None
if opt.netIMF == 'multiconv':
net = MultiConv(opt.enc_dim)
elif 'resnet' in opt.netIMF:
net = ResNet(model=opt.netIMF)
elif opt.netIMF == 'vgg16':
net = Vgg16()
else:
raise NotImplementedError('model name [%s] is not recognized' % opt.imf_type)
return net
def filter(self, images):
'''
Filter the input images
store all intermediate features.
:param images: [B, C, H, W] input images
'''
self.im_feat_list = self.image_filter(images)
def query(self, points, calibs, transforms=None, labels=None):
'''
Given 3D points, query the network predictions for each point.
Image features should be pre-computed before this call.
store all intermediate features.
query() function may behave differently during training/testing.
:param points: [B, 3, N] world space coordinates of points
:param calibs: [B, 3, 4] calibration matrices for each image
:param transforms: Optional [B, 2, 3] image space coordinate transforms
:param labels: Optional [B, Res, N] gt labeling
:return: [B, Res, N] predictions for each point
'''
if labels is not None:
self.labels = labels
xyz = self.projection(points, calibs, transforms)
xy = xyz[:, :2, :]
z = xyz[:, 2:3, :]
z_feat = self.normalizer(z)
# This is a list of [B, Feat_i, N] features
point_local_feat_list = [self.index(im_feat, xy) for im_feat in self.im_feat_list]
point_local_feat_list.append(z_feat)
# [B, Feat_all, N]
point_local_feat = torch.cat(point_local_feat_list, 1)
self.preds = self.surface_classifier(point_local_feat)