File size: 2,465 Bytes
c7f097c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
import torch.nn as nn
import torch.nn.functional as F
from .BasePIFuNet import BasePIFuNet


class VhullPIFuNet(BasePIFuNet):
    '''
    Vhull Piximp network is a minimal network demonstrating how the template works
    also, it helps debugging the training/test schemes
    It does the following:
        1. Compute the masks of images and stores under self.im_feats
        2. Calculate calibration and indexing
        3. Return if the points fall into the intersection of all masks
    '''

    def __init__(self,
                 num_views,
                 projection_mode='orthogonal',
                 error_term=nn.MSELoss(),
                 ):
        super(VhullPIFuNet, self).__init__(
            projection_mode=projection_mode,
            error_term=error_term)
        self.name = 'vhull'

        self.num_views = num_views

        self.im_feat = None

    def filter(self, images):
        '''
        Filter the input images
        store all intermediate features.
        :param images: [B, C, H, W] input images
        '''
        # If the image has alpha channel, use the alpha channel
        if images.shape[1] > 3:
            self.im_feat = images[:, 3:4, :, :]
        # Else, tell if it's not white
        else:
            self.im_feat = images[:, 0:1, :, :]

    def query(self, points, calibs, transforms=None, labels=None):
        '''
        Given 3D points, query the network predictions for each point.
        Image features should be pre-computed before this call.
        store all intermediate features.
        query() function may behave differently during training/testing.
        :param points: [B, 3, N] world space coordinates of points
        :param calibs: [B, 3, 4] calibration matrices for each image
        :param transforms: Optional [B, 2, 3] image space coordinate transforms
        :param labels: Optional [B, Res, N] gt labeling
        :return: [B, Res, N] predictions for each point
        '''
        if labels is not None:
            self.labels = labels

        xyz = self.projection(points, calibs, transforms)
        xy = xyz[:, :2, :]

        point_local_feat = self.index(self.im_feat, xy)
        local_shape = point_local_feat.shape
        point_feat = point_local_feat.view(
            local_shape[0] // self.num_views,
            local_shape[1] * self.num_views,
            -1)
        pred = torch.prod(point_feat, dim=1)

        self.preds = pred.unsqueeze(1)