Spaces:
Sleeping
Sleeping
""" | |
model to deal with shapenet inputs and other datasets such as Behave and ProciGen | |
the model takes a different data dictionary in forward function | |
""" | |
import inspect | |
from typing import Optional | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler | |
from diffusers.schedulers.scheduling_ddim import DDIMScheduler | |
from diffusers.schedulers.scheduling_pndm import PNDMScheduler | |
from pytorch3d.implicitron.dataset.data_loader_map_provider import FrameData | |
from pytorch3d.renderer.cameras import CamerasBase | |
from pytorch3d.structures import Pointclouds | |
from torch import Tensor | |
from tqdm import tqdm | |
from pytorch3d.renderer import PerspectiveCameras | |
from pytorch3d.datasets.r2n2.utils import BlenderCamera | |
from .model import ConditionalPointCloudDiffusionModel | |
from .model_utils import get_num_points | |
class ConditionalPCDiffusionShapenet(ConditionalPointCloudDiffusionModel): | |
def forward(self, batch, mode: str = 'train', **kwargs): | |
""" | |
take a batch of data from ShapeNet | |
""" | |
images = torch.stack(batch['images'], 0).to('cuda') | |
masks = torch.stack(batch['masks'], 0).to('cuda') | |
pc = Pointclouds([x.to('cuda') for x in batch['pclouds']]) | |
camera = BlenderCamera( | |
torch.stack(batch['R']), | |
torch.stack(batch['T']), | |
torch.stack(batch['K']), device='cuda' | |
) | |
if mode == 'train': | |
return self.forward_train( | |
pc=pc, | |
camera=camera, | |
image_rgb=images, | |
mask=masks, | |
**kwargs) | |
elif mode == 'sample': | |
num_points = kwargs.pop('num_points', get_num_points(pc)) | |
return self.forward_sample( | |
num_points=num_points, | |
camera=camera, | |
image_rgb=images, | |
mask=masks, | |
gt_pc=pc, | |
**kwargs) | |
else: | |
raise NotImplementedError() | |
class ConditionalPCDiffusionBehave(ConditionalPointCloudDiffusionModel): | |
"diffusion model for Behave dataset" | |
def forward(self, batch, mode: str = 'train', **kwargs): | |
images = torch.stack(batch['images'], 0).to('cuda') | |
masks = torch.stack(batch['masks'], 0).to('cuda') | |
pc = self.get_input_pc(batch) | |
camera = PerspectiveCameras( | |
R=torch.stack(batch['R']), | |
T=torch.stack(batch['T']), | |
K=torch.stack(batch['K']), | |
device='cuda', | |
in_ndc=True | |
) | |
grid_df = torch.stack(batch['grid_df'], 0).to('cuda') if 'grid_df' in batch else None | |
num_points = kwargs.pop('num_points', get_num_points(pc)) | |
if mode == 'train': | |
return self.forward_train( | |
pc=pc, | |
camera=camera, | |
image_rgb=images, | |
mask=masks, | |
grid_df=grid_df, | |
**kwargs) | |
elif mode == 'sample': | |
return self.forward_sample( | |
num_points=num_points, | |
camera=camera, | |
image_rgb=images, | |
mask=masks, | |
gt_pc=pc, | |
**kwargs) | |
else: | |
raise NotImplementedError() | |
def get_input_pc(self, batch): | |
pc = Pointclouds([x.to('cuda') for x in batch['pclouds']]) | |
return pc | |
class ConditionalPCDiffusionSeparateSegm(ConditionalPCDiffusionBehave): | |
"a separate model to predict binary labels, the final segmentation model" | |
def __init__(self, | |
beta_start: float, | |
beta_end: float, | |
beta_schedule: str, | |
point_cloud_model: str, | |
point_cloud_model_embed_dim: int, | |
**kwargs, # projection arguments | |
): | |
super(ConditionalPCDiffusionSeparateSegm, self).__init__(beta_start, beta_end, beta_schedule, | |
point_cloud_model, | |
point_cloud_model_embed_dim, **kwargs) | |
# add a separate model to predict binary label | |
from .point_cloud_transformer_model import PointCloudTransformerModel, PointCloudModel | |
self.binary_model = PointCloudTransformerModel( | |
num_layers=1, # XH: use the default color model number of layers | |
model_type=point_cloud_model, # pvcnn | |
embed_dim=point_cloud_model_embed_dim, # save as pc shape model | |
in_channels=self.in_channels, | |
out_channels=1, | |
) | |
self.binary_training_noise_std = kwargs.get("binary_training_noise_std", 0.1) | |
# re-initialize point cloud model | |
assert self.predict_binary | |
self.point_cloud_model = PointCloudModel( | |
model_type=point_cloud_model, | |
embed_dim=point_cloud_model_embed_dim, | |
in_channels=self.in_channels, | |
out_channels=self.out_channels - 1, # not predicting binary from this anymore | |
voxel_resolution_multiplier=kwargs.get('voxel_resolution_multiplier', 1) | |
) | |
def forward_train( | |
self, | |
pc: Pointclouds, | |
camera: Optional[CamerasBase], | |
image_rgb: Optional[Tensor], | |
mask: Optional[Tensor], | |
return_intermediate_steps: bool = False, | |
**kwargs | |
): | |
# first run shape forward, then binary label forward | |
assert not return_intermediate_steps | |
assert self.predict_binary | |
loss_shape = super(ConditionalPCDiffusionSeparateSegm, self).forward_train(pc, | |
camera, | |
image_rgb, | |
mask, | |
return_intermediate_steps, | |
**kwargs) | |
# binary label forward | |
x_0 = self.point_cloud_to_tensor(pc, normalize=True, scale=True) | |
x_points, x_colors = x_0[:, :, :3], x_0[:, :, 3:] | |
# Add noise to points. | |
x_input = x_points + torch.randn_like(x_points) * self.binary_training_noise_std # std=0.1 | |
x_input = self.get_input_with_conditioning(x_input, camera=camera, | |
image_rgb=image_rgb, mask=mask, t=None) | |
# Forward | |
pred_segm = self.binary_model(x_input) | |
# use compressed bits | |
df_grid = kwargs.get('grid_df', None).unsqueeze(1) # (B, 1, resz, resy, resx) | |
points = x_points.clone().detach() / self.scale_factor * 2 # , normalize to [-1, 1] | |
points[:, :, 0], points[:, :, 2] = points[:, :, 2].clone(), points[:, :,0].clone() # swap, make sure clone is used! | |
points = points.unsqueeze(1).unsqueeze(1) # (B,1, 1, N, 3) | |
with torch.no_grad(): | |
df_interp = F.grid_sample(df_grid, points, padding_mode='border', align_corners=True).squeeze(1).squeeze(1) # (B, 1, 1, 1, N) | |
binary_label = df_interp[:, 0] > 0.5 # (B, 1, N) | |
binary_pred = torch.sigmoid(pred_segm.squeeze(-1)) # add a sigmoid layer | |
loss_binary = F.mse_loss(binary_pred, binary_label.float().squeeze(1).squeeze(1)) * self.lw_binary | |
loss = loss_shape + loss_binary | |
return loss, torch.tensor([loss_shape, loss_binary]) | |
def reverse_step(self, extra_step_kwargs, scheduler, t, x_t, x_t_input, **kwargs): | |
"return (B, N, 4), the 4-th channel is binary label" | |
B = x_t.shape[0] | |
# Forward | |
noise_pred = self.point_cloud_model(x_t_input, t.reshape(1).expand(B)) | |
if self.consistent_center: | |
assert self.dm_pred_type != 'sample', 'incompatible dm predition type!' | |
# suggested by the CCD-3DR paper | |
noise_pred = noise_pred - torch.mean(noise_pred, dim=1, keepdim=True) | |
# Step: make sure only update the shape (first 3 channels) | |
x_t = scheduler.step(noise_pred, t, x_t[:, :, :3], **extra_step_kwargs).prev_sample | |
if self.consistent_center: | |
x_t = x_t - torch.mean(x_t, dim=1, keepdim=True) | |
# also add binary prediction | |
if kwargs.get('inference_binary', False): | |
pred_segm = self.binary_model(x_t_input) | |
else: | |
pred_segm = torch.zeros_like(x_t[:, :, 0:1]) | |
x_t = torch.cat([x_t, torch.sigmoid(pred_segm)], -1) | |
return x_t | |
def get_coord_feature(self, x_t): | |
x_t_input = [x_t[:, :, :3]] | |
return x_t_input | |
def tensor_to_point_cloud(self, x: Tensor, /, denormalize: bool = False, unscale: bool = False): | |
""" | |
take binary label into account | |
:param self: | |
:param x: (B, N, 4), the 4th channel is the binary segmentation, 1-human, 0-object | |
:param denormalize: denormalize the per-point colors, from pc2 | |
:param unscale: undo point scaling, from pc2 | |
:return: pc with point colors if predict binary label or per-point color | |
""" | |
points = x[:, :, :3] / (self.scale_factor if unscale else 1) | |
if self.predict_color: | |
colors = self.denormalize(x[:, :, 3:]) if denormalize else x[:, :, 3:] | |
return Pointclouds(points=points, features=colors) | |
else: | |
if self.predict_binary: | |
assert x.shape[2] == 4 | |
# add color to predicted binary labels | |
is_hum = x[:, :, 3] > 0.5 | |
features = [] | |
for mask in is_hum: | |
color = torch.zeros_like(x[0, :, :3]) + torch.tensor([0.5, 1.0, 0]).to(x.device) | |
color[mask, :] = torch.tensor([0.05, 1.0, 1.0]).to(x.device) # human is light blue, object light green | |
features.append(color) | |
else: | |
assert x.shape[2] == 3 | |
features = None | |
return Pointclouds(points=points, features=features) | |