diff --git a/README.md b/README.md index 3bad120e9b1a85cbafc28dab9df455f1e14135ec..25e18bc3de3c81966967698aed70d38d45415622 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,16 @@ ---- -title: HDM Interaction Recon -emoji: 🌍 -colorFrom: yellow -colorTo: green -sdk: gradio -sdk_version: 4.20.1 -app_file: app.py -pinned: false -license: cc-by-nc-4.0 ---- +# HDM +Official implementation for Hierarachical Diffusion Model in CVPR24 Template free reconstruction of human object interaction -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +[Project Page](https://virtualhumans.mpi-inf.mpg.de/procigen-hdm/)|[Code](https://github.com/xiexh20/HDM)|[Dataset](https://edmond.mpg.de/dataset.xhtml?persistentId=doi:10.17617/3.2VUEUS )|[Paper](https://virtualhumans.mpi-inf.mpg.de/procigen-hdm/paper-lowreso.pdf) + + +## Citation +``` +@inproceedings{xie2023template_free, + title = {Template Free Reconstruction of Human-object Interaction with Procedural Interaction Generation}, + author = {Xie, Xianghui and Bhatnagar, Bharat Lal and Lenssen, Jan Eric and Pons-Moll, Gerard}, + booktitle = {IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, + month = {June}, + year = {2024}, +} +``` diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..707d9b3ea3e375daa5bb47a44d7bcd2e97f476ef --- /dev/null +++ b/app.py @@ -0,0 +1,177 @@ +""" +Demo built with gradio +""" +import pickle as pkl +import sys, os +import os.path as osp +from typing import Iterable, Optional +from functools import partial + +import trimesh +from torch.utils.data import DataLoader +import cv2 +from accelerate import Accelerator +from tqdm import tqdm +from glob import glob + +sys.path.append(os.getcwd()) +import hydra +import torch +import numpy as np +import imageio +import gradio as gr +import plotly.graph_objs as go +import training_utils + +from configs.structured import ProjectConfig +from demo import DemoRunner +from dataset.demo_dataset import DemoDataset + + +md_description=""" +# HDM Interaction Reconstruction Demo +### Official Implementation of the paper \"Template Free Reconstruction of Human Object Interaction\", CVPR'24. +[Project Page](https://virtualhumans.mpi-inf.mpg.de/procigen-hdm/)|[Code](https://github.com/xiexh20/HDM)|[Dataset](https://edmond.mpg.de/dataset.xhtml?persistentId=doi:10.17617/3.2VUEUS )|[Paper](https://virtualhumans.mpi-inf.mpg.de/procigen-hdm/paper-lowreso.pdf) + + +Upload your own human object interaction image and get full 3D reconstruction! + +## Citation +``` +@inproceedings{xie2023template_free, + title = {Template Free Reconstruction of Human-object Interaction with Procedural Interaction Generation}, + author = {Xie, Xianghui and Bhatnagar, Bharat Lal and Lenssen, Jan Eric and Pons-Moll, Gerard}, + booktitle = {IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, + month = {June}, + year = {2024}, +} +``` +""" + +def plot_points(colors, coords): + """ + use plotly to visualize 3D point with colors + """ + trace = go.Scatter3d(x=coords[:, 0], y=coords[:, 1], z=coords[:, 2], mode='markers', + marker=dict( + size=2, + color=colors + )) + layout = go.Layout( + scene=dict( + xaxis=dict( + title="", + showgrid=False, + zeroline=False, + showline=False, + ticks='', + showticklabels=False + ), + yaxis=dict( + title="", + showgrid=False, + zeroline=False, + showline=False, + ticks='', + showticklabels=False + ), + zaxis=dict( + title="", + showgrid=False, + zeroline=False, + showline=False, + ticks='', + showticklabels=False + ), + ), + margin=dict(l=0, r=0, b=0, t=0), + showlegend=False + ) + fig = go.Figure(data=[trace], layout=layout) + return fig + + +def inference(runner: DemoRunner, cfg: ProjectConfig, rgb, mask_hum, mask_obj, std_coverage, input_seed): + """ + given user input, run inference + :param runner: + :param cfg: + :param rgb: (h, w, 3), np array + :param mask_hum: (h, w, 3), np array + :param mask_obj: (h, w, 3), np array + :param std_coverage: float value, used to estimate camera translation + :param input_seed: random seed + :return: path to the 3D reconstruction, and an interactive 3D figure for visualizing the point cloud + """ + # Set random seed + training_utils.set_seed(int(input_seed)) + + data = DemoDataset([], (cfg.dataset.image_size, cfg.dataset.image_size), + std_coverage) + batch = data.image2batch(rgb, mask_hum, mask_obj) + + out_stage1, out_stage2 = runner.forward_batch(batch, cfg) + points = out_stage2.points_packed().cpu().numpy() + colors = out_stage2.features_packed().cpu().numpy() + fig = plot_points(colors, points) + # save tmp point cloud + outdir = './results' + os.makedirs(outdir, exist_ok=True) + trimesh.PointCloud(points, colors).export(outdir + f"/pred_std{std_coverage}_seed{input_seed}_stage2.ply") + trimesh.PointCloud(out_stage1.points_packed().cpu().numpy(), + out_stage1.features_packed().cpu().numpy()).export(outdir + f"/pred_std{std_coverage}_seed{input_seed}_stage1.ply") + return fig, outdir + f"/pred_std{std_coverage}_seed{input_seed}_stage2.ply" + + +@hydra.main(config_path='configs', config_name='configs', version_base='1.1') +def main(cfg: ProjectConfig): + # Setup model + runner = DemoRunner(cfg) + + # Setup interface + demo = gr.Blocks(title="HDM Interaction Reconstruction Demo") + with demo: + gr.Markdown(md_description) + gr.HTML("""

HDM Demo

""") + gr.HTML("""

Instruction: Upload RGB, human, object masks and then click reconstruct.

""") + + # Input data + with gr.Row(): + input_rgb = gr.Image(label='Input RGB', type='numpy') + input_mask_hum = gr.Image(label='Human mask', type='numpy') + with gr.Row(): + input_mask_obj = gr.Image(label='Object mask', type='numpy') + with gr.Column(): + # TODO: add hint for this value here + input_std = gr.Number(label='Gaussian std coverage', value=3.5) + input_seed = gr.Number(label='Random seed', value=42) + # Output visualization + with gr.Row(): + pc_plot = gr.Plot(label="Reconstructed point cloud") + out_pc_download = gr.File(label="3D reconstruction for download") # this allows downloading + + gr.HTML("""
""") + # Control + with gr.Row(): + button_recon = gr.Button("Start Reconstruction", interactive=True, variant='secondary') + button_recon.click(fn=partial(inference, runner, cfg), + inputs=[input_rgb, input_mask_hum, input_mask_obj, input_std, input_seed], + outputs=[pc_plot, out_pc_download]) + gr.HTML("""
""") + # Example input + example_dir = cfg.run.code_dir_abs+"/examples" + rgb, ps, obj = 'k1.color.jpg', 'k1.person_mask.png', 'k1.obj_rend_mask.png' + example_images = gr.Examples([ + [f"{example_dir}/017450/{rgb}", f"{example_dir}/017450/{ps}", f"{example_dir}/017450/{obj}", 3.0, 42], + [f"{example_dir}/002446/{rgb}", f"{example_dir}/002446/{ps}", f"{example_dir}/002446/{obj}", 3.0, 42], + [f"{example_dir}/053431/{rgb}", f"{example_dir}/053431/{ps}", f"{example_dir}/053431/{obj}", 3.8, 42], + [f"{example_dir}/158107/{rgb}", f"{example_dir}/158107/{ps}", f"{example_dir}/158107/{obj}", 3.8, 42], + + ], inputs=[input_rgb, input_mask_hum, input_mask_obj, input_std, input_seed],) + + # demo.launch(share=True) + # Enabling queue for runtime>60s, see: https://github.com/tloen/alpaca-lora/issues/60#issuecomment-1510006062 + demo.queue(concurrency_count=3).launch(share=True) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/configs/__init__.py b/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/configs/structured.py b/configs/structured.py new file mode 100644 index 0000000000000000000000000000000000000000..4725105375b147cc2445ae932d8555da065ad8f9 --- /dev/null +++ b/configs/structured.py @@ -0,0 +1,416 @@ +import os +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Iterable +import os.path as osp + +from hydra.core.config_store import ConfigStore +from hydra.conf import RunDir + + +@dataclass +class CustomHydraRunDir(RunDir): + dir: str = './outputs/${run.name}/single' + + +@dataclass +class RunConfig: + name: str = 'debug' + job: str = 'train' + mixed_precision: str = 'fp16' # 'no' + cpu: bool = False + seed: int = 42 + val_before_training: bool = True + vis_before_training: bool = True + limit_train_batches: Optional[int] = None + limit_val_batches: Optional[int] = None + max_steps: int = 100_000 + checkpoint_freq: int = 1_000 + val_freq: int = 5_000 + vis_freq: int = 5_000 + # vis_freq: int = 10_000 + log_step_freq: int = 20 + print_step_freq: int = 100 + + # config to run demo + stage1_name: str = 'stage1' # experiment name to the stage 1 model + stage2_name: str = 'stage2' # experiment name to the stage 2 model + image_path: str = '' # the path to the images for running demo, can be a single file or a glob pattern + + # abs path to working dir + code_dir_abs: str = osp.dirname(osp.dirname(osp.abspath(__file__))) + + # Inference configs + num_inference_steps: int = 1000 + diffusion_scheduler: Optional[str] = 'ddpm' + num_samples: int = 1 + # num_sample_batches: Optional[int] = None + num_sample_batches: Optional[int] = 2000 # XH: change to 2 + sample_from_ema: bool = False + sample_save_evolutions: bool = False # temporarily set by default + save_name: str = 'sample' # XH: additional save name + redo: bool = False + + # for parallel sampling in slurm + batch_start: int = 0 + batch_end: Optional[int] = None + + # Training configs + freeze_feature_model: bool = True + + # Coloring training configs + coloring_training_noise_std: float = 0.0 + coloring_sample_dir: Optional[str] = None + + sample_mode: str = 'sample' # whether from noise or from some intermediate steps + sample_noise_step: int = 500 # add noise to GT up to some steps, and then denoise + sample_save_gt: bool = True + + +@dataclass +class LoggingConfig: + wandb: bool = True + wandb_project: str = 'pc2' + + + +@dataclass +class PointCloudProjectionModelConfig: + # Feature extraction arguments + image_size: int = '${dataset.image_size}' + image_feature_model: str = 'vit_base_patch16_224_mae' # or 'vit_small_patch16_224_msn' or 'identity' + use_local_colors: bool = True + use_local_features: bool = True + use_global_features: bool = False + use_mask: bool = True + use_distance_transform: bool = True + + # Point cloud data arguments. Note these are here because the processing happens + # inside the model, rather than inside the dataset. + scale_factor: float = "${dataset.scale_factor}" + colors_mean: float = 0.5 + colors_std: float = 0.5 + color_channels: int = 3 + predict_shape: bool = True + predict_color: bool = False + + # added by XH + load_sample_init: bool = False # load init samples from file + sample_init_scale: float = 1.0 # scale the initial pc samples + test_init_with_gtpc: bool = False # test time init samples with GT samples + consistent_center: bool = True # use consistent center prediction by CCD-3DR + voxel_resolution_multiplier: float = 1 # increase network voxel resolution + + # predict binary segmentation + predict_binary: bool = False # True for stage 1 model, False for others + lw_binary: float = 3.0 # to have roughly the same magnitude of the binary segmentation loss + # for separate model + binary_training_noise_std: float = 0.1 # from github doc for predicting color + self_conditioning: bool = False + +@dataclass +class PVCNNAEModelConfig(PointCloudProjectionModelConfig): + "my own model config, must inherit parent class" + model_name: str = 'pvcnn-ae' + latent_dim: int = 1024 + num_dec_blocks: int = 6 + block_dims: List[int] = field(default_factory=lambda: [512, 256]) + num_points: int = 1500 + bottleneck_dim: int = -1 # the input dim to the last MLP layer + +@dataclass +class PointCloudDiffusionModelConfig(PointCloudProjectionModelConfig): + model_name: str = 'pc2-diff-ho' # default as behave + + # Diffusion arguments + beta_start: float = 1e-5 # 0.00085 + beta_end: float = 8e-3 # 0.012 + beta_schedule: str = 'linear' # 'custom' + dm_pred_type: str = 'epsilon' # diffusion model prediction type, sample (x0) or noise + + # Point cloud model arguments + point_cloud_model: str = 'pvcnn' + point_cloud_model_embed_dim: int = 64 + + dataset_type: str = '${dataset.type}' + +@dataclass +class CrossAttnHOModelConfig(PointCloudDiffusionModelConfig): + model_name: str = 'diff-ho-attn' + + attn_type: str = 'coord3d+posenc-learnable' + attn_weight: float = 1.0 + point_visible_test: str = 'combine' # To compute point visibility: use all points or only human/object points + + +@dataclass +class DirectTransModelConfig(PointCloudProjectionModelConfig): + model_name: str = 'direct-transl-ho' + + pooling: str = "avg" + act: str = 'gelu' + out_act: str = 'relu' + # feat_dims_transl: Iterable[Any] = (384, 256, 128, 6) # cannot use List[int] https://github.com/facebookresearch/hydra/issues/1752#issuecomment-893174197 + # feat_dims_scale: Iterable[Any] = (384, 128, 64, 2) + feat_dims_transl: List[int] = field(default_factory=lambda: [384, 256, 128, 6]) + feat_dims_scale: List[int] = field(default_factory=lambda: [384, 128, 64, 2]) + lw_transl: float = 10000.0 + lw_scale: float = 10000.0 + + +@dataclass +class PointCloudColoringModelConfig(PointCloudProjectionModelConfig): + # Projection arguments + predict_shape: bool = False + predict_color: bool = True + + # Point cloud model arguments + point_cloud_model: str = 'pvcnn' + point_cloud_model_layers: int = 1 + point_cloud_model_embed_dim: int = 64 + + +@dataclass +class DatasetConfig: + type: str + + +@dataclass +class PointCloudDatasetConfig(DatasetConfig): + eval_split: str = 'val' + max_points: int = 16_384 + image_size: int = 224 + scale_factor: float = 1.0 + restrict_model_ids: Optional[List] = None # for only running on a subset of data points + + +@dataclass +class CO3DConfig(PointCloudDatasetConfig): + type: str = 'co3dv2' + # root: str = os.getenv('CO3DV2_DATASET_ROOT') + root: str = "/BS/xxie-2/work/co3d/hydrant" + category: str = 'hydrant' + subset_name: str = 'fewview_dev' + mask_images: bool = '${model.use_mask}' + + +@dataclass +class ShapeNetR2N2Config(PointCloudDatasetConfig): + # added by XH + fix_sample: bool = True + category: str = 'chair' + + type: str = 'shapenet_r2n2' + root: str = "/BS/chiban2/work/data_shapenet/ShapeNetCore.v1" + r2n2_dir: str = "/BS/databases20/3d-r2n2" + shapenet_dir: str = "/BS/chiban2/work/data_shapenet/ShapeNetCore.v1" + preprocessed_r2n2_dir: str = "${dataset.root}/r2n2_preprocessed_renders" + splits_file: str = "${dataset.root}/r2n2_standard_splits_from_ShapeNet_taxonomy.json" + # splits_file: str = "${dataset.root}/pix2mesh_splits_val05.json" # <-- incorrect + scale_factor: float = 7.0 + point_cloud_filename: str = 'pointcloud_r2n2.npz' # should use 'pointcloud_mesh.npz' + + + +@dataclass +class BehaveDatasetConfig(PointCloudDatasetConfig): + # added by XH + type: str = 'behave' + + fix_sample: bool = True + behave_dir: str = "/BS/xxie-5/static00/behave_release/sequences/" + split_file: str = "" # specify you dataset split file here + scale_factor: float = 7.0 # use the same as shapenet + sample_ratio_hum: float = 0.5 + image_size: int = 224 + + normalize_type: str = 'comb' + smpl_type: str = 'gt' # use which SMPL mesh to obtain normalization parameters + test_transl_type: str = 'norm' + + load_corr_points: bool = False # load autoencoder points for object and SMPL + uniform_obj_sample: bool = False + + # configs for direct translation prediction + bkg_type: str = 'none' + bbox_params: str = 'none' + ho_segm_pred_path: Optional[str] = None + use_gt_transl: bool = False + + cam_noise_std: float = 0. # add noise to the camera pose + sep_same_crop: bool = False # use same input image crop to separate models + aug_blur: float = 0. # blur augmentation + + std_coverage: float=3.5 # a heuristic value to estimate translation + + v2v_path: str = '' # object v2v corr path + +@dataclass +class ShapeDatasetConfig(BehaveDatasetConfig): + "the dataset to train AE for aligned shapes" + type: str = 'shape' + fix_sample: bool = False + split_file: str = "/BS/xxie-2/work/pc2-diff/experiments/splits/shapes-chair.pkl" + + +# TODO +@dataclass +class ShapeNetNMRConfig(PointCloudDatasetConfig): + type: str = 'shapenet_nmr' + shapenet_nmr_dir: str = "/work/lukemk/machine-learning-datasets/3d-reconstruction/ShapeNet_NMR/NMR_Dataset" + synset_names: str = 'chair' # comma-separated or 'all' + augmentation: str = 'all' + scale_factor: float = 7.0 + + +@dataclass +class AugmentationConfig: + # need to specify the variable type in order to define it properly + max_radius: int = 0 # generate a random square to mask object, this is the radius for the square in pixel size, zero means no occlusion + + +@dataclass +class DataloaderConfig: + # batch_size: int = 8 # 2 for debug + batch_size: int = 16 + num_workers: int = 14 # 0 for debug # suggested by accelerator for gpu20 + + +@dataclass +class LossConfig: + diffusion_weight: float = 1.0 + rgb_weight: float = 1.0 + consistency_weight: float = 1.0 + + +@dataclass +class CheckpointConfig: + resume: Optional[str] = "test" + resume_training: bool = True + resume_training_optimizer: bool = True + resume_training_scheduler: bool = True + resume_training_state: bool = True + + +@dataclass +class ExponentialMovingAverageConfig: + use_ema: bool = False + # # From Diffusers EMA (should probably switch) + # ema_inv_gamma: float = 1.0 + # ema_power: float = 0.75 + # ema_max_decay: float = 0.9999 + decay: float = 0.999 + update_every: int = 20 + + +@dataclass +class OptimizerConfig: + type: str + name: str + lr: float = 3e-4 + weight_decay: float = 0.0 + scale_learning_rate_with_batch_size: bool = False + gradient_accumulation_steps: int = 1 + clip_grad_norm: Optional[float] = 50.0 # 5.0 + kwargs: Dict = field(default_factory=lambda: dict()) + + +@dataclass +class AdadeltaOptimizerConfig(OptimizerConfig): + type: str = 'torch' + name: str = 'Adadelta' + kwargs: Dict = field(default_factory=lambda: dict( + weight_decay=1e-6, + )) + + +@dataclass +class AdamOptimizerConfig(OptimizerConfig): + type: str = 'torch' + name: str = 'AdamW' + weight_decay: float = 1e-6 + kwargs: Dict = field(default_factory=lambda: dict(betas=(0.95, 0.999))) + + +@dataclass +class SchedulerConfig: + type: str + kwargs: Dict = field(default_factory=lambda: dict()) + + +@dataclass +class LinearSchedulerConfig(SchedulerConfig): + type: str = 'transformers' + kwargs: Dict = field(default_factory=lambda: dict( + name='linear', + num_warmup_steps=0, + num_training_steps="${run.max_steps}", + )) + + +@dataclass +class CosineSchedulerConfig(SchedulerConfig): + type: str = 'transformers' + kwargs: Dict = field(default_factory=lambda: dict( + name='cosine', + num_warmup_steps=2000, # 0 + num_training_steps="${run.max_steps}", + )) + + +@dataclass +class ProjectConfig: + run: RunConfig + logging: LoggingConfig + dataset: PointCloudDatasetConfig + augmentations: AugmentationConfig + dataloader: DataloaderConfig + loss: LossConfig + model: PointCloudProjectionModelConfig + ema: ExponentialMovingAverageConfig + checkpoint: CheckpointConfig + optimizer: OptimizerConfig + scheduler: SchedulerConfig + + defaults: List[Any] = field(default_factory=lambda: [ + 'custom_hydra_run_dir', + {'run': 'default'}, + {'logging': 'default'}, + {'model': 'ho-attn'}, + # {'dataset': 'co3d'}, + {'dataset': 'behave'}, + {'augmentations': 'default'}, + {'dataloader': 'default'}, + {'ema': 'default'}, + {'loss': 'default'}, + {'checkpoint': 'default'}, + {'optimizer': 'adam'}, # default adamw + {'scheduler': 'linear'}, + # {'scheduler': 'cosine'}, + ]) + + +cs = ConfigStore.instance() +cs.store(name='custom_hydra_run_dir', node=CustomHydraRunDir, package="hydra.run") +cs.store(group='run', name='default', node=RunConfig) +cs.store(group='logging', name='default', node=LoggingConfig) +cs.store(group='model', name='diffrec', node=PointCloudDiffusionModelConfig) +cs.store(group='model', name='coloring_model', node=PointCloudColoringModelConfig) +cs.store(group='model', name='direct-transl', node=DirectTransModelConfig) +cs.store(group='model', name='ho-attn', node=CrossAttnHOModelConfig) +cs.store(group='model', name='pvcnn-ae', node=PVCNNAEModelConfig) +cs.store(group='dataset', name='co3d', node=CO3DConfig) +# TODO +cs.store(group='dataset', name='shapenet_r2n2', node=ShapeNetR2N2Config) +cs.store(group='dataset', name='behave', node=BehaveDatasetConfig) +cs.store(group='dataset', name='shape', node=ShapeDatasetConfig) +# cs.store(group='dataset', name='shapenet_nmr', node=ShapeNetNMRConfig) +cs.store(group='augmentations', name='default', node=AugmentationConfig) +cs.store(group='dataloader', name='default', node=DataloaderConfig) +cs.store(group='loss', name='default', node=LossConfig) +cs.store(group='ema', name='default', node=ExponentialMovingAverageConfig) +cs.store(group='checkpoint', name='default', node=CheckpointConfig) +cs.store(group='optimizer', name='adadelta', node=AdadeltaOptimizerConfig) +cs.store(group='optimizer', name='adam', node=AdamOptimizerConfig) +cs.store(group='scheduler', name='linear', node=LinearSchedulerConfig) +cs.store(group='scheduler', name='cosine', node=CosineSchedulerConfig) +cs.store(name='configs', node=ProjectConfig) diff --git a/dataset/__init__.py b/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a63cea161fe8d6feb7a5c09e458523b15935817a --- /dev/null +++ b/dataset/__init__.py @@ -0,0 +1,301 @@ +from pathlib import Path +from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union + +import numpy as np +import pytorch3d +import torch +from torch.utils.data import SequentialSampler +from omegaconf import DictConfig +from pytorch3d.implicitron.dataset.data_loader_map_provider import \ + SequenceDataLoaderMapProvider +from pytorch3d.implicitron.dataset.dataset_base import FrameData +from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset +from pytorch3d.implicitron.dataset.json_index_dataset_map_provider_v2 import ( + JsonIndexDatasetMapProviderV2, registry) +from pytorch3d.implicitron.tools.config import expand_args_fields +from pytorch3d.renderer.cameras import CamerasBase +from torch.utils.data import DataLoader + +from configs.structured import CO3DConfig, DataloaderConfig, ProjectConfig, Optional +from .exclude_sequence import EXCLUDE_SEQUENCE, LOW_QUALITY_SEQUENCE +from .utils import DatasetMap +from .r2n2_my import R2N2Sample, collate_batched_meshes + + +def get_dataset(cfg: ProjectConfig): + + if cfg.dataset.type == 'co3dv2': + dataset_cfg: CO3DConfig = cfg.dataset + dataloader_cfg: DataloaderConfig = cfg.dataloader + + # Exclude bad and low-quality sequences, XH: why this is needed? + exclude_sequence = [] + exclude_sequence.extend(EXCLUDE_SEQUENCE.get(dataset_cfg.category, [])) + exclude_sequence.extend(LOW_QUALITY_SEQUENCE.get(dataset_cfg.category, [])) + + # Whether to load pointclouds + kwargs = dict( + remove_empty_masks=True, + n_frames_per_sequence=1, + load_point_clouds=True, + max_points=dataset_cfg.max_points, + image_height=dataset_cfg.image_size, + image_width=dataset_cfg.image_size, + mask_images=dataset_cfg.mask_images, + exclude_sequence=exclude_sequence, + pick_sequence=() if dataset_cfg.restrict_model_ids is None else dataset_cfg.restrict_model_ids, + ) + + # Get dataset mapper + dataset_map_provider_type = registry.get(JsonIndexDatasetMapProviderV2, "JsonIndexDatasetMapProviderV2") + expand_args_fields(dataset_map_provider_type) + dataset_map_provider = dataset_map_provider_type( + category=dataset_cfg.category, + subset_name=dataset_cfg.subset_name, + dataset_root=dataset_cfg.root, + test_on_train=False, + only_test_set=False, + load_eval_batches=True, + dataset_JsonIndexDataset_args=DictConfig(kwargs), + ) + + # Get datasets + datasets = dataset_map_provider.get_dataset_map() # how to select specific frames?? + + # PATCH BUG WITH POINT CLOUD LOCATIONS! + for dataset in (datasets["train"], datasets["val"]): + # print(dataset.seq_annots.items()) + for key, ann in dataset.seq_annots.items(): + correct_point_cloud_path = Path(dataset.dataset_root) / Path(*Path(ann.point_cloud.path).parts[-3:]) + assert correct_point_cloud_path.is_file(), correct_point_cloud_path + ann.point_cloud.path = str(correct_point_cloud_path) + + # Get dataloader mapper + data_loader_map_provider_type = registry.get(SequenceDataLoaderMapProvider, "SequenceDataLoaderMapProvider") + expand_args_fields(data_loader_map_provider_type) + data_loader_map_provider = data_loader_map_provider_type( + batch_size=dataloader_cfg.batch_size, + num_workers=dataloader_cfg.num_workers, + ) + + # QUICK HACK: Patch the train dataset because it is not used but it throws an error + if (len(datasets['train']) == 0 and len(datasets[dataset_cfg.eval_split]) > 0 and + dataset_cfg.restrict_model_ids is not None and cfg.run.job == 'sample'): + datasets = DatasetMap(train=datasets[dataset_cfg.eval_split], val=datasets[dataset_cfg.eval_split], + test=datasets[dataset_cfg.eval_split]) + # XH: why all eval split? + print('Note: You used restrict_model_ids and there were no ids in the train set.') + + # Get dataloaders + dataloaders = data_loader_map_provider.get_data_loader_map(datasets) + dataloader_train = dataloaders['train'] + dataloader_val = dataloader_vis = dataloaders[dataset_cfg.eval_split] + + # Replace validation dataloader sampler with SequentialSampler + # seems to be randomly sampled? with a fixed random seed? but one cannot control which image is being sampled?? + dataloader_val.batch_sampler.sampler = SequentialSampler(dataloader_val.batch_sampler.sampler.data_source) + + # Modify for accelerate + dataloader_train.batch_sampler.drop_last = True + dataloader_val.batch_sampler.drop_last = False + elif cfg.dataset.type == 'shapenet_r2n2': + # from ..configs.structured import ShapeNetR2N2Config + dataset_cfg: ShapeNetR2N2Config = cfg.dataset + # for k in dataset_cfg: + # print(k) + datasets = [R2N2Sample(dataset_cfg.max_points, dataset_cfg.fix_sample, + dataset_cfg.image_size, cfg.augmentations, + s, dataset_cfg.shapenet_dir, + dataset_cfg.r2n2_dir, dataset_cfg.splits_file, + load_textures=False, return_all_views=True) for s in ['train', 'val', 'test']] + dataloader_train = DataLoader(datasets[0], batch_size=cfg.dataloader.batch_size, + collate_fn=collate_batched_meshes, + num_workers=cfg.dataloader.num_workers, shuffle=True) + dataloader_val = DataLoader(datasets[1], batch_size=cfg.dataloader.batch_size, + collate_fn=collate_batched_meshes, + num_workers=cfg.dataloader.num_workers, shuffle=False) + dataloader_vis = DataLoader(datasets[2], batch_size=cfg.dataloader.batch_size, + collate_fn=collate_batched_meshes, + num_workers=cfg.dataloader.num_workers, shuffle=False) + + elif cfg.dataset.type in ['behave', 'behave-objonly', 'behave-humonly', 'behave-dtransl', + 'behave-objonly-segm', 'behave-humonly-segm', 'behave-attn', + 'behave-test', 'behave-attn-test', 'behave-hum-pe', 'behave-hum-noscale', + 'behave-hum-surf', 'behave-objv2v']: + from .behave_dataset import BehaveDataset, NTUDataset, BehaveObjOnly, BehaveHumanOnly, BehaveHumanOnlyPosEnc + from .behave_dataset import BehaveHumanOnlySegmInput, BehaveObjOnlySegmInput, BehaveTestOnly, BehaveHumNoscale + from .behave_dataset import BehaveHumanOnlySurfSample + from .dtransl_dataset import DirectTranslDataset + from .behave_paths import DataPaths + from configs.structured import BehaveDatasetConfig + from .behave_crossattn import BehaveCrossAttnDataset, BehaveCrossAttnTest + from .behave_dataset import BehaveObjOnlyV2V + + dataset_cfg: BehaveDatasetConfig = cfg.dataset + # print(dataset_cfg.behave_dir) + train_paths, val_paths = DataPaths.load_splits(dataset_cfg.split_file, dataset_cfg.behave_dir) + # exit(0) + + # split validation paths to only consider the selected batches + bs = cfg.dataloader.batch_size + num_batches_total = int(np.ceil(len(val_paths)/cfg.dataloader.batch_size)) + end_idx = cfg.run.batch_end if cfg.run.batch_end is not None else num_batches_total + # print(cfg.run.batch_end, cfg.run.batch_start, end_idx) + val_paths = val_paths[cfg.run.batch_start*bs:end_idx*bs] + + if cfg.dataset.type == 'behave': + train_type = BehaveDataset + val_datatype = BehaveDataset if 'ntu' not in dataset_cfg.split_file else NTUDataset + elif cfg.dataset.type == 'behave-test': + train_type = BehaveDataset + val_datatype = BehaveTestOnly + elif cfg.dataset.type == 'behave-objonly': + train_type = BehaveObjOnly + val_datatype = BehaveObjOnly + assert 'ntu' not in dataset_cfg.split_file, 'ntu not implemented!' + elif cfg.dataset.type == 'behave-humonly': + train_type = BehaveHumanOnly + val_datatype = BehaveHumanOnly + assert 'ntu' not in dataset_cfg.split_file, 'ntu not implemented!' + elif cfg.dataset.type == 'behave-hum-noscale': + train_type = BehaveHumNoscale + val_datatype = BehaveHumNoscale + elif cfg.dataset.type == 'behave-hum-pe': + train_type = BehaveHumanOnlyPosEnc + val_datatype = BehaveHumanOnlyPosEnc + elif cfg.dataset.type == 'behave-hum-surf': + train_type = BehaveHumanOnlySurfSample + val_datatype = BehaveHumanOnlySurfSample + elif cfg.dataset.type == 'behave-humonly-segm': + assert cfg.dataset.ho_segm_pred_path is not None, 'please specify predicted HO segmentation!' + train_type = BehaveHumanOnly + val_datatype = BehaveHumanOnlySegmInput + assert 'ntu' not in dataset_cfg.split_file, 'ntu not implemented!' + elif cfg.dataset.type == 'behave-objonly-segm': + assert cfg.dataset.ho_segm_pred_path is not None, 'please specify predicted HO segmentation!' + train_type = BehaveObjOnly + val_datatype = BehaveObjOnlySegmInput + assert 'ntu' not in dataset_cfg.split_file, 'ntu not implemented!' + elif cfg.dataset.type == 'behave-dtransl': + train_type = DirectTranslDataset + val_datatype = DirectTranslDataset + elif cfg.dataset.type == 'behave-attn': + train_type = BehaveCrossAttnDataset + val_datatype = BehaveCrossAttnDataset + elif cfg.dataset.type == 'behave-attn-test': + train_type = BehaveCrossAttnDataset + val_datatype = BehaveCrossAttnTest + elif cfg.dataset.type == 'behave-objv2v': + train_type = BehaveObjOnlyV2V + val_datatype = BehaveObjOnlyV2V + else: + raise NotImplementedError + + dataset_train = train_type(train_paths, dataset_cfg.max_points, dataset_cfg.fix_sample, + (dataset_cfg.image_size, dataset_cfg.image_size), + split='train', sample_ratio_hum=dataset_cfg.sample_ratio_hum, + normalize_type=dataset_cfg.normalize_type, smpl_type='gt', + load_corr_points=dataset_cfg.load_corr_points, + uniform_obj_sample=dataset_cfg.uniform_obj_sample, + bkg_type=dataset_cfg.bkg_type, + bbox_params=dataset_cfg.bbox_params, + pred_binary=cfg.model.predict_binary, + ho_segm_pred_path=cfg.dataset.ho_segm_pred_path, + compute_closest_points=cfg.model.model_name=='pc2-diff-ho-tune-newloss', + use_gt_transl=cfg.dataset.use_gt_transl, + cam_noise_std=cfg.dataset.cam_noise_std, + sep_same_crop=cfg.dataset.sep_same_crop, + aug_blur=cfg.dataset.aug_blur, + std_coverage=cfg.dataset.std_coverage, + v2v_path=cfg.dataset.v2v_path) + + dataset_val = val_datatype(val_paths, dataset_cfg.max_points, dataset_cfg.fix_sample, + (dataset_cfg.image_size, dataset_cfg.image_size), + split='val', sample_ratio_hum=dataset_cfg.sample_ratio_hum, + normalize_type=dataset_cfg.normalize_type, smpl_type=dataset_cfg.smpl_type, + load_corr_points=dataset_cfg.load_corr_points, + test_transl_type=dataset_cfg.test_transl_type, + uniform_obj_sample=dataset_cfg.uniform_obj_sample, + bkg_type=dataset_cfg.bkg_type, + bbox_params=dataset_cfg.bbox_params, + pred_binary=cfg.model.predict_binary, + ho_segm_pred_path=cfg.dataset.ho_segm_pred_path, + compute_closest_points=cfg.model.model_name=='pc2-diff-ho-tune-newloss', + use_gt_transl=cfg.dataset.use_gt_transl, + sep_same_crop=cfg.dataset.sep_same_crop, + std_coverage=cfg.dataset.std_coverage, + v2v_path=cfg.dataset.v2v_path) + # dataset_test = val_datatype(val_paths, dataset_cfg.max_points, dataset_cfg.fix_sample, + # (dataset_cfg.image_size, dataset_cfg.image_size), + # split='test', sample_ratio_hum=dataset_cfg.sample_ratio_hum, + # normalize_type=dataset_cfg.normalize_type, smpl_type=dataset_cfg.smpl_type, + # load_corr_points=dataset_cfg.load_corr_points, + # test_transl_type=dataset_cfg.test_transl_type, + # uniform_obj_sample=dataset_cfg.uniform_obj_sample, + # bkg_type=dataset_cfg.bkg_type, + # bbox_params=dataset_cfg.bbox_params, + # pred_binary=cfg.model.predict_binary, + # ho_segm_pred_path=cfg.dataset.ho_segm_pred_path, + # compute_closest_points=cfg.model.model_name=='pc2-diff-ho-tune-newloss', + # use_gt_transl=cfg.dataset.use_gt_transl, + # sep_same_crop=cfg.dataset.sep_same_crop) + dataloader_train = DataLoader(dataset_train, batch_size=cfg.dataloader.batch_size, + collate_fn=collate_batched_meshes, + num_workers=cfg.dataloader.num_workers, shuffle=True) + shuffle = cfg.run.job == 'train' + dataloader_val = DataLoader(dataset_val, batch_size=cfg.dataloader.batch_size, + collate_fn=collate_batched_meshes, + num_workers=cfg.dataloader.num_workers, shuffle=shuffle) + dataloader_vis = DataLoader(dataset_val, batch_size=cfg.dataloader.batch_size, + collate_fn=collate_batched_meshes, + num_workers=cfg.dataloader.num_workers, shuffle=shuffle) + + # datasets = [BehaveDataset(p, dataset_cfg.max_points, dataset_cfg.fix_sample, + # (dataset_cfg.image_size, dataset_cfg.image_size), + # split=s, sample_ratio_hum=dataset_cfg.sample_ratio_hum, + # normalize_type=dataset_cfg.normalize_type) for p, s in zip([train_paths, val_paths, val_paths], + # ['train', 'val', 'test'])] + # dataloader_train = DataLoader(datasets[0], batch_size=cfg.dataloader.batch_size, + # collate_fn=collate_batched_meshes, + # num_workers=cfg.dataloader.num_workers, shuffle=True) + # dataloader_val = DataLoader(datasets[1], batch_size=cfg.dataloader.batch_size, + # collate_fn=collate_batched_meshes, + # num_workers=cfg.dataloader.num_workers, shuffle=False) + # dataloader_vis = DataLoader(datasets[2], batch_size=cfg.dataloader.batch_size, + # collate_fn=collate_batched_meshes, + # num_workers=cfg.dataloader.num_workers, shuffle=False) + elif cfg.dataset.type in ['shape']: + from .shape_dataset import ShapeDataset + from .behave_paths import DataPaths + from configs.structured import ShapeDatasetConfig + dataset_cfg: ShapeDatasetConfig = cfg.dataset + + train_paths, _ = DataPaths.load_splits(dataset_cfg.split_file, dataset_cfg.behave_dir) + val_paths = train_paths # same as training, this is for overfitting + # split validation paths to only consider the selected batches + bs = cfg.dataloader.batch_size + num_batches_total = int(np.ceil(len(val_paths) / cfg.dataloader.batch_size)) + end_idx = cfg.run.batch_end if cfg.run.batch_end is not None else num_batches_total + # print(cfg.run.batch_end, cfg.run.batch_start, end_idx) + val_paths = val_paths[cfg.run.batch_start * bs:end_idx * bs] + + dataset_train = ShapeDataset(train_paths, dataset_cfg.max_points, dataset_cfg.fix_sample, + (dataset_cfg.image_size, dataset_cfg.image_size), + split='train', ) + dataset_val = ShapeDataset(val_paths, dataset_cfg.max_points, dataset_cfg.fix_sample, + (dataset_cfg.image_size, dataset_cfg.image_size), + split='train', ) + dataloader_train = DataLoader(dataset_train, batch_size=cfg.dataloader.batch_size, + collate_fn=collate_batched_meshes, + num_workers=cfg.dataloader.num_workers, shuffle=True) + shuffle = cfg.run.job == 'train' + dataloader_val = DataLoader(dataset_val, batch_size=cfg.dataloader.batch_size, + collate_fn=collate_batched_meshes, + num_workers=cfg.dataloader.num_workers, shuffle=shuffle) + dataloader_vis = DataLoader(dataset_val, batch_size=cfg.dataloader.batch_size, + collate_fn=collate_batched_meshes, + num_workers=cfg.dataloader.num_workers, shuffle=shuffle) + else: + raise NotImplementedError(cfg.dataset.type) + + return dataloader_train, dataloader_val, dataloader_vis diff --git a/dataset/base_data.py b/dataset/base_data.py new file mode 100644 index 0000000000000000000000000000000000000000..43d250afd08963e5d3c88b456c5e673c4487f7dc --- /dev/null +++ b/dataset/base_data.py @@ -0,0 +1,110 @@ +from os import path as osp + +import cv2 +import numpy as np +from torch.utils.data import Dataset + +from dataset.img_utils import masks2bbox, resize, crop + + +class BaseDataset(Dataset): + def __init__(self, data_paths, input_size=(224, 224)): + self.data_paths = data_paths # RGB image files + self.input_size = input_size + opencv2py3d = np.eye(4) + opencv2py3d[0, 0] = opencv2py3d[1, 1] = -1 + self.opencv2py3d = opencv2py3d + + def __len__(self): + return len(self.data_paths) + + def load_masks(self, rgb_file): + person_mask_file = rgb_file.replace('.color.jpg', ".person_mask.png") + if not osp.isfile(person_mask_file): + person_mask_file = rgb_file.replace('.color.jpg', ".person_mask.jpg") + obj_mask_file = None + for pat in [".obj_rend_mask.png", ".obj_rend_mask.jpg", ".obj_mask.png", ".obj_mask.jpg", ".object_rend.png"]: + obj_mask_file = rgb_file.replace('.color.jpg', pat) + if osp.isfile(obj_mask_file): + break + person_mask = cv2.imread(person_mask_file, cv2.IMREAD_GRAYSCALE) + obj_mask = cv2.imread(obj_mask_file, cv2.IMREAD_GRAYSCALE) + + return person_mask, obj_mask + + def get_crop_params(self, mask_hum, mask_obj, bbox_exp=1.0): + "compute bounding box based on masks" + bmin, bmax = masks2bbox([mask_hum, mask_obj]) + crop_center = (bmin + bmax) // 2 + # crop_size = np.max(bmax - bmin) + crop_size = int(np.max(bmax - bmin) * bbox_exp) + if crop_size % 2 == 1: + crop_size += 1 # make sure it is an even number + return bmax, bmin, crop_center, crop_size + + def is_behave_dataset(self, image_width): + assert image_width in [2048, 1920, 1024, 960], f'unknwon image width {image_width}!' + if image_width in [2048, 1024]: + is_behave = True + else: + is_behave = False + return is_behave + + def compute_K_roi(self, bbox_square, + image_width=2048, + image_height=1536, + fx=979.7844, fy=979.840, + cx=1018.952, cy=779.486): + "return results in ndc coordinate, this is correct!!!" + x, y, b, w = bbox_square + assert b == w + is_behave = self.is_behave_dataset(image_width) + + if is_behave: + assert image_height / image_width == 0.75, f"invalid image aspect ratio: width={image_width}, height={image_height}" + # the image might be rendered at different size + ratio = image_width/2048. + fx, fy = 979.7844*ratio, 979.840*ratio + cx, cy = 1018.952*ratio, 779.486*ratio + else: + assert image_height / image_width == 9/16, f"invalid image aspect ratio: width={image_width}, height={image_height}" + # intercap camera + ratio = image_width/1920 + fx, fy = 918.457763671875*ratio, 918.4373779296875*ratio + cx, cy = 956.9661865234375*ratio, 555.944580078125*ratio + + cx, cy = cx - x, cy - y + scale = b/2. + # in ndc + cx_ = (scale - cx)/scale + cy_ = (scale - cy)/scale + fx_ = fx/scale + fy_ = fy/scale + + K_roi = np.array([ + [fx_, 0, cx_, 0], + [0., fy_, cy_, 0, ], + [0, 0, 0, 1.], + [0, 0, 1, 0] + ]) + return K_roi + + def crop_full_image(self, mask_hum, mask_obj, rgb_full, crop_masks, bbox_exp=1.0): + """ + crop the image based on the given masks + :param mask_hum: + :param mask_obj: + :param rgb_full: + :param crop_masks: a list of masks used to do the crop + :return: Kroi, cropped human, object mask and RGB images (background masked out). + """ + bmax, bmin, crop_center, crop_size = self.get_crop_params(*crop_masks, bbox_exp) + rgb = resize(crop(rgb_full, crop_center, crop_size), self.input_size) / 255. + person_mask = resize(crop(mask_hum, crop_center, crop_size), self.input_size) / 255. + obj_mask = resize(crop(mask_obj, crop_center, crop_size), self.input_size) / 255. + xywh = np.concatenate([crop_center - crop_size // 2, np.array([crop_size, crop_size])]) + Kroi = self.compute_K_roi(xywh, rgb_full.shape[1], rgb_full.shape[0]) + # mask bkg out + mask_comb = (person_mask > 0.5) | (obj_mask > 0.5) + rgb = rgb * np.expand_dims(mask_comb, -1) + return Kroi, obj_mask, person_mask, rgb diff --git a/dataset/behave_paths.py b/dataset/behave_paths.py new file mode 100644 index 0000000000000000000000000000000000000000..26304a2c937f391308de1363d9d13cd90c73cf6d --- /dev/null +++ b/dataset/behave_paths.py @@ -0,0 +1,228 @@ +import glob +import os, re +import pickle as pkl +from os.path import join, basename, dirname, isfile +import os.path as osp + +import cv2, json +import numpy as np + +# PROCESSED_PATH = paths['PROCESSED_PATH'] +BEHAVE_PATH = "/BS/xxie-5/static00/behave_release/sequences/" +RECON_PATH = "/BS/xxie-5/static00/behave-train" + +class DataPaths: + """ + class to handle path operations based on BEHAVE dataset structure + """ + def __init__(self): + pass + + @staticmethod + def load_splits(split_file, dataset_path=None): + assert os.path.exists(dataset_path), f'the given dataset path {dataset_path} does not exist, please check if your training data are placed over there!' + train, val = DataPaths.get_train_test_from_pkl(split_file) + return train, val + # print(train[:5], val[:5]) + if isinstance(train[0], list): + # video data + train_full = [[join(dataset_path, seq[x]) for x in range(len(seq))] for seq in train] + val_full = [[join(dataset_path, seq[x]) for x in range(len(seq))] for seq in val] + else: + train_full = [join(dataset_path, x) for x in train] # full path to the training data + val_full = [join(dataset_path, x) for x in val] # full path to the validation data files + # print(train_full[:5], val_full[:5]) + return train_full, val_full + + @staticmethod + def load_splits_online(split_file, dataset_path=BEHAVE_PATH): + "load rgb file, smpl and object mesh paths" + keys = ['rgb', 'smpl', 'obj'] + types = ['train', 'val'] + splits = {} + data = pkl.load(open(split_file, 'rb')) + for type in types: + for key in keys: + k = f'{type}_{key}' + splits[k] = [join(dataset_path, x) for x in data[k]] + return splits + + @staticmethod + def get_train_test_from_pkl(pkl_file): + data = pkl.load(open(pkl_file, 'rb')) + return data['train'], data['test'] + + @staticmethod + def get_image_paths_seq(seq, tid=1, check_occlusion=False, pat='t*.000'): + """ + find all image paths in one sequence + :param seq: path to one behave sequence + :param tid: test on images from which camera + :param check_occlusion: whether to load full object mask and check occlusion ratio + :return: a list of paths to test image files + """ + image_files = sorted(glob.glob(seq + f"/{pat}/k{tid}.color.jpg")) + # print(image_files, seq + f"/{pat}/k{tid}.color.jpg") + if not check_occlusion: + return image_files + # check object occlusion ratio + valid_files = [] + count = 0 + for img_file in image_files: + mask_file = img_file.replace('.color.jpg', '.obj_rend_mask.png') + if not os.path.isfile(mask_file): + mask_file = img_file.replace('.color.jpg', '.obj_rend_mask.jpg') + full_mask_file = img_file.replace('.color.jpg', '.obj_rend_full.png') + if not os.path.isfile(full_mask_file): + full_mask_file = img_file.replace('.color.jpg', '.obj_rend_full.jpg') + if not isfile(mask_file) or not isfile(full_mask_file): + continue + + mask = np.sum(cv2.imread(mask_file, cv2.IMREAD_GRAYSCALE) > 127) + mask_full = np.sum(cv2.imread(full_mask_file, cv2.IMREAD_GRAYSCALE) > 127) + if mask_full == 0: + count += 1 + continue + + ratio = mask / mask_full + if ratio > 0.3: + valid_files.append(img_file) + else: + count += 1 + print(f'{mask_file} occluded by {1 - ratio}!') + return valid_files + + @staticmethod + def get_kinect_id(rgb_file): + "extract kinect id from the rgb file" + filename = osp.basename(rgb_file) + try: + kid = int(filename.split('.')[0][1]) + assert kid in [0, 1, 2, 3, 4, 5], f'found invalid kinect id {kid} for file {rgb_file}' + return kid + except Exception as e: + print(rgb_file) + raise ValueError() + + @staticmethod + def get_seq_date(rgb_file): + "date for the sequence" + seq_name = str(rgb_file).split(os.sep)[-3] + date = seq_name.split('_')[0] + assert date in ['Date01', 'Date02', 'Date03', 'Date04', 'Date05', 'Date06', 'Date07', + "ICapS01", "ICapS02", "ICapS03", "Date08", "Date09"], f"invalid date for {rgb_file}" + return date + + @staticmethod + def rgb2obj_path(rgb_file:str, save_name='fit01-smooth'): + "convert an rgb file to a obj mesh file" + ss = rgb_file.split(os.sep) + seq_name = ss[-3] + obj_name = seq_name.split('_')[2] + real_name = obj_name + if 'chair' in obj_name: + real_name = 'chair' + if 'ball' in obj_name: + real_name = 'sports ball' + + frame_folder = osp.dirname(rgb_file) + mesh_file = osp.join(frame_folder, real_name, save_name, f'{real_name}_fit.ply') + + if not osp.isfile(mesh_file): + # synthetic data + mesh_file = osp.join(frame_folder, obj_name, save_name, f'{obj_name}_fit.ply') + return mesh_file + + @staticmethod + def rgb2smpl_path(rgb_file:str, save_name='fit03'): + frame_folder = osp.dirname(rgb_file) + real_name = 'person' + mesh_file = osp.join(frame_folder, real_name, save_name, f'{real_name}_fit.ply') + return mesh_file + + @staticmethod + def rgb2seq_frame(rgb_file:str): + "rgb file to seq_name, frame time" + ss = rgb_file.split(os.sep) + return ss[-3], ss[-2] + + @staticmethod + def rgb2recon_folder(rgb_file, save_name, recon_path): + "convert rgb file to the subfolder" + dataset_path = osp.dirname(osp.dirname(osp.dirname(rgb_file))) + recon_folder = osp.join(osp.dirname(rgb_file.replace(dataset_path, recon_path)), save_name) + return recon_folder + + @staticmethod + def get_seq_name(rgb_file): + return osp.basename(osp.dirname(osp.dirname(rgb_file))) + + @staticmethod + def rgb2template_path(rgb_file): + "return the path to the object template" + from recon.opt_utils import get_template_path + # seq_name = DataPaths.get_seq_name(rgb_file) + # obj_name = seq_name.split('_')[2] + obj_name = DataPaths.rgb2object_name(rgb_file) + path = get_template_path(BEHAVE_PATH+"/../objects", obj_name) + return path + + @staticmethod + def rgb2object_name(rgb_file): + seq_name = DataPaths.get_seq_name(rgb_file) + obj_name = seq_name.split('_')[2] + return obj_name + + @staticmethod + def rgb2recon_frame(rgb_file, recon_path=RECON_PATH): + "return the frame folder in recon path" + ss = rgb_file.split(os.sep) + seq_name, frame = ss[-3], ss[-2] + return osp.join(recon_path, seq_name, frame) + + @staticmethod + def rgb2gender(rgb_file): + "find the gender of this image" + seq_name = str(rgb_file).split(os.sep)[-3] + sub = seq_name.split('_')[1] + return _sub_gender[sub] + + @staticmethod + def get_dataset_root(rgb_file): + "return the root path to all sequences" + from pathlib import Path + path = Path(rgb_file) + return str(path.parents[2]) + + @staticmethod + def seqname2gender(seq_name:str): + sub = seq_name.split('_')[1] + return _sub_gender[sub] + +ICAP_PATH = "/BS/xxie-6/static00/InterCap" # assume same root folder +date_seqs = { + "Date01": BEHAVE_PATH + "/Date01_Sub01_backpack_back", + "Date02": BEHAVE_PATH + "/Date02_Sub02_backpack_back", + "Date03": BEHAVE_PATH + "/Date03_Sub03_backpack_back", + "Date04": BEHAVE_PATH + "/Date04_Sub05_backpack", + "Date05": BEHAVE_PATH + "/Date05_Sub05_backpack", + "Date06": BEHAVE_PATH + "/Date06_Sub07_backpack_back", + "Date07": BEHAVE_PATH + "/Date07_Sub04_backpack_back", + # "Date08": "/BS/xxie-6/static00/synthesize/Date08_Subxx_chairwood_synzv2-02", + "Date08": "/BS/xxie-6/static00/synz-backup/Date08_Subxx_chairwood_synzv2-02", + "Date09": "/BS/xxie-6/static00/synthesize/Date09_Subxx_obj01_icap", # InterCap sequence synz + "ICapS01": ICAP_PATH + "/ICapS01_sub01_obj01_Seg_0", + "ICapS02": ICAP_PATH + "/ICapS02_sub01_obj08_Seg_0", + "ICapS03": ICAP_PATH + "/ICapS03_sub07_obj05_Seg_0", +} + +_sub_gender = { +"Sub01": 'male', +"Sub02": 'male', +"Sub03": 'male', +"Sub04": 'male', +"Sub05": 'male', +"Sub06": 'female', +"Sub07": 'female', +"Sub08": 'female', +} \ No newline at end of file diff --git a/dataset/demo_dataset.py b/dataset/demo_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..8407baf146b6925fe34d36b5ff60dec97d30d7a6 --- /dev/null +++ b/dataset/demo_dataset.py @@ -0,0 +1,198 @@ +import os +import numpy as np +import cv2 +import torch + +from .base_data import BaseDataset +from .behave_paths import DataPaths +from .img_utils import compute_translation, masks2bbox, crop + + +def padTo_4x3(rgb, person_mask, obj_mask, aspect_ratio=0.75): + """ + pad images to have 4:3 aspect ratio + :param rgb: (H, W, 3) + :param person_mask: + :param obj_mask: + :return: all images at the given aspect ratio + """ + h, w = rgb.shape[:2] + if w > h * 1/aspect_ratio: + # pad top + h_4x3 = int(w * aspect_ratio) + pad_top = h_4x3 - h + rgb_pad = np.pad(rgb, ((pad_top, 0), (0, 0), (0, 0))) + person_mask = np.pad(person_mask, ((pad_top, 0), (0, 0))) if person_mask is not None else None + obj_mask = np.pad(obj_mask, ((pad_top, 0), (0, 0))) if obj_mask is not None else None + else: + # pad two side + w_new = np.lcm.reduce([h * 2, 16]) # least common multiplier + h_4x3 = int(w_new * aspect_ratio) + pad_top = h_4x3 - h + pad_left = (w_new - w) // 2 + pad_right = w_new - w - pad_left + rgb_pad = np.pad(rgb, ((pad_top, 0), (pad_left, pad_right), (0, 0))) + obj_mask = np.pad(obj_mask, ((pad_top, 0), (pad_left, pad_right))) if obj_mask is not None else None + person_mask = np.pad(person_mask, ((pad_top, 0), (pad_left, pad_right))) if person_mask is not None else None + return rgb_pad, obj_mask, person_mask + + +def recrop_input(rgb, person_mask, obj_mask, dataset_name='behave'): + "recrop input images" + exp_ratio = 1.42 + if dataset_name == 'behave': + mean_center = np.array([1008, 995]) # mean RGB image crop center + behave_size = (2048, 1536) + new_size = (int(750 * exp_ratio), int(exp_ratio * 750)) + else: + mean_center = np.array([904, 668]) # mean RGB image crop center for bottle sequences of ICAP + behave_size = (1920, 1080) + new_size = (int(593.925 * exp_ratio), int(exp_ratio * 593.925)) # mean width of bottle sequences + aspect_ratio = behave_size[1] / behave_size[0] + pad_top = mean_center[1] - new_size[0] // 2 + pad_bottom = behave_size[1] - (mean_center[1] + new_size[0] // 2) + pad_left = mean_center[0] - new_size[0] // 2 + pad_right = behave_size[0] - (mean_center[0] + new_size[0] // 2) + + # First resize to the same aspect ratio + if rgb.shape[0] / rgb.shape[1] != aspect_ratio: + rgb, obj_mask, person_mask = padTo_4x3(rgb, person_mask, obj_mask, aspect_ratio) + + # Resize to the same size as behave image, to have a comparable pixel size + rgb = cv2.resize(rgb, behave_size) + mask_ps = cv2.resize(person_mask, behave_size) + mask_obj = cv2.resize(obj_mask, behave_size) + + # Crop and resize the human + object patch + bmin, bmax = masks2bbox([mask_ps, mask_obj]) + center = (bmin + bmax) // 2 + crop_size = int(np.max(bmax - bmin) * exp_ratio) # larger crop to have background + img_crop = cv2.resize(crop(rgb, center, crop_size), new_size) + mask_ps = cv2.resize(crop(mask_ps, center, crop_size), new_size) + mask_obj = cv2.resize(crop(mask_obj, center, crop_size), new_size) + + # Pad back to have same shape as behave image + img_full = np.pad(img_crop, [[pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]) + mask_ps_full = np.pad(mask_ps, [[pad_top, pad_bottom], [pad_left, pad_right]]) + mask_obj_full = np.pad(mask_obj, [[pad_top, pad_bottom], [pad_left, pad_right]]) + + # Make sure the image shape is the same + if img_full.shape[:2] != behave_size[::-1]: + img_full = cv2.resize(img_full, behave_size) + mask_ps_full = cv2.resize(mask_ps_full, behave_size) + mask_obj_full = cv2.resize(mask_obj_full, behave_size) + return img_full, mask_ps_full, mask_obj_full + + +class DemoDataset(BaseDataset): + def __init__(self, data_paths, input_size=(224, 224), + std_coverage=3.5, # used to estimate camera translation + ): + super().__init__(data_paths, input_size) + self.std_coverage = std_coverage + + def __len__(self): + return len(self.data_paths) + + def __getitem__(self, idx): + rgb_file = self.data_paths[idx] + mask_hum, mask_obj = self.load_masks(rgb_file) + rgb_full = cv2.imread(rgb_file)[:, :, ::-1] + + return self.image2dict(mask_hum, mask_obj, rgb_full, rgb_file) + + def image2dict(self, mask_hum, mask_obj, rgb_full, rgb_file=None): + "do all the necessary preprocessing for images" + if rgb_full.shape[:2] != mask_obj.shape[:2]: + raise ValueError(f"The given object mask shape {mask_obj.shape[:2]} does not match the RGB image shape {rgb_full.shape[:2]}") + if rgb_full.shape[:2] != mask_hum.shape[:2]: + raise ValueError(f"The given human mask shape {mask_hum.shape[:2]} does not match the RGB image shape {rgb_full.shape[:2]}") + + if rgb_full.shape[:2] not in [(1080, 1920), (1536, 2048)]: + # crop and resize the image to behave image size + print(f"Recropping the input image and masks for {rgb_file}") + rgb_full, mask_hum, mask_obj = recrop_input(rgb_full, mask_hum, mask_obj) + color_h, color_w = rgb_full.shape[:2] + # Input to the first stage model: human + object crop + Kroi, objmask_fullcrop, psmask_fullcrop, rgb_fullcrop = self.crop_full_image(mask_hum.copy(), + mask_obj.copy(), + rgb_full.copy(), + [mask_hum, mask_obj], + 1.00) + # Input to the second stage model: human and object crops + Kroi_h, masko_hum, maskh_hum, rgb_hum = self.crop_full_image(mask_hum.copy(), + mask_obj.copy(), + rgb_full.copy(), + [mask_hum, mask_hum], 1.05) + Kroi_o, masko_obj, maskh_obj, rgb_obj = self.crop_full_image(mask_hum.copy(), + mask_obj.copy(), + rgb_full.copy(), + [mask_obj, mask_obj], 1.5) + # Estimate camera translation + cent_transform = np.eye(4) # the transform applied to the mesh that moves it back to kinect camera frame + bmin_ho, bmax_ho = masks2bbox([mask_hum, mask_obj]) + crop_size_ho = int(np.max(bmax_ho - bmin_ho) * 1.0) + if crop_size_ho % 2 == 1: + crop_size_ho += 1 # make sure it is an even number + is_behave = self.is_behave_dataset(rgb_full.shape[1]) + if rgb_full.shape[1] not in [2048, 1920]: + raise ValueError('the image is not normalized to BEHAVE or ICAP size!') + indices = np.indices(rgb_full.shape[:2]) + if np.sum(mask_obj > 127) < 5: + raise ValueError(f'not enough object mask found for {rgb_file}') + pts_h = np.stack([indices[1][mask_hum > 127], indices[0][mask_hum > 127]], -1) + pts_o = np.stack([indices[1][mask_obj > 127], indices[0][mask_obj > 127]], -1) + proj_cent_est = (np.mean(pts_h, 0) + np.mean(pts_o, 0)) / 2. # heuristic to obtain 2d projection center + transl_estimate = compute_translation(proj_cent_est, crop_size_ho, is_behave, self.std_coverage) + cent_transform[:3, 3] = transl_estimate / 7.0 + radius = 0.5 # don't do normalization anymore + cent = transl_estimate / 7.0 + comb = np.matmul(self.opencv2py3d, cent_transform) + R = torch.from_numpy(comb[:3, :3]).float() + T = torch.from_numpy(comb[:3, 3]).float() / (radius * 2) + data_dict = { + "R": R, + "T": T, + "K": torch.from_numpy(Kroi).float(), + "T_ho": torch.from_numpy(cent).float(), # translation for H+O + "image_path": rgb_file, + "image_size_hw": torch.tensor(self.input_size), + "images": torch.from_numpy(rgb_fullcrop).float().permute(2, 0, 1), + "masks": torch.from_numpy(np.stack([psmask_fullcrop, objmask_fullcrop], 0)).float(), + 'orig_image_size': torch.tensor([color_h, color_w]), + + # Human input to stage 2 + "images_hum": torch.from_numpy(rgb_hum).float().permute(2, 0, 1), + "masks_hum": torch.from_numpy(np.stack([maskh_hum, masko_hum], 0)).float(), + "K_hum": torch.from_numpy(Kroi_h).float(), + + # Object input to stage 2 + "images_obj": torch.from_numpy(rgb_obj).float().permute(2, 0, 1), + "masks_obj": torch.from_numpy(np.stack([maskh_obj, masko_obj], 0)).float(), + "K_obj": torch.from_numpy(Kroi_o).float(), + + # some normalization parameters + "gt_trans": cent, + 'radius': radius, + "estimated_trans": transl_estimate, + } + return data_dict + + def image2batch(self, rgb, mask_hum, mask_obj): + """ + given input image, convert it into a batch object ready for model inference + :param rgb: (h, w, 3), np array + :param mask_hum: (h, w, 3), np array + :param mask_obj: (h, w, 3), np array + :return: + """ + mask_hum = np.mean(mask_hum, -1) + mask_obj = np.mean(mask_obj, -1) + + data_dict = self.image2dict(mask_hum, mask_obj, rgb, 'input image') + # convert dict to list + new_dict = {k:[v] for k, v in data_dict.items()} + + return new_dict + + diff --git a/dataset/img_utils.py b/dataset/img_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..eb54b295e620e2487bad6dccc0aa5a9680d3b6b2 --- /dev/null +++ b/dataset/img_utils.py @@ -0,0 +1,149 @@ +""" +common functions for image operations +""" + +import cv2 +import numpy as np + + +def crop(img, center, crop_size): + """ + crop image around the given center, pad zeros for borders + :param img: + :param center: np array + :param crop_size: np array or a float size of the resulting crop + :return: a square crop around the center + """ + assert isinstance(img, np.ndarray) + h, w = img.shape[:2] + topleft = np.round(center - crop_size / 2).astype(int) + bottom_right = np.round(center + crop_size / 2).astype(int) + + x1 = max(0, topleft[0]) + y1 = max(0, topleft[1]) + x2 = min(w - 1, bottom_right[0]) + y2 = min(h - 1, bottom_right[1]) + cropped = img[y1:y2, x1:x2] + + p1 = max(0, -topleft[0]) # padding in x, top + p2 = max(0, -topleft[1]) # padding in y, top + p3 = max(0, bottom_right[0] - w + 1) # padding in x, bottom + p4 = max(0, bottom_right[1] - h + 1) # padding in y, bottom + + dim = len(img.shape) + if dim == 3: + padded = np.pad(cropped, [[p2, p4], [p1, p3], [0, 0]]) + elif dim == 2: + padded = np.pad(cropped, [[p2, p4], [p1, p3]]) + else: + raise NotImplemented + return padded + + +def resize(img, img_size, mode=cv2.INTER_LINEAR): + """ + resize image to the input + :param img: + :param img_size: (width, height) of the target image size + :param mode: + :return: + """ + h, w = img.shape[:2] + load_ratio = 1.0 * w / h + netin_ratio = 1.0 * img_size[0] / img_size[1] + assert load_ratio == netin_ratio, "image aspect ration not matching, given image: {}, net input: {}".format( + img.shape, img_size) + resized = cv2.resize(img, img_size, interpolation=mode) + return resized + + +def masks2bbox(masks, threshold=127): + """ + + :param masks: + :param threshold: + :return: bounding box corner coordinate + """ + mask_comb = np.zeros_like(masks[0], dtype=bool) + for m in masks: + mask_comb = mask_comb | (m > threshold) + + yid, xid = np.where(mask_comb) + bmin = np.array([xid.min(), yid.min()]) + bmax = np.array([xid.max(), yid.max()]) + return bmin, bmax + + +def compute_translation(crop_center, crop_size, is_behave=True, std_coverage=3.5): + """ + solve for an optimal translation that project gaussian in origin to the crop + Parameters + ---------- + crop_center: (x, y) of the crop center + crop_size: float, the size of the square crop + std_coverage: which edge point should be projected back to the edge of the 2d crop + + Returns + ------- + the estimated translation + + """ + x0, y0 = crop_center + x1, y1 = x0 + crop_size/2, y0 + x2, y2 = x0 - crop_size/2, y0 + x3, y3 = x0, y0 + crop_size/2. + # predefined kinect intrinsics + if is_behave: + fx = 979.7844 + fy = 979.840 + cx = 1018.952 + cy = 779.486 + else: + # intercap camera + fx, fy = 918.457763671875, 918.4373779296875 + cx, cy = 956.9661865234375, 555.944580078125 + + # construct the matrix + # A = np.array([ + # [fx, 0, cx-x0, cx-x0, 0, 0], + # [0, fy, cy-y0, cy-y0, 0, 0], + # [fx, 0, cx-x1, 0, cx-x1, 0], + # [0, fy, cy-y1, 0, cy-y1, 0], + # [fx, 0, cx-x2, 0, 0, cx-x2], + # [0, fy, cy-y2, 0, 0, cy-y2] + # ]) # this matrix is low-rank because columns are linearly dependent: col3 - col4 = col5 + col6 + # # find linearly dependent rows + # lambdas, V = np.linalg.eig(A) + # # print() + # # The linearly dependent row vectors + # print(lambdas == 0, np.linalg.det(A), A[lambdas == 0, :]) # some have determinant zero, some don't?? + # print(np.linalg.inv(A)) + + # A = np.array([ + # [fx, 0, cx - x0, cx - x0, 0, 0], + # [0, fy, cy - y0, cy - y0, 0, 0], + # [fx, 0, cx - x1, 0, cx - x1, 0], + # [0, fy, cy - y1, 0, cy - y1, 0], + # [fx, 0, cx - x3, 0, 0, cx - x3], + # [0, fy, cy - y3, 0, 0, cy - y3] + # ]) # this is also low rank! + # b = np.array([0, 0, -3*fx, 0, 0, -3*fy]).reshape((-1, 1)) + # print("rank of the coefficient matrix:", np.linalg.matrix_rank(A)) # rank is 5! underconstrained matrix! + # x = np.matmul(np.linalg.inv(A), b) + + # fix z0 as 0, then A is a full-rank matrix + # first two equations: origin (0, 0, 0) is projected to the crop center + # last two equations: edge point (3.5, 0, z) is projected to the edge of crop + A = np.array([ + [fx, 0, cx-x0, cx-x0], + [0, fy, cy-y0, cy-y0], + [fx, 0, fx-x1, 0], + [0, fy, cy-y1, 0] + ]) + # b = np.array([0, 0, -3.5*fx, 0]).reshape((-1, 1)) # 3.5->half of 7.0 + b = np.array([0, 0, -std_coverage * fx, 0]).reshape((-1, 1)) # 3.5->half of 7.0 + x = np.matmul(np.linalg.inv(A), b) # use 4 or 5 does not really matter, same results + + # A is always a full-rank matrix + + return x.flatten()[:3] diff --git a/demo.py b/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..9170002b248a364b696aaef6e2a2b3f0b667bda4 --- /dev/null +++ b/demo.py @@ -0,0 +1,280 @@ +""" +Demo for template-free reconstruction + +python demo.py model=ho-attn run.image_path=/BS/xxie-2/work/HDM/outputs/000000017450/k1.color.jpg run.job=sample model.predict_binary=True dataset.std_coverage=3.0 +""" +import pickle as pkl +import sys, os +import os.path as osp +from typing import Iterable, Optional + +import cv2 +from accelerate import Accelerator +from tqdm import tqdm +from glob import glob + +sys.path.append(os.getcwd()) +import hydra +import torch +import numpy as np +import imageio +from torch.utils.data import DataLoader +from pytorch3d.datasets import R2N2, collate_batched_meshes +from pytorch3d.structures import Pointclouds +from pytorch3d.renderer import PerspectiveCameras, look_at_view_transform +from pytorch3d.io import IO +import torchvision.transforms.functional as TVF +from huggingface_hub import hf_hub_download + +import training_utils +from configs.structured import ProjectConfig +from dataset.demo_dataset import DemoDataset +from model import CrossAttenHODiffusionModel, ConditionalPCDiffusionSeparateSegm +from render.pyt3d_wrapper import PcloudRenderer + + +class DemoRunner: + def __init__(self, cfg: ProjectConfig): + cfg.model.model_name, cfg.model.predict_binary = 'pc2-diff-ho-sepsegm', True + model_stage1 = ConditionalPCDiffusionSeparateSegm(**cfg.model) + cfg.model.model_name, cfg.model.predict_binary = 'diff-ho-attn', False # stage 2 does not predict segmentation + model_stage2 = CrossAttenHODiffusionModel(**cfg.model) + + # Load from checkpoint + # ckpt_file1 = os.path.join(cfg.run.code_dir_abs, f'outputs/{cfg.run.stage1_name}/single/checkpoint-latest.pth') + # self.load_checkpoint(ckpt_file1, model_stage1) + # ckpt_file2 = os.path.join(cfg.run.code_dir_abs, f'outputs/{cfg.run.stage2_name}/single/checkpoint-latest.pth') + # self.load_checkpoint(ckpt_file2, model_stage2) + # Load ckpt from hf + ckpt_file1 = hf_hub_download("xiexh20/HDM-models", f'{cfg.run.stage1_name}.pth') + self.load_checkpoint(ckpt_file1, model_stage1) + ckpt_file2 = hf_hub_download("xiexh20/HDM-models", f'{cfg.run.stage2_name}.pth') + self.load_checkpoint(ckpt_file2, model_stage2) + + self.model_stage1, self.model_stage2 = model_stage1, model_stage2 + self.model_stage1.eval() + self.model_stage2.eval() + self.model_stage1.to('cuda') + self.model_stage2.to('cuda') + + self.cfg = cfg + self.io_pc = IO() + + # For visualization + self.renderer = PcloudRenderer(image_size=cfg.dataset.image_size, radius=0.0075) + self.rend_size = cfg.dataset.image_size + self.device = 'cuda' + + def load_checkpoint(self, ckpt_file1, model_stage1): + checkpoint = torch.load(ckpt_file1, map_location='cpu') + state_dict, key = checkpoint['model'], 'model' + if any(k.startswith('module.') for k in state_dict.keys()): + state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} + print('Removed "module." from checkpoint state dict') + missing_keys, unexpected_keys = model_stage1.load_state_dict(state_dict, strict=False) + print(f'Loaded model checkpoint {key} from {ckpt_file1}') + if len(missing_keys): + print(f' - Missing_keys: {missing_keys}') + if len(unexpected_keys): + print(f' - Unexpected_keys: {unexpected_keys}') + + @torch.no_grad() + def run(self): + "simply run the demo on given images, and save the results" + # Set random seed + training_utils.set_seed(self.cfg.run.seed) + + outdir = osp.join(self.cfg.run.code_dir_abs, 'outputs/demo') + os.makedirs(outdir, exist_ok=True) + cfg = self.cfg + + # Init data + image_files = sorted(glob(cfg.run.image_path)) + data = DemoDataset(image_files, + (cfg.dataset.image_size, cfg.dataset.image_size), + cfg.dataset.std_coverage) + dataloader = DataLoader(data, batch_size=cfg.dataloader.batch_size, + collate_fn=collate_batched_meshes, + num_workers=1, shuffle=False) + dataloader = dataloader + progress_bar = tqdm(dataloader) + for batch_idx, batch in enumerate(progress_bar): + progress_bar.set_description(f'Processing batch {batch_idx:4d} / {len(dataloader):4d}') + + out_stage1, out_stage2 = self.forward_batch(batch, cfg) + + bs = len(out_stage1) + camera_full = PerspectiveCameras( + R=torch.stack(batch['R']), + T=torch.stack(batch['T']), + K=torch.stack(batch['K']), + device='cuda', + in_ndc=True) + + # save output + for i in range(bs): + image_path = str(batch['image_path']) + folder, fname = osp.basename(osp.dirname(image_path)), osp.splitext(osp.basename(image_path))[0] + out_i = osp.join(outdir, folder) + os.makedirs(out_i, exist_ok=True) + self.io_pc.save_pointcloud(data=out_stage1[i], + path=osp.join(out_i, f'{fname}_stage1.ply')) + self.io_pc.save_pointcloud(data=out_stage2[i], + path=osp.join(out_i, f'{fname}_stage2.ply')) + TVF.to_pil_image(batch['images'][i]).save(osp.join(out_i, f'{fname}_input.png')) + + # Save metadata as well + metadata = dict(index=i, + camera=camera_full[i], + image_size_hw=batch['image_size_hw'][i], + image_path=batch['image_path'][i]) + torch.save(metadata, osp.join(out_i, f'{fname}_meta.pth')) + + # Visualize + # front_camera = camera_full[i] + pc_comb = Pointclouds([out_stage1[i].points_packed(), out_stage2[i].points_packed()], + features=[out_stage1[i].features_packed(), out_stage2[i].features_packed()]) + video_file = osp.join(out_i, f'{fname}_360view.mp4') + video_writer = imageio.get_writer(video_file, format='FFMPEG', mode='I', fps=1) + + # first render front view + rend_stage1, _ = self.renderer.render(out_stage1[i], camera_full[i], mode='mask') + rend_stage2, _ = self.renderer.render(out_stage2[i], camera_full[i], mode='mask') + comb = np.concatenate([batch['images'][i].permute(1, 2, 0).cpu().numpy(), rend_stage1, rend_stage2], 1) + video_writer.append_data((comb*255).astype(np.uint8)) + + for azim in range(180, 180+360, 30): + R, T = look_at_view_transform(1.7, 0, azim, up=((0, -1, 0),), ) + side_camera = PerspectiveCameras(image_size=((self.rend_size, self.rend_size),), + device=self.device, + R=R.repeat(2, 1, 1), T=T.repeat(2, 1), + focal_length=self.rend_size * 1.5, + principal_point=((self.rend_size / 2., self.rend_size / 2.),), + in_ndc=False) + rend, mask = self.renderer.render(pc_comb, side_camera, mode='mask') + + imgs = [batch['images'][i].permute(1, 2, 0).cpu().numpy()] + imgs.extend([rend[0], rend[1]]) + video_writer.append_data((np.concatenate(imgs, 1)*255).astype(np.uint8)) + print(f"Visualization saved to {out_i}") + + @torch.no_grad() + def forward_batch(self, batch, cfg): + """ + forward one batch + :param batch: + :param cfg: + :return: predicted point clouds of stage 1 and 2 + """ + camera_full = PerspectiveCameras( + R=torch.stack(batch['R']), + T=torch.stack(batch['T']), + K=torch.stack(batch['K']), + device='cuda', + in_ndc=True) + out_stage1 = self.model_stage1.forward_sample(num_points=cfg.dataset.max_points, + camera=camera_full, + image_rgb=torch.stack(batch['images']).to('cuda'), + mask=torch.stack(batch['masks']).to('cuda'), + scheduler=cfg.run.diffusion_scheduler, + num_inference_steps=cfg.run.num_inference_steps, + ) + # segment and normalize human/object + bs = len(out_stage1) + pred_hum, pred_obj = [], [] # predicted human/object points + cent_hum_pred, cent_obj_pred = [], [] + radius_hum_pred, radius_obj_pred = [], [] + T_hum, T_obj = [], [] + num_samples = int(cfg.dataset.max_points / 2) + for i in range(bs): + pc: Pointclouds = out_stage1[i] + vc = pc.features_packed().cpu() # (P, 3), human is light blue [0.1, 1.0, 1.0], object light green [0.5, 1.0, 0] + points = pc.points_packed().cpu() # (P, 3) + mask_hum = vc[:, 2] > 0.5 + pc_hum, pc_obj = points[mask_hum], points[~mask_hum] + # Up/Down-sample the points + pc_obj = self.upsample_predicted_pc(num_samples, pc_obj) + pc_hum = self.upsample_predicted_pc(num_samples, pc_hum) + + # Normalize + cent_hum, cent_obj = torch.mean(pc_hum, 0, keepdim=True), torch.mean(pc_obj, 0, keepdim=True) + scale_hum = torch.sqrt(torch.sum((pc_hum - cent_hum) ** 2, -1).max()) + scale_obj = torch.sqrt(torch.sum((pc_obj - cent_obj) ** 2, -1).max()) + pc_hum = (pc_hum - cent_hum) / (2 * scale_hum) + pc_obj = (pc_obj - cent_obj) / (2 * scale_obj) + # Also update camera parameters for separate human + object + T_hum_scaled = (batch['T_ho'][i] + cent_hum.squeeze(0)) / (2 * scale_hum) + T_obj_scaled = (batch['T_ho'][i] + cent_obj.squeeze(0)) / (2 * scale_obj) + + pred_hum.append(pc_hum) + pred_obj.append(pc_obj) + cent_hum_pred.append(cent_hum.squeeze(0)) + cent_obj_pred.append(cent_obj.squeeze(0)) + T_hum.append(T_hum_scaled * torch.tensor([-1, -1, 1])) # apply opencv to pytorch3d transform: flip x and y + T_obj.append(T_obj_scaled * torch.tensor([-1, -1, 1])) + radius_hum_pred.append(scale_hum) + radius_obj_pred.append(scale_obj) + # Pack data into a new batch dict + camera_hum = PerspectiveCameras( + R=torch.stack(batch['R']), + T=torch.stack(T_hum), + K=torch.stack(batch['K_hum']), + device='cuda', + in_ndc=True + ) + camera_obj = PerspectiveCameras( + R=torch.stack(batch['R']), + T=torch.stack(T_obj), + K=torch.stack(batch['K_obj']), # the camera should be human/object specific!!! + device='cuda', + in_ndc=True + ) + # use pc from predicted + pc_hum = Pointclouds([x.to('cuda') for x in pred_hum]) + pc_obj = Pointclouds([x.to('cuda') for x in pred_obj]) + # use center and radius from predicted + cent_hum = torch.stack(cent_hum_pred, 0).to('cuda') + cent_obj = torch.stack(cent_obj_pred, 0).to('cuda') # B, 3 + radius_hum = torch.stack(radius_hum_pred, 0).to('cuda') # B, 1 + radius_obj = torch.stack(radius_obj_pred, 0).to('cuda') + out_stage2: Pointclouds = self.model_stage2.forward_sample( + num_points=num_samples, + camera=camera_hum, + image_rgb=torch.stack(batch['images_hum'], 0).to('cuda'), + mask=torch.stack(batch['masks_hum'], 0).to('cuda'), + gt_pc=pc_hum, + rgb_obj=torch.stack(batch['images_obj'], 0).to('cuda'), + mask_obj=torch.stack(batch['masks_obj'], 0).to('cuda'), + pc_obj=pc_obj, + camera_obj=camera_obj, + cent_hum=cent_hum, + cent_obj=cent_obj, + radius_hum=radius_hum.unsqueeze(-1), + radius_obj=radius_obj.unsqueeze(-1), + sample_from_interm=True, + noise_step=cfg.run.sample_noise_step) + return out_stage1, out_stage2 + + def upsample_predicted_pc(self, num_samples, pc_obj): + """ + Up/Downsample the points to given number + :param num_samples: the target number + :param pc_obj: (N, 3) + :return: (num_samples, 3) + """ + if len(pc_obj) > num_samples: + ind_obj = np.random.choice(len(pc_obj), num_samples) + else: + ind_obj = np.concatenate([np.arange(len(pc_obj)), np.random.choice(len(pc_obj), num_samples - len(pc_obj))]) + pc_obj = pc_obj.clone()[torch.from_numpy(ind_obj).long().to(pc_obj.device)] + return pc_obj + + +@hydra.main(config_path='configs', config_name='configs', version_base='1.1') +def main(cfg: ProjectConfig): + runner = DemoRunner(cfg) + runner.run() + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/diffusion_utils.py b/diffusion_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5de761c18576109cfcb8fd36243424b19c78862a --- /dev/null +++ b/diffusion_utils.py @@ -0,0 +1,313 @@ +import math +from typing import List, Optional, Sequence, Union + +import imageio +import logging +import numpy as np +import torch +import torch.utils.data +from PIL import Image +from torch.distributions import Normal +from torchvision.transforms.functional import to_pil_image +from torchvision.utils import make_grid +from tqdm import tqdm, trange +from pytorch3d.renderer import ( + AlphaCompositor, + NormWeightedCompositor, + OrthographicCameras, + PointsRasterizationSettings, + PointsRasterizer, + PointsRenderer, + look_at_view_transform) +from pytorch3d.renderer.cameras import CamerasBase +from pytorch3d.structures import Pointclouds +from pytorch3d.structures.pointclouds import join_pointclouds_as_batch + + +# Disable unnecessary imageio logging +logging.getLogger("imageio_ffmpeg").setLevel(logging.ERROR) + + +def rotation_matrix(axis, theta): + """ + Return the rotation matrix associated with counterclockwise rotation about + the given axis by theta radians. + """ + axis = np.asarray(axis) + axis = axis / np.sqrt(np.dot(axis, axis)) + a = np.cos(theta / 2.0) + b, c, d = -axis * np.sin(theta / 2.0) + aa, bb, cc, dd = a * a, b * b, c * c, d * d + bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d + return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)], + [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)], + [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]]) + + +def rotate(vertices, faces): + ''' + vertices: [numpoints, 3] + ''' + M = rotation_matrix([0, 1, 0], np.pi / 2).transpose() + N = rotation_matrix([1, 0, 0], -np.pi / 4).transpose() + K = rotation_matrix([0, 0, 1], np.pi).transpose() + + v, f = vertices[:, [1, 2, 0]].dot(M).dot(N).dot(K), faces[:, [1, 2, 0]] + return v, f + + +def norm(v, f): + v = (v - v.min()) / (v.max() - v.min()) - 0.5 + + return v, f + + +def getGradNorm(net): + pNorm = torch.sqrt(sum(torch.sum(p ** 2) for p in net.parameters())) + gradNorm = torch.sqrt(sum(torch.sum(p.grad ** 2) for p in net.parameters())) + return pNorm, gradNorm + + +def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1 and m.weight is not None: + torch.nn.init.xavier_normal_(m.weight) + elif classname.find('BatchNorm') != -1: + m.weight.data.normal_() + m.bias.data.fill_(0) + + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + # Assumes data is integers [0, 1] + assert x.shape == means.shape == log_scales.shape + px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales)) + + centered_x = x - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 0.5) + cdf_plus = px0.cdf(plus_in) + min_in = inv_stdv * (centered_x - .5) + cdf_min = px0.cdf(min_in) + log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus) * 1e-12)) + log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min) * 1e-12)) + cdf_delta = cdf_plus - cdf_min + + log_probs = torch.where( + x < 0.001, log_cdf_plus, + torch.where(x > 0.999, log_one_minus_cdf_min, + torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta) * 1e-12)))) + assert log_probs.shape == x.shape + return log_probs + + +def fig2img(fig): + """Convert a Matplotlib figure to a PIL Image and return it""" + import io + buf = io.BytesIO() + fig.savefig(buf) + buf.seek(0) + img = Image.open(buf) + return img + + +@torch.no_grad() +def visualize_distance_transform( + path_stem: str, + images: torch.Tensor, +) -> str: + output_file_image = f'{path_stem}.png' + if images.shape[3] in [1, 3]: # convert to (B, C, H, W) + images = images.permute(0, 3, 1, 2) + images = images[:, -1:] # (B, 1, H, W) # get only distances (not vectors for now, for simplicity) + image_grid = make_grid(images, nrow=int(math.sqrt(len(images))), pad_value=1, normalize=True) + to_pil_image(image_grid).save(output_file_image) + return output_file_image + + +@torch.no_grad() +def visualize_image( + path_stem: str, + images: torch.Tensor, + mean: Union[torch.Tensor, float] = 0.5, + std: Union[torch.Tensor, float] = 0.5, +) -> str: + output_file_image = f'{path_stem}.png' + if images.shape[3] in [1, 3, 4]: # convert to (B, C, H, W) + images = images.permute(0, 3, 1, 2) + if images.shape[1] in [3, 4]: # normalize (single-channel images are not normalized) + images[:, :3] = images[:, :3] * std + mean # denormalize (color channels only, not alpha channel) + if images.shape[1] == 4: # normalize (single-channel images are not normalized) + image_alpha = images[:, 3:] # (B, 1, H, W) + bg_color = torch.tensor([230, 220, 250], device=images.device).reshape(1, 3, 1, 1) / 255 + images = images[:, :3] * image_alpha + bg_color * (1 - image_alpha) # (B, 3, H, W) + image_grid = make_grid(images, nrow=int(math.sqrt(len(images))), pad_value=1) + to_pil_image(image_grid).save(output_file_image) + return output_file_image + + +def ensure_point_cloud_has_colors(pointcloud: Pointclouds): + if pointcloud.features_padded() is None: + pointcloud = type(pointcloud)(points=pointcloud.points_padded(), + normals=pointcloud.normals_padded(), features=torch.zeros_like(pointcloud.points_padded())) + return pointcloud + + +@torch.no_grad() +def render_pointcloud_batch_pytorch3d( + cameras: CamerasBase, + pointclouds: Pointclouds, + image_size: int = 224, + radius: float = 0.01, + points_per_pixel: int = 10, + background_color: Sequence[float] = (0.78431373, 0.78431373, 0.78431373), + compositor: str = 'norm_weighted' +): + # Define the settings for rasterization and shading. Here we set the output image to be of size + # 512x512. As we are rendering images for visualization purposes only we will set faces_per_pixel=1 + # and blur_radius=0.0. Refer to rasterize_points.py for explanations of these parameters. + raster_settings = PointsRasterizationSettings( + image_size=image_size, + radius=radius, + points_per_pixel=points_per_pixel, + ) + + # Rasterizer + rasterizer = PointsRasterizer(cameras=cameras, raster_settings=raster_settings) + + # Compositor + if compositor == 'alpha': + compositor = AlphaCompositor(background_color=background_color) + elif compositor == 'norm_weighted': + compositor = NormWeightedCompositor(background_color=background_color) + else: + raise ValueError(compositor) + + # Create a points renderer by compositing points using an weighted compositor (3D points are + # weighted according to their distance to a pixel and accumulated using a weighted sum) + renderer = PointsRenderer(rasterizer=rasterizer, compositor=compositor) + + # We cannot render a point cloud without colors, so add them if the pointcloud does + # not already have them + pointclouds = ensure_point_cloud_has_colors(pointclouds) + + # Render batch of image + images = renderer(pointclouds) + + return images + + +@torch.no_grad() +def visualize_pointcloud_batch_pytorch3d( + pointclouds: Pointclouds, + output_file_video: Optional[str] = None, + output_file_image: Optional[str] = None, + cameras: Optional[CamerasBase] = None, # if None, we rotate + scale_factor: float = 1.0, + num_frames: int = 1, # note that it takes a while with 30 * batch_size frames + elev: int = 30, +): + """Saves a video and a single image of a point cloud""" + assert 360 % num_frames == 0, 'please select a better number of frames' + + # Sizes + B, N, C, F = *(pointclouds.points_padded().shape), num_frames + device = pointclouds.device + + # If a camera has not been provided, we render from a rotating view around an image + if cameras is None: + + # Create view transforms - R is (F, 3, 3) and T is (F, 3) + R, T = look_at_view_transform(dist=10.0, elev=elev, azim=list(range(0, 360, 360 // F)), degrees=True, device=device) + + # Repeat + R = R.repeat_interleave(B, dim=0) # (F * B, 3, 3) + T = T.repeat_interleave(B, dim=0) # (F * B, 3) + points = pointclouds.points_padded().tile(F, 1, 1) # (F * B, num_points, 3) + colors = (torch.zeros_like(points) if pointclouds.features_padded() is None else + pointclouds.features_padded().tile(F, 1, 1)) # (F * B, num_points, 3) + + # Initialize batch of cameras + cameras = OrthographicCameras(focal_length=(0.25 * scale_factor), device=device, R=R, T=T) + + # Wrap in Pointclouds (with color, even if the original point cloud had no color) + pointclouds = Pointclouds(points=points, features=colors).to(device) + + # Render image + images = render_pointcloud_batch_pytorch3d(cameras, pointclouds) + + # Convert images into grid + image_grids = [] + images_for_grids = images.reshape(F, B, *images.shape[1:]).permute(0, 1, 4, 2, 3) + for image_for_grids in images_for_grids: + image_grid = make_grid(image_for_grids, nrow=int(math.sqrt(B)), pad_value=1) + image_grids.append(image_grid) + image_grids = torch.stack(image_grids, dim=0) + image_grids = image_grids.detach().cpu() + + # Save image + if output_file_image is not None: + to_pil_image(image_grids[0]).save(output_file_image) + + # Save video + if output_file_video: + video = (image_grids * 255).permute(0, 2, 3, 1).to(torch.uint8).numpy() + imageio.mimwrite(output_file_video, video, fps=10) + + +@torch.no_grad() +def visualize_pointcloud_evolution_pytorch3d( + pointclouds: Pointclouds, + output_file_video: str, + camera: Optional[CamerasBase] = None, # if None, we rotate + scale_factor: float = 1.0, +): + + # Device + B, device = len(pointclouds), pointclouds.device + + # Cameras + if camera is None: + R, T = look_at_view_transform(dist=10.0, elev=30, azim=0, device=device) + camera = OrthographicCameras(focal_length=(0.25 * scale_factor), device=device, R=R, T=T) + + # Render + frames = render_pointcloud_batch_pytorch3d(camera, pointclouds) + + # Save video + video = (frames.detach().cpu() * 255).to(torch.uint8).numpy() + imageio.mimwrite(output_file_video, video, fps=10) + + +def get_camera_index(cameras: CamerasBase, index: Optional[int] = None): + if index is None: + return cameras + kwargs = dict( + R=cameras.R[index].unsqueeze(0), + T=cameras.T[index].unsqueeze(0), + K=cameras.K[index].unsqueeze(0) if cameras.K is not None else None, + ) + if hasattr(cameras, 'focal_length'): + kwargs['focal_length'] = cameras.focal_length[index].unsqueeze(0) + if hasattr(cameras, 'principal_point'): + kwargs['principal_point'] = cameras.principal_point[index].unsqueeze(0) + return type(cameras)(**kwargs).to(cameras.device) + + +def get_metadata(item) -> str: + s = '-------------\n' + for key in item.keys(): + value = item[key] + if torch.is_tensor(value) and value.numel() < 25: + value_str = value + elif torch.is_tensor(value): + value_str = value.shape + elif isinstance(value, str): + value_str = value + elif isinstance(value, list) and 0 < len(value) and len(value) < 25 and isinstance(value[0], str): + value_str = value + elif isinstance(value, dict): + value_str = str({k: type(v) for k, v in value.items()}) + else: + value_str = type(value) + s += f"{key:<30} {value_str}\n" + return s diff --git a/examples/017450/k1.color.jpg b/examples/017450/k1.color.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8163a625027f9a49109a862400073d977b84d0a1 Binary files /dev/null and b/examples/017450/k1.color.jpg differ diff --git a/examples/017450/k1.obj_rend_mask.png b/examples/017450/k1.obj_rend_mask.png new file mode 100644 index 0000000000000000000000000000000000000000..b96e85199b2c498cc8f9edd5b994fc24a98f2edd Binary files /dev/null and b/examples/017450/k1.obj_rend_mask.png differ diff --git a/examples/017450/k1.person_mask.png b/examples/017450/k1.person_mask.png new file mode 100644 index 0000000000000000000000000000000000000000..3a30e94d90552416dcec779d9fac76b89a839f16 Binary files /dev/null and b/examples/017450/k1.person_mask.png differ diff --git a/model/__init__.py b/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c2cfef041c03fbb5378b65dce56988386b407d59 --- /dev/null +++ b/model/__init__.py @@ -0,0 +1,28 @@ +from configs.structured import ProjectConfig +from .model import ConditionalPointCloudDiffusionModel +from .model_coloring import PointCloudColoringModel +from .model_utils import set_requires_grad +from .model_diff_data import ConditionalPCDiffusionSeparateSegm +from .model_hoattn import CrossAttenHODiffusionModel + +def get_model(cfg: ProjectConfig): + if cfg.model.model_name == 'pc2-diff': + model = ConditionalPointCloudDiffusionModel(**cfg.model) + elif cfg.model.model_name == 'pc2-diff-ho-sepsegm': + model = ConditionalPCDiffusionSeparateSegm(**cfg.model) + print("Using a separate model to predict segmentation label") + elif cfg.model.model_name == 'diff-ho-attn': + model = CrossAttenHODiffusionModel(**cfg.model) + print("Using separate model for human + object with cross attention.") + else: + raise NotImplementedError + if cfg.run.freeze_feature_model: + set_requires_grad(model.feature_model, False) + return model + + +def get_coloring_model(cfg: ProjectConfig): + model = PointCloudColoringModel(**cfg.model) + if cfg.run.freeze_feature_model: + set_requires_grad(model.feature_model, False) + return model diff --git a/model/feature_model.py b/model/feature_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5ea6795bfdc36e080fbef3b7e1ebb00eed1609c5 --- /dev/null +++ b/model/feature_model.py @@ -0,0 +1,160 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers import ModelMixin +from timm.models.vision_transformer import VisionTransformer, resize_pos_embed +from torch import Tensor +from torchvision.transforms import functional as TVF + + +IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) + +MODEL_URLS = { + 'vit_base_patch16_224_mae': 'https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth', + 'vit_small_patch16_224_msn': 'https://dl.fbaipublicfiles.com/msn/vits16_800ep.pth.tar', + 'vit_large_patch7_224_msn': 'https://dl.fbaipublicfiles.com/msn/vitl7_200ep.pth.tar', +} + +NORMALIZATION = { + 'vit_base_patch16_224_mae': (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), + 'vit_small_patch16_224_msn': (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), + 'vit_large_patch7_224_msn': (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), +} + +MODEL_KWARGS = { + 'vit_base_patch16_224_mae': dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, + ), + 'vit_small_patch16_224_msn': dict( + patch_size=16, embed_dim=384, depth=12, num_heads=6, + ), + 'vit_large_patch7_224_msn': dict( + patch_size=7, embed_dim=1024, depth=24, num_heads=16, + ) +} + + +class FeatureModel(ModelMixin, ConfigMixin): + + @register_to_config + def __init__( + self, + image_size: int = 224, + model_name: str = 'vit_small_patch16_224_mae', + global_pool: str = '', # '' or 'token' + ) -> None: + super().__init__() + self.model_name = model_name + + # Identity + if self.model_name == 'identity': + return + + # Create model + self.model = VisionTransformer( + img_size=image_size, num_classes=0, global_pool=global_pool, + **MODEL_KWARGS[model_name]) + + # Model properties + self.feature_dim = self.model.embed_dim + self.mean, self.std = NORMALIZATION[model_name] + + # # Modify MSN model with output head from training + # if model_name.endswith('msn'): + # use_bn = True + # emb_dim = (192 if 'tiny' in model_name else 384 if 'small' in model_name else + # 768 if 'base' in model_name else 1024 if 'large' in model_name else 1280) + # hidden_dim = 2048 + # output_dim = 256 + # self.model.fc = None + # fc = OrderedDict([]) + # fc['fc1'] = torch.nn.Linear(emb_dim, hidden_dim) + # if use_bn: + # fc['bn1'] = torch.nn.BatchNorm1d(hidden_dim) + # fc['gelu1'] = torch.nn.GELU() + # fc['fc2'] = torch.nn.Linear(hidden_dim, hidden_dim) + # if use_bn: + # fc['bn2'] = torch.nn.BatchNorm1d(hidden_dim) + # fc['gelu2'] = torch.nn.GELU() + # fc['fc3'] = torch.nn.Linear(hidden_dim, output_dim) + # self.model.fc = torch.nn.Sequential(fc) + + # Load pretrained checkpoint + checkpoint = torch.hub.load_state_dict_from_url(MODEL_URLS[model_name]) + if 'model' in checkpoint: + state_dict = checkpoint['model'] + elif 'target_encoder' in checkpoint: + state_dict = checkpoint['target_encoder'] + state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} + # NOTE: Comment the line below if using the projection head, uncomment if not using it + # See https://github.com/facebookresearch/msn/blob/81cb855006f41cd993fbaad4b6a6efbb486488e6/src/msn_train.py#L490-L502 + # for more info about the projection head + state_dict = {k: v for k, v in state_dict.items() if not k.startswith('fc.')} + else: + raise NotImplementedError() + state_dict['pos_embed'] = resize_pos_embed(state_dict['pos_embed'], self.model.pos_embed) + self.model.load_state_dict(state_dict) + self.model.eval() + + # # Modify MSN model with output head from training + # if model_name.endswith('msn'): + # self.fc = self.model.fc + # del self.model.fc + # else: + # self.fc = nn.Identity() + + # NOTE: I've disabled the whole projection head stuff for simplicity for now + self.fc = nn.Identity() + + def denormalize(self, img: Tensor): + img = TVF.normalize(img, mean=[-m/s for m, s in zip(self.mean, self.std)], std=[1/s for s in self.std]) + return torch.clip(img, 0, 1) + + def normalize(self, img: Tensor): + return TVF.normalize(img, mean=self.mean, std=self.std) + + def forward( + self, + x: Tensor, + return_type: str = 'features', + return_upscaled_features: bool = True, + return_projection_head_output: bool = False, + ): + """Normalizes the input `x` and runs it through `model` to obtain features""" + assert return_type in {'cls_token', 'features', 'all'} + + # Identity + if self.model_name == 'identity': + return x + + # Normalize and forward + B, C, H, W = x.shape + x = self.normalize(x) + feats = self.model(x) + + # Reshape to image-like size + if return_type in {'features', 'all'}: + B, T, D = feats.shape + assert math.sqrt(T - 1).is_integer() + HW_down = int(math.sqrt(T - 1)) # subtract one for CLS token + output_feats: Tensor = feats[:, 1:, :].reshape(B, HW_down, HW_down, D).permute(0, 3, 1, 2) # (B, D, H_down, W_down) + if return_upscaled_features: + output_feats = F.interpolate(output_feats, size=(H, W), mode='bilinear', + align_corners=False) # (B, D, H_orig, W_orig) + + # Head for MSN + output_cls = feats[:, 0] + if return_projection_head_output and return_type in {'cls_token', 'all'}: + output_cls = self.fc(output_cls) + + # Return + if return_type == 'cls_token': + return output_cls + elif return_type == 'features': + return output_feats + else: + return output_cls, output_feats diff --git a/model/model.py b/model/model.py new file mode 100644 index 0000000000000000000000000000000000000000..d536ce36d8814359e8ef6fb574f1a7c55b38f741 --- /dev/null +++ b/model/model.py @@ -0,0 +1,303 @@ +import inspect +import random +from typing import Optional + +import torch +import torch.nn.functional as F +from diffusers.schedulers.scheduling_ddpm import DDPMScheduler +from diffusers.schedulers.scheduling_ddim import DDIMScheduler +from diffusers.schedulers.scheduling_pndm import PNDMScheduler +from pytorch3d.implicitron.dataset.data_loader_map_provider import FrameData +from pytorch3d.renderer.cameras import CamerasBase +from pytorch3d.structures import Pointclouds +from torch import Tensor +from tqdm import tqdm + +from .model_utils import get_num_points, get_custom_betas +from .point_cloud_model import PointCloudModel +from .projection_model import PointCloudProjectionModel + + +class ConditionalPointCloudDiffusionModel(PointCloudProjectionModel): + + def __init__( + self, + beta_start: float, + beta_end: float, + beta_schedule: str, + point_cloud_model: str, + point_cloud_model_embed_dim: int, + **kwargs, # projection arguments + ): + super().__init__(**kwargs) + + # Checks + if not self.predict_shape: + raise NotImplementedError('Must predict shape if performing diffusion.') + + # Create diffusion model schedulers which define the sampling timesteps + self.dm_pred_type = kwargs.get('dm_pred_type', "epsilon") + assert self.dm_pred_type in ['epsilon','sample'] + scheduler_kwargs = {"prediction_type": self.dm_pred_type} + if beta_schedule == 'custom': + scheduler_kwargs.update(dict(trained_betas=get_custom_betas(beta_start=beta_start, beta_end=beta_end))) + else: + scheduler_kwargs.update(dict(beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule)) + self.schedulers_map = { + 'ddpm': DDPMScheduler(**scheduler_kwargs, clip_sample=False), + 'ddim': DDIMScheduler(**scheduler_kwargs, clip_sample=False), + 'pndm': PNDMScheduler(**scheduler_kwargs), + } + self.scheduler = self.schedulers_map['ddpm'] # this can be changed for inference + + # Create point cloud model for processing point cloud at each diffusion step + self.init_pcloud_model(kwargs, point_cloud_model, point_cloud_model_embed_dim) + + self.load_sample_init = kwargs.get('load_sample_init', False) + self.sample_init_scale = kwargs.get('sample_init_scale', 1.0) + self.test_init_with_gtpc = kwargs.get('test_init_with_gtpc', False) + + self.consistent_center = kwargs.get('consistent_center', False) + self.cam_noise_std = kwargs.get('cam_noise_std', 0.0) # add noise to camera based on timestamps + + def init_pcloud_model(self, kwargs, point_cloud_model, point_cloud_model_embed_dim): + self.point_cloud_model = PointCloudModel( + model_type=point_cloud_model, + embed_dim=point_cloud_model_embed_dim, + in_channels=self.in_channels, + out_channels=self.out_channels, # voxel resolution multiplier is 1. + voxel_resolution_multiplier=kwargs.get('voxel_resolution_multiplier', 1) + ) + + def forward_train( + self, + pc: Pointclouds, + camera: Optional[CamerasBase], + image_rgb: Optional[Tensor], + mask: Optional[Tensor], + return_intermediate_steps: bool = False, + **kwargs + ): + + # Normalize colors and convert to tensor + x_0 = self.point_cloud_to_tensor(pc, normalize=True, scale=True) # this will not pack the point colors + B, N, D = x_0.shape + + # Sample random noise + noise = torch.randn_like(x_0) + if self.consistent_center: + # modification suggested by https://arxiv.org/pdf/2308.07837.pdf + noise = noise - torch.mean(noise, dim=1, keepdim=True) + + # Sample random timesteps for each point_cloud + timestep = torch.randint(0, self.scheduler.num_train_timesteps, (B,), + device=self.device, dtype=torch.long) + + # Add noise to points + x_t = self.scheduler.add_noise(x_0, noise, timestep) # diffusion noisy adding, only add to the coordinate, not features + + # add noise to the camera pose, based on timestamps + if self.cam_noise_std > 0.000001: + # the noise is very different + camera = camera.clone() + camT = camera.T # (B, 3) + dist = torch.sqrt(torch.sum(camT**2, -1, keepdim=True)) + nratio = timestep[:, None] / self.scheduler.num_train_timesteps # time-dependent noise + tnoise = torch.randn(B, 3).to(dist.device)/3. * dist * self.cam_noise_std * nratio + camera.T = camera.T + tnoise + + # Conditioning, the pixel-aligned feature is based on points with noise (new points) + x_t_input = self.get_diffu_input(camera, image_rgb, mask, timestep, x_t, **kwargs) + + # Forward + loss, noise_pred = self.compute_loss(noise, timestep, x_0, x_t_input) + + # Whether to return intermediate steps + if return_intermediate_steps: + return loss, (x_0, x_t, noise, noise_pred) + + return loss + + def compute_loss(self, noise, timestep, x_0, x_t_input): + x_pred = torch.zeros_like(x_0) + if self.self_conditioning: + # self conditioning, from https://openreview.net/pdf?id=3itjR9QxFw + if random.uniform(0, 1.) > 0.5: + with torch.no_grad(): + x_pred = self.point_cloud_model(torch.cat([x_t_input, x_pred], -1), timestep) + noise_pred = self.point_cloud_model(torch.cat([x_t_input, x_pred], -1), timestep) + else: + noise_pred = self.point_cloud_model(x_t_input, timestep) + # Check + if not noise_pred.shape == noise.shape: + raise ValueError(f'{noise_pred.shape=} and {noise.shape=}') + # Loss + if self.dm_pred_type == 'epsilon': + loss = F.mse_loss(noise_pred, noise) + elif self.dm_pred_type == 'sample': + loss = F.mse_loss(noise_pred, x_0) # predicting sample + else: + raise NotImplementedError + return loss, noise_pred + + def get_diffu_input(self, camera, image_rgb, mask, timestep, x_t, **kwargs): + "return: (B, N, D), the exact input to the diffusion model, x_t: (B, N, 3)" + x_t_input = self.get_input_with_conditioning(x_t, camera=camera, + image_rgb=image_rgb, mask=mask, t=timestep) + return x_t_input + + @torch.no_grad() + def forward_sample( + self, + num_points: int, + camera: Optional[CamerasBase], + image_rgb: Optional[Tensor], + mask: Optional[Tensor], + # Optional overrides + scheduler: Optional[str] = 'ddpm', + # Inference parameters + num_inference_steps: Optional[int] = 1000, + eta: Optional[float] = 0.0, # for DDIM + # Whether to return all the intermediate steps in generation + return_sample_every_n_steps: int = -1, + # Whether to disable tqdm + disable_tqdm: bool = False, + gt_pc: Pointclouds = None, + **kwargs + ): + + # Get scheduler from mapping, or use self.scheduler if None + scheduler = self.scheduler if scheduler is None else self.schedulers_map[scheduler] + + # Get the size of the noise + N = num_points + B = 1 if image_rgb is None else image_rgb.shape[0] + D = self.get_x_T_channel() + device = self.device if image_rgb is None else image_rgb.device + + sample_from_interm = kwargs.get('sample_from_interm', False) + interm_steps = kwargs.get('noise_step') if sample_from_interm else -1 + x_t = self.initialize_x_T(device, gt_pc, (B, N, D), interm_steps, scheduler) + x_pred = torch.zeros_like(x_t) + + # Set timesteps + extra_step_kwargs = self.setup_reverse_process(eta, num_inference_steps, scheduler) + + # Loop over timesteps + all_outputs = [] + return_all_outputs = (return_sample_every_n_steps > 0) + progress_bar = tqdm(scheduler.timesteps.to(device), desc=f'Sampling ({x_t.shape})', disable=disable_tqdm) + + for i, t in enumerate(progress_bar): + add_interm_output = (return_all_outputs and ( + i % return_sample_every_n_steps == 0 or i == len(scheduler.timesteps) - 1)) + # Conditioning + x_t_input = self.get_diffu_input(camera, image_rgb, mask, t, x_t, **kwargs) + if self.self_conditioning: + x_t_input = torch.cat([x_t_input, x_pred], -1) # add self-conditioning + inference_binary = (i == len(progress_bar) - 1) | add_interm_output + # One reverse step with conditioning + x_t = self.reverse_step(extra_step_kwargs, scheduler, t, x_t, x_t_input, + inference_binary=inference_binary) # (B, N, D), D=3 or 4 + x_pred = x_t # for next iteration self conditioning + + # Append to output list if desired + if add_interm_output: + all_outputs.append(x_t) + + # Convert output back into a point cloud, undoing normalization and scaling + output = self.tensor_to_point_cloud(x_t, denormalize=True, unscale=True) # this convert the points back to original scale + if return_all_outputs: + all_outputs = torch.stack(all_outputs, dim=1) # (B, sample_steps, N, D) + all_outputs = [self.tensor_to_point_cloud(o, denormalize=True, unscale=True) for o in all_outputs] + + return (output, all_outputs) if return_all_outputs else output + + def get_x_T_channel(self): + D = 3 + (self.color_channels if self.predict_color else 0) + return D + + def initialize_x_T(self, device, gt_pc, shape, interm_steps:int=-1, scheduler=None): + B, N, D = shape + # Sample noise initialization + if interm_steps > 0: + # Sample from some intermediate steps + x_0 = self.point_cloud_to_tensor(gt_pc, normalize=True, scale=True) + noise = torch.randn(B, N, D, device=device) + + # always make sure the noise does not change the pc center, this is important to reduce 0.1cm CD! + noise = noise - torch.mean(noise, dim=1, keepdim=True) + + x_t = scheduler.add_noise(x_0, noise, torch.tensor([interm_steps - 1] * B).long().to(device)) # Add noise + else: + # Sample from random Gaussian + x_t = torch.randn(B, N, D, device=device) + + x_t = x_t * self.sample_init_scale # for test + if self.consistent_center: + x_t = x_t - torch.mean(x_t, dim=1, keepdim=True) + return x_t + + def reverse_step(self, extra_step_kwargs, scheduler, t, x_t, x_t_input, **kwargs): + """ + run one reverse step to compute x_t + :param extra_step_kwargs: + :param scheduler: + :param t: [1], diffusion time step + :param x_t: (B, N, 3) + :param x_t_input: conditional features (B, N, F) + :param kwargs: other configurations to run diffusion step + :return: denoised x_t + """ + B = x_t.shape[0] + # Forward + noise_pred = self.point_cloud_model(x_t_input, t.reshape(1).expand(B)) + if self.consistent_center: + assert self.dm_pred_type != 'sample', 'incompatible dm predition type for CCD!' + # suggested by the CCD-3DR paper + noise_pred = noise_pred - torch.mean(noise_pred, dim=1, keepdim=True) + # Step + x_t = scheduler.step(noise_pred, t, x_t, **extra_step_kwargs).prev_sample + if self.consistent_center: + x_t = x_t - torch.mean(x_t, dim=1, keepdim=True) + return x_t + + def setup_reverse_process(self, eta, num_inference_steps, scheduler): + """ + setup diffusion chain, and others. + """ + accepts_offset = "offset" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + extra_set_kwargs = {"offset": 1} if accepts_offset else {} + scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + # Prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(scheduler.step).parameters.keys()) + extra_step_kwargs = {"eta": eta} if accepts_eta else {} + return extra_step_kwargs + + def forward(self, batch: FrameData, mode: str = 'train', **kwargs): + """ + A wrapper around the forward method for training and inference + """ + if isinstance(batch, dict): # fixes a bug with multiprocessing where batch becomes a dict + batch = FrameData(**batch) # it really makes no sense, I do not understand it + + if mode == 'train': + return self.forward_train( + pc=batch.sequence_point_cloud, + camera=batch.camera, + image_rgb=batch.image_rgb, + mask=batch.fg_probability, + **kwargs) + elif mode == 'sample': + num_points = kwargs.pop('num_points', get_num_points(batch.sequence_point_cloud)) + return self.forward_sample( + num_points=num_points, + camera=batch.camera, + image_rgb=batch.image_rgb, + mask=batch.fg_probability, + **kwargs) + else: + raise NotImplementedError() \ No newline at end of file diff --git a/model/model_coloring.py b/model/model_coloring.py new file mode 100644 index 0000000000000000000000000000000000000000..3d7ba33202a3dc876f816080539bdf7da50023df --- /dev/null +++ b/model/model_coloring.py @@ -0,0 +1,84 @@ +from typing import Optional + +import torch +import torch.nn.functional as F +from pytorch3d.implicitron.dataset.data_loader_map_provider import FrameData +from pytorch3d.renderer.cameras import CamerasBase +from pytorch3d.structures import Pointclouds +from torch import Tensor + +from .point_cloud_transformer_model import PointCloudTransformerModel +from .projection_model import PointCloudProjectionModel + +class PointCloudColoringModel(PointCloudProjectionModel): + + def __init__( + self, + point_cloud_model: str, + point_cloud_model_layers: int, + point_cloud_model_embed_dim: int, + **kwargs, # projection arguments + ): + super().__init__(**kwargs) + + # Checks + if self.predict_shape or not self.predict_color: + raise NotImplementedError('Must predict color, not shape, for coloring') + + # Create point cloud model for processing point cloud + self.point_cloud_model = PointCloudTransformerModel( + num_layers=point_cloud_model_layers, + model_type=point_cloud_model, + embed_dim=point_cloud_model_embed_dim, + in_channels=self.in_channels, + out_channels=self.out_channels, + ) # why use transformer instead??? + + def _forward( + self, + pc: Pointclouds, + camera: Optional[CamerasBase], + image_rgb: Optional[Tensor], + mask: Optional[Tensor], + return_point_cloud: bool = False, + noise_std: float = 0.0, + ): + + # Normalize colors and convert to tensor + x = self.point_cloud_to_tensor(pc, normalize=True, scale=True) + x_points, x_colors = x[:, :, :3], x[:, :, 3:] + + # Add noise to points. TODO: Add to configs. + x_input = x_points + torch.randn_like(x_points) * noise_std # simulate noise of the predicted pc? + + # Conditioning + # x_input = self.get_input_with_conditioning(x_input, camera=camera, + # image_rgb=image_rgb, mask=mask) + # XH: edit to run + x_input = self.get_input_with_conditioning(x_input, camera=camera, + image_rgb=image_rgb, mask=mask, t=None) + + # Forward + pred_colors = self.point_cloud_model(x_input) + + # During inference, we return the point cloud with the predicted colors + if return_point_cloud: + pred_pointcloud = self.tensor_to_point_cloud( + torch.cat((x_points, pred_colors), dim=2), denormalize=True, unscale=True) + return pred_pointcloud + + # During training, we have ground truth colors and return the loss + loss = F.mse_loss(pred_colors, x_colors) + return loss + + def forward(self, batch: FrameData, **kwargs): + """A wrapper around the forward method""" + if isinstance(batch, dict): # fixes a bug with multiprocessing where batch becomes a dict + batch = FrameData(**batch) # it really makes no sense, I do not understand it + return self._forward( + pc=batch.sequence_point_cloud, + camera=batch.camera, + image_rgb=batch.image_rgb, + mask=batch.fg_probability, + **kwargs, + ) \ No newline at end of file diff --git a/model/model_diff_data.py b/model/model_diff_data.py new file mode 100644 index 0000000000000000000000000000000000000000..80726006ad7ae6097d164f29cf2745742948c6fc --- /dev/null +++ b/model/model_diff_data.py @@ -0,0 +1,238 @@ +""" +model to deal with shapenet inputs and other datasets such as Behave and ProciGen +the model takes a different data dictionary in forward function +""" +import inspect +from typing import Optional +import numpy as np + +import torch +import torch.nn.functional as F +from diffusers.schedulers.scheduling_ddpm import DDPMScheduler +from diffusers.schedulers.scheduling_ddim import DDIMScheduler +from diffusers.schedulers.scheduling_pndm import PNDMScheduler +from pytorch3d.implicitron.dataset.data_loader_map_provider import FrameData +from pytorch3d.renderer.cameras import CamerasBase +from pytorch3d.structures import Pointclouds +from torch import Tensor +from tqdm import tqdm +from pytorch3d.renderer import PerspectiveCameras +from pytorch3d.datasets.r2n2.utils import BlenderCamera + + +from .model import ConditionalPointCloudDiffusionModel +from .model_utils import get_num_points + + +class ConditionalPCDiffusionShapenet(ConditionalPointCloudDiffusionModel): + def forward(self, batch, mode: str = 'train', **kwargs): + """ + take a batch of data from ShapeNet + """ + images = torch.stack(batch['images'], 0).to('cuda') + masks = torch.stack(batch['masks'], 0).to('cuda') + pc = Pointclouds([x.to('cuda') for x in batch['pclouds']]) + camera = BlenderCamera( + torch.stack(batch['R']), + torch.stack(batch['T']), + torch.stack(batch['K']), device='cuda' + ) + + if mode == 'train': + return self.forward_train( + pc=pc, + camera=camera, + image_rgb=images, + mask=masks, + + **kwargs) + elif mode == 'sample': + num_points = kwargs.pop('num_points', get_num_points(pc)) + return self.forward_sample( + num_points=num_points, + camera=camera, + image_rgb=images, + mask=masks, + gt_pc=pc, + **kwargs) + else: + raise NotImplementedError() + + +class ConditionalPCDiffusionBehave(ConditionalPointCloudDiffusionModel): + "diffusion model for Behave dataset" + def forward(self, batch, mode: str = 'train', **kwargs): + images = torch.stack(batch['images'], 0).to('cuda') + masks = torch.stack(batch['masks'], 0).to('cuda') + pc = self.get_input_pc(batch) + camera = PerspectiveCameras( + R=torch.stack(batch['R']), + T=torch.stack(batch['T']), + K=torch.stack(batch['K']), + device='cuda', + in_ndc=True + ) + grid_df = torch.stack(batch['grid_df'], 0).to('cuda') if 'grid_df' in batch else None + num_points = kwargs.pop('num_points', get_num_points(pc)) + if mode == 'train': + return self.forward_train( + pc=pc, + camera=camera, + image_rgb=images, + mask=masks, + grid_df=grid_df, + **kwargs) + elif mode == 'sample': + return self.forward_sample( + num_points=num_points, + camera=camera, + image_rgb=images, + mask=masks, + gt_pc=pc, + **kwargs) + else: + raise NotImplementedError() + + def get_input_pc(self, batch): + pc = Pointclouds([x.to('cuda') for x in batch['pclouds']]) + return pc + + +class ConditionalPCDiffusionSeparateSegm(ConditionalPCDiffusionBehave): + "a separate model to predict binary labels, the final segmentation model" + def __init__(self, + beta_start: float, + beta_end: float, + beta_schedule: str, + point_cloud_model: str, + point_cloud_model_embed_dim: int, + **kwargs, # projection arguments + ): + super(ConditionalPCDiffusionSeparateSegm, self).__init__(beta_start, beta_end, beta_schedule, + point_cloud_model, + point_cloud_model_embed_dim, **kwargs) + # add a separate model to predict binary label + from .point_cloud_transformer_model import PointCloudTransformerModel, PointCloudModel + + self.binary_model = PointCloudTransformerModel( + num_layers=1, # XH: use the default color model number of layers + model_type=point_cloud_model, # pvcnn + embed_dim=point_cloud_model_embed_dim, # save as pc shape model + in_channels=self.in_channels, + out_channels=1, + ) + self.binary_training_noise_std = kwargs.get("binary_training_noise_std", 0.1) + + # re-initialize point cloud model + assert self.predict_binary + self.point_cloud_model = PointCloudModel( + model_type=point_cloud_model, + embed_dim=point_cloud_model_embed_dim, + in_channels=self.in_channels, + out_channels=self.out_channels - 1, # not predicting binary from this anymore + voxel_resolution_multiplier=kwargs.get('voxel_resolution_multiplier', 1) + ) + + def forward_train( + self, + pc: Pointclouds, + camera: Optional[CamerasBase], + image_rgb: Optional[Tensor], + mask: Optional[Tensor], + return_intermediate_steps: bool = False, + **kwargs + ): + # first run shape forward, then binary label forward + assert not return_intermediate_steps + assert self.predict_binary + loss_shape = super(ConditionalPCDiffusionSeparateSegm, self).forward_train(pc, + camera, + image_rgb, + mask, + return_intermediate_steps, + **kwargs) + + # binary label forward + x_0 = self.point_cloud_to_tensor(pc, normalize=True, scale=True) + x_points, x_colors = x_0[:, :, :3], x_0[:, :, 3:] + + # Add noise to points. + x_input = x_points + torch.randn_like(x_points) * self.binary_training_noise_std # std=0.1 + x_input = self.get_input_with_conditioning(x_input, camera=camera, + image_rgb=image_rgb, mask=mask, t=None) + + # Forward + pred_segm = self.binary_model(x_input) + + # use compressed bits + df_grid = kwargs.get('grid_df', None).unsqueeze(1) # (B, 1, resz, resy, resx) + points = x_points.clone().detach() / self.scale_factor * 2 # , normalize to [-1, 1] + points[:, :, 0], points[:, :, 2] = points[:, :, 2].clone(), points[:, :,0].clone() # swap, make sure clone is used! + points = points.unsqueeze(1).unsqueeze(1) # (B,1, 1, N, 3) + with torch.no_grad(): + df_interp = F.grid_sample(df_grid, points, padding_mode='border', align_corners=True).squeeze(1).squeeze(1) # (B, 1, 1, 1, N) + binary_label = df_interp[:, 0] > 0.5 # (B, 1, N) + + binary_pred = torch.sigmoid(pred_segm.squeeze(-1)) # add a sigmoid layer + loss_binary = F.mse_loss(binary_pred, binary_label.float().squeeze(1).squeeze(1)) * self.lw_binary + loss = loss_shape + loss_binary + + return loss, torch.tensor([loss_shape, loss_binary]) + + def reverse_step(self, extra_step_kwargs, scheduler, t, x_t, x_t_input, **kwargs): + "return (B, N, 4), the 4-th channel is binary label" + B = x_t.shape[0] + # Forward + noise_pred = self.point_cloud_model(x_t_input, t.reshape(1).expand(B)) + if self.consistent_center: + assert self.dm_pred_type != 'sample', 'incompatible dm predition type!' + # suggested by the CCD-3DR paper + noise_pred = noise_pred - torch.mean(noise_pred, dim=1, keepdim=True) + # Step: make sure only update the shape (first 3 channels) + x_t = scheduler.step(noise_pred, t, x_t[:, :, :3], **extra_step_kwargs).prev_sample + if self.consistent_center: + x_t = x_t - torch.mean(x_t, dim=1, keepdim=True) + + # also add binary prediction + if kwargs.get('inference_binary', False): + pred_segm = self.binary_model(x_t_input) + else: + pred_segm = torch.zeros_like(x_t[:, :, 0:1]) + + x_t = torch.cat([x_t, torch.sigmoid(pred_segm)], -1) + + return x_t + + def get_coord_feature(self, x_t): + x_t_input = [x_t[:, :, :3]] + return x_t_input + + def tensor_to_point_cloud(self, x: Tensor, /, denormalize: bool = False, unscale: bool = False): + """ + take binary label into account + :param self: + :param x: (B, N, 4), the 4th channel is the binary segmentation, 1-human, 0-object + :param denormalize: denormalize the per-point colors, from pc2 + :param unscale: undo point scaling, from pc2 + :return: pc with point colors if predict binary label or per-point color + """ + points = x[:, :, :3] / (self.scale_factor if unscale else 1) + if self.predict_color: + colors = self.denormalize(x[:, :, 3:]) if denormalize else x[:, :, 3:] + return Pointclouds(points=points, features=colors) + else: + if self.predict_binary: + assert x.shape[2] == 4 + # add color to predicted binary labels + is_hum = x[:, :, 3] > 0.5 + features = [] + for mask in is_hum: + color = torch.zeros_like(x[0, :, :3]) + torch.tensor([0.5, 1.0, 0]).to(x.device) + color[mask, :] = torch.tensor([0.05, 1.0, 1.0]).to(x.device) # human is light blue, object light green + features.append(color) + else: + assert x.shape[2] == 3 + features = None + return Pointclouds(points=points, features=features) + + diff --git a/model/model_hoattn.py b/model/model_hoattn.py new file mode 100644 index 0000000000000000000000000000000000000000..cf7574f24a60dfc626a9f574924cc6bf94382f51 --- /dev/null +++ b/model/model_hoattn.py @@ -0,0 +1,457 @@ +""" +model that use cross attention to predict human + object +""" + +import inspect +import random +from typing import Optional +from torch import Tensor +import torch +import numpy as np + +from pytorch3d.structures import Pointclouds +from pytorch3d.renderer import CamerasBase +from .model_diff_data import ConditionalPCDiffusionBehave +from .pvcnn.pvcnn_ho import PVCNN2HumObj +import torch.nn.functional as F +from pytorch3d.renderer import PerspectiveCameras +from .model_utils import get_num_points +from tqdm import tqdm + + +class CrossAttenHODiffusionModel(ConditionalPCDiffusionBehave): + def init_pcloud_model(self, kwargs, point_cloud_model, point_cloud_model_embed_dim): + """use cross attention model""" + if point_cloud_model == 'pvcnn': + self.point_cloud_model = PVCNN2HumObj(embed_dim=point_cloud_model_embed_dim, + num_classes=self.out_channels, + extra_feature_channels=(self.in_channels - 3), + voxel_resolution_multiplier=kwargs.get('voxel_resolution_multiplier', 1), + attn_type=kwargs.get('attn_type', 'simple-cross'), + attn_weight=kwargs.get("attn_weight", 1.0) + ) + else: + raise ValueError(f"Unknown point cloud model {point_cloud_model}!") + self.point_visible_test = kwargs.get("point_visible_test", 'single') # when doing point visibility test, use only human points or human + object? + assert self.point_visible_test in ['single', 'combine'], f'invalide point visible test option {self.point_visible_test}' + # print(f"Point visibility test is based on {self.point_visible_test} point clouds!") + + def forward_train( + self, + pc: Pointclouds, + camera: Optional[CamerasBase], + image_rgb: Optional[Tensor], + mask: Optional[Tensor], + return_intermediate_steps: bool = False, + **kwargs + ): + "additional input (RGB, mask, camera, and pc) for object is read from kwargs" + # assert not self.consistent_center + assert not self.self_conditioning + + # Normalize colors and convert to tensor + x0_h = self.point_cloud_to_tensor(pc, normalize=True, scale=True) # this will not pack the point colors + x0_o = self.point_cloud_to_tensor(kwargs.get('pc_obj'), normalize=True, scale=True) + B, N, D = x0_h.shape + + # Sample random noise + noise = torch.randn_like(x0_h) + if self.consistent_center: + # modification suggested by https://arxiv.org/pdf/2308.07837.pdf + noise = noise - torch.mean(noise, dim=1, keepdim=True) + + # Sample random timesteps for each point_cloud + timestep = torch.randint(0, self.scheduler.num_train_timesteps, (B,), + device=self.device, dtype=torch.long) + # timestep = torch.randint(0, 1, (B,), + # device=self.device, dtype=torch.long) + + # Add noise to points + xt_h = self.scheduler.add_noise(x0_h, noise, timestep) + xt_o = self.scheduler.add_noise(x0_o, noise, timestep) + norm_parms = self.pack_norm_params(kwargs) # (2, B, 4) + + # get input conditioning + x_t_input_h, x_t_input_o = self.get_image_conditioning(camera, image_rgb, kwargs, mask, norm_parms, timestep, + xt_h, xt_o) + + # Diffusion prediction + noise_pred_h, noise_pred_o = self.point_cloud_model(x_t_input_h, x_t_input_o, timestep, norm_parms) + + # Check + if not noise_pred_h.shape == noise.shape: + raise ValueError(f'{noise_pred_h.shape=} and {noise.shape=}') + if not noise_pred_o.shape == noise.shape: + raise ValueError(f'{noise_pred_o.shape=} and {noise.shape=}') + + # Loss + loss_h = F.mse_loss(noise_pred_h, noise) + loss_o = F.mse_loss(noise_pred_o, noise) + + loss = loss_h + loss_o + + # Whether to return intermediate steps + if return_intermediate_steps: + return loss, (x0_h, xt_h, noise, noise_pred_h) + + return loss, torch.tensor([loss_h, loss_o]) + + def get_image_conditioning(self, camera, image_rgb, kwargs, mask, norm_parms, timestep, xt_h, xt_o): + """ + compute image features for each point + :param camera: + :param image_rgb: + :param kwargs: + :param mask: + :param norm_parms: + :param timestep: + :param xt_h: + :param xt_o: + :return: + """ + if self.point_visible_test == 'single': + # Visibility test is down independently for human and object + x_t_input_h = self.get_input_with_conditioning(xt_h, camera=camera, + image_rgb=image_rgb, mask=mask, t=timestep) + x_t_input_o = self.get_input_with_conditioning(xt_o, camera=kwargs.get('camera_obj'), + image_rgb=kwargs.get('rgb_obj'), + mask=kwargs.get('mask_obj'), t=timestep) + elif self.point_visible_test == 'combine': + # Combine human + object points to do visibility test and obtain features + B, N = xt_h.shape[:2] # (B, N, 3) + # for human: transform object points first to H+O space, then to human space + xt_o_in_ho = xt_o * 2 * norm_parms[1, :, 3:].unsqueeze(1) + norm_parms[1, :, :3].unsqueeze(1) + xt_o_in_hum = (xt_o_in_ho - norm_parms[0, :, :3].unsqueeze(1)) / (2 * norm_parms[0, :, 3:].unsqueeze(1)) + # compute features for all points, take only first half feature for human + x_t_input_h = self.get_input_with_conditioning(torch.cat([xt_h, xt_o_in_hum], 1), camera=camera, + image_rgb=image_rgb, mask=mask, t=timestep)[:,:N] + # for object: transform human points to H+O space, then to object space + xt_h_in_ho = xt_h * 2 * norm_parms[0, :, 3:].unsqueeze(1) + norm_parms[0, :, :3].unsqueeze(1) + xt_h_in_obj = (xt_h_in_ho - norm_parms[1, :, :3].unsqueeze(1)) / (2 * norm_parms[1, :, 3:].unsqueeze(1)) + x_t_input_o = self.get_input_with_conditioning(torch.cat([xt_o, xt_h_in_obj], 1), + camera=kwargs.get('camera_obj'), + image_rgb=kwargs.get('rgb_obj'), + mask=kwargs.get('mask_obj'), t=timestep)[:, :N] + else: + raise NotImplementedError + return x_t_input_h, x_t_input_o + + def forward(self, batch, mode: str = 'train', **kwargs): + """""" + images = torch.stack(batch['images'], 0).to('cuda') + masks = torch.stack(batch['masks'], 0).to('cuda') + pc = self.get_input_pc(batch) + camera = PerspectiveCameras( + R=torch.stack(batch['R']), + T=torch.stack(batch['T_hum']), + K=torch.stack(batch['K_hum']), + device='cuda', + in_ndc=True + ) + grid_df = torch.stack(batch['grid_df'], 0).to('cuda') if 'grid_df' in batch else None + num_points = kwargs.pop('num_points', get_num_points(pc)) + + rgb_obj = torch.stack(batch['images_obj'], 0).to('cuda') + masks_obj = torch.stack(batch['masks_obj'], 0).to('cuda') + pc_obj = Pointclouds([x.to('cuda') for x in batch['pclouds_obj']]) + camera_obj = PerspectiveCameras( + R=torch.stack(batch['R']), + T=torch.stack(batch['T_obj']), + K=torch.stack(batch['K_obj']), + device='cuda', + in_ndc=True + ) + + # normalization parameters + cent_hum = torch.stack(batch['cent_hum'], 0).to('cuda') + cent_obj = torch.stack(batch['cent_obj'], 0).to('cuda') # B, 3 + radius_hum = torch.stack(batch['radius_hum'], 0).to('cuda') # B, 1 + radius_obj = torch.stack(batch['radius_obj'], 0).to('cuda') + + # print(batch['image_path']) + + if mode == 'train': + return self.forward_train( + pc=pc, + camera=camera, + image_rgb=images, + mask=masks, + grid_df=grid_df, + rgb_obj=rgb_obj, + mask_obj=masks_obj, + pc_obj=pc_obj, + camera_obj=camera_obj, + cent_hum=cent_hum, + cent_obj=cent_obj, + radius_hum=radius_hum, + radius_obj=radius_obj, + ) + elif mode == 'sample': + # this use GT centers to do projection + return self.forward_sample( + num_points=num_points, + camera=camera, + image_rgb=images, + mask=masks, + gt_pc=pc, + rgb_obj=rgb_obj, + mask_obj=masks_obj, + pc_obj=pc_obj, + camera_obj=camera_obj, + cent_hum=cent_hum, + cent_obj=cent_obj, + radius_hum=radius_hum, + radius_obj=radius_obj, + **kwargs) + elif mode == 'interm-gt': + return self.forward_sample( + num_points=num_points, + camera=camera, + image_rgb=images, + mask=masks, + gt_pc=pc, + rgb_obj=rgb_obj, + mask_obj=masks_obj, + pc_obj=pc_obj, + camera_obj=camera_obj, + cent_hum=cent_hum, + cent_obj=cent_obj, + radius_hum=radius_hum, + radius_obj=radius_obj, + sample_from_interm=True, + **kwargs) + elif mode == 'interm-pred': + # use camera from predicted + camera = PerspectiveCameras( + R=torch.stack(batch['R']), + T=torch.stack(batch['T_hum_scaled']), + K=torch.stack(batch['K_hum']), + device='cuda', + in_ndc=True + ) + camera_obj = PerspectiveCameras( + R=torch.stack(batch['R']), + T=torch.stack(batch['T_obj_scaled']), + K=torch.stack(batch['K_obj']), # the camera should be human/object specific!!! + device='cuda', + in_ndc=True + ) + # use pc from predicted + pc = Pointclouds([x.to('cuda') for x in batch['pred_hum']]) + pc_obj = Pointclouds([x.to('cuda') for x in batch['pred_obj']]) + # use center and radius from predicted + cent_hum = torch.stack(batch['cent_hum_pred'], 0).to('cuda') + cent_obj = torch.stack(batch['cent_obj_pred'], 0).to('cuda') # B, 3 + radius_hum = torch.stack(batch['radius_hum_pred'], 0).to('cuda') # B, 1 + radius_obj = torch.stack(batch['radius_obj_pred'], 0).to('cuda') + + return self.forward_sample( + num_points=num_points, + camera=camera, + image_rgb=images, + mask=masks, + gt_pc=pc, + rgb_obj=rgb_obj, + mask_obj=masks_obj, + pc_obj=pc_obj, + camera_obj=camera_obj, + cent_hum=cent_hum, + cent_obj=cent_obj, + radius_hum=radius_hum, + radius_obj=radius_obj, + sample_from_interm=True, + **kwargs) + elif mode == 'interm-pred-ts': + # use only estimate translation and scale, but sample from gaussian + # this works, the camera is GT!!! + pc = Pointclouds([x.to('cuda') for x in batch['pred_hum']]) + pc_obj = Pointclouds([x.to('cuda') for x in batch['pred_obj']]) + # use center and radius from predicted + cent_hum = torch.stack(batch['cent_hum_pred'], 0).to('cuda') + cent_obj = torch.stack(batch['cent_obj_pred'], 0).to('cuda') # B, 3 + radius_hum = torch.stack(batch['radius_hum_pred'], 0).to('cuda') # B, 1 + radius_obj = torch.stack(batch['radius_obj_pred'], 0).to('cuda') + # print(cent_hum[0], radius_hum[0], cent_obj[0], radius_obj[0]) + + return self.forward_sample( + num_points=num_points, + camera=camera, + image_rgb=images, + mask=masks, + gt_pc=pc, + rgb_obj=rgb_obj, + mask_obj=masks_obj, + pc_obj=pc_obj, + camera_obj=camera_obj, + cent_hum=cent_hum, + cent_obj=cent_obj, + radius_hum=radius_hum, + radius_obj=radius_obj, + sample_from_interm=False, + **kwargs) + else: + raise NotImplementedError + + def forward_sample( + self, + num_points: int, + camera: Optional[CamerasBase], + image_rgb: Optional[Tensor], + mask: Optional[Tensor], + # Optional overrides + scheduler: Optional[str] = 'ddpm', + # Inference parameters + num_inference_steps: Optional[int] = 1000, + eta: Optional[float] = 0.0, # for DDIM + # Whether to return all the intermediate steps in generation + return_sample_every_n_steps: int = -1, + # Whether to disable tqdm + disable_tqdm: bool = False, + gt_pc: Pointclouds = None, + **kwargs + ): + "use two models to run diffusion forward, and also use translation and scale to put them back" + assert not self.self_conditioning + # Get scheduler from mapping, or use self.scheduler if None + scheduler = self.scheduler if scheduler is None else self.schedulers_map[scheduler] + + # Get the size of the noise + N = num_points + B = 1 if image_rgb is None else image_rgb.shape[0] + D = self.get_x_T_channel() + device = self.device if image_rgb is None else image_rgb.device + + # sample from full steps or only a few steps + sample_from_interm = kwargs.get('sample_from_interm', False) + interm_steps = kwargs.get('noise_step') if sample_from_interm else -1 + + xt_h = self.initialize_x_T(device, gt_pc, (B, N, D), interm_steps, scheduler) + xt_o = self.initialize_x_T(device, kwargs.get('pc_obj', None), (B, N, D), interm_steps, scheduler) + + # the segmentation mask + segm_mask = torch.zeros(B, 2*N, 1).to(device) + segm_mask[:, :N] = 1.0 + + # Set timesteps + extra_step_kwargs = self.setup_reverse_process(eta, num_inference_steps, scheduler) + + # Loop over timesteps + all_outputs = [] + return_all_outputs = (return_sample_every_n_steps > 0) + progress_bar = tqdm(self.get_reverse_timesteps(scheduler, interm_steps), + desc=f'Sampling ({xt_h.shape})', disable=disable_tqdm) + + # print("Camera T:", camera.T[0], camera.R[0]) + # print("Camera_obj T:", kwargs.get('camera_obj').T[0], kwargs.get('camera_obj').R[0]) + + norm_parms = self.pack_norm_params(kwargs) + for i, t in enumerate(progress_bar): + x_t_input_h, x_t_input_o = self.get_image_conditioning(camera, image_rgb, + kwargs, mask, + norm_parms, + t, + xt_h, xt_o) + + # One reverse step with conditioning + xt_h, xt_o = self.reverse_step(extra_step_kwargs, scheduler, t, torch.stack([xt_h, xt_o], 0), + torch.stack([x_t_input_h, x_t_input_o], 0), **kwargs) # (B, N, D), D=3 + + if (return_all_outputs and (i % return_sample_every_n_steps == 0 or i == len(scheduler.timesteps) - 1)): + # print(xt_h.shape, kwargs.get('cent_hum').shape, kwargs.get('radius_hum').shape) + x_t = torch.cat([self.denormalize_pclouds(xt_h, kwargs.get('cent_hum'), kwargs.get('radius_hum')), + self.denormalize_pclouds(xt_o, kwargs.get('cent_obj'), kwargs.get('radius_obj'))], 1) + # print(x_t.shape, xt_o.shape) + all_outputs.append(torch.cat([x_t, segm_mask], -1)) + # print("Updating intermediate...") + + # Convert output back into a point cloud, undoing normalization and scaling + x_t = torch.cat([self.denormalize_pclouds(xt_h, kwargs.get('cent_hum'), kwargs.get('radius_hum')), + self.denormalize_pclouds(xt_o, kwargs.get('cent_obj'), kwargs.get('radius_obj'))], 1) + x_t = torch.cat([x_t, segm_mask], -1) + output = self.tensor_to_point_cloud(x_t, denormalize=False, unscale=False) # this convert the points back to original scale + if return_all_outputs: + all_outputs = torch.stack(all_outputs, dim=1) # (B, sample_steps, N, D) + all_outputs = [self.tensor_to_point_cloud(o, denormalize=False, unscale=False) for o in all_outputs] + + return (output, all_outputs) if return_all_outputs else output + + def get_reverse_timesteps(self, scheduler, interm_steps:int): + """ + + :param scheduler: + :param interm_steps: start from some intermediate steps + :return: + """ + if interm_steps > 0: + timesteps = torch.from_numpy(np.arange(0, interm_steps)[::-1].copy()).to(self.device) + else: + timesteps = scheduler.timesteps.to(self.device) + return timesteps + + def pack_norm_params(self, kwargs:dict, scale=True): + scale_factor = self.scale_factor if scale else 1.0 + hum = torch.cat([kwargs.get('cent_hum')*scale_factor, kwargs.get('radius_hum')], -1) + obj = torch.cat([kwargs.get('cent_obj')*scale_factor, kwargs.get('radius_obj')], -1) + return torch.stack([hum, obj], 0) # (2, B, 4) + + def reverse_step(self, extra_step_kwargs, scheduler, t, x_t, x_t_input, **kwargs): + "x_t: (2, B, D, N), x_t_input: (2, B, D, N)" + norm_parms = self.pack_norm_params(kwargs) # (2, B, 4) + B = x_t.shape[1] + # print(f"Step {t} Norm params:", norm_parms[:, 0, :]) + noise_pred_h, noise_pred_o = self.point_cloud_model(x_t_input[0], x_t_input[1], t.reshape(1).expand(B), + norm_parms) + if self.consistent_center: + assert self.dm_pred_type != 'sample', 'incompatible dm predition type!' + noise_pred_h = noise_pred_h - torch.mean(noise_pred_h, dim=1, keepdim=True) + noise_pred_o = noise_pred_o - torch.mean(noise_pred_o, dim=1, keepdim=True) + + xt_h = scheduler.step(noise_pred_h, t, x_t[0], **extra_step_kwargs).prev_sample + xt_o = scheduler.step(noise_pred_o, t, x_t[1], **extra_step_kwargs).prev_sample + + if self.consistent_center: + xt_h = xt_h - torch.mean(xt_h, dim=1, keepdim=True) + xt_o = xt_o - torch.mean(xt_o, dim=1, keepdim=True) + + return xt_h, xt_o + + def denormalize_pclouds(self, x: Tensor, cent, radius, unscale: bool = True): + """ + first denormalize, then apply center and scale to original H+O coordinate + :param x: + :param cent: (B, 3) + :param radius: (B, 1) + :param unscale: + :return: + """ + # denormalize: scale down. + points = x[:, :, :3] / (self.scale_factor if unscale else 1) + # translation and scale back to H+O coordinate + points = points * 2 * radius.unsqueeze(-1) + cent.unsqueeze(1) + return points + + def tensor_to_point_cloud(self, x: Tensor, /, denormalize: bool = False, unscale: bool = False): + """ + take binary into account + :param self: + :param x: (B, N, 4) + :param denormalize: + :param unscale: + :return: + """ + points = x[:, :, :3] / (self.scale_factor if unscale else 1) + if self.predict_color: + colors = self.denormalize(x[:, :, 3:]) if denormalize else x[:, :, 3:] + return Pointclouds(points=points, features=colors) + else: + assert x.shape[2] == 4 + # add color to predicted binary labels + is_hum = x[:, :, 3] > 0.5 + features = [] + for mask in is_hum: + color = torch.zeros_like(x[0, :, :3]) + torch.tensor([0.5, 1.0, 0]).to(x.device) + color[mask, :] = torch.tensor([0.05, 1.0, 1.0]).to(x.device) # human is light blue, object light green + features.append(color) + return Pointclouds(points=points, features=features) + + diff --git a/model/model_utils.py b/model/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..416e6a9928e6cd6f1dc39becbf40b0538f706fb8 --- /dev/null +++ b/model/model_utils.py @@ -0,0 +1,58 @@ +import cv2 +import numpy as np +import torch +import torch.nn as nn +from pytorch3d.structures import Pointclouds + + +def set_requires_grad(module: nn.Module, requires_grad: bool): + for p in module.parameters(): + p.requires_grad_(requires_grad) + + +def compute_distance_transform(mask: torch.Tensor): + """ + + Parameters + ---------- + mask (B, 1, H, W) or (B, 2, H, W) true for foreground + + Returns + ------- + the vector to the closest foreground pixel, zero if inside mask + + """ + C = mask.shape[1] + assert C in [1, 2], f'invalid mask shape {mask.shape} found!' + + image_size = mask.shape[-1] + + dts = [] + for i in range(C): + distance_transform = torch.stack([ + torch.from_numpy(cv2.distanceTransform( + (1 - m), distanceType=cv2.DIST_L2, maskSize=cv2.DIST_MASK_3 + ) / (image_size / 2)) + for m in mask[:, i:i+1].squeeze(1).detach().cpu().numpy().astype(np.uint8) + ]).unsqueeze(1).clip(0, 1).to(mask.device) + dts.append(distance_transform) + return torch.cat(dts, 1) + + +def default(x, d): + return d if x is None else x + + +def get_num_points(x: Pointclouds, /): + return x.points_padded().shape[1] + + +def get_custom_betas(beta_start: float, beta_end: float, warmup_frac: float = 0.3, num_train_timesteps: int = 1000): + """Custom beta schedule""" + betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) + warmup_frac = 0.3 + warmup_time = int(num_train_timesteps * warmup_frac) + warmup_steps = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64) + warmup_time = min(warmup_time, num_train_timesteps) + betas[:warmup_time] = warmup_steps[:warmup_time] + return betas diff --git a/model/point_cloud_model.py b/model/point_cloud_model.py new file mode 100644 index 0000000000000000000000000000000000000000..8e9617ced94e167ebbe3937cc16c207c87b5e7e9 --- /dev/null +++ b/model/point_cloud_model.py @@ -0,0 +1,67 @@ +from contextlib import nullcontext + +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers import ModelMixin +from torch import Tensor + +from .pvcnn.pvcnn import PVCNN2 +from .pvcnn.pvcnn_plus_plus import PVCNN2PlusPlus +from .simple.simple_model import SimplePointModel + + +class PointCloudModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + model_type: str = 'pvcnn', + in_channels: int = 3, + out_channels: int = 3, + embed_dim: int = 64, + dropout: float = 0.1, + width_multiplier: int = 1, + voxel_resolution_multiplier: int = 1, + ): + super().__init__() + self.model_type = model_type + if self.model_type == 'pvcnn': + self.autocast_context = torch.autocast('cuda', dtype=torch.float32) + self.model = PVCNN2( + embed_dim=embed_dim, + num_classes=out_channels, + extra_feature_channels=(in_channels - 3), + dropout=dropout, width_multiplier=width_multiplier, + voxel_resolution_multiplier=voxel_resolution_multiplier + ) + self.model.classifier[-1].bias.data.normal_(0, 1e-6) + self.model.classifier[-1].weight.data.normal_(0, 1e-6) + elif self.model_type == 'pvcnnplusplus': + self.autocast_context = torch.autocast('cuda', dtype=torch.float32) + self.model = PVCNN2PlusPlus( + embed_dim=embed_dim, + num_classes=out_channels, + extra_feature_channels=(in_channels - 3), + ) + self.model.output_projection[-1].bias.data.normal_(0, 1e-6) + self.model.output_projection[-1].weight.data.normal_(0, 1e-6) + elif self.model_type == 'simple': + self.autocast_context = nullcontext() + self.model = SimplePointModel( + embed_dim=embed_dim, + num_classes=out_channels, + extra_feature_channels=(in_channels - 3), + ) + self.model.output_projection.bias.data.normal_(0, 1e-6) + self.model.output_projection.weight.data.normal_(0, 1e-6) + else: + raise NotImplementedError() + + def forward(self, inputs: Tensor, t: Tensor, ret_feats=False) -> Tensor: + """ Receives input of shape (B, N, in_channels) and returns output + of shape (B, N, out_channels) """ + with self.autocast_context: + if not ret_feats: + return self.model(inputs.transpose(1, 2), t, ret_feats=False).transpose(1, 2) + else: + pred, feats = self.model(inputs.transpose(1, 2), t, ret_feats=True) + return pred.transpose(1, 2), feats \ No newline at end of file diff --git a/model/point_cloud_transformer_model.py b/model/point_cloud_transformer_model.py new file mode 100644 index 0000000000000000000000000000000000000000..54621629f7ee38e31823c1b5d23ca8d3c0fed2e0 --- /dev/null +++ b/model/point_cloud_transformer_model.py @@ -0,0 +1,80 @@ +from typing import Optional + +import torch +import torch.nn as nn +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers import ModelMixin +from torch import Tensor +from timm.models.vision_transformer import Attention, LayerScale, DropPath, Mlp + +from .point_cloud_model import PointCloudModel + + +class PointCloudModelBlock(nn.Module): + + def __init__( + self, + *, + # Point cloud model + dim: int, + model_type: str = 'pvcnn', + dropout: float = 0.1, + width_multiplier: int = 1, + voxel_resolution_multiplier: int = 1, + # Transformer model + num_heads=6, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None, + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_attn=False + ): + super().__init__() + + # Point cloud model + self.norm0 = norm_layer(dim) + self.point_cloud_model = PointCloudModel(model_type=model_type, + in_channels=dim, out_channels=dim, embed_dim=dim, dropout=dropout, + width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier) + self.ls0 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path0 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + # Attention + self.use_attn = use_attn + if self.use_attn: + self.norm1 = norm_layer(dim) + self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + # MLP + self.norm2 = norm_layer(dim) + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def apply_point_cloud_model(self, x: Tensor, t: Optional[Tensor] = None) -> Tensor: + t = t if t is not None else torch.zeros(len(x), device=x.device, dtype=torch.long) + return self.point_cloud_model(x, t) + + def forward(self, x: Tensor): + x = x + self.drop_path0(self.ls0(self.apply_point_cloud_model(self.norm0(x)))) + if self.use_attn: + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class PointCloudTransformerModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__(self, num_layers: int, in_channels: int = 3, out_channels: int = 3, embed_dim: int = 64, **kwargs): + super().__init__() + self.num_layers = num_layers + self.input_projection = nn.Linear(in_channels, embed_dim) + self.blocks = nn.Sequential(*[PointCloudModelBlock(dim=embed_dim, **kwargs) for i in range(self.num_layers)]) + self.norm = nn.LayerNorm(embed_dim) + self.output_projection = nn.Linear(embed_dim, out_channels) + + def forward(self, inputs: Tensor) -> Tensor: + """ Receives input of shape (B, N, in_channels) and returns output + of shape (B, N, out_channels) """ + x = self.input_projection(inputs) + x = self.blocks(x) + x = self.output_projection(x) + return x diff --git a/model/projection_model.py b/model/projection_model.py new file mode 100644 index 0000000000000000000000000000000000000000..78450d69c034f61de60f0cb0b3185d175035b69e --- /dev/null +++ b/model/projection_model.py @@ -0,0 +1,273 @@ +from typing import Optional, Union + +import torch +from diffusers.schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler +from diffusers.schedulers.scheduling_lms_discrete import LMSDiscreteScheduler +from diffusers import ModelMixin +from pytorch3d.implicitron.dataset.data_loader_map_provider import FrameData +from pytorch3d.renderer import PointsRasterizationSettings, PointsRasterizer +from pytorch3d.renderer.cameras import CamerasBase +from pytorch3d.structures import Pointclouds +from torch import Tensor + +from .feature_model import FeatureModel +from .model_utils import compute_distance_transform + +SchedulerClass = Union[DDPMScheduler, DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + + +class PointCloudProjectionModel(ModelMixin): + + def __init__( + self, + image_size: int, + image_feature_model: str, + use_local_colors: bool = True, + use_local_features: bool = True, + use_global_features: bool = False, + use_mask: bool = True, + use_distance_transform: bool = True, + predict_shape: bool = True, + predict_color: bool = False, + process_color: bool = False, + image_color_channels: int = 3, # for the input image, not the points + color_channels: int = 3, # for the points, not the input image + colors_mean: float = 0.5, + colors_std: float = 0.5, + scale_factor: float = 1.0, + # Rasterization settings + raster_point_radius: float = 0.0075, # point size + raster_points_per_pixel: int = 1, # a single point per pixel, for now + bin_size: int = 0, + model_name=None, + # additional arguments added by XH + load_sample_init=False, + sample_init_scale=1.0, + test_init_with_gtpc=False, + consistent_center=False, # from https://arxiv.org/pdf/2308.07837.pdf + voxel_resolution_multiplier: int=1, + predict_binary: bool=False, # predict a binary class label + lw_binary: float=1.0, + binary_training_noise_std: float=0.1, + dm_pred_type: str='epsilon', # diffusion prediction type + self_conditioning=False, + **kwargs, + + ): + super().__init__() + self.image_size = image_size + self.scale_factor = scale_factor + self.use_local_colors = use_local_colors + self.use_local_features = use_local_features + self.use_global_features = use_global_features + self.use_mask = use_mask + self.use_distance_transform = use_distance_transform + self.predict_shape = predict_shape # default False + self.predict_color = predict_color # default True + self.process_color = process_color + self.image_color_channels = image_color_channels + self.color_channels = color_channels + self.colors_mean = colors_mean + self.colors_std = colors_std + self.model_name = model_name + print("PointCloud Model scale factor:", self.scale_factor, 'Model name:', self.model_name) + self.predict_binary = predict_binary + self.lw_binary = lw_binary + self.self_conditioning = self_conditioning + + # Types of conditioning that are used + self.use_local_conditioning = self.use_local_colors or self.use_local_features or self.use_mask + self.use_global_conditioning = self.use_global_features + self.kwargs = kwargs + + # Create feature model + self.feature_model = FeatureModel(image_size, image_feature_model) + + # Input size + self.in_channels = 3 # 3 for 3D point positions + if self.use_local_colors: # whether color should be an input + self.in_channels += self.image_color_channels + if self.use_local_features: + self.in_channels += self.feature_model.feature_dim + if self.use_global_features: + self.in_channels += self.feature_model.feature_dim + if self.use_mask: + self.in_channels += 2 if self.use_distance_transform else 1 + if self.process_color: + self.in_channels += self.color_channels # point color added to input or not, default False + if self.self_conditioning: + self.in_channels += 3 # add self conditioning + + self.in_channels = self.add_extra_input_chennels(self.in_channels) + + if self.model_name in ['pc2-diff-ho-sepsegm', 'diff-ho-attn']: + self.in_channels += 2 if self.use_distance_transform else 1 + + # Output size + self.out_channels = 0 + if self.predict_shape: + self.out_channels += 3 + if self.predict_color: + self.out_channels += self.color_channels + if self.predict_binary: + print("Output binary classification score!") + self.out_channels += 1 + + # Save rasterization settings + self.raster_settings = PointsRasterizationSettings( + image_size=(image_size, image_size), + radius=raster_point_radius, + points_per_pixel=raster_points_per_pixel, + bin_size=bin_size, + ) + + def add_extra_input_chennels(self, input_channels): + return input_channels + + def denormalize(self, x: Tensor, /, clamp: bool = True): + x = x * self.colors_std + self.colors_mean + return torch.clamp(x, 0, 1) if clamp else x + + def normalize(self, x: Tensor, /): + x = (x - self.colors_mean) / self.colors_std + return x + + def get_global_conditioning(self, image_rgb: Tensor): + global_conditioning = [] + if self.use_global_features: + global_conditioning.append(self.feature_model(image_rgb, + return_cls_token_only=True)) # (B, D) + global_conditioning = torch.cat(global_conditioning, dim=1) # (B, D_cond) + return global_conditioning + + def get_local_conditioning(self, image_rgb: Tensor, mask: Tensor): + """ + compute per-point conditioning + Parameters + ---------- + image_rgb: (B, 3, 224, 224), values normalized to 0-1, background is masked by the given mask + mask: (B, 1, 224, 224), or (B, 2, 224, 224) for h+o + """ + local_conditioning = [] + # import pdb; pdb.set_trace() + + if self.use_local_colors: # XH: default True + local_conditioning.append(self.normalize(image_rgb)) + if self.use_local_features: # XH: default True + local_conditioning.append(self.feature_model(image_rgb)) # I guess no mask here? feature model: 'vit_small_patch16_224_mae' + if self.use_mask: # default True + local_conditioning.append(mask.float()) + if self.use_distance_transform: # default True + if not self.use_mask: + raise ValueError('No mask for distance transform?') + if mask.is_floating_point(): + mask = mask > 0.5 + local_conditioning.append(compute_distance_transform(mask)) + local_conditioning = torch.cat(local_conditioning, dim=1) # (B, D_cond, H, W) + return local_conditioning + + @torch.autocast('cuda', dtype=torch.float32) + def surface_projection( + self, points: Tensor, camera: CamerasBase, local_features: Tensor, + ): + B, C, H, W, device = *local_features.shape, local_features.device + R = self.raster_settings.points_per_pixel + N = points.shape[1] + + # Scale camera by scaling T. ASSUMES CAMERA IS LOOKING AT ORIGIN! + camera = camera.clone() + camera.T = camera.T * self.scale_factor + + # Create rasterizer + rasterizer = PointsRasterizer(cameras=camera, raster_settings=self.raster_settings) + + # Associate points with features via rasterization + fragments = rasterizer(Pointclouds(points)) # (B, H, W, R) + fragments_idx: Tensor = fragments.idx.long() + visible_pixels = (fragments_idx > -1) # (B, H, W, R) + points_to_visible_pixels = fragments_idx[visible_pixels] + + # Reshape local features to (B, H, W, R, C) + local_features = local_features.permute(0, 2, 3, 1).unsqueeze(-2).expand(-1, -1, -1, R, -1) # (B, H, W, R, C) + + # Get local features corresponding to visible points + local_features_proj = torch.zeros(B * N, C, device=device) + # local feature includes: raw RGB color, image features, mask, distance transform + local_features_proj[points_to_visible_pixels] = local_features[visible_pixels] + local_features_proj = local_features_proj.reshape(B, N, C) + + return local_features_proj + + def point_cloud_to_tensor(self, pc: Pointclouds, /, normalize: bool = False, scale: bool = False): + """Converts a point cloud to a tensor, with color if and only if self.predict_color""" + points = pc.points_padded() * (self.scale_factor if scale else 1) + if self.predict_color and pc.features_padded() is not None: # normalize color, not point locations + colors = self.normalize(pc.features_padded()) if normalize else pc.features_padded() + return torch.cat((points, colors), dim=2) + else: + return points + + def tensor_to_point_cloud(self, x: Tensor, /, denormalize: bool = False, unscale: bool = False): + points = x[:, :, :3] / (self.scale_factor if unscale else 1) + if self.predict_color: + colors = self.denormalize(x[:, :, 3:]) if denormalize else x[:, :, 3:] + return Pointclouds(points=points, features=colors) + else: + assert x.shape[2] == 3 + return Pointclouds(points=points) + + def get_input_with_conditioning( + self, + x_t: Tensor, + camera: Optional[CamerasBase], + image_rgb: Optional[Tensor], + mask: Optional[Tensor], + t: Optional[Tensor], + ): + """ Extracts local features from the input image and projects them onto the points + in the point cloud to obtain the input to the model. Then extracts global + features, replicates them across points, and concats them to the input. + image_rgb: masked background + XH: why there is no positional encoding as described by the supp?? + """ + B, N = x_t.shape[:2] + + # Initial input is the point locations (and colors if and only if predicting color) + x_t_input = self.get_coord_feature(x_t) + + # Local conditioning + if self.use_local_conditioning: + + # Get local features and check that they are the same size as the input image + local_features = self.get_local_conditioning(image_rgb=image_rgb, mask=mask) # concatenate RGB + mask + RGB feature + distance transform + if local_features.shape[-2:] != image_rgb.shape[-2:]: + raise ValueError(f'{local_features.shape=} and {image_rgb.shape=}') + + # Project local features. Here that we only need the point locations, not colors + local_features_proj = self.surface_projection(points=x_t[:, :, :3], + camera=camera, local_features=local_features) # (B, N, D_local) + + x_t_input.append(local_features_proj) + + # Global conditioning + if self.use_global_conditioning: # False + + # Get and repeat global features + global_features = self.get_global_conditioning(image_rgb=image_rgb) # (B, D_global) + global_features = global_features.unsqueeze(1).expand(-1, N, -1) # (B, D_global, N) + + x_t_input.append(global_features) + + # Concatenate together all the pointwise features + x_t_input = torch.cat(x_t_input, dim=2) # (B, N, D) + + return x_t_input + + def get_coord_feature(self, x_t): + """get coordinate feature, for model that uses separate model to predict binary, we use first 3 channels only""" + x_t_input = [x_t] + return x_t_input + + def forward(self, batch: FrameData, mode: str = 'train', **kwargs): + """ The forward method may be defined differently for different models. """ + raise NotImplementedError() diff --git a/model/pvcnn/__init__.py b/model/pvcnn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/model/pvcnn/modules/__init__.py b/model/pvcnn/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3975207ca20b36c9fcd5f4c4c137eae4852d6b9a --- /dev/null +++ b/model/pvcnn/modules/__init__.py @@ -0,0 +1,8 @@ +from .ball_query import BallQuery, BallQueryHO +from .frustum import FrustumPointNetLoss +from .loss import KLLoss +from .pointnet import PointNetAModule, PointNetSAModule, PointNetFPModule +from .pvconv import PVConv, Attention, Swish, PVConvReLU +from .se import SE3d +from .shared_mlp import SharedMLP +from .voxelization import Voxelization diff --git a/model/pvcnn/modules/ball_query.py b/model/pvcnn/modules/ball_query.py new file mode 100644 index 0000000000000000000000000000000000000000..4ace370968cc273c2f310d8c6200d2b9a0a8bea4 --- /dev/null +++ b/model/pvcnn/modules/ball_query.py @@ -0,0 +1,69 @@ +import torch +import torch.nn as nn + +from . import functional as F + +__all__ = ['BallQuery'] + + +class BallQuery(nn.Module): + def __init__(self, radius, num_neighbors, include_coordinates=True): + super().__init__() + self.radius = radius + self.num_neighbors = num_neighbors + self.include_coordinates = include_coordinates + + def forward(self, points_coords, centers_coords, temb, points_features=None): + points_coords = points_coords.contiguous() + centers_coords = centers_coords.contiguous() + neighbor_indices = F.ball_query(centers_coords, points_coords, self.radius, self.num_neighbors) + neighbor_coordinates = F.grouping(points_coords, neighbor_indices) + neighbor_coordinates = neighbor_coordinates - centers_coords.unsqueeze(-1) + + if points_features is None: + assert self.include_coordinates, 'No Features For Grouping' + neighbor_features = neighbor_coordinates + else: + neighbor_features = F.grouping(points_features, neighbor_indices) # return [B, C, M, U] C=feat dim, M=# centers, U=# neighbours + if self.include_coordinates: + neighbor_features = torch.cat([neighbor_coordinates, neighbor_features], dim=1) + return neighbor_features, F.grouping(temb, neighbor_indices) + + def extra_repr(self): + return 'radius={}, num_neighbors={}{}'.format( + self.radius, self.num_neighbors, ', include coordinates' if self.include_coordinates else '') + + +class BallQueryHO(nn.Module): + "no point feature, but only relative and abs coordinate" + def __init__(self, radius, num_neighbors, include_relative=False): + super().__init__() + self.radius = radius + self.num_neighbors = num_neighbors + self.include_relative = include_relative + + def forward(self, points_coords, centers_coords, points_features=None): + """ + if not enough points inside the given radius, the entries will be zero + if too many points inside the radius, the order is random??? (not sure) + :param points_coords: (B, 3, N) + :param centers_coords: (B, 3, M) + :param points_features: None + :return: + """ + points_coords = points_coords.contiguous() + centers_coords = centers_coords.contiguous() + neighbor_indices = F.ball_query(centers_coords, points_coords, self.radius, self.num_neighbors) + neighbor_coordinates = F.grouping(points_coords, neighbor_indices) # (B, 3, M, U) + if self.include_relative: + neighbor_coordinates_rela = neighbor_coordinates - centers_coords.unsqueeze(-1) + neighbor_coordinates = torch.cat([neighbor_coordinates, neighbor_coordinates_rela], 1) # (B, 6, M, U) + # flatten the coordinate + neighbor_coordinates = neighbor_coordinates.permute(0, 1, 3, 2) # (B, 3/6, U, M) + neighbor_coordinates = torch.flatten(neighbor_coordinates, 1, 2) # (B, 3*U, M) + return neighbor_coordinates + + def extra_repr(self): + return 'radius={}, num_neighbors={}{}'.format( + self.radius, self.num_neighbors, ', include relative' if self.include_relative else '') + diff --git a/model/pvcnn/modules/frustum.py b/model/pvcnn/modules/frustum.py new file mode 100644 index 0000000000000000000000000000000000000000..72c1ba870e024079081e308cad7a8521de952bc3 --- /dev/null +++ b/model/pvcnn/modules/frustum.py @@ -0,0 +1,138 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from . import functional as F + +__all__ = ['FrustumPointNetLoss', 'get_box_corners_3d'] + + +class FrustumPointNetLoss(nn.Module): + def __init__(self, num_heading_angle_bins, num_size_templates, size_templates, box_loss_weight=1.0, + corners_loss_weight=10.0, heading_residual_loss_weight=20.0, size_residual_loss_weight=20.0): + super().__init__() + self.box_loss_weight = box_loss_weight + self.corners_loss_weight = corners_loss_weight + self.heading_residual_loss_weight = heading_residual_loss_weight + self.size_residual_loss_weight = size_residual_loss_weight + + self.num_heading_angle_bins = num_heading_angle_bins + self.num_size_templates = num_size_templates + self.register_buffer('size_templates', size_templates.view(self.num_size_templates, 3)) + self.register_buffer( + 'heading_angle_bin_centers', torch.arange(0, 2 * np.pi, 2 * np.pi / self.num_heading_angle_bins) + ) + + def forward(self, inputs, targets): + mask_logits = inputs['mask_logits'] # (B, 2, N) + center_reg = inputs['center_reg'] # (B, 3) + center = inputs['center'] # (B, 3) + heading_scores = inputs['heading_scores'] # (B, NH) + heading_residuals_normalized = inputs['heading_residuals_normalized'] # (B, NH) + heading_residuals = inputs['heading_residuals'] # (B, NH) + size_scores = inputs['size_scores'] # (B, NS) + size_residuals_normalized = inputs['size_residuals_normalized'] # (B, NS, 3) + size_residuals = inputs['size_residuals'] # (B, NS, 3) + + mask_logits_target = targets['mask_logits'] # (B, N) + center_target = targets['center'] # (B, 3) + heading_bin_id_target = targets['heading_bin_id'] # (B, ) + heading_residual_target = targets['heading_residual'] # (B, ) + size_template_id_target = targets['size_template_id'] # (B, ) + size_residual_target = targets['size_residual'] # (B, 3) + + batch_size = center.size(0) + batch_id = torch.arange(batch_size, device=center.device) + + # Basic Classification and Regression losses + mask_loss = F.cross_entropy(mask_logits, mask_logits_target) + heading_loss = F.cross_entropy(heading_scores, heading_bin_id_target) + size_loss = F.cross_entropy(size_scores, size_template_id_target) + center_loss = PF.huber_loss(torch.norm(center_target - center, dim=-1), delta=2.0) + center_reg_loss = PF.huber_loss(torch.norm(center_target - center_reg, dim=-1), delta=1.0) + + # Refinement losses for size/heading + heading_residuals_normalized = heading_residuals_normalized[batch_id, heading_bin_id_target] # (B, ) + heading_residual_normalized_target = heading_residual_target / (np.pi / self.num_heading_angle_bins) + heading_residual_normalized_loss = PF.huber_loss( + heading_residuals_normalized - heading_residual_normalized_target, delta=1.0 + ) + size_residuals_normalized = size_residuals_normalized[batch_id, size_template_id_target] # (B, 3) + size_residual_normalized_target = size_residual_target / self.size_templates[size_template_id_target] + size_residual_normalized_loss = PF.huber_loss( + torch.norm(size_residual_normalized_target - size_residuals_normalized, dim=-1), delta=1.0 + ) + + # Bounding box losses + heading = (heading_residuals[batch_id, heading_bin_id_target] + + self.heading_angle_bin_centers[heading_bin_id_target]) # (B, ) + # Warning: in origin code, size_residuals are added twice (issue #43 and #49 in charlesq34/frustum-pointnets) + size = (size_residuals[batch_id, size_template_id_target] + + self.size_templates[size_template_id_target]) # (B, 3) + corners = get_box_corners_3d(centers=center, headings=heading, sizes=size, with_flip=False) # (B, 3, 8) + heading_target = self.heading_angle_bin_centers[heading_bin_id_target] + heading_residual_target # (B, ) + size_target = self.size_templates[size_template_id_target] + size_residual_target # (B, 3) + corners_target, corners_target_flip = get_box_corners_3d(centers=center_target, headings=heading_target, + sizes=size_target, with_flip=True) # (B, 3, 8) + corners_loss = PF.huber_loss(torch.min( + torch.norm(corners - corners_target, dim=1), torch.norm(corners - corners_target_flip, dim=1) + ), delta=1.0) + # Summing up + loss = mask_loss + self.box_loss_weight * ( + center_loss + center_reg_loss + heading_loss + size_loss + + self.heading_residual_loss_weight * heading_residual_normalized_loss + + self.size_residual_loss_weight * size_residual_normalized_loss + + self.corners_loss_weight * corners_loss + ) + + return loss + + +def get_box_corners_3d(centers, headings, sizes, with_flip=False): + """ + :param centers: coords of box centers, FloatTensor[N, 3] + :param headings: heading angles, FloatTensor[N, ] + :param sizes: box sizes, FloatTensor[N, 3] + :param with_flip: bool, whether to return flipped box (headings + np.pi) + :return: + coords of box corners, FloatTensor[N, 3, 8] + NOTE: corner points are in counter clockwise order, e.g., + 2--1 + 3--0 5 + 7--4 + """ + l = sizes[:, 0] # (N,) + w = sizes[:, 1] # (N,) + h = sizes[:, 2] # (N,) + x_corners = torch.stack([l/2, l/2, -l/2, -l/2, l/2, l/2, -l/2, -l/2], dim=1) # (N, 8) + y_corners = torch.stack([h/2, h/2, h/2, h/2, -h/2, -h/2, -h/2, -h/2], dim=1) # (N, 8) + z_corners = torch.stack([w/2, -w/2, -w/2, w/2, w/2, -w/2, -w/2, w/2], dim=1) # (N, 8) + + c = torch.cos(headings) # (N,) + s = torch.sin(headings) # (N,) + o = torch.ones_like(headings) # (N,) + z = torch.zeros_like(headings) # (N,) + + centers = centers.unsqueeze(-1) # (B, 3, 1) + corners = torch.stack([x_corners, y_corners, z_corners], dim=1) # (N, 3, 8) + R = torch.stack([c, z, s, z, o, z, -s, z, c], dim=1).view(-1, 3, 3) # roty matrix: (N, 3, 3) + if with_flip: + R_flip = torch.stack([-c, z, -s, z, o, z, s, z, -c], dim=1).view(-1, 3, 3) + return torch.matmul(R, corners) + centers, torch.matmul(R_flip, corners) + centers + else: + return torch.matmul(R, corners) + centers + + # centers = centers.unsqueeze(1) # (B, 1, 3) + # corners = torch.stack([x_corners, y_corners, z_corners], dim=-1) # (N, 8, 3) + # RT = torch.stack([c, z, -s, z, o, z, s, z, c], dim=1).view(-1, 3, 3) # (N, 3, 3) + # if with_flip: + # RT_flip = torch.stack([-c, z, s, z, o, z, -s, z, -c], dim=1).view(-1, 3, 3) # (N, 3, 3) + # return torch.matmul(corners, RT) + centers, torch.matmul(corners, RT_flip) + centers # (N, 8, 3) + # else: + # return torch.matmul(corners, RT) + centers # (N, 8, 3) + + # corners = torch.stack([x_corners, y_corners, z_corners], dim=1) # (N, 3, 8) + # R = torch.stack([c, z, s, z, o, z, -s, z, c], dim=1).view(-1, 3, 3) # (N, 3, 3) + # corners = torch.matmul(R, corners) + centers.unsqueeze(2) # (N, 3, 8) + # corners = corners.transpose(1, 2) # (N, 8, 3) diff --git a/model/pvcnn/modules/functional/__init__.py b/model/pvcnn/modules/functional/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b945c5dd1521e7446e5344714e85af1f89f24510 --- /dev/null +++ b/model/pvcnn/modules/functional/__init__.py @@ -0,0 +1,7 @@ +from .ball_query import ball_query +from .devoxelization import trilinear_devoxelize +from .grouping import grouping +from .interpolatation import nearest_neighbor_interpolate +from .loss import kl_loss, huber_loss +from .sampling import gather, furthest_point_sample, logits_mask +from .voxelization import avg_voxelize diff --git a/model/pvcnn/modules/functional/backend.py b/model/pvcnn/modules/functional/backend.py new file mode 100644 index 0000000000000000000000000000000000000000..c7bb555291b5784cf0e416d23f12b9c397538c83 --- /dev/null +++ b/model/pvcnn/modules/functional/backend.py @@ -0,0 +1,33 @@ +import os +from pathlib import Path + +from torch.utils.cpp_extension import load + + +gcc_path = os.getenv('CC', default='/usr/bin/gcc') +if not Path(gcc_path).is_file(): + raise ValueError('Could not find your gcc, please replace it here.') + +_src_path = os.path.dirname(os.path.abspath(__file__)) +_backend = load( + name='_pvcnn_backend', + extra_cflags=['-O3', '-std=c++17'], + extra_cuda_cflags=[f'--compiler-bindir={gcc_path}'], + sources=[os.path.join(_src_path,'src', f) for f in [ + 'ball_query/ball_query.cpp', + 'ball_query/ball_query.cu', + 'grouping/grouping.cpp', + 'grouping/grouping.cu', + 'interpolate/neighbor_interpolate.cpp', + 'interpolate/neighbor_interpolate.cu', + 'interpolate/trilinear_devox.cpp', + 'interpolate/trilinear_devox.cu', + 'sampling/sampling.cpp', + 'sampling/sampling.cu', + 'voxelization/vox.cpp', + 'voxelization/vox.cu', + 'bindings.cpp', + ]] +) + +__all__ = ['_backend'] diff --git a/model/pvcnn/modules/functional/ball_query.py b/model/pvcnn/modules/functional/ball_query.py new file mode 100644 index 0000000000000000000000000000000000000000..4daef84d172a4c79d4657a7905aca6f3a8b02cc1 --- /dev/null +++ b/model/pvcnn/modules/functional/ball_query.py @@ -0,0 +1,19 @@ +from torch.autograd import Function + +from .backend import _backend + +__all__ = ['ball_query'] + + +def ball_query(centers_coords, points_coords, radius, num_neighbors): + """ + :param centers_coords: coordinates of centers, FloatTensor[B, 3, M] + :param points_coords: coordinates of points, FloatTensor[B, 3, N] + :param radius: float, radius of ball query + :param num_neighbors: int, maximum number of neighbors + :return: + neighbor_indices: indices of neighbors, IntTensor[B, M, U] + """ + centers_coords = centers_coords.contiguous() + points_coords = points_coords.contiguous() + return _backend.ball_query(centers_coords, points_coords, radius, num_neighbors) diff --git a/model/pvcnn/modules/functional/devoxelization.py b/model/pvcnn/modules/functional/devoxelization.py new file mode 100644 index 0000000000000000000000000000000000000000..e9dd014e393a468b8b5b40b8654860fc7ab87eea --- /dev/null +++ b/model/pvcnn/modules/functional/devoxelization.py @@ -0,0 +1,42 @@ +from torch.autograd import Function + +from .backend import _backend + +__all__ = ['trilinear_devoxelize'] + + +class TrilinearDevoxelization(Function): + @staticmethod + def forward(ctx, features, coords, resolution, is_training=True): + """ + :param ctx: + :param coords: the coordinates of points, FloatTensor[B, 3, N] + :param features: FloatTensor[B, C, R, R, R] + :param resolution: int, the voxel resolution + :param is_training: bool, training mode + :return: + FloatTensor[B, C, N] + """ + B, C = features.shape[:2] + features = features.contiguous().view(B, C, -1) + coords = coords.contiguous() + outs, inds, wgts = _backend.trilinear_devoxelize_forward(resolution, is_training, coords, features) + if is_training: + ctx.save_for_backward(inds, wgts) + ctx.r = resolution + return outs + + @staticmethod + def backward(ctx, grad_output): + """ + :param ctx: + :param grad_output: gradient of outputs, FloatTensor[B, C, N] + :return: + gradient of inputs, FloatTensor[B, C, R, R, R] + """ + inds, wgts = ctx.saved_tensors + grad_inputs = _backend.trilinear_devoxelize_backward(grad_output.contiguous(), inds, wgts, ctx.r) + return grad_inputs.view(grad_output.size(0), grad_output.size(1), ctx.r, ctx.r, ctx.r), None, None, None + + +trilinear_devoxelize = TrilinearDevoxelization.apply diff --git a/model/pvcnn/modules/functional/grouping.py b/model/pvcnn/modules/functional/grouping.py new file mode 100644 index 0000000000000000000000000000000000000000..db19ef7dba8333499e5deaffdbaf3d6fb860f71f --- /dev/null +++ b/model/pvcnn/modules/functional/grouping.py @@ -0,0 +1,32 @@ +from torch.autograd import Function + +from .backend import _backend + +__all__ = ['grouping'] + + +class Grouping(Function): + @staticmethod + def forward(ctx, features, indices): + """ + :param ctx: + :param features: features of points, FloatTensor[B, C, N] + :param indices: neighbor indices of centers, IntTensor[B, M, U], M is #centers, U is #neighbors + :return: + grouped_features: grouped features, FloatTensor[B, C, M, U] + """ + features = features.contiguous() + indices = indices.contiguous() + ctx.save_for_backward(indices) + ctx.num_points = features.size(-1) + # print(features.dtype, features.shape) + return _backend.grouping_forward(features, indices) + + @staticmethod + def backward(ctx, grad_output): + indices, = ctx.saved_tensors + grad_features = _backend.grouping_backward(grad_output.contiguous(), indices, ctx.num_points) + return grad_features, None + + +grouping = Grouping.apply diff --git a/model/pvcnn/modules/functional/interpolatation.py b/model/pvcnn/modules/functional/interpolatation.py new file mode 100644 index 0000000000000000000000000000000000000000..a83378d0f18901d2ced6ba3acb1691a1d1bc20b8 --- /dev/null +++ b/model/pvcnn/modules/functional/interpolatation.py @@ -0,0 +1,38 @@ +from torch.autograd import Function + +from .backend import _backend + +__all__ = ['nearest_neighbor_interpolate'] + + +class NeighborInterpolation(Function): + @staticmethod + def forward(ctx, points_coords, centers_coords, centers_features): + """ + :param ctx: + :param points_coords: coordinates of points, FloatTensor[B, 3, N] + :param centers_coords: coordinates of centers, FloatTensor[B, 3, M] + :param centers_features: features of centers, FloatTensor[B, C, M] + :return: + points_features: features of points, FloatTensor[B, C, N] + """ + centers_coords = centers_coords.contiguous() + points_coords = points_coords.contiguous() + centers_features = centers_features.contiguous() + points_features, indices, weights = _backend.three_nearest_neighbors_interpolate_forward( + points_coords, centers_coords, centers_features + ) + ctx.save_for_backward(indices, weights) + ctx.num_centers = centers_coords.size(-1) + return points_features + + @staticmethod + def backward(ctx, grad_output): + indices, weights = ctx.saved_tensors + grad_centers_features = _backend.three_nearest_neighbors_interpolate_backward( + grad_output.contiguous(), indices, weights, ctx.num_centers + ) + return None, None, grad_centers_features + + +nearest_neighbor_interpolate = NeighborInterpolation.apply diff --git a/model/pvcnn/modules/functional/loss.py b/model/pvcnn/modules/functional/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..41112b3271c6f340047e8eb874b70ee630597b32 --- /dev/null +++ b/model/pvcnn/modules/functional/loss.py @@ -0,0 +1,17 @@ +import torch +import torch.nn.functional as F + +__all__ = ['kl_loss', 'huber_loss'] + + +def kl_loss(x, y): + x = F.softmax(x.detach(), dim=1) + y = F.log_softmax(y, dim=1) + return torch.mean(torch.sum(x * (torch.log(x) - y), dim=1)) + + +def huber_loss(error, delta): + abs_error = torch.abs(error) + quadratic = torch.min(abs_error, torch.full_like(abs_error, fill_value=delta)) + losses = 0.5 * (quadratic ** 2) + delta * (abs_error - quadratic) + return torch.mean(losses) diff --git a/model/pvcnn/modules/functional/sampling.py b/model/pvcnn/modules/functional/sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..b08ec63446fad34dc947bd966de12e6c4161733b --- /dev/null +++ b/model/pvcnn/modules/functional/sampling.py @@ -0,0 +1,84 @@ +import numpy as np +import torch +from torch.autograd import Function + +from .backend import _backend + +__all__ = ['gather', 'furthest_point_sample', 'logits_mask'] + + +class Gather(Function): + @staticmethod + def forward(ctx, features, indices): + """ + Gather + :param ctx: + :param features: features of points, FloatTensor[B, C, N] + :param indices: centers' indices in points, IntTensor[b, m] + :return: + centers_coords: coordinates of sampled centers, FloatTensor[B, C, M] + """ + features = features.contiguous() + indices = indices.int().contiguous() + ctx.save_for_backward(indices) + ctx.num_points = features.size(-1) + return _backend.gather_features_forward(features, indices) + + @staticmethod + def backward(ctx, grad_output): + indices, = ctx.saved_tensors + grad_features = _backend.gather_features_backward(grad_output.contiguous(), indices, ctx.num_points) + return grad_features, None + + +gather = Gather.apply + + +def furthest_point_sample(coords, num_samples): + """ + Uses iterative furthest point sampling to select a set of npoint features that have the largest + minimum distance to the sampled point set + :param coords: coordinates of points, FloatTensor[B, 3, N] + :param num_samples: int, M + :return: + centers_coords: coordinates of sampled centers, FloatTensor[B, 3, M] + """ + coords = coords.contiguous() + indices = _backend.furthest_point_sampling(coords, num_samples) + return gather(coords, indices) + + +def logits_mask(coords, logits, num_points_per_object): + """ + Use logits to sample points + :param coords: coords of points, FloatTensor[B, 3, N] + :param logits: binary classification logits, FloatTensor[B, 2, N] + :param num_points_per_object: M, #points per object after masking, int + :return: + selected_coords: FloatTensor[B, 3, M] + masked_coords_mean: mean coords of selected points, FloatTensor[B, 3] + mask: mask to select points, BoolTensor[B, N] + """ + batch_size, _, num_points = coords.shape + mask = torch.lt(logits[:, 0, :], logits[:, 1, :]) # [B, N] + num_candidates = torch.sum(mask, dim=-1, keepdim=True) # [B, 1] + masked_coords = coords * mask.view(batch_size, 1, num_points) # [B, C, N] + masked_coords_mean = torch.sum(masked_coords, dim=-1) / torch.max(num_candidates, + torch.ones_like(num_candidates)).float() # [B, C] + selected_indices = torch.zeros((batch_size, num_points_per_object), device=coords.device, dtype=torch.int32) + for i in range(batch_size): + current_mask = mask[i] # [N] + current_candidates = current_mask.nonzero().view(-1) + current_num_candidates = current_candidates.numel() + if current_num_candidates >= num_points_per_object: + choices = np.random.choice(current_num_candidates, num_points_per_object, replace=False) + selected_indices[i] = current_candidates[choices] + elif current_num_candidates > 0: + choices = np.concatenate([ + np.arange(current_num_candidates).repeat(num_points_per_object // current_num_candidates), + np.random.choice(current_num_candidates, num_points_per_object % current_num_candidates, replace=False) + ]) + np.random.shuffle(choices) + selected_indices[i] = current_candidates[choices] + selected_coords = gather(masked_coords - masked_coords_mean.view(batch_size, -1, 1), selected_indices) + return selected_coords, masked_coords_mean, mask diff --git a/model/pvcnn/modules/functional/src/ball_query/ball_query.cpp b/model/pvcnn/modules/functional/src/ball_query/ball_query.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5ae1fb6f6ada981c75c5dc8e986ae2d903bf61c2 --- /dev/null +++ b/model/pvcnn/modules/functional/src/ball_query/ball_query.cpp @@ -0,0 +1,30 @@ +#include "ball_query.hpp" +#include "ball_query.cuh" + +#include "../utils.hpp" + +at::Tensor ball_query_forward(at::Tensor centers_coords, + at::Tensor points_coords, const float radius, + const int num_neighbors) { + CHECK_CUDA(centers_coords); + CHECK_CUDA(points_coords); + CHECK_CONTIGUOUS(centers_coords); + CHECK_CONTIGUOUS(points_coords); + CHECK_IS_FLOAT(centers_coords); + CHECK_IS_FLOAT(points_coords); + + int b = centers_coords.size(0); + int m = centers_coords.size(2); + int n = points_coords.size(2); + + at::Tensor neighbors_indices = torch::zeros( + {b, m, num_neighbors}, + at::device(centers_coords.device()).dtype(at::ScalarType::Int)); + + ball_query(b, n, m, radius * radius, num_neighbors, + centers_coords.data_ptr(), + points_coords.data_ptr(), + neighbors_indices.data_ptr()); + + return neighbors_indices; +} diff --git a/model/pvcnn/modules/functional/src/ball_query/ball_query.cu b/model/pvcnn/modules/functional/src/ball_query/ball_query.cu new file mode 100644 index 0000000000000000000000000000000000000000..079e3cb86cb2ee51bb1326b1c7ecb1e70c8a4ae8 --- /dev/null +++ b/model/pvcnn/modules/functional/src/ball_query/ball_query.cu @@ -0,0 +1,59 @@ +#include +#include +#include + +#include "../cuda_utils.cuh" + +/* + Function: ball query + Args: + b : batch size + n : number of points in point clouds + m : number of query centers + r2 : ball query radius ** 2 + u : maximum number of neighbors + centers_coords: coordinates of centers, FloatTensor[b, 3, m] + points_coords : coordinates of points, FloatTensor[b, 3, n] + neighbors_indices : neighbor indices in points, IntTensor[b, m, u] +*/ +__global__ void ball_query_kernel(int b, int n, int m, float r2, int u, + const float *__restrict__ centers_coords, + const float *__restrict__ points_coords, + int *__restrict__ neighbors_indices) { + int batch_index = blockIdx.x; + int index = threadIdx.x; + int stride = blockDim.x; + points_coords += batch_index * n * 3; + centers_coords += batch_index * m * 3; + neighbors_indices += batch_index * m * u; + + for (int j = index; j < m; j += stride) { + float center_x = centers_coords[j]; + float center_y = centers_coords[j + m]; + float center_z = centers_coords[j + m + m]; + for (int k = 0, cnt = 0; k < n && cnt < u; ++k) { + float dx = center_x - points_coords[k]; + float dy = center_y - points_coords[k + n]; + float dz = center_z - points_coords[k + n + n]; + float d2 = dx * dx + dy * dy + dz * dz; + if (d2 < r2) { + if (cnt == 0) { + for (int v = 0; v < u; ++v) { + neighbors_indices[j * u + v] = k; + } + } + neighbors_indices[j * u + cnt] = k; + ++cnt; + } + } + } +} + +void ball_query(int b, int n, int m, float r2, int u, + const float *centers_coords, const float *points_coords, + int *neighbors_indices) { + ball_query_kernel<<>>( + b, n, m, r2, u, centers_coords, points_coords, neighbors_indices); + CUDA_CHECK_ERRORS(); +} diff --git a/model/pvcnn/modules/functional/src/ball_query/ball_query.cuh b/model/pvcnn/modules/functional/src/ball_query/ball_query.cuh new file mode 100644 index 0000000000000000000000000000000000000000..ba32492f17733007d4f0cbfef327c74cdd35dd0b --- /dev/null +++ b/model/pvcnn/modules/functional/src/ball_query/ball_query.cuh @@ -0,0 +1,8 @@ +#ifndef _BALL_QUERY_CUH +#define _BALL_QUERY_CUH + +void ball_query(int b, int n, int m, float r2, int u, + const float *centers_coords, const float *points_coords, + int *neighbors_indices); + +#endif diff --git a/model/pvcnn/modules/functional/src/ball_query/ball_query.hpp b/model/pvcnn/modules/functional/src/ball_query/ball_query.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d87bbd93258a8a4ee08ae20b03a544a2d79ebb68 --- /dev/null +++ b/model/pvcnn/modules/functional/src/ball_query/ball_query.hpp @@ -0,0 +1,10 @@ +#ifndef _BALL_QUERY_HPP +#define _BALL_QUERY_HPP + +#include + +at::Tensor ball_query_forward(at::Tensor centers_coords, + at::Tensor points_coords, const float radius, + const int num_neighbors); + +#endif diff --git a/model/pvcnn/modules/functional/src/bindings.cpp b/model/pvcnn/modules/functional/src/bindings.cpp new file mode 100644 index 0000000000000000000000000000000000000000..994e01b5861fe20269e26ae49c2e00cadbef2b01 --- /dev/null +++ b/model/pvcnn/modules/functional/src/bindings.cpp @@ -0,0 +1,37 @@ +#include + +#include "ball_query/ball_query.hpp" +#include "grouping/grouping.hpp" +#include "interpolate/neighbor_interpolate.hpp" +#include "interpolate/trilinear_devox.hpp" +#include "sampling/sampling.hpp" +#include "voxelization/vox.hpp" + +PYBIND11_MODULE(_pvcnn_backend, m) { + m.def("gather_features_forward", &gather_features_forward, + "Gather Centers' Features forward (CUDA)"); + m.def("gather_features_backward", &gather_features_backward, + "Gather Centers' Features backward (CUDA)"); + m.def("furthest_point_sampling", &furthest_point_sampling_forward, + "Furthest Point Sampling (CUDA)"); + m.def("ball_query", &ball_query_forward, "Ball Query (CUDA)"); + m.def("grouping_forward", &grouping_forward, + "Grouping Features forward (CUDA)"); + m.def("grouping_backward", &grouping_backward, + "Grouping Features backward (CUDA)"); + m.def("three_nearest_neighbors_interpolate_forward", + &three_nearest_neighbors_interpolate_forward, + "3 Nearest Neighbors Interpolate forward (CUDA)"); + m.def("three_nearest_neighbors_interpolate_backward", + &three_nearest_neighbors_interpolate_backward, + "3 Nearest Neighbors Interpolate backward (CUDA)"); + + m.def("trilinear_devoxelize_forward", &trilinear_devoxelize_forward, + "Trilinear Devoxelization forward (CUDA)"); + m.def("trilinear_devoxelize_backward", &trilinear_devoxelize_backward, + "Trilinear Devoxelization backward (CUDA)"); + m.def("avg_voxelize_forward", &avg_voxelize_forward, + "Voxelization forward with average pooling (CUDA)"); + m.def("avg_voxelize_backward", &avg_voxelize_backward, + "Voxelization backward (CUDA)"); +} diff --git a/model/pvcnn/modules/functional/src/cuda_utils.cuh b/model/pvcnn/modules/functional/src/cuda_utils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..01bf5512914af1a0265697990a21eb760c14663c --- /dev/null +++ b/model/pvcnn/modules/functional/src/cuda_utils.cuh @@ -0,0 +1,39 @@ +#ifndef _CUDA_UTILS_H +#define _CUDA_UTILS_H + +#include +#include +#include + +#include +#include + +#include + +#define MAXIMUM_THREADS 512 + +inline int optimal_num_threads(int work_size) { + const int pow_2 = std::log2(static_cast(work_size)); + return max(min(1 << pow_2, MAXIMUM_THREADS), 1); +} + +inline dim3 optimal_block_config(int x, int y) { + const int x_threads = optimal_num_threads(x); + const int y_threads = + max(min(optimal_num_threads(y), MAXIMUM_THREADS / x_threads), 1); + dim3 block_config(x_threads, y_threads, 1); + return block_config; +} + +#define CUDA_CHECK_ERRORS() \ + { \ + cudaError_t err = cudaGetLastError(); \ + if (cudaSuccess != err) { \ + fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \ + cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \ + __FILE__); \ + exit(-1); \ + } \ + } + +#endif diff --git a/model/pvcnn/modules/functional/src/grouping/grouping.cpp b/model/pvcnn/modules/functional/src/grouping/grouping.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4f97650069f3961946e16b68722893e4bb65cb26 --- /dev/null +++ b/model/pvcnn/modules/functional/src/grouping/grouping.cpp @@ -0,0 +1,44 @@ +#include "grouping.hpp" +#include "grouping.cuh" + +#include "../utils.hpp" + +at::Tensor grouping_forward(at::Tensor features, at::Tensor indices) { + CHECK_CUDA(features); + CHECK_CUDA(indices); + CHECK_CONTIGUOUS(features); + CHECK_CONTIGUOUS(indices); + CHECK_IS_FLOAT(features); + CHECK_IS_INT(indices); + + int b = features.size(0); + int c = features.size(1); + int n = features.size(2); + int m = indices.size(1); + int u = indices.size(2); + at::Tensor output = torch::zeros( + {b, c, m, u}, at::device(features.device()).dtype(at::ScalarType::Float)); + grouping(b, c, n, m, u, features.data_ptr(), indices.data_ptr(), + output.data_ptr()); + return output; +} + +at::Tensor grouping_backward(at::Tensor grad_y, at::Tensor indices, + const int n) { + CHECK_CUDA(grad_y); + CHECK_CUDA(indices); + CHECK_CONTIGUOUS(grad_y); + CHECK_CONTIGUOUS(indices); + CHECK_IS_FLOAT(grad_y); + CHECK_IS_INT(indices); + + int b = grad_y.size(0); + int c = grad_y.size(1); + int m = indices.size(1); + int u = indices.size(2); + at::Tensor grad_x = torch::zeros( + {b, c, n}, at::device(grad_y.device()).dtype(at::ScalarType::Float)); + grouping_grad(b, c, n, m, u, grad_y.data_ptr(), + indices.data_ptr(), grad_x.data_ptr()); + return grad_x; +} diff --git a/model/pvcnn/modules/functional/src/grouping/grouping.cu b/model/pvcnn/modules/functional/src/grouping/grouping.cu new file mode 100644 index 0000000000000000000000000000000000000000..0cf561a1e94874d968f575e0f63170069c05fa80 --- /dev/null +++ b/model/pvcnn/modules/functional/src/grouping/grouping.cu @@ -0,0 +1,85 @@ +#include +#include + +#include "../cuda_utils.cuh" + +/* + Function: grouping features of neighbors (forward) + Args: + b : batch size + c : #channles of features + n : number of points in point clouds + m : number of query centers + u : maximum number of neighbors + features: points' features, FloatTensor[b, c, n] + indices : neighbor indices in points, IntTensor[b, m, u] + out : gathered features, FloatTensor[b, c, m, u] +*/ +__global__ void grouping_kernel(int b, int c, int n, int m, int u, + const float *__restrict__ features, + const int *__restrict__ indices, + float *__restrict__ out) { + int batch_index = blockIdx.x; + features += batch_index * n * c; + indices += batch_index * m * u; + out += batch_index * m * u * c; + + const int index = threadIdx.y * blockDim.x + threadIdx.x; + const int stride = blockDim.y * blockDim.x; + for (int i = index; i < c * m; i += stride) { + const int l = i / m; + const int j = i % m; + for (int k = 0; k < u; ++k) { + out[(l * m + j) * u + k] = features[l * n + indices[j * u + k]]; + } + } +} + +void grouping(int b, int c, int n, int m, int u, const float *features, + const int *indices, float *out) { + grouping_kernel<<>>(b, c, n, m, u, features, + indices, out); + CUDA_CHECK_ERRORS(); +} + +/* + Function: grouping features of neighbors (backward) + Args: + b : batch size + c : #channles of features + n : number of points in point clouds + m : number of query centers + u : maximum number of neighbors + grad_y : grad of gathered features, FloatTensor[b, c, m, u] + indices : neighbor indices in points, IntTensor[b, m, u] + grad_x: grad of points' features, FloatTensor[b, c, n] +*/ +__global__ void grouping_grad_kernel(int b, int c, int n, int m, int u, + const float *__restrict__ grad_y, + const int *__restrict__ indices, + float *__restrict__ grad_x) { + int batch_index = blockIdx.x; + grad_y += batch_index * m * u * c; + indices += batch_index * m * u; + grad_x += batch_index * n * c; + + const int index = threadIdx.y * blockDim.x + threadIdx.x; + const int stride = blockDim.y * blockDim.x; + for (int i = index; i < c * m; i += stride) { + const int l = i / m; + const int j = i % m; + for (int k = 0; k < u; ++k) { + atomicAdd(grad_x + l * n + indices[j * u + k], + grad_y[(l * m + j) * u + k]); + } + } +} + +void grouping_grad(int b, int c, int n, int m, int u, const float *grad_y, + const int *indices, float *grad_x) { + grouping_grad_kernel<<>>( + b, c, n, m, u, grad_y, indices, grad_x); + CUDA_CHECK_ERRORS(); +} diff --git a/model/pvcnn/modules/functional/src/grouping/grouping.cuh b/model/pvcnn/modules/functional/src/grouping/grouping.cuh new file mode 100644 index 0000000000000000000000000000000000000000..c8a114feda3e63bd0cbdbd271bb0c476181b65b1 --- /dev/null +++ b/model/pvcnn/modules/functional/src/grouping/grouping.cuh @@ -0,0 +1,9 @@ +#ifndef _GROUPING_CUH +#define _GROUPING_CUH + +void grouping(int b, int c, int n, int m, int u, const float *features, + const int *indices, float *out); +void grouping_grad(int b, int c, int n, int m, int u, const float *grad_y, + const int *indices, float *grad_x); + +#endif \ No newline at end of file diff --git a/model/pvcnn/modules/functional/src/grouping/grouping.hpp b/model/pvcnn/modules/functional/src/grouping/grouping.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3f5733d2f53bd7a02b2b16cb1f0303af84cd6c7e --- /dev/null +++ b/model/pvcnn/modules/functional/src/grouping/grouping.hpp @@ -0,0 +1,10 @@ +#ifndef _GROUPING_HPP +#define _GROUPING_HPP + +#include + +at::Tensor grouping_forward(at::Tensor features, at::Tensor indices); +at::Tensor grouping_backward(at::Tensor grad_y, at::Tensor indices, + const int n); + +#endif diff --git a/model/pvcnn/modules/functional/src/interpolate/neighbor_interpolate.cpp b/model/pvcnn/modules/functional/src/interpolate/neighbor_interpolate.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fc73c43cda3eea7a9bb4635851c9023359d05079 --- /dev/null +++ b/model/pvcnn/modules/functional/src/interpolate/neighbor_interpolate.cpp @@ -0,0 +1,65 @@ +#include "neighbor_interpolate.hpp" +#include "neighbor_interpolate.cuh" + +#include "../utils.hpp" + +std::vector +three_nearest_neighbors_interpolate_forward(at::Tensor points_coords, + at::Tensor centers_coords, + at::Tensor centers_features) { + CHECK_CUDA(points_coords); + CHECK_CUDA(centers_coords); + CHECK_CUDA(centers_features); + CHECK_CONTIGUOUS(points_coords); + CHECK_CONTIGUOUS(centers_coords); + CHECK_CONTIGUOUS(centers_features); + CHECK_IS_FLOAT(points_coords); + CHECK_IS_FLOAT(centers_coords); + CHECK_IS_FLOAT(centers_features); + + int b = centers_features.size(0); + int c = centers_features.size(1); + int m = centers_features.size(2); + int n = points_coords.size(2); + + at::Tensor indices = torch::zeros( + {b, 3, n}, at::device(points_coords.device()).dtype(at::ScalarType::Int)); + at::Tensor weights = torch::zeros( + {b, 3, n}, + at::device(points_coords.device()).dtype(at::ScalarType::Float)); + at::Tensor output = torch::zeros( + {b, c, n}, + at::device(centers_features.device()).dtype(at::ScalarType::Float)); + + three_nearest_neighbors_interpolate( + b, c, m, n, points_coords.data_ptr(), + centers_coords.data_ptr(), centers_features.data_ptr(), + indices.data_ptr(), weights.data_ptr(), + output.data_ptr()); + return {output, indices, weights}; +} + +at::Tensor three_nearest_neighbors_interpolate_backward(at::Tensor grad_y, + at::Tensor indices, + at::Tensor weights, + const int m) { + CHECK_CUDA(grad_y); + CHECK_CUDA(indices); + CHECK_CUDA(weights); + CHECK_CONTIGUOUS(grad_y); + CHECK_CONTIGUOUS(indices); + CHECK_CONTIGUOUS(weights); + CHECK_IS_FLOAT(grad_y); + CHECK_IS_INT(indices); + CHECK_IS_FLOAT(weights); + + int b = grad_y.size(0); + int c = grad_y.size(1); + int n = grad_y.size(2); + at::Tensor grad_x = torch::zeros( + {b, c, m}, at::device(grad_y.device()).dtype(at::ScalarType::Float)); + three_nearest_neighbors_interpolate_grad( + b, c, n, m, grad_y.data_ptr(), indices.data_ptr(), + weights.data_ptr(), grad_x.data_ptr()); + return grad_x; +} diff --git a/model/pvcnn/modules/functional/src/interpolate/neighbor_interpolate.cu b/model/pvcnn/modules/functional/src/interpolate/neighbor_interpolate.cu new file mode 100644 index 0000000000000000000000000000000000000000..8168507aacc04f2e24b56b7e8d1fc635f169afd6 --- /dev/null +++ b/model/pvcnn/modules/functional/src/interpolate/neighbor_interpolate.cu @@ -0,0 +1,181 @@ +#include +#include +#include + +#include "../cuda_utils.cuh" + +/* + Function: three nearest neighbors + Args: + b : batch size + n : number of points in point clouds + m : number of query centers + points_coords : coordinates of points, FloatTensor[b, 3, n] + centers_coords: coordinates of centers, FloatTensor[b, 3, m] + weights : weights of nearest 3 centers to the point, + FloatTensor[b, 3, n] + indices : indices of nearest 3 centers to the point, + IntTensor[b, 3, n] +*/ +__global__ void three_nearest_neighbors_kernel( + int b, int n, int m, const float *__restrict__ points_coords, + const float *__restrict__ centers_coords, float *__restrict__ weights, + int *__restrict__ indices) { + int batch_index = blockIdx.x; + int index = threadIdx.x; + int stride = blockDim.x; + points_coords += batch_index * 3 * n; + weights += batch_index * 3 * n; + indices += batch_index * 3 * n; + centers_coords += batch_index * 3 * m; + + for (int j = index; j < n; j += stride) { + float ux = points_coords[j]; + float uy = points_coords[j + n]; + float uz = points_coords[j + n + n]; + + double best0 = 1e40, best1 = 1e40, best2 = 1e40; + int besti0 = 0, besti1 = 0, besti2 = 0; + for (int k = 0; k < m; ++k) { + float x = centers_coords[k]; + float y = centers_coords[k + m]; + float z = centers_coords[k + m + m]; + float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); + if (d < best2) { + best2 = d; + besti2 = k; + if (d < best1) { + best2 = best1; + besti2 = besti1; + best1 = d; + besti1 = k; + if (d < best0) { + best1 = best0; + besti1 = besti0; + best0 = d; + besti0 = k; + } + } + } + } + best0 = max(min(1e10f, best0), 1e-10f); + best1 = max(min(1e10f, best1), 1e-10f); + best2 = max(min(1e10f, best2), 1e-10f); + float d0d1 = best0 * best1; + float d0d2 = best0 * best2; + float d1d2 = best1 * best2; + float d0d1d2 = 1.0f / (d0d1 + d0d2 + d1d2); + weights[j] = d1d2 * d0d1d2; + indices[j] = besti0; + weights[j + n] = d0d2 * d0d1d2; + indices[j + n] = besti1; + weights[j + n + n] = d0d1 * d0d1d2; + indices[j + n + n] = besti2; + } +} + +/* + Function: interpolate three nearest neighbors (forward) + Args: + b : batch size + c : #channels of features + m : number of query centers + n : number of points in point clouds + centers_features: features of centers, FloatTensor[b, c, m] + indices : indices of nearest 3 centers to the point, + IntTensor[b, 3, n] + weights : weights for interpolation, FloatTensor[b, 3, n] + out : features of points, FloatTensor[b, c, n] +*/ +__global__ void three_nearest_neighbors_interpolate_kernel( + int b, int c, int m, int n, const float *__restrict__ centers_features, + const int *__restrict__ indices, const float *__restrict__ weights, + float *__restrict__ out) { + int batch_index = blockIdx.x; + centers_features += batch_index * m * c; + indices += batch_index * n * 3; + weights += batch_index * n * 3; + out += batch_index * n * c; + + const int index = threadIdx.y * blockDim.x + threadIdx.x; + const int stride = blockDim.y * blockDim.x; + for (int i = index; i < c * n; i += stride) { + const int l = i / n; + const int j = i % n; + float w1 = weights[j]; + float w2 = weights[j + n]; + float w3 = weights[j + n + n]; + int i1 = indices[j]; + int i2 = indices[j + n]; + int i3 = indices[j + n + n]; + + out[i] = centers_features[l * m + i1] * w1 + + centers_features[l * m + i2] * w2 + + centers_features[l * m + i3] * w3; + } +} + +void three_nearest_neighbors_interpolate(int b, int c, int m, int n, + const float *points_coords, + const float *centers_coords, + const float *centers_features, + int *indices, float *weights, + float *out) { + three_nearest_neighbors_kernel<<>>( + b, n, m, points_coords, centers_coords, weights, indices); + three_nearest_neighbors_interpolate_kernel<<< + b, optimal_block_config(n, c), 0, at::cuda::getCurrentCUDAStream()>>>( + b, c, m, n, centers_features, indices, weights, out); + CUDA_CHECK_ERRORS(); +} + +/* + Function: interpolate three nearest neighbors (backward) + Args: + b : batch size + c : #channels of features + m : number of query centers + n : number of points in point clouds + grad_y : grad of features of points, FloatTensor[b, c, n] + indices : indices of nearest 3 centers to the point, IntTensor[b, 3, n] + weights : weights for interpolation, FloatTensor[b, 3, n] + grad_x : grad of features of centers, FloatTensor[b, c, m] +*/ +__global__ void three_nearest_neighbors_interpolate_grad_kernel( + int b, int c, int n, int m, const float *__restrict__ grad_y, + const int *__restrict__ indices, const float *__restrict__ weights, + float *__restrict__ grad_x) { + int batch_index = blockIdx.x; + grad_y += batch_index * n * c; + indices += batch_index * n * 3; + weights += batch_index * n * 3; + grad_x += batch_index * m * c; + + const int index = threadIdx.y * blockDim.x + threadIdx.x; + const int stride = blockDim.y * blockDim.x; + for (int i = index; i < c * n; i += stride) { + const int l = i / n; + const int j = i % n; + float w1 = weights[j]; + float w2 = weights[j + n]; + float w3 = weights[j + n + n]; + int i1 = indices[j]; + int i2 = indices[j + n]; + int i3 = indices[j + n + n]; + atomicAdd(grad_x + l * m + i1, grad_y[i] * w1); + atomicAdd(grad_x + l * m + i2, grad_y[i] * w2); + atomicAdd(grad_x + l * m + i3, grad_y[i] * w3); + } +} + +void three_nearest_neighbors_interpolate_grad(int b, int c, int n, int m, + const float *grad_y, + const int *indices, + const float *weights, + float *grad_x) { + three_nearest_neighbors_interpolate_grad_kernel<<< + b, optimal_block_config(n, c), 0, at::cuda::getCurrentCUDAStream()>>>( + b, c, n, m, grad_y, indices, weights, grad_x); + CUDA_CHECK_ERRORS(); +} diff --git a/model/pvcnn/modules/functional/src/interpolate/neighbor_interpolate.cuh b/model/pvcnn/modules/functional/src/interpolate/neighbor_interpolate.cuh new file mode 100644 index 0000000000000000000000000000000000000000..a15f37e40f3d45c6cddbb092fecba2efe7012ec6 --- /dev/null +++ b/model/pvcnn/modules/functional/src/interpolate/neighbor_interpolate.cuh @@ -0,0 +1,16 @@ +#ifndef _NEIGHBOR_INTERPOLATE_CUH +#define _NEIGHBOR_INTERPOLATE_CUH + +void three_nearest_neighbors_interpolate(int b, int c, int m, int n, + const float *points_coords, + const float *centers_coords, + const float *centers_features, + int *indices, float *weights, + float *out); +void three_nearest_neighbors_interpolate_grad(int b, int c, int n, int m, + const float *grad_y, + const int *indices, + const float *weights, + float *grad_x); + +#endif diff --git a/model/pvcnn/modules/functional/src/interpolate/neighbor_interpolate.hpp b/model/pvcnn/modules/functional/src/interpolate/neighbor_interpolate.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cdc7835d44c1b3b4b3c49051bbc958234fea3454 --- /dev/null +++ b/model/pvcnn/modules/functional/src/interpolate/neighbor_interpolate.hpp @@ -0,0 +1,16 @@ +#ifndef _NEIGHBOR_INTERPOLATE_HPP +#define _NEIGHBOR_INTERPOLATE_HPP + +#include +#include + +std::vector +three_nearest_neighbors_interpolate_forward(at::Tensor points_coords, + at::Tensor centers_coords, + at::Tensor centers_features); +at::Tensor three_nearest_neighbors_interpolate_backward(at::Tensor grad_y, + at::Tensor indices, + at::Tensor weights, + const int m); + +#endif diff --git a/model/pvcnn/modules/functional/src/interpolate/trilinear_devox.cpp b/model/pvcnn/modules/functional/src/interpolate/trilinear_devox.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a8ff4fc74f6975cef6811eceb4b1f42d1116e1cf --- /dev/null +++ b/model/pvcnn/modules/functional/src/interpolate/trilinear_devox.cpp @@ -0,0 +1,91 @@ +#include "trilinear_devox.hpp" +#include "trilinear_devox.cuh" + +#include "../utils.hpp" + +/* + Function: trilinear devoxelization (forward) + Args: + r : voxel resolution + trainig : whether is training mode + coords : the coordinates of points, FloatTensor[b, 3, n] + features : features, FloatTensor[b, c, s], s = r ** 3 + Return: + outs : outputs, FloatTensor[b, c, n] + inds : the voxel coordinates of point cube, IntTensor[b, 8, n] + wgts : weight for trilinear interpolation, FloatTensor[b, 8, n] +*/ +std::vector +trilinear_devoxelize_forward(const int r, const bool is_training, + const at::Tensor coords, + const at::Tensor features) { + CHECK_CUDA(features); + CHECK_CUDA(coords); + CHECK_CONTIGUOUS(features); + CHECK_CONTIGUOUS(coords); + CHECK_IS_FLOAT(features); + CHECK_IS_FLOAT(coords); + + int b = features.size(0); + int c = features.size(1); + int n = coords.size(2); + int r2 = r * r; + int r3 = r2 * r; + at::Tensor outs = torch::zeros( + {b, c, n}, at::device(features.device()).dtype(at::ScalarType::Float)); + if (is_training) { + at::Tensor inds = torch::zeros( + {b, 8, n}, at::device(features.device()).dtype(at::ScalarType::Int)); + at::Tensor wgts = torch::zeros( + {b, 8, n}, at::device(features.device()).dtype(at::ScalarType::Float)); + trilinear_devoxelize(b, c, n, r, r2, r3, true, coords.data_ptr(), + features.data_ptr(), inds.data_ptr(), + wgts.data_ptr(), outs.data_ptr()); + return {outs, inds, wgts}; + } else { + at::Tensor inds = torch::zeros( + {1}, at::device(features.device()).dtype(at::ScalarType::Int)); + at::Tensor wgts = torch::zeros( + {1}, at::device(features.device()).dtype(at::ScalarType::Float)); + trilinear_devoxelize(b, c, n, r, r2, r3, false, coords.data_ptr(), + features.data_ptr(), inds.data_ptr(), + wgts.data_ptr(), outs.data_ptr()); + return {outs, inds, wgts}; + } +} + +/* + Function: trilinear devoxelization (backward) + Args: + grad_y : grad outputs, FloatTensor[b, c, n] + indices : the voxel coordinates of point cube, IntTensor[b, 8, n] + weights : weight for trilinear interpolation, FloatTensor[b, 8, n] + r : voxel resolution + Return: + grad_x : grad inputs, FloatTensor[b, c, s], s = r ** 3 +*/ +at::Tensor trilinear_devoxelize_backward(const at::Tensor grad_y, + const at::Tensor indices, + const at::Tensor weights, + const int r) { + CHECK_CUDA(grad_y); + CHECK_CUDA(weights); + CHECK_CUDA(indices); + CHECK_CONTIGUOUS(grad_y); + CHECK_CONTIGUOUS(weights); + CHECK_CONTIGUOUS(indices); + CHECK_IS_FLOAT(grad_y); + CHECK_IS_FLOAT(weights); + CHECK_IS_INT(indices); + + int b = grad_y.size(0); + int c = grad_y.size(1); + int n = grad_y.size(2); + int r3 = r * r * r; + at::Tensor grad_x = torch::zeros( + {b, c, r3}, at::device(grad_y.device()).dtype(at::ScalarType::Float)); + trilinear_devoxelize_grad(b, c, n, r3, indices.data_ptr(), + weights.data_ptr(), grad_y.data_ptr(), + grad_x.data_ptr()); + return grad_x; +} diff --git a/model/pvcnn/modules/functional/src/interpolate/trilinear_devox.cu b/model/pvcnn/modules/functional/src/interpolate/trilinear_devox.cu new file mode 100644 index 0000000000000000000000000000000000000000..4e1e50c0b170508c90d3ec392cc43dbf771e276b --- /dev/null +++ b/model/pvcnn/modules/functional/src/interpolate/trilinear_devox.cu @@ -0,0 +1,178 @@ +#include +#include + +#include "../cuda_utils.cuh" + +/* + Function: trilinear devoxlization (forward) + Args: + b : batch size + c : #channels + n : number of points + r : voxel resolution + r2 : r ** 2 + r3 : r ** 3 + coords : the coordinates of points, FloatTensor[b, 3, n] + feat : features, FloatTensor[b, c, r3] + inds : the voxel indices of point cube, IntTensor[b, 8, n] + wgts : weight for trilinear interpolation, FloatTensor[b, 8, n] + outs : outputs, FloatTensor[b, c, n] +*/ +__global__ void trilinear_devoxelize_kernel(int b, int c, int n, int r, int r2, + int r3, bool is_training, + const float *__restrict__ coords, + const float *__restrict__ feat, + int *__restrict__ inds, + float *__restrict__ wgts, + float *__restrict__ outs) { + int batch_index = blockIdx.x; + int stride = blockDim.x; + int index = threadIdx.x; + coords += batch_index * n * 3; + inds += batch_index * n * 8; + wgts += batch_index * n * 8; + feat += batch_index * c * r3; + outs += batch_index * c * n; + + for (int i = index; i < n; i += stride) { + float x = coords[i]; + float y = coords[i + n]; + float z = coords[i + n + n]; + float x_lo_f = floorf(x); + float y_lo_f = floorf(y); + float z_lo_f = floorf(z); + + float x_d_1 = x - x_lo_f; // / (x_hi_f - x_lo_f + 1e-8f) + float y_d_1 = y - y_lo_f; + float z_d_1 = z - z_lo_f; + float x_d_0 = 1.0f - x_d_1; + float y_d_0 = 1.0f - y_d_1; + float z_d_0 = 1.0f - z_d_1; + + float wgt000 = x_d_0 * y_d_0 * z_d_0; + float wgt001 = x_d_0 * y_d_0 * z_d_1; + float wgt010 = x_d_0 * y_d_1 * z_d_0; + float wgt011 = x_d_0 * y_d_1 * z_d_1; + float wgt100 = x_d_1 * y_d_0 * z_d_0; + float wgt101 = x_d_1 * y_d_0 * z_d_1; + float wgt110 = x_d_1 * y_d_1 * z_d_0; + float wgt111 = x_d_1 * y_d_1 * z_d_1; + + int x_lo = static_cast(x_lo_f); + int y_lo = static_cast(y_lo_f); + int z_lo = static_cast(z_lo_f); + int x_hi = (x_d_1 > 0) ? -1 : 0; + int y_hi = (y_d_1 > 0) ? -1 : 0; + int z_hi = (z_d_1 > 0) ? 1 : 0; + + int idx000 = x_lo * r2 + y_lo * r + z_lo; + int idx001 = idx000 + z_hi; // x_lo * r2 + y_lo * r + z_hi; + int idx010 = idx000 + (y_hi & r); // x_lo * r2 + y_hi * r + z_lo; + int idx011 = idx010 + z_hi; // x_lo * r2 + y_hi * r + z_hi; + int idx100 = idx000 + (x_hi & r2); // x_hi * r2 + y_lo * r + z_lo; + int idx101 = idx100 + z_hi; // x_hi * r2 + y_lo * r + z_hi; + int idx110 = idx100 + (y_hi & r); // x_hi * r2 + y_hi * r + z_lo; + int idx111 = idx110 + z_hi; // x_hi * r2 + y_hi * r + z_hi; + + if (is_training) { + wgts[i] = wgt000; + wgts[i + n] = wgt001; + wgts[i + n * 2] = wgt010; + wgts[i + n * 3] = wgt011; + wgts[i + n * 4] = wgt100; + wgts[i + n * 5] = wgt101; + wgts[i + n * 6] = wgt110; + wgts[i + n * 7] = wgt111; + inds[i] = idx000; + inds[i + n] = idx001; + inds[i + n * 2] = idx010; + inds[i + n * 3] = idx011; + inds[i + n * 4] = idx100; + inds[i + n * 5] = idx101; + inds[i + n * 6] = idx110; + inds[i + n * 7] = idx111; + } + + for (int j = 0; j < c; j++) { + int jr3 = j * r3; + outs[j * n + i] = + wgt000 * feat[jr3 + idx000] + wgt001 * feat[jr3 + idx001] + + wgt010 * feat[jr3 + idx010] + wgt011 * feat[jr3 + idx011] + + wgt100 * feat[jr3 + idx100] + wgt101 * feat[jr3 + idx101] + + wgt110 * feat[jr3 + idx110] + wgt111 * feat[jr3 + idx111]; + } + } +} + +/* + Function: trilinear devoxlization (backward) + Args: + b : batch size + c : #channels + n : number of points + r3 : voxel cube size = voxel resolution ** 3 + inds : the voxel indices of point cube, IntTensor[b, 8, n] + wgts : weight for trilinear interpolation, FloatTensor[b, 8, n] + grad_y : grad outputs, FloatTensor[b, c, n] + grad_x : grad inputs, FloatTensor[b, c, r3] +*/ +__global__ void trilinear_devoxelize_grad_kernel( + int b, int c, int n, int r3, const int *__restrict__ inds, + const float *__restrict__ wgts, const float *__restrict__ grad_y, + float *__restrict__ grad_x) { + int batch_index = blockIdx.x; + int stride = blockDim.x; + int index = threadIdx.x; + inds += batch_index * n * 8; + wgts += batch_index * n * 8; + grad_x += batch_index * c * r3; + grad_y += batch_index * c * n; + + for (int i = index; i < n; i += stride) { + int idx000 = inds[i]; + int idx001 = inds[i + n]; + int idx010 = inds[i + n * 2]; + int idx011 = inds[i + n * 3]; + int idx100 = inds[i + n * 4]; + int idx101 = inds[i + n * 5]; + int idx110 = inds[i + n * 6]; + int idx111 = inds[i + n * 7]; + float wgt000 = wgts[i]; + float wgt001 = wgts[i + n]; + float wgt010 = wgts[i + n * 2]; + float wgt011 = wgts[i + n * 3]; + float wgt100 = wgts[i + n * 4]; + float wgt101 = wgts[i + n * 5]; + float wgt110 = wgts[i + n * 6]; + float wgt111 = wgts[i + n * 7]; + + for (int j = 0; j < c; j++) { + int jr3 = j * r3; + float g = grad_y[j * n + i]; + atomicAdd(grad_x + jr3 + idx000, wgt000 * g); + atomicAdd(grad_x + jr3 + idx001, wgt001 * g); + atomicAdd(grad_x + jr3 + idx010, wgt010 * g); + atomicAdd(grad_x + jr3 + idx011, wgt011 * g); + atomicAdd(grad_x + jr3 + idx100, wgt100 * g); + atomicAdd(grad_x + jr3 + idx101, wgt101 * g); + atomicAdd(grad_x + jr3 + idx110, wgt110 * g); + atomicAdd(grad_x + jr3 + idx111, wgt111 * g); + } + } +} + +void trilinear_devoxelize(int b, int c, int n, int r, int r2, int r3, + bool training, const float *coords, const float *feat, + int *inds, float *wgts, float *outs) { + trilinear_devoxelize_kernel<<>>( + b, c, n, r, r2, r3, training, coords, feat, inds, wgts, outs); + CUDA_CHECK_ERRORS(); +} + +void trilinear_devoxelize_grad(int b, int c, int n, int r3, const int *inds, + const float *wgts, const float *grad_y, + float *grad_x) { + trilinear_devoxelize_grad_kernel<<>>( + b, c, n, r3, inds, wgts, grad_y, grad_x); + CUDA_CHECK_ERRORS(); +} diff --git a/model/pvcnn/modules/functional/src/interpolate/trilinear_devox.cuh b/model/pvcnn/modules/functional/src/interpolate/trilinear_devox.cuh new file mode 100644 index 0000000000000000000000000000000000000000..8aadbaf34f4ad539199e21cb35c041dfcbe766e8 --- /dev/null +++ b/model/pvcnn/modules/functional/src/interpolate/trilinear_devox.cuh @@ -0,0 +1,13 @@ +#ifndef _TRILINEAR_DEVOX_CUH +#define _TRILINEAR_DEVOX_CUH + +// CUDA function declarations +void trilinear_devoxelize(int b, int c, int n, int r, int r2, int r3, + bool is_training, const float *coords, + const float *feat, int *inds, float *wgts, + float *outs); +void trilinear_devoxelize_grad(int b, int c, int n, int r3, const int *inds, + const float *wgts, const float *grad_y, + float *grad_x); + +#endif diff --git a/model/pvcnn/modules/functional/src/interpolate/trilinear_devox.hpp b/model/pvcnn/modules/functional/src/interpolate/trilinear_devox.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a9d67957b79dba20cf6891c073da0f9dd5653aac --- /dev/null +++ b/model/pvcnn/modules/functional/src/interpolate/trilinear_devox.hpp @@ -0,0 +1,16 @@ +#ifndef _TRILINEAR_DEVOX_HPP +#define _TRILINEAR_DEVOX_HPP + +#include +#include + +std::vector trilinear_devoxelize_forward(const int r, + const bool is_training, + const at::Tensor coords, + const at::Tensor features); + +at::Tensor trilinear_devoxelize_backward(const at::Tensor grad_y, + const at::Tensor indices, + const at::Tensor weights, const int r); + +#endif diff --git a/model/pvcnn/modules/functional/src/sampling/sampling.cpp b/model/pvcnn/modules/functional/src/sampling/sampling.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9b8ca6ef63144ad4112dffa4d71391329578e25e --- /dev/null +++ b/model/pvcnn/modules/functional/src/sampling/sampling.cpp @@ -0,0 +1,58 @@ +#include "sampling.hpp" +#include "sampling.cuh" + +#include "../utils.hpp" + +at::Tensor gather_features_forward(at::Tensor features, at::Tensor indices) { + CHECK_CUDA(features); + CHECK_CUDA(indices); + CHECK_CONTIGUOUS(features); + CHECK_CONTIGUOUS(indices); + CHECK_IS_FLOAT(features); + CHECK_IS_INT(indices); + + int b = features.size(0); + int c = features.size(1); + int n = features.size(2); + int m = indices.size(1); + at::Tensor output = torch::zeros( + {b, c, m}, at::device(features.device()).dtype(at::ScalarType::Float)); + gather_features(b, c, n, m, features.data_ptr(), + indices.data_ptr(), output.data_ptr()); + return output; +} + +at::Tensor gather_features_backward(at::Tensor grad_y, at::Tensor indices, + const int n) { + CHECK_CUDA(grad_y); + CHECK_CUDA(indices); + CHECK_CONTIGUOUS(grad_y); + CHECK_CONTIGUOUS(indices); + CHECK_IS_FLOAT(grad_y); + CHECK_IS_INT(indices); + + int b = grad_y.size(0); + int c = grad_y.size(1); + at::Tensor grad_x = torch::zeros( + {b, c, n}, at::device(grad_y.device()).dtype(at::ScalarType::Float)); + gather_features_grad(b, c, n, indices.size(1), grad_y.data_ptr(), + indices.data_ptr(), grad_x.data_ptr()); + return grad_x; +} + +at::Tensor furthest_point_sampling_forward(at::Tensor coords, + const int num_samples) { + CHECK_CUDA(coords); + CHECK_CONTIGUOUS(coords); + CHECK_IS_FLOAT(coords); + + int b = coords.size(0); + int n = coords.size(2); + at::Tensor indices = torch::zeros( + {b, num_samples}, at::device(coords.device()).dtype(at::ScalarType::Int)); + at::Tensor distances = torch::full( + {b, n}, 1e38f, at::device(coords.device()).dtype(at::ScalarType::Float)); + furthest_point_sampling(b, n, num_samples, coords.data_ptr(), + distances.data_ptr(), indices.data_ptr()); + return indices; +} diff --git a/model/pvcnn/modules/functional/src/sampling/sampling.cu b/model/pvcnn/modules/functional/src/sampling/sampling.cu new file mode 100644 index 0000000000000000000000000000000000000000..06bc0ee7240ef3d86c12eb95492d9e621b27b1bb --- /dev/null +++ b/model/pvcnn/modules/functional/src/sampling/sampling.cu @@ -0,0 +1,174 @@ +#include +#include + +#include "../cuda_utils.cuh" + +/* + Function: gather centers' features (forward) + Args: + b : batch size + c : #channles of features + n : number of points in point clouds + m : number of query/sampled centers + features: points' features, FloatTensor[b, c, n] + indices : centers' indices in points, IntTensor[b, m] + out : gathered features, FloatTensor[b, c, m] +*/ +__global__ void gather_features_kernel(int b, int c, int n, int m, + const float *__restrict__ features, + const int *__restrict__ indices, + float *__restrict__ out) { + int batch_index = blockIdx.x; + int channel_index = blockIdx.y; + int temp_index = batch_index * c + channel_index; + features += temp_index * n; + indices += batch_index * m; + out += temp_index * m; + + for (int j = threadIdx.x; j < m; j += blockDim.x) { + out[j] = features[indices[j]]; + } +} + +void gather_features(int b, int c, int n, int m, const float *features, + const int *indices, float *out) { + gather_features_kernel<<>>( + b, c, n, m, features, indices, out); + CUDA_CHECK_ERRORS(); +} + +/* + Function: gather centers' features (backward) + Args: + b : batch size + c : #channles of features + n : number of points in point clouds + m : number of query/sampled centers + grad_y : grad of gathered features, FloatTensor[b, c, m] + indices : centers' indices in points, IntTensor[b, m] + grad_x : grad of points' features, FloatTensor[b, c, n] +*/ +__global__ void gather_features_grad_kernel(int b, int c, int n, int m, + const float *__restrict__ grad_y, + const int *__restrict__ indices, + float *__restrict__ grad_x) { + int batch_index = blockIdx.x; + int channel_index = blockIdx.y; + int temp_index = batch_index * c + channel_index; + grad_y += temp_index * m; + indices += batch_index * m; + grad_x += temp_index * n; + + for (int j = threadIdx.x; j < m; j += blockDim.x) { + atomicAdd(grad_x + indices[j], grad_y[j]); + } +} + +void gather_features_grad(int b, int c, int n, int m, const float *grad_y, + const int *indices, float *grad_x) { + gather_features_grad_kernel<<>>( + b, c, n, m, grad_y, indices, grad_x); + CUDA_CHECK_ERRORS(); +} + +/* + Function: furthest point sampling + Args: + b : batch size + n : number of points in point clouds + m : number of query/sampled centers + coords : points' coords, FloatTensor[b, 3, n] + distances : minimum distance of a point to the set, IntTensor[b, n] + indices : sampled centers' indices in points, IntTensor[b, m] +*/ +__global__ void furthest_point_sampling_kernel(int b, int n, int m, + const float *__restrict__ coords, + float *__restrict__ distances, + int *__restrict__ indices) { + if (m <= 0) + return; + int batch_index = blockIdx.x; + coords += batch_index * n * 3; + distances += batch_index * n; + indices += batch_index * m; + + const int BlockSize = 512; + __shared__ float dists[BlockSize]; + __shared__ int dists_i[BlockSize]; + const int BufferSize = 3072; + __shared__ float buf[BufferSize * 3]; + + int old = 0; + if (threadIdx.x == 0) + indices[0] = old; + + for (int j = threadIdx.x; j < min(BufferSize, n); j += blockDim.x) { + buf[j] = coords[j]; + buf[j + BufferSize] = coords[j + n]; + buf[j + BufferSize + BufferSize] = coords[j + n + n]; + } + __syncthreads(); + + for (int j = 1; j < m; j++) { + int besti = 0; // best index + float best = -1; // farthest distance + // calculating the distance with the latest sampled point + float x1 = coords[old]; + float y1 = coords[old + n]; + float z1 = coords[old + n + n]; + for (int k = threadIdx.x; k < n; k += blockDim.x) { + // fetch distance at block n, thread k + float td = distances[k]; + float x2, y2, z2; + if (k < BufferSize) { + x2 = buf[k]; + y2 = buf[k + BufferSize]; + z2 = buf[k + BufferSize + BufferSize]; + } else { + x2 = coords[k]; + y2 = coords[k + n]; + z2 = coords[k + n + n]; + } + float d = + (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); + float d2 = min(d, td); + // update "point-to-set" distance + if (d2 != td) + distances[k] = d2; + // update the farthest distance at sample step j + if (d2 > best) { + best = d2; + besti = k; + } + } + + dists[threadIdx.x] = best; + dists_i[threadIdx.x] = besti; + for (int u = 0; (1 << u) < blockDim.x; u++) { + __syncthreads(); + if (threadIdx.x < (blockDim.x >> (u + 1))) { + int i1 = (threadIdx.x * 2) << u; + int i2 = (threadIdx.x * 2 + 1) << u; + if (dists[i1] < dists[i2]) { + dists[i1] = dists[i2]; + dists_i[i1] = dists_i[i2]; + } + } + } + __syncthreads(); + + // finish sample step j; old is the sampled index + old = dists_i[0]; + if (threadIdx.x == 0) + indices[j] = old; + } +} + +void furthest_point_sampling(int b, int n, int m, const float *coords, + float *distances, int *indices) { + furthest_point_sampling_kernel<<>>(b, n, m, coords, distances, + indices); + CUDA_CHECK_ERRORS(); +} diff --git a/model/pvcnn/modules/functional/src/sampling/sampling.cuh b/model/pvcnn/modules/functional/src/sampling/sampling.cuh new file mode 100644 index 0000000000000000000000000000000000000000..e68358ffa59e10a4d00fceaf383cdef403b849ce --- /dev/null +++ b/model/pvcnn/modules/functional/src/sampling/sampling.cuh @@ -0,0 +1,11 @@ +#ifndef _SAMPLING_CUH +#define _SAMPLING_CUH + +void gather_features(int b, int c, int n, int m, const float *features, + const int *indices, float *out); +void gather_features_grad(int b, int c, int n, int m, const float *grad_y, + const int *indices, float *grad_x); +void furthest_point_sampling(int b, int n, int m, const float *coords, + float *distances, int *indices); + +#endif diff --git a/model/pvcnn/modules/functional/src/sampling/sampling.hpp b/model/pvcnn/modules/functional/src/sampling/sampling.hpp new file mode 100644 index 0000000000000000000000000000000000000000..db2a5c84aa6a7eec46c59373c23b43be3c585499 --- /dev/null +++ b/model/pvcnn/modules/functional/src/sampling/sampling.hpp @@ -0,0 +1,12 @@ +#ifndef _SAMPLING_HPP +#define _SAMPLING_HPP + +#include + +at::Tensor gather_features_forward(at::Tensor features, at::Tensor indices); +at::Tensor gather_features_backward(at::Tensor grad_y, at::Tensor indices, + const int n); +at::Tensor furthest_point_sampling_forward(at::Tensor coords, + const int num_samples); + +#endif diff --git a/model/pvcnn/modules/functional/src/utils.hpp b/model/pvcnn/modules/functional/src/utils.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f4f21a07ec174f15aa2aa0974a5903488b0d7be3 --- /dev/null +++ b/model/pvcnn/modules/functional/src/utils.hpp @@ -0,0 +1,20 @@ +#ifndef _UTILS_HPP +#define _UTILS_HPP + +#include +#include + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") + +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") + +#define CHECK_IS_INT(x) \ + TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, \ + #x " must be an int tensor") + +#define CHECK_IS_FLOAT(x) \ + TORCH_CHECK(x.scalar_type() == at::ScalarType::Float, \ + #x " must be a float tensor") + +#endif diff --git a/model/pvcnn/modules/functional/src/voxelization/vox.cpp b/model/pvcnn/modules/functional/src/voxelization/vox.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6a84594f25e78d1bad74131744c27402885b1dad --- /dev/null +++ b/model/pvcnn/modules/functional/src/voxelization/vox.cpp @@ -0,0 +1,76 @@ +#include "vox.hpp" +#include "vox.cuh" + +#include "../utils.hpp" + +/* + Function: average pool voxelization (forward) + Args: + features: features, FloatTensor[b, c, n] + coords : coords of each point, IntTensor[b, 3, n] + resolution : voxel resolution + Return: + out : outputs, FloatTensor[b, c, s], s = r ** 3 + ind : voxel index of each point, IntTensor[b, n] + cnt : #points in each voxel index, IntTensor[b, s] +*/ +std::vector avg_voxelize_forward(const at::Tensor features, + const at::Tensor coords, + const int resolution) { + CHECK_CUDA(features); + CHECK_CUDA(coords); + CHECK_CONTIGUOUS(features); + CHECK_CONTIGUOUS(coords); + CHECK_IS_FLOAT(features); + CHECK_IS_INT(coords); + + int b = features.size(0); + int c = features.size(1); + int n = features.size(2); + int r = resolution; + int r2 = r * r; + int r3 = r2 * r; + at::Tensor ind = torch::zeros( + {b, n}, at::device(features.device()).dtype(at::ScalarType::Int)); + at::Tensor out = torch::zeros( + {b, c, r3}, at::device(features.device()).dtype(at::ScalarType::Float)); + at::Tensor cnt = torch::zeros( + {b, r3}, at::device(features.device()).dtype(at::ScalarType::Int)); + avg_voxelize(b, c, n, r, r2, r3, coords.data_ptr(), + features.data_ptr(), ind.data_ptr(), + cnt.data_ptr(), out.data_ptr()); + return {out, ind, cnt}; +} + +/* + Function: average pool voxelization (backward) + Args: + grad_y : grad outputs, FloatTensor[b, c, s] + indices: voxel index of each point, IntTensor[b, n] + cnt : #points in each voxel index, IntTensor[b, s] + Return: + grad_x : grad inputs, FloatTensor[b, c, n] +*/ +at::Tensor avg_voxelize_backward(const at::Tensor grad_y, + const at::Tensor indices, + const at::Tensor cnt) { + CHECK_CUDA(grad_y); + CHECK_CUDA(indices); + CHECK_CUDA(cnt); + CHECK_CONTIGUOUS(grad_y); + CHECK_CONTIGUOUS(indices); + CHECK_CONTIGUOUS(cnt); + CHECK_IS_FLOAT(grad_y); + CHECK_IS_INT(indices); + CHECK_IS_INT(cnt); + + int b = grad_y.size(0); + int c = grad_y.size(1); + int s = grad_y.size(2); + int n = indices.size(1); + at::Tensor grad_x = torch::zeros( + {b, c, n}, at::device(grad_y.device()).dtype(at::ScalarType::Float)); + avg_voxelize_grad(b, c, n, s, indices.data_ptr(), cnt.data_ptr(), + grad_y.data_ptr(), grad_x.data_ptr()); + return grad_x; +} diff --git a/model/pvcnn/modules/functional/src/voxelization/vox.cu b/model/pvcnn/modules/functional/src/voxelization/vox.cu new file mode 100644 index 0000000000000000000000000000000000000000..1c1a2c92a646dde44886506dd8728cff21dbe533 --- /dev/null +++ b/model/pvcnn/modules/functional/src/voxelization/vox.cu @@ -0,0 +1,126 @@ +#include +#include + +#include "../cuda_utils.cuh" + +/* + Function: get how many points in each voxel grid + Args: + b : batch size + n : number of points + r : voxel resolution + r2 : = r * r + r3 : s, voxel cube size = r ** 3 + coords : coords of each point, IntTensor[b, 3, n] + ind : voxel index of each point, IntTensor[b, n] + cnt : #points in each voxel index, IntTensor[b, s] +*/ +__global__ void grid_stats_kernel(int b, int n, int r, int r2, int r3, + const int *__restrict__ coords, + int *__restrict__ ind, int *cnt) { + int batch_index = blockIdx.x; + int stride = blockDim.x; + int index = threadIdx.x; + coords += batch_index * n * 3; + ind += batch_index * n; + cnt += batch_index * r3; + + for (int i = index; i < n; i += stride) { + // if (ind[i] == -1) + // continue; + ind[i] = coords[i] * r2 + coords[i + n] * r + coords[i + n + n]; + atomicAdd(cnt + ind[i], 1); + } +} + +/* + Function: average pool voxelization (forward) + Args: + b : batch size + c : #channels + n : number of points + s : voxel cube size = voxel resolution ** 3 + ind : voxel index of each point, IntTensor[b, n] + cnt : #points in each voxel index, IntTensor[b, s] + feat: features, FloatTensor[b, c, n] + out : outputs, FloatTensor[b, c, s] +*/ +__global__ void avg_voxelize_kernel(int b, int c, int n, int s, + const int *__restrict__ ind, + const int *__restrict__ cnt, + const float *__restrict__ feat, + float *__restrict__ out) { + int batch_index = blockIdx.x; + int stride = blockDim.x; + int index = threadIdx.x; + ind += batch_index * n; + feat += batch_index * c * n; + out += batch_index * c * s; + cnt += batch_index * s; + for (int i = index; i < n; i += stride) { + int pos = ind[i]; + // if (pos == -1) + // continue; + int cur_cnt = cnt[pos]; + if (cur_cnt > 0) { + float div_cur_cnt = 1.0 / static_cast(cur_cnt); + for (int j = 0; j < c; j++) { + atomicAdd(out + j * s + pos, feat[j * n + i] * div_cur_cnt); + } + } + } +} + +/* + Function: average pool voxelization (backward) + Args: + b : batch size + c : #channels + n : number of points + r3 : voxel cube size = voxel resolution ** 3 + ind : voxel index of each point, IntTensor[b, n] + cnt : #points in each voxel index, IntTensor[b, s] + grad_y : grad outputs, FloatTensor[b, c, s] + grad_x : grad inputs, FloatTensor[b, c, n] +*/ +__global__ void avg_voxelize_grad_kernel(int b, int c, int n, int r3, + const int *__restrict__ ind, + const int *__restrict__ cnt, + const float *__restrict__ grad_y, + float *__restrict__ grad_x) { + int batch_index = blockIdx.x; + int stride = blockDim.x; + int index = threadIdx.x; + ind += batch_index * n; + grad_x += batch_index * c * n; + grad_y += batch_index * c * r3; + cnt += batch_index * r3; + for (int i = index; i < n; i += stride) { + int pos = ind[i]; + // if (pos == -1) + // continue; + int cur_cnt = cnt[pos]; + if (cur_cnt > 0) { + float div_cur_cnt = 1.0 / static_cast(cur_cnt); + for (int j = 0; j < c; j++) { + atomicAdd(grad_x + j * n + i, grad_y[j * r3 + pos] * div_cur_cnt); + } + } + } +} + +void avg_voxelize(int b, int c, int n, int r, int r2, int r3, const int *coords, + const float *feat, int *ind, int *cnt, float *out) { + grid_stats_kernel<<>>(b, n, r, r2, r3, coords, ind, + cnt); + avg_voxelize_kernel<<>>(b, c, n, r3, ind, cnt, + feat, out); + CUDA_CHECK_ERRORS(); +} + +void avg_voxelize_grad(int b, int c, int n, int s, const int *ind, + const int *cnt, const float *grad_y, float *grad_x) { + avg_voxelize_grad_kernel<<>>(b, c, n, s, ind, cnt, + grad_y, grad_x); + CUDA_CHECK_ERRORS(); +} diff --git a/model/pvcnn/modules/functional/src/voxelization/vox.cuh b/model/pvcnn/modules/functional/src/voxelization/vox.cuh new file mode 100644 index 0000000000000000000000000000000000000000..9adb0fdabce3352a2ed3f64e053f2aa8fafd17e7 --- /dev/null +++ b/model/pvcnn/modules/functional/src/voxelization/vox.cuh @@ -0,0 +1,10 @@ +#ifndef _VOX_CUH +#define _VOX_CUH + +// CUDA function declarations +void avg_voxelize(int b, int c, int n, int r, int r2, int r3, const int *coords, + const float *feat, int *ind, int *cnt, float *out); +void avg_voxelize_grad(int b, int c, int n, int s, const int *idx, + const int *cnt, const float *grad_y, float *grad_x); + +#endif diff --git a/model/pvcnn/modules/functional/src/voxelization/vox.hpp b/model/pvcnn/modules/functional/src/voxelization/vox.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6e62bc39efccc29da9e6d3e6d149baeee7d022e3 --- /dev/null +++ b/model/pvcnn/modules/functional/src/voxelization/vox.hpp @@ -0,0 +1,15 @@ +#ifndef _VOX_HPP +#define _VOX_HPP + +#include +#include + +std::vector avg_voxelize_forward(const at::Tensor features, + const at::Tensor coords, + const int resolution); + +at::Tensor avg_voxelize_backward(const at::Tensor grad_y, + const at::Tensor indices, + const at::Tensor cnt); + +#endif diff --git a/model/pvcnn/modules/functional/voxelization.py b/model/pvcnn/modules/functional/voxelization.py new file mode 100644 index 0000000000000000000000000000000000000000..ca4656df0c5ddd418405c3eba50b7f401f204938 --- /dev/null +++ b/model/pvcnn/modules/functional/voxelization.py @@ -0,0 +1,40 @@ +from torch.autograd import Function + +from .backend import _backend + +__all__ = ['avg_voxelize'] + + +class AvgVoxelization(Function): + @staticmethod + def forward(ctx, features, coords, resolution): + """ + :param ctx: + :param features: Features of the point cloud, FloatTensor[B, C, N] + :param coords: Voxelized Coordinates of each point, IntTensor[B, 3, N] + :param resolution: Voxel resolution + :return: + Voxelized Features, FloatTensor[B, C, R, R, R] + """ + features = features.contiguous() + coords = coords.int().contiguous() + b, c, _ = features.shape + out, indices, counts = _backend.avg_voxelize_forward(features, coords, resolution) + ctx.save_for_backward(indices, counts) + return out.view(b, c, resolution, resolution, resolution) + + @staticmethod + def backward(ctx, grad_output): + """ + :param ctx: + :param grad_output: gradient of output, FloatTensor[B, C, R, R, R] + :return: + gradient of inputs, FloatTensor[B, C, N] + """ + b, c = grad_output.shape[:2] + indices, counts = ctx.saved_tensors + grad_features = _backend.avg_voxelize_backward(grad_output.contiguous().view(b, c, -1), indices, counts) + return grad_features, None, None + + +avg_voxelize = AvgVoxelization.apply diff --git a/model/pvcnn/modules/loss.py b/model/pvcnn/modules/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..a35cdd8a0fe83c8ca6b1d7040b66d142e76471df --- /dev/null +++ b/model/pvcnn/modules/loss.py @@ -0,0 +1,10 @@ +import torch.nn as nn + +from . import functional as F + +__all__ = ['KLLoss'] + + +class KLLoss(nn.Module): + def forward(self, x, y): + return F.kl_loss(x, y) diff --git a/model/pvcnn/modules/pointnet.py b/model/pvcnn/modules/pointnet.py new file mode 100644 index 0000000000000000000000000000000000000000..7672ae30c4379fe8791efecc1a21abf6b2bb3b32 --- /dev/null +++ b/model/pvcnn/modules/pointnet.py @@ -0,0 +1,118 @@ +import torch +import torch.nn as nn + +from . import functional as F +from .ball_query import BallQuery +from .shared_mlp import SharedMLP + +__all__ = ['PointNetAModule', 'PointNetSAModule', 'PointNetFPModule'] + + +class PointNetAModule(nn.Module): + def __init__(self, in_channels, out_channels, include_coordinates=True): + super().__init__() + if not isinstance(out_channels, (list, tuple)): + out_channels = [[out_channels]] + elif not isinstance(out_channels[0], (list, tuple)): + out_channels = [out_channels] + + mlps = [] + total_out_channels = 0 + for _out_channels in out_channels: + mlps.append( + SharedMLP(in_channels=in_channels + (3 if include_coordinates else 0), + out_channels=_out_channels, dim=1) + ) + total_out_channels += _out_channels[-1] + + self.include_coordinates = include_coordinates + self.out_channels = total_out_channels + self.mlps = nn.ModuleList(mlps) + + def forward(self, inputs): + features, coords = inputs + if self.include_coordinates: + features = torch.cat([features, coords], dim=1) + coords = torch.zeros((coords.size(0), 3, 1), device=coords.device) + if len(self.mlps) > 1: + features_list = [] + for mlp in self.mlps: + features_list.append(mlp(features).max(dim=-1, keepdim=True).values) + return torch.cat(features_list, dim=1), coords + else: + return self.mlps[0](features).max(dim=-1, keepdim=True).values, coords + + def extra_repr(self): + return f'out_channels={self.out_channels}, include_coordinates={self.include_coordinates}' + + +class PointNetSAModule(nn.Module): + def __init__(self, num_centers, radius, num_neighbors, in_channels, out_channels, include_coordinates=True): + super().__init__() + # print(f"PointNet module, in={in_channels}, out={out_channels}") + if not isinstance(radius, (list, tuple)): + radius = [radius] + if not isinstance(num_neighbors, (list, tuple)): + num_neighbors = [num_neighbors] * len(radius) + assert len(radius) == len(num_neighbors) + if not isinstance(out_channels, (list, tuple)): + out_channels = [[out_channels]] * len(radius) + elif not isinstance(out_channels[0], (list, tuple)): + out_channels = [out_channels] * len(radius) + assert len(radius) == len(out_channels) + + groupers, mlps = [], [] + total_out_channels = 0 + for _radius, _out_channels, _num_neighbors in zip(radius, out_channels, num_neighbors): + groupers.append( + BallQuery(radius=_radius, num_neighbors=_num_neighbors, include_coordinates=include_coordinates) + ) + mlps.append( + SharedMLP(in_channels=in_channels + (3 if include_coordinates else 0), + out_channels=_out_channels, dim=2) + ) + total_out_channels += _out_channels[-1] + + self.num_centers = num_centers + self.out_channels = total_out_channels + self.groupers = nn.ModuleList(groupers) + self.mlps = nn.ModuleList(mlps) + + def forward(self, inputs): + features, coords, temb = inputs + centers_coords = F.furthest_point_sample(coords, self.num_centers) # use this to reduce the number of points to next layer + features_list = [] + # print("Pointnet input shape:", features.shape) + for grouper, mlp in zip(self.groupers, self.mlps): + features, temb = mlp(grouper(coords, centers_coords, temb, features)) + features_list.append(features.max(dim=-1).values) + # print("Point net output shape:", features.shape) + if len(features_list) > 1: + return features_list[0], centers_coords, temb.max(dim=-1).values if temb.shape[1] > 0 else temb + else: + return features_list[0], centers_coords, temb.max(dim=-1).values if temb.shape[1] > 0 else temb + + def extra_repr(self): + return f'num_centers={self.num_centers}, out_channels={self.out_channels}' + + +class PointNetFPModule(nn.Module): + def __init__(self, in_channels, out_channels): + # print(f"IN channels={in_channels}, out channels={out_channels}") + super().__init__() + self.mlp = SharedMLP(in_channels=in_channels, out_channels=out_channels, dim=1) + + def forward(self, inputs): + # print(inputs.shape) + if len(inputs) == 3: + points_coords, centers_coords, centers_features, temb = inputs + points_features = None + else: + points_coords, centers_coords, centers_features, points_features, temb = inputs + interpolated_features = F.nearest_neighbor_interpolate(points_coords, centers_coords, centers_features) + interpolated_temb = F.nearest_neighbor_interpolate(points_coords, centers_coords, temb) + if points_features is not None: + interpolated_features = torch.cat( + [interpolated_features, points_features], dim=1 + ) # concate interpolated, with original point features (394, N) + return self.mlp(interpolated_features), points_coords, interpolated_temb diff --git a/model/pvcnn/modules/pvconv.py b/model/pvcnn/modules/pvconv.py new file mode 100644 index 0000000000000000000000000000000000000000..6bc915a28f1179171ef88ab5efffc96781aacf44 --- /dev/null +++ b/model/pvcnn/modules/pvconv.py @@ -0,0 +1,140 @@ +import torch.nn as nn +import torch + +from .voxelization import Voxelization +from .shared_mlp import SharedMLP +from .se import SE3d +from . import functional as F + +__all__ = ['PVConv', 'Attention', 'Swish', 'PVConvReLU'] + + +class Swish(nn.Module): + def forward(self,x): + return x * torch.sigmoid(x) + + +class Attention(nn.Module): + def __init__(self, in_ch, num_groups, D=3): + super(Attention, self).__init__() + assert in_ch % num_groups == 0 + # it also has some learnable parameters + if D == 3: + self.q = nn.Conv3d(in_ch, in_ch, 1) + self.k = nn.Conv3d(in_ch, in_ch, 1) + self.v = nn.Conv3d(in_ch, in_ch, 1) + + self.out = nn.Conv3d(in_ch, in_ch, 1) + elif D == 1: + self.q = nn.Conv1d(in_ch, in_ch, 1) + self.k = nn.Conv1d(in_ch, in_ch, 1) + self.v = nn.Conv1d(in_ch, in_ch, 1) + + self.out = nn.Conv1d(in_ch, in_ch, 1) + + self.norm = nn.GroupNorm(num_groups, in_ch) + self.nonlin = Swish() + + self.sm = nn.Softmax(-1) + + + def forward(self, x): + """ + self attention + reso32: Attention layer, x=torch.Size([16, 64, 16, 16, 16]), q=torch.Size([16, 64, 4096]), k=torch.Size([16, 64, 4096]), v=torch.Size([16, 64, 4096]) + reso48: Attention layer, x=torch.Size([16, 64, 24, 24, 24]), q=torch.Size([16, 64, 13824]), k=torch.Size([16, 64, 13824]), v=torch.Size([16, 64, 13824]) + # this can cause OOM! + + :param x: (B, C, reso, reso, reso)? + :return: + """ + B, C = x.shape[:2] + h = x + + q = self.q(h).reshape(B,C,-1) + k = self.k(h).reshape(B,C,-1) + v = self.v(h).reshape(B,C,-1) + + qk = torch.matmul(q.permute(0, 2, 1), k) #* (int(C) ** (-0.5)) + + w = self.sm(qk) + + h = torch.matmul(v, w.permute(0, 2, 1)).reshape(B,C,*x.shape[2:]) + + h = self.out(h) + + x = h + x + + x = self.nonlin(self.norm(x)) # group norm + swish + + return x + +class PVConv(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, resolution, attention=False, + dropout=0.1, with_se=False, with_se_relu=False, normalize=True, eps=0): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.resolution = resolution + + self.voxelization = Voxelization(resolution, normalize=normalize, eps=eps) + voxel_layers = [ + nn.Conv3d(in_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2), + nn.GroupNorm(num_groups=8, num_channels=out_channels), + Swish() + ] + voxel_layers += [nn.Dropout(dropout)] if dropout is not None else [] + voxel_layers += [ + nn.Conv3d(out_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2), + nn.GroupNorm(num_groups=8, num_channels=out_channels), + Attention(out_channels, 8) if attention else Swish() + ] + if with_se: + voxel_layers.append(SE3d(out_channels, use_relu=with_se_relu)) + self.voxel_layers = nn.Sequential(*voxel_layers) + self.point_features = SharedMLP(in_channels, out_channels) # this is basically an MLP + + def forward(self, inputs): + features, coords, temb = inputs # features: (B, F, N), temb: sinusoidal embedding of diffusion timestaps + voxel_features, voxel_coords = self.voxelization(features, coords) + voxel_features = self.voxel_layers(voxel_features) + voxel_features = F.trilinear_devoxelize(voxel_features, voxel_coords, self.resolution, self.training) + fused_features = voxel_features + self.point_features(features) + return fused_features, coords, temb # coords is not changed, and also temb + + + +class PVConvReLU(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, resolution, attention=False, leak=0.2, + dropout=0.1, with_se=False, with_se_relu=False, normalize=True, eps=0): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.resolution = resolution + + self.voxelization = Voxelization(resolution, normalize=normalize, eps=eps) + voxel_layers = [ + nn.Conv3d(in_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2), + nn.BatchNorm3d(out_channels), + nn.LeakyReLU(leak, True) + ] + voxel_layers += [nn.Dropout(dropout)] if dropout is not None else [] + voxel_layers += [ + nn.Conv3d(out_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2), + nn.BatchNorm3d(out_channels), + Attention(out_channels, 8) if attention else nn.LeakyReLU(leak, True) + ] + if with_se: + voxel_layers.append(SE3d(out_channels, use_relu=with_se_relu)) + self.voxel_layers = nn.Sequential(*voxel_layers) + self.point_features = SharedMLP(in_channels, out_channels) + + def forward(self, inputs): + features, coords, temb = inputs + voxel_features, voxel_coords = self.voxelization(features, coords) + voxel_features = self.voxel_layers(voxel_features) + voxel_features = F.trilinear_devoxelize(voxel_features, voxel_coords, self.resolution, self.training) + fused_features = voxel_features + self.point_features(features) + return fused_features, coords, temb diff --git a/model/pvcnn/modules/se.py b/model/pvcnn/modules/se.py new file mode 100644 index 0000000000000000000000000000000000000000..c34eef769cebbb0f9df44d573288b02169ddfbac --- /dev/null +++ b/model/pvcnn/modules/se.py @@ -0,0 +1,19 @@ +import torch.nn as nn +import torch +__all__ = ['SE3d'] + +class Swish(nn.Module): + def forward(self,x): + return x * torch.sigmoid(x) +class SE3d(nn.Module): + def __init__(self, channel, reduction=8, use_relu=False): + super().__init__() + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction, bias=False), + nn.ReLU(True) if use_relu else Swish() , + nn.Linear(channel // reduction, channel, bias=False), + nn.Sigmoid() + ) + + def forward(self, inputs): + return inputs * self.fc(inputs.mean(-1).mean(-1).mean(-1)).view(inputs.shape[0], inputs.shape[1], 1, 1, 1) diff --git a/model/pvcnn/modules/shared_mlp.py b/model/pvcnn/modules/shared_mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..71fd6350024d5fbaa8a97adfd304013f6df5234e --- /dev/null +++ b/model/pvcnn/modules/shared_mlp.py @@ -0,0 +1,38 @@ +import torch.nn as nn +import torch + +__all__ = ['SharedMLP'] + + +class Swish(nn.Module): + def forward(self,x): + return x * torch.sigmoid(x) + +class SharedMLP(nn.Module): + def __init__(self, in_channels, out_channels, dim=1): + super().__init__() + if dim == 1: # default value + conv = nn.Conv1d + bn = nn.GroupNorm + elif dim == 2: + conv = nn.Conv2d + bn = nn.GroupNorm + else: + raise ValueError + if not isinstance(out_channels, (list, tuple)): + out_channels = [out_channels] + layers = [] + for oc in out_channels: + layers.extend([ + conv(in_channels, oc, 1), + bn(8, oc), + Swish(), + ]) + in_channels = oc + self.layers = nn.Sequential(*layers) + + def forward(self, inputs): + if isinstance(inputs, (list, tuple)): + return (self.layers(inputs[0]), *inputs[1:]) + else: + return self.layers(inputs) diff --git a/model/pvcnn/modules/voxelization.py b/model/pvcnn/modules/voxelization.py new file mode 100644 index 0000000000000000000000000000000000000000..d19597ed23d6753a670613f42d56755e7ffd7b41 --- /dev/null +++ b/model/pvcnn/modules/voxelization.py @@ -0,0 +1,28 @@ +import torch +import torch.nn as nn + +from . import functional as F + +__all__ = ['Voxelization'] + + +class Voxelization(nn.Module): + def __init__(self, resolution, normalize=True, eps=0): + super().__init__() + self.r = int(resolution) + self.normalize = normalize + self.eps = eps + + def forward(self, features, coords): + coords = coords.detach() + norm_coords = coords - coords.mean(2, keepdim=True) + if self.normalize: + norm_coords = norm_coords / (norm_coords.norm(dim=1, keepdim=True).max(dim=2, keepdim=True).values * 2.0 + self.eps) + 0.5 # within a unit cube of size 1x1x1 + else: + norm_coords = (norm_coords + 1) / 2.0 + norm_coords = torch.clamp(norm_coords * self.r, 0, self.r - 1) + vox_coords = torch.round(norm_coords).to(torch.int32) + return F.avg_voxelize(features, vox_coords, self.r), norm_coords + + def extra_repr(self): + return 'resolution={}{}'.format(self.r, ', normalized eps = {}'.format(self.eps) if self.normalize else '') diff --git a/model/pvcnn/pos_enc.py b/model/pvcnn/pos_enc.py new file mode 100644 index 0000000000000000000000000000000000000000..67e6070297775ff0f8f6ec2b4f3075af1f338c9e --- /dev/null +++ b/model/pvcnn/pos_enc.py @@ -0,0 +1,88 @@ +""" +positional encoding + + + +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +class Embedder: + "adapted from https://github.com/yenchenlin/nerf-pytorch/blob/master/run_nerf_helpers.py#L48" + def __init__(self, **kwargs): + """ + default config: + + :param kwargs: + """ + self.kwargs = kwargs + self.create_embedding_fn() + + def create_embedding_fn(self): + embed_fns = [] + d = self.kwargs['input_dims'] + out_dim = 0 + if self.kwargs['include_input']: + embed_fns.append(lambda x: x) + out_dim += d + + max_freq = self.kwargs['max_freq_log2'] + N_freqs = self.kwargs['num_freqs'] + + if self.kwargs['log_sampling']: + freq_bands = 2. ** torch.linspace(0., max_freq, steps=N_freqs) + else: + freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, steps=N_freqs) + + for freq in freq_bands: + for p_fn in self.kwargs['periodic_fns']: + embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) + out_dim += d + + self.embed_fns = embed_fns + self.out_dim = out_dim + + def embed(self, inputs): + """ + + :param inputs: (N_rays, N_samples, 3) + :return: (N_rays, N_samples, D) + """ + return torch.cat([fn(inputs) for fn in self.embed_fns], -1) + + +def get_embedder(multires, i=0, input_dims=3): + if i == -1: + return nn.Identity(), 3 + + embed_kwargs = { + 'include_input': True, + 'input_dims': input_dims, + 'max_freq_log2': multires - 1, + 'num_freqs': multires, + 'log_sampling': True, + 'periodic_fns': [torch.sin, torch.cos], + } + + embedder_obj = Embedder(**embed_kwargs) + embed = lambda x, eo=embedder_obj: eo.embed(x) + return embed, embedder_obj.out_dim + + +def test(): + "" + # x = torch.randn(10, 50, 3) + # embed, _ = get_embedder(10) + # enc = embed(x) + # print(enc.shape) # torch.Size([10, 50, 63]) + # print(x[0, :2]) + # print(enc[0, :2]) # this encoding already includes the input coordinates + + embed, _ = get_embedder(15, input_dims=1) + enc = embed(torch.randn(1, 1, 1)) + print(enc.shape) # (1, 1, 31) 2*multires + input_dims + +if __name__ == '__main__': + test() \ No newline at end of file diff --git a/model/pvcnn/pvcnn.py b/model/pvcnn/pvcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..7c2f18de2f7cad44217a4b5738c717b8338e92c6 --- /dev/null +++ b/model/pvcnn/pvcnn.py @@ -0,0 +1,174 @@ +import numpy as np +import torch +import torch.nn as nn + +from model.pvcnn.modules import Attention +from model.pvcnn.pvcnn_utils import create_mlp_components, create_pointnet2_sa_components, create_pointnet2_fp_modules +from model.pvcnn.pvcnn_utils import get_timestep_embedding + + +class PVCNN2Base(nn.Module): + def __init__( + self, + num_classes: int, + embed_dim: int, + use_att: bool = True, + dropout: float = 0.1, + extra_feature_channels: int = 3, + width_multiplier: int = 1, + voxel_resolution_multiplier: int = 1 + ): + super().__init__() + assert extra_feature_channels >= 0 + self.embed_dim = embed_dim + self.dropout = dropout + self.width_multiplier = width_multiplier + + self.in_channels = extra_feature_channels + 3 + + # Create PointNet-2 model + sa_layers, sa_in_channels, channels_sa_features, _ = create_pointnet2_sa_components( + sa_blocks_config=self.sa_blocks, + extra_feature_channels=extra_feature_channels, + with_se=True, + embed_dim=embed_dim, + use_att=use_att, + dropout=dropout, + width_multiplier=width_multiplier, + voxel_resolution_multiplier=voxel_resolution_multiplier + ) + self.sa_layers = nn.ModuleList(sa_layers) + + # Additional global attention module, default true + self.global_att = None if not use_att else Attention(channels_sa_features, 8, D=1) + + # Only use extra features in the last fp module + sa_in_channels[0] = extra_feature_channels + fp_layers, channels_fp_features = create_pointnet2_fp_modules( + fp_blocks=self.fp_blocks, + in_channels=channels_sa_features, + sa_in_channels=sa_in_channels, + with_se=True, + embed_dim=embed_dim, + use_att=use_att, + dropout=dropout, + width_multiplier=width_multiplier, + voxel_resolution_multiplier=voxel_resolution_multiplier + ) + self.fp_layers = nn.ModuleList(fp_layers) + + # Create MLP layers + self.channels_fp_features = channels_fp_features + layers, _ = create_mlp_components( + in_channels=channels_fp_features, + out_channels=[128, dropout, num_classes], # was 0.5 + classifier=True, + dim=2, + width_multiplier=width_multiplier + ) + self.classifier = nn.Sequential(*layers) # applied to point features directly + + # Time embedding function + self.embedf = nn.Sequential( + nn.Linear(embed_dim, embed_dim), + nn.LeakyReLU(0.1, inplace=True), + nn.Linear(embed_dim, embed_dim), + ) + + def forward(self, inputs: torch.Tensor, t: torch.Tensor, ret_feats=False): + """ + The inputs have size (B, 3 + S, N), where S is the number of additional + feature channels and N is the number of points. The timesteps t can be either + continuous or discrete. This model has a sort of U-Net-like structure I think, + which is why it first goes down and then up in terms of resolution (?) + + torch.Size([16, 394, 16384]) + Downscaling step 0 feature shape: torch.Size([16, 64, 1024]) + Downscaling step 1 feature shape: torch.Size([16, 128, 256]) + Downscaling step 2 feature shape: torch.Size([16, 256, 64]) + Downscaling step 3 feature shape: torch.Size([16, 512, 16]) + Upscaling step 0 feature shape: torch.Size([16, 256, 64]) + Upscaling step 1 feature shape: torch.Size([16, 256, 256]) + Upscaling step 2 feature shape: torch.Size([16, 128, 1024]) + Upscaling step 3 feature shape: torch.Size([16, 64, 16384]) + + """ + + # Embed timesteps, sinusoidal encoding + t_emb = get_timestep_embedding(self.embed_dim, t, inputs.device).float() + t_emb = self.embedf(t_emb)[:, :, None].expand(-1, -1, inputs.shape[-1]) + + # Separate input coordinates and features + coords = inputs[:, :3, :].contiguous() # (B, 3, N) range (-3.5, 3.5) + features = inputs # (B, 3 + S, N) + + # Downscaling layers + coords_list = [] + in_features_list = [] + for i, sa_blocks in enumerate(self.sa_layers): + in_features_list.append(features) + coords_list.append(coords) + if i == 0: + features, coords, t_emb = sa_blocks((features, coords, t_emb)) + else: + features, coords, t_emb = sa_blocks((torch.cat([features, t_emb], dim=1), coords, t_emb)) + + # Replace the input features + in_features_list[0] = inputs[:, 3:, :].contiguous() + + # Apply global attention layer + if self.global_att is not None: + features = self.global_att(features) + + # Upscaling layers + feats_list = [] # save intermediate features from the decoder layers + for fp_idx, fp_blocks in enumerate(self.fp_layers): + features, coords, t_emb = fp_blocks( + ( # this is a tuple because of nn.Sequential + coords_list[-1 - fp_idx], # reverse coords list from above + coords, # original point coordinates + torch.cat([features, t_emb], dim=1), # keep concatenating upsampled features with timesteps + in_features_list[-1 - fp_idx], # reverse features list from above + t_emb # original timestep embedding + ) # this is where point voxel convolution is carried out, the point feature network preserves the order. + ) + feats_list.append((features, coords)) # t_emb is always the same + + # exit(0) + # Output MLP layers + output = self.classifier(features) + + if ret_feats: + return output, feats_list # return intermediate features + + return output + + +class PVCNN2(PVCNN2Base): + # exact same configuration from PVD: https://github.com/alexzhou907/PVD/blob/9747265a5f141e5546fd4f862bfa66aa59f1bd33/train_completion.py#L375 + # conv_configs, sa_configs + # conv_configs: (out_ch, num_blocks, voxel_reso), sa_configs: (num_centers, radius, num_neighbors, out_channels) + sa_blocks = [ + ((32, 2, 32), (1024, 0.1, 32, (32, 64))), # the first is out_channels, num_blocks, voxel_resolution + ((64, 3, 16), (256, 0.2, 32, (64, 128))), + ((128, 3, 8), (64, 0.4, 32, (128, 256))), + (None, (16, 0.8, 32, (256, 256, 512))), + ] + fp_blocks = [ + ((256, 256), (256, 3, 8)), + ((256, 256), (256, 3, 8)), + ((256, 128), (128, 2, 16)), + ((128, 128, 64), (64, 2, 32)), + ] + + def __init__(self, num_classes, embed_dim, use_att=True, dropout=0.1, extra_feature_channels=3, + width_multiplier=1, voxel_resolution_multiplier=1): + super().__init__( + num_classes=num_classes, embed_dim=embed_dim, use_att=use_att, + dropout=dropout, extra_feature_channels=extra_feature_channels, + width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier + ) + + + + diff --git a/model/pvcnn/pvcnn_ho.py b/model/pvcnn/pvcnn_ho.py new file mode 100644 index 0000000000000000000000000000000000000000..00aa0e1e3947cff1379a230c8429d5cf2a271a0d --- /dev/null +++ b/model/pvcnn/pvcnn_ho.py @@ -0,0 +1,416 @@ +""" +two separate diffusion models for human+object, with cross attention to communicate between them +""" +import numpy as np +import torch +import torch.nn as nn +from torch import Tensor +from typing import Optional, Tuple +import math + +from model.pvcnn.modules import Attention, PVConv, BallQueryHO +from model.pvcnn.pvcnn_utils import create_mlp_components, create_pointnet2_sa_components, create_pointnet2_fp_modules +from model.pvcnn.pvcnn_utils import get_timestep_embedding +import torch.nn.functional as F +from .pos_enc import get_embedder + + +def _scaled_dot_product_attention( + q: Tensor, + k: Tensor, + v: Tensor, + attn_mask: Optional[Tensor] = None, + dropout_p: float = 0.0, +) -> Tuple[Tensor, Tensor]: + r""" + Computes scaled dot product attention on query, key and value tensors, using + an optional attention mask if passed, and applying dropout if a probability + greater than 0.0 is specified. + Returns a tensor pair containing attended values and attention weights. + + Args: + q, k, v: query, key and value tensors. See Shape section for shape details. + attn_mask: optional tensor containing mask values to be added to calculated + attention. May be 2D or 3D; see Shape section for details. + dropout_p: dropout probability. If greater than 0.0, dropout is applied. + + Shape: + - q: :math:`(B, Nt, E)` where B is batch size, Nt is the target sequence length, + and E is embedding dimension. + - key: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length, + and E is embedding dimension. + - value: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length, + and E is embedding dimension. + - attn_mask: either a 3D tensor of shape :math:`(B, Nt, Ns)` or a 2D tensor of + shape :math:`(Nt, Ns)`. + + - Output: attention values have shape :math:`(B, Nt, E)`; attention weights + have shape :math:`(B, Nt, Ns)` + """ + B, Nt, E = q.shape + q = q / math.sqrt(E) + # (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns) + if attn_mask is not None: + attn = torch.baddbmm(attn_mask, q, k.transpose(-2, -1)) + else: + attn = torch.bmm(q, k.transpose(-2, -1)) + + attn = F.softmax(attn, dim=-1) + if dropout_p > 0.0: + attn = F.dropout(attn, p=dropout_p) # this is only for training? + # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E) + output = torch.bmm(attn, v) + return output, attn + + +class PVCNN2HumObj(nn.Module): + sa_blocks = [ + # (out_channel, num_blocks, voxel_reso), (num_centers, radius, num_neighbors, out_channels) + ((32, 2, 32), (1024, 0.1, 32, (32, 64))), + ((64, 3, 16), (256, 0.2, 32, (64, 128))), + ((128, 3, 8), (64, 0.4, 32, (128, 256))), + (None, (16, 0.8, 32, (256, 256, 512))), + ] + fp_blocks = [ + # (out, in_channels), (out_channels, num_blocks, voxel_resolution) + ((256, 256), (256, 3, 8)), + ((256, 256), (256, 3, 8)), + ((256, 128), (128, 2, 16)), + ((128, 128, 64), (64, 2, 32)), + ] + + def __init__( + self, + num_classes: int, + embed_dim: int, + use_att: bool = True, + dropout: float = 0.1, + extra_feature_channels: int = 3, + width_multiplier: int = 1, + voxel_resolution_multiplier: int = 1, + attn_type: str = 'simple-cross', # + attn_weight: float=1.0, # attention feature weight + multires: int = 10, # positional encoding resolution + num_neighbours: int = 32 # ball query neighbours + ): + super(PVCNN2HumObj, self).__init__() + assert extra_feature_channels >= 0 + self.embed_dim = embed_dim + self.dropout = dropout + self.width_multiplier = width_multiplier + self.num_neighbours = num_neighbours + self.in_channels = extra_feature_channels + 3 + + self.attn_type = attn_type # how to compute attention + self.attn_weight = attn_weight + + # separate human/object model + classifier, embedf, fp_layers, global_att, sa_layers = self.make_modules(dropout, embed_dim, + extra_feature_channels, num_classes, + use_att, voxel_resolution_multiplier, + width_multiplier) + + self.sa_layers_hum = sa_layers + self.global_att_hum = global_att + self.fp_layers_hum = fp_layers + self.classifier_hum = classifier + self.embedf_hum = embedf + self.posi_encoder, _ = get_embedder(multires) + + classifier, embedf, fp_layers, global_att, sa_layers = self.make_modules(dropout, embed_dim, + extra_feature_channels, num_classes, + use_att, voxel_resolution_multiplier, + width_multiplier) + + self.sa_layers_obj = sa_layers + self.global_att_obj = global_att + self.fp_layers_obj = fp_layers + self.classifier_obj = classifier + self.embedf_obj = embedf + + self.make_coord_attn() + assert self.attn_type == 'coord3d+posenc-learnable', f'unknown attention type {self.attn_type}' + + def make_modules(self, dropout, embed_dim, extra_feature_channels, num_classes, use_att, + voxel_resolution_multiplier, width_multiplier): + """ + make module for human/object + :param dropout: + :param embed_dim: + :param extra_feature_channels: + :param num_classes: + :param use_att: + :param voxel_resolution_multiplier: + :param width_multiplier: + :return: + """ + in_ch_multiplier = 1 + extra_in_channel = 63 # the segmentation+positional feature is projected to dim 63 + + # Create PointNet-2 model + sa_layers, sa_in_channels, channels_sa_features, _ = create_pointnet2_sa_components( + sa_blocks_config=self.sa_blocks, + extra_feature_channels=extra_feature_channels, + with_se=True, + embed_dim=embed_dim, + use_att=use_att, + dropout=dropout, + width_multiplier=width_multiplier, + voxel_resolution_multiplier=voxel_resolution_multiplier, + in_ch_multiplier=in_ch_multiplier, + extra_in_channel=extra_in_channel + ) + sa_layers = nn.ModuleList(sa_layers) + # Additional global attention module, default true + if self.attn_type == 'coord3d+posenc+rgb': + # reduce channel number, only for the global attention layer, the decoders remain unchanged + global_att = None if not use_att else Attention(channels_sa_features//2, 8, D=1) + else: + global_att = None if not use_att else Attention(channels_sa_features, 8, D=1) + + # Only use extra features in the last fp module + sa_in_channels[0] = extra_feature_channels + fp_layers, channels_fp_features = create_pointnet2_fp_modules( + fp_blocks=self.fp_blocks, + in_channels=channels_sa_features, + sa_in_channels=sa_in_channels, + with_se=True, + embed_dim=embed_dim, + use_att=use_att, + dropout=dropout, + width_multiplier=width_multiplier, + voxel_resolution_multiplier=voxel_resolution_multiplier, + in_ch_multiplier=in_ch_multiplier, + extra_in_channel=extra_in_channel + ) + fp_layers = nn.ModuleList(fp_layers) + + # Create MLP layers for output prediction + layers, _ = create_mlp_components( + in_channels=channels_fp_features, + out_channels=[128, dropout, num_classes], # was 0.5 + classifier=True, + dim=2, + width_multiplier=width_multiplier + ) + classifier = nn.Sequential(*layers) # applied to point features directly + # Time embedding function + embedf = nn.Sequential( + nn.Linear(embed_dim, embed_dim), + nn.LeakyReLU(0.1, inplace=True), + nn.Linear(embed_dim, embed_dim), + ) + return classifier, embedf, fp_layers, global_att, sa_layers + + def make_coord_attn(self): + "learnable attention only on point coordinate + positional encoding " + pvconv_encoders = [] + for i, (conv_configs, sa_configs) in enumerate(self.sa_blocks): + # should use point net out channel + out_channel = 63 + layer = nn.MultiheadAttention(out_channel, 1, batch_first=True, kdim=out_channel, vdim=out_channel+2) + pvconv_encoders.append(layer) # only one block for conv + pvconv_decoders = [] + for fp_configs, conv_configs in self.fp_blocks: + out_channel = 63 + layer = nn.MultiheadAttention(out_channel, 1, batch_first=True, kdim=out_channel, vdim=out_channel + 2) + pvconv_decoders.append(layer) + self.cross_conv_encoders = nn.ModuleList(pvconv_encoders) + self.cross_conv_decoders = nn.ModuleList(pvconv_decoders) + + def forward(self, inputs_hum: torch.Tensor, inputs_obj: torch.Tensor, t: torch.Tensor, norm_params=None): + """ + + :param inputs: (B, N, D), N is the number of points, D is the conditional feature dimension + :param t: (B, ) timestamps + :param norm_params: (2, B, 4), transformation parameters that move points back to H+O joint space, first 3 values are cent, the last is radius/scale + :return: (B, N, D_out) x2 + """ + inputs_hum = inputs_hum.transpose(1, 2) + inputs_obj = inputs_obj.transpose(1, 2) + + # Embed timesteps, sinusoidal encoding + t_emb_init = get_timestep_embedding(self.embed_dim, t, inputs_hum.device).float() + t_emb_hum = self.embedf_hum(t_emb_init)[:, :, None].expand(-1, -1, inputs_hum.shape[-1]).float() + t_emb_obj = self.embedf_obj(t_emb_init)[:, :, None].expand(-1, -1, inputs_obj.shape[-1]).float() + + # Separate input coordinates and features + coords_hum, coords_obj = inputs_hum[:, :3, :].contiguous(), inputs_obj[:, :3, :].contiguous() # (B, 3, N) range (-3.5, 3.5) + features_hum, features_obj = inputs_hum, inputs_obj # (B, 3 + S, N) + + DEBUG = False + + # Encoder: Downscaling layers + coords_list_hum, coords_list_obj = [], [] + in_features_list_hum, in_features_list_obj = [], [] + for i, (sa_blocks_h, sa_blocks_o) in enumerate(zip(self.sa_layers_hum, self.sa_layers_obj)): + in_features_list_hum.append(features_hum) + coords_list_hum.append(coords_hum) + in_features_list_obj.append(features_obj) + coords_list_obj.append(coords_obj) + if i == 0: + # First step no timestamp embedding + features_hum, coords_hum, t_emb_hum = sa_blocks_h((features_hum, coords_hum, t_emb_hum)) + features_obj, coords_obj, t_emb_obj = sa_blocks_o((features_obj, coords_obj, t_emb_obj)) + else: + features_hum, coords_hum, t_emb_hum = sa_blocks_h((torch.cat([features_hum, t_emb_hum], dim=1), coords_hum, t_emb_hum)) + features_obj, coords_obj, t_emb_obj = sa_blocks_o((torch.cat([features_obj, t_emb_obj], dim=1), coords_obj, t_emb_obj)) + + if i < len(self.sa_layers_hum)-1: + features_hum, features_obj = self.add_attn_feature(features_hum, features_obj, + self.transform_coords(coords_hum, norm_params, 0), + self.transform_coords(coords_obj, norm_params, 1), + self.cross_conv_encoders[i], + temb_hum=t_emb_hum, + temb_obj=t_emb_obj) + + # for debug: save some point clouds + if DEBUG: + for i, (ch, co) in enumerate(zip(coords_list_hum, coords_list_obj)): + import trimesh + ch_ho = self.transform_coords(ch, norm_params, 0) + co_ho = self.transform_coords(co, norm_params, 1) + points = torch.cat([ch_ho, co_ho], -1).transpose(1, 2) + L = ch_ho.shape[-1] + vc = np.concatenate( + [np.zeros((L, 3)) + np.array([0.5, 1.0, 0]), + np.zeros((L, 3)) + np.array([0.05, 1.0, 1.0])] + ) + trimesh.PointCloud(points[0].cpu().numpy(), colors=vc).export( + f'/BS/xxie-2/work/pc2-diff/experiments/debug/meshes/encoder_step{i:02d}.ply') + + # Replace the input features + in_features_list_hum[0] = inputs_hum[:, 3:, :].contiguous() + in_features_list_obj[0] = inputs_obj[:, 3:, :].contiguous() + + # Apply global attention layer + if self.global_att_hum is not None: + features_hum = self.global_att_hum(features_hum) + + if self.global_att_obj is not None: + features_obj = self.global_att_obj(features_obj) + # Do cross attention after self-attention + if self.attn_type in ['coord3d+posenc-learnable']: + features_hum, features_obj = self.add_attn_feature(features_hum, features_obj, + self.transform_coords(coords_hum, norm_params, 0), + self.transform_coords(coords_obj, norm_params, 1), + self.cross_conv_encoders[-1] if self.attn_type in [ + 'coord3d+posenc-learnable'] else None, + temb_hum=t_emb_hum, + temb_obj=t_emb_obj) + + # Upscaling layers + for fp_idx, (fp_blocks_h, fp_blocks_o) in enumerate(zip(self.fp_layers_hum, self.fp_layers_obj)): + features_hum, coords_hum, t_emb_hum = fp_blocks_h( + ( # this is a tuple because of nn.Sequential + coords_list_hum[-1 - fp_idx], # reverse coords list from above + coords_hum, # original point coordinates + torch.cat([features_hum, t_emb_hum], dim=1), # keep concatenating upsampled features with timesteps + in_features_list_hum[-1 - fp_idx], # reverse features list from above + t_emb_hum # original timestep embedding + ) + # this is where point voxel convolution is carried out, the point feature network preserves the order. + ) + features_obj, coords_obj, t_emb_obj = fp_blocks_o( + ( # this is a tuple because of nn.Sequential + coords_list_obj[-1 - fp_idx], # reverse coords list from above + coords_obj, # original point coordinates + torch.cat([features_obj, t_emb_obj], dim=1), # keep concatenating upsampled features with timesteps + in_features_list_obj[-1 - fp_idx], # reverse features list from above + t_emb_obj # original timestep embedding + ) + # this is where point voxel convolution is carried out, the point feature network preserves the order. + ) + + # these features are reused as input for next layer + # add attention except for the last layer + if fp_idx < len(self.fp_layers_hum) - 1: + # Perform cross attention between human and object branches + features_hum, features_obj = self.add_attn_feature(features_hum, features_obj, + self.transform_coords(coords_hum, norm_params, 0), + self.transform_coords(coords_obj, norm_params, 1), + self.cross_conv_decoders[fp_idx] if self.attn_type in ['coord3d+posenc-learnable'] else None, + temb_hum=t_emb_hum, + temb_obj=t_emb_obj + ) + + if DEBUG: + import trimesh + ch_ho = self.transform_coords(coords_hum, norm_params, 0) + co_ho = self.transform_coords(coords_obj, norm_params, 1) + points = torch.cat([ch_ho, co_ho], -1).transpose(1, 2) + L = ch_ho.shape[-1] + vc = np.concatenate( + [np.zeros((L, 3)) + np.array([0.5, 1.0, 0]), + np.zeros((L, 3)) + np.array([0.05, 1.0, 1.0])] + ) + trimesh.PointCloud(points[0].cpu().numpy(), colors=vc).export( + f'/BS/xxie-2/work/pc2-diff/experiments/debug/meshes/decoder_step{fp_idx:02d}.ply') + + if DEBUG: + exit(0) + # Output MLP layers + output_hum = self.classifier_hum(features_hum).transpose(1, 2) # convert back to (B, N, D) format + output_obj = self.classifier_obj(features_obj).transpose(1, 2) + + return output_hum, output_obj + + def transform_coords(self, coords, norm_params, target_ind): + """ + transform coordinates such that the points align back to H+O interaction space + :param coords: (B, 3, N) + :param norm_params: (2, B, 4) + :param target_ind: 0 or 1 + :return: + """ + scale = norm_params[target_ind, :, 3:].unsqueeze(1) + cent = norm_params[target_ind, :, :3].unsqueeze(-1) + coords_ho = coords * 2 * scale + cent + return coords_ho + + def add_attn_feature(self, features_hum, features_obj, + coords_hum=None, coords_obj=None, + attn_module=None, + temb_hum=None, temb_obj=None): + """ + compute cross attention between human and object points + :param features_hum: (B, D, N) + :param features_obj: (B, D, N) + :param coords_hum: (B, 3, N), human points in the H+O frame + :param coords_obj: (B, 3, N), object points in the H+O frame + :param temb: time embedding + :return: cross attended human object features. + """ + B, D, N = features_hum.shape + # the attn_module is learnable, only difference is the number of output feature dimension + onehot_hum, onehot_obj = self.get_onehot_feat(features_hum) + pos_hum = self.posi_encoder(coords_hum.permute(0, 2, 1)).permute(0, 2, 1) + pos_obj = self.posi_encoder(coords_obj.permute(0, 2, 1)).permute(0, 2, 1) + feat_hum = torch.cat([pos_hum, onehot_hum], 1) + feat_obj = torch.cat([pos_obj, onehot_obj], 1) # (B, 65, N) + + attn_h2o = attn_module(pos_obj.permute(0, 2, 1), + pos_hum.permute(0, 2, 1), + feat_hum.permute(0, 2, 1))[0].permute(0, 2, 1) + attn_o2h = attn_module(pos_hum.permute(0, 2, 1), + pos_obj.permute(0, 2, 1), + feat_obj.permute(0, 2, 1))[0].permute(0, 2, 1) + features_hum = torch.cat([features_hum, attn_o2h * self.attn_weight], 1) + features_obj = torch.cat([features_obj, attn_h2o * self.attn_weight], 1) + + return features_hum, features_obj + + def get_onehot_feat(self, features_hum): + """ + compute a onehot feature vector to identify this is human or object + :param features_hum: + :return: (B, 2, N) x2 for human and object + """ + B, D, N = features_hum.shape + onehot_hum = torch.zeros(B, 2, N).to(features_hum.device) + onehot_hum[:, 0] = 1. + onehot_obj = torch.zeros(B, 2, N).to(features_hum.device) + onehot_obj[:, 1] = 1.0 + return onehot_hum, onehot_obj + + diff --git a/model/pvcnn/pvcnn_plus_plus.py b/model/pvcnn/pvcnn_plus_plus.py new file mode 100644 index 0000000000000000000000000000000000000000..8bdc235f29c43f41f0a43ef47abc0ba683a1a46a --- /dev/null +++ b/model/pvcnn/pvcnn_plus_plus.py @@ -0,0 +1,42 @@ +import torch +import torch.nn as nn + +from model.pvcnn.pvcnn import PVCNN2 +from model.pvcnn.pvcnn_utils import create_mlp_components +from model.simple.simple_model import SimplePointModel + + +class PVCNN2PlusPlus(nn.Module): + def __init__( + self, + *, + embed_dim, + num_classes, + extra_feature_channels, + ): + super().__init__() + + # Create models + self.simple_point_model = SimplePointModel(num_classes=embed_dim, embed_dim=embed_dim, + extra_feature_channels=extra_feature_channels, num_layers=3) + self.pvcnn = PVCNN2(num_classes=embed_dim, embed_dim=embed_dim, + extra_feature_channels=(embed_dim - 3)) + + # Tie timestep embeddings + self.pvcnn.embedf = self.simple_point_model.timestep_projection + + # # Remove output projections + # self.pvcnn.classifier = nn.Identity() + # self.simple_point_model.output_projection = nn.Identity() + + # Create new output projection + layers, _ = create_mlp_components( + in_channels=embed_dim, out_channels=[128, self.pvcnn.dropout, num_classes], + classifier=True, dim=2, width_multiplier=self.pvcnn.width_multiplier) + self.output_projection = nn.Sequential(*layers) + + def forward(self, inputs: torch.Tensor, t: torch.Tensor): + x = self.simple_point_model(inputs, t) # (B, D_emb, N) + x = x + self.pvcnn(x, t) # (B, D_emb, N) + x = self.output_projection(x) # (B, D_out, N) + return x diff --git a/model/pvcnn/pvcnn_utils.py b/model/pvcnn/pvcnn_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..81533821fba921f8f8e6a1bedeed62cc88963cb0 --- /dev/null +++ b/model/pvcnn/pvcnn_utils.py @@ -0,0 +1,216 @@ +import functools +import torch +import torch.nn as nn +import numpy as np + +from model.pvcnn.modules import SharedMLP, PVConv, PointNetSAModule, PointNetAModule, PointNetFPModule, Swish + + +def _linear_gn_relu(in_channels, out_channels): + return nn.Sequential(nn.Linear(in_channels, out_channels), nn.GroupNorm(8,out_channels), Swish()) + + +def create_mlp_components(in_channels, out_channels, classifier=False, dim=2, width_multiplier=1): + r = width_multiplier + + if dim == 1: + block = _linear_gn_relu + else: + block = SharedMLP + if not isinstance(out_channels, (list, tuple)): + out_channels = [out_channels] + if len(out_channels) == 0 or (len(out_channels) == 1 and out_channels[0] is None): + return nn.Sequential(), in_channels, in_channels + + layers = [] + for oc in out_channels[:-1]: + if oc < 1: + layers.append(nn.Dropout(oc)) + else: + oc = int(r * oc) + layers.append(block(in_channels, oc)) + in_channels = oc + if dim == 1: + if classifier: + layers.append(nn.Linear(in_channels, out_channels[-1])) + else: + layers.append(_linear_gn_relu(in_channels, int(r * out_channels[-1]))) + else: + if classifier: + layers.append(nn.Conv1d(in_channels, out_channels[-1], 1)) + else: + layers.append(SharedMLP(in_channels, int(r * out_channels[-1]))) + return layers, out_channels[-1] if classifier else int(r * out_channels[-1]) + + +def create_pointnet_components(blocks, in_channels, embed_dim, with_se=False, normalize=True, eps=0, + width_multiplier=1, voxel_resolution_multiplier=1): + r, vr = width_multiplier, voxel_resolution_multiplier + + layers, concat_channels = [], 0 + c = 0 + for k, (out_channels, num_blocks, voxel_resolution) in enumerate(blocks): + out_channels = int(r * out_channels) + for p in range(num_blocks): + attention = k % 2 == 0 and k > 0 and p == 0 + if voxel_resolution is None: + block = SharedMLP + else: + block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention, + with_se=with_se, normalize=normalize, eps=eps) + + if c == 0: + layers.append(block(in_channels, out_channels)) + else: + layers.append(block(in_channels+embed_dim, out_channels)) + in_channels = out_channels + concat_channels += out_channels + c += 1 + return layers, in_channels, concat_channels + + +def create_pointnet2_sa_components(sa_blocks_config, extra_feature_channels, embed_dim=64, use_att=False, + dropout=0.1, with_se=False, normalize=True, eps=0, + width_multiplier=1, voxel_resolution_multiplier=1, + in_ch_multiplier=1, + extra_in_channel=0): + "use_att is True by default, in_ch_multiplier: increase the input channel dimension" + r, vr = width_multiplier, voxel_resolution_multiplier + in_channels = extra_feature_channels + 3 + + sa_layers, sa_in_channels = [], [] + block_count = 0 + for conv_configs, sa_configs in sa_blocks_config: + k = 0 + sa_in_channels.append(in_channels) + sa_blocks = [] + + if conv_configs is not None: + out_channels, num_blocks, voxel_resolution = conv_configs + out_channels = int(r * out_channels) + for p in range(num_blocks): # pconv is repeated + attention = (block_count+1) % 2 == 0 and use_att and p == 0 + if voxel_resolution is None: + block = SharedMLP + else: + block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention, + dropout=dropout, + with_se=with_se, with_se_relu=True, + normalize=normalize, eps=eps) + + if block_count == 0: + sa_blocks.append(block(in_channels, out_channels)) + elif k ==0: + sa_blocks.append(block(in_channels+embed_dim, out_channels)) + in_channels = out_channels + k += 1 + extra_feature_channels = in_channels + num_centers, radius, num_neighbors, out_channels = sa_configs + _out_channels = [] + for oc in out_channels: + if isinstance(oc, (list, tuple)): + _out_channels.append([int(r * _oc) for _oc in oc]) + else: + _out_channels.append(int(r * oc)) + out_channels = _out_channels + if num_centers is None: + block = PointNetAModule # always not-none + else: + block = functools.partial(PointNetSAModule, num_centers=num_centers, radius=radius, + num_neighbors=num_neighbors) + sa_blocks.append(block(in_channels=extra_feature_channels+(embed_dim if k==0 else 0 ), out_channels=out_channels, + include_coordinates=True)) + block_count += 1 + # XH: double the channel for concat, or add additional channel for cross attention + if block_count < len(sa_blocks_config): + in_channels = extra_feature_channels = int(sa_blocks[-1].out_channels * in_ch_multiplier + extra_in_channel) + else: + # no cross attention before the self attention module + in_channels = extra_feature_channels = int(sa_blocks[-1].out_channels * in_ch_multiplier) + if len(sa_blocks) == 1: + sa_layers.append(sa_blocks[0]) # first pconv is repeated ? + else: + sa_layers.append(nn.Sequential(*sa_blocks)) + + return sa_layers, sa_in_channels, in_channels, 1 if num_centers is None else num_centers + + +def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, embed_dim=64, use_att=False, + dropout=0.1, + with_se=False, normalize=True, eps=0, + width_multiplier=1, voxel_resolution_multiplier=1, + in_ch_multiplier=1, extra_in_channel=0): + """ + + :param fp_blocks: + :param in_channels: + :param sa_in_channels: + :param embed_dim: + :param use_att: + :param dropout: + :param with_se: + :param normalize: + :param eps: + :param width_multiplier: + :param voxel_resolution_multiplier: + :param in_ch_multiplier: increase the input channel dimension + :return: + """ + r, vr = width_multiplier, voxel_resolution_multiplier + + fp_layers = [] + c = 0 + for fp_idx, (fp_configs, conv_configs) in enumerate(fp_blocks): + fp_blocks = [] + out_channels = tuple(int(r * oc) for oc in fp_configs) + if fp_idx > 0: + # to handle additional channel from concatenating human + object features + sa_in_concat = int(in_channels*in_ch_multiplier + extra_in_channel) + else: + sa_in_concat = in_channels + extra_in_channel # this is for simple-coord3d, where the decoder first layer also has cross attention + fp_blocks.append( + PointNetFPModule(in_channels=sa_in_concat + sa_in_channels[-1 - fp_idx] + embed_dim, + out_channels=out_channels) + ) # interpolate + Conv1d, does not change number of points + in_channels = out_channels[-1] + + if conv_configs is not None: + out_channels, num_blocks, voxel_resolution = conv_configs + out_channels = int(r * out_channels) + for p in range(num_blocks): + attention = (c+1) % 2 == 0 and c < len(fp_blocks) - 1 and use_att and p == 0 + if voxel_resolution is None: + block = SharedMLP + else: + block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention, + dropout=dropout, + with_se=with_se, with_se_relu=True, + normalize=normalize, eps=eps) + + fp_blocks.append(block(in_channels, out_channels)) + in_channels = out_channels # this should not change! + if len(fp_blocks) == 1: + fp_layers.append(fp_blocks[0]) # this is the last block, no PVConv layer + else: + fp_layers.append(nn.Sequential(*fp_blocks)) + + c += 1 + + return fp_layers, in_channels + + +def get_timestep_embedding(embed_dim, timesteps, device): + """ + Timestep embedding function. Not that this should work just as well for + continuous values as for discrete values. + """ + assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32 + half_dim = embed_dim // 2 + emb = np.log(10000) / (half_dim - 1) + emb = torch.from_numpy(np.exp(np.arange(0, half_dim) * -emb)).float().to(device) + emb = timesteps[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embed_dim % 2 == 1: # zero pad + emb = nn.functional.pad(emb, (0, 1), "constant", 0) + assert emb.shape == torch.Size([timesteps.shape[0], embed_dim]) + return emb diff --git a/model/simple/__init__.py b/model/simple/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/model/simple/simple_model.py b/model/simple/simple_model.py new file mode 100644 index 0000000000000000000000000000000000000000..ceb25a5cc0dae72298bceed7739050891810ce35 --- /dev/null +++ b/model/simple/simple_model.py @@ -0,0 +1,80 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import nn + +from .simple_model_utils import FeedForward, BasePointModel + + +class SimplePointModel(BasePointModel): + """ + A simple model that processes a point cloud by applying a series of MLPs to each point + individually, along with some pooled global features. + """ + + def get_layers(self): + return nn.ModuleList([FeedForward( + d_in=(3 * self.dim), d_hidden=(4 * self.dim), d_out=self.dim, + activation=nn.SiLU(), is_gated=True, bias1=False, bias2=False, bias_gate=False, use_layernorm=True + ) for _ in range(self.num_layers)]) + + def forward(self, inputs: torch.Tensor, t: torch.Tensor): + + # Prepare inputs + x, coords = self.prepare_inputs(inputs, t) + + # Model + for layer in self.layers: + x_pool_max, x_pool_std = self.get_global_tensors(x) + x_input = torch.cat((x, x_pool_max, x_pool_std), dim=-1) # (B, N, 3 * D) + x = x + layer(x_input) # (B, N, D_model) + + # Project + x = self.output_projection(x) # (B, N, D_out) + x = torch.transpose(x, -2, -1) # -> (B, D_out, N) + + return x + + +class SimpleNearestNeighborsPointModel(BasePointModel): + """ + A simple model that processes a point cloud by applying a series of MLPs to each point + individually, along with some pooled global features, and the features of its nearest + neighbors. + """ + + def __init__(self, num_neighbors: int = 4, **kwargs): + self.num_neighbors = num_neighbors + super().__init__(**kwargs) + from pytorch3d.ops import knn_points + self.knn_points = knn_points + + def get_layers(self): + return nn.ModuleList([FeedForward( + d_in=((3 + self.num_neighbors) * self.dim), d_hidden=(4 * self.dim), d_out=self.dim, + activation=nn.SiLU(), is_gated=True, bias1=False, bias2=False, bias_gate=False, use_layernorm=True + ) for _ in range(self.num_layers)]) + + def forward(self, inputs: torch.Tensor, t: torch.Tensor): + + # Prepare inputs + x, coords = self.prepare_inputs(inputs, t) # (B, N, D), (B, N, 3) + + # Get nearest neighbors. Note that the first neighbor is the identity, which is convenient + _dists, indices, _neighbors = self.knn_points( + p1=coords, p2=coords, K=(self.num_neighbors + 1), + return_nn=False) # (B, N, K), (B, N, K) + (B, N, D), (_B, _N, K) = x.shape, indices.shape + + # Model + for layer in self.layers: + x_neighbor = torch.stack([x_i[idx] for x_i, idx in zip(x, indices.reshape(B, N * K))]).reshape(B, N, K * D) + x_pool_max, x_pool_std = self.get_global_tensors(x) + x_input = torch.cat((x_neighbor, x_pool_max, x_pool_std), dim=-1) # (B, N, (3+K)*D) + x = x + layer(x_input) # (B, N, D_model) + + # Project + x = self.output_projection(x) # (B, N, D_out) + x = torch.transpose(x, -2, -1) # -> (B, D_out, N) + + return x diff --git a/model/simple/simple_model_utils.py b/model/simple/simple_model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bd84cd1f2f2a235715b98a74c428dcf6d2e36104 --- /dev/null +++ b/model/simple/simple_model_utils.py @@ -0,0 +1,282 @@ +from typing import Any, Callable, Iterable, List, Optional, Union + +import torch +import torch.jit as jit +import torch.nn as nn +import torch.nn.functional as F +from torch import Size, Tensor, nn +from torch.nn import LayerNorm + +from model.pvcnn.pvcnn_utils import get_timestep_embedding + + +def sample_b(size: Size, sigma: float) -> Tensor: + """Sample b matrix for fourier features + + Arguments: + size (Size): b matrix size + sigma (float): std of the gaussian + + Returns: + b (Tensor): b matrix + """ + return torch.randn(size) * sigma + + +@jit.script +def map_positional_encoding(v: Tensor, freq_bands: Tensor) -> Tensor: + """Map v to positional encoding representation phi(v) + + Arguments: + v (Tensor): input features (B, IFeatures) + freq_bands (Tensor): frequency bands (N_freqs, ) + + Returns: + phi(v) (Tensor): fourrier features (B, 3 + (2 * N_freqs) * 3) + """ + pe = [v] + for freq in freq_bands: + fv = freq * v + pe += [torch.sin(fv), torch.cos(fv)] + return torch.cat(pe, dim=-1) + + +@jit.script +def map_fourier_features(v: Tensor, b: Tensor) -> Tensor: + """Map v to fourier features representation phi(v) + + Arguments: + v (Tensor): input features (B, IFeatures) + b (Tensor): b matrix (OFeatures, IFeatures) + + Returns: + phi(v) (Tensor): fourrier features (B, 2 * Features) + """ + PI = 3.141592653589793 + a = 2 * PI * v @ b.T + return torch.cat((torch.sin(a), torch.cos(a)), dim=-1) + + +class FeatureMapping(nn.Module): + """FeatureMapping nn.Module + + Maps v to features following transformation phi(v) + + Arguments: + i_dim (int): input dimensions + o_dim (int): output dimensions + """ + + def __init__(self, i_dim: int, o_dim: int) -> None: + super().__init__() + self.i_dim = i_dim + self.o_dim = o_dim + + def forward(self, v: Tensor) -> Tensor: + """FeratureMapping forward pass + + Arguments: + v (Tensor): input features (B, IFeatures) + + Returns: + phi(v) (Tensor): mapped features (B, OFeatures) + """ + raise NotImplementedError("Forward pass not implemented yet!") + + +class PositionalEncoding(FeatureMapping): + """PositionalEncoding module + + Maps v to positional encoding representation phi(v) + + Arguments: + i_dim (int): input dimension for v + N_freqs (int): #frequency to sample (default: 10) + """ + + def __init__( + self, + i_dim: int, + N_freqs: int = 10, + ) -> None: + super().__init__(i_dim, 3 + (2 * N_freqs) * 3) + self.N_freqs = N_freqs + + a, b = 1, self.N_freqs - 1 + freq_bands = 2 ** torch.linspace(a, b, self.N_freqs) + self.register_buffer("freq_bands", freq_bands) + + def forward(self, v: Tensor) -> Tensor: + """Map v to positional encoding representation phi(v) + + Arguments: + v (Tensor): input features (B, IFeatures) + + Returns: + phi(v) (Tensor): fourrier features (B, 3 + (2 * N_freqs) * 3) + """ + return map_positional_encoding(v, self.freq_bands) + + +class FourierFeatures(FeatureMapping): + + """Fourier Features module + + Maps v to fourier features representation phi(v) + + Arguments: + i_dim (int): input dimension for v + features (int): output dimension (default: 256) + sigma (float): std of the gaussian (default: 26.) + """ + + def __init__( + self, + i_dim: int, + features: int = 256, + sigma: float = 26., + ) -> None: + super().__init__(i_dim, 2 * features) + self.features = features + self.sigma = sigma + + self.size = Size((self.features, self.i_dim)) + self.register_buffer("b", sample_b(self.size, self.sigma)) + + def forward(self, v: Tensor) -> Tensor: + """Map v to fourier features representation phi(v) + + Arguments: + v (Tensor): input features (B, IFeatures) + + Returns: + phi(v) (Tensor): fourrier features (B, 2 * Features) + """ + return map_fourier_features(v, self.b) + + +class FeedForward(nn.Module): + """ Adapted from the FeedForward layer from labmlai """ + + def __init__( + self, + d_in: int, + d_hidden: int, + d_out: int, + activation: Callable = nn.ReLU(), + is_gated: bool = False, + bias1: bool = True, + bias2: bool = True, + bias_gate: bool = True, + dropout: float = 0.1, + use_layernorm: bool = False, + ): + super().__init__() + # Layer one parameterized by weight $W_1$ and bias $b_1$ + self.layer1 = nn.Linear(d_in, d_hidden, bias=bias1) + # Layer one parameterized by weight $W_1$ and bias $b_1$ + self.layer2 = nn.Linear(d_hidden, d_out, bias=bias2) + # Hidden layer dropout + self.dropout = nn.Dropout(dropout) + # Activation function $f$ + self.activation = activation + # Whether there is a gate + self.is_gated = is_gated + if is_gated: + # If there is a gate the linear layer to transform inputs to + # be multiplied by the gate, parameterized by weight $V$ and bias $c$ + self.linear_v = nn.Linear(d_in, d_hidden, bias=bias_gate) + # Whether to add a layernorm layer + self.use_layernorm = use_layernorm + if use_layernorm: + self.layernorm = LayerNorm(d_in) + + def forward(self, x: Tensor, coords: Tensor = None) -> Tensor: + """Applies a simple feed forward layer""" + x = self.layernorm(x) if self.use_layernorm else x + g = self.activation(self.layer1(x)) + x = (g * self.linear_v(x)) if self.is_gated else g + x = self.dropout(x) + x = self.layer2(x) + return x + + +class BasePointModel(nn.Module): + """ A base class providing useful methods for point cloud processing. """ + + def __init__( + self, + *, + num_classes, + embed_dim, + extra_feature_channels, + dim: int = 128, + num_layers: int = 6 + ): + + super().__init__() + self.extra_feature_channels = extra_feature_channels + self.timestep_embed_dim = embed_dim + self.output_dim = num_classes + self.dim = dim + self.num_layers = num_layers + + # Time embedding function + self.timestep_projection = nn.Sequential( + nn.Linear(embed_dim, embed_dim), + nn.LeakyReLU(0.1, inplace=True), + nn.Linear(embed_dim, embed_dim), + ) + + # Positional encoding + self.positional_encoding = PositionalEncoding(i_dim=3, N_freqs=10) + positional_encoding_d_out = 3 + (2 * 10) * 3 + + # Input projection (point coords, point coord encodings, other features, and timestep embeddings) + self.input_projection = nn.Linear( + in_features=(3 + positional_encoding_d_out + extra_feature_channels + self.timestep_embed_dim), + out_features=self.dim + ) + + # Transformer layers + self.layers = self.get_layers() + + # Output projection + self.output_projection = nn.Linear(self.dim, self.output_dim) + + def get_layers(self): + raise NotImplementedError('This method should be implemented by subclasses') + + def prepare_inputs(self, inputs: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + """ + The inputs have size (B, 3 + S, N), where S is the number of additional + feature channels and N is the number of points. The timesteps t can be either + continuous or discrete. This model has a sort of U-Net-like structure I think, + which is why it first goes down and then up in terms of resolution (?) + """ + + # Embed and project timesteps + t_emb = get_timestep_embedding(self.timestep_embed_dim, t, inputs.device) + t_emb = self.timestep_projection(t_emb)[:, None, :].expand(-1, inputs.shape[-1], -1) # (B, N, D_t_emb) + + # Separate input coordinates and features + x = torch.transpose(inputs, -2, -1) # -> (B, N, 3 + S) + coords = x[:, :, :3] # (B, N, 3), point coordinates + + # Positional encoding of point coords + coords_posenc = self.positional_encoding(coords) # (B, N, D_p_enc) + + # Project + x = torch.cat((x, coords_posenc, t_emb), dim=2) # (B, N, 3 + S + D_p_enc + D_t_emb) + x = self.input_projection(x) # (B, N, D_model) + + return x, coords + + def get_global_tensors(self, x: Tensor): + B, N, D = x.shape + x_pool_max = torch.max(x, dim=1, keepdim=True).values.repeat(1, N, 1) # (B, 1, D) + x_pool_std = torch.std(x, dim=1, keepdim=True).repeat(1, N, 1) # (B, 1, D) + return x_pool_max, x_pool_std + + def forward(self, inputs: torch.Tensor, t: torch.Tensor): + raise NotImplementedError('This method should be implemented by subclasses') \ No newline at end of file diff --git a/render/__init__.py b/render/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/render/pyt3d_wrapper.py b/render/pyt3d_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..ba2c025ee5383cfe1bde3412b0e051569dfffd80 --- /dev/null +++ b/render/pyt3d_wrapper.py @@ -0,0 +1,302 @@ +""" +a simple wrapper for pytorch3d rendering +Cite: BEHAVE: Dataset and Method for Tracking Human Object Interaction +""" +import numpy as np +import torch +from copy import deepcopy +# Data structures and functions for rendering +from pytorch3d.renderer import ( + PointLights, + RasterizationSettings, + MeshRenderer, + MeshRasterizer, + SoftPhongShader, + TexturesVertex, + PerspectiveCameras, + PointsRasterizer, + AlphaCompositor, + PointsRasterizationSettings, +) +from pytorch3d.structures import Meshes, join_meshes_as_scene, Pointclouds + +SMPL_OBJ_COLOR_LIST = [ + [0.65098039, 0.74117647, 0.85882353], # SMPL + [251 / 255.0, 128 / 255.0, 114 / 255.0], # object + ] + + +class MeshRendererWrapper: + "a simple wrapper for the pytorch3d mesh renderer" + def __init__(self, image_size=1200, + faces_per_pixel=1, + device='cuda:0', + blur_radius=0, lights=None, + materials=None, max_faces_per_bin=50000): + self.image_size = image_size + self.faces_per_pixel=faces_per_pixel + self.max_faces_per_bin=max_faces_per_bin # prevent overflow, see https://github.com/facebookresearch/pytorch3d/issues/348 + self.blur_radius = blur_radius + self.device = device + self.lights=lights if lights is not None else PointLights( + ((0.5, 0.5, 0.5),), ((0.5, 0.5, 0.5),), ((0.05, 0.05, 0.05),), ((0, -2, 0),), device + ) + self.materials = materials + self.renderer = self.setup_renderer() + + def setup_renderer(self): + # for sillhouette rendering + sigma = 1e-4 + raster_settings = RasterizationSettings( + image_size=self.image_size, + blur_radius=self.blur_radius, + # blur_radius=np.log(1. / 1e-4 - 1.) * sigma, # this will create large sphere for each face + faces_per_pixel=self.faces_per_pixel, + clip_barycentric_coords=False, + max_faces_per_bin=self.max_faces_per_bin + ) + shader = SoftPhongShader( + device=self.device, + lights=self.lights, + materials=self.materials) + renderer = MeshRenderer( + rasterizer=MeshRasterizer( + raster_settings=raster_settings), + shader=shader + ) + return renderer + + def render(self, meshes, cameras, ret_mask=False, mode='rgb'): + assert len(meshes.faces_list()) == 1, 'currently only support batch size =1 rendering!' + images = self.renderer(meshes, cameras=cameras) + # print(images.shape) + if ret_mask or mode=='mask': + mask = images[0, ..., 3].cpu().detach().numpy() + return images[0, ..., :3].cpu().detach().numpy(), mask > 0 + return images[0, ..., :3].cpu().detach().numpy() + + +def get_kinect_camera(device='cuda:0', kid=1): + R, T = torch.eye(3), torch.zeros(3) + R[0, 0] = R[1, 1] = -1 # pytorch3d y-axis up, need to rotate to kinect coordinate + R = R.unsqueeze(0) + T = T.unsqueeze(0) + assert kid in [0, 1, 2, 3], f'invalid kinect index {kid}!' + if kid == 0: + fx, fy = 976.212, 976.047 + cx, cy = 1017.958, 787.313 + elif kid == 1: + fx, fy = 979.784, 979.840 # for original kinect coordinate system + cx, cy = 1018.952, 779.486 + elif kid == 2: + fx, fy = 974.899, 974.337 + cx, cy = 1018.747, 786.176 + else: + fx, fy = 972.873, 972.790 + cx, cy = 1022.0565, 770.397 + color_w, color_h = 2048, 1536 # kinect color image size + cam_center = torch.tensor((cx, cy), dtype=torch.float32).unsqueeze(0) + focal_length = torch.tensor((fx, fy), dtype=torch.float32).unsqueeze(0) + cam = PerspectiveCameras(focal_length=focal_length, principal_point=cam_center, + image_size=((color_w, color_h),), + device=device, + R=R, T=T) + return cam + + +class PcloudRenderer: + "a simple wrapper for pytorch3d point cloud renderer" + def __init__(self, image_size=1024, radius=0.005, points_per_pixel=10, + device='cuda:0', bin_size=128, batch_size=1, ret_depth=False): + camera_centers = [] + focal_lengths = [] + for i in range(batch_size): + camera_centers.append(torch.Tensor([image_size / 2., image_size / 2.]).to(device)) + focal_lengths.append(torch.Tensor([image_size / 2., image_size / 2.]).to(device)) + self.image_size = image_size + self.device = device + self.camera_center = torch.stack(camera_centers) + self.focal_length = torch.stack(focal_lengths) + self.ret_depth = ret_depth # return depth map or not + self.renderer = self.setup_renderer(radius, points_per_pixel, bin_size) + + def render(self, pc, cameras, mode='image'): + # TODO: support batch rendering + """ + render the point cloud, compute the world coordinate of each pixel based on zbuf + image: (H, W, 3) + xyz_world: (H, W, 3), the third dimension is the xyz coordinate in world space + """ + # assert cameras.R.shape[0]==1, "batch rendering is not supported for now!" + images, fragments = self.renderer(pc, cameras=cameras) + if mode=='image': + if images.shape[0] == 1: + img = images[0, ..., :3].cpu().numpy().copy() + else: + img = images[..., :3].cpu().numpy().copy() + return img + elif mode=='mask': + zbuf = torch.mean(fragments.zbuf, -1) # (B, H, W) + masks = zbuf >= 0 + if images.shape[0] == 1: + img = images[0, ..., :3].cpu().numpy() + masks = masks[0].cpu().numpy().astype(bool) + else: + img = images[..., :3].cpu().numpy() + masks = masks.cpu().numpy().astype(bool) + + return img, masks + + def get_xy_ndc(self): + """ + return (H, W, 2), each pixel is the x,y coordinate in NDC space + """ + py, px = torch.meshgrid(torch.linspace(0, self.image_size-1, self.image_size), + torch.linspace(0, self.image_size-1, self.image_size)) + x_ndc = 1 - 2*px/(self.image_size - 1) + y_ndc = 1 - 2*py/(self.image_size - 1) + xy_ndc = torch.stack([x_ndc, y_ndc], axis=-1).to(self.device) + return xy_ndc.squeeze(0).unsqueeze(0) + + def setup_renderer(self, radius, points_per_pixel, bin_size): + raster_settings = PointsRasterizationSettings( + image_size=self.image_size, + # radius=0.003, + radius=radius, + points_per_pixel=points_per_pixel, + bin_size=bin_size, + max_points_per_bin=500000 + ) + # Create a points renderer by compositing points using an alpha compositor (nearer points + # are weighted more heavily). See [1] for an explanation. + rasterizer = PointsRasterizer(raster_settings=raster_settings) + renderer = PointsRendererWithFragments( + rasterizer=rasterizer, + compositor=AlphaCompositor() + ) + return renderer + + +class PointsRendererWithFragments(torch.nn.Module): + def __init__(self, rasterizer, compositor): + super().__init__() + self.rasterizer = rasterizer + self.compositor = compositor + + def forward(self, point_clouds, **kwargs) -> (torch.Tensor, torch.Tensor): + fragments = self.rasterizer(point_clouds, **kwargs) + # Construct weights based on the distance of a point to the true point. + # However, this could be done differently: e.g. predicted as opposed + # to a function of the weights. + r = self.rasterizer.raster_settings.radius + + dists2 = fragments.dists.permute(0, 3, 1, 2) + weights = 1 - dists2 / (r * r) + images = self.compositor( + fragments.idx.long().permute(0, 3, 1, 2), + weights, + point_clouds.features_packed().permute(1, 0), + **kwargs, + ) + + # permute so image comes at the end + images = images.permute(0, 2, 3, 1) + + return images, fragments + +# class PcloudsRenderer + + +class DepthRasterizer(torch.nn.Module): + """ + simply rasterize a mesh or point cloud to depth image + """ + def __init__(self, image_size, dtype='pc', + radius=0.005, points_per_pixel=1, + bin_size=128, + blur_radius=0, + max_faces_per_bin=50000, + faces_per_pixel=1,): + """ + image_size: (height, width) + """ + super(DepthRasterizer, self).__init__() + if dtype == 'pc': + raster_settings = PointsRasterizationSettings( + image_size=image_size, + radius=radius, + points_per_pixel=points_per_pixel, + bin_size=bin_size + ) + self.rasterizer = PointsRasterizer(raster_settings=raster_settings) + elif dtype == 'mesh': + raster_settings = RasterizationSettings( + image_size=image_size, + blur_radius=blur_radius, + # blur_radius=np.log(1. / 1e-4 - 1.) * sigma, # this will create large sphere for each face + faces_per_pixel=faces_per_pixel, + clip_barycentric_coords=False, + max_faces_per_bin=max_faces_per_bin + ) + self.rasterizer=MeshRasterizer(raster_settings=raster_settings) + else: + raise NotImplemented + + def forward(self, data, to_np=True, **kwargs): + fragments = self.rasterizer(data, **kwargs) + if to_np: + zbuf = fragments.zbuf # (B, H, W, points_per_pixel) + return zbuf[0, ..., 0].cpu().numpy() + return fragments.zbuf + + +def test_depth_rasterizer(): + from psbody.mesh import Mesh + import cv2 + m = Mesh() + m.load_from_file("/BS/xxie-4/work/kindata/Sep29_shuo_chairwood_hand/t0003.000/person/person.ply") + device = 'cuda:0' + pc = Pointclouds([torch.from_numpy(m.v).float().to(device)], + features=[torch.from_numpy(m.vc).float().to(device)]) + rasterizer = DepthRasterizer(image_size=(480, 640)) + camera = get_kinect_camera(device) + + depth = rasterizer(pc, cameras=camera) + std = torch.std(depth, -1) + print('max std', torch.max(std)) # maximum std is up to 1.7m, too much! + print('min std', torch.min(std)) + + print(depth.shape) + dmap = depth[0, ..., 0].cpu().numpy() + dmap[dmap<0] = 0 + cv2.imwrite('debug/depth.png', (dmap*1000).astype(np.uint16)) + +def test_mesh_rasterizer(): + from psbody.mesh import Mesh + import cv2 + m = Mesh() + m.load_from_file("/BS/xxie-4/work/kindata/Sep29_shuo_chairwood_hand/t0003.000/person/fit02/person_fit.ply") + device = 'cuda:0' + mesh = Meshes([torch.from_numpy(m.v).float().to(device)], + [torch.from_numpy(m.f.astype(int)).to(device)]) + rasterizer = DepthRasterizer(image_size=(480, 640), dtype='mesh') + camera = get_kinect_camera(device) + + depth = rasterizer(mesh, to_np=False, cameras=camera) + + print(depth.shape) + dmap = depth[0, ..., 0].cpu().numpy() + dmap[dmap < 0] = 0 + cv2.imwrite('debug/depth_mesh.png', (dmap * 1000).astype(np.uint16)) + + + +if __name__ == '__main__': + # test_depth_rasterizer() + test_mesh_rasterizer() + + + + + + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..42632834d92da7c3585976b86e3071f3a9fe8aca --- /dev/null +++ b/requirements.txt @@ -0,0 +1,15 @@ +accelerate +diffusers +hydra-core +imageio +ninja +opencv-python-headless +plotly +rich +scipy +timm +torch-ema +tqdm +transformers +wandb +trimesh \ No newline at end of file diff --git a/training_utils.py b/training_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fa9d6035ef9ca2ee8457f335e2283b6ea28ac546 --- /dev/null +++ b/training_utils.py @@ -0,0 +1,443 @@ +""" +Misc functions, including distributed helpers, mostly from torchvision +""" +import glob +import math +import os +import time +import datetime +import random +from dataclasses import dataclass +from collections import defaultdict, deque +from typing import Callable, Optional +from PIL import Image +import numpy as np +import torch +import torch.distributed as dist +import torchvision +from accelerate import Accelerator +from omegaconf import DictConfig + +from configs.structured import ProjectConfig + + + +@dataclass +class TrainState: + epoch: int = 0 + step: int = 0 + best_val: Optional[float] = None + + +def get_optimizer(cfg: ProjectConfig, model: torch.nn.Module, accelerator: Accelerator) -> torch.optim.Optimizer: + """Gets optimizer from configs""" + + # Determine the learning rate + if cfg.optimizer.scale_learning_rate_with_batch_size: + lr = accelerator.state.num_processes * cfg.dataloader.batch_size * cfg.optimizer.lr + print('lr = {ws} (num gpus) * {bs} (batch_size) * {blr} (base learning rate) = {lr}'.format( + ws=accelerator.state.num_processes, bs=cfg.dataloader.batch_size, blr=cfg.optimizer.lr, lr=lr)) + else: # scale base learning rate by batch size + lr = cfg.optimizer.lr + print('lr = {lr} (absolute learning rate)'.format(lr=lr)) + + # Get optimizer parameters, excluding certain parameters from weight decay + no_decay = ["bias", "LayerNorm.weight"] + parameters = [ + { + "params": [p for n, p in model.named_parameters() if p.requires_grad and not any(nd in n for nd in no_decay)], + "weight_decay": cfg.optimizer.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if p.requires_grad and any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + + # Construct optimizer + if cfg.optimizer.type == 'torch': + Optimizer: torch.optim.Optimizer = getattr(torch.optim, cfg.optimizer.name) + optimizer = Optimizer(parameters, lr=lr, **cfg.optimizer.kwargs) + elif cfg.optimizer.type == 'timm': + from timm.optim import create_optimizer_v2 + optimizer = create_optimizer_v2(model_or_params=parameters, lr=lr, **cfg.optimizer.kwargs) + elif cfg.optimizer.type == 'transformers': + import transformers + Optimizer: torch.optim.Optimizer = getattr(transformers, cfg.optimizer.name) + optimizer = Optimizer(parameters, lr=lr, **cfg.optimizer.kwargs) + else: + raise NotImplementedError(f'Invalid optimizer configs: {cfg.optimizer}') + + return optimizer + + +def get_scheduler(cfg: ProjectConfig, optimizer: torch.optim.Optimizer) -> Callable: + """Gets scheduler from configs""" + + # Get scheduler + if cfg.scheduler.type == 'torch': + Scheduler: torch.optim.lr_scheduler._LRScheduler = getattr(torch.optim.lr_scheduler, cfg.scheduler.type) + scheduler = Scheduler(optimizer=optimizer, **cfg.scheduler.kwargs) + if cfg.scheduler.get('warmup', 0): + from warmup_scheduler import GradualWarmupScheduler + scheduler = GradualWarmupScheduler(optimizer, multiplier=1, + total_epoch=cfg.scheduler.warmup, after_scheduler=scheduler) + elif cfg.scheduler.type == 'timm': + from timm.scheduler import create_scheduler + scheduler, _ = create_scheduler(optimizer=optimizer, args=cfg.scheduler.kwargs) + elif cfg.scheduler.type == 'transformers': + from transformers import get_scheduler # default: linear scheduler without warm up and linear decay + scheduler = get_scheduler(optimizer=optimizer, **cfg.scheduler.kwargs) + else: + raise NotImplementedError(f'invalid scheduler configs: {cfg.scheduler}') + + return scheduler + + +@torch.no_grad() +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.reshape(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self, device='cuda'): + """ + Warning: does not synchronize the deque! + """ + if not using_distributed(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device=device) + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / max(self.count, 1) + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) if len(self.deque) > 0 else "" + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + n = kwargs.pop('n', 1) + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v, n=n) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self, device='cuda'): + for meter in self.meters.values(): + meter.synchronize_between_processes(device=device) + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + log_msg = [ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}' + ] + if torch.cuda.is_available(): + log_msg.append('max mem: {memory:.0f}') + log_msg = self.delimiter.join(log_msg) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + else: + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('{} Total time: {} ({:.4f} s / it)'.format( + header, total_time_str, total_time / len(iterable))) + + +class NormalizeInverse(torchvision.transforms.Normalize): + """ + Undoes the normalization and returns the reconstructed images in the input domain. + """ + + def __init__(self, mean, std): + mean = torch.as_tensor(mean) + std = torch.as_tensor(std) + std_inv = 1 / (std + 1e-7) + mean_inv = -mean * std_inv + super().__init__(mean=mean_inv, std=std_inv) + + def __call__(self, tensor): + return super().__call__(tensor.clone()) + + +def resume_from_checkpoint(cfg: ProjectConfig, model, optimizer=None, scheduler=None, model_ema=None): + + # Check if resuming training from a checkpoint + if not cfg.checkpoint.resume: + print('Starting training from scratch') + return TrainState() + + # XH: find checkpiont path automatically + if not os.path.isfile(cfg.checkpoint.resume): + print(f"The given checkpoint path {cfg.checkpoint.resume} does not exist, trying to find one...") + # print(os.getcwd()) + ckpt_file = os.path.join(cfg.run.code_dir_abs, f'outputs/{cfg.run.name}/single/checkpoint-latest.pth') + if not os.path.isfile(ckpt_file): + # just get the fist dir, for backward compatibility + folders = sorted(glob.glob(os.path.join(cfg.run.code_dir_abs, f'outputs/{cfg.run.name}/2023-*'))) + assert len(folders) <= 1 + if len(folders) > 0: + ckpt_file = os.path.join(folders[0], 'checkpoint-latest.pth') + + if os.path.isfile(ckpt_file): + print(f"Found checkpoint at {ckpt_file}!") + cfg.checkpoint.resume = ckpt_file + else: + print(f"No checkpoint found in outputs/{cfg.run.name}/single/!") + return TrainState() + + # If resuming, load model state dict + print(f'Loading checkpoint ({datetime.datetime.now()})') + checkpoint = torch.load(cfg.checkpoint.resume, map_location='cpu') + if 'model' in checkpoint: + state_dict, key = checkpoint['model'], 'model' + else: + print("Warning: no model found in checkpoint!") + state_dict, key = checkpoint, 'N/A' + if any(k.startswith('module.') for k in state_dict.keys()): + state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} + print('Removed "module." from checkpoint state dict') + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + print(f'Loaded model checkpoint key {key} from {cfg.checkpoint.resume}') + if len(missing_keys): + print(f' - Missing_keys: {missing_keys}') + if len(unexpected_keys): + print(f' - Unexpected_keys: {unexpected_keys}') + # 298 missing, 328 unexpected! total 448 modules. + print(f"{len(missing_keys)} missing, {len(unexpected_keys)} unexpected! total {len(model.state_dict().keys())} modules.") + # print("First 10 keys:") + # for i in range(10): + # print(missing_keys[i], unexpected_keys[i]) + # exit(0) + if 'step' in checkpoint: + print("Number of trained steps:", checkpoint['step']) + + # TODO: implement better loading for fine tuning + # Resume model ema + if cfg.ema.use_ema: + if checkpoint['model_ema']: + model_ema.load_state_dict(checkpoint['model_ema']) + print('Loaded model ema from checkpoint') + else: + model_ema.load_state_dict(model.parameters()) + print('No model ema in checkpoint; loaded current parameters into model') + else: + if 'model_ema' in checkpoint and checkpoint['model_ema']: + print('Not using model ema, but model_ema found in checkpoint (you probably want to resume it!)') + else: + print('Not using model ema, and no model_ema found in checkpoint.') + + # Resume optimizer and/or training state + train_state = TrainState() + if 'train' in cfg.run.job: + if cfg.checkpoint.resume_training: + assert ( + cfg.checkpoint.resume_training_optimizer + or cfg.checkpoint.resume_training_scheduler + or cfg.checkpoint.resume_training_state + or cfg.checkpoint.resume_training + ), f'Invalid configs: {cfg.checkpoint}' + if cfg.checkpoint.resume_training_optimizer: + if 'optimizer' not in checkpoint: + assert 'tune' in cfg.run.name, f'please check the checkpoint for run {cfg.run.name}' + print("Warning: not loading optimizer!") + else: + assert 'optimizer' in checkpoint, f'Value not in {checkpoint.keys()}' + optimizer.load_state_dict(checkpoint['optimizer']) + print(f'Loaded optimizer from checkpoint') + else: + print(f'Did not load optimizer from checkpoint') + if cfg.checkpoint.resume_training_scheduler: + if 'scheduler' not in checkpoint: + assert 'tune' in cfg.run.name, f'please check the checkpoint for run {cfg.run.name}' + print("Warning: not loading scheduler!") + else: + assert 'scheduler' in checkpoint, f'Value not in {checkpoint.keys()}' + scheduler.load_state_dict(checkpoint['scheduler']) + print(f'Loaded scheduler from checkpoint') + else: + print(f'Did not load scheduler from checkpoint') + if cfg.checkpoint.resume_training_state: + if 'steps' in checkpoint and 'step' not in checkpoint: # fixes an old typo + checkpoint['step'] = checkpoint.pop('steps') + assert {'epoch', 'step', 'best_val'}.issubset(set(checkpoint.keys())) + epoch, step, best_val = checkpoint['epoch'] + 1, checkpoint['step'], checkpoint['best_val'] + train_state = TrainState(epoch=epoch, step=step, best_val=best_val) + print(f'Resumed state from checkpoint: step {step}, epoch {epoch}, best_val {best_val}') + else: + print(f'Did not load train state from checkpoint') + else: + print('Did not resume optimizer, scheduler, or epoch from checkpoint') + + print(f'Finished loading checkpoint ({datetime.datetime.now()})') + + return train_state + + +def setup_distributed_print(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + from rich import print as __richprint__ + builtin_print = __richprint__ # __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def using_distributed(): + return dist.is_available() and dist.is_initialized() + + +def get_rank(): + return dist.get_rank() if using_distributed() else 0 + + +def set_seed(seed): + rank = get_rank() + seed = seed + rank + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = True + if using_distributed(): + print(f'Seeding node {rank} with seed {seed}', force=True) + else: + print(f'Seeding node {rank} with seed {seed}') + + +def compute_grad_norm(parameters): + # total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2).item() + total_norm = 0 + for p in parameters: + if p.grad is not None and p.requires_grad: + param_norm = p.grad.detach().data.norm(2) + total_norm += param_norm.item() ** 2 + total_norm = total_norm ** 0.5 + return total_norm + + +class dotdict(dict): + """dot.notation access to dictionary attributes""" + __getattr__ = dict.get + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__