File size: 3,040 Bytes
da48dbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
487ee6d
da48dbe
 
 
 
 
 
fb140f6
 
 
da48dbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# -*- coding: utf-8 -*-

# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: ps-license@tuebingen.mpg.de

import pytorch_lightning as pl
import torch.nn as nn

from .geometry import index, orthogonal, perspective


class BasePIFuNet(pl.LightningModule):
    def __init__(
        self,
        projection_mode="orthogonal",
        error_term=nn.MSELoss(),
    ):
        """
        :param projection_mode:
        Either orthogonal or perspective.
        It will call the corresponding function for projection.
        :param error_term:
        nn Loss between the predicted [B, Res, N] and the label [B, Res, N]
        """
        super(BasePIFuNet, self).__init__()
        self.name = "base"

        self.error_term = error_term

        self.index = index
        self.projection = orthogonal if projection_mode == "orthogonal" else perspective

    def forward(self, points, images, calibs, transforms=None):
        """
        :param points: [B, 3, N] world space coordinates of points
        :param images: [B, C, H, W] input images
        :param calibs: [B, 3, 4] calibration matrices for each image
        :param transforms: Optional [B, 2, 3] image space coordinate transforms
        :return: [B, Res, N] predictions for each point
        """
        features = self.filter(images)
        preds = self.query(features, points, calibs, transforms)
        return preds

    def filter(self, images):
        """
        Filter the input images
        store all intermediate features.
        :param images: [B, C, H, W] input images
        """
        return None

    def query(self, features, points, calibs, transforms=None):
        """
        Given 3D points, query the network predictions for each point.
        Image features should be pre-computed before this call.
        store all intermediate features.
        query() function may behave differently during training/testing.
        :param points: [B, 3, N] world space coordinates of points
        :param calibs: [B, 3, 4] calibration matrices for each image
        :param transforms: Optional [B, 2, 3] image space coordinate transforms
        :param labels: Optional [B, Res, N] gt labeling
        :return: [B, Res, N] predictions for each point
        """
        return None

    def get_error(self, preds, labels):
        """
        Get the network loss from the last query
        :return: loss term
        """
        return self.error_term(preds, labels)