Spaces:
Sleeping
Sleeping
""" | |
model that use cross attention to predict human + object | |
""" | |
import inspect | |
import random | |
from typing import Optional | |
from torch import Tensor | |
import torch | |
import numpy as np | |
from pytorch3d.structures import Pointclouds | |
from pytorch3d.renderer import CamerasBase | |
from .model_diff_data import ConditionalPCDiffusionBehave | |
from .pvcnn.pvcnn_ho import PVCNN2HumObj | |
import torch.nn.functional as F | |
from pytorch3d.renderer import PerspectiveCameras | |
from .model_utils import get_num_points | |
from tqdm import tqdm | |
class CrossAttenHODiffusionModel(ConditionalPCDiffusionBehave): | |
def init_pcloud_model(self, kwargs, point_cloud_model, point_cloud_model_embed_dim): | |
"""use cross attention model""" | |
if point_cloud_model == 'pvcnn': | |
self.point_cloud_model = PVCNN2HumObj(embed_dim=point_cloud_model_embed_dim, | |
num_classes=self.out_channels, | |
extra_feature_channels=(self.in_channels - 3), | |
voxel_resolution_multiplier=kwargs.get('voxel_resolution_multiplier', 1), | |
attn_type=kwargs.get('attn_type', 'simple-cross'), | |
attn_weight=kwargs.get("attn_weight", 1.0) | |
) | |
else: | |
raise ValueError(f"Unknown point cloud model {point_cloud_model}!") | |
self.point_visible_test = kwargs.get("point_visible_test", 'single') # when doing point visibility test, use only human points or human + object? | |
assert self.point_visible_test in ['single', 'combine'], f'invalide point visible test option {self.point_visible_test}' | |
# print(f"Point visibility test is based on {self.point_visible_test} point clouds!") | |
def forward_train( | |
self, | |
pc: Pointclouds, | |
camera: Optional[CamerasBase], | |
image_rgb: Optional[Tensor], | |
mask: Optional[Tensor], | |
return_intermediate_steps: bool = False, | |
**kwargs | |
): | |
"additional input (RGB, mask, camera, and pc) for object is read from kwargs" | |
# assert not self.consistent_center | |
assert not self.self_conditioning | |
# Normalize colors and convert to tensor | |
x0_h = self.point_cloud_to_tensor(pc, normalize=True, scale=True) # this will not pack the point colors | |
x0_o = self.point_cloud_to_tensor(kwargs.get('pc_obj'), normalize=True, scale=True) | |
B, N, D = x0_h.shape | |
# Sample random noise | |
noise = torch.randn_like(x0_h) | |
if self.consistent_center: | |
# modification suggested by https://arxiv.org/pdf/2308.07837.pdf | |
noise = noise - torch.mean(noise, dim=1, keepdim=True) | |
# Sample random timesteps for each point_cloud | |
timestep = torch.randint(0, self.scheduler.num_train_timesteps, (B,), | |
device=self.device, dtype=torch.long) | |
# timestep = torch.randint(0, 1, (B,), | |
# device=self.device, dtype=torch.long) | |
# Add noise to points | |
xt_h = self.scheduler.add_noise(x0_h, noise, timestep) | |
xt_o = self.scheduler.add_noise(x0_o, noise, timestep) | |
norm_parms = self.pack_norm_params(kwargs) # (2, B, 4) | |
# get input conditioning | |
x_t_input_h, x_t_input_o = self.get_image_conditioning(camera, image_rgb, kwargs, mask, norm_parms, timestep, | |
xt_h, xt_o) | |
# Diffusion prediction | |
noise_pred_h, noise_pred_o = self.point_cloud_model(x_t_input_h, x_t_input_o, timestep, norm_parms) | |
# Check | |
if not noise_pred_h.shape == noise.shape: | |
raise ValueError(f'{noise_pred_h.shape=} and {noise.shape=}') | |
if not noise_pred_o.shape == noise.shape: | |
raise ValueError(f'{noise_pred_o.shape=} and {noise.shape=}') | |
# Loss | |
loss_h = F.mse_loss(noise_pred_h, noise) | |
loss_o = F.mse_loss(noise_pred_o, noise) | |
loss = loss_h + loss_o | |
# Whether to return intermediate steps | |
if return_intermediate_steps: | |
return loss, (x0_h, xt_h, noise, noise_pred_h) | |
return loss, torch.tensor([loss_h, loss_o]) | |
def get_image_conditioning(self, camera, image_rgb, kwargs, mask, norm_parms, timestep, xt_h, xt_o): | |
""" | |
compute image features for each point | |
:param camera: | |
:param image_rgb: | |
:param kwargs: | |
:param mask: | |
:param norm_parms: | |
:param timestep: | |
:param xt_h: | |
:param xt_o: | |
:return: | |
""" | |
if self.point_visible_test == 'single': | |
# Visibility test is down independently for human and object | |
x_t_input_h = self.get_input_with_conditioning(xt_h, camera=camera, | |
image_rgb=image_rgb, mask=mask, t=timestep) | |
x_t_input_o = self.get_input_with_conditioning(xt_o, camera=kwargs.get('camera_obj'), | |
image_rgb=kwargs.get('rgb_obj'), | |
mask=kwargs.get('mask_obj'), t=timestep) | |
elif self.point_visible_test == 'combine': | |
# Combine human + object points to do visibility test and obtain features | |
B, N = xt_h.shape[:2] # (B, N, 3) | |
# for human: transform object points first to H+O space, then to human space | |
xt_o_in_ho = xt_o * 2 * norm_parms[1, :, 3:].unsqueeze(1) + norm_parms[1, :, :3].unsqueeze(1) | |
xt_o_in_hum = (xt_o_in_ho - norm_parms[0, :, :3].unsqueeze(1)) / (2 * norm_parms[0, :, 3:].unsqueeze(1)) | |
# compute features for all points, take only first half feature for human | |
x_t_input_h = self.get_input_with_conditioning(torch.cat([xt_h, xt_o_in_hum], 1), camera=camera, | |
image_rgb=image_rgb, mask=mask, t=timestep)[:,:N] | |
# for object: transform human points to H+O space, then to object space | |
xt_h_in_ho = xt_h * 2 * norm_parms[0, :, 3:].unsqueeze(1) + norm_parms[0, :, :3].unsqueeze(1) | |
xt_h_in_obj = (xt_h_in_ho - norm_parms[1, :, :3].unsqueeze(1)) / (2 * norm_parms[1, :, 3:].unsqueeze(1)) | |
x_t_input_o = self.get_input_with_conditioning(torch.cat([xt_o, xt_h_in_obj], 1), | |
camera=kwargs.get('camera_obj'), | |
image_rgb=kwargs.get('rgb_obj'), | |
mask=kwargs.get('mask_obj'), t=timestep)[:, :N] | |
else: | |
raise NotImplementedError | |
return x_t_input_h, x_t_input_o | |
def forward(self, batch, mode: str = 'train', **kwargs): | |
"""""" | |
images = torch.stack(batch['images'], 0).to('cuda') | |
masks = torch.stack(batch['masks'], 0).to('cuda') | |
pc = self.get_input_pc(batch) | |
camera = PerspectiveCameras( | |
R=torch.stack(batch['R']), | |
T=torch.stack(batch['T_hum']), | |
K=torch.stack(batch['K_hum']), | |
device='cuda', | |
in_ndc=True | |
) | |
grid_df = torch.stack(batch['grid_df'], 0).to('cuda') if 'grid_df' in batch else None | |
num_points = kwargs.pop('num_points', get_num_points(pc)) | |
rgb_obj = torch.stack(batch['images_obj'], 0).to('cuda') | |
masks_obj = torch.stack(batch['masks_obj'], 0).to('cuda') | |
pc_obj = Pointclouds([x.to('cuda') for x in batch['pclouds_obj']]) | |
camera_obj = PerspectiveCameras( | |
R=torch.stack(batch['R']), | |
T=torch.stack(batch['T_obj']), | |
K=torch.stack(batch['K_obj']), | |
device='cuda', | |
in_ndc=True | |
) | |
# normalization parameters | |
cent_hum = torch.stack(batch['cent_hum'], 0).to('cuda') | |
cent_obj = torch.stack(batch['cent_obj'], 0).to('cuda') # B, 3 | |
radius_hum = torch.stack(batch['radius_hum'], 0).to('cuda') # B, 1 | |
radius_obj = torch.stack(batch['radius_obj'], 0).to('cuda') | |
# print(batch['image_path']) | |
if mode == 'train': | |
return self.forward_train( | |
pc=pc, | |
camera=camera, | |
image_rgb=images, | |
mask=masks, | |
grid_df=grid_df, | |
rgb_obj=rgb_obj, | |
mask_obj=masks_obj, | |
pc_obj=pc_obj, | |
camera_obj=camera_obj, | |
cent_hum=cent_hum, | |
cent_obj=cent_obj, | |
radius_hum=radius_hum, | |
radius_obj=radius_obj, | |
) | |
elif mode == 'sample': | |
# this use GT centers to do projection | |
return self.forward_sample( | |
num_points=num_points, | |
camera=camera, | |
image_rgb=images, | |
mask=masks, | |
gt_pc=pc, | |
rgb_obj=rgb_obj, | |
mask_obj=masks_obj, | |
pc_obj=pc_obj, | |
camera_obj=camera_obj, | |
cent_hum=cent_hum, | |
cent_obj=cent_obj, | |
radius_hum=radius_hum, | |
radius_obj=radius_obj, | |
**kwargs) | |
elif mode == 'interm-gt': | |
return self.forward_sample( | |
num_points=num_points, | |
camera=camera, | |
image_rgb=images, | |
mask=masks, | |
gt_pc=pc, | |
rgb_obj=rgb_obj, | |
mask_obj=masks_obj, | |
pc_obj=pc_obj, | |
camera_obj=camera_obj, | |
cent_hum=cent_hum, | |
cent_obj=cent_obj, | |
radius_hum=radius_hum, | |
radius_obj=radius_obj, | |
sample_from_interm=True, | |
**kwargs) | |
elif mode == 'interm-pred': | |
# use camera from predicted | |
camera = PerspectiveCameras( | |
R=torch.stack(batch['R']), | |
T=torch.stack(batch['T_hum_scaled']), | |
K=torch.stack(batch['K_hum']), | |
device='cuda', | |
in_ndc=True | |
) | |
camera_obj = PerspectiveCameras( | |
R=torch.stack(batch['R']), | |
T=torch.stack(batch['T_obj_scaled']), | |
K=torch.stack(batch['K_obj']), # the camera should be human/object specific!!! | |
device='cuda', | |
in_ndc=True | |
) | |
# use pc from predicted | |
pc = Pointclouds([x.to('cuda') for x in batch['pred_hum']]) | |
pc_obj = Pointclouds([x.to('cuda') for x in batch['pred_obj']]) | |
# use center and radius from predicted | |
cent_hum = torch.stack(batch['cent_hum_pred'], 0).to('cuda') | |
cent_obj = torch.stack(batch['cent_obj_pred'], 0).to('cuda') # B, 3 | |
radius_hum = torch.stack(batch['radius_hum_pred'], 0).to('cuda') # B, 1 | |
radius_obj = torch.stack(batch['radius_obj_pred'], 0).to('cuda') | |
return self.forward_sample( | |
num_points=num_points, | |
camera=camera, | |
image_rgb=images, | |
mask=masks, | |
gt_pc=pc, | |
rgb_obj=rgb_obj, | |
mask_obj=masks_obj, | |
pc_obj=pc_obj, | |
camera_obj=camera_obj, | |
cent_hum=cent_hum, | |
cent_obj=cent_obj, | |
radius_hum=radius_hum, | |
radius_obj=radius_obj, | |
sample_from_interm=True, | |
**kwargs) | |
elif mode == 'interm-pred-ts': | |
# use only estimate translation and scale, but sample from gaussian | |
# this works, the camera is GT!!! | |
pc = Pointclouds([x.to('cuda') for x in batch['pred_hum']]) | |
pc_obj = Pointclouds([x.to('cuda') for x in batch['pred_obj']]) | |
# use center and radius from predicted | |
cent_hum = torch.stack(batch['cent_hum_pred'], 0).to('cuda') | |
cent_obj = torch.stack(batch['cent_obj_pred'], 0).to('cuda') # B, 3 | |
radius_hum = torch.stack(batch['radius_hum_pred'], 0).to('cuda') # B, 1 | |
radius_obj = torch.stack(batch['radius_obj_pred'], 0).to('cuda') | |
# print(cent_hum[0], radius_hum[0], cent_obj[0], radius_obj[0]) | |
return self.forward_sample( | |
num_points=num_points, | |
camera=camera, | |
image_rgb=images, | |
mask=masks, | |
gt_pc=pc, | |
rgb_obj=rgb_obj, | |
mask_obj=masks_obj, | |
pc_obj=pc_obj, | |
camera_obj=camera_obj, | |
cent_hum=cent_hum, | |
cent_obj=cent_obj, | |
radius_hum=radius_hum, | |
radius_obj=radius_obj, | |
sample_from_interm=False, | |
**kwargs) | |
else: | |
raise NotImplementedError | |
def forward_sample( | |
self, | |
num_points: int, | |
camera: Optional[CamerasBase], | |
image_rgb: Optional[Tensor], | |
mask: Optional[Tensor], | |
# Optional overrides | |
scheduler: Optional[str] = 'ddpm', | |
# Inference parameters | |
num_inference_steps: Optional[int] = 1000, | |
eta: Optional[float] = 0.0, # for DDIM | |
# Whether to return all the intermediate steps in generation | |
return_sample_every_n_steps: int = -1, | |
# Whether to disable tqdm | |
disable_tqdm: bool = False, | |
gt_pc: Pointclouds = None, | |
**kwargs | |
): | |
"use two models to run diffusion forward, and also use translation and scale to put them back" | |
assert not self.self_conditioning | |
# Get scheduler from mapping, or use self.scheduler if None | |
scheduler = self.scheduler if scheduler is None else self.schedulers_map[scheduler] | |
# Get the size of the noise | |
N = num_points | |
B = 1 if image_rgb is None else image_rgb.shape[0] | |
D = self.get_x_T_channel() | |
device = self.device if image_rgb is None else image_rgb.device | |
# sample from full steps or only a few steps | |
sample_from_interm = kwargs.get('sample_from_interm', False) | |
interm_steps = kwargs.get('noise_step') if sample_from_interm else -1 | |
xt_h = self.initialize_x_T(device, gt_pc, (B, N, D), interm_steps, scheduler) | |
xt_o = self.initialize_x_T(device, kwargs.get('pc_obj', None), (B, N, D), interm_steps, scheduler) | |
# the segmentation mask | |
segm_mask = torch.zeros(B, 2*N, 1).to(device) | |
segm_mask[:, :N] = 1.0 | |
# Set timesteps | |
extra_step_kwargs = self.setup_reverse_process(eta, num_inference_steps, scheduler) | |
# Loop over timesteps | |
all_outputs = [] | |
return_all_outputs = (return_sample_every_n_steps > 0) | |
progress_bar = tqdm(self.get_reverse_timesteps(scheduler, interm_steps), | |
desc=f'Sampling ({xt_h.shape})', disable=disable_tqdm) | |
# print("Camera T:", camera.T[0], camera.R[0]) | |
# print("Camera_obj T:", kwargs.get('camera_obj').T[0], kwargs.get('camera_obj').R[0]) | |
norm_parms = self.pack_norm_params(kwargs) | |
for i, t in enumerate(progress_bar): | |
x_t_input_h, x_t_input_o = self.get_image_conditioning(camera, image_rgb, | |
kwargs, mask, | |
norm_parms, | |
t, | |
xt_h, xt_o) | |
# One reverse step with conditioning | |
xt_h, xt_o = self.reverse_step(extra_step_kwargs, scheduler, t, torch.stack([xt_h, xt_o], 0), | |
torch.stack([x_t_input_h, x_t_input_o], 0), **kwargs) # (B, N, D), D=3 | |
if (return_all_outputs and (i % return_sample_every_n_steps == 0 or i == len(scheduler.timesteps) - 1)): | |
# print(xt_h.shape, kwargs.get('cent_hum').shape, kwargs.get('radius_hum').shape) | |
x_t = torch.cat([self.denormalize_pclouds(xt_h, kwargs.get('cent_hum'), kwargs.get('radius_hum')), | |
self.denormalize_pclouds(xt_o, kwargs.get('cent_obj'), kwargs.get('radius_obj'))], 1) | |
# print(x_t.shape, xt_o.shape) | |
all_outputs.append(torch.cat([x_t, segm_mask], -1)) | |
# print("Updating intermediate...") | |
# Convert output back into a point cloud, undoing normalization and scaling | |
x_t = torch.cat([self.denormalize_pclouds(xt_h, kwargs.get('cent_hum'), kwargs.get('radius_hum')), | |
self.denormalize_pclouds(xt_o, kwargs.get('cent_obj'), kwargs.get('radius_obj'))], 1) | |
x_t = torch.cat([x_t, segm_mask], -1) | |
output = self.tensor_to_point_cloud(x_t, denormalize=False, unscale=False) # this convert the points back to original scale | |
if return_all_outputs: | |
all_outputs = torch.stack(all_outputs, dim=1) # (B, sample_steps, N, D) | |
all_outputs = [self.tensor_to_point_cloud(o, denormalize=False, unscale=False) for o in all_outputs] | |
return (output, all_outputs) if return_all_outputs else output | |
def get_reverse_timesteps(self, scheduler, interm_steps:int): | |
""" | |
:param scheduler: | |
:param interm_steps: start from some intermediate steps | |
:return: | |
""" | |
if interm_steps > 0: | |
timesteps = torch.from_numpy(np.arange(0, interm_steps)[::-1].copy()).to(self.device) | |
else: | |
timesteps = scheduler.timesteps.to(self.device) | |
return timesteps | |
def pack_norm_params(self, kwargs:dict, scale=True): | |
scale_factor = self.scale_factor if scale else 1.0 | |
hum = torch.cat([kwargs.get('cent_hum')*scale_factor, kwargs.get('radius_hum')], -1) | |
obj = torch.cat([kwargs.get('cent_obj')*scale_factor, kwargs.get('radius_obj')], -1) | |
return torch.stack([hum, obj], 0) # (2, B, 4) | |
def reverse_step(self, extra_step_kwargs, scheduler, t, x_t, x_t_input, **kwargs): | |
"x_t: (2, B, D, N), x_t_input: (2, B, D, N)" | |
norm_parms = self.pack_norm_params(kwargs) # (2, B, 4) | |
B = x_t.shape[1] | |
# print(f"Step {t} Norm params:", norm_parms[:, 0, :]) | |
noise_pred_h, noise_pred_o = self.point_cloud_model(x_t_input[0], x_t_input[1], t.reshape(1).expand(B), | |
norm_parms) | |
if self.consistent_center: | |
assert self.dm_pred_type != 'sample', 'incompatible dm predition type!' | |
noise_pred_h = noise_pred_h - torch.mean(noise_pred_h, dim=1, keepdim=True) | |
noise_pred_o = noise_pred_o - torch.mean(noise_pred_o, dim=1, keepdim=True) | |
xt_h = scheduler.step(noise_pred_h, t, x_t[0], **extra_step_kwargs).prev_sample | |
xt_o = scheduler.step(noise_pred_o, t, x_t[1], **extra_step_kwargs).prev_sample | |
if self.consistent_center: | |
xt_h = xt_h - torch.mean(xt_h, dim=1, keepdim=True) | |
xt_o = xt_o - torch.mean(xt_o, dim=1, keepdim=True) | |
return xt_h, xt_o | |
def denormalize_pclouds(self, x: Tensor, cent, radius, unscale: bool = True): | |
""" | |
first denormalize, then apply center and scale to original H+O coordinate | |
:param x: | |
:param cent: (B, 3) | |
:param radius: (B, 1) | |
:param unscale: | |
:return: | |
""" | |
# denormalize: scale down. | |
points = x[:, :, :3] / (self.scale_factor if unscale else 1) | |
# translation and scale back to H+O coordinate | |
points = points * 2 * radius.unsqueeze(-1) + cent.unsqueeze(1) | |
return points | |
def tensor_to_point_cloud(self, x: Tensor, /, denormalize: bool = False, unscale: bool = False): | |
""" | |
take binary into account | |
:param self: | |
:param x: (B, N, 4) | |
:param denormalize: | |
:param unscale: | |
:return: | |
""" | |
points = x[:, :, :3] / (self.scale_factor if unscale else 1) | |
if self.predict_color: | |
colors = self.denormalize(x[:, :, 3:]) if denormalize else x[:, :, 3:] | |
return Pointclouds(points=points, features=colors) | |
else: | |
assert x.shape[2] == 4 | |
# add color to predicted binary labels | |
is_hum = x[:, :, 3] > 0.5 | |
features = [] | |
for mask in is_hum: | |
color = torch.zeros_like(x[0, :, :3]) + torch.tensor([0.5, 1.0, 0]).to(x.device) | |
color[mask, :] = torch.tensor([0.05, 1.0, 1.0]).to(x.device) # human is light blue, object light green | |
features.append(color) | |
return Pointclouds(points=points, features=features) | |