| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import torch.nn as nn |
| | import pytorch_lightning as pl |
| |
|
| | from .geometry import index, orthogonal, perspective |
| |
|
| |
|
| | class BasePIFuNet(pl.LightningModule): |
| |
|
| | def __init__( |
| | self, |
| | projection_mode='orthogonal', |
| | error_term=nn.MSELoss(), |
| | ): |
| | """ |
| | :param projection_mode: |
| | Either orthogonal or perspective. |
| | It will call the corresponding function for projection. |
| | :param error_term: |
| | nn Loss between the predicted [B, Res, N] and the label [B, Res, N] |
| | """ |
| | super(BasePIFuNet, self).__init__() |
| | self.name = 'base' |
| |
|
| | self.error_term = error_term |
| |
|
| | self.index = index |
| | self.projection = orthogonal if projection_mode == 'orthogonal' else perspective |
| |
|
| | def forward(self, points, images, calibs, transforms=None): |
| | ''' |
| | :param points: [B, 3, N] world space coordinates of points |
| | :param images: [B, C, H, W] input images |
| | :param calibs: [B, 3, 4] calibration matrices for each image |
| | :param transforms: Optional [B, 2, 3] image space coordinate transforms |
| | :return: [B, Res, N] predictions for each point |
| | ''' |
| | features = self.filter(images) |
| | preds = self.query(features, points, calibs, transforms) |
| | return preds |
| |
|
| | def filter(self, images): |
| | ''' |
| | Filter the input images |
| | store all intermediate features. |
| | :param images: [B, C, H, W] input images |
| | ''' |
| | return None |
| |
|
| | def query(self, features, points, calibs, transforms=None): |
| | ''' |
| | Given 3D points, query the network predictions for each point. |
| | Image features should be pre-computed before this call. |
| | store all intermediate features. |
| | query() function may behave differently during training/testing. |
| | :param points: [B, 3, N] world space coordinates of points |
| | :param calibs: [B, 3, 4] calibration matrices for each image |
| | :param transforms: Optional [B, 2, 3] image space coordinate transforms |
| | :param labels: Optional [B, Res, N] gt labeling |
| | :return: [B, Res, N] predictions for each point |
| | ''' |
| | return None |
| |
|
| | def get_error(self, preds, labels): |
| | ''' |
| | Get the network loss from the last query |
| | :return: loss term |
| | ''' |
| | return self.error_term(preds, labels) |
| |
|