HDM-interaction-recon / model /model_hoattn.py
xiexh20's picture
add hdm demo v1
2fd6166
raw history blame
No virus
20.9 kB
"""
model that use cross attention to predict human + object
"""
import inspect
import random
from typing import Optional
from torch import Tensor
import torch
import numpy as np
from pytorch3d.structures import Pointclouds
from pytorch3d.renderer import CamerasBase
from .model_diff_data import ConditionalPCDiffusionBehave
from .pvcnn.pvcnn_ho import PVCNN2HumObj
import torch.nn.functional as F
from pytorch3d.renderer import PerspectiveCameras
from .model_utils import get_num_points
from tqdm import tqdm
class CrossAttenHODiffusionModel(ConditionalPCDiffusionBehave):
def init_pcloud_model(self, kwargs, point_cloud_model, point_cloud_model_embed_dim):
"""use cross attention model"""
if point_cloud_model == 'pvcnn':
self.point_cloud_model = PVCNN2HumObj(embed_dim=point_cloud_model_embed_dim,
num_classes=self.out_channels,
extra_feature_channels=(self.in_channels - 3),
voxel_resolution_multiplier=kwargs.get('voxel_resolution_multiplier', 1),
attn_type=kwargs.get('attn_type', 'simple-cross'),
attn_weight=kwargs.get("attn_weight", 1.0)
)
else:
raise ValueError(f"Unknown point cloud model {point_cloud_model}!")
self.point_visible_test = kwargs.get("point_visible_test", 'single') # when doing point visibility test, use only human points or human + object?
assert self.point_visible_test in ['single', 'combine'], f'invalide point visible test option {self.point_visible_test}'
# print(f"Point visibility test is based on {self.point_visible_test} point clouds!")
def forward_train(
self,
pc: Pointclouds,
camera: Optional[CamerasBase],
image_rgb: Optional[Tensor],
mask: Optional[Tensor],
return_intermediate_steps: bool = False,
**kwargs
):
"additional input (RGB, mask, camera, and pc) for object is read from kwargs"
# assert not self.consistent_center
assert not self.self_conditioning
# Normalize colors and convert to tensor
x0_h = self.point_cloud_to_tensor(pc, normalize=True, scale=True) # this will not pack the point colors
x0_o = self.point_cloud_to_tensor(kwargs.get('pc_obj'), normalize=True, scale=True)
B, N, D = x0_h.shape
# Sample random noise
noise = torch.randn_like(x0_h)
if self.consistent_center:
# modification suggested by https://arxiv.org/pdf/2308.07837.pdf
noise = noise - torch.mean(noise, dim=1, keepdim=True)
# Sample random timesteps for each point_cloud
timestep = torch.randint(0, self.scheduler.num_train_timesteps, (B,),
device=self.device, dtype=torch.long)
# timestep = torch.randint(0, 1, (B,),
# device=self.device, dtype=torch.long)
# Add noise to points
xt_h = self.scheduler.add_noise(x0_h, noise, timestep)
xt_o = self.scheduler.add_noise(x0_o, noise, timestep)
norm_parms = self.pack_norm_params(kwargs) # (2, B, 4)
# get input conditioning
x_t_input_h, x_t_input_o = self.get_image_conditioning(camera, image_rgb, kwargs, mask, norm_parms, timestep,
xt_h, xt_o)
# Diffusion prediction
noise_pred_h, noise_pred_o = self.point_cloud_model(x_t_input_h, x_t_input_o, timestep, norm_parms)
# Check
if not noise_pred_h.shape == noise.shape:
raise ValueError(f'{noise_pred_h.shape=} and {noise.shape=}')
if not noise_pred_o.shape == noise.shape:
raise ValueError(f'{noise_pred_o.shape=} and {noise.shape=}')
# Loss
loss_h = F.mse_loss(noise_pred_h, noise)
loss_o = F.mse_loss(noise_pred_o, noise)
loss = loss_h + loss_o
# Whether to return intermediate steps
if return_intermediate_steps:
return loss, (x0_h, xt_h, noise, noise_pred_h)
return loss, torch.tensor([loss_h, loss_o])
def get_image_conditioning(self, camera, image_rgb, kwargs, mask, norm_parms, timestep, xt_h, xt_o):
"""
compute image features for each point
:param camera:
:param image_rgb:
:param kwargs:
:param mask:
:param norm_parms:
:param timestep:
:param xt_h:
:param xt_o:
:return:
"""
if self.point_visible_test == 'single':
# Visibility test is down independently for human and object
x_t_input_h = self.get_input_with_conditioning(xt_h, camera=camera,
image_rgb=image_rgb, mask=mask, t=timestep)
x_t_input_o = self.get_input_with_conditioning(xt_o, camera=kwargs.get('camera_obj'),
image_rgb=kwargs.get('rgb_obj'),
mask=kwargs.get('mask_obj'), t=timestep)
elif self.point_visible_test == 'combine':
# Combine human + object points to do visibility test and obtain features
B, N = xt_h.shape[:2] # (B, N, 3)
# for human: transform object points first to H+O space, then to human space
xt_o_in_ho = xt_o * 2 * norm_parms[1, :, 3:].unsqueeze(1) + norm_parms[1, :, :3].unsqueeze(1)
xt_o_in_hum = (xt_o_in_ho - norm_parms[0, :, :3].unsqueeze(1)) / (2 * norm_parms[0, :, 3:].unsqueeze(1))
# compute features for all points, take only first half feature for human
x_t_input_h = self.get_input_with_conditioning(torch.cat([xt_h, xt_o_in_hum], 1), camera=camera,
image_rgb=image_rgb, mask=mask, t=timestep)[:,:N]
# for object: transform human points to H+O space, then to object space
xt_h_in_ho = xt_h * 2 * norm_parms[0, :, 3:].unsqueeze(1) + norm_parms[0, :, :3].unsqueeze(1)
xt_h_in_obj = (xt_h_in_ho - norm_parms[1, :, :3].unsqueeze(1)) / (2 * norm_parms[1, :, 3:].unsqueeze(1))
x_t_input_o = self.get_input_with_conditioning(torch.cat([xt_o, xt_h_in_obj], 1),
camera=kwargs.get('camera_obj'),
image_rgb=kwargs.get('rgb_obj'),
mask=kwargs.get('mask_obj'), t=timestep)[:, :N]
else:
raise NotImplementedError
return x_t_input_h, x_t_input_o
def forward(self, batch, mode: str = 'train', **kwargs):
""""""
images = torch.stack(batch['images'], 0).to('cuda')
masks = torch.stack(batch['masks'], 0).to('cuda')
pc = self.get_input_pc(batch)
camera = PerspectiveCameras(
R=torch.stack(batch['R']),
T=torch.stack(batch['T_hum']),
K=torch.stack(batch['K_hum']),
device='cuda',
in_ndc=True
)
grid_df = torch.stack(batch['grid_df'], 0).to('cuda') if 'grid_df' in batch else None
num_points = kwargs.pop('num_points', get_num_points(pc))
rgb_obj = torch.stack(batch['images_obj'], 0).to('cuda')
masks_obj = torch.stack(batch['masks_obj'], 0).to('cuda')
pc_obj = Pointclouds([x.to('cuda') for x in batch['pclouds_obj']])
camera_obj = PerspectiveCameras(
R=torch.stack(batch['R']),
T=torch.stack(batch['T_obj']),
K=torch.stack(batch['K_obj']),
device='cuda',
in_ndc=True
)
# normalization parameters
cent_hum = torch.stack(batch['cent_hum'], 0).to('cuda')
cent_obj = torch.stack(batch['cent_obj'], 0).to('cuda') # B, 3
radius_hum = torch.stack(batch['radius_hum'], 0).to('cuda') # B, 1
radius_obj = torch.stack(batch['radius_obj'], 0).to('cuda')
# print(batch['image_path'])
if mode == 'train':
return self.forward_train(
pc=pc,
camera=camera,
image_rgb=images,
mask=masks,
grid_df=grid_df,
rgb_obj=rgb_obj,
mask_obj=masks_obj,
pc_obj=pc_obj,
camera_obj=camera_obj,
cent_hum=cent_hum,
cent_obj=cent_obj,
radius_hum=radius_hum,
radius_obj=radius_obj,
)
elif mode == 'sample':
# this use GT centers to do projection
return self.forward_sample(
num_points=num_points,
camera=camera,
image_rgb=images,
mask=masks,
gt_pc=pc,
rgb_obj=rgb_obj,
mask_obj=masks_obj,
pc_obj=pc_obj,
camera_obj=camera_obj,
cent_hum=cent_hum,
cent_obj=cent_obj,
radius_hum=radius_hum,
radius_obj=radius_obj,
**kwargs)
elif mode == 'interm-gt':
return self.forward_sample(
num_points=num_points,
camera=camera,
image_rgb=images,
mask=masks,
gt_pc=pc,
rgb_obj=rgb_obj,
mask_obj=masks_obj,
pc_obj=pc_obj,
camera_obj=camera_obj,
cent_hum=cent_hum,
cent_obj=cent_obj,
radius_hum=radius_hum,
radius_obj=radius_obj,
sample_from_interm=True,
**kwargs)
elif mode == 'interm-pred':
# use camera from predicted
camera = PerspectiveCameras(
R=torch.stack(batch['R']),
T=torch.stack(batch['T_hum_scaled']),
K=torch.stack(batch['K_hum']),
device='cuda',
in_ndc=True
)
camera_obj = PerspectiveCameras(
R=torch.stack(batch['R']),
T=torch.stack(batch['T_obj_scaled']),
K=torch.stack(batch['K_obj']), # the camera should be human/object specific!!!
device='cuda',
in_ndc=True
)
# use pc from predicted
pc = Pointclouds([x.to('cuda') for x in batch['pred_hum']])
pc_obj = Pointclouds([x.to('cuda') for x in batch['pred_obj']])
# use center and radius from predicted
cent_hum = torch.stack(batch['cent_hum_pred'], 0).to('cuda')
cent_obj = torch.stack(batch['cent_obj_pred'], 0).to('cuda') # B, 3
radius_hum = torch.stack(batch['radius_hum_pred'], 0).to('cuda') # B, 1
radius_obj = torch.stack(batch['radius_obj_pred'], 0).to('cuda')
return self.forward_sample(
num_points=num_points,
camera=camera,
image_rgb=images,
mask=masks,
gt_pc=pc,
rgb_obj=rgb_obj,
mask_obj=masks_obj,
pc_obj=pc_obj,
camera_obj=camera_obj,
cent_hum=cent_hum,
cent_obj=cent_obj,
radius_hum=radius_hum,
radius_obj=radius_obj,
sample_from_interm=True,
**kwargs)
elif mode == 'interm-pred-ts':
# use only estimate translation and scale, but sample from gaussian
# this works, the camera is GT!!!
pc = Pointclouds([x.to('cuda') for x in batch['pred_hum']])
pc_obj = Pointclouds([x.to('cuda') for x in batch['pred_obj']])
# use center and radius from predicted
cent_hum = torch.stack(batch['cent_hum_pred'], 0).to('cuda')
cent_obj = torch.stack(batch['cent_obj_pred'], 0).to('cuda') # B, 3
radius_hum = torch.stack(batch['radius_hum_pred'], 0).to('cuda') # B, 1
radius_obj = torch.stack(batch['radius_obj_pred'], 0).to('cuda')
# print(cent_hum[0], radius_hum[0], cent_obj[0], radius_obj[0])
return self.forward_sample(
num_points=num_points,
camera=camera,
image_rgb=images,
mask=masks,
gt_pc=pc,
rgb_obj=rgb_obj,
mask_obj=masks_obj,
pc_obj=pc_obj,
camera_obj=camera_obj,
cent_hum=cent_hum,
cent_obj=cent_obj,
radius_hum=radius_hum,
radius_obj=radius_obj,
sample_from_interm=False,
**kwargs)
else:
raise NotImplementedError
def forward_sample(
self,
num_points: int,
camera: Optional[CamerasBase],
image_rgb: Optional[Tensor],
mask: Optional[Tensor],
# Optional overrides
scheduler: Optional[str] = 'ddpm',
# Inference parameters
num_inference_steps: Optional[int] = 1000,
eta: Optional[float] = 0.0, # for DDIM
# Whether to return all the intermediate steps in generation
return_sample_every_n_steps: int = -1,
# Whether to disable tqdm
disable_tqdm: bool = False,
gt_pc: Pointclouds = None,
**kwargs
):
"use two models to run diffusion forward, and also use translation and scale to put them back"
assert not self.self_conditioning
# Get scheduler from mapping, or use self.scheduler if None
scheduler = self.scheduler if scheduler is None else self.schedulers_map[scheduler]
# Get the size of the noise
N = num_points
B = 1 if image_rgb is None else image_rgb.shape[0]
D = self.get_x_T_channel()
device = self.device if image_rgb is None else image_rgb.device
# sample from full steps or only a few steps
sample_from_interm = kwargs.get('sample_from_interm', False)
interm_steps = kwargs.get('noise_step') if sample_from_interm else -1
xt_h = self.initialize_x_T(device, gt_pc, (B, N, D), interm_steps, scheduler)
xt_o = self.initialize_x_T(device, kwargs.get('pc_obj', None), (B, N, D), interm_steps, scheduler)
# the segmentation mask
segm_mask = torch.zeros(B, 2*N, 1).to(device)
segm_mask[:, :N] = 1.0
# Set timesteps
extra_step_kwargs = self.setup_reverse_process(eta, num_inference_steps, scheduler)
# Loop over timesteps
all_outputs = []
return_all_outputs = (return_sample_every_n_steps > 0)
progress_bar = tqdm(self.get_reverse_timesteps(scheduler, interm_steps),
desc=f'Sampling ({xt_h.shape})', disable=disable_tqdm)
# print("Camera T:", camera.T[0], camera.R[0])
# print("Camera_obj T:", kwargs.get('camera_obj').T[0], kwargs.get('camera_obj').R[0])
norm_parms = self.pack_norm_params(kwargs)
for i, t in enumerate(progress_bar):
x_t_input_h, x_t_input_o = self.get_image_conditioning(camera, image_rgb,
kwargs, mask,
norm_parms,
t,
xt_h, xt_o)
# One reverse step with conditioning
xt_h, xt_o = self.reverse_step(extra_step_kwargs, scheduler, t, torch.stack([xt_h, xt_o], 0),
torch.stack([x_t_input_h, x_t_input_o], 0), **kwargs) # (B, N, D), D=3
if (return_all_outputs and (i % return_sample_every_n_steps == 0 or i == len(scheduler.timesteps) - 1)):
# print(xt_h.shape, kwargs.get('cent_hum').shape, kwargs.get('radius_hum').shape)
x_t = torch.cat([self.denormalize_pclouds(xt_h, kwargs.get('cent_hum'), kwargs.get('radius_hum')),
self.denormalize_pclouds(xt_o, kwargs.get('cent_obj'), kwargs.get('radius_obj'))], 1)
# print(x_t.shape, xt_o.shape)
all_outputs.append(torch.cat([x_t, segm_mask], -1))
# print("Updating intermediate...")
# Convert output back into a point cloud, undoing normalization and scaling
x_t = torch.cat([self.denormalize_pclouds(xt_h, kwargs.get('cent_hum'), kwargs.get('radius_hum')),
self.denormalize_pclouds(xt_o, kwargs.get('cent_obj'), kwargs.get('radius_obj'))], 1)
x_t = torch.cat([x_t, segm_mask], -1)
output = self.tensor_to_point_cloud(x_t, denormalize=False, unscale=False) # this convert the points back to original scale
if return_all_outputs:
all_outputs = torch.stack(all_outputs, dim=1) # (B, sample_steps, N, D)
all_outputs = [self.tensor_to_point_cloud(o, denormalize=False, unscale=False) for o in all_outputs]
return (output, all_outputs) if return_all_outputs else output
def get_reverse_timesteps(self, scheduler, interm_steps:int):
"""
:param scheduler:
:param interm_steps: start from some intermediate steps
:return:
"""
if interm_steps > 0:
timesteps = torch.from_numpy(np.arange(0, interm_steps)[::-1].copy()).to(self.device)
else:
timesteps = scheduler.timesteps.to(self.device)
return timesteps
def pack_norm_params(self, kwargs:dict, scale=True):
scale_factor = self.scale_factor if scale else 1.0
hum = torch.cat([kwargs.get('cent_hum')*scale_factor, kwargs.get('radius_hum')], -1)
obj = torch.cat([kwargs.get('cent_obj')*scale_factor, kwargs.get('radius_obj')], -1)
return torch.stack([hum, obj], 0) # (2, B, 4)
def reverse_step(self, extra_step_kwargs, scheduler, t, x_t, x_t_input, **kwargs):
"x_t: (2, B, D, N), x_t_input: (2, B, D, N)"
norm_parms = self.pack_norm_params(kwargs) # (2, B, 4)
B = x_t.shape[1]
# print(f"Step {t} Norm params:", norm_parms[:, 0, :])
noise_pred_h, noise_pred_o = self.point_cloud_model(x_t_input[0], x_t_input[1], t.reshape(1).expand(B),
norm_parms)
if self.consistent_center:
assert self.dm_pred_type != 'sample', 'incompatible dm predition type!'
noise_pred_h = noise_pred_h - torch.mean(noise_pred_h, dim=1, keepdim=True)
noise_pred_o = noise_pred_o - torch.mean(noise_pred_o, dim=1, keepdim=True)
xt_h = scheduler.step(noise_pred_h, t, x_t[0], **extra_step_kwargs).prev_sample
xt_o = scheduler.step(noise_pred_o, t, x_t[1], **extra_step_kwargs).prev_sample
if self.consistent_center:
xt_h = xt_h - torch.mean(xt_h, dim=1, keepdim=True)
xt_o = xt_o - torch.mean(xt_o, dim=1, keepdim=True)
return xt_h, xt_o
def denormalize_pclouds(self, x: Tensor, cent, radius, unscale: bool = True):
"""
first denormalize, then apply center and scale to original H+O coordinate
:param x:
:param cent: (B, 3)
:param radius: (B, 1)
:param unscale:
:return:
"""
# denormalize: scale down.
points = x[:, :, :3] / (self.scale_factor if unscale else 1)
# translation and scale back to H+O coordinate
points = points * 2 * radius.unsqueeze(-1) + cent.unsqueeze(1)
return points
def tensor_to_point_cloud(self, x: Tensor, /, denormalize: bool = False, unscale: bool = False):
"""
take binary into account
:param self:
:param x: (B, N, 4)
:param denormalize:
:param unscale:
:return:
"""
points = x[:, :, :3] / (self.scale_factor if unscale else 1)
if self.predict_color:
colors = self.denormalize(x[:, :, 3:]) if denormalize else x[:, :, 3:]
return Pointclouds(points=points, features=colors)
else:
assert x.shape[2] == 4
# add color to predicted binary labels
is_hum = x[:, :, 3] > 0.5
features = []
for mask in is_hum:
color = torch.zeros_like(x[0, :, :3]) + torch.tensor([0.5, 1.0, 0]).to(x.device)
color[mask, :] = torch.tensor([0.05, 1.0, 1.0]).to(x.device) # human is light blue, object light green
features.append(color)
return Pointclouds(points=points, features=features)