import torch import torch.nn as nn import torch.nn.functional as F from .BasePIFuNet import BasePIFuNet from .SurfaceClassifier import SurfaceClassifier from .DepthNormalizer import DepthNormalizer from .HGFilters import * from ..net_util import init_net class HGPIFuNet(BasePIFuNet): ''' HG PIFu network uses Hourglass stacks as the image filter. It does the following: 1. Compute image feature stacks and store it in self.im_feat_list self.im_feat_list[-1] is the last stack (output stack) 2. Calculate calibration 3. If training, it index on every intermediate stacks, If testing, it index on the last stack. 4. Classification. 5. During training, error is calculated on all stacks. ''' def __init__(self, opt, projection_mode='orthogonal', error_term=nn.MSELoss(), ): super(HGPIFuNet, self).__init__( projection_mode=projection_mode, error_term=error_term) self.name = 'hgpifu' self.opt = opt self.num_views = self.opt.num_views self.image_filter = HGFilter(opt) self.surface_classifier = SurfaceClassifier( filter_channels=self.opt.mlp_dim, num_views=self.opt.num_views, no_residual=self.opt.no_residual, last_op=nn.Sigmoid()) self.normalizer = DepthNormalizer(opt) # This is a list of [B x Feat_i x H x W] features self.im_feat_list = [] self.tmpx = None self.normx = None self.intermediate_preds_list = [] init_net(self) def filter(self, images): ''' Filter the input images store all intermediate features. :param images: [B, C, H, W] input images ''' self.im_feat_list, self.tmpx, self.normx = self.image_filter(images) # If it is not in training, only produce the last im_feat if not self.training: self.im_feat_list = [self.im_feat_list[-1]] def query(self, points, calibs, transforms=None, labels=None): ''' Given 3D points, query the network predictions for each point. Image features should be pre-computed before this call. store all intermediate features. query() function may behave differently during training/testing. :param points: [B, 3, N] world space coordinates of points :param calibs: [B, 3, 4] calibration matrices for each image :param transforms: Optional [B, 2, 3] image space coordinate transforms :param labels: Optional [B, Res, N] gt labeling :return: [B, Res, N] predictions for each point ''' if labels is not None: self.labels = labels xyz = self.projection(points, calibs, transforms) xy = xyz[:, :2, :] z = xyz[:, 2:3, :] in_img = (xy[:, 0] >= -1.0) & (xy[:, 0] <= 1.0) & (xy[:, 1] >= -1.0) & (xy[:, 1] <= 1.0) z_feat = self.normalizer(z, calibs=calibs) if self.opt.skip_hourglass: tmpx_local_feature = self.index(self.tmpx, xy) self.intermediate_preds_list = [] for im_feat in self.im_feat_list: # [B, Feat_i + z, N] point_local_feat_list = [self.index(im_feat, xy), z_feat] if self.opt.skip_hourglass: point_local_feat_list.append(tmpx_local_feature) point_local_feat = torch.cat(point_local_feat_list, 1) # out of image plane is always set to 0 pred = in_img[:,None].float() * self.surface_classifier(point_local_feat) self.intermediate_preds_list.append(pred) self.preds = self.intermediate_preds_list[-1] def get_im_feat(self): ''' Get the image filter :return: [B, C_feat, H, W] image feature after filtering ''' return self.im_feat_list[-1] def get_error(self): ''' Hourglass has its own intermediate supervision scheme ''' error = 0 for preds in self.intermediate_preds_list: error += self.error_term(preds, self.labels) error /= len(self.intermediate_preds_list) return error def forward(self, images, points, calibs, transforms=None, labels=None): # Get image feature self.filter(images) # Phase 2: point query self.query(points=points, calibs=calibs, transforms=transforms, labels=labels) # get the prediction res = self.get_preds() # get the error error = self.get_error() return res, error