File size: 3,200 Bytes
2fd6166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from typing import Optional

import torch
import torch.nn.functional as F
from pytorch3d.implicitron.dataset.data_loader_map_provider import FrameData
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.structures import Pointclouds
from torch import Tensor

from .point_cloud_transformer_model import PointCloudTransformerModel
from .projection_model import PointCloudProjectionModel

class PointCloudColoringModel(PointCloudProjectionModel):
    
    def __init__(
        self,
        point_cloud_model: str,
        point_cloud_model_layers: int,
        point_cloud_model_embed_dim: int,
        **kwargs,  # projection arguments
    ):
        super().__init__(**kwargs)
        
        # Checks
        if self.predict_shape or not self.predict_color:
            raise NotImplementedError('Must predict color, not shape, for coloring')

        # Create point cloud model for processing point cloud
        self.point_cloud_model = PointCloudTransformerModel(
            num_layers=point_cloud_model_layers,
            model_type=point_cloud_model,
            embed_dim=point_cloud_model_embed_dim,
            in_channels=self.in_channels,
            out_channels=self.out_channels,
        ) # why use transformer instead???

    def _forward(
        self, 
        pc: Pointclouds,
        camera: Optional[CamerasBase],
        image_rgb: Optional[Tensor],
        mask: Optional[Tensor],
        return_point_cloud: bool = False,
        noise_std: float = 0.0,
    ):

        # Normalize colors and convert to tensor
        x = self.point_cloud_to_tensor(pc, normalize=True, scale=True)
        x_points, x_colors = x[:, :, :3], x[:, :, 3:]

        # Add noise to points. TODO: Add to configs.
        x_input = x_points + torch.randn_like(x_points) * noise_std # simulate noise of the predicted pc?

        # Conditioning
        # x_input = self.get_input_with_conditioning(x_input, camera=camera,
        #     image_rgb=image_rgb, mask=mask)
        # XH: edit to run
        x_input = self.get_input_with_conditioning(x_input, camera=camera,
                                                   image_rgb=image_rgb, mask=mask, t=None)

        # Forward
        pred_colors = self.point_cloud_model(x_input)

        # During inference, we return the point cloud with the predicted colors
        if return_point_cloud:
            pred_pointcloud = self.tensor_to_point_cloud(
                torch.cat((x_points, pred_colors), dim=2), denormalize=True, unscale=True)
            return pred_pointcloud

        # During training, we have ground truth colors and return the loss
        loss = F.mse_loss(pred_colors, x_colors)
        return loss

    def forward(self, batch: FrameData, **kwargs):
        """A wrapper around the forward method"""
        if isinstance(batch, dict):  # fixes a bug with multiprocessing where batch becomes a dict
            batch = FrameData(**batch)  # it really makes no sense, I do not understand it
        return self._forward(
            pc=batch.sequence_point_cloud, 
            camera=batch.camera,
            image_rgb=batch.image_rgb, 
            mask=batch.fg_probability,
            **kwargs,
        )