from typing import Optional import torch import torch.nn.functional as F from pytorch3d.implicitron.dataset.data_loader_map_provider import FrameData from pytorch3d.renderer.cameras import CamerasBase from pytorch3d.structures import Pointclouds from torch import Tensor from .point_cloud_transformer_model import PointCloudTransformerModel from .projection_model import PointCloudProjectionModel class PointCloudColoringModel(PointCloudProjectionModel): def __init__( self, point_cloud_model: str, point_cloud_model_layers: int, point_cloud_model_embed_dim: int, **kwargs, # projection arguments ): super().__init__(**kwargs) # Checks if self.predict_shape or not self.predict_color: raise NotImplementedError('Must predict color, not shape, for coloring') # Create point cloud model for processing point cloud self.point_cloud_model = PointCloudTransformerModel( num_layers=point_cloud_model_layers, model_type=point_cloud_model, embed_dim=point_cloud_model_embed_dim, in_channels=self.in_channels, out_channels=self.out_channels, ) # why use transformer instead??? def _forward( self, pc: Pointclouds, camera: Optional[CamerasBase], image_rgb: Optional[Tensor], mask: Optional[Tensor], return_point_cloud: bool = False, noise_std: float = 0.0, ): # Normalize colors and convert to tensor x = self.point_cloud_to_tensor(pc, normalize=True, scale=True) x_points, x_colors = x[:, :, :3], x[:, :, 3:] # Add noise to points. TODO: Add to configs. x_input = x_points + torch.randn_like(x_points) * noise_std # simulate noise of the predicted pc? # Conditioning # x_input = self.get_input_with_conditioning(x_input, camera=camera, # image_rgb=image_rgb, mask=mask) # XH: edit to run x_input = self.get_input_with_conditioning(x_input, camera=camera, image_rgb=image_rgb, mask=mask, t=None) # Forward pred_colors = self.point_cloud_model(x_input) # During inference, we return the point cloud with the predicted colors if return_point_cloud: pred_pointcloud = self.tensor_to_point_cloud( torch.cat((x_points, pred_colors), dim=2), denormalize=True, unscale=True) return pred_pointcloud # During training, we have ground truth colors and return the loss loss = F.mse_loss(pred_colors, x_colors) return loss def forward(self, batch: FrameData, **kwargs): """A wrapper around the forward method""" if isinstance(batch, dict): # fixes a bug with multiprocessing where batch becomes a dict batch = FrameData(**batch) # it really makes no sense, I do not understand it return self._forward( pc=batch.sequence_point_cloud, camera=batch.camera, image_rgb=batch.image_rgb, mask=batch.fg_probability, **kwargs, )