# -*- coding: utf-8 -*- # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is # holder of all proprietary rights on this computer program. # You can only use this computer program if you have closed # a license agreement with MPG or you get the right to use the computer # program from someone who is authorized to grant you that right. # Any use of the computer program without a valid license is prohibited and # liable to prosecution. # # Copyright©2019 Max-Planck-Gesellschaft zur Förderung # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute # for Intelligent Systems. All rights reserved. # # Contact: ps-license@tuebingen.mpg.de import torch.nn as nn import pytorch_lightning as pl from .geometry import index, orthogonal, perspective class BasePIFuNet(pl.LightningModule): def __init__( self, projection_mode='orthogonal', error_term=nn.MSELoss(), ): """ :param projection_mode: Either orthogonal or perspective. It will call the corresponding function for projection. :param error_term: nn Loss between the predicted [B, Res, N] and the label [B, Res, N] """ super(BasePIFuNet, self).__init__() self.name = 'base' self.error_term = error_term self.index = index self.projection = orthogonal if projection_mode == 'orthogonal' else perspective def forward(self, points, images, calibs, transforms=None): ''' :param points: [B, 3, N] world space coordinates of points :param images: [B, C, H, W] input images :param calibs: [B, 3, 4] calibration matrices for each image :param transforms: Optional [B, 2, 3] image space coordinate transforms :return: [B, Res, N] predictions for each point ''' features = self.filter(images) preds = self.query(features, points, calibs, transforms) return preds def filter(self, images): ''' Filter the input images store all intermediate features. :param images: [B, C, H, W] input images ''' return None def query(self, features, points, calibs, transforms=None): ''' Given 3D points, query the network predictions for each point. Image features should be pre-computed before this call. store all intermediate features. query() function may behave differently during training/testing. :param points: [B, 3, N] world space coordinates of points :param calibs: [B, 3, 4] calibration matrices for each image :param transforms: Optional [B, 2, 3] image space coordinate transforms :param labels: Optional [B, Res, N] gt labeling :return: [B, Res, N] predictions for each point ''' return None def get_error(self, preds, labels): ''' Get the network loss from the last query :return: loss term ''' return self.error_term(preds, labels)