xxie
add ddim support
531dfb5
"""
Demo for template-free reconstruction
python demo.py model=ho-attn run.image_path=/BS/xxie-2/work/HDM/outputs/000000017450/k1.color.jpg run.job=sample model.predict_binary=True dataset.std_coverage=3.0
"""
import pickle as pkl
import sys, os
import os.path as osp
from typing import Iterable, Optional
import cv2
from accelerate import Accelerator
from tqdm import tqdm
from glob import glob
sys.path.append(os.getcwd())
import hydra
import torch
import numpy as np
import imageio
from torch.utils.data import DataLoader
from pytorch3d.datasets import R2N2, collate_batched_meshes
from pytorch3d.structures import Pointclouds
from pytorch3d.renderer import PerspectiveCameras, look_at_view_transform
from pytorch3d.io import IO
import torchvision.transforms.functional as TVF
from huggingface_hub import hf_hub_download
import training_utils
from configs.structured import ProjectConfig
from dataset.demo_dataset import DemoDataset
from model import CrossAttenHODiffusionModel, ConditionalPCDiffusionSeparateSegm
from render.pyt3d_wrapper import PcloudRenderer
class DemoRunner:
def __init__(self, cfg: ProjectConfig):
cfg.model.model_name, cfg.model.predict_binary = 'pc2-diff-ho-sepsegm', True
model_stage1 = ConditionalPCDiffusionSeparateSegm(**cfg.model)
cfg.model.model_name, cfg.model.predict_binary = 'diff-ho-attn', False # stage 2 does not predict segmentation
model_stage2 = CrossAttenHODiffusionModel(**cfg.model)
# Load ckpt from hf
ckpt_file1 = hf_hub_download("xiexh20/HDM-models", f'{cfg.run.stage1_name}.pth')
self.load_checkpoint(ckpt_file1, model_stage1)
ckpt_file2 = hf_hub_download("xiexh20/HDM-models", f'{cfg.run.stage2_name}.pth')
self.load_checkpoint(ckpt_file2, model_stage2)
self.model_stage1, self.model_stage2 = model_stage1, model_stage2
self.model_stage1.eval()
self.model_stage2.eval()
self.model_stage1.to('cuda')
self.model_stage2.to('cuda')
self.cfg = cfg
self.io_pc = IO()
# For visualization
self.renderer = PcloudRenderer(image_size=cfg.dataset.image_size, radius=0.0075)
self.rend_size = cfg.dataset.image_size
self.device = 'cuda'
def load_checkpoint(self, ckpt_file1, model_stage1, device='cpu'):
checkpoint = torch.load(ckpt_file1, map_location=device)
state_dict, key = checkpoint['model'], 'model'
if any(k.startswith('module.') for k in state_dict.keys()):
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
print('Removed "module." from checkpoint state dict')
missing_keys, unexpected_keys = model_stage1.load_state_dict(state_dict, strict=False)
print(f'Loaded model checkpoint {key} from {ckpt_file1}')
if len(missing_keys):
print(f' - Missing_keys: {missing_keys}')
if len(unexpected_keys):
print(f' - Unexpected_keys: {unexpected_keys}')
def reload_checkpoint(self, cat_name):
"load checkpoint of models fine tuned on specific categories"
ckpt_file1 = hf_hub_download("xiexh20/HDM-models", f'{self.cfg.run.stage1_name}-{cat_name}.pth')
self.load_checkpoint(ckpt_file1, self.model_stage1, device=self.device)
ckpt_file2 = hf_hub_download("xiexh20/HDM-models", f'{self.cfg.run.stage2_name}-{cat_name}.pth')
self.load_checkpoint(ckpt_file2, self.model_stage2, device=self.device)
@torch.no_grad()
def run(self):
"simply run the demo on given images, and save the results"
# Set random seed
training_utils.set_seed(self.cfg.run.seed)
outdir = osp.join(self.cfg.run.code_dir_abs, 'outputs/demo')
os.makedirs(outdir, exist_ok=True)
cfg = self.cfg
# Init data
image_files = sorted(glob(cfg.run.image_path))
data = DemoDataset(image_files,
(cfg.dataset.image_size, cfg.dataset.image_size),
cfg.dataset.std_coverage)
dataloader = DataLoader(data, batch_size=cfg.dataloader.batch_size,
collate_fn=collate_batched_meshes,
num_workers=1, shuffle=False)
dataloader = dataloader
progress_bar = tqdm(dataloader)
for batch_idx, batch in enumerate(progress_bar):
progress_bar.set_description(f'Processing batch {batch_idx:4d} / {len(dataloader):4d}')
out_stage1, out_stage2 = self.forward_batch(batch, cfg)
bs = len(out_stage1)
camera_full = PerspectiveCameras(
R=torch.stack(batch['R']),
T=torch.stack(batch['T']),
K=torch.stack(batch['K']),
device='cuda',
in_ndc=True)
# save output
for i in range(bs):
image_path = str(batch['image_path'])
folder, fname = osp.basename(osp.dirname(image_path)), osp.splitext(osp.basename(image_path))[0]
out_i = osp.join(outdir, folder)
os.makedirs(out_i, exist_ok=True)
self.io_pc.save_pointcloud(data=out_stage1[i],
path=osp.join(out_i, f'{fname}_stage1.ply'))
self.io_pc.save_pointcloud(data=out_stage2[i],
path=osp.join(out_i, f'{fname}_stage2.ply'))
TVF.to_pil_image(batch['images'][i]).save(osp.join(out_i, f'{fname}_input.png'))
# Save metadata as well
metadata = dict(index=i,
camera=camera_full[i],
image_size_hw=batch['image_size_hw'][i],
image_path=batch['image_path'][i])
torch.save(metadata, osp.join(out_i, f'{fname}_meta.pth'))
# Visualize
# front_camera = camera_full[i]
pc_comb = Pointclouds([out_stage1[i].points_packed(), out_stage2[i].points_packed()],
features=[out_stage1[i].features_packed(), out_stage2[i].features_packed()])
video_file = osp.join(out_i, f'{fname}_360view.mp4')
video_writer = imageio.get_writer(video_file, format='FFMPEG', mode='I', fps=1)
# first render front view
rend_stage1, _ = self.renderer.render(out_stage1[i], camera_full[i], mode='mask')
rend_stage2, _ = self.renderer.render(out_stage2[i], camera_full[i], mode='mask')
comb = np.concatenate([batch['images'][i].permute(1, 2, 0).cpu().numpy(), rend_stage1, rend_stage2], 1)
video_writer.append_data((comb*255).astype(np.uint8))
for azim in range(180, 180+360, 30):
R, T = look_at_view_transform(1.7, 0, azim, up=((0, -1, 0),), )
side_camera = PerspectiveCameras(image_size=((self.rend_size, self.rend_size),),
device=self.device,
R=R.repeat(2, 1, 1), T=T.repeat(2, 1),
focal_length=self.rend_size * 1.5,
principal_point=((self.rend_size / 2., self.rend_size / 2.),),
in_ndc=False)
rend, mask = self.renderer.render(pc_comb, side_camera, mode='mask')
imgs = [batch['images'][i].permute(1, 2, 0).cpu().numpy()]
imgs.extend([rend[0], rend[1]])
video_writer.append_data((np.concatenate(imgs, 1)*255).astype(np.uint8))
print(f"Visualization saved to {out_i}")
@torch.no_grad()
def forward_batch(self, batch, cfg):
"""
forward one batch
:param batch:
:param cfg:
:return: predicted point clouds of stage 1 and 2
"""
camera_full = PerspectiveCameras(
R=torch.stack(batch['R']),
T=torch.stack(batch['T']),
K=torch.stack(batch['K']),
device='cuda',
in_ndc=True)
out_stage1 = self.model_stage1.forward_sample(num_points=cfg.dataset.max_points,
camera=camera_full,
image_rgb=torch.stack(batch['images']).to('cuda'),
mask=torch.stack(batch['masks']).to('cuda'),
scheduler=cfg.run.diffusion_scheduler,
num_inference_steps=cfg.run.num_inference_steps,
eta=cfg.model.ddim_eta,
)
# segment and normalize human/object
bs = len(out_stage1)
pred_hum, pred_obj = [], [] # predicted human/object points
cent_hum_pred, cent_obj_pred = [], []
radius_hum_pred, radius_obj_pred = [], []
T_hum, T_obj = [], []
num_samples = int(cfg.dataset.max_points / 2)
for i in range(bs):
pc: Pointclouds = out_stage1[i]
vc = pc.features_packed().cpu() # (P, 3), human is light blue [0.1, 1.0, 1.0], object light green [0.5, 1.0, 0]
points = pc.points_packed().cpu() # (P, 3)
mask_hum = vc[:, 2] > 0.5
pc_hum, pc_obj = points[mask_hum], points[~mask_hum]
# Up/Down-sample the points
pc_obj = self.upsample_predicted_pc(num_samples, pc_obj)
pc_hum = self.upsample_predicted_pc(num_samples, pc_hum)
# Normalize
cent_hum, cent_obj = torch.mean(pc_hum, 0, keepdim=True), torch.mean(pc_obj, 0, keepdim=True)
scale_hum = torch.sqrt(torch.sum((pc_hum - cent_hum) ** 2, -1).max())
scale_obj = torch.sqrt(torch.sum((pc_obj - cent_obj) ** 2, -1).max())
pc_hum = (pc_hum - cent_hum) / (2 * scale_hum)
pc_obj = (pc_obj - cent_obj) / (2 * scale_obj)
# Also update camera parameters for separate human + object
T_hum_scaled = (batch['T_ho'][i] + cent_hum.squeeze(0)) / (2 * scale_hum)
T_obj_scaled = (batch['T_ho'][i] + cent_obj.squeeze(0)) / (2 * scale_obj)
pred_hum.append(pc_hum)
pred_obj.append(pc_obj)
cent_hum_pred.append(cent_hum.squeeze(0))
cent_obj_pred.append(cent_obj.squeeze(0))
T_hum.append(T_hum_scaled * torch.tensor([-1, -1, 1])) # apply opencv to pytorch3d transform: flip x and y
T_obj.append(T_obj_scaled * torch.tensor([-1, -1, 1]))
radius_hum_pred.append(scale_hum)
radius_obj_pred.append(scale_obj)
# Pack data into a new batch dict
camera_hum = PerspectiveCameras(
R=torch.stack(batch['R']),
T=torch.stack(T_hum),
K=torch.stack(batch['K_hum']),
device='cuda',
in_ndc=True
)
camera_obj = PerspectiveCameras(
R=torch.stack(batch['R']),
T=torch.stack(T_obj),
K=torch.stack(batch['K_obj']), # the camera should be human/object specific!!!
device='cuda',
in_ndc=True
)
# use pc from predicted
pc_hum = Pointclouds([x.to('cuda') for x in pred_hum])
pc_obj = Pointclouds([x.to('cuda') for x in pred_obj])
# use center and radius from predicted
cent_hum = torch.stack(cent_hum_pred, 0).to('cuda')
cent_obj = torch.stack(cent_obj_pred, 0).to('cuda') # B, 3
radius_hum = torch.stack(radius_hum_pred, 0).to('cuda') # B, 1
radius_obj = torch.stack(radius_obj_pred, 0).to('cuda')
out_stage2: Pointclouds = self.model_stage2.forward_sample(
num_points=num_samples,
camera=camera_hum,
image_rgb=torch.stack(batch['images_hum'], 0).to('cuda'),
mask=torch.stack(batch['masks_hum'], 0).to('cuda'),
gt_pc=pc_hum,
rgb_obj=torch.stack(batch['images_obj'], 0).to('cuda'),
mask_obj=torch.stack(batch['masks_obj'], 0).to('cuda'),
pc_obj=pc_obj,
camera_obj=camera_obj,
cent_hum=cent_hum,
cent_obj=cent_obj,
radius_hum=radius_hum.unsqueeze(-1),
radius_obj=radius_obj.unsqueeze(-1),
sample_from_interm=True,
noise_step=cfg.run.sample_noise_step,
scheduler=cfg.run.diffusion_scheduler,
num_inference_steps=cfg.run.num_inference_steps,
eta=cfg.model.ddim_eta,
)
return out_stage1, out_stage2
def upsample_predicted_pc(self, num_samples, pc_obj):
"""
Up/Downsample the points to given number
:param num_samples: the target number
:param pc_obj: (N, 3)
:return: (num_samples, 3)
"""
if len(pc_obj) > num_samples:
ind_obj = np.random.choice(len(pc_obj), num_samples)
else:
ind_obj = np.concatenate([np.arange(len(pc_obj)), np.random.choice(len(pc_obj), num_samples - len(pc_obj))])
pc_obj = pc_obj.clone()[torch.from_numpy(ind_obj).long().to(pc_obj.device)]
return pc_obj
@hydra.main(config_path='configs', config_name='configs', version_base='1.1')
def main(cfg: ProjectConfig):
runner = DemoRunner(cfg)
runner.run()
if __name__ == '__main__':
main()