radames's picture
initial commit
c7f097c
raw
history blame
2.59 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..geometry import index, orthogonal, perspective
class BasePIFuNet(nn.Module):
def __init__(self,
projection_mode='orthogonal',
error_term=nn.MSELoss(),
):
"""
:param projection_mode:
Either orthogonal or perspective.
It will call the corresponding function for projection.
:param error_term:
nn Loss between the predicted [B, Res, N] and the label [B, Res, N]
"""
super(BasePIFuNet, self).__init__()
self.name = 'base'
self.error_term = error_term
self.index = index
self.projection = orthogonal if projection_mode == 'orthogonal' else perspective
self.preds = None
self.labels = None
def forward(self, points, images, calibs, transforms=None):
'''
:param points: [B, 3, N] world space coordinates of points
:param images: [B, C, H, W] input images
:param calibs: [B, 3, 4] calibration matrices for each image
:param transforms: Optional [B, 2, 3] image space coordinate transforms
:return: [B, Res, N] predictions for each point
'''
self.filter(images)
self.query(points, calibs, transforms)
return self.get_preds()
def filter(self, images):
'''
Filter the input images
store all intermediate features.
:param images: [B, C, H, W] input images
'''
None
def query(self, points, calibs, transforms=None, labels=None):
'''
Given 3D points, query the network predictions for each point.
Image features should be pre-computed before this call.
store all intermediate features.
query() function may behave differently during training/testing.
:param points: [B, 3, N] world space coordinates of points
:param calibs: [B, 3, 4] calibration matrices for each image
:param transforms: Optional [B, 2, 3] image space coordinate transforms
:param labels: Optional [B, Res, N] gt labeling
:return: [B, Res, N] predictions for each point
'''
None
def get_preds(self):
'''
Get the predictions from the last query
:return: [B, Res, N] network prediction for the last query
'''
return self.preds
def get_error(self):
'''
Get the network loss from the last query
:return: loss term
'''
return self.error_term(self.preds, self.labels)