HDM-interaction-recon / model /model_coloring.py
xiexh20's picture
add hdm demo v1
2fd6166
raw
history blame
3.2 kB
from typing import Optional
import torch
import torch.nn.functional as F
from pytorch3d.implicitron.dataset.data_loader_map_provider import FrameData
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.structures import Pointclouds
from torch import Tensor
from .point_cloud_transformer_model import PointCloudTransformerModel
from .projection_model import PointCloudProjectionModel
class PointCloudColoringModel(PointCloudProjectionModel):
def __init__(
self,
point_cloud_model: str,
point_cloud_model_layers: int,
point_cloud_model_embed_dim: int,
**kwargs, # projection arguments
):
super().__init__(**kwargs)
# Checks
if self.predict_shape or not self.predict_color:
raise NotImplementedError('Must predict color, not shape, for coloring')
# Create point cloud model for processing point cloud
self.point_cloud_model = PointCloudTransformerModel(
num_layers=point_cloud_model_layers,
model_type=point_cloud_model,
embed_dim=point_cloud_model_embed_dim,
in_channels=self.in_channels,
out_channels=self.out_channels,
) # why use transformer instead???
def _forward(
self,
pc: Pointclouds,
camera: Optional[CamerasBase],
image_rgb: Optional[Tensor],
mask: Optional[Tensor],
return_point_cloud: bool = False,
noise_std: float = 0.0,
):
# Normalize colors and convert to tensor
x = self.point_cloud_to_tensor(pc, normalize=True, scale=True)
x_points, x_colors = x[:, :, :3], x[:, :, 3:]
# Add noise to points. TODO: Add to configs.
x_input = x_points + torch.randn_like(x_points) * noise_std # simulate noise of the predicted pc?
# Conditioning
# x_input = self.get_input_with_conditioning(x_input, camera=camera,
# image_rgb=image_rgb, mask=mask)
# XH: edit to run
x_input = self.get_input_with_conditioning(x_input, camera=camera,
image_rgb=image_rgb, mask=mask, t=None)
# Forward
pred_colors = self.point_cloud_model(x_input)
# During inference, we return the point cloud with the predicted colors
if return_point_cloud:
pred_pointcloud = self.tensor_to_point_cloud(
torch.cat((x_points, pred_colors), dim=2), denormalize=True, unscale=True)
return pred_pointcloud
# During training, we have ground truth colors and return the loss
loss = F.mse_loss(pred_colors, x_colors)
return loss
def forward(self, batch: FrameData, **kwargs):
"""A wrapper around the forward method"""
if isinstance(batch, dict): # fixes a bug with multiprocessing where batch becomes a dict
batch = FrameData(**batch) # it really makes no sense, I do not understand it
return self._forward(
pc=batch.sequence_point_cloud,
camera=batch.camera,
image_rgb=batch.image_rgb,
mask=batch.fg_probability,
**kwargs,
)