diff --git a/README.md b/README.md
index 3bad120e9b1a85cbafc28dab9df455f1e14135ec..25e18bc3de3c81966967698aed70d38d45415622 100644
--- a/README.md
+++ b/README.md
@@ -1,13 +1,16 @@
----
-title: HDM Interaction Recon
-emoji: 🌍
-colorFrom: yellow
-colorTo: green
-sdk: gradio
-sdk_version: 4.20.1
-app_file: app.py
-pinned: false
-license: cc-by-nc-4.0
----
+# HDM
+Official implementation for Hierarachical Diffusion Model in CVPR24 Template free reconstruction of human object interaction
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+[Project Page](https://virtualhumans.mpi-inf.mpg.de/procigen-hdm/)|[Code](https://github.com/xiexh20/HDM)|[Dataset](https://edmond.mpg.de/dataset.xhtml?persistentId=doi:10.17617/3.2VUEUS )|[Paper](https://virtualhumans.mpi-inf.mpg.de/procigen-hdm/paper-lowreso.pdf)
+
+
+## Citation
+```
+@inproceedings{xie2023template_free,
+ title = {Template Free Reconstruction of Human-object Interaction with Procedural Interaction Generation},
+ author = {Xie, Xianghui and Bhatnagar, Bharat Lal and Lenssen, Jan Eric and Pons-Moll, Gerard},
+ booktitle = {IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
+ month = {June},
+ year = {2024},
+}
+```
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..707d9b3ea3e375daa5bb47a44d7bcd2e97f476ef
--- /dev/null
+++ b/app.py
@@ -0,0 +1,177 @@
+"""
+Demo built with gradio
+"""
+import pickle as pkl
+import sys, os
+import os.path as osp
+from typing import Iterable, Optional
+from functools import partial
+
+import trimesh
+from torch.utils.data import DataLoader
+import cv2
+from accelerate import Accelerator
+from tqdm import tqdm
+from glob import glob
+
+sys.path.append(os.getcwd())
+import hydra
+import torch
+import numpy as np
+import imageio
+import gradio as gr
+import plotly.graph_objs as go
+import training_utils
+
+from configs.structured import ProjectConfig
+from demo import DemoRunner
+from dataset.demo_dataset import DemoDataset
+
+
+md_description="""
+# HDM Interaction Reconstruction Demo
+### Official Implementation of the paper \"Template Free Reconstruction of Human Object Interaction\", CVPR'24.
+[Project Page](https://virtualhumans.mpi-inf.mpg.de/procigen-hdm/)|[Code](https://github.com/xiexh20/HDM)|[Dataset](https://edmond.mpg.de/dataset.xhtml?persistentId=doi:10.17617/3.2VUEUS )|[Paper](https://virtualhumans.mpi-inf.mpg.de/procigen-hdm/paper-lowreso.pdf)
+
+
+Upload your own human object interaction image and get full 3D reconstruction!
+
+## Citation
+```
+@inproceedings{xie2023template_free,
+ title = {Template Free Reconstruction of Human-object Interaction with Procedural Interaction Generation},
+ author = {Xie, Xianghui and Bhatnagar, Bharat Lal and Lenssen, Jan Eric and Pons-Moll, Gerard},
+ booktitle = {IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
+ month = {June},
+ year = {2024},
+}
+```
+"""
+
+def plot_points(colors, coords):
+ """
+ use plotly to visualize 3D point with colors
+ """
+ trace = go.Scatter3d(x=coords[:, 0], y=coords[:, 1], z=coords[:, 2], mode='markers',
+ marker=dict(
+ size=2,
+ color=colors
+ ))
+ layout = go.Layout(
+ scene=dict(
+ xaxis=dict(
+ title="",
+ showgrid=False,
+ zeroline=False,
+ showline=False,
+ ticks='',
+ showticklabels=False
+ ),
+ yaxis=dict(
+ title="",
+ showgrid=False,
+ zeroline=False,
+ showline=False,
+ ticks='',
+ showticklabels=False
+ ),
+ zaxis=dict(
+ title="",
+ showgrid=False,
+ zeroline=False,
+ showline=False,
+ ticks='',
+ showticklabels=False
+ ),
+ ),
+ margin=dict(l=0, r=0, b=0, t=0),
+ showlegend=False
+ )
+ fig = go.Figure(data=[trace], layout=layout)
+ return fig
+
+
+def inference(runner: DemoRunner, cfg: ProjectConfig, rgb, mask_hum, mask_obj, std_coverage, input_seed):
+ """
+ given user input, run inference
+ :param runner:
+ :param cfg:
+ :param rgb: (h, w, 3), np array
+ :param mask_hum: (h, w, 3), np array
+ :param mask_obj: (h, w, 3), np array
+ :param std_coverage: float value, used to estimate camera translation
+ :param input_seed: random seed
+ :return: path to the 3D reconstruction, and an interactive 3D figure for visualizing the point cloud
+ """
+ # Set random seed
+ training_utils.set_seed(int(input_seed))
+
+ data = DemoDataset([], (cfg.dataset.image_size, cfg.dataset.image_size),
+ std_coverage)
+ batch = data.image2batch(rgb, mask_hum, mask_obj)
+
+ out_stage1, out_stage2 = runner.forward_batch(batch, cfg)
+ points = out_stage2.points_packed().cpu().numpy()
+ colors = out_stage2.features_packed().cpu().numpy()
+ fig = plot_points(colors, points)
+ # save tmp point cloud
+ outdir = './results'
+ os.makedirs(outdir, exist_ok=True)
+ trimesh.PointCloud(points, colors).export(outdir + f"/pred_std{std_coverage}_seed{input_seed}_stage2.ply")
+ trimesh.PointCloud(out_stage1.points_packed().cpu().numpy(),
+ out_stage1.features_packed().cpu().numpy()).export(outdir + f"/pred_std{std_coverage}_seed{input_seed}_stage1.ply")
+ return fig, outdir + f"/pred_std{std_coverage}_seed{input_seed}_stage2.ply"
+
+
+@hydra.main(config_path='configs', config_name='configs', version_base='1.1')
+def main(cfg: ProjectConfig):
+ # Setup model
+ runner = DemoRunner(cfg)
+
+ # Setup interface
+ demo = gr.Blocks(title="HDM Interaction Reconstruction Demo")
+ with demo:
+ gr.Markdown(md_description)
+ gr.HTML("""
HDM Demo
""")
+ gr.HTML("""Instruction: Upload RGB, human, object masks and then click reconstruct.""")
+
+ # Input data
+ with gr.Row():
+ input_rgb = gr.Image(label='Input RGB', type='numpy')
+ input_mask_hum = gr.Image(label='Human mask', type='numpy')
+ with gr.Row():
+ input_mask_obj = gr.Image(label='Object mask', type='numpy')
+ with gr.Column():
+ # TODO: add hint for this value here
+ input_std = gr.Number(label='Gaussian std coverage', value=3.5)
+ input_seed = gr.Number(label='Random seed', value=42)
+ # Output visualization
+ with gr.Row():
+ pc_plot = gr.Plot(label="Reconstructed point cloud")
+ out_pc_download = gr.File(label="3D reconstruction for download") # this allows downloading
+
+ gr.HTML("""
""")
+ # Control
+ with gr.Row():
+ button_recon = gr.Button("Start Reconstruction", interactive=True, variant='secondary')
+ button_recon.click(fn=partial(inference, runner, cfg),
+ inputs=[input_rgb, input_mask_hum, input_mask_obj, input_std, input_seed],
+ outputs=[pc_plot, out_pc_download])
+ gr.HTML("""
""")
+ # Example input
+ example_dir = cfg.run.code_dir_abs+"/examples"
+ rgb, ps, obj = 'k1.color.jpg', 'k1.person_mask.png', 'k1.obj_rend_mask.png'
+ example_images = gr.Examples([
+ [f"{example_dir}/017450/{rgb}", f"{example_dir}/017450/{ps}", f"{example_dir}/017450/{obj}", 3.0, 42],
+ [f"{example_dir}/002446/{rgb}", f"{example_dir}/002446/{ps}", f"{example_dir}/002446/{obj}", 3.0, 42],
+ [f"{example_dir}/053431/{rgb}", f"{example_dir}/053431/{ps}", f"{example_dir}/053431/{obj}", 3.8, 42],
+ [f"{example_dir}/158107/{rgb}", f"{example_dir}/158107/{ps}", f"{example_dir}/158107/{obj}", 3.8, 42],
+
+ ], inputs=[input_rgb, input_mask_hum, input_mask_obj, input_std, input_seed],)
+
+ # demo.launch(share=True)
+ # Enabling queue for runtime>60s, see: https://github.com/tloen/alpaca-lora/issues/60#issuecomment-1510006062
+ demo.queue(concurrency_count=3).launch(share=True)
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file
diff --git a/configs/__init__.py b/configs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/configs/structured.py b/configs/structured.py
new file mode 100644
index 0000000000000000000000000000000000000000..4725105375b147cc2445ae932d8555da065ad8f9
--- /dev/null
+++ b/configs/structured.py
@@ -0,0 +1,416 @@
+import os
+from dataclasses import dataclass, field
+from typing import Any, Dict, List, Optional, Iterable
+import os.path as osp
+
+from hydra.core.config_store import ConfigStore
+from hydra.conf import RunDir
+
+
+@dataclass
+class CustomHydraRunDir(RunDir):
+ dir: str = './outputs/${run.name}/single'
+
+
+@dataclass
+class RunConfig:
+ name: str = 'debug'
+ job: str = 'train'
+ mixed_precision: str = 'fp16' # 'no'
+ cpu: bool = False
+ seed: int = 42
+ val_before_training: bool = True
+ vis_before_training: bool = True
+ limit_train_batches: Optional[int] = None
+ limit_val_batches: Optional[int] = None
+ max_steps: int = 100_000
+ checkpoint_freq: int = 1_000
+ val_freq: int = 5_000
+ vis_freq: int = 5_000
+ # vis_freq: int = 10_000
+ log_step_freq: int = 20
+ print_step_freq: int = 100
+
+ # config to run demo
+ stage1_name: str = 'stage1' # experiment name to the stage 1 model
+ stage2_name: str = 'stage2' # experiment name to the stage 2 model
+ image_path: str = '' # the path to the images for running demo, can be a single file or a glob pattern
+
+ # abs path to working dir
+ code_dir_abs: str = osp.dirname(osp.dirname(osp.abspath(__file__)))
+
+ # Inference configs
+ num_inference_steps: int = 1000
+ diffusion_scheduler: Optional[str] = 'ddpm'
+ num_samples: int = 1
+ # num_sample_batches: Optional[int] = None
+ num_sample_batches: Optional[int] = 2000 # XH: change to 2
+ sample_from_ema: bool = False
+ sample_save_evolutions: bool = False # temporarily set by default
+ save_name: str = 'sample' # XH: additional save name
+ redo: bool = False
+
+ # for parallel sampling in slurm
+ batch_start: int = 0
+ batch_end: Optional[int] = None
+
+ # Training configs
+ freeze_feature_model: bool = True
+
+ # Coloring training configs
+ coloring_training_noise_std: float = 0.0
+ coloring_sample_dir: Optional[str] = None
+
+ sample_mode: str = 'sample' # whether from noise or from some intermediate steps
+ sample_noise_step: int = 500 # add noise to GT up to some steps, and then denoise
+ sample_save_gt: bool = True
+
+
+@dataclass
+class LoggingConfig:
+ wandb: bool = True
+ wandb_project: str = 'pc2'
+
+
+
+@dataclass
+class PointCloudProjectionModelConfig:
+ # Feature extraction arguments
+ image_size: int = '${dataset.image_size}'
+ image_feature_model: str = 'vit_base_patch16_224_mae' # or 'vit_small_patch16_224_msn' or 'identity'
+ use_local_colors: bool = True
+ use_local_features: bool = True
+ use_global_features: bool = False
+ use_mask: bool = True
+ use_distance_transform: bool = True
+
+ # Point cloud data arguments. Note these are here because the processing happens
+ # inside the model, rather than inside the dataset.
+ scale_factor: float = "${dataset.scale_factor}"
+ colors_mean: float = 0.5
+ colors_std: float = 0.5
+ color_channels: int = 3
+ predict_shape: bool = True
+ predict_color: bool = False
+
+ # added by XH
+ load_sample_init: bool = False # load init samples from file
+ sample_init_scale: float = 1.0 # scale the initial pc samples
+ test_init_with_gtpc: bool = False # test time init samples with GT samples
+ consistent_center: bool = True # use consistent center prediction by CCD-3DR
+ voxel_resolution_multiplier: float = 1 # increase network voxel resolution
+
+ # predict binary segmentation
+ predict_binary: bool = False # True for stage 1 model, False for others
+ lw_binary: float = 3.0 # to have roughly the same magnitude of the binary segmentation loss
+ # for separate model
+ binary_training_noise_std: float = 0.1 # from github doc for predicting color
+ self_conditioning: bool = False
+
+@dataclass
+class PVCNNAEModelConfig(PointCloudProjectionModelConfig):
+ "my own model config, must inherit parent class"
+ model_name: str = 'pvcnn-ae'
+ latent_dim: int = 1024
+ num_dec_blocks: int = 6
+ block_dims: List[int] = field(default_factory=lambda: [512, 256])
+ num_points: int = 1500
+ bottleneck_dim: int = -1 # the input dim to the last MLP layer
+
+@dataclass
+class PointCloudDiffusionModelConfig(PointCloudProjectionModelConfig):
+ model_name: str = 'pc2-diff-ho' # default as behave
+
+ # Diffusion arguments
+ beta_start: float = 1e-5 # 0.00085
+ beta_end: float = 8e-3 # 0.012
+ beta_schedule: str = 'linear' # 'custom'
+ dm_pred_type: str = 'epsilon' # diffusion model prediction type, sample (x0) or noise
+
+ # Point cloud model arguments
+ point_cloud_model: str = 'pvcnn'
+ point_cloud_model_embed_dim: int = 64
+
+ dataset_type: str = '${dataset.type}'
+
+@dataclass
+class CrossAttnHOModelConfig(PointCloudDiffusionModelConfig):
+ model_name: str = 'diff-ho-attn'
+
+ attn_type: str = 'coord3d+posenc-learnable'
+ attn_weight: float = 1.0
+ point_visible_test: str = 'combine' # To compute point visibility: use all points or only human/object points
+
+
+@dataclass
+class DirectTransModelConfig(PointCloudProjectionModelConfig):
+ model_name: str = 'direct-transl-ho'
+
+ pooling: str = "avg"
+ act: str = 'gelu'
+ out_act: str = 'relu'
+ # feat_dims_transl: Iterable[Any] = (384, 256, 128, 6) # cannot use List[int] https://github.com/facebookresearch/hydra/issues/1752#issuecomment-893174197
+ # feat_dims_scale: Iterable[Any] = (384, 128, 64, 2)
+ feat_dims_transl: List[int] = field(default_factory=lambda: [384, 256, 128, 6])
+ feat_dims_scale: List[int] = field(default_factory=lambda: [384, 128, 64, 2])
+ lw_transl: float = 10000.0
+ lw_scale: float = 10000.0
+
+
+@dataclass
+class PointCloudColoringModelConfig(PointCloudProjectionModelConfig):
+ # Projection arguments
+ predict_shape: bool = False
+ predict_color: bool = True
+
+ # Point cloud model arguments
+ point_cloud_model: str = 'pvcnn'
+ point_cloud_model_layers: int = 1
+ point_cloud_model_embed_dim: int = 64
+
+
+@dataclass
+class DatasetConfig:
+ type: str
+
+
+@dataclass
+class PointCloudDatasetConfig(DatasetConfig):
+ eval_split: str = 'val'
+ max_points: int = 16_384
+ image_size: int = 224
+ scale_factor: float = 1.0
+ restrict_model_ids: Optional[List] = None # for only running on a subset of data points
+
+
+@dataclass
+class CO3DConfig(PointCloudDatasetConfig):
+ type: str = 'co3dv2'
+ # root: str = os.getenv('CO3DV2_DATASET_ROOT')
+ root: str = "/BS/xxie-2/work/co3d/hydrant"
+ category: str = 'hydrant'
+ subset_name: str = 'fewview_dev'
+ mask_images: bool = '${model.use_mask}'
+
+
+@dataclass
+class ShapeNetR2N2Config(PointCloudDatasetConfig):
+ # added by XH
+ fix_sample: bool = True
+ category: str = 'chair'
+
+ type: str = 'shapenet_r2n2'
+ root: str = "/BS/chiban2/work/data_shapenet/ShapeNetCore.v1"
+ r2n2_dir: str = "/BS/databases20/3d-r2n2"
+ shapenet_dir: str = "/BS/chiban2/work/data_shapenet/ShapeNetCore.v1"
+ preprocessed_r2n2_dir: str = "${dataset.root}/r2n2_preprocessed_renders"
+ splits_file: str = "${dataset.root}/r2n2_standard_splits_from_ShapeNet_taxonomy.json"
+ # splits_file: str = "${dataset.root}/pix2mesh_splits_val05.json" # <-- incorrect
+ scale_factor: float = 7.0
+ point_cloud_filename: str = 'pointcloud_r2n2.npz' # should use 'pointcloud_mesh.npz'
+
+
+
+@dataclass
+class BehaveDatasetConfig(PointCloudDatasetConfig):
+ # added by XH
+ type: str = 'behave'
+
+ fix_sample: bool = True
+ behave_dir: str = "/BS/xxie-5/static00/behave_release/sequences/"
+ split_file: str = "" # specify you dataset split file here
+ scale_factor: float = 7.0 # use the same as shapenet
+ sample_ratio_hum: float = 0.5
+ image_size: int = 224
+
+ normalize_type: str = 'comb'
+ smpl_type: str = 'gt' # use which SMPL mesh to obtain normalization parameters
+ test_transl_type: str = 'norm'
+
+ load_corr_points: bool = False # load autoencoder points for object and SMPL
+ uniform_obj_sample: bool = False
+
+ # configs for direct translation prediction
+ bkg_type: str = 'none'
+ bbox_params: str = 'none'
+ ho_segm_pred_path: Optional[str] = None
+ use_gt_transl: bool = False
+
+ cam_noise_std: float = 0. # add noise to the camera pose
+ sep_same_crop: bool = False # use same input image crop to separate models
+ aug_blur: float = 0. # blur augmentation
+
+ std_coverage: float=3.5 # a heuristic value to estimate translation
+
+ v2v_path: str = '' # object v2v corr path
+
+@dataclass
+class ShapeDatasetConfig(BehaveDatasetConfig):
+ "the dataset to train AE for aligned shapes"
+ type: str = 'shape'
+ fix_sample: bool = False
+ split_file: str = "/BS/xxie-2/work/pc2-diff/experiments/splits/shapes-chair.pkl"
+
+
+# TODO
+@dataclass
+class ShapeNetNMRConfig(PointCloudDatasetConfig):
+ type: str = 'shapenet_nmr'
+ shapenet_nmr_dir: str = "/work/lukemk/machine-learning-datasets/3d-reconstruction/ShapeNet_NMR/NMR_Dataset"
+ synset_names: str = 'chair' # comma-separated or 'all'
+ augmentation: str = 'all'
+ scale_factor: float = 7.0
+
+
+@dataclass
+class AugmentationConfig:
+ # need to specify the variable type in order to define it properly
+ max_radius: int = 0 # generate a random square to mask object, this is the radius for the square in pixel size, zero means no occlusion
+
+
+@dataclass
+class DataloaderConfig:
+ # batch_size: int = 8 # 2 for debug
+ batch_size: int = 16
+ num_workers: int = 14 # 0 for debug # suggested by accelerator for gpu20
+
+
+@dataclass
+class LossConfig:
+ diffusion_weight: float = 1.0
+ rgb_weight: float = 1.0
+ consistency_weight: float = 1.0
+
+
+@dataclass
+class CheckpointConfig:
+ resume: Optional[str] = "test"
+ resume_training: bool = True
+ resume_training_optimizer: bool = True
+ resume_training_scheduler: bool = True
+ resume_training_state: bool = True
+
+
+@dataclass
+class ExponentialMovingAverageConfig:
+ use_ema: bool = False
+ # # From Diffusers EMA (should probably switch)
+ # ema_inv_gamma: float = 1.0
+ # ema_power: float = 0.75
+ # ema_max_decay: float = 0.9999
+ decay: float = 0.999
+ update_every: int = 20
+
+
+@dataclass
+class OptimizerConfig:
+ type: str
+ name: str
+ lr: float = 3e-4
+ weight_decay: float = 0.0
+ scale_learning_rate_with_batch_size: bool = False
+ gradient_accumulation_steps: int = 1
+ clip_grad_norm: Optional[float] = 50.0 # 5.0
+ kwargs: Dict = field(default_factory=lambda: dict())
+
+
+@dataclass
+class AdadeltaOptimizerConfig(OptimizerConfig):
+ type: str = 'torch'
+ name: str = 'Adadelta'
+ kwargs: Dict = field(default_factory=lambda: dict(
+ weight_decay=1e-6,
+ ))
+
+
+@dataclass
+class AdamOptimizerConfig(OptimizerConfig):
+ type: str = 'torch'
+ name: str = 'AdamW'
+ weight_decay: float = 1e-6
+ kwargs: Dict = field(default_factory=lambda: dict(betas=(0.95, 0.999)))
+
+
+@dataclass
+class SchedulerConfig:
+ type: str
+ kwargs: Dict = field(default_factory=lambda: dict())
+
+
+@dataclass
+class LinearSchedulerConfig(SchedulerConfig):
+ type: str = 'transformers'
+ kwargs: Dict = field(default_factory=lambda: dict(
+ name='linear',
+ num_warmup_steps=0,
+ num_training_steps="${run.max_steps}",
+ ))
+
+
+@dataclass
+class CosineSchedulerConfig(SchedulerConfig):
+ type: str = 'transformers'
+ kwargs: Dict = field(default_factory=lambda: dict(
+ name='cosine',
+ num_warmup_steps=2000, # 0
+ num_training_steps="${run.max_steps}",
+ ))
+
+
+@dataclass
+class ProjectConfig:
+ run: RunConfig
+ logging: LoggingConfig
+ dataset: PointCloudDatasetConfig
+ augmentations: AugmentationConfig
+ dataloader: DataloaderConfig
+ loss: LossConfig
+ model: PointCloudProjectionModelConfig
+ ema: ExponentialMovingAverageConfig
+ checkpoint: CheckpointConfig
+ optimizer: OptimizerConfig
+ scheduler: SchedulerConfig
+
+ defaults: List[Any] = field(default_factory=lambda: [
+ 'custom_hydra_run_dir',
+ {'run': 'default'},
+ {'logging': 'default'},
+ {'model': 'ho-attn'},
+ # {'dataset': 'co3d'},
+ {'dataset': 'behave'},
+ {'augmentations': 'default'},
+ {'dataloader': 'default'},
+ {'ema': 'default'},
+ {'loss': 'default'},
+ {'checkpoint': 'default'},
+ {'optimizer': 'adam'}, # default adamw
+ {'scheduler': 'linear'},
+ # {'scheduler': 'cosine'},
+ ])
+
+
+cs = ConfigStore.instance()
+cs.store(name='custom_hydra_run_dir', node=CustomHydraRunDir, package="hydra.run")
+cs.store(group='run', name='default', node=RunConfig)
+cs.store(group='logging', name='default', node=LoggingConfig)
+cs.store(group='model', name='diffrec', node=PointCloudDiffusionModelConfig)
+cs.store(group='model', name='coloring_model', node=PointCloudColoringModelConfig)
+cs.store(group='model', name='direct-transl', node=DirectTransModelConfig)
+cs.store(group='model', name='ho-attn', node=CrossAttnHOModelConfig)
+cs.store(group='model', name='pvcnn-ae', node=PVCNNAEModelConfig)
+cs.store(group='dataset', name='co3d', node=CO3DConfig)
+# TODO
+cs.store(group='dataset', name='shapenet_r2n2', node=ShapeNetR2N2Config)
+cs.store(group='dataset', name='behave', node=BehaveDatasetConfig)
+cs.store(group='dataset', name='shape', node=ShapeDatasetConfig)
+# cs.store(group='dataset', name='shapenet_nmr', node=ShapeNetNMRConfig)
+cs.store(group='augmentations', name='default', node=AugmentationConfig)
+cs.store(group='dataloader', name='default', node=DataloaderConfig)
+cs.store(group='loss', name='default', node=LossConfig)
+cs.store(group='ema', name='default', node=ExponentialMovingAverageConfig)
+cs.store(group='checkpoint', name='default', node=CheckpointConfig)
+cs.store(group='optimizer', name='adadelta', node=AdadeltaOptimizerConfig)
+cs.store(group='optimizer', name='adam', node=AdamOptimizerConfig)
+cs.store(group='scheduler', name='linear', node=LinearSchedulerConfig)
+cs.store(group='scheduler', name='cosine', node=CosineSchedulerConfig)
+cs.store(name='configs', node=ProjectConfig)
diff --git a/dataset/__init__.py b/dataset/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a63cea161fe8d6feb7a5c09e458523b15935817a
--- /dev/null
+++ b/dataset/__init__.py
@@ -0,0 +1,301 @@
+from pathlib import Path
+from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
+
+import numpy as np
+import pytorch3d
+import torch
+from torch.utils.data import SequentialSampler
+from omegaconf import DictConfig
+from pytorch3d.implicitron.dataset.data_loader_map_provider import \
+ SequenceDataLoaderMapProvider
+from pytorch3d.implicitron.dataset.dataset_base import FrameData
+from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
+from pytorch3d.implicitron.dataset.json_index_dataset_map_provider_v2 import (
+ JsonIndexDatasetMapProviderV2, registry)
+from pytorch3d.implicitron.tools.config import expand_args_fields
+from pytorch3d.renderer.cameras import CamerasBase
+from torch.utils.data import DataLoader
+
+from configs.structured import CO3DConfig, DataloaderConfig, ProjectConfig, Optional
+from .exclude_sequence import EXCLUDE_SEQUENCE, LOW_QUALITY_SEQUENCE
+from .utils import DatasetMap
+from .r2n2_my import R2N2Sample, collate_batched_meshes
+
+
+def get_dataset(cfg: ProjectConfig):
+
+ if cfg.dataset.type == 'co3dv2':
+ dataset_cfg: CO3DConfig = cfg.dataset
+ dataloader_cfg: DataloaderConfig = cfg.dataloader
+
+ # Exclude bad and low-quality sequences, XH: why this is needed?
+ exclude_sequence = []
+ exclude_sequence.extend(EXCLUDE_SEQUENCE.get(dataset_cfg.category, []))
+ exclude_sequence.extend(LOW_QUALITY_SEQUENCE.get(dataset_cfg.category, []))
+
+ # Whether to load pointclouds
+ kwargs = dict(
+ remove_empty_masks=True,
+ n_frames_per_sequence=1,
+ load_point_clouds=True,
+ max_points=dataset_cfg.max_points,
+ image_height=dataset_cfg.image_size,
+ image_width=dataset_cfg.image_size,
+ mask_images=dataset_cfg.mask_images,
+ exclude_sequence=exclude_sequence,
+ pick_sequence=() if dataset_cfg.restrict_model_ids is None else dataset_cfg.restrict_model_ids,
+ )
+
+ # Get dataset mapper
+ dataset_map_provider_type = registry.get(JsonIndexDatasetMapProviderV2, "JsonIndexDatasetMapProviderV2")
+ expand_args_fields(dataset_map_provider_type)
+ dataset_map_provider = dataset_map_provider_type(
+ category=dataset_cfg.category,
+ subset_name=dataset_cfg.subset_name,
+ dataset_root=dataset_cfg.root,
+ test_on_train=False,
+ only_test_set=False,
+ load_eval_batches=True,
+ dataset_JsonIndexDataset_args=DictConfig(kwargs),
+ )
+
+ # Get datasets
+ datasets = dataset_map_provider.get_dataset_map() # how to select specific frames??
+
+ # PATCH BUG WITH POINT CLOUD LOCATIONS!
+ for dataset in (datasets["train"], datasets["val"]):
+ # print(dataset.seq_annots.items())
+ for key, ann in dataset.seq_annots.items():
+ correct_point_cloud_path = Path(dataset.dataset_root) / Path(*Path(ann.point_cloud.path).parts[-3:])
+ assert correct_point_cloud_path.is_file(), correct_point_cloud_path
+ ann.point_cloud.path = str(correct_point_cloud_path)
+
+ # Get dataloader mapper
+ data_loader_map_provider_type = registry.get(SequenceDataLoaderMapProvider, "SequenceDataLoaderMapProvider")
+ expand_args_fields(data_loader_map_provider_type)
+ data_loader_map_provider = data_loader_map_provider_type(
+ batch_size=dataloader_cfg.batch_size,
+ num_workers=dataloader_cfg.num_workers,
+ )
+
+ # QUICK HACK: Patch the train dataset because it is not used but it throws an error
+ if (len(datasets['train']) == 0 and len(datasets[dataset_cfg.eval_split]) > 0 and
+ dataset_cfg.restrict_model_ids is not None and cfg.run.job == 'sample'):
+ datasets = DatasetMap(train=datasets[dataset_cfg.eval_split], val=datasets[dataset_cfg.eval_split],
+ test=datasets[dataset_cfg.eval_split])
+ # XH: why all eval split?
+ print('Note: You used restrict_model_ids and there were no ids in the train set.')
+
+ # Get dataloaders
+ dataloaders = data_loader_map_provider.get_data_loader_map(datasets)
+ dataloader_train = dataloaders['train']
+ dataloader_val = dataloader_vis = dataloaders[dataset_cfg.eval_split]
+
+ # Replace validation dataloader sampler with SequentialSampler
+ # seems to be randomly sampled? with a fixed random seed? but one cannot control which image is being sampled??
+ dataloader_val.batch_sampler.sampler = SequentialSampler(dataloader_val.batch_sampler.sampler.data_source)
+
+ # Modify for accelerate
+ dataloader_train.batch_sampler.drop_last = True
+ dataloader_val.batch_sampler.drop_last = False
+ elif cfg.dataset.type == 'shapenet_r2n2':
+ # from ..configs.structured import ShapeNetR2N2Config
+ dataset_cfg: ShapeNetR2N2Config = cfg.dataset
+ # for k in dataset_cfg:
+ # print(k)
+ datasets = [R2N2Sample(dataset_cfg.max_points, dataset_cfg.fix_sample,
+ dataset_cfg.image_size, cfg.augmentations,
+ s, dataset_cfg.shapenet_dir,
+ dataset_cfg.r2n2_dir, dataset_cfg.splits_file,
+ load_textures=False, return_all_views=True) for s in ['train', 'val', 'test']]
+ dataloader_train = DataLoader(datasets[0], batch_size=cfg.dataloader.batch_size,
+ collate_fn=collate_batched_meshes,
+ num_workers=cfg.dataloader.num_workers, shuffle=True)
+ dataloader_val = DataLoader(datasets[1], batch_size=cfg.dataloader.batch_size,
+ collate_fn=collate_batched_meshes,
+ num_workers=cfg.dataloader.num_workers, shuffle=False)
+ dataloader_vis = DataLoader(datasets[2], batch_size=cfg.dataloader.batch_size,
+ collate_fn=collate_batched_meshes,
+ num_workers=cfg.dataloader.num_workers, shuffle=False)
+
+ elif cfg.dataset.type in ['behave', 'behave-objonly', 'behave-humonly', 'behave-dtransl',
+ 'behave-objonly-segm', 'behave-humonly-segm', 'behave-attn',
+ 'behave-test', 'behave-attn-test', 'behave-hum-pe', 'behave-hum-noscale',
+ 'behave-hum-surf', 'behave-objv2v']:
+ from .behave_dataset import BehaveDataset, NTUDataset, BehaveObjOnly, BehaveHumanOnly, BehaveHumanOnlyPosEnc
+ from .behave_dataset import BehaveHumanOnlySegmInput, BehaveObjOnlySegmInput, BehaveTestOnly, BehaveHumNoscale
+ from .behave_dataset import BehaveHumanOnlySurfSample
+ from .dtransl_dataset import DirectTranslDataset
+ from .behave_paths import DataPaths
+ from configs.structured import BehaveDatasetConfig
+ from .behave_crossattn import BehaveCrossAttnDataset, BehaveCrossAttnTest
+ from .behave_dataset import BehaveObjOnlyV2V
+
+ dataset_cfg: BehaveDatasetConfig = cfg.dataset
+ # print(dataset_cfg.behave_dir)
+ train_paths, val_paths = DataPaths.load_splits(dataset_cfg.split_file, dataset_cfg.behave_dir)
+ # exit(0)
+
+ # split validation paths to only consider the selected batches
+ bs = cfg.dataloader.batch_size
+ num_batches_total = int(np.ceil(len(val_paths)/cfg.dataloader.batch_size))
+ end_idx = cfg.run.batch_end if cfg.run.batch_end is not None else num_batches_total
+ # print(cfg.run.batch_end, cfg.run.batch_start, end_idx)
+ val_paths = val_paths[cfg.run.batch_start*bs:end_idx*bs]
+
+ if cfg.dataset.type == 'behave':
+ train_type = BehaveDataset
+ val_datatype = BehaveDataset if 'ntu' not in dataset_cfg.split_file else NTUDataset
+ elif cfg.dataset.type == 'behave-test':
+ train_type = BehaveDataset
+ val_datatype = BehaveTestOnly
+ elif cfg.dataset.type == 'behave-objonly':
+ train_type = BehaveObjOnly
+ val_datatype = BehaveObjOnly
+ assert 'ntu' not in dataset_cfg.split_file, 'ntu not implemented!'
+ elif cfg.dataset.type == 'behave-humonly':
+ train_type = BehaveHumanOnly
+ val_datatype = BehaveHumanOnly
+ assert 'ntu' not in dataset_cfg.split_file, 'ntu not implemented!'
+ elif cfg.dataset.type == 'behave-hum-noscale':
+ train_type = BehaveHumNoscale
+ val_datatype = BehaveHumNoscale
+ elif cfg.dataset.type == 'behave-hum-pe':
+ train_type = BehaveHumanOnlyPosEnc
+ val_datatype = BehaveHumanOnlyPosEnc
+ elif cfg.dataset.type == 'behave-hum-surf':
+ train_type = BehaveHumanOnlySurfSample
+ val_datatype = BehaveHumanOnlySurfSample
+ elif cfg.dataset.type == 'behave-humonly-segm':
+ assert cfg.dataset.ho_segm_pred_path is not None, 'please specify predicted HO segmentation!'
+ train_type = BehaveHumanOnly
+ val_datatype = BehaveHumanOnlySegmInput
+ assert 'ntu' not in dataset_cfg.split_file, 'ntu not implemented!'
+ elif cfg.dataset.type == 'behave-objonly-segm':
+ assert cfg.dataset.ho_segm_pred_path is not None, 'please specify predicted HO segmentation!'
+ train_type = BehaveObjOnly
+ val_datatype = BehaveObjOnlySegmInput
+ assert 'ntu' not in dataset_cfg.split_file, 'ntu not implemented!'
+ elif cfg.dataset.type == 'behave-dtransl':
+ train_type = DirectTranslDataset
+ val_datatype = DirectTranslDataset
+ elif cfg.dataset.type == 'behave-attn':
+ train_type = BehaveCrossAttnDataset
+ val_datatype = BehaveCrossAttnDataset
+ elif cfg.dataset.type == 'behave-attn-test':
+ train_type = BehaveCrossAttnDataset
+ val_datatype = BehaveCrossAttnTest
+ elif cfg.dataset.type == 'behave-objv2v':
+ train_type = BehaveObjOnlyV2V
+ val_datatype = BehaveObjOnlyV2V
+ else:
+ raise NotImplementedError
+
+ dataset_train = train_type(train_paths, dataset_cfg.max_points, dataset_cfg.fix_sample,
+ (dataset_cfg.image_size, dataset_cfg.image_size),
+ split='train', sample_ratio_hum=dataset_cfg.sample_ratio_hum,
+ normalize_type=dataset_cfg.normalize_type, smpl_type='gt',
+ load_corr_points=dataset_cfg.load_corr_points,
+ uniform_obj_sample=dataset_cfg.uniform_obj_sample,
+ bkg_type=dataset_cfg.bkg_type,
+ bbox_params=dataset_cfg.bbox_params,
+ pred_binary=cfg.model.predict_binary,
+ ho_segm_pred_path=cfg.dataset.ho_segm_pred_path,
+ compute_closest_points=cfg.model.model_name=='pc2-diff-ho-tune-newloss',
+ use_gt_transl=cfg.dataset.use_gt_transl,
+ cam_noise_std=cfg.dataset.cam_noise_std,
+ sep_same_crop=cfg.dataset.sep_same_crop,
+ aug_blur=cfg.dataset.aug_blur,
+ std_coverage=cfg.dataset.std_coverage,
+ v2v_path=cfg.dataset.v2v_path)
+
+ dataset_val = val_datatype(val_paths, dataset_cfg.max_points, dataset_cfg.fix_sample,
+ (dataset_cfg.image_size, dataset_cfg.image_size),
+ split='val', sample_ratio_hum=dataset_cfg.sample_ratio_hum,
+ normalize_type=dataset_cfg.normalize_type, smpl_type=dataset_cfg.smpl_type,
+ load_corr_points=dataset_cfg.load_corr_points,
+ test_transl_type=dataset_cfg.test_transl_type,
+ uniform_obj_sample=dataset_cfg.uniform_obj_sample,
+ bkg_type=dataset_cfg.bkg_type,
+ bbox_params=dataset_cfg.bbox_params,
+ pred_binary=cfg.model.predict_binary,
+ ho_segm_pred_path=cfg.dataset.ho_segm_pred_path,
+ compute_closest_points=cfg.model.model_name=='pc2-diff-ho-tune-newloss',
+ use_gt_transl=cfg.dataset.use_gt_transl,
+ sep_same_crop=cfg.dataset.sep_same_crop,
+ std_coverage=cfg.dataset.std_coverage,
+ v2v_path=cfg.dataset.v2v_path)
+ # dataset_test = val_datatype(val_paths, dataset_cfg.max_points, dataset_cfg.fix_sample,
+ # (dataset_cfg.image_size, dataset_cfg.image_size),
+ # split='test', sample_ratio_hum=dataset_cfg.sample_ratio_hum,
+ # normalize_type=dataset_cfg.normalize_type, smpl_type=dataset_cfg.smpl_type,
+ # load_corr_points=dataset_cfg.load_corr_points,
+ # test_transl_type=dataset_cfg.test_transl_type,
+ # uniform_obj_sample=dataset_cfg.uniform_obj_sample,
+ # bkg_type=dataset_cfg.bkg_type,
+ # bbox_params=dataset_cfg.bbox_params,
+ # pred_binary=cfg.model.predict_binary,
+ # ho_segm_pred_path=cfg.dataset.ho_segm_pred_path,
+ # compute_closest_points=cfg.model.model_name=='pc2-diff-ho-tune-newloss',
+ # use_gt_transl=cfg.dataset.use_gt_transl,
+ # sep_same_crop=cfg.dataset.sep_same_crop)
+ dataloader_train = DataLoader(dataset_train, batch_size=cfg.dataloader.batch_size,
+ collate_fn=collate_batched_meshes,
+ num_workers=cfg.dataloader.num_workers, shuffle=True)
+ shuffle = cfg.run.job == 'train'
+ dataloader_val = DataLoader(dataset_val, batch_size=cfg.dataloader.batch_size,
+ collate_fn=collate_batched_meshes,
+ num_workers=cfg.dataloader.num_workers, shuffle=shuffle)
+ dataloader_vis = DataLoader(dataset_val, batch_size=cfg.dataloader.batch_size,
+ collate_fn=collate_batched_meshes,
+ num_workers=cfg.dataloader.num_workers, shuffle=shuffle)
+
+ # datasets = [BehaveDataset(p, dataset_cfg.max_points, dataset_cfg.fix_sample,
+ # (dataset_cfg.image_size, dataset_cfg.image_size),
+ # split=s, sample_ratio_hum=dataset_cfg.sample_ratio_hum,
+ # normalize_type=dataset_cfg.normalize_type) for p, s in zip([train_paths, val_paths, val_paths],
+ # ['train', 'val', 'test'])]
+ # dataloader_train = DataLoader(datasets[0], batch_size=cfg.dataloader.batch_size,
+ # collate_fn=collate_batched_meshes,
+ # num_workers=cfg.dataloader.num_workers, shuffle=True)
+ # dataloader_val = DataLoader(datasets[1], batch_size=cfg.dataloader.batch_size,
+ # collate_fn=collate_batched_meshes,
+ # num_workers=cfg.dataloader.num_workers, shuffle=False)
+ # dataloader_vis = DataLoader(datasets[2], batch_size=cfg.dataloader.batch_size,
+ # collate_fn=collate_batched_meshes,
+ # num_workers=cfg.dataloader.num_workers, shuffle=False)
+ elif cfg.dataset.type in ['shape']:
+ from .shape_dataset import ShapeDataset
+ from .behave_paths import DataPaths
+ from configs.structured import ShapeDatasetConfig
+ dataset_cfg: ShapeDatasetConfig = cfg.dataset
+
+ train_paths, _ = DataPaths.load_splits(dataset_cfg.split_file, dataset_cfg.behave_dir)
+ val_paths = train_paths # same as training, this is for overfitting
+ # split validation paths to only consider the selected batches
+ bs = cfg.dataloader.batch_size
+ num_batches_total = int(np.ceil(len(val_paths) / cfg.dataloader.batch_size))
+ end_idx = cfg.run.batch_end if cfg.run.batch_end is not None else num_batches_total
+ # print(cfg.run.batch_end, cfg.run.batch_start, end_idx)
+ val_paths = val_paths[cfg.run.batch_start * bs:end_idx * bs]
+
+ dataset_train = ShapeDataset(train_paths, dataset_cfg.max_points, dataset_cfg.fix_sample,
+ (dataset_cfg.image_size, dataset_cfg.image_size),
+ split='train', )
+ dataset_val = ShapeDataset(val_paths, dataset_cfg.max_points, dataset_cfg.fix_sample,
+ (dataset_cfg.image_size, dataset_cfg.image_size),
+ split='train', )
+ dataloader_train = DataLoader(dataset_train, batch_size=cfg.dataloader.batch_size,
+ collate_fn=collate_batched_meshes,
+ num_workers=cfg.dataloader.num_workers, shuffle=True)
+ shuffle = cfg.run.job == 'train'
+ dataloader_val = DataLoader(dataset_val, batch_size=cfg.dataloader.batch_size,
+ collate_fn=collate_batched_meshes,
+ num_workers=cfg.dataloader.num_workers, shuffle=shuffle)
+ dataloader_vis = DataLoader(dataset_val, batch_size=cfg.dataloader.batch_size,
+ collate_fn=collate_batched_meshes,
+ num_workers=cfg.dataloader.num_workers, shuffle=shuffle)
+ else:
+ raise NotImplementedError(cfg.dataset.type)
+
+ return dataloader_train, dataloader_val, dataloader_vis
diff --git a/dataset/base_data.py b/dataset/base_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..43d250afd08963e5d3c88b456c5e673c4487f7dc
--- /dev/null
+++ b/dataset/base_data.py
@@ -0,0 +1,110 @@
+from os import path as osp
+
+import cv2
+import numpy as np
+from torch.utils.data import Dataset
+
+from dataset.img_utils import masks2bbox, resize, crop
+
+
+class BaseDataset(Dataset):
+ def __init__(self, data_paths, input_size=(224, 224)):
+ self.data_paths = data_paths # RGB image files
+ self.input_size = input_size
+ opencv2py3d = np.eye(4)
+ opencv2py3d[0, 0] = opencv2py3d[1, 1] = -1
+ self.opencv2py3d = opencv2py3d
+
+ def __len__(self):
+ return len(self.data_paths)
+
+ def load_masks(self, rgb_file):
+ person_mask_file = rgb_file.replace('.color.jpg', ".person_mask.png")
+ if not osp.isfile(person_mask_file):
+ person_mask_file = rgb_file.replace('.color.jpg', ".person_mask.jpg")
+ obj_mask_file = None
+ for pat in [".obj_rend_mask.png", ".obj_rend_mask.jpg", ".obj_mask.png", ".obj_mask.jpg", ".object_rend.png"]:
+ obj_mask_file = rgb_file.replace('.color.jpg', pat)
+ if osp.isfile(obj_mask_file):
+ break
+ person_mask = cv2.imread(person_mask_file, cv2.IMREAD_GRAYSCALE)
+ obj_mask = cv2.imread(obj_mask_file, cv2.IMREAD_GRAYSCALE)
+
+ return person_mask, obj_mask
+
+ def get_crop_params(self, mask_hum, mask_obj, bbox_exp=1.0):
+ "compute bounding box based on masks"
+ bmin, bmax = masks2bbox([mask_hum, mask_obj])
+ crop_center = (bmin + bmax) // 2
+ # crop_size = np.max(bmax - bmin)
+ crop_size = int(np.max(bmax - bmin) * bbox_exp)
+ if crop_size % 2 == 1:
+ crop_size += 1 # make sure it is an even number
+ return bmax, bmin, crop_center, crop_size
+
+ def is_behave_dataset(self, image_width):
+ assert image_width in [2048, 1920, 1024, 960], f'unknwon image width {image_width}!'
+ if image_width in [2048, 1024]:
+ is_behave = True
+ else:
+ is_behave = False
+ return is_behave
+
+ def compute_K_roi(self, bbox_square,
+ image_width=2048,
+ image_height=1536,
+ fx=979.7844, fy=979.840,
+ cx=1018.952, cy=779.486):
+ "return results in ndc coordinate, this is correct!!!"
+ x, y, b, w = bbox_square
+ assert b == w
+ is_behave = self.is_behave_dataset(image_width)
+
+ if is_behave:
+ assert image_height / image_width == 0.75, f"invalid image aspect ratio: width={image_width}, height={image_height}"
+ # the image might be rendered at different size
+ ratio = image_width/2048.
+ fx, fy = 979.7844*ratio, 979.840*ratio
+ cx, cy = 1018.952*ratio, 779.486*ratio
+ else:
+ assert image_height / image_width == 9/16, f"invalid image aspect ratio: width={image_width}, height={image_height}"
+ # intercap camera
+ ratio = image_width/1920
+ fx, fy = 918.457763671875*ratio, 918.4373779296875*ratio
+ cx, cy = 956.9661865234375*ratio, 555.944580078125*ratio
+
+ cx, cy = cx - x, cy - y
+ scale = b/2.
+ # in ndc
+ cx_ = (scale - cx)/scale
+ cy_ = (scale - cy)/scale
+ fx_ = fx/scale
+ fy_ = fy/scale
+
+ K_roi = np.array([
+ [fx_, 0, cx_, 0],
+ [0., fy_, cy_, 0, ],
+ [0, 0, 0, 1.],
+ [0, 0, 1, 0]
+ ])
+ return K_roi
+
+ def crop_full_image(self, mask_hum, mask_obj, rgb_full, crop_masks, bbox_exp=1.0):
+ """
+ crop the image based on the given masks
+ :param mask_hum:
+ :param mask_obj:
+ :param rgb_full:
+ :param crop_masks: a list of masks used to do the crop
+ :return: Kroi, cropped human, object mask and RGB images (background masked out).
+ """
+ bmax, bmin, crop_center, crop_size = self.get_crop_params(*crop_masks, bbox_exp)
+ rgb = resize(crop(rgb_full, crop_center, crop_size), self.input_size) / 255.
+ person_mask = resize(crop(mask_hum, crop_center, crop_size), self.input_size) / 255.
+ obj_mask = resize(crop(mask_obj, crop_center, crop_size), self.input_size) / 255.
+ xywh = np.concatenate([crop_center - crop_size // 2, np.array([crop_size, crop_size])])
+ Kroi = self.compute_K_roi(xywh, rgb_full.shape[1], rgb_full.shape[0])
+ # mask bkg out
+ mask_comb = (person_mask > 0.5) | (obj_mask > 0.5)
+ rgb = rgb * np.expand_dims(mask_comb, -1)
+ return Kroi, obj_mask, person_mask, rgb
diff --git a/dataset/behave_paths.py b/dataset/behave_paths.py
new file mode 100644
index 0000000000000000000000000000000000000000..26304a2c937f391308de1363d9d13cd90c73cf6d
--- /dev/null
+++ b/dataset/behave_paths.py
@@ -0,0 +1,228 @@
+import glob
+import os, re
+import pickle as pkl
+from os.path import join, basename, dirname, isfile
+import os.path as osp
+
+import cv2, json
+import numpy as np
+
+# PROCESSED_PATH = paths['PROCESSED_PATH']
+BEHAVE_PATH = "/BS/xxie-5/static00/behave_release/sequences/"
+RECON_PATH = "/BS/xxie-5/static00/behave-train"
+
+class DataPaths:
+ """
+ class to handle path operations based on BEHAVE dataset structure
+ """
+ def __init__(self):
+ pass
+
+ @staticmethod
+ def load_splits(split_file, dataset_path=None):
+ assert os.path.exists(dataset_path), f'the given dataset path {dataset_path} does not exist, please check if your training data are placed over there!'
+ train, val = DataPaths.get_train_test_from_pkl(split_file)
+ return train, val
+ # print(train[:5], val[:5])
+ if isinstance(train[0], list):
+ # video data
+ train_full = [[join(dataset_path, seq[x]) for x in range(len(seq))] for seq in train]
+ val_full = [[join(dataset_path, seq[x]) for x in range(len(seq))] for seq in val]
+ else:
+ train_full = [join(dataset_path, x) for x in train] # full path to the training data
+ val_full = [join(dataset_path, x) for x in val] # full path to the validation data files
+ # print(train_full[:5], val_full[:5])
+ return train_full, val_full
+
+ @staticmethod
+ def load_splits_online(split_file, dataset_path=BEHAVE_PATH):
+ "load rgb file, smpl and object mesh paths"
+ keys = ['rgb', 'smpl', 'obj']
+ types = ['train', 'val']
+ splits = {}
+ data = pkl.load(open(split_file, 'rb'))
+ for type in types:
+ for key in keys:
+ k = f'{type}_{key}'
+ splits[k] = [join(dataset_path, x) for x in data[k]]
+ return splits
+
+ @staticmethod
+ def get_train_test_from_pkl(pkl_file):
+ data = pkl.load(open(pkl_file, 'rb'))
+ return data['train'], data['test']
+
+ @staticmethod
+ def get_image_paths_seq(seq, tid=1, check_occlusion=False, pat='t*.000'):
+ """
+ find all image paths in one sequence
+ :param seq: path to one behave sequence
+ :param tid: test on images from which camera
+ :param check_occlusion: whether to load full object mask and check occlusion ratio
+ :return: a list of paths to test image files
+ """
+ image_files = sorted(glob.glob(seq + f"/{pat}/k{tid}.color.jpg"))
+ # print(image_files, seq + f"/{pat}/k{tid}.color.jpg")
+ if not check_occlusion:
+ return image_files
+ # check object occlusion ratio
+ valid_files = []
+ count = 0
+ for img_file in image_files:
+ mask_file = img_file.replace('.color.jpg', '.obj_rend_mask.png')
+ if not os.path.isfile(mask_file):
+ mask_file = img_file.replace('.color.jpg', '.obj_rend_mask.jpg')
+ full_mask_file = img_file.replace('.color.jpg', '.obj_rend_full.png')
+ if not os.path.isfile(full_mask_file):
+ full_mask_file = img_file.replace('.color.jpg', '.obj_rend_full.jpg')
+ if not isfile(mask_file) or not isfile(full_mask_file):
+ continue
+
+ mask = np.sum(cv2.imread(mask_file, cv2.IMREAD_GRAYSCALE) > 127)
+ mask_full = np.sum(cv2.imread(full_mask_file, cv2.IMREAD_GRAYSCALE) > 127)
+ if mask_full == 0:
+ count += 1
+ continue
+
+ ratio = mask / mask_full
+ if ratio > 0.3:
+ valid_files.append(img_file)
+ else:
+ count += 1
+ print(f'{mask_file} occluded by {1 - ratio}!')
+ return valid_files
+
+ @staticmethod
+ def get_kinect_id(rgb_file):
+ "extract kinect id from the rgb file"
+ filename = osp.basename(rgb_file)
+ try:
+ kid = int(filename.split('.')[0][1])
+ assert kid in [0, 1, 2, 3, 4, 5], f'found invalid kinect id {kid} for file {rgb_file}'
+ return kid
+ except Exception as e:
+ print(rgb_file)
+ raise ValueError()
+
+ @staticmethod
+ def get_seq_date(rgb_file):
+ "date for the sequence"
+ seq_name = str(rgb_file).split(os.sep)[-3]
+ date = seq_name.split('_')[0]
+ assert date in ['Date01', 'Date02', 'Date03', 'Date04', 'Date05', 'Date06', 'Date07',
+ "ICapS01", "ICapS02", "ICapS03", "Date08", "Date09"], f"invalid date for {rgb_file}"
+ return date
+
+ @staticmethod
+ def rgb2obj_path(rgb_file:str, save_name='fit01-smooth'):
+ "convert an rgb file to a obj mesh file"
+ ss = rgb_file.split(os.sep)
+ seq_name = ss[-3]
+ obj_name = seq_name.split('_')[2]
+ real_name = obj_name
+ if 'chair' in obj_name:
+ real_name = 'chair'
+ if 'ball' in obj_name:
+ real_name = 'sports ball'
+
+ frame_folder = osp.dirname(rgb_file)
+ mesh_file = osp.join(frame_folder, real_name, save_name, f'{real_name}_fit.ply')
+
+ if not osp.isfile(mesh_file):
+ # synthetic data
+ mesh_file = osp.join(frame_folder, obj_name, save_name, f'{obj_name}_fit.ply')
+ return mesh_file
+
+ @staticmethod
+ def rgb2smpl_path(rgb_file:str, save_name='fit03'):
+ frame_folder = osp.dirname(rgb_file)
+ real_name = 'person'
+ mesh_file = osp.join(frame_folder, real_name, save_name, f'{real_name}_fit.ply')
+ return mesh_file
+
+ @staticmethod
+ def rgb2seq_frame(rgb_file:str):
+ "rgb file to seq_name, frame time"
+ ss = rgb_file.split(os.sep)
+ return ss[-3], ss[-2]
+
+ @staticmethod
+ def rgb2recon_folder(rgb_file, save_name, recon_path):
+ "convert rgb file to the subfolder"
+ dataset_path = osp.dirname(osp.dirname(osp.dirname(rgb_file)))
+ recon_folder = osp.join(osp.dirname(rgb_file.replace(dataset_path, recon_path)), save_name)
+ return recon_folder
+
+ @staticmethod
+ def get_seq_name(rgb_file):
+ return osp.basename(osp.dirname(osp.dirname(rgb_file)))
+
+ @staticmethod
+ def rgb2template_path(rgb_file):
+ "return the path to the object template"
+ from recon.opt_utils import get_template_path
+ # seq_name = DataPaths.get_seq_name(rgb_file)
+ # obj_name = seq_name.split('_')[2]
+ obj_name = DataPaths.rgb2object_name(rgb_file)
+ path = get_template_path(BEHAVE_PATH+"/../objects", obj_name)
+ return path
+
+ @staticmethod
+ def rgb2object_name(rgb_file):
+ seq_name = DataPaths.get_seq_name(rgb_file)
+ obj_name = seq_name.split('_')[2]
+ return obj_name
+
+ @staticmethod
+ def rgb2recon_frame(rgb_file, recon_path=RECON_PATH):
+ "return the frame folder in recon path"
+ ss = rgb_file.split(os.sep)
+ seq_name, frame = ss[-3], ss[-2]
+ return osp.join(recon_path, seq_name, frame)
+
+ @staticmethod
+ def rgb2gender(rgb_file):
+ "find the gender of this image"
+ seq_name = str(rgb_file).split(os.sep)[-3]
+ sub = seq_name.split('_')[1]
+ return _sub_gender[sub]
+
+ @staticmethod
+ def get_dataset_root(rgb_file):
+ "return the root path to all sequences"
+ from pathlib import Path
+ path = Path(rgb_file)
+ return str(path.parents[2])
+
+ @staticmethod
+ def seqname2gender(seq_name:str):
+ sub = seq_name.split('_')[1]
+ return _sub_gender[sub]
+
+ICAP_PATH = "/BS/xxie-6/static00/InterCap" # assume same root folder
+date_seqs = {
+ "Date01": BEHAVE_PATH + "/Date01_Sub01_backpack_back",
+ "Date02": BEHAVE_PATH + "/Date02_Sub02_backpack_back",
+ "Date03": BEHAVE_PATH + "/Date03_Sub03_backpack_back",
+ "Date04": BEHAVE_PATH + "/Date04_Sub05_backpack",
+ "Date05": BEHAVE_PATH + "/Date05_Sub05_backpack",
+ "Date06": BEHAVE_PATH + "/Date06_Sub07_backpack_back",
+ "Date07": BEHAVE_PATH + "/Date07_Sub04_backpack_back",
+ # "Date08": "/BS/xxie-6/static00/synthesize/Date08_Subxx_chairwood_synzv2-02",
+ "Date08": "/BS/xxie-6/static00/synz-backup/Date08_Subxx_chairwood_synzv2-02",
+ "Date09": "/BS/xxie-6/static00/synthesize/Date09_Subxx_obj01_icap", # InterCap sequence synz
+ "ICapS01": ICAP_PATH + "/ICapS01_sub01_obj01_Seg_0",
+ "ICapS02": ICAP_PATH + "/ICapS02_sub01_obj08_Seg_0",
+ "ICapS03": ICAP_PATH + "/ICapS03_sub07_obj05_Seg_0",
+}
+
+_sub_gender = {
+"Sub01": 'male',
+"Sub02": 'male',
+"Sub03": 'male',
+"Sub04": 'male',
+"Sub05": 'male',
+"Sub06": 'female',
+"Sub07": 'female',
+"Sub08": 'female',
+}
\ No newline at end of file
diff --git a/dataset/demo_dataset.py b/dataset/demo_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..8407baf146b6925fe34d36b5ff60dec97d30d7a6
--- /dev/null
+++ b/dataset/demo_dataset.py
@@ -0,0 +1,198 @@
+import os
+import numpy as np
+import cv2
+import torch
+
+from .base_data import BaseDataset
+from .behave_paths import DataPaths
+from .img_utils import compute_translation, masks2bbox, crop
+
+
+def padTo_4x3(rgb, person_mask, obj_mask, aspect_ratio=0.75):
+ """
+ pad images to have 4:3 aspect ratio
+ :param rgb: (H, W, 3)
+ :param person_mask:
+ :param obj_mask:
+ :return: all images at the given aspect ratio
+ """
+ h, w = rgb.shape[:2]
+ if w > h * 1/aspect_ratio:
+ # pad top
+ h_4x3 = int(w * aspect_ratio)
+ pad_top = h_4x3 - h
+ rgb_pad = np.pad(rgb, ((pad_top, 0), (0, 0), (0, 0)))
+ person_mask = np.pad(person_mask, ((pad_top, 0), (0, 0))) if person_mask is not None else None
+ obj_mask = np.pad(obj_mask, ((pad_top, 0), (0, 0))) if obj_mask is not None else None
+ else:
+ # pad two side
+ w_new = np.lcm.reduce([h * 2, 16]) # least common multiplier
+ h_4x3 = int(w_new * aspect_ratio)
+ pad_top = h_4x3 - h
+ pad_left = (w_new - w) // 2
+ pad_right = w_new - w - pad_left
+ rgb_pad = np.pad(rgb, ((pad_top, 0), (pad_left, pad_right), (0, 0)))
+ obj_mask = np.pad(obj_mask, ((pad_top, 0), (pad_left, pad_right))) if obj_mask is not None else None
+ person_mask = np.pad(person_mask, ((pad_top, 0), (pad_left, pad_right))) if person_mask is not None else None
+ return rgb_pad, obj_mask, person_mask
+
+
+def recrop_input(rgb, person_mask, obj_mask, dataset_name='behave'):
+ "recrop input images"
+ exp_ratio = 1.42
+ if dataset_name == 'behave':
+ mean_center = np.array([1008, 995]) # mean RGB image crop center
+ behave_size = (2048, 1536)
+ new_size = (int(750 * exp_ratio), int(exp_ratio * 750))
+ else:
+ mean_center = np.array([904, 668]) # mean RGB image crop center for bottle sequences of ICAP
+ behave_size = (1920, 1080)
+ new_size = (int(593.925 * exp_ratio), int(exp_ratio * 593.925)) # mean width of bottle sequences
+ aspect_ratio = behave_size[1] / behave_size[0]
+ pad_top = mean_center[1] - new_size[0] // 2
+ pad_bottom = behave_size[1] - (mean_center[1] + new_size[0] // 2)
+ pad_left = mean_center[0] - new_size[0] // 2
+ pad_right = behave_size[0] - (mean_center[0] + new_size[0] // 2)
+
+ # First resize to the same aspect ratio
+ if rgb.shape[0] / rgb.shape[1] != aspect_ratio:
+ rgb, obj_mask, person_mask = padTo_4x3(rgb, person_mask, obj_mask, aspect_ratio)
+
+ # Resize to the same size as behave image, to have a comparable pixel size
+ rgb = cv2.resize(rgb, behave_size)
+ mask_ps = cv2.resize(person_mask, behave_size)
+ mask_obj = cv2.resize(obj_mask, behave_size)
+
+ # Crop and resize the human + object patch
+ bmin, bmax = masks2bbox([mask_ps, mask_obj])
+ center = (bmin + bmax) // 2
+ crop_size = int(np.max(bmax - bmin) * exp_ratio) # larger crop to have background
+ img_crop = cv2.resize(crop(rgb, center, crop_size), new_size)
+ mask_ps = cv2.resize(crop(mask_ps, center, crop_size), new_size)
+ mask_obj = cv2.resize(crop(mask_obj, center, crop_size), new_size)
+
+ # Pad back to have same shape as behave image
+ img_full = np.pad(img_crop, [[pad_top, pad_bottom], [pad_left, pad_right], [0, 0]])
+ mask_ps_full = np.pad(mask_ps, [[pad_top, pad_bottom], [pad_left, pad_right]])
+ mask_obj_full = np.pad(mask_obj, [[pad_top, pad_bottom], [pad_left, pad_right]])
+
+ # Make sure the image shape is the same
+ if img_full.shape[:2] != behave_size[::-1]:
+ img_full = cv2.resize(img_full, behave_size)
+ mask_ps_full = cv2.resize(mask_ps_full, behave_size)
+ mask_obj_full = cv2.resize(mask_obj_full, behave_size)
+ return img_full, mask_ps_full, mask_obj_full
+
+
+class DemoDataset(BaseDataset):
+ def __init__(self, data_paths, input_size=(224, 224),
+ std_coverage=3.5, # used to estimate camera translation
+ ):
+ super().__init__(data_paths, input_size)
+ self.std_coverage = std_coverage
+
+ def __len__(self):
+ return len(self.data_paths)
+
+ def __getitem__(self, idx):
+ rgb_file = self.data_paths[idx]
+ mask_hum, mask_obj = self.load_masks(rgb_file)
+ rgb_full = cv2.imread(rgb_file)[:, :, ::-1]
+
+ return self.image2dict(mask_hum, mask_obj, rgb_full, rgb_file)
+
+ def image2dict(self, mask_hum, mask_obj, rgb_full, rgb_file=None):
+ "do all the necessary preprocessing for images"
+ if rgb_full.shape[:2] != mask_obj.shape[:2]:
+ raise ValueError(f"The given object mask shape {mask_obj.shape[:2]} does not match the RGB image shape {rgb_full.shape[:2]}")
+ if rgb_full.shape[:2] != mask_hum.shape[:2]:
+ raise ValueError(f"The given human mask shape {mask_hum.shape[:2]} does not match the RGB image shape {rgb_full.shape[:2]}")
+
+ if rgb_full.shape[:2] not in [(1080, 1920), (1536, 2048)]:
+ # crop and resize the image to behave image size
+ print(f"Recropping the input image and masks for {rgb_file}")
+ rgb_full, mask_hum, mask_obj = recrop_input(rgb_full, mask_hum, mask_obj)
+ color_h, color_w = rgb_full.shape[:2]
+ # Input to the first stage model: human + object crop
+ Kroi, objmask_fullcrop, psmask_fullcrop, rgb_fullcrop = self.crop_full_image(mask_hum.copy(),
+ mask_obj.copy(),
+ rgb_full.copy(),
+ [mask_hum, mask_obj],
+ 1.00)
+ # Input to the second stage model: human and object crops
+ Kroi_h, masko_hum, maskh_hum, rgb_hum = self.crop_full_image(mask_hum.copy(),
+ mask_obj.copy(),
+ rgb_full.copy(),
+ [mask_hum, mask_hum], 1.05)
+ Kroi_o, masko_obj, maskh_obj, rgb_obj = self.crop_full_image(mask_hum.copy(),
+ mask_obj.copy(),
+ rgb_full.copy(),
+ [mask_obj, mask_obj], 1.5)
+ # Estimate camera translation
+ cent_transform = np.eye(4) # the transform applied to the mesh that moves it back to kinect camera frame
+ bmin_ho, bmax_ho = masks2bbox([mask_hum, mask_obj])
+ crop_size_ho = int(np.max(bmax_ho - bmin_ho) * 1.0)
+ if crop_size_ho % 2 == 1:
+ crop_size_ho += 1 # make sure it is an even number
+ is_behave = self.is_behave_dataset(rgb_full.shape[1])
+ if rgb_full.shape[1] not in [2048, 1920]:
+ raise ValueError('the image is not normalized to BEHAVE or ICAP size!')
+ indices = np.indices(rgb_full.shape[:2])
+ if np.sum(mask_obj > 127) < 5:
+ raise ValueError(f'not enough object mask found for {rgb_file}')
+ pts_h = np.stack([indices[1][mask_hum > 127], indices[0][mask_hum > 127]], -1)
+ pts_o = np.stack([indices[1][mask_obj > 127], indices[0][mask_obj > 127]], -1)
+ proj_cent_est = (np.mean(pts_h, 0) + np.mean(pts_o, 0)) / 2. # heuristic to obtain 2d projection center
+ transl_estimate = compute_translation(proj_cent_est, crop_size_ho, is_behave, self.std_coverage)
+ cent_transform[:3, 3] = transl_estimate / 7.0
+ radius = 0.5 # don't do normalization anymore
+ cent = transl_estimate / 7.0
+ comb = np.matmul(self.opencv2py3d, cent_transform)
+ R = torch.from_numpy(comb[:3, :3]).float()
+ T = torch.from_numpy(comb[:3, 3]).float() / (radius * 2)
+ data_dict = {
+ "R": R,
+ "T": T,
+ "K": torch.from_numpy(Kroi).float(),
+ "T_ho": torch.from_numpy(cent).float(), # translation for H+O
+ "image_path": rgb_file,
+ "image_size_hw": torch.tensor(self.input_size),
+ "images": torch.from_numpy(rgb_fullcrop).float().permute(2, 0, 1),
+ "masks": torch.from_numpy(np.stack([psmask_fullcrop, objmask_fullcrop], 0)).float(),
+ 'orig_image_size': torch.tensor([color_h, color_w]),
+
+ # Human input to stage 2
+ "images_hum": torch.from_numpy(rgb_hum).float().permute(2, 0, 1),
+ "masks_hum": torch.from_numpy(np.stack([maskh_hum, masko_hum], 0)).float(),
+ "K_hum": torch.from_numpy(Kroi_h).float(),
+
+ # Object input to stage 2
+ "images_obj": torch.from_numpy(rgb_obj).float().permute(2, 0, 1),
+ "masks_obj": torch.from_numpy(np.stack([maskh_obj, masko_obj], 0)).float(),
+ "K_obj": torch.from_numpy(Kroi_o).float(),
+
+ # some normalization parameters
+ "gt_trans": cent,
+ 'radius': radius,
+ "estimated_trans": transl_estimate,
+ }
+ return data_dict
+
+ def image2batch(self, rgb, mask_hum, mask_obj):
+ """
+ given input image, convert it into a batch object ready for model inference
+ :param rgb: (h, w, 3), np array
+ :param mask_hum: (h, w, 3), np array
+ :param mask_obj: (h, w, 3), np array
+ :return:
+ """
+ mask_hum = np.mean(mask_hum, -1)
+ mask_obj = np.mean(mask_obj, -1)
+
+ data_dict = self.image2dict(mask_hum, mask_obj, rgb, 'input image')
+ # convert dict to list
+ new_dict = {k:[v] for k, v in data_dict.items()}
+
+ return new_dict
+
+
diff --git a/dataset/img_utils.py b/dataset/img_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb54b295e620e2487bad6dccc0aa5a9680d3b6b2
--- /dev/null
+++ b/dataset/img_utils.py
@@ -0,0 +1,149 @@
+"""
+common functions for image operations
+"""
+
+import cv2
+import numpy as np
+
+
+def crop(img, center, crop_size):
+ """
+ crop image around the given center, pad zeros for borders
+ :param img:
+ :param center: np array
+ :param crop_size: np array or a float size of the resulting crop
+ :return: a square crop around the center
+ """
+ assert isinstance(img, np.ndarray)
+ h, w = img.shape[:2]
+ topleft = np.round(center - crop_size / 2).astype(int)
+ bottom_right = np.round(center + crop_size / 2).astype(int)
+
+ x1 = max(0, topleft[0])
+ y1 = max(0, topleft[1])
+ x2 = min(w - 1, bottom_right[0])
+ y2 = min(h - 1, bottom_right[1])
+ cropped = img[y1:y2, x1:x2]
+
+ p1 = max(0, -topleft[0]) # padding in x, top
+ p2 = max(0, -topleft[1]) # padding in y, top
+ p3 = max(0, bottom_right[0] - w + 1) # padding in x, bottom
+ p4 = max(0, bottom_right[1] - h + 1) # padding in y, bottom
+
+ dim = len(img.shape)
+ if dim == 3:
+ padded = np.pad(cropped, [[p2, p4], [p1, p3], [0, 0]])
+ elif dim == 2:
+ padded = np.pad(cropped, [[p2, p4], [p1, p3]])
+ else:
+ raise NotImplemented
+ return padded
+
+
+def resize(img, img_size, mode=cv2.INTER_LINEAR):
+ """
+ resize image to the input
+ :param img:
+ :param img_size: (width, height) of the target image size
+ :param mode:
+ :return:
+ """
+ h, w = img.shape[:2]
+ load_ratio = 1.0 * w / h
+ netin_ratio = 1.0 * img_size[0] / img_size[1]
+ assert load_ratio == netin_ratio, "image aspect ration not matching, given image: {}, net input: {}".format(
+ img.shape, img_size)
+ resized = cv2.resize(img, img_size, interpolation=mode)
+ return resized
+
+
+def masks2bbox(masks, threshold=127):
+ """
+
+ :param masks:
+ :param threshold:
+ :return: bounding box corner coordinate
+ """
+ mask_comb = np.zeros_like(masks[0], dtype=bool)
+ for m in masks:
+ mask_comb = mask_comb | (m > threshold)
+
+ yid, xid = np.where(mask_comb)
+ bmin = np.array([xid.min(), yid.min()])
+ bmax = np.array([xid.max(), yid.max()])
+ return bmin, bmax
+
+
+def compute_translation(crop_center, crop_size, is_behave=True, std_coverage=3.5):
+ """
+ solve for an optimal translation that project gaussian in origin to the crop
+ Parameters
+ ----------
+ crop_center: (x, y) of the crop center
+ crop_size: float, the size of the square crop
+ std_coverage: which edge point should be projected back to the edge of the 2d crop
+
+ Returns
+ -------
+ the estimated translation
+
+ """
+ x0, y0 = crop_center
+ x1, y1 = x0 + crop_size/2, y0
+ x2, y2 = x0 - crop_size/2, y0
+ x3, y3 = x0, y0 + crop_size/2.
+ # predefined kinect intrinsics
+ if is_behave:
+ fx = 979.7844
+ fy = 979.840
+ cx = 1018.952
+ cy = 779.486
+ else:
+ # intercap camera
+ fx, fy = 918.457763671875, 918.4373779296875
+ cx, cy = 956.9661865234375, 555.944580078125
+
+ # construct the matrix
+ # A = np.array([
+ # [fx, 0, cx-x0, cx-x0, 0, 0],
+ # [0, fy, cy-y0, cy-y0, 0, 0],
+ # [fx, 0, cx-x1, 0, cx-x1, 0],
+ # [0, fy, cy-y1, 0, cy-y1, 0],
+ # [fx, 0, cx-x2, 0, 0, cx-x2],
+ # [0, fy, cy-y2, 0, 0, cy-y2]
+ # ]) # this matrix is low-rank because columns are linearly dependent: col3 - col4 = col5 + col6
+ # # find linearly dependent rows
+ # lambdas, V = np.linalg.eig(A)
+ # # print()
+ # # The linearly dependent row vectors
+ # print(lambdas == 0, np.linalg.det(A), A[lambdas == 0, :]) # some have determinant zero, some don't??
+ # print(np.linalg.inv(A))
+
+ # A = np.array([
+ # [fx, 0, cx - x0, cx - x0, 0, 0],
+ # [0, fy, cy - y0, cy - y0, 0, 0],
+ # [fx, 0, cx - x1, 0, cx - x1, 0],
+ # [0, fy, cy - y1, 0, cy - y1, 0],
+ # [fx, 0, cx - x3, 0, 0, cx - x3],
+ # [0, fy, cy - y3, 0, 0, cy - y3]
+ # ]) # this is also low rank!
+ # b = np.array([0, 0, -3*fx, 0, 0, -3*fy]).reshape((-1, 1))
+ # print("rank of the coefficient matrix:", np.linalg.matrix_rank(A)) # rank is 5! underconstrained matrix!
+ # x = np.matmul(np.linalg.inv(A), b)
+
+ # fix z0 as 0, then A is a full-rank matrix
+ # first two equations: origin (0, 0, 0) is projected to the crop center
+ # last two equations: edge point (3.5, 0, z) is projected to the edge of crop
+ A = np.array([
+ [fx, 0, cx-x0, cx-x0],
+ [0, fy, cy-y0, cy-y0],
+ [fx, 0, fx-x1, 0],
+ [0, fy, cy-y1, 0]
+ ])
+ # b = np.array([0, 0, -3.5*fx, 0]).reshape((-1, 1)) # 3.5->half of 7.0
+ b = np.array([0, 0, -std_coverage * fx, 0]).reshape((-1, 1)) # 3.5->half of 7.0
+ x = np.matmul(np.linalg.inv(A), b) # use 4 or 5 does not really matter, same results
+
+ # A is always a full-rank matrix
+
+ return x.flatten()[:3]
diff --git a/demo.py b/demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..9170002b248a364b696aaef6e2a2b3f0b667bda4
--- /dev/null
+++ b/demo.py
@@ -0,0 +1,280 @@
+"""
+Demo for template-free reconstruction
+
+python demo.py model=ho-attn run.image_path=/BS/xxie-2/work/HDM/outputs/000000017450/k1.color.jpg run.job=sample model.predict_binary=True dataset.std_coverage=3.0
+"""
+import pickle as pkl
+import sys, os
+import os.path as osp
+from typing import Iterable, Optional
+
+import cv2
+from accelerate import Accelerator
+from tqdm import tqdm
+from glob import glob
+
+sys.path.append(os.getcwd())
+import hydra
+import torch
+import numpy as np
+import imageio
+from torch.utils.data import DataLoader
+from pytorch3d.datasets import R2N2, collate_batched_meshes
+from pytorch3d.structures import Pointclouds
+from pytorch3d.renderer import PerspectiveCameras, look_at_view_transform
+from pytorch3d.io import IO
+import torchvision.transforms.functional as TVF
+from huggingface_hub import hf_hub_download
+
+import training_utils
+from configs.structured import ProjectConfig
+from dataset.demo_dataset import DemoDataset
+from model import CrossAttenHODiffusionModel, ConditionalPCDiffusionSeparateSegm
+from render.pyt3d_wrapper import PcloudRenderer
+
+
+class DemoRunner:
+ def __init__(self, cfg: ProjectConfig):
+ cfg.model.model_name, cfg.model.predict_binary = 'pc2-diff-ho-sepsegm', True
+ model_stage1 = ConditionalPCDiffusionSeparateSegm(**cfg.model)
+ cfg.model.model_name, cfg.model.predict_binary = 'diff-ho-attn', False # stage 2 does not predict segmentation
+ model_stage2 = CrossAttenHODiffusionModel(**cfg.model)
+
+ # Load from checkpoint
+ # ckpt_file1 = os.path.join(cfg.run.code_dir_abs, f'outputs/{cfg.run.stage1_name}/single/checkpoint-latest.pth')
+ # self.load_checkpoint(ckpt_file1, model_stage1)
+ # ckpt_file2 = os.path.join(cfg.run.code_dir_abs, f'outputs/{cfg.run.stage2_name}/single/checkpoint-latest.pth')
+ # self.load_checkpoint(ckpt_file2, model_stage2)
+ # Load ckpt from hf
+ ckpt_file1 = hf_hub_download("xiexh20/HDM-models", f'{cfg.run.stage1_name}.pth')
+ self.load_checkpoint(ckpt_file1, model_stage1)
+ ckpt_file2 = hf_hub_download("xiexh20/HDM-models", f'{cfg.run.stage2_name}.pth')
+ self.load_checkpoint(ckpt_file2, model_stage2)
+
+ self.model_stage1, self.model_stage2 = model_stage1, model_stage2
+ self.model_stage1.eval()
+ self.model_stage2.eval()
+ self.model_stage1.to('cuda')
+ self.model_stage2.to('cuda')
+
+ self.cfg = cfg
+ self.io_pc = IO()
+
+ # For visualization
+ self.renderer = PcloudRenderer(image_size=cfg.dataset.image_size, radius=0.0075)
+ self.rend_size = cfg.dataset.image_size
+ self.device = 'cuda'
+
+ def load_checkpoint(self, ckpt_file1, model_stage1):
+ checkpoint = torch.load(ckpt_file1, map_location='cpu')
+ state_dict, key = checkpoint['model'], 'model'
+ if any(k.startswith('module.') for k in state_dict.keys()):
+ state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
+ print('Removed "module." from checkpoint state dict')
+ missing_keys, unexpected_keys = model_stage1.load_state_dict(state_dict, strict=False)
+ print(f'Loaded model checkpoint {key} from {ckpt_file1}')
+ if len(missing_keys):
+ print(f' - Missing_keys: {missing_keys}')
+ if len(unexpected_keys):
+ print(f' - Unexpected_keys: {unexpected_keys}')
+
+ @torch.no_grad()
+ def run(self):
+ "simply run the demo on given images, and save the results"
+ # Set random seed
+ training_utils.set_seed(self.cfg.run.seed)
+
+ outdir = osp.join(self.cfg.run.code_dir_abs, 'outputs/demo')
+ os.makedirs(outdir, exist_ok=True)
+ cfg = self.cfg
+
+ # Init data
+ image_files = sorted(glob(cfg.run.image_path))
+ data = DemoDataset(image_files,
+ (cfg.dataset.image_size, cfg.dataset.image_size),
+ cfg.dataset.std_coverage)
+ dataloader = DataLoader(data, batch_size=cfg.dataloader.batch_size,
+ collate_fn=collate_batched_meshes,
+ num_workers=1, shuffle=False)
+ dataloader = dataloader
+ progress_bar = tqdm(dataloader)
+ for batch_idx, batch in enumerate(progress_bar):
+ progress_bar.set_description(f'Processing batch {batch_idx:4d} / {len(dataloader):4d}')
+
+ out_stage1, out_stage2 = self.forward_batch(batch, cfg)
+
+ bs = len(out_stage1)
+ camera_full = PerspectiveCameras(
+ R=torch.stack(batch['R']),
+ T=torch.stack(batch['T']),
+ K=torch.stack(batch['K']),
+ device='cuda',
+ in_ndc=True)
+
+ # save output
+ for i in range(bs):
+ image_path = str(batch['image_path'])
+ folder, fname = osp.basename(osp.dirname(image_path)), osp.splitext(osp.basename(image_path))[0]
+ out_i = osp.join(outdir, folder)
+ os.makedirs(out_i, exist_ok=True)
+ self.io_pc.save_pointcloud(data=out_stage1[i],
+ path=osp.join(out_i, f'{fname}_stage1.ply'))
+ self.io_pc.save_pointcloud(data=out_stage2[i],
+ path=osp.join(out_i, f'{fname}_stage2.ply'))
+ TVF.to_pil_image(batch['images'][i]).save(osp.join(out_i, f'{fname}_input.png'))
+
+ # Save metadata as well
+ metadata = dict(index=i,
+ camera=camera_full[i],
+ image_size_hw=batch['image_size_hw'][i],
+ image_path=batch['image_path'][i])
+ torch.save(metadata, osp.join(out_i, f'{fname}_meta.pth'))
+
+ # Visualize
+ # front_camera = camera_full[i]
+ pc_comb = Pointclouds([out_stage1[i].points_packed(), out_stage2[i].points_packed()],
+ features=[out_stage1[i].features_packed(), out_stage2[i].features_packed()])
+ video_file = osp.join(out_i, f'{fname}_360view.mp4')
+ video_writer = imageio.get_writer(video_file, format='FFMPEG', mode='I', fps=1)
+
+ # first render front view
+ rend_stage1, _ = self.renderer.render(out_stage1[i], camera_full[i], mode='mask')
+ rend_stage2, _ = self.renderer.render(out_stage2[i], camera_full[i], mode='mask')
+ comb = np.concatenate([batch['images'][i].permute(1, 2, 0).cpu().numpy(), rend_stage1, rend_stage2], 1)
+ video_writer.append_data((comb*255).astype(np.uint8))
+
+ for azim in range(180, 180+360, 30):
+ R, T = look_at_view_transform(1.7, 0, azim, up=((0, -1, 0),), )
+ side_camera = PerspectiveCameras(image_size=((self.rend_size, self.rend_size),),
+ device=self.device,
+ R=R.repeat(2, 1, 1), T=T.repeat(2, 1),
+ focal_length=self.rend_size * 1.5,
+ principal_point=((self.rend_size / 2., self.rend_size / 2.),),
+ in_ndc=False)
+ rend, mask = self.renderer.render(pc_comb, side_camera, mode='mask')
+
+ imgs = [batch['images'][i].permute(1, 2, 0).cpu().numpy()]
+ imgs.extend([rend[0], rend[1]])
+ video_writer.append_data((np.concatenate(imgs, 1)*255).astype(np.uint8))
+ print(f"Visualization saved to {out_i}")
+
+ @torch.no_grad()
+ def forward_batch(self, batch, cfg):
+ """
+ forward one batch
+ :param batch:
+ :param cfg:
+ :return: predicted point clouds of stage 1 and 2
+ """
+ camera_full = PerspectiveCameras(
+ R=torch.stack(batch['R']),
+ T=torch.stack(batch['T']),
+ K=torch.stack(batch['K']),
+ device='cuda',
+ in_ndc=True)
+ out_stage1 = self.model_stage1.forward_sample(num_points=cfg.dataset.max_points,
+ camera=camera_full,
+ image_rgb=torch.stack(batch['images']).to('cuda'),
+ mask=torch.stack(batch['masks']).to('cuda'),
+ scheduler=cfg.run.diffusion_scheduler,
+ num_inference_steps=cfg.run.num_inference_steps,
+ )
+ # segment and normalize human/object
+ bs = len(out_stage1)
+ pred_hum, pred_obj = [], [] # predicted human/object points
+ cent_hum_pred, cent_obj_pred = [], []
+ radius_hum_pred, radius_obj_pred = [], []
+ T_hum, T_obj = [], []
+ num_samples = int(cfg.dataset.max_points / 2)
+ for i in range(bs):
+ pc: Pointclouds = out_stage1[i]
+ vc = pc.features_packed().cpu() # (P, 3), human is light blue [0.1, 1.0, 1.0], object light green [0.5, 1.0, 0]
+ points = pc.points_packed().cpu() # (P, 3)
+ mask_hum = vc[:, 2] > 0.5
+ pc_hum, pc_obj = points[mask_hum], points[~mask_hum]
+ # Up/Down-sample the points
+ pc_obj = self.upsample_predicted_pc(num_samples, pc_obj)
+ pc_hum = self.upsample_predicted_pc(num_samples, pc_hum)
+
+ # Normalize
+ cent_hum, cent_obj = torch.mean(pc_hum, 0, keepdim=True), torch.mean(pc_obj, 0, keepdim=True)
+ scale_hum = torch.sqrt(torch.sum((pc_hum - cent_hum) ** 2, -1).max())
+ scale_obj = torch.sqrt(torch.sum((pc_obj - cent_obj) ** 2, -1).max())
+ pc_hum = (pc_hum - cent_hum) / (2 * scale_hum)
+ pc_obj = (pc_obj - cent_obj) / (2 * scale_obj)
+ # Also update camera parameters for separate human + object
+ T_hum_scaled = (batch['T_ho'][i] + cent_hum.squeeze(0)) / (2 * scale_hum)
+ T_obj_scaled = (batch['T_ho'][i] + cent_obj.squeeze(0)) / (2 * scale_obj)
+
+ pred_hum.append(pc_hum)
+ pred_obj.append(pc_obj)
+ cent_hum_pred.append(cent_hum.squeeze(0))
+ cent_obj_pred.append(cent_obj.squeeze(0))
+ T_hum.append(T_hum_scaled * torch.tensor([-1, -1, 1])) # apply opencv to pytorch3d transform: flip x and y
+ T_obj.append(T_obj_scaled * torch.tensor([-1, -1, 1]))
+ radius_hum_pred.append(scale_hum)
+ radius_obj_pred.append(scale_obj)
+ # Pack data into a new batch dict
+ camera_hum = PerspectiveCameras(
+ R=torch.stack(batch['R']),
+ T=torch.stack(T_hum),
+ K=torch.stack(batch['K_hum']),
+ device='cuda',
+ in_ndc=True
+ )
+ camera_obj = PerspectiveCameras(
+ R=torch.stack(batch['R']),
+ T=torch.stack(T_obj),
+ K=torch.stack(batch['K_obj']), # the camera should be human/object specific!!!
+ device='cuda',
+ in_ndc=True
+ )
+ # use pc from predicted
+ pc_hum = Pointclouds([x.to('cuda') for x in pred_hum])
+ pc_obj = Pointclouds([x.to('cuda') for x in pred_obj])
+ # use center and radius from predicted
+ cent_hum = torch.stack(cent_hum_pred, 0).to('cuda')
+ cent_obj = torch.stack(cent_obj_pred, 0).to('cuda') # B, 3
+ radius_hum = torch.stack(radius_hum_pred, 0).to('cuda') # B, 1
+ radius_obj = torch.stack(radius_obj_pred, 0).to('cuda')
+ out_stage2: Pointclouds = self.model_stage2.forward_sample(
+ num_points=num_samples,
+ camera=camera_hum,
+ image_rgb=torch.stack(batch['images_hum'], 0).to('cuda'),
+ mask=torch.stack(batch['masks_hum'], 0).to('cuda'),
+ gt_pc=pc_hum,
+ rgb_obj=torch.stack(batch['images_obj'], 0).to('cuda'),
+ mask_obj=torch.stack(batch['masks_obj'], 0).to('cuda'),
+ pc_obj=pc_obj,
+ camera_obj=camera_obj,
+ cent_hum=cent_hum,
+ cent_obj=cent_obj,
+ radius_hum=radius_hum.unsqueeze(-1),
+ radius_obj=radius_obj.unsqueeze(-1),
+ sample_from_interm=True,
+ noise_step=cfg.run.sample_noise_step)
+ return out_stage1, out_stage2
+
+ def upsample_predicted_pc(self, num_samples, pc_obj):
+ """
+ Up/Downsample the points to given number
+ :param num_samples: the target number
+ :param pc_obj: (N, 3)
+ :return: (num_samples, 3)
+ """
+ if len(pc_obj) > num_samples:
+ ind_obj = np.random.choice(len(pc_obj), num_samples)
+ else:
+ ind_obj = np.concatenate([np.arange(len(pc_obj)), np.random.choice(len(pc_obj), num_samples - len(pc_obj))])
+ pc_obj = pc_obj.clone()[torch.from_numpy(ind_obj).long().to(pc_obj.device)]
+ return pc_obj
+
+
+@hydra.main(config_path='configs', config_name='configs', version_base='1.1')
+def main(cfg: ProjectConfig):
+ runner = DemoRunner(cfg)
+ runner.run()
+
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file
diff --git a/diffusion_utils.py b/diffusion_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5de761c18576109cfcb8fd36243424b19c78862a
--- /dev/null
+++ b/diffusion_utils.py
@@ -0,0 +1,313 @@
+import math
+from typing import List, Optional, Sequence, Union
+
+import imageio
+import logging
+import numpy as np
+import torch
+import torch.utils.data
+from PIL import Image
+from torch.distributions import Normal
+from torchvision.transforms.functional import to_pil_image
+from torchvision.utils import make_grid
+from tqdm import tqdm, trange
+from pytorch3d.renderer import (
+ AlphaCompositor,
+ NormWeightedCompositor,
+ OrthographicCameras,
+ PointsRasterizationSettings,
+ PointsRasterizer,
+ PointsRenderer,
+ look_at_view_transform)
+from pytorch3d.renderer.cameras import CamerasBase
+from pytorch3d.structures import Pointclouds
+from pytorch3d.structures.pointclouds import join_pointclouds_as_batch
+
+
+# Disable unnecessary imageio logging
+logging.getLogger("imageio_ffmpeg").setLevel(logging.ERROR)
+
+
+def rotation_matrix(axis, theta):
+ """
+ Return the rotation matrix associated with counterclockwise rotation about
+ the given axis by theta radians.
+ """
+ axis = np.asarray(axis)
+ axis = axis / np.sqrt(np.dot(axis, axis))
+ a = np.cos(theta / 2.0)
+ b, c, d = -axis * np.sin(theta / 2.0)
+ aa, bb, cc, dd = a * a, b * b, c * c, d * d
+ bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d
+ return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)],
+ [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)],
+ [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]])
+
+
+def rotate(vertices, faces):
+ '''
+ vertices: [numpoints, 3]
+ '''
+ M = rotation_matrix([0, 1, 0], np.pi / 2).transpose()
+ N = rotation_matrix([1, 0, 0], -np.pi / 4).transpose()
+ K = rotation_matrix([0, 0, 1], np.pi).transpose()
+
+ v, f = vertices[:, [1, 2, 0]].dot(M).dot(N).dot(K), faces[:, [1, 2, 0]]
+ return v, f
+
+
+def norm(v, f):
+ v = (v - v.min()) / (v.max() - v.min()) - 0.5
+
+ return v, f
+
+
+def getGradNorm(net):
+ pNorm = torch.sqrt(sum(torch.sum(p ** 2) for p in net.parameters()))
+ gradNorm = torch.sqrt(sum(torch.sum(p.grad ** 2) for p in net.parameters()))
+ return pNorm, gradNorm
+
+
+def weights_init(m):
+ classname = m.__class__.__name__
+ if classname.find('Conv') != -1 and m.weight is not None:
+ torch.nn.init.xavier_normal_(m.weight)
+ elif classname.find('BatchNorm') != -1:
+ m.weight.data.normal_()
+ m.bias.data.fill_(0)
+
+
+def discretized_gaussian_log_likelihood(x, *, means, log_scales):
+ # Assumes data is integers [0, 1]
+ assert x.shape == means.shape == log_scales.shape
+ px0 = Normal(torch.zeros_like(means), torch.ones_like(log_scales))
+
+ centered_x = x - means
+ inv_stdv = torch.exp(-log_scales)
+ plus_in = inv_stdv * (centered_x + 0.5)
+ cdf_plus = px0.cdf(plus_in)
+ min_in = inv_stdv * (centered_x - .5)
+ cdf_min = px0.cdf(min_in)
+ log_cdf_plus = torch.log(torch.max(cdf_plus, torch.ones_like(cdf_plus) * 1e-12))
+ log_one_minus_cdf_min = torch.log(torch.max(1. - cdf_min, torch.ones_like(cdf_min) * 1e-12))
+ cdf_delta = cdf_plus - cdf_min
+
+ log_probs = torch.where(
+ x < 0.001, log_cdf_plus,
+ torch.where(x > 0.999, log_one_minus_cdf_min,
+ torch.log(torch.max(cdf_delta, torch.ones_like(cdf_delta) * 1e-12))))
+ assert log_probs.shape == x.shape
+ return log_probs
+
+
+def fig2img(fig):
+ """Convert a Matplotlib figure to a PIL Image and return it"""
+ import io
+ buf = io.BytesIO()
+ fig.savefig(buf)
+ buf.seek(0)
+ img = Image.open(buf)
+ return img
+
+
+@torch.no_grad()
+def visualize_distance_transform(
+ path_stem: str,
+ images: torch.Tensor,
+) -> str:
+ output_file_image = f'{path_stem}.png'
+ if images.shape[3] in [1, 3]: # convert to (B, C, H, W)
+ images = images.permute(0, 3, 1, 2)
+ images = images[:, -1:] # (B, 1, H, W) # get only distances (not vectors for now, for simplicity)
+ image_grid = make_grid(images, nrow=int(math.sqrt(len(images))), pad_value=1, normalize=True)
+ to_pil_image(image_grid).save(output_file_image)
+ return output_file_image
+
+
+@torch.no_grad()
+def visualize_image(
+ path_stem: str,
+ images: torch.Tensor,
+ mean: Union[torch.Tensor, float] = 0.5,
+ std: Union[torch.Tensor, float] = 0.5,
+) -> str:
+ output_file_image = f'{path_stem}.png'
+ if images.shape[3] in [1, 3, 4]: # convert to (B, C, H, W)
+ images = images.permute(0, 3, 1, 2)
+ if images.shape[1] in [3, 4]: # normalize (single-channel images are not normalized)
+ images[:, :3] = images[:, :3] * std + mean # denormalize (color channels only, not alpha channel)
+ if images.shape[1] == 4: # normalize (single-channel images are not normalized)
+ image_alpha = images[:, 3:] # (B, 1, H, W)
+ bg_color = torch.tensor([230, 220, 250], device=images.device).reshape(1, 3, 1, 1) / 255
+ images = images[:, :3] * image_alpha + bg_color * (1 - image_alpha) # (B, 3, H, W)
+ image_grid = make_grid(images, nrow=int(math.sqrt(len(images))), pad_value=1)
+ to_pil_image(image_grid).save(output_file_image)
+ return output_file_image
+
+
+def ensure_point_cloud_has_colors(pointcloud: Pointclouds):
+ if pointcloud.features_padded() is None:
+ pointcloud = type(pointcloud)(points=pointcloud.points_padded(),
+ normals=pointcloud.normals_padded(), features=torch.zeros_like(pointcloud.points_padded()))
+ return pointcloud
+
+
+@torch.no_grad()
+def render_pointcloud_batch_pytorch3d(
+ cameras: CamerasBase,
+ pointclouds: Pointclouds,
+ image_size: int = 224,
+ radius: float = 0.01,
+ points_per_pixel: int = 10,
+ background_color: Sequence[float] = (0.78431373, 0.78431373, 0.78431373),
+ compositor: str = 'norm_weighted'
+):
+ # Define the settings for rasterization and shading. Here we set the output image to be of size
+ # 512x512. As we are rendering images for visualization purposes only we will set faces_per_pixel=1
+ # and blur_radius=0.0. Refer to rasterize_points.py for explanations of these parameters.
+ raster_settings = PointsRasterizationSettings(
+ image_size=image_size,
+ radius=radius,
+ points_per_pixel=points_per_pixel,
+ )
+
+ # Rasterizer
+ rasterizer = PointsRasterizer(cameras=cameras, raster_settings=raster_settings)
+
+ # Compositor
+ if compositor == 'alpha':
+ compositor = AlphaCompositor(background_color=background_color)
+ elif compositor == 'norm_weighted':
+ compositor = NormWeightedCompositor(background_color=background_color)
+ else:
+ raise ValueError(compositor)
+
+ # Create a points renderer by compositing points using an weighted compositor (3D points are
+ # weighted according to their distance to a pixel and accumulated using a weighted sum)
+ renderer = PointsRenderer(rasterizer=rasterizer, compositor=compositor)
+
+ # We cannot render a point cloud without colors, so add them if the pointcloud does
+ # not already have them
+ pointclouds = ensure_point_cloud_has_colors(pointclouds)
+
+ # Render batch of image
+ images = renderer(pointclouds)
+
+ return images
+
+
+@torch.no_grad()
+def visualize_pointcloud_batch_pytorch3d(
+ pointclouds: Pointclouds,
+ output_file_video: Optional[str] = None,
+ output_file_image: Optional[str] = None,
+ cameras: Optional[CamerasBase] = None, # if None, we rotate
+ scale_factor: float = 1.0,
+ num_frames: int = 1, # note that it takes a while with 30 * batch_size frames
+ elev: int = 30,
+):
+ """Saves a video and a single image of a point cloud"""
+ assert 360 % num_frames == 0, 'please select a better number of frames'
+
+ # Sizes
+ B, N, C, F = *(pointclouds.points_padded().shape), num_frames
+ device = pointclouds.device
+
+ # If a camera has not been provided, we render from a rotating view around an image
+ if cameras is None:
+
+ # Create view transforms - R is (F, 3, 3) and T is (F, 3)
+ R, T = look_at_view_transform(dist=10.0, elev=elev, azim=list(range(0, 360, 360 // F)), degrees=True, device=device)
+
+ # Repeat
+ R = R.repeat_interleave(B, dim=0) # (F * B, 3, 3)
+ T = T.repeat_interleave(B, dim=0) # (F * B, 3)
+ points = pointclouds.points_padded().tile(F, 1, 1) # (F * B, num_points, 3)
+ colors = (torch.zeros_like(points) if pointclouds.features_padded() is None else
+ pointclouds.features_padded().tile(F, 1, 1)) # (F * B, num_points, 3)
+
+ # Initialize batch of cameras
+ cameras = OrthographicCameras(focal_length=(0.25 * scale_factor), device=device, R=R, T=T)
+
+ # Wrap in Pointclouds (with color, even if the original point cloud had no color)
+ pointclouds = Pointclouds(points=points, features=colors).to(device)
+
+ # Render image
+ images = render_pointcloud_batch_pytorch3d(cameras, pointclouds)
+
+ # Convert images into grid
+ image_grids = []
+ images_for_grids = images.reshape(F, B, *images.shape[1:]).permute(0, 1, 4, 2, 3)
+ for image_for_grids in images_for_grids:
+ image_grid = make_grid(image_for_grids, nrow=int(math.sqrt(B)), pad_value=1)
+ image_grids.append(image_grid)
+ image_grids = torch.stack(image_grids, dim=0)
+ image_grids = image_grids.detach().cpu()
+
+ # Save image
+ if output_file_image is not None:
+ to_pil_image(image_grids[0]).save(output_file_image)
+
+ # Save video
+ if output_file_video:
+ video = (image_grids * 255).permute(0, 2, 3, 1).to(torch.uint8).numpy()
+ imageio.mimwrite(output_file_video, video, fps=10)
+
+
+@torch.no_grad()
+def visualize_pointcloud_evolution_pytorch3d(
+ pointclouds: Pointclouds,
+ output_file_video: str,
+ camera: Optional[CamerasBase] = None, # if None, we rotate
+ scale_factor: float = 1.0,
+):
+
+ # Device
+ B, device = len(pointclouds), pointclouds.device
+
+ # Cameras
+ if camera is None:
+ R, T = look_at_view_transform(dist=10.0, elev=30, azim=0, device=device)
+ camera = OrthographicCameras(focal_length=(0.25 * scale_factor), device=device, R=R, T=T)
+
+ # Render
+ frames = render_pointcloud_batch_pytorch3d(camera, pointclouds)
+
+ # Save video
+ video = (frames.detach().cpu() * 255).to(torch.uint8).numpy()
+ imageio.mimwrite(output_file_video, video, fps=10)
+
+
+def get_camera_index(cameras: CamerasBase, index: Optional[int] = None):
+ if index is None:
+ return cameras
+ kwargs = dict(
+ R=cameras.R[index].unsqueeze(0),
+ T=cameras.T[index].unsqueeze(0),
+ K=cameras.K[index].unsqueeze(0) if cameras.K is not None else None,
+ )
+ if hasattr(cameras, 'focal_length'):
+ kwargs['focal_length'] = cameras.focal_length[index].unsqueeze(0)
+ if hasattr(cameras, 'principal_point'):
+ kwargs['principal_point'] = cameras.principal_point[index].unsqueeze(0)
+ return type(cameras)(**kwargs).to(cameras.device)
+
+
+def get_metadata(item) -> str:
+ s = '-------------\n'
+ for key in item.keys():
+ value = item[key]
+ if torch.is_tensor(value) and value.numel() < 25:
+ value_str = value
+ elif torch.is_tensor(value):
+ value_str = value.shape
+ elif isinstance(value, str):
+ value_str = value
+ elif isinstance(value, list) and 0 < len(value) and len(value) < 25 and isinstance(value[0], str):
+ value_str = value
+ elif isinstance(value, dict):
+ value_str = str({k: type(v) for k, v in value.items()})
+ else:
+ value_str = type(value)
+ s += f"{key:<30} {value_str}\n"
+ return s
diff --git a/examples/017450/k1.color.jpg b/examples/017450/k1.color.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..8163a625027f9a49109a862400073d977b84d0a1
Binary files /dev/null and b/examples/017450/k1.color.jpg differ
diff --git a/examples/017450/k1.obj_rend_mask.png b/examples/017450/k1.obj_rend_mask.png
new file mode 100644
index 0000000000000000000000000000000000000000..b96e85199b2c498cc8f9edd5b994fc24a98f2edd
Binary files /dev/null and b/examples/017450/k1.obj_rend_mask.png differ
diff --git a/examples/017450/k1.person_mask.png b/examples/017450/k1.person_mask.png
new file mode 100644
index 0000000000000000000000000000000000000000..3a30e94d90552416dcec779d9fac76b89a839f16
Binary files /dev/null and b/examples/017450/k1.person_mask.png differ
diff --git a/model/__init__.py b/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2cfef041c03fbb5378b65dce56988386b407d59
--- /dev/null
+++ b/model/__init__.py
@@ -0,0 +1,28 @@
+from configs.structured import ProjectConfig
+from .model import ConditionalPointCloudDiffusionModel
+from .model_coloring import PointCloudColoringModel
+from .model_utils import set_requires_grad
+from .model_diff_data import ConditionalPCDiffusionSeparateSegm
+from .model_hoattn import CrossAttenHODiffusionModel
+
+def get_model(cfg: ProjectConfig):
+ if cfg.model.model_name == 'pc2-diff':
+ model = ConditionalPointCloudDiffusionModel(**cfg.model)
+ elif cfg.model.model_name == 'pc2-diff-ho-sepsegm':
+ model = ConditionalPCDiffusionSeparateSegm(**cfg.model)
+ print("Using a separate model to predict segmentation label")
+ elif cfg.model.model_name == 'diff-ho-attn':
+ model = CrossAttenHODiffusionModel(**cfg.model)
+ print("Using separate model for human + object with cross attention.")
+ else:
+ raise NotImplementedError
+ if cfg.run.freeze_feature_model:
+ set_requires_grad(model.feature_model, False)
+ return model
+
+
+def get_coloring_model(cfg: ProjectConfig):
+ model = PointCloudColoringModel(**cfg.model)
+ if cfg.run.freeze_feature_model:
+ set_requires_grad(model.feature_model, False)
+ return model
diff --git a/model/feature_model.py b/model/feature_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ea6795bfdc36e080fbef3b7e1ebb00eed1609c5
--- /dev/null
+++ b/model/feature_model.py
@@ -0,0 +1,160 @@
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers import ModelMixin
+from timm.models.vision_transformer import VisionTransformer, resize_pos_embed
+from torch import Tensor
+from torchvision.transforms import functional as TVF
+
+
+IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
+IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
+
+MODEL_URLS = {
+ 'vit_base_patch16_224_mae': 'https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth',
+ 'vit_small_patch16_224_msn': 'https://dl.fbaipublicfiles.com/msn/vits16_800ep.pth.tar',
+ 'vit_large_patch7_224_msn': 'https://dl.fbaipublicfiles.com/msn/vitl7_200ep.pth.tar',
+}
+
+NORMALIZATION = {
+ 'vit_base_patch16_224_mae': (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
+ 'vit_small_patch16_224_msn': (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
+ 'vit_large_patch7_224_msn': (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
+}
+
+MODEL_KWARGS = {
+ 'vit_base_patch16_224_mae': dict(
+ patch_size=16, embed_dim=768, depth=12, num_heads=12,
+ ),
+ 'vit_small_patch16_224_msn': dict(
+ patch_size=16, embed_dim=384, depth=12, num_heads=6,
+ ),
+ 'vit_large_patch7_224_msn': dict(
+ patch_size=7, embed_dim=1024, depth=24, num_heads=16,
+ )
+}
+
+
+class FeatureModel(ModelMixin, ConfigMixin):
+
+ @register_to_config
+ def __init__(
+ self,
+ image_size: int = 224,
+ model_name: str = 'vit_small_patch16_224_mae',
+ global_pool: str = '', # '' or 'token'
+ ) -> None:
+ super().__init__()
+ self.model_name = model_name
+
+ # Identity
+ if self.model_name == 'identity':
+ return
+
+ # Create model
+ self.model = VisionTransformer(
+ img_size=image_size, num_classes=0, global_pool=global_pool,
+ **MODEL_KWARGS[model_name])
+
+ # Model properties
+ self.feature_dim = self.model.embed_dim
+ self.mean, self.std = NORMALIZATION[model_name]
+
+ # # Modify MSN model with output head from training
+ # if model_name.endswith('msn'):
+ # use_bn = True
+ # emb_dim = (192 if 'tiny' in model_name else 384 if 'small' in model_name else
+ # 768 if 'base' in model_name else 1024 if 'large' in model_name else 1280)
+ # hidden_dim = 2048
+ # output_dim = 256
+ # self.model.fc = None
+ # fc = OrderedDict([])
+ # fc['fc1'] = torch.nn.Linear(emb_dim, hidden_dim)
+ # if use_bn:
+ # fc['bn1'] = torch.nn.BatchNorm1d(hidden_dim)
+ # fc['gelu1'] = torch.nn.GELU()
+ # fc['fc2'] = torch.nn.Linear(hidden_dim, hidden_dim)
+ # if use_bn:
+ # fc['bn2'] = torch.nn.BatchNorm1d(hidden_dim)
+ # fc['gelu2'] = torch.nn.GELU()
+ # fc['fc3'] = torch.nn.Linear(hidden_dim, output_dim)
+ # self.model.fc = torch.nn.Sequential(fc)
+
+ # Load pretrained checkpoint
+ checkpoint = torch.hub.load_state_dict_from_url(MODEL_URLS[model_name])
+ if 'model' in checkpoint:
+ state_dict = checkpoint['model']
+ elif 'target_encoder' in checkpoint:
+ state_dict = checkpoint['target_encoder']
+ state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
+ # NOTE: Comment the line below if using the projection head, uncomment if not using it
+ # See https://github.com/facebookresearch/msn/blob/81cb855006f41cd993fbaad4b6a6efbb486488e6/src/msn_train.py#L490-L502
+ # for more info about the projection head
+ state_dict = {k: v for k, v in state_dict.items() if not k.startswith('fc.')}
+ else:
+ raise NotImplementedError()
+ state_dict['pos_embed'] = resize_pos_embed(state_dict['pos_embed'], self.model.pos_embed)
+ self.model.load_state_dict(state_dict)
+ self.model.eval()
+
+ # # Modify MSN model with output head from training
+ # if model_name.endswith('msn'):
+ # self.fc = self.model.fc
+ # del self.model.fc
+ # else:
+ # self.fc = nn.Identity()
+
+ # NOTE: I've disabled the whole projection head stuff for simplicity for now
+ self.fc = nn.Identity()
+
+ def denormalize(self, img: Tensor):
+ img = TVF.normalize(img, mean=[-m/s for m, s in zip(self.mean, self.std)], std=[1/s for s in self.std])
+ return torch.clip(img, 0, 1)
+
+ def normalize(self, img: Tensor):
+ return TVF.normalize(img, mean=self.mean, std=self.std)
+
+ def forward(
+ self,
+ x: Tensor,
+ return_type: str = 'features',
+ return_upscaled_features: bool = True,
+ return_projection_head_output: bool = False,
+ ):
+ """Normalizes the input `x` and runs it through `model` to obtain features"""
+ assert return_type in {'cls_token', 'features', 'all'}
+
+ # Identity
+ if self.model_name == 'identity':
+ return x
+
+ # Normalize and forward
+ B, C, H, W = x.shape
+ x = self.normalize(x)
+ feats = self.model(x)
+
+ # Reshape to image-like size
+ if return_type in {'features', 'all'}:
+ B, T, D = feats.shape
+ assert math.sqrt(T - 1).is_integer()
+ HW_down = int(math.sqrt(T - 1)) # subtract one for CLS token
+ output_feats: Tensor = feats[:, 1:, :].reshape(B, HW_down, HW_down, D).permute(0, 3, 1, 2) # (B, D, H_down, W_down)
+ if return_upscaled_features:
+ output_feats = F.interpolate(output_feats, size=(H, W), mode='bilinear',
+ align_corners=False) # (B, D, H_orig, W_orig)
+
+ # Head for MSN
+ output_cls = feats[:, 0]
+ if return_projection_head_output and return_type in {'cls_token', 'all'}:
+ output_cls = self.fc(output_cls)
+
+ # Return
+ if return_type == 'cls_token':
+ return output_cls
+ elif return_type == 'features':
+ return output_feats
+ else:
+ return output_cls, output_feats
diff --git a/model/model.py b/model/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..d536ce36d8814359e8ef6fb574f1a7c55b38f741
--- /dev/null
+++ b/model/model.py
@@ -0,0 +1,303 @@
+import inspect
+import random
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
+from diffusers.schedulers.scheduling_ddim import DDIMScheduler
+from diffusers.schedulers.scheduling_pndm import PNDMScheduler
+from pytorch3d.implicitron.dataset.data_loader_map_provider import FrameData
+from pytorch3d.renderer.cameras import CamerasBase
+from pytorch3d.structures import Pointclouds
+from torch import Tensor
+from tqdm import tqdm
+
+from .model_utils import get_num_points, get_custom_betas
+from .point_cloud_model import PointCloudModel
+from .projection_model import PointCloudProjectionModel
+
+
+class ConditionalPointCloudDiffusionModel(PointCloudProjectionModel):
+
+ def __init__(
+ self,
+ beta_start: float,
+ beta_end: float,
+ beta_schedule: str,
+ point_cloud_model: str,
+ point_cloud_model_embed_dim: int,
+ **kwargs, # projection arguments
+ ):
+ super().__init__(**kwargs)
+
+ # Checks
+ if not self.predict_shape:
+ raise NotImplementedError('Must predict shape if performing diffusion.')
+
+ # Create diffusion model schedulers which define the sampling timesteps
+ self.dm_pred_type = kwargs.get('dm_pred_type', "epsilon")
+ assert self.dm_pred_type in ['epsilon','sample']
+ scheduler_kwargs = {"prediction_type": self.dm_pred_type}
+ if beta_schedule == 'custom':
+ scheduler_kwargs.update(dict(trained_betas=get_custom_betas(beta_start=beta_start, beta_end=beta_end)))
+ else:
+ scheduler_kwargs.update(dict(beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule))
+ self.schedulers_map = {
+ 'ddpm': DDPMScheduler(**scheduler_kwargs, clip_sample=False),
+ 'ddim': DDIMScheduler(**scheduler_kwargs, clip_sample=False),
+ 'pndm': PNDMScheduler(**scheduler_kwargs),
+ }
+ self.scheduler = self.schedulers_map['ddpm'] # this can be changed for inference
+
+ # Create point cloud model for processing point cloud at each diffusion step
+ self.init_pcloud_model(kwargs, point_cloud_model, point_cloud_model_embed_dim)
+
+ self.load_sample_init = kwargs.get('load_sample_init', False)
+ self.sample_init_scale = kwargs.get('sample_init_scale', 1.0)
+ self.test_init_with_gtpc = kwargs.get('test_init_with_gtpc', False)
+
+ self.consistent_center = kwargs.get('consistent_center', False)
+ self.cam_noise_std = kwargs.get('cam_noise_std', 0.0) # add noise to camera based on timestamps
+
+ def init_pcloud_model(self, kwargs, point_cloud_model, point_cloud_model_embed_dim):
+ self.point_cloud_model = PointCloudModel(
+ model_type=point_cloud_model,
+ embed_dim=point_cloud_model_embed_dim,
+ in_channels=self.in_channels,
+ out_channels=self.out_channels, # voxel resolution multiplier is 1.
+ voxel_resolution_multiplier=kwargs.get('voxel_resolution_multiplier', 1)
+ )
+
+ def forward_train(
+ self,
+ pc: Pointclouds,
+ camera: Optional[CamerasBase],
+ image_rgb: Optional[Tensor],
+ mask: Optional[Tensor],
+ return_intermediate_steps: bool = False,
+ **kwargs
+ ):
+
+ # Normalize colors and convert to tensor
+ x_0 = self.point_cloud_to_tensor(pc, normalize=True, scale=True) # this will not pack the point colors
+ B, N, D = x_0.shape
+
+ # Sample random noise
+ noise = torch.randn_like(x_0)
+ if self.consistent_center:
+ # modification suggested by https://arxiv.org/pdf/2308.07837.pdf
+ noise = noise - torch.mean(noise, dim=1, keepdim=True)
+
+ # Sample random timesteps for each point_cloud
+ timestep = torch.randint(0, self.scheduler.num_train_timesteps, (B,),
+ device=self.device, dtype=torch.long)
+
+ # Add noise to points
+ x_t = self.scheduler.add_noise(x_0, noise, timestep) # diffusion noisy adding, only add to the coordinate, not features
+
+ # add noise to the camera pose, based on timestamps
+ if self.cam_noise_std > 0.000001:
+ # the noise is very different
+ camera = camera.clone()
+ camT = camera.T # (B, 3)
+ dist = torch.sqrt(torch.sum(camT**2, -1, keepdim=True))
+ nratio = timestep[:, None] / self.scheduler.num_train_timesteps # time-dependent noise
+ tnoise = torch.randn(B, 3).to(dist.device)/3. * dist * self.cam_noise_std * nratio
+ camera.T = camera.T + tnoise
+
+ # Conditioning, the pixel-aligned feature is based on points with noise (new points)
+ x_t_input = self.get_diffu_input(camera, image_rgb, mask, timestep, x_t, **kwargs)
+
+ # Forward
+ loss, noise_pred = self.compute_loss(noise, timestep, x_0, x_t_input)
+
+ # Whether to return intermediate steps
+ if return_intermediate_steps:
+ return loss, (x_0, x_t, noise, noise_pred)
+
+ return loss
+
+ def compute_loss(self, noise, timestep, x_0, x_t_input):
+ x_pred = torch.zeros_like(x_0)
+ if self.self_conditioning:
+ # self conditioning, from https://openreview.net/pdf?id=3itjR9QxFw
+ if random.uniform(0, 1.) > 0.5:
+ with torch.no_grad():
+ x_pred = self.point_cloud_model(torch.cat([x_t_input, x_pred], -1), timestep)
+ noise_pred = self.point_cloud_model(torch.cat([x_t_input, x_pred], -1), timestep)
+ else:
+ noise_pred = self.point_cloud_model(x_t_input, timestep)
+ # Check
+ if not noise_pred.shape == noise.shape:
+ raise ValueError(f'{noise_pred.shape=} and {noise.shape=}')
+ # Loss
+ if self.dm_pred_type == 'epsilon':
+ loss = F.mse_loss(noise_pred, noise)
+ elif self.dm_pred_type == 'sample':
+ loss = F.mse_loss(noise_pred, x_0) # predicting sample
+ else:
+ raise NotImplementedError
+ return loss, noise_pred
+
+ def get_diffu_input(self, camera, image_rgb, mask, timestep, x_t, **kwargs):
+ "return: (B, N, D), the exact input to the diffusion model, x_t: (B, N, 3)"
+ x_t_input = self.get_input_with_conditioning(x_t, camera=camera,
+ image_rgb=image_rgb, mask=mask, t=timestep)
+ return x_t_input
+
+ @torch.no_grad()
+ def forward_sample(
+ self,
+ num_points: int,
+ camera: Optional[CamerasBase],
+ image_rgb: Optional[Tensor],
+ mask: Optional[Tensor],
+ # Optional overrides
+ scheduler: Optional[str] = 'ddpm',
+ # Inference parameters
+ num_inference_steps: Optional[int] = 1000,
+ eta: Optional[float] = 0.0, # for DDIM
+ # Whether to return all the intermediate steps in generation
+ return_sample_every_n_steps: int = -1,
+ # Whether to disable tqdm
+ disable_tqdm: bool = False,
+ gt_pc: Pointclouds = None,
+ **kwargs
+ ):
+
+ # Get scheduler from mapping, or use self.scheduler if None
+ scheduler = self.scheduler if scheduler is None else self.schedulers_map[scheduler]
+
+ # Get the size of the noise
+ N = num_points
+ B = 1 if image_rgb is None else image_rgb.shape[0]
+ D = self.get_x_T_channel()
+ device = self.device if image_rgb is None else image_rgb.device
+
+ sample_from_interm = kwargs.get('sample_from_interm', False)
+ interm_steps = kwargs.get('noise_step') if sample_from_interm else -1
+ x_t = self.initialize_x_T(device, gt_pc, (B, N, D), interm_steps, scheduler)
+ x_pred = torch.zeros_like(x_t)
+
+ # Set timesteps
+ extra_step_kwargs = self.setup_reverse_process(eta, num_inference_steps, scheduler)
+
+ # Loop over timesteps
+ all_outputs = []
+ return_all_outputs = (return_sample_every_n_steps > 0)
+ progress_bar = tqdm(scheduler.timesteps.to(device), desc=f'Sampling ({x_t.shape})', disable=disable_tqdm)
+
+ for i, t in enumerate(progress_bar):
+ add_interm_output = (return_all_outputs and (
+ i % return_sample_every_n_steps == 0 or i == len(scheduler.timesteps) - 1))
+ # Conditioning
+ x_t_input = self.get_diffu_input(camera, image_rgb, mask, t, x_t, **kwargs)
+ if self.self_conditioning:
+ x_t_input = torch.cat([x_t_input, x_pred], -1) # add self-conditioning
+ inference_binary = (i == len(progress_bar) - 1) | add_interm_output
+ # One reverse step with conditioning
+ x_t = self.reverse_step(extra_step_kwargs, scheduler, t, x_t, x_t_input,
+ inference_binary=inference_binary) # (B, N, D), D=3 or 4
+ x_pred = x_t # for next iteration self conditioning
+
+ # Append to output list if desired
+ if add_interm_output:
+ all_outputs.append(x_t)
+
+ # Convert output back into a point cloud, undoing normalization and scaling
+ output = self.tensor_to_point_cloud(x_t, denormalize=True, unscale=True) # this convert the points back to original scale
+ if return_all_outputs:
+ all_outputs = torch.stack(all_outputs, dim=1) # (B, sample_steps, N, D)
+ all_outputs = [self.tensor_to_point_cloud(o, denormalize=True, unscale=True) for o in all_outputs]
+
+ return (output, all_outputs) if return_all_outputs else output
+
+ def get_x_T_channel(self):
+ D = 3 + (self.color_channels if self.predict_color else 0)
+ return D
+
+ def initialize_x_T(self, device, gt_pc, shape, interm_steps:int=-1, scheduler=None):
+ B, N, D = shape
+ # Sample noise initialization
+ if interm_steps > 0:
+ # Sample from some intermediate steps
+ x_0 = self.point_cloud_to_tensor(gt_pc, normalize=True, scale=True)
+ noise = torch.randn(B, N, D, device=device)
+
+ # always make sure the noise does not change the pc center, this is important to reduce 0.1cm CD!
+ noise = noise - torch.mean(noise, dim=1, keepdim=True)
+
+ x_t = scheduler.add_noise(x_0, noise, torch.tensor([interm_steps - 1] * B).long().to(device)) # Add noise
+ else:
+ # Sample from random Gaussian
+ x_t = torch.randn(B, N, D, device=device)
+
+ x_t = x_t * self.sample_init_scale # for test
+ if self.consistent_center:
+ x_t = x_t - torch.mean(x_t, dim=1, keepdim=True)
+ return x_t
+
+ def reverse_step(self, extra_step_kwargs, scheduler, t, x_t, x_t_input, **kwargs):
+ """
+ run one reverse step to compute x_t
+ :param extra_step_kwargs:
+ :param scheduler:
+ :param t: [1], diffusion time step
+ :param x_t: (B, N, 3)
+ :param x_t_input: conditional features (B, N, F)
+ :param kwargs: other configurations to run diffusion step
+ :return: denoised x_t
+ """
+ B = x_t.shape[0]
+ # Forward
+ noise_pred = self.point_cloud_model(x_t_input, t.reshape(1).expand(B))
+ if self.consistent_center:
+ assert self.dm_pred_type != 'sample', 'incompatible dm predition type for CCD!'
+ # suggested by the CCD-3DR paper
+ noise_pred = noise_pred - torch.mean(noise_pred, dim=1, keepdim=True)
+ # Step
+ x_t = scheduler.step(noise_pred, t, x_t, **extra_step_kwargs).prev_sample
+ if self.consistent_center:
+ x_t = x_t - torch.mean(x_t, dim=1, keepdim=True)
+ return x_t
+
+ def setup_reverse_process(self, eta, num_inference_steps, scheduler):
+ """
+ setup diffusion chain, and others.
+ """
+ accepts_offset = "offset" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ extra_set_kwargs = {"offset": 1} if accepts_offset else {}
+ scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
+ # Prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(scheduler.step).parameters.keys())
+ extra_step_kwargs = {"eta": eta} if accepts_eta else {}
+ return extra_step_kwargs
+
+ def forward(self, batch: FrameData, mode: str = 'train', **kwargs):
+ """
+ A wrapper around the forward method for training and inference
+ """
+ if isinstance(batch, dict): # fixes a bug with multiprocessing where batch becomes a dict
+ batch = FrameData(**batch) # it really makes no sense, I do not understand it
+
+ if mode == 'train':
+ return self.forward_train(
+ pc=batch.sequence_point_cloud,
+ camera=batch.camera,
+ image_rgb=batch.image_rgb,
+ mask=batch.fg_probability,
+ **kwargs)
+ elif mode == 'sample':
+ num_points = kwargs.pop('num_points', get_num_points(batch.sequence_point_cloud))
+ return self.forward_sample(
+ num_points=num_points,
+ camera=batch.camera,
+ image_rgb=batch.image_rgb,
+ mask=batch.fg_probability,
+ **kwargs)
+ else:
+ raise NotImplementedError()
\ No newline at end of file
diff --git a/model/model_coloring.py b/model/model_coloring.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d7ba33202a3dc876f816080539bdf7da50023df
--- /dev/null
+++ b/model/model_coloring.py
@@ -0,0 +1,84 @@
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from pytorch3d.implicitron.dataset.data_loader_map_provider import FrameData
+from pytorch3d.renderer.cameras import CamerasBase
+from pytorch3d.structures import Pointclouds
+from torch import Tensor
+
+from .point_cloud_transformer_model import PointCloudTransformerModel
+from .projection_model import PointCloudProjectionModel
+
+class PointCloudColoringModel(PointCloudProjectionModel):
+
+ def __init__(
+ self,
+ point_cloud_model: str,
+ point_cloud_model_layers: int,
+ point_cloud_model_embed_dim: int,
+ **kwargs, # projection arguments
+ ):
+ super().__init__(**kwargs)
+
+ # Checks
+ if self.predict_shape or not self.predict_color:
+ raise NotImplementedError('Must predict color, not shape, for coloring')
+
+ # Create point cloud model for processing point cloud
+ self.point_cloud_model = PointCloudTransformerModel(
+ num_layers=point_cloud_model_layers,
+ model_type=point_cloud_model,
+ embed_dim=point_cloud_model_embed_dim,
+ in_channels=self.in_channels,
+ out_channels=self.out_channels,
+ ) # why use transformer instead???
+
+ def _forward(
+ self,
+ pc: Pointclouds,
+ camera: Optional[CamerasBase],
+ image_rgb: Optional[Tensor],
+ mask: Optional[Tensor],
+ return_point_cloud: bool = False,
+ noise_std: float = 0.0,
+ ):
+
+ # Normalize colors and convert to tensor
+ x = self.point_cloud_to_tensor(pc, normalize=True, scale=True)
+ x_points, x_colors = x[:, :, :3], x[:, :, 3:]
+
+ # Add noise to points. TODO: Add to configs.
+ x_input = x_points + torch.randn_like(x_points) * noise_std # simulate noise of the predicted pc?
+
+ # Conditioning
+ # x_input = self.get_input_with_conditioning(x_input, camera=camera,
+ # image_rgb=image_rgb, mask=mask)
+ # XH: edit to run
+ x_input = self.get_input_with_conditioning(x_input, camera=camera,
+ image_rgb=image_rgb, mask=mask, t=None)
+
+ # Forward
+ pred_colors = self.point_cloud_model(x_input)
+
+ # During inference, we return the point cloud with the predicted colors
+ if return_point_cloud:
+ pred_pointcloud = self.tensor_to_point_cloud(
+ torch.cat((x_points, pred_colors), dim=2), denormalize=True, unscale=True)
+ return pred_pointcloud
+
+ # During training, we have ground truth colors and return the loss
+ loss = F.mse_loss(pred_colors, x_colors)
+ return loss
+
+ def forward(self, batch: FrameData, **kwargs):
+ """A wrapper around the forward method"""
+ if isinstance(batch, dict): # fixes a bug with multiprocessing where batch becomes a dict
+ batch = FrameData(**batch) # it really makes no sense, I do not understand it
+ return self._forward(
+ pc=batch.sequence_point_cloud,
+ camera=batch.camera,
+ image_rgb=batch.image_rgb,
+ mask=batch.fg_probability,
+ **kwargs,
+ )
\ No newline at end of file
diff --git a/model/model_diff_data.py b/model/model_diff_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..80726006ad7ae6097d164f29cf2745742948c6fc
--- /dev/null
+++ b/model/model_diff_data.py
@@ -0,0 +1,238 @@
+"""
+model to deal with shapenet inputs and other datasets such as Behave and ProciGen
+the model takes a different data dictionary in forward function
+"""
+import inspect
+from typing import Optional
+import numpy as np
+
+import torch
+import torch.nn.functional as F
+from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
+from diffusers.schedulers.scheduling_ddim import DDIMScheduler
+from diffusers.schedulers.scheduling_pndm import PNDMScheduler
+from pytorch3d.implicitron.dataset.data_loader_map_provider import FrameData
+from pytorch3d.renderer.cameras import CamerasBase
+from pytorch3d.structures import Pointclouds
+from torch import Tensor
+from tqdm import tqdm
+from pytorch3d.renderer import PerspectiveCameras
+from pytorch3d.datasets.r2n2.utils import BlenderCamera
+
+
+from .model import ConditionalPointCloudDiffusionModel
+from .model_utils import get_num_points
+
+
+class ConditionalPCDiffusionShapenet(ConditionalPointCloudDiffusionModel):
+ def forward(self, batch, mode: str = 'train', **kwargs):
+ """
+ take a batch of data from ShapeNet
+ """
+ images = torch.stack(batch['images'], 0).to('cuda')
+ masks = torch.stack(batch['masks'], 0).to('cuda')
+ pc = Pointclouds([x.to('cuda') for x in batch['pclouds']])
+ camera = BlenderCamera(
+ torch.stack(batch['R']),
+ torch.stack(batch['T']),
+ torch.stack(batch['K']), device='cuda'
+ )
+
+ if mode == 'train':
+ return self.forward_train(
+ pc=pc,
+ camera=camera,
+ image_rgb=images,
+ mask=masks,
+
+ **kwargs)
+ elif mode == 'sample':
+ num_points = kwargs.pop('num_points', get_num_points(pc))
+ return self.forward_sample(
+ num_points=num_points,
+ camera=camera,
+ image_rgb=images,
+ mask=masks,
+ gt_pc=pc,
+ **kwargs)
+ else:
+ raise NotImplementedError()
+
+
+class ConditionalPCDiffusionBehave(ConditionalPointCloudDiffusionModel):
+ "diffusion model for Behave dataset"
+ def forward(self, batch, mode: str = 'train', **kwargs):
+ images = torch.stack(batch['images'], 0).to('cuda')
+ masks = torch.stack(batch['masks'], 0).to('cuda')
+ pc = self.get_input_pc(batch)
+ camera = PerspectiveCameras(
+ R=torch.stack(batch['R']),
+ T=torch.stack(batch['T']),
+ K=torch.stack(batch['K']),
+ device='cuda',
+ in_ndc=True
+ )
+ grid_df = torch.stack(batch['grid_df'], 0).to('cuda') if 'grid_df' in batch else None
+ num_points = kwargs.pop('num_points', get_num_points(pc))
+ if mode == 'train':
+ return self.forward_train(
+ pc=pc,
+ camera=camera,
+ image_rgb=images,
+ mask=masks,
+ grid_df=grid_df,
+ **kwargs)
+ elif mode == 'sample':
+ return self.forward_sample(
+ num_points=num_points,
+ camera=camera,
+ image_rgb=images,
+ mask=masks,
+ gt_pc=pc,
+ **kwargs)
+ else:
+ raise NotImplementedError()
+
+ def get_input_pc(self, batch):
+ pc = Pointclouds([x.to('cuda') for x in batch['pclouds']])
+ return pc
+
+
+class ConditionalPCDiffusionSeparateSegm(ConditionalPCDiffusionBehave):
+ "a separate model to predict binary labels, the final segmentation model"
+ def __init__(self,
+ beta_start: float,
+ beta_end: float,
+ beta_schedule: str,
+ point_cloud_model: str,
+ point_cloud_model_embed_dim: int,
+ **kwargs, # projection arguments
+ ):
+ super(ConditionalPCDiffusionSeparateSegm, self).__init__(beta_start, beta_end, beta_schedule,
+ point_cloud_model,
+ point_cloud_model_embed_dim, **kwargs)
+ # add a separate model to predict binary label
+ from .point_cloud_transformer_model import PointCloudTransformerModel, PointCloudModel
+
+ self.binary_model = PointCloudTransformerModel(
+ num_layers=1, # XH: use the default color model number of layers
+ model_type=point_cloud_model, # pvcnn
+ embed_dim=point_cloud_model_embed_dim, # save as pc shape model
+ in_channels=self.in_channels,
+ out_channels=1,
+ )
+ self.binary_training_noise_std = kwargs.get("binary_training_noise_std", 0.1)
+
+ # re-initialize point cloud model
+ assert self.predict_binary
+ self.point_cloud_model = PointCloudModel(
+ model_type=point_cloud_model,
+ embed_dim=point_cloud_model_embed_dim,
+ in_channels=self.in_channels,
+ out_channels=self.out_channels - 1, # not predicting binary from this anymore
+ voxel_resolution_multiplier=kwargs.get('voxel_resolution_multiplier', 1)
+ )
+
+ def forward_train(
+ self,
+ pc: Pointclouds,
+ camera: Optional[CamerasBase],
+ image_rgb: Optional[Tensor],
+ mask: Optional[Tensor],
+ return_intermediate_steps: bool = False,
+ **kwargs
+ ):
+ # first run shape forward, then binary label forward
+ assert not return_intermediate_steps
+ assert self.predict_binary
+ loss_shape = super(ConditionalPCDiffusionSeparateSegm, self).forward_train(pc,
+ camera,
+ image_rgb,
+ mask,
+ return_intermediate_steps,
+ **kwargs)
+
+ # binary label forward
+ x_0 = self.point_cloud_to_tensor(pc, normalize=True, scale=True)
+ x_points, x_colors = x_0[:, :, :3], x_0[:, :, 3:]
+
+ # Add noise to points.
+ x_input = x_points + torch.randn_like(x_points) * self.binary_training_noise_std # std=0.1
+ x_input = self.get_input_with_conditioning(x_input, camera=camera,
+ image_rgb=image_rgb, mask=mask, t=None)
+
+ # Forward
+ pred_segm = self.binary_model(x_input)
+
+ # use compressed bits
+ df_grid = kwargs.get('grid_df', None).unsqueeze(1) # (B, 1, resz, resy, resx)
+ points = x_points.clone().detach() / self.scale_factor * 2 # , normalize to [-1, 1]
+ points[:, :, 0], points[:, :, 2] = points[:, :, 2].clone(), points[:, :,0].clone() # swap, make sure clone is used!
+ points = points.unsqueeze(1).unsqueeze(1) # (B,1, 1, N, 3)
+ with torch.no_grad():
+ df_interp = F.grid_sample(df_grid, points, padding_mode='border', align_corners=True).squeeze(1).squeeze(1) # (B, 1, 1, 1, N)
+ binary_label = df_interp[:, 0] > 0.5 # (B, 1, N)
+
+ binary_pred = torch.sigmoid(pred_segm.squeeze(-1)) # add a sigmoid layer
+ loss_binary = F.mse_loss(binary_pred, binary_label.float().squeeze(1).squeeze(1)) * self.lw_binary
+ loss = loss_shape + loss_binary
+
+ return loss, torch.tensor([loss_shape, loss_binary])
+
+ def reverse_step(self, extra_step_kwargs, scheduler, t, x_t, x_t_input, **kwargs):
+ "return (B, N, 4), the 4-th channel is binary label"
+ B = x_t.shape[0]
+ # Forward
+ noise_pred = self.point_cloud_model(x_t_input, t.reshape(1).expand(B))
+ if self.consistent_center:
+ assert self.dm_pred_type != 'sample', 'incompatible dm predition type!'
+ # suggested by the CCD-3DR paper
+ noise_pred = noise_pred - torch.mean(noise_pred, dim=1, keepdim=True)
+ # Step: make sure only update the shape (first 3 channels)
+ x_t = scheduler.step(noise_pred, t, x_t[:, :, :3], **extra_step_kwargs).prev_sample
+ if self.consistent_center:
+ x_t = x_t - torch.mean(x_t, dim=1, keepdim=True)
+
+ # also add binary prediction
+ if kwargs.get('inference_binary', False):
+ pred_segm = self.binary_model(x_t_input)
+ else:
+ pred_segm = torch.zeros_like(x_t[:, :, 0:1])
+
+ x_t = torch.cat([x_t, torch.sigmoid(pred_segm)], -1)
+
+ return x_t
+
+ def get_coord_feature(self, x_t):
+ x_t_input = [x_t[:, :, :3]]
+ return x_t_input
+
+ def tensor_to_point_cloud(self, x: Tensor, /, denormalize: bool = False, unscale: bool = False):
+ """
+ take binary label into account
+ :param self:
+ :param x: (B, N, 4), the 4th channel is the binary segmentation, 1-human, 0-object
+ :param denormalize: denormalize the per-point colors, from pc2
+ :param unscale: undo point scaling, from pc2
+ :return: pc with point colors if predict binary label or per-point color
+ """
+ points = x[:, :, :3] / (self.scale_factor if unscale else 1)
+ if self.predict_color:
+ colors = self.denormalize(x[:, :, 3:]) if denormalize else x[:, :, 3:]
+ return Pointclouds(points=points, features=colors)
+ else:
+ if self.predict_binary:
+ assert x.shape[2] == 4
+ # add color to predicted binary labels
+ is_hum = x[:, :, 3] > 0.5
+ features = []
+ for mask in is_hum:
+ color = torch.zeros_like(x[0, :, :3]) + torch.tensor([0.5, 1.0, 0]).to(x.device)
+ color[mask, :] = torch.tensor([0.05, 1.0, 1.0]).to(x.device) # human is light blue, object light green
+ features.append(color)
+ else:
+ assert x.shape[2] == 3
+ features = None
+ return Pointclouds(points=points, features=features)
+
+
diff --git a/model/model_hoattn.py b/model/model_hoattn.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf7574f24a60dfc626a9f574924cc6bf94382f51
--- /dev/null
+++ b/model/model_hoattn.py
@@ -0,0 +1,457 @@
+"""
+model that use cross attention to predict human + object
+"""
+
+import inspect
+import random
+from typing import Optional
+from torch import Tensor
+import torch
+import numpy as np
+
+from pytorch3d.structures import Pointclouds
+from pytorch3d.renderer import CamerasBase
+from .model_diff_data import ConditionalPCDiffusionBehave
+from .pvcnn.pvcnn_ho import PVCNN2HumObj
+import torch.nn.functional as F
+from pytorch3d.renderer import PerspectiveCameras
+from .model_utils import get_num_points
+from tqdm import tqdm
+
+
+class CrossAttenHODiffusionModel(ConditionalPCDiffusionBehave):
+ def init_pcloud_model(self, kwargs, point_cloud_model, point_cloud_model_embed_dim):
+ """use cross attention model"""
+ if point_cloud_model == 'pvcnn':
+ self.point_cloud_model = PVCNN2HumObj(embed_dim=point_cloud_model_embed_dim,
+ num_classes=self.out_channels,
+ extra_feature_channels=(self.in_channels - 3),
+ voxel_resolution_multiplier=kwargs.get('voxel_resolution_multiplier', 1),
+ attn_type=kwargs.get('attn_type', 'simple-cross'),
+ attn_weight=kwargs.get("attn_weight", 1.0)
+ )
+ else:
+ raise ValueError(f"Unknown point cloud model {point_cloud_model}!")
+ self.point_visible_test = kwargs.get("point_visible_test", 'single') # when doing point visibility test, use only human points or human + object?
+ assert self.point_visible_test in ['single', 'combine'], f'invalide point visible test option {self.point_visible_test}'
+ # print(f"Point visibility test is based on {self.point_visible_test} point clouds!")
+
+ def forward_train(
+ self,
+ pc: Pointclouds,
+ camera: Optional[CamerasBase],
+ image_rgb: Optional[Tensor],
+ mask: Optional[Tensor],
+ return_intermediate_steps: bool = False,
+ **kwargs
+ ):
+ "additional input (RGB, mask, camera, and pc) for object is read from kwargs"
+ # assert not self.consistent_center
+ assert not self.self_conditioning
+
+ # Normalize colors and convert to tensor
+ x0_h = self.point_cloud_to_tensor(pc, normalize=True, scale=True) # this will not pack the point colors
+ x0_o = self.point_cloud_to_tensor(kwargs.get('pc_obj'), normalize=True, scale=True)
+ B, N, D = x0_h.shape
+
+ # Sample random noise
+ noise = torch.randn_like(x0_h)
+ if self.consistent_center:
+ # modification suggested by https://arxiv.org/pdf/2308.07837.pdf
+ noise = noise - torch.mean(noise, dim=1, keepdim=True)
+
+ # Sample random timesteps for each point_cloud
+ timestep = torch.randint(0, self.scheduler.num_train_timesteps, (B,),
+ device=self.device, dtype=torch.long)
+ # timestep = torch.randint(0, 1, (B,),
+ # device=self.device, dtype=torch.long)
+
+ # Add noise to points
+ xt_h = self.scheduler.add_noise(x0_h, noise, timestep)
+ xt_o = self.scheduler.add_noise(x0_o, noise, timestep)
+ norm_parms = self.pack_norm_params(kwargs) # (2, B, 4)
+
+ # get input conditioning
+ x_t_input_h, x_t_input_o = self.get_image_conditioning(camera, image_rgb, kwargs, mask, norm_parms, timestep,
+ xt_h, xt_o)
+
+ # Diffusion prediction
+ noise_pred_h, noise_pred_o = self.point_cloud_model(x_t_input_h, x_t_input_o, timestep, norm_parms)
+
+ # Check
+ if not noise_pred_h.shape == noise.shape:
+ raise ValueError(f'{noise_pred_h.shape=} and {noise.shape=}')
+ if not noise_pred_o.shape == noise.shape:
+ raise ValueError(f'{noise_pred_o.shape=} and {noise.shape=}')
+
+ # Loss
+ loss_h = F.mse_loss(noise_pred_h, noise)
+ loss_o = F.mse_loss(noise_pred_o, noise)
+
+ loss = loss_h + loss_o
+
+ # Whether to return intermediate steps
+ if return_intermediate_steps:
+ return loss, (x0_h, xt_h, noise, noise_pred_h)
+
+ return loss, torch.tensor([loss_h, loss_o])
+
+ def get_image_conditioning(self, camera, image_rgb, kwargs, mask, norm_parms, timestep, xt_h, xt_o):
+ """
+ compute image features for each point
+ :param camera:
+ :param image_rgb:
+ :param kwargs:
+ :param mask:
+ :param norm_parms:
+ :param timestep:
+ :param xt_h:
+ :param xt_o:
+ :return:
+ """
+ if self.point_visible_test == 'single':
+ # Visibility test is down independently for human and object
+ x_t_input_h = self.get_input_with_conditioning(xt_h, camera=camera,
+ image_rgb=image_rgb, mask=mask, t=timestep)
+ x_t_input_o = self.get_input_with_conditioning(xt_o, camera=kwargs.get('camera_obj'),
+ image_rgb=kwargs.get('rgb_obj'),
+ mask=kwargs.get('mask_obj'), t=timestep)
+ elif self.point_visible_test == 'combine':
+ # Combine human + object points to do visibility test and obtain features
+ B, N = xt_h.shape[:2] # (B, N, 3)
+ # for human: transform object points first to H+O space, then to human space
+ xt_o_in_ho = xt_o * 2 * norm_parms[1, :, 3:].unsqueeze(1) + norm_parms[1, :, :3].unsqueeze(1)
+ xt_o_in_hum = (xt_o_in_ho - norm_parms[0, :, :3].unsqueeze(1)) / (2 * norm_parms[0, :, 3:].unsqueeze(1))
+ # compute features for all points, take only first half feature for human
+ x_t_input_h = self.get_input_with_conditioning(torch.cat([xt_h, xt_o_in_hum], 1), camera=camera,
+ image_rgb=image_rgb, mask=mask, t=timestep)[:,:N]
+ # for object: transform human points to H+O space, then to object space
+ xt_h_in_ho = xt_h * 2 * norm_parms[0, :, 3:].unsqueeze(1) + norm_parms[0, :, :3].unsqueeze(1)
+ xt_h_in_obj = (xt_h_in_ho - norm_parms[1, :, :3].unsqueeze(1)) / (2 * norm_parms[1, :, 3:].unsqueeze(1))
+ x_t_input_o = self.get_input_with_conditioning(torch.cat([xt_o, xt_h_in_obj], 1),
+ camera=kwargs.get('camera_obj'),
+ image_rgb=kwargs.get('rgb_obj'),
+ mask=kwargs.get('mask_obj'), t=timestep)[:, :N]
+ else:
+ raise NotImplementedError
+ return x_t_input_h, x_t_input_o
+
+ def forward(self, batch, mode: str = 'train', **kwargs):
+ """"""
+ images = torch.stack(batch['images'], 0).to('cuda')
+ masks = torch.stack(batch['masks'], 0).to('cuda')
+ pc = self.get_input_pc(batch)
+ camera = PerspectiveCameras(
+ R=torch.stack(batch['R']),
+ T=torch.stack(batch['T_hum']),
+ K=torch.stack(batch['K_hum']),
+ device='cuda',
+ in_ndc=True
+ )
+ grid_df = torch.stack(batch['grid_df'], 0).to('cuda') if 'grid_df' in batch else None
+ num_points = kwargs.pop('num_points', get_num_points(pc))
+
+ rgb_obj = torch.stack(batch['images_obj'], 0).to('cuda')
+ masks_obj = torch.stack(batch['masks_obj'], 0).to('cuda')
+ pc_obj = Pointclouds([x.to('cuda') for x in batch['pclouds_obj']])
+ camera_obj = PerspectiveCameras(
+ R=torch.stack(batch['R']),
+ T=torch.stack(batch['T_obj']),
+ K=torch.stack(batch['K_obj']),
+ device='cuda',
+ in_ndc=True
+ )
+
+ # normalization parameters
+ cent_hum = torch.stack(batch['cent_hum'], 0).to('cuda')
+ cent_obj = torch.stack(batch['cent_obj'], 0).to('cuda') # B, 3
+ radius_hum = torch.stack(batch['radius_hum'], 0).to('cuda') # B, 1
+ radius_obj = torch.stack(batch['radius_obj'], 0).to('cuda')
+
+ # print(batch['image_path'])
+
+ if mode == 'train':
+ return self.forward_train(
+ pc=pc,
+ camera=camera,
+ image_rgb=images,
+ mask=masks,
+ grid_df=grid_df,
+ rgb_obj=rgb_obj,
+ mask_obj=masks_obj,
+ pc_obj=pc_obj,
+ camera_obj=camera_obj,
+ cent_hum=cent_hum,
+ cent_obj=cent_obj,
+ radius_hum=radius_hum,
+ radius_obj=radius_obj,
+ )
+ elif mode == 'sample':
+ # this use GT centers to do projection
+ return self.forward_sample(
+ num_points=num_points,
+ camera=camera,
+ image_rgb=images,
+ mask=masks,
+ gt_pc=pc,
+ rgb_obj=rgb_obj,
+ mask_obj=masks_obj,
+ pc_obj=pc_obj,
+ camera_obj=camera_obj,
+ cent_hum=cent_hum,
+ cent_obj=cent_obj,
+ radius_hum=radius_hum,
+ radius_obj=radius_obj,
+ **kwargs)
+ elif mode == 'interm-gt':
+ return self.forward_sample(
+ num_points=num_points,
+ camera=camera,
+ image_rgb=images,
+ mask=masks,
+ gt_pc=pc,
+ rgb_obj=rgb_obj,
+ mask_obj=masks_obj,
+ pc_obj=pc_obj,
+ camera_obj=camera_obj,
+ cent_hum=cent_hum,
+ cent_obj=cent_obj,
+ radius_hum=radius_hum,
+ radius_obj=radius_obj,
+ sample_from_interm=True,
+ **kwargs)
+ elif mode == 'interm-pred':
+ # use camera from predicted
+ camera = PerspectiveCameras(
+ R=torch.stack(batch['R']),
+ T=torch.stack(batch['T_hum_scaled']),
+ K=torch.stack(batch['K_hum']),
+ device='cuda',
+ in_ndc=True
+ )
+ camera_obj = PerspectiveCameras(
+ R=torch.stack(batch['R']),
+ T=torch.stack(batch['T_obj_scaled']),
+ K=torch.stack(batch['K_obj']), # the camera should be human/object specific!!!
+ device='cuda',
+ in_ndc=True
+ )
+ # use pc from predicted
+ pc = Pointclouds([x.to('cuda') for x in batch['pred_hum']])
+ pc_obj = Pointclouds([x.to('cuda') for x in batch['pred_obj']])
+ # use center and radius from predicted
+ cent_hum = torch.stack(batch['cent_hum_pred'], 0).to('cuda')
+ cent_obj = torch.stack(batch['cent_obj_pred'], 0).to('cuda') # B, 3
+ radius_hum = torch.stack(batch['radius_hum_pred'], 0).to('cuda') # B, 1
+ radius_obj = torch.stack(batch['radius_obj_pred'], 0).to('cuda')
+
+ return self.forward_sample(
+ num_points=num_points,
+ camera=camera,
+ image_rgb=images,
+ mask=masks,
+ gt_pc=pc,
+ rgb_obj=rgb_obj,
+ mask_obj=masks_obj,
+ pc_obj=pc_obj,
+ camera_obj=camera_obj,
+ cent_hum=cent_hum,
+ cent_obj=cent_obj,
+ radius_hum=radius_hum,
+ radius_obj=radius_obj,
+ sample_from_interm=True,
+ **kwargs)
+ elif mode == 'interm-pred-ts':
+ # use only estimate translation and scale, but sample from gaussian
+ # this works, the camera is GT!!!
+ pc = Pointclouds([x.to('cuda') for x in batch['pred_hum']])
+ pc_obj = Pointclouds([x.to('cuda') for x in batch['pred_obj']])
+ # use center and radius from predicted
+ cent_hum = torch.stack(batch['cent_hum_pred'], 0).to('cuda')
+ cent_obj = torch.stack(batch['cent_obj_pred'], 0).to('cuda') # B, 3
+ radius_hum = torch.stack(batch['radius_hum_pred'], 0).to('cuda') # B, 1
+ radius_obj = torch.stack(batch['radius_obj_pred'], 0).to('cuda')
+ # print(cent_hum[0], radius_hum[0], cent_obj[0], radius_obj[0])
+
+ return self.forward_sample(
+ num_points=num_points,
+ camera=camera,
+ image_rgb=images,
+ mask=masks,
+ gt_pc=pc,
+ rgb_obj=rgb_obj,
+ mask_obj=masks_obj,
+ pc_obj=pc_obj,
+ camera_obj=camera_obj,
+ cent_hum=cent_hum,
+ cent_obj=cent_obj,
+ radius_hum=radius_hum,
+ radius_obj=radius_obj,
+ sample_from_interm=False,
+ **kwargs)
+ else:
+ raise NotImplementedError
+
+ def forward_sample(
+ self,
+ num_points: int,
+ camera: Optional[CamerasBase],
+ image_rgb: Optional[Tensor],
+ mask: Optional[Tensor],
+ # Optional overrides
+ scheduler: Optional[str] = 'ddpm',
+ # Inference parameters
+ num_inference_steps: Optional[int] = 1000,
+ eta: Optional[float] = 0.0, # for DDIM
+ # Whether to return all the intermediate steps in generation
+ return_sample_every_n_steps: int = -1,
+ # Whether to disable tqdm
+ disable_tqdm: bool = False,
+ gt_pc: Pointclouds = None,
+ **kwargs
+ ):
+ "use two models to run diffusion forward, and also use translation and scale to put them back"
+ assert not self.self_conditioning
+ # Get scheduler from mapping, or use self.scheduler if None
+ scheduler = self.scheduler if scheduler is None else self.schedulers_map[scheduler]
+
+ # Get the size of the noise
+ N = num_points
+ B = 1 if image_rgb is None else image_rgb.shape[0]
+ D = self.get_x_T_channel()
+ device = self.device if image_rgb is None else image_rgb.device
+
+ # sample from full steps or only a few steps
+ sample_from_interm = kwargs.get('sample_from_interm', False)
+ interm_steps = kwargs.get('noise_step') if sample_from_interm else -1
+
+ xt_h = self.initialize_x_T(device, gt_pc, (B, N, D), interm_steps, scheduler)
+ xt_o = self.initialize_x_T(device, kwargs.get('pc_obj', None), (B, N, D), interm_steps, scheduler)
+
+ # the segmentation mask
+ segm_mask = torch.zeros(B, 2*N, 1).to(device)
+ segm_mask[:, :N] = 1.0
+
+ # Set timesteps
+ extra_step_kwargs = self.setup_reverse_process(eta, num_inference_steps, scheduler)
+
+ # Loop over timesteps
+ all_outputs = []
+ return_all_outputs = (return_sample_every_n_steps > 0)
+ progress_bar = tqdm(self.get_reverse_timesteps(scheduler, interm_steps),
+ desc=f'Sampling ({xt_h.shape})', disable=disable_tqdm)
+
+ # print("Camera T:", camera.T[0], camera.R[0])
+ # print("Camera_obj T:", kwargs.get('camera_obj').T[0], kwargs.get('camera_obj').R[0])
+
+ norm_parms = self.pack_norm_params(kwargs)
+ for i, t in enumerate(progress_bar):
+ x_t_input_h, x_t_input_o = self.get_image_conditioning(camera, image_rgb,
+ kwargs, mask,
+ norm_parms,
+ t,
+ xt_h, xt_o)
+
+ # One reverse step with conditioning
+ xt_h, xt_o = self.reverse_step(extra_step_kwargs, scheduler, t, torch.stack([xt_h, xt_o], 0),
+ torch.stack([x_t_input_h, x_t_input_o], 0), **kwargs) # (B, N, D), D=3
+
+ if (return_all_outputs and (i % return_sample_every_n_steps == 0 or i == len(scheduler.timesteps) - 1)):
+ # print(xt_h.shape, kwargs.get('cent_hum').shape, kwargs.get('radius_hum').shape)
+ x_t = torch.cat([self.denormalize_pclouds(xt_h, kwargs.get('cent_hum'), kwargs.get('radius_hum')),
+ self.denormalize_pclouds(xt_o, kwargs.get('cent_obj'), kwargs.get('radius_obj'))], 1)
+ # print(x_t.shape, xt_o.shape)
+ all_outputs.append(torch.cat([x_t, segm_mask], -1))
+ # print("Updating intermediate...")
+
+ # Convert output back into a point cloud, undoing normalization and scaling
+ x_t = torch.cat([self.denormalize_pclouds(xt_h, kwargs.get('cent_hum'), kwargs.get('radius_hum')),
+ self.denormalize_pclouds(xt_o, kwargs.get('cent_obj'), kwargs.get('radius_obj'))], 1)
+ x_t = torch.cat([x_t, segm_mask], -1)
+ output = self.tensor_to_point_cloud(x_t, denormalize=False, unscale=False) # this convert the points back to original scale
+ if return_all_outputs:
+ all_outputs = torch.stack(all_outputs, dim=1) # (B, sample_steps, N, D)
+ all_outputs = [self.tensor_to_point_cloud(o, denormalize=False, unscale=False) for o in all_outputs]
+
+ return (output, all_outputs) if return_all_outputs else output
+
+ def get_reverse_timesteps(self, scheduler, interm_steps:int):
+ """
+
+ :param scheduler:
+ :param interm_steps: start from some intermediate steps
+ :return:
+ """
+ if interm_steps > 0:
+ timesteps = torch.from_numpy(np.arange(0, interm_steps)[::-1].copy()).to(self.device)
+ else:
+ timesteps = scheduler.timesteps.to(self.device)
+ return timesteps
+
+ def pack_norm_params(self, kwargs:dict, scale=True):
+ scale_factor = self.scale_factor if scale else 1.0
+ hum = torch.cat([kwargs.get('cent_hum')*scale_factor, kwargs.get('radius_hum')], -1)
+ obj = torch.cat([kwargs.get('cent_obj')*scale_factor, kwargs.get('radius_obj')], -1)
+ return torch.stack([hum, obj], 0) # (2, B, 4)
+
+ def reverse_step(self, extra_step_kwargs, scheduler, t, x_t, x_t_input, **kwargs):
+ "x_t: (2, B, D, N), x_t_input: (2, B, D, N)"
+ norm_parms = self.pack_norm_params(kwargs) # (2, B, 4)
+ B = x_t.shape[1]
+ # print(f"Step {t} Norm params:", norm_parms[:, 0, :])
+ noise_pred_h, noise_pred_o = self.point_cloud_model(x_t_input[0], x_t_input[1], t.reshape(1).expand(B),
+ norm_parms)
+ if self.consistent_center:
+ assert self.dm_pred_type != 'sample', 'incompatible dm predition type!'
+ noise_pred_h = noise_pred_h - torch.mean(noise_pred_h, dim=1, keepdim=True)
+ noise_pred_o = noise_pred_o - torch.mean(noise_pred_o, dim=1, keepdim=True)
+
+ xt_h = scheduler.step(noise_pred_h, t, x_t[0], **extra_step_kwargs).prev_sample
+ xt_o = scheduler.step(noise_pred_o, t, x_t[1], **extra_step_kwargs).prev_sample
+
+ if self.consistent_center:
+ xt_h = xt_h - torch.mean(xt_h, dim=1, keepdim=True)
+ xt_o = xt_o - torch.mean(xt_o, dim=1, keepdim=True)
+
+ return xt_h, xt_o
+
+ def denormalize_pclouds(self, x: Tensor, cent, radius, unscale: bool = True):
+ """
+ first denormalize, then apply center and scale to original H+O coordinate
+ :param x:
+ :param cent: (B, 3)
+ :param radius: (B, 1)
+ :param unscale:
+ :return:
+ """
+ # denormalize: scale down.
+ points = x[:, :, :3] / (self.scale_factor if unscale else 1)
+ # translation and scale back to H+O coordinate
+ points = points * 2 * radius.unsqueeze(-1) + cent.unsqueeze(1)
+ return points
+
+ def tensor_to_point_cloud(self, x: Tensor, /, denormalize: bool = False, unscale: bool = False):
+ """
+ take binary into account
+ :param self:
+ :param x: (B, N, 4)
+ :param denormalize:
+ :param unscale:
+ :return:
+ """
+ points = x[:, :, :3] / (self.scale_factor if unscale else 1)
+ if self.predict_color:
+ colors = self.denormalize(x[:, :, 3:]) if denormalize else x[:, :, 3:]
+ return Pointclouds(points=points, features=colors)
+ else:
+ assert x.shape[2] == 4
+ # add color to predicted binary labels
+ is_hum = x[:, :, 3] > 0.5
+ features = []
+ for mask in is_hum:
+ color = torch.zeros_like(x[0, :, :3]) + torch.tensor([0.5, 1.0, 0]).to(x.device)
+ color[mask, :] = torch.tensor([0.05, 1.0, 1.0]).to(x.device) # human is light blue, object light green
+ features.append(color)
+ return Pointclouds(points=points, features=features)
+
+
diff --git a/model/model_utils.py b/model/model_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..416e6a9928e6cd6f1dc39becbf40b0538f706fb8
--- /dev/null
+++ b/model/model_utils.py
@@ -0,0 +1,58 @@
+import cv2
+import numpy as np
+import torch
+import torch.nn as nn
+from pytorch3d.structures import Pointclouds
+
+
+def set_requires_grad(module: nn.Module, requires_grad: bool):
+ for p in module.parameters():
+ p.requires_grad_(requires_grad)
+
+
+def compute_distance_transform(mask: torch.Tensor):
+ """
+
+ Parameters
+ ----------
+ mask (B, 1, H, W) or (B, 2, H, W) true for foreground
+
+ Returns
+ -------
+ the vector to the closest foreground pixel, zero if inside mask
+
+ """
+ C = mask.shape[1]
+ assert C in [1, 2], f'invalid mask shape {mask.shape} found!'
+
+ image_size = mask.shape[-1]
+
+ dts = []
+ for i in range(C):
+ distance_transform = torch.stack([
+ torch.from_numpy(cv2.distanceTransform(
+ (1 - m), distanceType=cv2.DIST_L2, maskSize=cv2.DIST_MASK_3
+ ) / (image_size / 2))
+ for m in mask[:, i:i+1].squeeze(1).detach().cpu().numpy().astype(np.uint8)
+ ]).unsqueeze(1).clip(0, 1).to(mask.device)
+ dts.append(distance_transform)
+ return torch.cat(dts, 1)
+
+
+def default(x, d):
+ return d if x is None else x
+
+
+def get_num_points(x: Pointclouds, /):
+ return x.points_padded().shape[1]
+
+
+def get_custom_betas(beta_start: float, beta_end: float, warmup_frac: float = 0.3, num_train_timesteps: int = 1000):
+ """Custom beta schedule"""
+ betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
+ warmup_frac = 0.3
+ warmup_time = int(num_train_timesteps * warmup_frac)
+ warmup_steps = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
+ warmup_time = min(warmup_time, num_train_timesteps)
+ betas[:warmup_time] = warmup_steps[:warmup_time]
+ return betas
diff --git a/model/point_cloud_model.py b/model/point_cloud_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e9617ced94e167ebbe3937cc16c207c87b5e7e9
--- /dev/null
+++ b/model/point_cloud_model.py
@@ -0,0 +1,67 @@
+from contextlib import nullcontext
+
+import torch
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers import ModelMixin
+from torch import Tensor
+
+from .pvcnn.pvcnn import PVCNN2
+from .pvcnn.pvcnn_plus_plus import PVCNN2PlusPlus
+from .simple.simple_model import SimplePointModel
+
+
+class PointCloudModel(ModelMixin, ConfigMixin):
+ @register_to_config
+ def __init__(
+ self,
+ model_type: str = 'pvcnn',
+ in_channels: int = 3,
+ out_channels: int = 3,
+ embed_dim: int = 64,
+ dropout: float = 0.1,
+ width_multiplier: int = 1,
+ voxel_resolution_multiplier: int = 1,
+ ):
+ super().__init__()
+ self.model_type = model_type
+ if self.model_type == 'pvcnn':
+ self.autocast_context = torch.autocast('cuda', dtype=torch.float32)
+ self.model = PVCNN2(
+ embed_dim=embed_dim,
+ num_classes=out_channels,
+ extra_feature_channels=(in_channels - 3),
+ dropout=dropout, width_multiplier=width_multiplier,
+ voxel_resolution_multiplier=voxel_resolution_multiplier
+ )
+ self.model.classifier[-1].bias.data.normal_(0, 1e-6)
+ self.model.classifier[-1].weight.data.normal_(0, 1e-6)
+ elif self.model_type == 'pvcnnplusplus':
+ self.autocast_context = torch.autocast('cuda', dtype=torch.float32)
+ self.model = PVCNN2PlusPlus(
+ embed_dim=embed_dim,
+ num_classes=out_channels,
+ extra_feature_channels=(in_channels - 3),
+ )
+ self.model.output_projection[-1].bias.data.normal_(0, 1e-6)
+ self.model.output_projection[-1].weight.data.normal_(0, 1e-6)
+ elif self.model_type == 'simple':
+ self.autocast_context = nullcontext()
+ self.model = SimplePointModel(
+ embed_dim=embed_dim,
+ num_classes=out_channels,
+ extra_feature_channels=(in_channels - 3),
+ )
+ self.model.output_projection.bias.data.normal_(0, 1e-6)
+ self.model.output_projection.weight.data.normal_(0, 1e-6)
+ else:
+ raise NotImplementedError()
+
+ def forward(self, inputs: Tensor, t: Tensor, ret_feats=False) -> Tensor:
+ """ Receives input of shape (B, N, in_channels) and returns output
+ of shape (B, N, out_channels) """
+ with self.autocast_context:
+ if not ret_feats:
+ return self.model(inputs.transpose(1, 2), t, ret_feats=False).transpose(1, 2)
+ else:
+ pred, feats = self.model(inputs.transpose(1, 2), t, ret_feats=True)
+ return pred.transpose(1, 2), feats
\ No newline at end of file
diff --git a/model/point_cloud_transformer_model.py b/model/point_cloud_transformer_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..54621629f7ee38e31823c1b5d23ca8d3c0fed2e0
--- /dev/null
+++ b/model/point_cloud_transformer_model.py
@@ -0,0 +1,80 @@
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers import ModelMixin
+from torch import Tensor
+from timm.models.vision_transformer import Attention, LayerScale, DropPath, Mlp
+
+from .point_cloud_model import PointCloudModel
+
+
+class PointCloudModelBlock(nn.Module):
+
+ def __init__(
+ self,
+ *,
+ # Point cloud model
+ dim: int,
+ model_type: str = 'pvcnn',
+ dropout: float = 0.1,
+ width_multiplier: int = 1,
+ voxel_resolution_multiplier: int = 1,
+ # Transformer model
+ num_heads=6, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_attn=False
+ ):
+ super().__init__()
+
+ # Point cloud model
+ self.norm0 = norm_layer(dim)
+ self.point_cloud_model = PointCloudModel(model_type=model_type,
+ in_channels=dim, out_channels=dim, embed_dim=dim, dropout=dropout,
+ width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier)
+ self.ls0 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path0 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ # Attention
+ self.use_attn = use_attn
+ if self.use_attn:
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ # MLP
+ self.norm2 = norm_layer(dim)
+ self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ def apply_point_cloud_model(self, x: Tensor, t: Optional[Tensor] = None) -> Tensor:
+ t = t if t is not None else torch.zeros(len(x), device=x.device, dtype=torch.long)
+ return self.point_cloud_model(x, t)
+
+ def forward(self, x: Tensor):
+ x = x + self.drop_path0(self.ls0(self.apply_point_cloud_model(self.norm0(x))))
+ if self.use_attn:
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
+ return x
+
+
+class PointCloudTransformerModel(ModelMixin, ConfigMixin):
+ @register_to_config
+ def __init__(self, num_layers: int, in_channels: int = 3, out_channels: int = 3, embed_dim: int = 64, **kwargs):
+ super().__init__()
+ self.num_layers = num_layers
+ self.input_projection = nn.Linear(in_channels, embed_dim)
+ self.blocks = nn.Sequential(*[PointCloudModelBlock(dim=embed_dim, **kwargs) for i in range(self.num_layers)])
+ self.norm = nn.LayerNorm(embed_dim)
+ self.output_projection = nn.Linear(embed_dim, out_channels)
+
+ def forward(self, inputs: Tensor) -> Tensor:
+ """ Receives input of shape (B, N, in_channels) and returns output
+ of shape (B, N, out_channels) """
+ x = self.input_projection(inputs)
+ x = self.blocks(x)
+ x = self.output_projection(x)
+ return x
diff --git a/model/projection_model.py b/model/projection_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..78450d69c034f61de60f0cb0b3185d175035b69e
--- /dev/null
+++ b/model/projection_model.py
@@ -0,0 +1,273 @@
+from typing import Optional, Union
+
+import torch
+from diffusers.schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler
+from diffusers.schedulers.scheduling_lms_discrete import LMSDiscreteScheduler
+from diffusers import ModelMixin
+from pytorch3d.implicitron.dataset.data_loader_map_provider import FrameData
+from pytorch3d.renderer import PointsRasterizationSettings, PointsRasterizer
+from pytorch3d.renderer.cameras import CamerasBase
+from pytorch3d.structures import Pointclouds
+from torch import Tensor
+
+from .feature_model import FeatureModel
+from .model_utils import compute_distance_transform
+
+SchedulerClass = Union[DDPMScheduler, DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
+
+
+class PointCloudProjectionModel(ModelMixin):
+
+ def __init__(
+ self,
+ image_size: int,
+ image_feature_model: str,
+ use_local_colors: bool = True,
+ use_local_features: bool = True,
+ use_global_features: bool = False,
+ use_mask: bool = True,
+ use_distance_transform: bool = True,
+ predict_shape: bool = True,
+ predict_color: bool = False,
+ process_color: bool = False,
+ image_color_channels: int = 3, # for the input image, not the points
+ color_channels: int = 3, # for the points, not the input image
+ colors_mean: float = 0.5,
+ colors_std: float = 0.5,
+ scale_factor: float = 1.0,
+ # Rasterization settings
+ raster_point_radius: float = 0.0075, # point size
+ raster_points_per_pixel: int = 1, # a single point per pixel, for now
+ bin_size: int = 0,
+ model_name=None,
+ # additional arguments added by XH
+ load_sample_init=False,
+ sample_init_scale=1.0,
+ test_init_with_gtpc=False,
+ consistent_center=False, # from https://arxiv.org/pdf/2308.07837.pdf
+ voxel_resolution_multiplier: int=1,
+ predict_binary: bool=False, # predict a binary class label
+ lw_binary: float=1.0,
+ binary_training_noise_std: float=0.1,
+ dm_pred_type: str='epsilon', # diffusion prediction type
+ self_conditioning=False,
+ **kwargs,
+
+ ):
+ super().__init__()
+ self.image_size = image_size
+ self.scale_factor = scale_factor
+ self.use_local_colors = use_local_colors
+ self.use_local_features = use_local_features
+ self.use_global_features = use_global_features
+ self.use_mask = use_mask
+ self.use_distance_transform = use_distance_transform
+ self.predict_shape = predict_shape # default False
+ self.predict_color = predict_color # default True
+ self.process_color = process_color
+ self.image_color_channels = image_color_channels
+ self.color_channels = color_channels
+ self.colors_mean = colors_mean
+ self.colors_std = colors_std
+ self.model_name = model_name
+ print("PointCloud Model scale factor:", self.scale_factor, 'Model name:', self.model_name)
+ self.predict_binary = predict_binary
+ self.lw_binary = lw_binary
+ self.self_conditioning = self_conditioning
+
+ # Types of conditioning that are used
+ self.use_local_conditioning = self.use_local_colors or self.use_local_features or self.use_mask
+ self.use_global_conditioning = self.use_global_features
+ self.kwargs = kwargs
+
+ # Create feature model
+ self.feature_model = FeatureModel(image_size, image_feature_model)
+
+ # Input size
+ self.in_channels = 3 # 3 for 3D point positions
+ if self.use_local_colors: # whether color should be an input
+ self.in_channels += self.image_color_channels
+ if self.use_local_features:
+ self.in_channels += self.feature_model.feature_dim
+ if self.use_global_features:
+ self.in_channels += self.feature_model.feature_dim
+ if self.use_mask:
+ self.in_channels += 2 if self.use_distance_transform else 1
+ if self.process_color:
+ self.in_channels += self.color_channels # point color added to input or not, default False
+ if self.self_conditioning:
+ self.in_channels += 3 # add self conditioning
+
+ self.in_channels = self.add_extra_input_chennels(self.in_channels)
+
+ if self.model_name in ['pc2-diff-ho-sepsegm', 'diff-ho-attn']:
+ self.in_channels += 2 if self.use_distance_transform else 1
+
+ # Output size
+ self.out_channels = 0
+ if self.predict_shape:
+ self.out_channels += 3
+ if self.predict_color:
+ self.out_channels += self.color_channels
+ if self.predict_binary:
+ print("Output binary classification score!")
+ self.out_channels += 1
+
+ # Save rasterization settings
+ self.raster_settings = PointsRasterizationSettings(
+ image_size=(image_size, image_size),
+ radius=raster_point_radius,
+ points_per_pixel=raster_points_per_pixel,
+ bin_size=bin_size,
+ )
+
+ def add_extra_input_chennels(self, input_channels):
+ return input_channels
+
+ def denormalize(self, x: Tensor, /, clamp: bool = True):
+ x = x * self.colors_std + self.colors_mean
+ return torch.clamp(x, 0, 1) if clamp else x
+
+ def normalize(self, x: Tensor, /):
+ x = (x - self.colors_mean) / self.colors_std
+ return x
+
+ def get_global_conditioning(self, image_rgb: Tensor):
+ global_conditioning = []
+ if self.use_global_features:
+ global_conditioning.append(self.feature_model(image_rgb,
+ return_cls_token_only=True)) # (B, D)
+ global_conditioning = torch.cat(global_conditioning, dim=1) # (B, D_cond)
+ return global_conditioning
+
+ def get_local_conditioning(self, image_rgb: Tensor, mask: Tensor):
+ """
+ compute per-point conditioning
+ Parameters
+ ----------
+ image_rgb: (B, 3, 224, 224), values normalized to 0-1, background is masked by the given mask
+ mask: (B, 1, 224, 224), or (B, 2, 224, 224) for h+o
+ """
+ local_conditioning = []
+ # import pdb; pdb.set_trace()
+
+ if self.use_local_colors: # XH: default True
+ local_conditioning.append(self.normalize(image_rgb))
+ if self.use_local_features: # XH: default True
+ local_conditioning.append(self.feature_model(image_rgb)) # I guess no mask here? feature model: 'vit_small_patch16_224_mae'
+ if self.use_mask: # default True
+ local_conditioning.append(mask.float())
+ if self.use_distance_transform: # default True
+ if not self.use_mask:
+ raise ValueError('No mask for distance transform?')
+ if mask.is_floating_point():
+ mask = mask > 0.5
+ local_conditioning.append(compute_distance_transform(mask))
+ local_conditioning = torch.cat(local_conditioning, dim=1) # (B, D_cond, H, W)
+ return local_conditioning
+
+ @torch.autocast('cuda', dtype=torch.float32)
+ def surface_projection(
+ self, points: Tensor, camera: CamerasBase, local_features: Tensor,
+ ):
+ B, C, H, W, device = *local_features.shape, local_features.device
+ R = self.raster_settings.points_per_pixel
+ N = points.shape[1]
+
+ # Scale camera by scaling T. ASSUMES CAMERA IS LOOKING AT ORIGIN!
+ camera = camera.clone()
+ camera.T = camera.T * self.scale_factor
+
+ # Create rasterizer
+ rasterizer = PointsRasterizer(cameras=camera, raster_settings=self.raster_settings)
+
+ # Associate points with features via rasterization
+ fragments = rasterizer(Pointclouds(points)) # (B, H, W, R)
+ fragments_idx: Tensor = fragments.idx.long()
+ visible_pixels = (fragments_idx > -1) # (B, H, W, R)
+ points_to_visible_pixels = fragments_idx[visible_pixels]
+
+ # Reshape local features to (B, H, W, R, C)
+ local_features = local_features.permute(0, 2, 3, 1).unsqueeze(-2).expand(-1, -1, -1, R, -1) # (B, H, W, R, C)
+
+ # Get local features corresponding to visible points
+ local_features_proj = torch.zeros(B * N, C, device=device)
+ # local feature includes: raw RGB color, image features, mask, distance transform
+ local_features_proj[points_to_visible_pixels] = local_features[visible_pixels]
+ local_features_proj = local_features_proj.reshape(B, N, C)
+
+ return local_features_proj
+
+ def point_cloud_to_tensor(self, pc: Pointclouds, /, normalize: bool = False, scale: bool = False):
+ """Converts a point cloud to a tensor, with color if and only if self.predict_color"""
+ points = pc.points_padded() * (self.scale_factor if scale else 1)
+ if self.predict_color and pc.features_padded() is not None: # normalize color, not point locations
+ colors = self.normalize(pc.features_padded()) if normalize else pc.features_padded()
+ return torch.cat((points, colors), dim=2)
+ else:
+ return points
+
+ def tensor_to_point_cloud(self, x: Tensor, /, denormalize: bool = False, unscale: bool = False):
+ points = x[:, :, :3] / (self.scale_factor if unscale else 1)
+ if self.predict_color:
+ colors = self.denormalize(x[:, :, 3:]) if denormalize else x[:, :, 3:]
+ return Pointclouds(points=points, features=colors)
+ else:
+ assert x.shape[2] == 3
+ return Pointclouds(points=points)
+
+ def get_input_with_conditioning(
+ self,
+ x_t: Tensor,
+ camera: Optional[CamerasBase],
+ image_rgb: Optional[Tensor],
+ mask: Optional[Tensor],
+ t: Optional[Tensor],
+ ):
+ """ Extracts local features from the input image and projects them onto the points
+ in the point cloud to obtain the input to the model. Then extracts global
+ features, replicates them across points, and concats them to the input.
+ image_rgb: masked background
+ XH: why there is no positional encoding as described by the supp??
+ """
+ B, N = x_t.shape[:2]
+
+ # Initial input is the point locations (and colors if and only if predicting color)
+ x_t_input = self.get_coord_feature(x_t)
+
+ # Local conditioning
+ if self.use_local_conditioning:
+
+ # Get local features and check that they are the same size as the input image
+ local_features = self.get_local_conditioning(image_rgb=image_rgb, mask=mask) # concatenate RGB + mask + RGB feature + distance transform
+ if local_features.shape[-2:] != image_rgb.shape[-2:]:
+ raise ValueError(f'{local_features.shape=} and {image_rgb.shape=}')
+
+ # Project local features. Here that we only need the point locations, not colors
+ local_features_proj = self.surface_projection(points=x_t[:, :, :3],
+ camera=camera, local_features=local_features) # (B, N, D_local)
+
+ x_t_input.append(local_features_proj)
+
+ # Global conditioning
+ if self.use_global_conditioning: # False
+
+ # Get and repeat global features
+ global_features = self.get_global_conditioning(image_rgb=image_rgb) # (B, D_global)
+ global_features = global_features.unsqueeze(1).expand(-1, N, -1) # (B, D_global, N)
+
+ x_t_input.append(global_features)
+
+ # Concatenate together all the pointwise features
+ x_t_input = torch.cat(x_t_input, dim=2) # (B, N, D)
+
+ return x_t_input
+
+ def get_coord_feature(self, x_t):
+ """get coordinate feature, for model that uses separate model to predict binary, we use first 3 channels only"""
+ x_t_input = [x_t]
+ return x_t_input
+
+ def forward(self, batch: FrameData, mode: str = 'train', **kwargs):
+ """ The forward method may be defined differently for different models. """
+ raise NotImplementedError()
diff --git a/model/pvcnn/__init__.py b/model/pvcnn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/pvcnn/modules/__init__.py b/model/pvcnn/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3975207ca20b36c9fcd5f4c4c137eae4852d6b9a
--- /dev/null
+++ b/model/pvcnn/modules/__init__.py
@@ -0,0 +1,8 @@
+from .ball_query import BallQuery, BallQueryHO
+from .frustum import FrustumPointNetLoss
+from .loss import KLLoss
+from .pointnet import PointNetAModule, PointNetSAModule, PointNetFPModule
+from .pvconv import PVConv, Attention, Swish, PVConvReLU
+from .se import SE3d
+from .shared_mlp import SharedMLP
+from .voxelization import Voxelization
diff --git a/model/pvcnn/modules/ball_query.py b/model/pvcnn/modules/ball_query.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ace370968cc273c2f310d8c6200d2b9a0a8bea4
--- /dev/null
+++ b/model/pvcnn/modules/ball_query.py
@@ -0,0 +1,69 @@
+import torch
+import torch.nn as nn
+
+from . import functional as F
+
+__all__ = ['BallQuery']
+
+
+class BallQuery(nn.Module):
+ def __init__(self, radius, num_neighbors, include_coordinates=True):
+ super().__init__()
+ self.radius = radius
+ self.num_neighbors = num_neighbors
+ self.include_coordinates = include_coordinates
+
+ def forward(self, points_coords, centers_coords, temb, points_features=None):
+ points_coords = points_coords.contiguous()
+ centers_coords = centers_coords.contiguous()
+ neighbor_indices = F.ball_query(centers_coords, points_coords, self.radius, self.num_neighbors)
+ neighbor_coordinates = F.grouping(points_coords, neighbor_indices)
+ neighbor_coordinates = neighbor_coordinates - centers_coords.unsqueeze(-1)
+
+ if points_features is None:
+ assert self.include_coordinates, 'No Features For Grouping'
+ neighbor_features = neighbor_coordinates
+ else:
+ neighbor_features = F.grouping(points_features, neighbor_indices) # return [B, C, M, U] C=feat dim, M=# centers, U=# neighbours
+ if self.include_coordinates:
+ neighbor_features = torch.cat([neighbor_coordinates, neighbor_features], dim=1)
+ return neighbor_features, F.grouping(temb, neighbor_indices)
+
+ def extra_repr(self):
+ return 'radius={}, num_neighbors={}{}'.format(
+ self.radius, self.num_neighbors, ', include coordinates' if self.include_coordinates else '')
+
+
+class BallQueryHO(nn.Module):
+ "no point feature, but only relative and abs coordinate"
+ def __init__(self, radius, num_neighbors, include_relative=False):
+ super().__init__()
+ self.radius = radius
+ self.num_neighbors = num_neighbors
+ self.include_relative = include_relative
+
+ def forward(self, points_coords, centers_coords, points_features=None):
+ """
+ if not enough points inside the given radius, the entries will be zero
+ if too many points inside the radius, the order is random??? (not sure)
+ :param points_coords: (B, 3, N)
+ :param centers_coords: (B, 3, M)
+ :param points_features: None
+ :return:
+ """
+ points_coords = points_coords.contiguous()
+ centers_coords = centers_coords.contiguous()
+ neighbor_indices = F.ball_query(centers_coords, points_coords, self.radius, self.num_neighbors)
+ neighbor_coordinates = F.grouping(points_coords, neighbor_indices) # (B, 3, M, U)
+ if self.include_relative:
+ neighbor_coordinates_rela = neighbor_coordinates - centers_coords.unsqueeze(-1)
+ neighbor_coordinates = torch.cat([neighbor_coordinates, neighbor_coordinates_rela], 1) # (B, 6, M, U)
+ # flatten the coordinate
+ neighbor_coordinates = neighbor_coordinates.permute(0, 1, 3, 2) # (B, 3/6, U, M)
+ neighbor_coordinates = torch.flatten(neighbor_coordinates, 1, 2) # (B, 3*U, M)
+ return neighbor_coordinates
+
+ def extra_repr(self):
+ return 'radius={}, num_neighbors={}{}'.format(
+ self.radius, self.num_neighbors, ', include relative' if self.include_relative else '')
+
diff --git a/model/pvcnn/modules/frustum.py b/model/pvcnn/modules/frustum.py
new file mode 100644
index 0000000000000000000000000000000000000000..72c1ba870e024079081e308cad7a8521de952bc3
--- /dev/null
+++ b/model/pvcnn/modules/frustum.py
@@ -0,0 +1,138 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from . import functional as F
+
+__all__ = ['FrustumPointNetLoss', 'get_box_corners_3d']
+
+
+class FrustumPointNetLoss(nn.Module):
+ def __init__(self, num_heading_angle_bins, num_size_templates, size_templates, box_loss_weight=1.0,
+ corners_loss_weight=10.0, heading_residual_loss_weight=20.0, size_residual_loss_weight=20.0):
+ super().__init__()
+ self.box_loss_weight = box_loss_weight
+ self.corners_loss_weight = corners_loss_weight
+ self.heading_residual_loss_weight = heading_residual_loss_weight
+ self.size_residual_loss_weight = size_residual_loss_weight
+
+ self.num_heading_angle_bins = num_heading_angle_bins
+ self.num_size_templates = num_size_templates
+ self.register_buffer('size_templates', size_templates.view(self.num_size_templates, 3))
+ self.register_buffer(
+ 'heading_angle_bin_centers', torch.arange(0, 2 * np.pi, 2 * np.pi / self.num_heading_angle_bins)
+ )
+
+ def forward(self, inputs, targets):
+ mask_logits = inputs['mask_logits'] # (B, 2, N)
+ center_reg = inputs['center_reg'] # (B, 3)
+ center = inputs['center'] # (B, 3)
+ heading_scores = inputs['heading_scores'] # (B, NH)
+ heading_residuals_normalized = inputs['heading_residuals_normalized'] # (B, NH)
+ heading_residuals = inputs['heading_residuals'] # (B, NH)
+ size_scores = inputs['size_scores'] # (B, NS)
+ size_residuals_normalized = inputs['size_residuals_normalized'] # (B, NS, 3)
+ size_residuals = inputs['size_residuals'] # (B, NS, 3)
+
+ mask_logits_target = targets['mask_logits'] # (B, N)
+ center_target = targets['center'] # (B, 3)
+ heading_bin_id_target = targets['heading_bin_id'] # (B, )
+ heading_residual_target = targets['heading_residual'] # (B, )
+ size_template_id_target = targets['size_template_id'] # (B, )
+ size_residual_target = targets['size_residual'] # (B, 3)
+
+ batch_size = center.size(0)
+ batch_id = torch.arange(batch_size, device=center.device)
+
+ # Basic Classification and Regression losses
+ mask_loss = F.cross_entropy(mask_logits, mask_logits_target)
+ heading_loss = F.cross_entropy(heading_scores, heading_bin_id_target)
+ size_loss = F.cross_entropy(size_scores, size_template_id_target)
+ center_loss = PF.huber_loss(torch.norm(center_target - center, dim=-1), delta=2.0)
+ center_reg_loss = PF.huber_loss(torch.norm(center_target - center_reg, dim=-1), delta=1.0)
+
+ # Refinement losses for size/heading
+ heading_residuals_normalized = heading_residuals_normalized[batch_id, heading_bin_id_target] # (B, )
+ heading_residual_normalized_target = heading_residual_target / (np.pi / self.num_heading_angle_bins)
+ heading_residual_normalized_loss = PF.huber_loss(
+ heading_residuals_normalized - heading_residual_normalized_target, delta=1.0
+ )
+ size_residuals_normalized = size_residuals_normalized[batch_id, size_template_id_target] # (B, 3)
+ size_residual_normalized_target = size_residual_target / self.size_templates[size_template_id_target]
+ size_residual_normalized_loss = PF.huber_loss(
+ torch.norm(size_residual_normalized_target - size_residuals_normalized, dim=-1), delta=1.0
+ )
+
+ # Bounding box losses
+ heading = (heading_residuals[batch_id, heading_bin_id_target]
+ + self.heading_angle_bin_centers[heading_bin_id_target]) # (B, )
+ # Warning: in origin code, size_residuals are added twice (issue #43 and #49 in charlesq34/frustum-pointnets)
+ size = (size_residuals[batch_id, size_template_id_target]
+ + self.size_templates[size_template_id_target]) # (B, 3)
+ corners = get_box_corners_3d(centers=center, headings=heading, sizes=size, with_flip=False) # (B, 3, 8)
+ heading_target = self.heading_angle_bin_centers[heading_bin_id_target] + heading_residual_target # (B, )
+ size_target = self.size_templates[size_template_id_target] + size_residual_target # (B, 3)
+ corners_target, corners_target_flip = get_box_corners_3d(centers=center_target, headings=heading_target,
+ sizes=size_target, with_flip=True) # (B, 3, 8)
+ corners_loss = PF.huber_loss(torch.min(
+ torch.norm(corners - corners_target, dim=1), torch.norm(corners - corners_target_flip, dim=1)
+ ), delta=1.0)
+ # Summing up
+ loss = mask_loss + self.box_loss_weight * (
+ center_loss + center_reg_loss + heading_loss + size_loss
+ + self.heading_residual_loss_weight * heading_residual_normalized_loss
+ + self.size_residual_loss_weight * size_residual_normalized_loss
+ + self.corners_loss_weight * corners_loss
+ )
+
+ return loss
+
+
+def get_box_corners_3d(centers, headings, sizes, with_flip=False):
+ """
+ :param centers: coords of box centers, FloatTensor[N, 3]
+ :param headings: heading angles, FloatTensor[N, ]
+ :param sizes: box sizes, FloatTensor[N, 3]
+ :param with_flip: bool, whether to return flipped box (headings + np.pi)
+ :return:
+ coords of box corners, FloatTensor[N, 3, 8]
+ NOTE: corner points are in counter clockwise order, e.g.,
+ 2--1
+ 3--0 5
+ 7--4
+ """
+ l = sizes[:, 0] # (N,)
+ w = sizes[:, 1] # (N,)
+ h = sizes[:, 2] # (N,)
+ x_corners = torch.stack([l/2, l/2, -l/2, -l/2, l/2, l/2, -l/2, -l/2], dim=1) # (N, 8)
+ y_corners = torch.stack([h/2, h/2, h/2, h/2, -h/2, -h/2, -h/2, -h/2], dim=1) # (N, 8)
+ z_corners = torch.stack([w/2, -w/2, -w/2, w/2, w/2, -w/2, -w/2, w/2], dim=1) # (N, 8)
+
+ c = torch.cos(headings) # (N,)
+ s = torch.sin(headings) # (N,)
+ o = torch.ones_like(headings) # (N,)
+ z = torch.zeros_like(headings) # (N,)
+
+ centers = centers.unsqueeze(-1) # (B, 3, 1)
+ corners = torch.stack([x_corners, y_corners, z_corners], dim=1) # (N, 3, 8)
+ R = torch.stack([c, z, s, z, o, z, -s, z, c], dim=1).view(-1, 3, 3) # roty matrix: (N, 3, 3)
+ if with_flip:
+ R_flip = torch.stack([-c, z, -s, z, o, z, s, z, -c], dim=1).view(-1, 3, 3)
+ return torch.matmul(R, corners) + centers, torch.matmul(R_flip, corners) + centers
+ else:
+ return torch.matmul(R, corners) + centers
+
+ # centers = centers.unsqueeze(1) # (B, 1, 3)
+ # corners = torch.stack([x_corners, y_corners, z_corners], dim=-1) # (N, 8, 3)
+ # RT = torch.stack([c, z, -s, z, o, z, s, z, c], dim=1).view(-1, 3, 3) # (N, 3, 3)
+ # if with_flip:
+ # RT_flip = torch.stack([-c, z, s, z, o, z, -s, z, -c], dim=1).view(-1, 3, 3) # (N, 3, 3)
+ # return torch.matmul(corners, RT) + centers, torch.matmul(corners, RT_flip) + centers # (N, 8, 3)
+ # else:
+ # return torch.matmul(corners, RT) + centers # (N, 8, 3)
+
+ # corners = torch.stack([x_corners, y_corners, z_corners], dim=1) # (N, 3, 8)
+ # R = torch.stack([c, z, s, z, o, z, -s, z, c], dim=1).view(-1, 3, 3) # (N, 3, 3)
+ # corners = torch.matmul(R, corners) + centers.unsqueeze(2) # (N, 3, 8)
+ # corners = corners.transpose(1, 2) # (N, 8, 3)
diff --git a/model/pvcnn/modules/functional/__init__.py b/model/pvcnn/modules/functional/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b945c5dd1521e7446e5344714e85af1f89f24510
--- /dev/null
+++ b/model/pvcnn/modules/functional/__init__.py
@@ -0,0 +1,7 @@
+from .ball_query import ball_query
+from .devoxelization import trilinear_devoxelize
+from .grouping import grouping
+from .interpolatation import nearest_neighbor_interpolate
+from .loss import kl_loss, huber_loss
+from .sampling import gather, furthest_point_sample, logits_mask
+from .voxelization import avg_voxelize
diff --git a/model/pvcnn/modules/functional/backend.py b/model/pvcnn/modules/functional/backend.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7bb555291b5784cf0e416d23f12b9c397538c83
--- /dev/null
+++ b/model/pvcnn/modules/functional/backend.py
@@ -0,0 +1,33 @@
+import os
+from pathlib import Path
+
+from torch.utils.cpp_extension import load
+
+
+gcc_path = os.getenv('CC', default='/usr/bin/gcc')
+if not Path(gcc_path).is_file():
+ raise ValueError('Could not find your gcc, please replace it here.')
+
+_src_path = os.path.dirname(os.path.abspath(__file__))
+_backend = load(
+ name='_pvcnn_backend',
+ extra_cflags=['-O3', '-std=c++17'],
+ extra_cuda_cflags=[f'--compiler-bindir={gcc_path}'],
+ sources=[os.path.join(_src_path,'src', f) for f in [
+ 'ball_query/ball_query.cpp',
+ 'ball_query/ball_query.cu',
+ 'grouping/grouping.cpp',
+ 'grouping/grouping.cu',
+ 'interpolate/neighbor_interpolate.cpp',
+ 'interpolate/neighbor_interpolate.cu',
+ 'interpolate/trilinear_devox.cpp',
+ 'interpolate/trilinear_devox.cu',
+ 'sampling/sampling.cpp',
+ 'sampling/sampling.cu',
+ 'voxelization/vox.cpp',
+ 'voxelization/vox.cu',
+ 'bindings.cpp',
+ ]]
+)
+
+__all__ = ['_backend']
diff --git a/model/pvcnn/modules/functional/ball_query.py b/model/pvcnn/modules/functional/ball_query.py
new file mode 100644
index 0000000000000000000000000000000000000000..4daef84d172a4c79d4657a7905aca6f3a8b02cc1
--- /dev/null
+++ b/model/pvcnn/modules/functional/ball_query.py
@@ -0,0 +1,19 @@
+from torch.autograd import Function
+
+from .backend import _backend
+
+__all__ = ['ball_query']
+
+
+def ball_query(centers_coords, points_coords, radius, num_neighbors):
+ """
+ :param centers_coords: coordinates of centers, FloatTensor[B, 3, M]
+ :param points_coords: coordinates of points, FloatTensor[B, 3, N]
+ :param radius: float, radius of ball query
+ :param num_neighbors: int, maximum number of neighbors
+ :return:
+ neighbor_indices: indices of neighbors, IntTensor[B, M, U]
+ """
+ centers_coords = centers_coords.contiguous()
+ points_coords = points_coords.contiguous()
+ return _backend.ball_query(centers_coords, points_coords, radius, num_neighbors)
diff --git a/model/pvcnn/modules/functional/devoxelization.py b/model/pvcnn/modules/functional/devoxelization.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9dd014e393a468b8b5b40b8654860fc7ab87eea
--- /dev/null
+++ b/model/pvcnn/modules/functional/devoxelization.py
@@ -0,0 +1,42 @@
+from torch.autograd import Function
+
+from .backend import _backend
+
+__all__ = ['trilinear_devoxelize']
+
+
+class TrilinearDevoxelization(Function):
+ @staticmethod
+ def forward(ctx, features, coords, resolution, is_training=True):
+ """
+ :param ctx:
+ :param coords: the coordinates of points, FloatTensor[B, 3, N]
+ :param features: FloatTensor[B, C, R, R, R]
+ :param resolution: int, the voxel resolution
+ :param is_training: bool, training mode
+ :return:
+ FloatTensor[B, C, N]
+ """
+ B, C = features.shape[:2]
+ features = features.contiguous().view(B, C, -1)
+ coords = coords.contiguous()
+ outs, inds, wgts = _backend.trilinear_devoxelize_forward(resolution, is_training, coords, features)
+ if is_training:
+ ctx.save_for_backward(inds, wgts)
+ ctx.r = resolution
+ return outs
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ """
+ :param ctx:
+ :param grad_output: gradient of outputs, FloatTensor[B, C, N]
+ :return:
+ gradient of inputs, FloatTensor[B, C, R, R, R]
+ """
+ inds, wgts = ctx.saved_tensors
+ grad_inputs = _backend.trilinear_devoxelize_backward(grad_output.contiguous(), inds, wgts, ctx.r)
+ return grad_inputs.view(grad_output.size(0), grad_output.size(1), ctx.r, ctx.r, ctx.r), None, None, None
+
+
+trilinear_devoxelize = TrilinearDevoxelization.apply
diff --git a/model/pvcnn/modules/functional/grouping.py b/model/pvcnn/modules/functional/grouping.py
new file mode 100644
index 0000000000000000000000000000000000000000..db19ef7dba8333499e5deaffdbaf3d6fb860f71f
--- /dev/null
+++ b/model/pvcnn/modules/functional/grouping.py
@@ -0,0 +1,32 @@
+from torch.autograd import Function
+
+from .backend import _backend
+
+__all__ = ['grouping']
+
+
+class Grouping(Function):
+ @staticmethod
+ def forward(ctx, features, indices):
+ """
+ :param ctx:
+ :param features: features of points, FloatTensor[B, C, N]
+ :param indices: neighbor indices of centers, IntTensor[B, M, U], M is #centers, U is #neighbors
+ :return:
+ grouped_features: grouped features, FloatTensor[B, C, M, U]
+ """
+ features = features.contiguous()
+ indices = indices.contiguous()
+ ctx.save_for_backward(indices)
+ ctx.num_points = features.size(-1)
+ # print(features.dtype, features.shape)
+ return _backend.grouping_forward(features, indices)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ indices, = ctx.saved_tensors
+ grad_features = _backend.grouping_backward(grad_output.contiguous(), indices, ctx.num_points)
+ return grad_features, None
+
+
+grouping = Grouping.apply
diff --git a/model/pvcnn/modules/functional/interpolatation.py b/model/pvcnn/modules/functional/interpolatation.py
new file mode 100644
index 0000000000000000000000000000000000000000..a83378d0f18901d2ced6ba3acb1691a1d1bc20b8
--- /dev/null
+++ b/model/pvcnn/modules/functional/interpolatation.py
@@ -0,0 +1,38 @@
+from torch.autograd import Function
+
+from .backend import _backend
+
+__all__ = ['nearest_neighbor_interpolate']
+
+
+class NeighborInterpolation(Function):
+ @staticmethod
+ def forward(ctx, points_coords, centers_coords, centers_features):
+ """
+ :param ctx:
+ :param points_coords: coordinates of points, FloatTensor[B, 3, N]
+ :param centers_coords: coordinates of centers, FloatTensor[B, 3, M]
+ :param centers_features: features of centers, FloatTensor[B, C, M]
+ :return:
+ points_features: features of points, FloatTensor[B, C, N]
+ """
+ centers_coords = centers_coords.contiguous()
+ points_coords = points_coords.contiguous()
+ centers_features = centers_features.contiguous()
+ points_features, indices, weights = _backend.three_nearest_neighbors_interpolate_forward(
+ points_coords, centers_coords, centers_features
+ )
+ ctx.save_for_backward(indices, weights)
+ ctx.num_centers = centers_coords.size(-1)
+ return points_features
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ indices, weights = ctx.saved_tensors
+ grad_centers_features = _backend.three_nearest_neighbors_interpolate_backward(
+ grad_output.contiguous(), indices, weights, ctx.num_centers
+ )
+ return None, None, grad_centers_features
+
+
+nearest_neighbor_interpolate = NeighborInterpolation.apply
diff --git a/model/pvcnn/modules/functional/loss.py b/model/pvcnn/modules/functional/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..41112b3271c6f340047e8eb874b70ee630597b32
--- /dev/null
+++ b/model/pvcnn/modules/functional/loss.py
@@ -0,0 +1,17 @@
+import torch
+import torch.nn.functional as F
+
+__all__ = ['kl_loss', 'huber_loss']
+
+
+def kl_loss(x, y):
+ x = F.softmax(x.detach(), dim=1)
+ y = F.log_softmax(y, dim=1)
+ return torch.mean(torch.sum(x * (torch.log(x) - y), dim=1))
+
+
+def huber_loss(error, delta):
+ abs_error = torch.abs(error)
+ quadratic = torch.min(abs_error, torch.full_like(abs_error, fill_value=delta))
+ losses = 0.5 * (quadratic ** 2) + delta * (abs_error - quadratic)
+ return torch.mean(losses)
diff --git a/model/pvcnn/modules/functional/sampling.py b/model/pvcnn/modules/functional/sampling.py
new file mode 100644
index 0000000000000000000000000000000000000000..b08ec63446fad34dc947bd966de12e6c4161733b
--- /dev/null
+++ b/model/pvcnn/modules/functional/sampling.py
@@ -0,0 +1,84 @@
+import numpy as np
+import torch
+from torch.autograd import Function
+
+from .backend import _backend
+
+__all__ = ['gather', 'furthest_point_sample', 'logits_mask']
+
+
+class Gather(Function):
+ @staticmethod
+ def forward(ctx, features, indices):
+ """
+ Gather
+ :param ctx:
+ :param features: features of points, FloatTensor[B, C, N]
+ :param indices: centers' indices in points, IntTensor[b, m]
+ :return:
+ centers_coords: coordinates of sampled centers, FloatTensor[B, C, M]
+ """
+ features = features.contiguous()
+ indices = indices.int().contiguous()
+ ctx.save_for_backward(indices)
+ ctx.num_points = features.size(-1)
+ return _backend.gather_features_forward(features, indices)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ indices, = ctx.saved_tensors
+ grad_features = _backend.gather_features_backward(grad_output.contiguous(), indices, ctx.num_points)
+ return grad_features, None
+
+
+gather = Gather.apply
+
+
+def furthest_point_sample(coords, num_samples):
+ """
+ Uses iterative furthest point sampling to select a set of npoint features that have the largest
+ minimum distance to the sampled point set
+ :param coords: coordinates of points, FloatTensor[B, 3, N]
+ :param num_samples: int, M
+ :return:
+ centers_coords: coordinates of sampled centers, FloatTensor[B, 3, M]
+ """
+ coords = coords.contiguous()
+ indices = _backend.furthest_point_sampling(coords, num_samples)
+ return gather(coords, indices)
+
+
+def logits_mask(coords, logits, num_points_per_object):
+ """
+ Use logits to sample points
+ :param coords: coords of points, FloatTensor[B, 3, N]
+ :param logits: binary classification logits, FloatTensor[B, 2, N]
+ :param num_points_per_object: M, #points per object after masking, int
+ :return:
+ selected_coords: FloatTensor[B, 3, M]
+ masked_coords_mean: mean coords of selected points, FloatTensor[B, 3]
+ mask: mask to select points, BoolTensor[B, N]
+ """
+ batch_size, _, num_points = coords.shape
+ mask = torch.lt(logits[:, 0, :], logits[:, 1, :]) # [B, N]
+ num_candidates = torch.sum(mask, dim=-1, keepdim=True) # [B, 1]
+ masked_coords = coords * mask.view(batch_size, 1, num_points) # [B, C, N]
+ masked_coords_mean = torch.sum(masked_coords, dim=-1) / torch.max(num_candidates,
+ torch.ones_like(num_candidates)).float() # [B, C]
+ selected_indices = torch.zeros((batch_size, num_points_per_object), device=coords.device, dtype=torch.int32)
+ for i in range(batch_size):
+ current_mask = mask[i] # [N]
+ current_candidates = current_mask.nonzero().view(-1)
+ current_num_candidates = current_candidates.numel()
+ if current_num_candidates >= num_points_per_object:
+ choices = np.random.choice(current_num_candidates, num_points_per_object, replace=False)
+ selected_indices[i] = current_candidates[choices]
+ elif current_num_candidates > 0:
+ choices = np.concatenate([
+ np.arange(current_num_candidates).repeat(num_points_per_object // current_num_candidates),
+ np.random.choice(current_num_candidates, num_points_per_object % current_num_candidates, replace=False)
+ ])
+ np.random.shuffle(choices)
+ selected_indices[i] = current_candidates[choices]
+ selected_coords = gather(masked_coords - masked_coords_mean.view(batch_size, -1, 1), selected_indices)
+ return selected_coords, masked_coords_mean, mask
diff --git a/model/pvcnn/modules/functional/src/ball_query/ball_query.cpp b/model/pvcnn/modules/functional/src/ball_query/ball_query.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..5ae1fb6f6ada981c75c5dc8e986ae2d903bf61c2
--- /dev/null
+++ b/model/pvcnn/modules/functional/src/ball_query/ball_query.cpp
@@ -0,0 +1,30 @@
+#include "ball_query.hpp"
+#include "ball_query.cuh"
+
+#include "../utils.hpp"
+
+at::Tensor ball_query_forward(at::Tensor centers_coords,
+ at::Tensor points_coords, const float radius,
+ const int num_neighbors) {
+ CHECK_CUDA(centers_coords);
+ CHECK_CUDA(points_coords);
+ CHECK_CONTIGUOUS(centers_coords);
+ CHECK_CONTIGUOUS(points_coords);
+ CHECK_IS_FLOAT(centers_coords);
+ CHECK_IS_FLOAT(points_coords);
+
+ int b = centers_coords.size(0);
+ int m = centers_coords.size(2);
+ int n = points_coords.size(2);
+
+ at::Tensor neighbors_indices = torch::zeros(
+ {b, m, num_neighbors},
+ at::device(centers_coords.device()).dtype(at::ScalarType::Int));
+
+ ball_query(b, n, m, radius * radius, num_neighbors,
+ centers_coords.data_ptr(),
+ points_coords.data_ptr(),
+ neighbors_indices.data_ptr());
+
+ return neighbors_indices;
+}
diff --git a/model/pvcnn/modules/functional/src/ball_query/ball_query.cu b/model/pvcnn/modules/functional/src/ball_query/ball_query.cu
new file mode 100644
index 0000000000000000000000000000000000000000..079e3cb86cb2ee51bb1326b1c7ecb1e70c8a4ae8
--- /dev/null
+++ b/model/pvcnn/modules/functional/src/ball_query/ball_query.cu
@@ -0,0 +1,59 @@
+#include
+#include
+#include
+
+#include "../cuda_utils.cuh"
+
+/*
+ Function: ball query
+ Args:
+ b : batch size
+ n : number of points in point clouds
+ m : number of query centers
+ r2 : ball query radius ** 2
+ u : maximum number of neighbors
+ centers_coords: coordinates of centers, FloatTensor[b, 3, m]
+ points_coords : coordinates of points, FloatTensor[b, 3, n]
+ neighbors_indices : neighbor indices in points, IntTensor[b, m, u]
+*/
+__global__ void ball_query_kernel(int b, int n, int m, float r2, int u,
+ const float *__restrict__ centers_coords,
+ const float *__restrict__ points_coords,
+ int *__restrict__ neighbors_indices) {
+ int batch_index = blockIdx.x;
+ int index = threadIdx.x;
+ int stride = blockDim.x;
+ points_coords += batch_index * n * 3;
+ centers_coords += batch_index * m * 3;
+ neighbors_indices += batch_index * m * u;
+
+ for (int j = index; j < m; j += stride) {
+ float center_x = centers_coords[j];
+ float center_y = centers_coords[j + m];
+ float center_z = centers_coords[j + m + m];
+ for (int k = 0, cnt = 0; k < n && cnt < u; ++k) {
+ float dx = center_x - points_coords[k];
+ float dy = center_y - points_coords[k + n];
+ float dz = center_z - points_coords[k + n + n];
+ float d2 = dx * dx + dy * dy + dz * dz;
+ if (d2 < r2) {
+ if (cnt == 0) {
+ for (int v = 0; v < u; ++v) {
+ neighbors_indices[j * u + v] = k;
+ }
+ }
+ neighbors_indices[j * u + cnt] = k;
+ ++cnt;
+ }
+ }
+ }
+}
+
+void ball_query(int b, int n, int m, float r2, int u,
+ const float *centers_coords, const float *points_coords,
+ int *neighbors_indices) {
+ ball_query_kernel<<>>(
+ b, n, m, r2, u, centers_coords, points_coords, neighbors_indices);
+ CUDA_CHECK_ERRORS();
+}
diff --git a/model/pvcnn/modules/functional/src/ball_query/ball_query.cuh b/model/pvcnn/modules/functional/src/ball_query/ball_query.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..ba32492f17733007d4f0cbfef327c74cdd35dd0b
--- /dev/null
+++ b/model/pvcnn/modules/functional/src/ball_query/ball_query.cuh
@@ -0,0 +1,8 @@
+#ifndef _BALL_QUERY_CUH
+#define _BALL_QUERY_CUH
+
+void ball_query(int b, int n, int m, float r2, int u,
+ const float *centers_coords, const float *points_coords,
+ int *neighbors_indices);
+
+#endif
diff --git a/model/pvcnn/modules/functional/src/ball_query/ball_query.hpp b/model/pvcnn/modules/functional/src/ball_query/ball_query.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..d87bbd93258a8a4ee08ae20b03a544a2d79ebb68
--- /dev/null
+++ b/model/pvcnn/modules/functional/src/ball_query/ball_query.hpp
@@ -0,0 +1,10 @@
+#ifndef _BALL_QUERY_HPP
+#define _BALL_QUERY_HPP
+
+#include
+
+at::Tensor ball_query_forward(at::Tensor centers_coords,
+ at::Tensor points_coords, const float radius,
+ const int num_neighbors);
+
+#endif
diff --git a/model/pvcnn/modules/functional/src/bindings.cpp b/model/pvcnn/modules/functional/src/bindings.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..994e01b5861fe20269e26ae49c2e00cadbef2b01
--- /dev/null
+++ b/model/pvcnn/modules/functional/src/bindings.cpp
@@ -0,0 +1,37 @@
+#include
+
+#include "ball_query/ball_query.hpp"
+#include "grouping/grouping.hpp"
+#include "interpolate/neighbor_interpolate.hpp"
+#include "interpolate/trilinear_devox.hpp"
+#include "sampling/sampling.hpp"
+#include "voxelization/vox.hpp"
+
+PYBIND11_MODULE(_pvcnn_backend, m) {
+ m.def("gather_features_forward", &gather_features_forward,
+ "Gather Centers' Features forward (CUDA)");
+ m.def("gather_features_backward", &gather_features_backward,
+ "Gather Centers' Features backward (CUDA)");
+ m.def("furthest_point_sampling", &furthest_point_sampling_forward,
+ "Furthest Point Sampling (CUDA)");
+ m.def("ball_query", &ball_query_forward, "Ball Query (CUDA)");
+ m.def("grouping_forward", &grouping_forward,
+ "Grouping Features forward (CUDA)");
+ m.def("grouping_backward", &grouping_backward,
+ "Grouping Features backward (CUDA)");
+ m.def("three_nearest_neighbors_interpolate_forward",
+ &three_nearest_neighbors_interpolate_forward,
+ "3 Nearest Neighbors Interpolate forward (CUDA)");
+ m.def("three_nearest_neighbors_interpolate_backward",
+ &three_nearest_neighbors_interpolate_backward,
+ "3 Nearest Neighbors Interpolate backward (CUDA)");
+
+ m.def("trilinear_devoxelize_forward", &trilinear_devoxelize_forward,
+ "Trilinear Devoxelization forward (CUDA)");
+ m.def("trilinear_devoxelize_backward", &trilinear_devoxelize_backward,
+ "Trilinear Devoxelization backward (CUDA)");
+ m.def("avg_voxelize_forward", &avg_voxelize_forward,
+ "Voxelization forward with average pooling (CUDA)");
+ m.def("avg_voxelize_backward", &avg_voxelize_backward,
+ "Voxelization backward (CUDA)");
+}
diff --git a/model/pvcnn/modules/functional/src/cuda_utils.cuh b/model/pvcnn/modules/functional/src/cuda_utils.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..01bf5512914af1a0265697990a21eb760c14663c
--- /dev/null
+++ b/model/pvcnn/modules/functional/src/cuda_utils.cuh
@@ -0,0 +1,39 @@
+#ifndef _CUDA_UTILS_H
+#define _CUDA_UTILS_H
+
+#include
+#include
+#include
+
+#include
+#include
+
+#include
+
+#define MAXIMUM_THREADS 512
+
+inline int optimal_num_threads(int work_size) {
+ const int pow_2 = std::log2(static_cast(work_size));
+ return max(min(1 << pow_2, MAXIMUM_THREADS), 1);
+}
+
+inline dim3 optimal_block_config(int x, int y) {
+ const int x_threads = optimal_num_threads(x);
+ const int y_threads =
+ max(min(optimal_num_threads(y), MAXIMUM_THREADS / x_threads), 1);
+ dim3 block_config(x_threads, y_threads, 1);
+ return block_config;
+}
+
+#define CUDA_CHECK_ERRORS() \
+ { \
+ cudaError_t err = cudaGetLastError(); \
+ if (cudaSuccess != err) { \
+ fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \
+ cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \
+ __FILE__); \
+ exit(-1); \
+ } \
+ }
+
+#endif
diff --git a/model/pvcnn/modules/functional/src/grouping/grouping.cpp b/model/pvcnn/modules/functional/src/grouping/grouping.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..4f97650069f3961946e16b68722893e4bb65cb26
--- /dev/null
+++ b/model/pvcnn/modules/functional/src/grouping/grouping.cpp
@@ -0,0 +1,44 @@
+#include "grouping.hpp"
+#include "grouping.cuh"
+
+#include "../utils.hpp"
+
+at::Tensor grouping_forward(at::Tensor features, at::Tensor indices) {
+ CHECK_CUDA(features);
+ CHECK_CUDA(indices);
+ CHECK_CONTIGUOUS(features);
+ CHECK_CONTIGUOUS(indices);
+ CHECK_IS_FLOAT(features);
+ CHECK_IS_INT(indices);
+
+ int b = features.size(0);
+ int c = features.size(1);
+ int n = features.size(2);
+ int m = indices.size(1);
+ int u = indices.size(2);
+ at::Tensor output = torch::zeros(
+ {b, c, m, u}, at::device(features.device()).dtype(at::ScalarType::Float));
+ grouping(b, c, n, m, u, features.data_ptr(), indices.data_ptr(),
+ output.data_ptr());
+ return output;
+}
+
+at::Tensor grouping_backward(at::Tensor grad_y, at::Tensor indices,
+ const int n) {
+ CHECK_CUDA(grad_y);
+ CHECK_CUDA(indices);
+ CHECK_CONTIGUOUS(grad_y);
+ CHECK_CONTIGUOUS(indices);
+ CHECK_IS_FLOAT(grad_y);
+ CHECK_IS_INT(indices);
+
+ int b = grad_y.size(0);
+ int c = grad_y.size(1);
+ int m = indices.size(1);
+ int u = indices.size(2);
+ at::Tensor grad_x = torch::zeros(
+ {b, c, n}, at::device(grad_y.device()).dtype(at::ScalarType::Float));
+ grouping_grad(b, c, n, m, u, grad_y.data_ptr(),
+ indices.data_ptr(), grad_x.data_ptr());
+ return grad_x;
+}
diff --git a/model/pvcnn/modules/functional/src/grouping/grouping.cu b/model/pvcnn/modules/functional/src/grouping/grouping.cu
new file mode 100644
index 0000000000000000000000000000000000000000..0cf561a1e94874d968f575e0f63170069c05fa80
--- /dev/null
+++ b/model/pvcnn/modules/functional/src/grouping/grouping.cu
@@ -0,0 +1,85 @@
+#include
+#include
+
+#include "../cuda_utils.cuh"
+
+/*
+ Function: grouping features of neighbors (forward)
+ Args:
+ b : batch size
+ c : #channles of features
+ n : number of points in point clouds
+ m : number of query centers
+ u : maximum number of neighbors
+ features: points' features, FloatTensor[b, c, n]
+ indices : neighbor indices in points, IntTensor[b, m, u]
+ out : gathered features, FloatTensor[b, c, m, u]
+*/
+__global__ void grouping_kernel(int b, int c, int n, int m, int u,
+ const float *__restrict__ features,
+ const int *__restrict__ indices,
+ float *__restrict__ out) {
+ int batch_index = blockIdx.x;
+ features += batch_index * n * c;
+ indices += batch_index * m * u;
+ out += batch_index * m * u * c;
+
+ const int index = threadIdx.y * blockDim.x + threadIdx.x;
+ const int stride = blockDim.y * blockDim.x;
+ for (int i = index; i < c * m; i += stride) {
+ const int l = i / m;
+ const int j = i % m;
+ for (int k = 0; k < u; ++k) {
+ out[(l * m + j) * u + k] = features[l * n + indices[j * u + k]];
+ }
+ }
+}
+
+void grouping(int b, int c, int n, int m, int u, const float *features,
+ const int *indices, float *out) {
+ grouping_kernel<<>>(b, c, n, m, u, features,
+ indices, out);
+ CUDA_CHECK_ERRORS();
+}
+
+/*
+ Function: grouping features of neighbors (backward)
+ Args:
+ b : batch size
+ c : #channles of features
+ n : number of points in point clouds
+ m : number of query centers
+ u : maximum number of neighbors
+ grad_y : grad of gathered features, FloatTensor[b, c, m, u]
+ indices : neighbor indices in points, IntTensor[b, m, u]
+ grad_x: grad of points' features, FloatTensor[b, c, n]
+*/
+__global__ void grouping_grad_kernel(int b, int c, int n, int m, int u,
+ const float *__restrict__ grad_y,
+ const int *__restrict__ indices,
+ float *__restrict__ grad_x) {
+ int batch_index = blockIdx.x;
+ grad_y += batch_index * m * u * c;
+ indices += batch_index * m * u;
+ grad_x += batch_index * n * c;
+
+ const int index = threadIdx.y * blockDim.x + threadIdx.x;
+ const int stride = blockDim.y * blockDim.x;
+ for (int i = index; i < c * m; i += stride) {
+ const int l = i / m;
+ const int j = i % m;
+ for (int k = 0; k < u; ++k) {
+ atomicAdd(grad_x + l * n + indices[j * u + k],
+ grad_y[(l * m + j) * u + k]);
+ }
+ }
+}
+
+void grouping_grad(int b, int c, int n, int m, int u, const float *grad_y,
+ const int *indices, float *grad_x) {
+ grouping_grad_kernel<<>>(
+ b, c, n, m, u, grad_y, indices, grad_x);
+ CUDA_CHECK_ERRORS();
+}
diff --git a/model/pvcnn/modules/functional/src/grouping/grouping.cuh b/model/pvcnn/modules/functional/src/grouping/grouping.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..c8a114feda3e63bd0cbdbd271bb0c476181b65b1
--- /dev/null
+++ b/model/pvcnn/modules/functional/src/grouping/grouping.cuh
@@ -0,0 +1,9 @@
+#ifndef _GROUPING_CUH
+#define _GROUPING_CUH
+
+void grouping(int b, int c, int n, int m, int u, const float *features,
+ const int *indices, float *out);
+void grouping_grad(int b, int c, int n, int m, int u, const float *grad_y,
+ const int *indices, float *grad_x);
+
+#endif
\ No newline at end of file
diff --git a/model/pvcnn/modules/functional/src/grouping/grouping.hpp b/model/pvcnn/modules/functional/src/grouping/grouping.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..3f5733d2f53bd7a02b2b16cb1f0303af84cd6c7e
--- /dev/null
+++ b/model/pvcnn/modules/functional/src/grouping/grouping.hpp
@@ -0,0 +1,10 @@
+#ifndef _GROUPING_HPP
+#define _GROUPING_HPP
+
+#include
+
+at::Tensor grouping_forward(at::Tensor features, at::Tensor indices);
+at::Tensor grouping_backward(at::Tensor grad_y, at::Tensor indices,
+ const int n);
+
+#endif
diff --git a/model/pvcnn/modules/functional/src/interpolate/neighbor_interpolate.cpp b/model/pvcnn/modules/functional/src/interpolate/neighbor_interpolate.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..fc73c43cda3eea7a9bb4635851c9023359d05079
--- /dev/null
+++ b/model/pvcnn/modules/functional/src/interpolate/neighbor_interpolate.cpp
@@ -0,0 +1,65 @@
+#include "neighbor_interpolate.hpp"
+#include "neighbor_interpolate.cuh"
+
+#include "../utils.hpp"
+
+std::vector
+three_nearest_neighbors_interpolate_forward(at::Tensor points_coords,
+ at::Tensor centers_coords,
+ at::Tensor centers_features) {
+ CHECK_CUDA(points_coords);
+ CHECK_CUDA(centers_coords);
+ CHECK_CUDA(centers_features);
+ CHECK_CONTIGUOUS(points_coords);
+ CHECK_CONTIGUOUS(centers_coords);
+ CHECK_CONTIGUOUS(centers_features);
+ CHECK_IS_FLOAT(points_coords);
+ CHECK_IS_FLOAT(centers_coords);
+ CHECK_IS_FLOAT(centers_features);
+
+ int b = centers_features.size(0);
+ int c = centers_features.size(1);
+ int m = centers_features.size(2);
+ int n = points_coords.size(2);
+
+ at::Tensor indices = torch::zeros(
+ {b, 3, n}, at::device(points_coords.device()).dtype(at::ScalarType::Int));
+ at::Tensor weights = torch::zeros(
+ {b, 3, n},
+ at::device(points_coords.device()).dtype(at::ScalarType::Float));
+ at::Tensor output = torch::zeros(
+ {b, c, n},
+ at::device(centers_features.device()).dtype(at::ScalarType::Float));
+
+ three_nearest_neighbors_interpolate(
+ b, c, m, n, points_coords.data_ptr(),
+ centers_coords.data_ptr(), centers_features.data_ptr(),
+ indices.data_ptr(), weights.data_ptr(),
+ output.data_ptr());
+ return {output, indices, weights};
+}
+
+at::Tensor three_nearest_neighbors_interpolate_backward(at::Tensor grad_y,
+ at::Tensor indices,
+ at::Tensor weights,
+ const int m) {
+ CHECK_CUDA(grad_y);
+ CHECK_CUDA(indices);
+ CHECK_CUDA(weights);
+ CHECK_CONTIGUOUS(grad_y);
+ CHECK_CONTIGUOUS(indices);
+ CHECK_CONTIGUOUS(weights);
+ CHECK_IS_FLOAT(grad_y);
+ CHECK_IS_INT(indices);
+ CHECK_IS_FLOAT(weights);
+
+ int b = grad_y.size(0);
+ int c = grad_y.size(1);
+ int n = grad_y.size(2);
+ at::Tensor grad_x = torch::zeros(
+ {b, c, m}, at::device(grad_y.device()).dtype(at::ScalarType::Float));
+ three_nearest_neighbors_interpolate_grad(
+ b, c, n, m, grad_y.data_ptr(), indices.data_ptr(),
+ weights.data_ptr(), grad_x.data_ptr());
+ return grad_x;
+}
diff --git a/model/pvcnn/modules/functional/src/interpolate/neighbor_interpolate.cu b/model/pvcnn/modules/functional/src/interpolate/neighbor_interpolate.cu
new file mode 100644
index 0000000000000000000000000000000000000000..8168507aacc04f2e24b56b7e8d1fc635f169afd6
--- /dev/null
+++ b/model/pvcnn/modules/functional/src/interpolate/neighbor_interpolate.cu
@@ -0,0 +1,181 @@
+#include
+#include
+#include
+
+#include "../cuda_utils.cuh"
+
+/*
+ Function: three nearest neighbors
+ Args:
+ b : batch size
+ n : number of points in point clouds
+ m : number of query centers
+ points_coords : coordinates of points, FloatTensor[b, 3, n]
+ centers_coords: coordinates of centers, FloatTensor[b, 3, m]
+ weights : weights of nearest 3 centers to the point,
+ FloatTensor[b, 3, n]
+ indices : indices of nearest 3 centers to the point,
+ IntTensor[b, 3, n]
+*/
+__global__ void three_nearest_neighbors_kernel(
+ int b, int n, int m, const float *__restrict__ points_coords,
+ const float *__restrict__ centers_coords, float *__restrict__ weights,
+ int *__restrict__ indices) {
+ int batch_index = blockIdx.x;
+ int index = threadIdx.x;
+ int stride = blockDim.x;
+ points_coords += batch_index * 3 * n;
+ weights += batch_index * 3 * n;
+ indices += batch_index * 3 * n;
+ centers_coords += batch_index * 3 * m;
+
+ for (int j = index; j < n; j += stride) {
+ float ux = points_coords[j];
+ float uy = points_coords[j + n];
+ float uz = points_coords[j + n + n];
+
+ double best0 = 1e40, best1 = 1e40, best2 = 1e40;
+ int besti0 = 0, besti1 = 0, besti2 = 0;
+ for (int k = 0; k < m; ++k) {
+ float x = centers_coords[k];
+ float y = centers_coords[k + m];
+ float z = centers_coords[k + m + m];
+ float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
+ if (d < best2) {
+ best2 = d;
+ besti2 = k;
+ if (d < best1) {
+ best2 = best1;
+ besti2 = besti1;
+ best1 = d;
+ besti1 = k;
+ if (d < best0) {
+ best1 = best0;
+ besti1 = besti0;
+ best0 = d;
+ besti0 = k;
+ }
+ }
+ }
+ }
+ best0 = max(min(1e10f, best0), 1e-10f);
+ best1 = max(min(1e10f, best1), 1e-10f);
+ best2 = max(min(1e10f, best2), 1e-10f);
+ float d0d1 = best0 * best1;
+ float d0d2 = best0 * best2;
+ float d1d2 = best1 * best2;
+ float d0d1d2 = 1.0f / (d0d1 + d0d2 + d1d2);
+ weights[j] = d1d2 * d0d1d2;
+ indices[j] = besti0;
+ weights[j + n] = d0d2 * d0d1d2;
+ indices[j + n] = besti1;
+ weights[j + n + n] = d0d1 * d0d1d2;
+ indices[j + n + n] = besti2;
+ }
+}
+
+/*
+ Function: interpolate three nearest neighbors (forward)
+ Args:
+ b : batch size
+ c : #channels of features
+ m : number of query centers
+ n : number of points in point clouds
+ centers_features: features of centers, FloatTensor[b, c, m]
+ indices : indices of nearest 3 centers to the point,
+ IntTensor[b, 3, n]
+ weights : weights for interpolation, FloatTensor[b, 3, n]
+ out : features of points, FloatTensor[b, c, n]
+*/
+__global__ void three_nearest_neighbors_interpolate_kernel(
+ int b, int c, int m, int n, const float *__restrict__ centers_features,
+ const int *__restrict__ indices, const float *__restrict__ weights,
+ float *__restrict__ out) {
+ int batch_index = blockIdx.x;
+ centers_features += batch_index * m * c;
+ indices += batch_index * n * 3;
+ weights += batch_index * n * 3;
+ out += batch_index * n * c;
+
+ const int index = threadIdx.y * blockDim.x + threadIdx.x;
+ const int stride = blockDim.y * blockDim.x;
+ for (int i = index; i < c * n; i += stride) {
+ const int l = i / n;
+ const int j = i % n;
+ float w1 = weights[j];
+ float w2 = weights[j + n];
+ float w3 = weights[j + n + n];
+ int i1 = indices[j];
+ int i2 = indices[j + n];
+ int i3 = indices[j + n + n];
+
+ out[i] = centers_features[l * m + i1] * w1 +
+ centers_features[l * m + i2] * w2 +
+ centers_features[l * m + i3] * w3;
+ }
+}
+
+void three_nearest_neighbors_interpolate(int b, int c, int m, int n,
+ const float *points_coords,
+ const float *centers_coords,
+ const float *centers_features,
+ int *indices, float *weights,
+ float *out) {
+ three_nearest_neighbors_kernel<<>>(
+ b, n, m, points_coords, centers_coords, weights, indices);
+ three_nearest_neighbors_interpolate_kernel<<<
+ b, optimal_block_config(n, c), 0, at::cuda::getCurrentCUDAStream()>>>(
+ b, c, m, n, centers_features, indices, weights, out);
+ CUDA_CHECK_ERRORS();
+}
+
+/*
+ Function: interpolate three nearest neighbors (backward)
+ Args:
+ b : batch size
+ c : #channels of features
+ m : number of query centers
+ n : number of points in point clouds
+ grad_y : grad of features of points, FloatTensor[b, c, n]
+ indices : indices of nearest 3 centers to the point, IntTensor[b, 3, n]
+ weights : weights for interpolation, FloatTensor[b, 3, n]
+ grad_x : grad of features of centers, FloatTensor[b, c, m]
+*/
+__global__ void three_nearest_neighbors_interpolate_grad_kernel(
+ int b, int c, int n, int m, const float *__restrict__ grad_y,
+ const int *__restrict__ indices, const float *__restrict__ weights,
+ float *__restrict__ grad_x) {
+ int batch_index = blockIdx.x;
+ grad_y += batch_index * n * c;
+ indices += batch_index * n * 3;
+ weights += batch_index * n * 3;
+ grad_x += batch_index * m * c;
+
+ const int index = threadIdx.y * blockDim.x + threadIdx.x;
+ const int stride = blockDim.y * blockDim.x;
+ for (int i = index; i < c * n; i += stride) {
+ const int l = i / n;
+ const int j = i % n;
+ float w1 = weights[j];
+ float w2 = weights[j + n];
+ float w3 = weights[j + n + n];
+ int i1 = indices[j];
+ int i2 = indices[j + n];
+ int i3 = indices[j + n + n];
+ atomicAdd(grad_x + l * m + i1, grad_y[i] * w1);
+ atomicAdd(grad_x + l * m + i2, grad_y[i] * w2);
+ atomicAdd(grad_x + l * m + i3, grad_y[i] * w3);
+ }
+}
+
+void three_nearest_neighbors_interpolate_grad(int b, int c, int n, int m,
+ const float *grad_y,
+ const int *indices,
+ const float *weights,
+ float *grad_x) {
+ three_nearest_neighbors_interpolate_grad_kernel<<<
+ b, optimal_block_config(n, c), 0, at::cuda::getCurrentCUDAStream()>>>(
+ b, c, n, m, grad_y, indices, weights, grad_x);
+ CUDA_CHECK_ERRORS();
+}
diff --git a/model/pvcnn/modules/functional/src/interpolate/neighbor_interpolate.cuh b/model/pvcnn/modules/functional/src/interpolate/neighbor_interpolate.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..a15f37e40f3d45c6cddbb092fecba2efe7012ec6
--- /dev/null
+++ b/model/pvcnn/modules/functional/src/interpolate/neighbor_interpolate.cuh
@@ -0,0 +1,16 @@
+#ifndef _NEIGHBOR_INTERPOLATE_CUH
+#define _NEIGHBOR_INTERPOLATE_CUH
+
+void three_nearest_neighbors_interpolate(int b, int c, int m, int n,
+ const float *points_coords,
+ const float *centers_coords,
+ const float *centers_features,
+ int *indices, float *weights,
+ float *out);
+void three_nearest_neighbors_interpolate_grad(int b, int c, int n, int m,
+ const float *grad_y,
+ const int *indices,
+ const float *weights,
+ float *grad_x);
+
+#endif
diff --git a/model/pvcnn/modules/functional/src/interpolate/neighbor_interpolate.hpp b/model/pvcnn/modules/functional/src/interpolate/neighbor_interpolate.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..cdc7835d44c1b3b4b3c49051bbc958234fea3454
--- /dev/null
+++ b/model/pvcnn/modules/functional/src/interpolate/neighbor_interpolate.hpp
@@ -0,0 +1,16 @@
+#ifndef _NEIGHBOR_INTERPOLATE_HPP
+#define _NEIGHBOR_INTERPOLATE_HPP
+
+#include
+#include
+
+std::vector
+three_nearest_neighbors_interpolate_forward(at::Tensor points_coords,
+ at::Tensor centers_coords,
+ at::Tensor centers_features);
+at::Tensor three_nearest_neighbors_interpolate_backward(at::Tensor grad_y,
+ at::Tensor indices,
+ at::Tensor weights,
+ const int m);
+
+#endif
diff --git a/model/pvcnn/modules/functional/src/interpolate/trilinear_devox.cpp b/model/pvcnn/modules/functional/src/interpolate/trilinear_devox.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..a8ff4fc74f6975cef6811eceb4b1f42d1116e1cf
--- /dev/null
+++ b/model/pvcnn/modules/functional/src/interpolate/trilinear_devox.cpp
@@ -0,0 +1,91 @@
+#include "trilinear_devox.hpp"
+#include "trilinear_devox.cuh"
+
+#include "../utils.hpp"
+
+/*
+ Function: trilinear devoxelization (forward)
+ Args:
+ r : voxel resolution
+ trainig : whether is training mode
+ coords : the coordinates of points, FloatTensor[b, 3, n]
+ features : features, FloatTensor[b, c, s], s = r ** 3
+ Return:
+ outs : outputs, FloatTensor[b, c, n]
+ inds : the voxel coordinates of point cube, IntTensor[b, 8, n]
+ wgts : weight for trilinear interpolation, FloatTensor[b, 8, n]
+*/
+std::vector
+trilinear_devoxelize_forward(const int r, const bool is_training,
+ const at::Tensor coords,
+ const at::Tensor features) {
+ CHECK_CUDA(features);
+ CHECK_CUDA(coords);
+ CHECK_CONTIGUOUS(features);
+ CHECK_CONTIGUOUS(coords);
+ CHECK_IS_FLOAT(features);
+ CHECK_IS_FLOAT(coords);
+
+ int b = features.size(0);
+ int c = features.size(1);
+ int n = coords.size(2);
+ int r2 = r * r;
+ int r3 = r2 * r;
+ at::Tensor outs = torch::zeros(
+ {b, c, n}, at::device(features.device()).dtype(at::ScalarType::Float));
+ if (is_training) {
+ at::Tensor inds = torch::zeros(
+ {b, 8, n}, at::device(features.device()).dtype(at::ScalarType::Int));
+ at::Tensor wgts = torch::zeros(
+ {b, 8, n}, at::device(features.device()).dtype(at::ScalarType::Float));
+ trilinear_devoxelize(b, c, n, r, r2, r3, true, coords.data_ptr(),
+ features.data_ptr(), inds.data_ptr(),
+ wgts.data_ptr(), outs.data_ptr());
+ return {outs, inds, wgts};
+ } else {
+ at::Tensor inds = torch::zeros(
+ {1}, at::device(features.device()).dtype(at::ScalarType::Int));
+ at::Tensor wgts = torch::zeros(
+ {1}, at::device(features.device()).dtype(at::ScalarType::Float));
+ trilinear_devoxelize(b, c, n, r, r2, r3, false, coords.data_ptr(),
+ features.data_ptr(), inds.data_ptr(),
+ wgts.data_ptr(), outs.data_ptr());
+ return {outs, inds, wgts};
+ }
+}
+
+/*
+ Function: trilinear devoxelization (backward)
+ Args:
+ grad_y : grad outputs, FloatTensor[b, c, n]
+ indices : the voxel coordinates of point cube, IntTensor[b, 8, n]
+ weights : weight for trilinear interpolation, FloatTensor[b, 8, n]
+ r : voxel resolution
+ Return:
+ grad_x : grad inputs, FloatTensor[b, c, s], s = r ** 3
+*/
+at::Tensor trilinear_devoxelize_backward(const at::Tensor grad_y,
+ const at::Tensor indices,
+ const at::Tensor weights,
+ const int r) {
+ CHECK_CUDA(grad_y);
+ CHECK_CUDA(weights);
+ CHECK_CUDA(indices);
+ CHECK_CONTIGUOUS(grad_y);
+ CHECK_CONTIGUOUS(weights);
+ CHECK_CONTIGUOUS(indices);
+ CHECK_IS_FLOAT(grad_y);
+ CHECK_IS_FLOAT(weights);
+ CHECK_IS_INT(indices);
+
+ int b = grad_y.size(0);
+ int c = grad_y.size(1);
+ int n = grad_y.size(2);
+ int r3 = r * r * r;
+ at::Tensor grad_x = torch::zeros(
+ {b, c, r3}, at::device(grad_y.device()).dtype(at::ScalarType::Float));
+ trilinear_devoxelize_grad(b, c, n, r3, indices.data_ptr(),
+ weights.data_ptr(), grad_y.data_ptr(),
+ grad_x.data_ptr());
+ return grad_x;
+}
diff --git a/model/pvcnn/modules/functional/src/interpolate/trilinear_devox.cu b/model/pvcnn/modules/functional/src/interpolate/trilinear_devox.cu
new file mode 100644
index 0000000000000000000000000000000000000000..4e1e50c0b170508c90d3ec392cc43dbf771e276b
--- /dev/null
+++ b/model/pvcnn/modules/functional/src/interpolate/trilinear_devox.cu
@@ -0,0 +1,178 @@
+#include
+#include
+
+#include "../cuda_utils.cuh"
+
+/*
+ Function: trilinear devoxlization (forward)
+ Args:
+ b : batch size
+ c : #channels
+ n : number of points
+ r : voxel resolution
+ r2 : r ** 2
+ r3 : r ** 3
+ coords : the coordinates of points, FloatTensor[b, 3, n]
+ feat : features, FloatTensor[b, c, r3]
+ inds : the voxel indices of point cube, IntTensor[b, 8, n]
+ wgts : weight for trilinear interpolation, FloatTensor[b, 8, n]
+ outs : outputs, FloatTensor[b, c, n]
+*/
+__global__ void trilinear_devoxelize_kernel(int b, int c, int n, int r, int r2,
+ int r3, bool is_training,
+ const float *__restrict__ coords,
+ const float *__restrict__ feat,
+ int *__restrict__ inds,
+ float *__restrict__ wgts,
+ float *__restrict__ outs) {
+ int batch_index = blockIdx.x;
+ int stride = blockDim.x;
+ int index = threadIdx.x;
+ coords += batch_index * n * 3;
+ inds += batch_index * n * 8;
+ wgts += batch_index * n * 8;
+ feat += batch_index * c * r3;
+ outs += batch_index * c * n;
+
+ for (int i = index; i < n; i += stride) {
+ float x = coords[i];
+ float y = coords[i + n];
+ float z = coords[i + n + n];
+ float x_lo_f = floorf(x);
+ float y_lo_f = floorf(y);
+ float z_lo_f = floorf(z);
+
+ float x_d_1 = x - x_lo_f; // / (x_hi_f - x_lo_f + 1e-8f)
+ float y_d_1 = y - y_lo_f;
+ float z_d_1 = z - z_lo_f;
+ float x_d_0 = 1.0f - x_d_1;
+ float y_d_0 = 1.0f - y_d_1;
+ float z_d_0 = 1.0f - z_d_1;
+
+ float wgt000 = x_d_0 * y_d_0 * z_d_0;
+ float wgt001 = x_d_0 * y_d_0 * z_d_1;
+ float wgt010 = x_d_0 * y_d_1 * z_d_0;
+ float wgt011 = x_d_0 * y_d_1 * z_d_1;
+ float wgt100 = x_d_1 * y_d_0 * z_d_0;
+ float wgt101 = x_d_1 * y_d_0 * z_d_1;
+ float wgt110 = x_d_1 * y_d_1 * z_d_0;
+ float wgt111 = x_d_1 * y_d_1 * z_d_1;
+
+ int x_lo = static_cast(x_lo_f);
+ int y_lo = static_cast(y_lo_f);
+ int z_lo = static_cast(z_lo_f);
+ int x_hi = (x_d_1 > 0) ? -1 : 0;
+ int y_hi = (y_d_1 > 0) ? -1 : 0;
+ int z_hi = (z_d_1 > 0) ? 1 : 0;
+
+ int idx000 = x_lo * r2 + y_lo * r + z_lo;
+ int idx001 = idx000 + z_hi; // x_lo * r2 + y_lo * r + z_hi;
+ int idx010 = idx000 + (y_hi & r); // x_lo * r2 + y_hi * r + z_lo;
+ int idx011 = idx010 + z_hi; // x_lo * r2 + y_hi * r + z_hi;
+ int idx100 = idx000 + (x_hi & r2); // x_hi * r2 + y_lo * r + z_lo;
+ int idx101 = idx100 + z_hi; // x_hi * r2 + y_lo * r + z_hi;
+ int idx110 = idx100 + (y_hi & r); // x_hi * r2 + y_hi * r + z_lo;
+ int idx111 = idx110 + z_hi; // x_hi * r2 + y_hi * r + z_hi;
+
+ if (is_training) {
+ wgts[i] = wgt000;
+ wgts[i + n] = wgt001;
+ wgts[i + n * 2] = wgt010;
+ wgts[i + n * 3] = wgt011;
+ wgts[i + n * 4] = wgt100;
+ wgts[i + n * 5] = wgt101;
+ wgts[i + n * 6] = wgt110;
+ wgts[i + n * 7] = wgt111;
+ inds[i] = idx000;
+ inds[i + n] = idx001;
+ inds[i + n * 2] = idx010;
+ inds[i + n * 3] = idx011;
+ inds[i + n * 4] = idx100;
+ inds[i + n * 5] = idx101;
+ inds[i + n * 6] = idx110;
+ inds[i + n * 7] = idx111;
+ }
+
+ for (int j = 0; j < c; j++) {
+ int jr3 = j * r3;
+ outs[j * n + i] =
+ wgt000 * feat[jr3 + idx000] + wgt001 * feat[jr3 + idx001] +
+ wgt010 * feat[jr3 + idx010] + wgt011 * feat[jr3 + idx011] +
+ wgt100 * feat[jr3 + idx100] + wgt101 * feat[jr3 + idx101] +
+ wgt110 * feat[jr3 + idx110] + wgt111 * feat[jr3 + idx111];
+ }
+ }
+}
+
+/*
+ Function: trilinear devoxlization (backward)
+ Args:
+ b : batch size
+ c : #channels
+ n : number of points
+ r3 : voxel cube size = voxel resolution ** 3
+ inds : the voxel indices of point cube, IntTensor[b, 8, n]
+ wgts : weight for trilinear interpolation, FloatTensor[b, 8, n]
+ grad_y : grad outputs, FloatTensor[b, c, n]
+ grad_x : grad inputs, FloatTensor[b, c, r3]
+*/
+__global__ void trilinear_devoxelize_grad_kernel(
+ int b, int c, int n, int r3, const int *__restrict__ inds,
+ const float *__restrict__ wgts, const float *__restrict__ grad_y,
+ float *__restrict__ grad_x) {
+ int batch_index = blockIdx.x;
+ int stride = blockDim.x;
+ int index = threadIdx.x;
+ inds += batch_index * n * 8;
+ wgts += batch_index * n * 8;
+ grad_x += batch_index * c * r3;
+ grad_y += batch_index * c * n;
+
+ for (int i = index; i < n; i += stride) {
+ int idx000 = inds[i];
+ int idx001 = inds[i + n];
+ int idx010 = inds[i + n * 2];
+ int idx011 = inds[i + n * 3];
+ int idx100 = inds[i + n * 4];
+ int idx101 = inds[i + n * 5];
+ int idx110 = inds[i + n * 6];
+ int idx111 = inds[i + n * 7];
+ float wgt000 = wgts[i];
+ float wgt001 = wgts[i + n];
+ float wgt010 = wgts[i + n * 2];
+ float wgt011 = wgts[i + n * 3];
+ float wgt100 = wgts[i + n * 4];
+ float wgt101 = wgts[i + n * 5];
+ float wgt110 = wgts[i + n * 6];
+ float wgt111 = wgts[i + n * 7];
+
+ for (int j = 0; j < c; j++) {
+ int jr3 = j * r3;
+ float g = grad_y[j * n + i];
+ atomicAdd(grad_x + jr3 + idx000, wgt000 * g);
+ atomicAdd(grad_x + jr3 + idx001, wgt001 * g);
+ atomicAdd(grad_x + jr3 + idx010, wgt010 * g);
+ atomicAdd(grad_x + jr3 + idx011, wgt011 * g);
+ atomicAdd(grad_x + jr3 + idx100, wgt100 * g);
+ atomicAdd(grad_x + jr3 + idx101, wgt101 * g);
+ atomicAdd(grad_x + jr3 + idx110, wgt110 * g);
+ atomicAdd(grad_x + jr3 + idx111, wgt111 * g);
+ }
+ }
+}
+
+void trilinear_devoxelize(int b, int c, int n, int r, int r2, int r3,
+ bool training, const float *coords, const float *feat,
+ int *inds, float *wgts, float *outs) {
+ trilinear_devoxelize_kernel<<>>(
+ b, c, n, r, r2, r3, training, coords, feat, inds, wgts, outs);
+ CUDA_CHECK_ERRORS();
+}
+
+void trilinear_devoxelize_grad(int b, int c, int n, int r3, const int *inds,
+ const float *wgts, const float *grad_y,
+ float *grad_x) {
+ trilinear_devoxelize_grad_kernel<<>>(
+ b, c, n, r3, inds, wgts, grad_y, grad_x);
+ CUDA_CHECK_ERRORS();
+}
diff --git a/model/pvcnn/modules/functional/src/interpolate/trilinear_devox.cuh b/model/pvcnn/modules/functional/src/interpolate/trilinear_devox.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..8aadbaf34f4ad539199e21cb35c041dfcbe766e8
--- /dev/null
+++ b/model/pvcnn/modules/functional/src/interpolate/trilinear_devox.cuh
@@ -0,0 +1,13 @@
+#ifndef _TRILINEAR_DEVOX_CUH
+#define _TRILINEAR_DEVOX_CUH
+
+// CUDA function declarations
+void trilinear_devoxelize(int b, int c, int n, int r, int r2, int r3,
+ bool is_training, const float *coords,
+ const float *feat, int *inds, float *wgts,
+ float *outs);
+void trilinear_devoxelize_grad(int b, int c, int n, int r3, const int *inds,
+ const float *wgts, const float *grad_y,
+ float *grad_x);
+
+#endif
diff --git a/model/pvcnn/modules/functional/src/interpolate/trilinear_devox.hpp b/model/pvcnn/modules/functional/src/interpolate/trilinear_devox.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..a9d67957b79dba20cf6891c073da0f9dd5653aac
--- /dev/null
+++ b/model/pvcnn/modules/functional/src/interpolate/trilinear_devox.hpp
@@ -0,0 +1,16 @@
+#ifndef _TRILINEAR_DEVOX_HPP
+#define _TRILINEAR_DEVOX_HPP
+
+#include
+#include
+
+std::vector trilinear_devoxelize_forward(const int r,
+ const bool is_training,
+ const at::Tensor coords,
+ const at::Tensor features);
+
+at::Tensor trilinear_devoxelize_backward(const at::Tensor grad_y,
+ const at::Tensor indices,
+ const at::Tensor weights, const int r);
+
+#endif
diff --git a/model/pvcnn/modules/functional/src/sampling/sampling.cpp b/model/pvcnn/modules/functional/src/sampling/sampling.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..9b8ca6ef63144ad4112dffa4d71391329578e25e
--- /dev/null
+++ b/model/pvcnn/modules/functional/src/sampling/sampling.cpp
@@ -0,0 +1,58 @@
+#include "sampling.hpp"
+#include "sampling.cuh"
+
+#include "../utils.hpp"
+
+at::Tensor gather_features_forward(at::Tensor features, at::Tensor indices) {
+ CHECK_CUDA(features);
+ CHECK_CUDA(indices);
+ CHECK_CONTIGUOUS(features);
+ CHECK_CONTIGUOUS(indices);
+ CHECK_IS_FLOAT(features);
+ CHECK_IS_INT(indices);
+
+ int b = features.size(0);
+ int c = features.size(1);
+ int n = features.size(2);
+ int m = indices.size(1);
+ at::Tensor output = torch::zeros(
+ {b, c, m}, at::device(features.device()).dtype(at::ScalarType::Float));
+ gather_features(b, c, n, m, features.data_ptr(),
+ indices.data_ptr(), output.data_ptr());
+ return output;
+}
+
+at::Tensor gather_features_backward(at::Tensor grad_y, at::Tensor indices,
+ const int n) {
+ CHECK_CUDA(grad_y);
+ CHECK_CUDA(indices);
+ CHECK_CONTIGUOUS(grad_y);
+ CHECK_CONTIGUOUS(indices);
+ CHECK_IS_FLOAT(grad_y);
+ CHECK_IS_INT(indices);
+
+ int b = grad_y.size(0);
+ int c = grad_y.size(1);
+ at::Tensor grad_x = torch::zeros(
+ {b, c, n}, at::device(grad_y.device()).dtype(at::ScalarType::Float));
+ gather_features_grad(b, c, n, indices.size(1), grad_y.data_ptr(),
+ indices.data_ptr(), grad_x.data_ptr());
+ return grad_x;
+}
+
+at::Tensor furthest_point_sampling_forward(at::Tensor coords,
+ const int num_samples) {
+ CHECK_CUDA(coords);
+ CHECK_CONTIGUOUS(coords);
+ CHECK_IS_FLOAT(coords);
+
+ int b = coords.size(0);
+ int n = coords.size(2);
+ at::Tensor indices = torch::zeros(
+ {b, num_samples}, at::device(coords.device()).dtype(at::ScalarType::Int));
+ at::Tensor distances = torch::full(
+ {b, n}, 1e38f, at::device(coords.device()).dtype(at::ScalarType::Float));
+ furthest_point_sampling(b, n, num_samples, coords.data_ptr(),
+ distances.data_ptr(), indices.data_ptr());
+ return indices;
+}
diff --git a/model/pvcnn/modules/functional/src/sampling/sampling.cu b/model/pvcnn/modules/functional/src/sampling/sampling.cu
new file mode 100644
index 0000000000000000000000000000000000000000..06bc0ee7240ef3d86c12eb95492d9e621b27b1bb
--- /dev/null
+++ b/model/pvcnn/modules/functional/src/sampling/sampling.cu
@@ -0,0 +1,174 @@
+#include
+#include
+
+#include "../cuda_utils.cuh"
+
+/*
+ Function: gather centers' features (forward)
+ Args:
+ b : batch size
+ c : #channles of features
+ n : number of points in point clouds
+ m : number of query/sampled centers
+ features: points' features, FloatTensor[b, c, n]
+ indices : centers' indices in points, IntTensor[b, m]
+ out : gathered features, FloatTensor[b, c, m]
+*/
+__global__ void gather_features_kernel(int b, int c, int n, int m,
+ const float *__restrict__ features,
+ const int *__restrict__ indices,
+ float *__restrict__ out) {
+ int batch_index = blockIdx.x;
+ int channel_index = blockIdx.y;
+ int temp_index = batch_index * c + channel_index;
+ features += temp_index * n;
+ indices += batch_index * m;
+ out += temp_index * m;
+
+ for (int j = threadIdx.x; j < m; j += blockDim.x) {
+ out[j] = features[indices[j]];
+ }
+}
+
+void gather_features(int b, int c, int n, int m, const float *features,
+ const int *indices, float *out) {
+ gather_features_kernel<<>>(
+ b, c, n, m, features, indices, out);
+ CUDA_CHECK_ERRORS();
+}
+
+/*
+ Function: gather centers' features (backward)
+ Args:
+ b : batch size
+ c : #channles of features
+ n : number of points in point clouds
+ m : number of query/sampled centers
+ grad_y : grad of gathered features, FloatTensor[b, c, m]
+ indices : centers' indices in points, IntTensor[b, m]
+ grad_x : grad of points' features, FloatTensor[b, c, n]
+*/
+__global__ void gather_features_grad_kernel(int b, int c, int n, int m,
+ const float *__restrict__ grad_y,
+ const int *__restrict__ indices,
+ float *__restrict__ grad_x) {
+ int batch_index = blockIdx.x;
+ int channel_index = blockIdx.y;
+ int temp_index = batch_index * c + channel_index;
+ grad_y += temp_index * m;
+ indices += batch_index * m;
+ grad_x += temp_index * n;
+
+ for (int j = threadIdx.x; j < m; j += blockDim.x) {
+ atomicAdd(grad_x + indices[j], grad_y[j]);
+ }
+}
+
+void gather_features_grad(int b, int c, int n, int m, const float *grad_y,
+ const int *indices, float *grad_x) {
+ gather_features_grad_kernel<<>>(
+ b, c, n, m, grad_y, indices, grad_x);
+ CUDA_CHECK_ERRORS();
+}
+
+/*
+ Function: furthest point sampling
+ Args:
+ b : batch size
+ n : number of points in point clouds
+ m : number of query/sampled centers
+ coords : points' coords, FloatTensor[b, 3, n]
+ distances : minimum distance of a point to the set, IntTensor[b, n]
+ indices : sampled centers' indices in points, IntTensor[b, m]
+*/
+__global__ void furthest_point_sampling_kernel(int b, int n, int m,
+ const float *__restrict__ coords,
+ float *__restrict__ distances,
+ int *__restrict__ indices) {
+ if (m <= 0)
+ return;
+ int batch_index = blockIdx.x;
+ coords += batch_index * n * 3;
+ distances += batch_index * n;
+ indices += batch_index * m;
+
+ const int BlockSize = 512;
+ __shared__ float dists[BlockSize];
+ __shared__ int dists_i[BlockSize];
+ const int BufferSize = 3072;
+ __shared__ float buf[BufferSize * 3];
+
+ int old = 0;
+ if (threadIdx.x == 0)
+ indices[0] = old;
+
+ for (int j = threadIdx.x; j < min(BufferSize, n); j += blockDim.x) {
+ buf[j] = coords[j];
+ buf[j + BufferSize] = coords[j + n];
+ buf[j + BufferSize + BufferSize] = coords[j + n + n];
+ }
+ __syncthreads();
+
+ for (int j = 1; j < m; j++) {
+ int besti = 0; // best index
+ float best = -1; // farthest distance
+ // calculating the distance with the latest sampled point
+ float x1 = coords[old];
+ float y1 = coords[old + n];
+ float z1 = coords[old + n + n];
+ for (int k = threadIdx.x; k < n; k += blockDim.x) {
+ // fetch distance at block n, thread k
+ float td = distances[k];
+ float x2, y2, z2;
+ if (k < BufferSize) {
+ x2 = buf[k];
+ y2 = buf[k + BufferSize];
+ z2 = buf[k + BufferSize + BufferSize];
+ } else {
+ x2 = coords[k];
+ y2 = coords[k + n];
+ z2 = coords[k + n + n];
+ }
+ float d =
+ (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1);
+ float d2 = min(d, td);
+ // update "point-to-set" distance
+ if (d2 != td)
+ distances[k] = d2;
+ // update the farthest distance at sample step j
+ if (d2 > best) {
+ best = d2;
+ besti = k;
+ }
+ }
+
+ dists[threadIdx.x] = best;
+ dists_i[threadIdx.x] = besti;
+ for (int u = 0; (1 << u) < blockDim.x; u++) {
+ __syncthreads();
+ if (threadIdx.x < (blockDim.x >> (u + 1))) {
+ int i1 = (threadIdx.x * 2) << u;
+ int i2 = (threadIdx.x * 2 + 1) << u;
+ if (dists[i1] < dists[i2]) {
+ dists[i1] = dists[i2];
+ dists_i[i1] = dists_i[i2];
+ }
+ }
+ }
+ __syncthreads();
+
+ // finish sample step j; old is the sampled index
+ old = dists_i[0];
+ if (threadIdx.x == 0)
+ indices[j] = old;
+ }
+}
+
+void furthest_point_sampling(int b, int n, int m, const float *coords,
+ float *distances, int *indices) {
+ furthest_point_sampling_kernel<<>>(b, n, m, coords, distances,
+ indices);
+ CUDA_CHECK_ERRORS();
+}
diff --git a/model/pvcnn/modules/functional/src/sampling/sampling.cuh b/model/pvcnn/modules/functional/src/sampling/sampling.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..e68358ffa59e10a4d00fceaf383cdef403b849ce
--- /dev/null
+++ b/model/pvcnn/modules/functional/src/sampling/sampling.cuh
@@ -0,0 +1,11 @@
+#ifndef _SAMPLING_CUH
+#define _SAMPLING_CUH
+
+void gather_features(int b, int c, int n, int m, const float *features,
+ const int *indices, float *out);
+void gather_features_grad(int b, int c, int n, int m, const float *grad_y,
+ const int *indices, float *grad_x);
+void furthest_point_sampling(int b, int n, int m, const float *coords,
+ float *distances, int *indices);
+
+#endif
diff --git a/model/pvcnn/modules/functional/src/sampling/sampling.hpp b/model/pvcnn/modules/functional/src/sampling/sampling.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..db2a5c84aa6a7eec46c59373c23b43be3c585499
--- /dev/null
+++ b/model/pvcnn/modules/functional/src/sampling/sampling.hpp
@@ -0,0 +1,12 @@
+#ifndef _SAMPLING_HPP
+#define _SAMPLING_HPP
+
+#include
+
+at::Tensor gather_features_forward(at::Tensor features, at::Tensor indices);
+at::Tensor gather_features_backward(at::Tensor grad_y, at::Tensor indices,
+ const int n);
+at::Tensor furthest_point_sampling_forward(at::Tensor coords,
+ const int num_samples);
+
+#endif
diff --git a/model/pvcnn/modules/functional/src/utils.hpp b/model/pvcnn/modules/functional/src/utils.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..f4f21a07ec174f15aa2aa0974a5903488b0d7be3
--- /dev/null
+++ b/model/pvcnn/modules/functional/src/utils.hpp
@@ -0,0 +1,20 @@
+#ifndef _UTILS_HPP
+#define _UTILS_HPP
+
+#include
+#include
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
+
+#define CHECK_CONTIGUOUS(x) \
+ TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
+
+#define CHECK_IS_INT(x) \
+ TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, \
+ #x " must be an int tensor")
+
+#define CHECK_IS_FLOAT(x) \
+ TORCH_CHECK(x.scalar_type() == at::ScalarType::Float, \
+ #x " must be a float tensor")
+
+#endif
diff --git a/model/pvcnn/modules/functional/src/voxelization/vox.cpp b/model/pvcnn/modules/functional/src/voxelization/vox.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..6a84594f25e78d1bad74131744c27402885b1dad
--- /dev/null
+++ b/model/pvcnn/modules/functional/src/voxelization/vox.cpp
@@ -0,0 +1,76 @@
+#include "vox.hpp"
+#include "vox.cuh"
+
+#include "../utils.hpp"
+
+/*
+ Function: average pool voxelization (forward)
+ Args:
+ features: features, FloatTensor[b, c, n]
+ coords : coords of each point, IntTensor[b, 3, n]
+ resolution : voxel resolution
+ Return:
+ out : outputs, FloatTensor[b, c, s], s = r ** 3
+ ind : voxel index of each point, IntTensor[b, n]
+ cnt : #points in each voxel index, IntTensor[b, s]
+*/
+std::vector avg_voxelize_forward(const at::Tensor features,
+ const at::Tensor coords,
+ const int resolution) {
+ CHECK_CUDA(features);
+ CHECK_CUDA(coords);
+ CHECK_CONTIGUOUS(features);
+ CHECK_CONTIGUOUS(coords);
+ CHECK_IS_FLOAT(features);
+ CHECK_IS_INT(coords);
+
+ int b = features.size(0);
+ int c = features.size(1);
+ int n = features.size(2);
+ int r = resolution;
+ int r2 = r * r;
+ int r3 = r2 * r;
+ at::Tensor ind = torch::zeros(
+ {b, n}, at::device(features.device()).dtype(at::ScalarType::Int));
+ at::Tensor out = torch::zeros(
+ {b, c, r3}, at::device(features.device()).dtype(at::ScalarType::Float));
+ at::Tensor cnt = torch::zeros(
+ {b, r3}, at::device(features.device()).dtype(at::ScalarType::Int));
+ avg_voxelize(b, c, n, r, r2, r3, coords.data_ptr(),
+ features.data_ptr(), ind.data_ptr(),
+ cnt.data_ptr(), out.data_ptr());
+ return {out, ind, cnt};
+}
+
+/*
+ Function: average pool voxelization (backward)
+ Args:
+ grad_y : grad outputs, FloatTensor[b, c, s]
+ indices: voxel index of each point, IntTensor[b, n]
+ cnt : #points in each voxel index, IntTensor[b, s]
+ Return:
+ grad_x : grad inputs, FloatTensor[b, c, n]
+*/
+at::Tensor avg_voxelize_backward(const at::Tensor grad_y,
+ const at::Tensor indices,
+ const at::Tensor cnt) {
+ CHECK_CUDA(grad_y);
+ CHECK_CUDA(indices);
+ CHECK_CUDA(cnt);
+ CHECK_CONTIGUOUS(grad_y);
+ CHECK_CONTIGUOUS(indices);
+ CHECK_CONTIGUOUS(cnt);
+ CHECK_IS_FLOAT(grad_y);
+ CHECK_IS_INT(indices);
+ CHECK_IS_INT(cnt);
+
+ int b = grad_y.size(0);
+ int c = grad_y.size(1);
+ int s = grad_y.size(2);
+ int n = indices.size(1);
+ at::Tensor grad_x = torch::zeros(
+ {b, c, n}, at::device(grad_y.device()).dtype(at::ScalarType::Float));
+ avg_voxelize_grad(b, c, n, s, indices.data_ptr(), cnt.data_ptr(),
+ grad_y.data_ptr(), grad_x.data_ptr());
+ return grad_x;
+}
diff --git a/model/pvcnn/modules/functional/src/voxelization/vox.cu b/model/pvcnn/modules/functional/src/voxelization/vox.cu
new file mode 100644
index 0000000000000000000000000000000000000000..1c1a2c92a646dde44886506dd8728cff21dbe533
--- /dev/null
+++ b/model/pvcnn/modules/functional/src/voxelization/vox.cu
@@ -0,0 +1,126 @@
+#include
+#include
+
+#include "../cuda_utils.cuh"
+
+/*
+ Function: get how many points in each voxel grid
+ Args:
+ b : batch size
+ n : number of points
+ r : voxel resolution
+ r2 : = r * r
+ r3 : s, voxel cube size = r ** 3
+ coords : coords of each point, IntTensor[b, 3, n]
+ ind : voxel index of each point, IntTensor[b, n]
+ cnt : #points in each voxel index, IntTensor[b, s]
+*/
+__global__ void grid_stats_kernel(int b, int n, int r, int r2, int r3,
+ const int *__restrict__ coords,
+ int *__restrict__ ind, int *cnt) {
+ int batch_index = blockIdx.x;
+ int stride = blockDim.x;
+ int index = threadIdx.x;
+ coords += batch_index * n * 3;
+ ind += batch_index * n;
+ cnt += batch_index * r3;
+
+ for (int i = index; i < n; i += stride) {
+ // if (ind[i] == -1)
+ // continue;
+ ind[i] = coords[i] * r2 + coords[i + n] * r + coords[i + n + n];
+ atomicAdd(cnt + ind[i], 1);
+ }
+}
+
+/*
+ Function: average pool voxelization (forward)
+ Args:
+ b : batch size
+ c : #channels
+ n : number of points
+ s : voxel cube size = voxel resolution ** 3
+ ind : voxel index of each point, IntTensor[b, n]
+ cnt : #points in each voxel index, IntTensor[b, s]
+ feat: features, FloatTensor[b, c, n]
+ out : outputs, FloatTensor[b, c, s]
+*/
+__global__ void avg_voxelize_kernel(int b, int c, int n, int s,
+ const int *__restrict__ ind,
+ const int *__restrict__ cnt,
+ const float *__restrict__ feat,
+ float *__restrict__ out) {
+ int batch_index = blockIdx.x;
+ int stride = blockDim.x;
+ int index = threadIdx.x;
+ ind += batch_index * n;
+ feat += batch_index * c * n;
+ out += batch_index * c * s;
+ cnt += batch_index * s;
+ for (int i = index; i < n; i += stride) {
+ int pos = ind[i];
+ // if (pos == -1)
+ // continue;
+ int cur_cnt = cnt[pos];
+ if (cur_cnt > 0) {
+ float div_cur_cnt = 1.0 / static_cast(cur_cnt);
+ for (int j = 0; j < c; j++) {
+ atomicAdd(out + j * s + pos, feat[j * n + i] * div_cur_cnt);
+ }
+ }
+ }
+}
+
+/*
+ Function: average pool voxelization (backward)
+ Args:
+ b : batch size
+ c : #channels
+ n : number of points
+ r3 : voxel cube size = voxel resolution ** 3
+ ind : voxel index of each point, IntTensor[b, n]
+ cnt : #points in each voxel index, IntTensor[b, s]
+ grad_y : grad outputs, FloatTensor[b, c, s]
+ grad_x : grad inputs, FloatTensor[b, c, n]
+*/
+__global__ void avg_voxelize_grad_kernel(int b, int c, int n, int r3,
+ const int *__restrict__ ind,
+ const int *__restrict__ cnt,
+ const float *__restrict__ grad_y,
+ float *__restrict__ grad_x) {
+ int batch_index = blockIdx.x;
+ int stride = blockDim.x;
+ int index = threadIdx.x;
+ ind += batch_index * n;
+ grad_x += batch_index * c * n;
+ grad_y += batch_index * c * r3;
+ cnt += batch_index * r3;
+ for (int i = index; i < n; i += stride) {
+ int pos = ind[i];
+ // if (pos == -1)
+ // continue;
+ int cur_cnt = cnt[pos];
+ if (cur_cnt > 0) {
+ float div_cur_cnt = 1.0 / static_cast(cur_cnt);
+ for (int j = 0; j < c; j++) {
+ atomicAdd(grad_x + j * n + i, grad_y[j * r3 + pos] * div_cur_cnt);
+ }
+ }
+ }
+}
+
+void avg_voxelize(int b, int c, int n, int r, int r2, int r3, const int *coords,
+ const float *feat, int *ind, int *cnt, float *out) {
+ grid_stats_kernel<<>>(b, n, r, r2, r3, coords, ind,
+ cnt);
+ avg_voxelize_kernel<<>>(b, c, n, r3, ind, cnt,
+ feat, out);
+ CUDA_CHECK_ERRORS();
+}
+
+void avg_voxelize_grad(int b, int c, int n, int s, const int *ind,
+ const int *cnt, const float *grad_y, float *grad_x) {
+ avg_voxelize_grad_kernel<<>>(b, c, n, s, ind, cnt,
+ grad_y, grad_x);
+ CUDA_CHECK_ERRORS();
+}
diff --git a/model/pvcnn/modules/functional/src/voxelization/vox.cuh b/model/pvcnn/modules/functional/src/voxelization/vox.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..9adb0fdabce3352a2ed3f64e053f2aa8fafd17e7
--- /dev/null
+++ b/model/pvcnn/modules/functional/src/voxelization/vox.cuh
@@ -0,0 +1,10 @@
+#ifndef _VOX_CUH
+#define _VOX_CUH
+
+// CUDA function declarations
+void avg_voxelize(int b, int c, int n, int r, int r2, int r3, const int *coords,
+ const float *feat, int *ind, int *cnt, float *out);
+void avg_voxelize_grad(int b, int c, int n, int s, const int *idx,
+ const int *cnt, const float *grad_y, float *grad_x);
+
+#endif
diff --git a/model/pvcnn/modules/functional/src/voxelization/vox.hpp b/model/pvcnn/modules/functional/src/voxelization/vox.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..6e62bc39efccc29da9e6d3e6d149baeee7d022e3
--- /dev/null
+++ b/model/pvcnn/modules/functional/src/voxelization/vox.hpp
@@ -0,0 +1,15 @@
+#ifndef _VOX_HPP
+#define _VOX_HPP
+
+#include
+#include
+
+std::vector avg_voxelize_forward(const at::Tensor features,
+ const at::Tensor coords,
+ const int resolution);
+
+at::Tensor avg_voxelize_backward(const at::Tensor grad_y,
+ const at::Tensor indices,
+ const at::Tensor cnt);
+
+#endif
diff --git a/model/pvcnn/modules/functional/voxelization.py b/model/pvcnn/modules/functional/voxelization.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca4656df0c5ddd418405c3eba50b7f401f204938
--- /dev/null
+++ b/model/pvcnn/modules/functional/voxelization.py
@@ -0,0 +1,40 @@
+from torch.autograd import Function
+
+from .backend import _backend
+
+__all__ = ['avg_voxelize']
+
+
+class AvgVoxelization(Function):
+ @staticmethod
+ def forward(ctx, features, coords, resolution):
+ """
+ :param ctx:
+ :param features: Features of the point cloud, FloatTensor[B, C, N]
+ :param coords: Voxelized Coordinates of each point, IntTensor[B, 3, N]
+ :param resolution: Voxel resolution
+ :return:
+ Voxelized Features, FloatTensor[B, C, R, R, R]
+ """
+ features = features.contiguous()
+ coords = coords.int().contiguous()
+ b, c, _ = features.shape
+ out, indices, counts = _backend.avg_voxelize_forward(features, coords, resolution)
+ ctx.save_for_backward(indices, counts)
+ return out.view(b, c, resolution, resolution, resolution)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ """
+ :param ctx:
+ :param grad_output: gradient of output, FloatTensor[B, C, R, R, R]
+ :return:
+ gradient of inputs, FloatTensor[B, C, N]
+ """
+ b, c = grad_output.shape[:2]
+ indices, counts = ctx.saved_tensors
+ grad_features = _backend.avg_voxelize_backward(grad_output.contiguous().view(b, c, -1), indices, counts)
+ return grad_features, None, None
+
+
+avg_voxelize = AvgVoxelization.apply
diff --git a/model/pvcnn/modules/loss.py b/model/pvcnn/modules/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..a35cdd8a0fe83c8ca6b1d7040b66d142e76471df
--- /dev/null
+++ b/model/pvcnn/modules/loss.py
@@ -0,0 +1,10 @@
+import torch.nn as nn
+
+from . import functional as F
+
+__all__ = ['KLLoss']
+
+
+class KLLoss(nn.Module):
+ def forward(self, x, y):
+ return F.kl_loss(x, y)
diff --git a/model/pvcnn/modules/pointnet.py b/model/pvcnn/modules/pointnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..7672ae30c4379fe8791efecc1a21abf6b2bb3b32
--- /dev/null
+++ b/model/pvcnn/modules/pointnet.py
@@ -0,0 +1,118 @@
+import torch
+import torch.nn as nn
+
+from . import functional as F
+from .ball_query import BallQuery
+from .shared_mlp import SharedMLP
+
+__all__ = ['PointNetAModule', 'PointNetSAModule', 'PointNetFPModule']
+
+
+class PointNetAModule(nn.Module):
+ def __init__(self, in_channels, out_channels, include_coordinates=True):
+ super().__init__()
+ if not isinstance(out_channels, (list, tuple)):
+ out_channels = [[out_channels]]
+ elif not isinstance(out_channels[0], (list, tuple)):
+ out_channels = [out_channels]
+
+ mlps = []
+ total_out_channels = 0
+ for _out_channels in out_channels:
+ mlps.append(
+ SharedMLP(in_channels=in_channels + (3 if include_coordinates else 0),
+ out_channels=_out_channels, dim=1)
+ )
+ total_out_channels += _out_channels[-1]
+
+ self.include_coordinates = include_coordinates
+ self.out_channels = total_out_channels
+ self.mlps = nn.ModuleList(mlps)
+
+ def forward(self, inputs):
+ features, coords = inputs
+ if self.include_coordinates:
+ features = torch.cat([features, coords], dim=1)
+ coords = torch.zeros((coords.size(0), 3, 1), device=coords.device)
+ if len(self.mlps) > 1:
+ features_list = []
+ for mlp in self.mlps:
+ features_list.append(mlp(features).max(dim=-1, keepdim=True).values)
+ return torch.cat(features_list, dim=1), coords
+ else:
+ return self.mlps[0](features).max(dim=-1, keepdim=True).values, coords
+
+ def extra_repr(self):
+ return f'out_channels={self.out_channels}, include_coordinates={self.include_coordinates}'
+
+
+class PointNetSAModule(nn.Module):
+ def __init__(self, num_centers, radius, num_neighbors, in_channels, out_channels, include_coordinates=True):
+ super().__init__()
+ # print(f"PointNet module, in={in_channels}, out={out_channels}")
+ if not isinstance(radius, (list, tuple)):
+ radius = [radius]
+ if not isinstance(num_neighbors, (list, tuple)):
+ num_neighbors = [num_neighbors] * len(radius)
+ assert len(radius) == len(num_neighbors)
+ if not isinstance(out_channels, (list, tuple)):
+ out_channels = [[out_channels]] * len(radius)
+ elif not isinstance(out_channels[0], (list, tuple)):
+ out_channels = [out_channels] * len(radius)
+ assert len(radius) == len(out_channels)
+
+ groupers, mlps = [], []
+ total_out_channels = 0
+ for _radius, _out_channels, _num_neighbors in zip(radius, out_channels, num_neighbors):
+ groupers.append(
+ BallQuery(radius=_radius, num_neighbors=_num_neighbors, include_coordinates=include_coordinates)
+ )
+ mlps.append(
+ SharedMLP(in_channels=in_channels + (3 if include_coordinates else 0),
+ out_channels=_out_channels, dim=2)
+ )
+ total_out_channels += _out_channels[-1]
+
+ self.num_centers = num_centers
+ self.out_channels = total_out_channels
+ self.groupers = nn.ModuleList(groupers)
+ self.mlps = nn.ModuleList(mlps)
+
+ def forward(self, inputs):
+ features, coords, temb = inputs
+ centers_coords = F.furthest_point_sample(coords, self.num_centers) # use this to reduce the number of points to next layer
+ features_list = []
+ # print("Pointnet input shape:", features.shape)
+ for grouper, mlp in zip(self.groupers, self.mlps):
+ features, temb = mlp(grouper(coords, centers_coords, temb, features))
+ features_list.append(features.max(dim=-1).values)
+ # print("Point net output shape:", features.shape)
+ if len(features_list) > 1:
+ return features_list[0], centers_coords, temb.max(dim=-1).values if temb.shape[1] > 0 else temb
+ else:
+ return features_list[0], centers_coords, temb.max(dim=-1).values if temb.shape[1] > 0 else temb
+
+ def extra_repr(self):
+ return f'num_centers={self.num_centers}, out_channels={self.out_channels}'
+
+
+class PointNetFPModule(nn.Module):
+ def __init__(self, in_channels, out_channels):
+ # print(f"IN channels={in_channels}, out channels={out_channels}")
+ super().__init__()
+ self.mlp = SharedMLP(in_channels=in_channels, out_channels=out_channels, dim=1)
+
+ def forward(self, inputs):
+ # print(inputs.shape)
+ if len(inputs) == 3:
+ points_coords, centers_coords, centers_features, temb = inputs
+ points_features = None
+ else:
+ points_coords, centers_coords, centers_features, points_features, temb = inputs
+ interpolated_features = F.nearest_neighbor_interpolate(points_coords, centers_coords, centers_features)
+ interpolated_temb = F.nearest_neighbor_interpolate(points_coords, centers_coords, temb)
+ if points_features is not None:
+ interpolated_features = torch.cat(
+ [interpolated_features, points_features], dim=1
+ ) # concate interpolated, with original point features (394, N)
+ return self.mlp(interpolated_features), points_coords, interpolated_temb
diff --git a/model/pvcnn/modules/pvconv.py b/model/pvcnn/modules/pvconv.py
new file mode 100644
index 0000000000000000000000000000000000000000..6bc915a28f1179171ef88ab5efffc96781aacf44
--- /dev/null
+++ b/model/pvcnn/modules/pvconv.py
@@ -0,0 +1,140 @@
+import torch.nn as nn
+import torch
+
+from .voxelization import Voxelization
+from .shared_mlp import SharedMLP
+from .se import SE3d
+from . import functional as F
+
+__all__ = ['PVConv', 'Attention', 'Swish', 'PVConvReLU']
+
+
+class Swish(nn.Module):
+ def forward(self,x):
+ return x * torch.sigmoid(x)
+
+
+class Attention(nn.Module):
+ def __init__(self, in_ch, num_groups, D=3):
+ super(Attention, self).__init__()
+ assert in_ch % num_groups == 0
+ # it also has some learnable parameters
+ if D == 3:
+ self.q = nn.Conv3d(in_ch, in_ch, 1)
+ self.k = nn.Conv3d(in_ch, in_ch, 1)
+ self.v = nn.Conv3d(in_ch, in_ch, 1)
+
+ self.out = nn.Conv3d(in_ch, in_ch, 1)
+ elif D == 1:
+ self.q = nn.Conv1d(in_ch, in_ch, 1)
+ self.k = nn.Conv1d(in_ch, in_ch, 1)
+ self.v = nn.Conv1d(in_ch, in_ch, 1)
+
+ self.out = nn.Conv1d(in_ch, in_ch, 1)
+
+ self.norm = nn.GroupNorm(num_groups, in_ch)
+ self.nonlin = Swish()
+
+ self.sm = nn.Softmax(-1)
+
+
+ def forward(self, x):
+ """
+ self attention
+ reso32: Attention layer, x=torch.Size([16, 64, 16, 16, 16]), q=torch.Size([16, 64, 4096]), k=torch.Size([16, 64, 4096]), v=torch.Size([16, 64, 4096])
+ reso48: Attention layer, x=torch.Size([16, 64, 24, 24, 24]), q=torch.Size([16, 64, 13824]), k=torch.Size([16, 64, 13824]), v=torch.Size([16, 64, 13824])
+ # this can cause OOM!
+
+ :param x: (B, C, reso, reso, reso)?
+ :return:
+ """
+ B, C = x.shape[:2]
+ h = x
+
+ q = self.q(h).reshape(B,C,-1)
+ k = self.k(h).reshape(B,C,-1)
+ v = self.v(h).reshape(B,C,-1)
+
+ qk = torch.matmul(q.permute(0, 2, 1), k) #* (int(C) ** (-0.5))
+
+ w = self.sm(qk)
+
+ h = torch.matmul(v, w.permute(0, 2, 1)).reshape(B,C,*x.shape[2:])
+
+ h = self.out(h)
+
+ x = h + x
+
+ x = self.nonlin(self.norm(x)) # group norm + swish
+
+ return x
+
+class PVConv(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size, resolution, attention=False,
+ dropout=0.1, with_se=False, with_se_relu=False, normalize=True, eps=0):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size
+ self.resolution = resolution
+
+ self.voxelization = Voxelization(resolution, normalize=normalize, eps=eps)
+ voxel_layers = [
+ nn.Conv3d(in_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2),
+ nn.GroupNorm(num_groups=8, num_channels=out_channels),
+ Swish()
+ ]
+ voxel_layers += [nn.Dropout(dropout)] if dropout is not None else []
+ voxel_layers += [
+ nn.Conv3d(out_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2),
+ nn.GroupNorm(num_groups=8, num_channels=out_channels),
+ Attention(out_channels, 8) if attention else Swish()
+ ]
+ if with_se:
+ voxel_layers.append(SE3d(out_channels, use_relu=with_se_relu))
+ self.voxel_layers = nn.Sequential(*voxel_layers)
+ self.point_features = SharedMLP(in_channels, out_channels) # this is basically an MLP
+
+ def forward(self, inputs):
+ features, coords, temb = inputs # features: (B, F, N), temb: sinusoidal embedding of diffusion timestaps
+ voxel_features, voxel_coords = self.voxelization(features, coords)
+ voxel_features = self.voxel_layers(voxel_features)
+ voxel_features = F.trilinear_devoxelize(voxel_features, voxel_coords, self.resolution, self.training)
+ fused_features = voxel_features + self.point_features(features)
+ return fused_features, coords, temb # coords is not changed, and also temb
+
+
+
+class PVConvReLU(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size, resolution, attention=False, leak=0.2,
+ dropout=0.1, with_se=False, with_se_relu=False, normalize=True, eps=0):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size
+ self.resolution = resolution
+
+ self.voxelization = Voxelization(resolution, normalize=normalize, eps=eps)
+ voxel_layers = [
+ nn.Conv3d(in_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2),
+ nn.BatchNorm3d(out_channels),
+ nn.LeakyReLU(leak, True)
+ ]
+ voxel_layers += [nn.Dropout(dropout)] if dropout is not None else []
+ voxel_layers += [
+ nn.Conv3d(out_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2),
+ nn.BatchNorm3d(out_channels),
+ Attention(out_channels, 8) if attention else nn.LeakyReLU(leak, True)
+ ]
+ if with_se:
+ voxel_layers.append(SE3d(out_channels, use_relu=with_se_relu))
+ self.voxel_layers = nn.Sequential(*voxel_layers)
+ self.point_features = SharedMLP(in_channels, out_channels)
+
+ def forward(self, inputs):
+ features, coords, temb = inputs
+ voxel_features, voxel_coords = self.voxelization(features, coords)
+ voxel_features = self.voxel_layers(voxel_features)
+ voxel_features = F.trilinear_devoxelize(voxel_features, voxel_coords, self.resolution, self.training)
+ fused_features = voxel_features + self.point_features(features)
+ return fused_features, coords, temb
diff --git a/model/pvcnn/modules/se.py b/model/pvcnn/modules/se.py
new file mode 100644
index 0000000000000000000000000000000000000000..c34eef769cebbb0f9df44d573288b02169ddfbac
--- /dev/null
+++ b/model/pvcnn/modules/se.py
@@ -0,0 +1,19 @@
+import torch.nn as nn
+import torch
+__all__ = ['SE3d']
+
+class Swish(nn.Module):
+ def forward(self,x):
+ return x * torch.sigmoid(x)
+class SE3d(nn.Module):
+ def __init__(self, channel, reduction=8, use_relu=False):
+ super().__init__()
+ self.fc = nn.Sequential(
+ nn.Linear(channel, channel // reduction, bias=False),
+ nn.ReLU(True) if use_relu else Swish() ,
+ nn.Linear(channel // reduction, channel, bias=False),
+ nn.Sigmoid()
+ )
+
+ def forward(self, inputs):
+ return inputs * self.fc(inputs.mean(-1).mean(-1).mean(-1)).view(inputs.shape[0], inputs.shape[1], 1, 1, 1)
diff --git a/model/pvcnn/modules/shared_mlp.py b/model/pvcnn/modules/shared_mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..71fd6350024d5fbaa8a97adfd304013f6df5234e
--- /dev/null
+++ b/model/pvcnn/modules/shared_mlp.py
@@ -0,0 +1,38 @@
+import torch.nn as nn
+import torch
+
+__all__ = ['SharedMLP']
+
+
+class Swish(nn.Module):
+ def forward(self,x):
+ return x * torch.sigmoid(x)
+
+class SharedMLP(nn.Module):
+ def __init__(self, in_channels, out_channels, dim=1):
+ super().__init__()
+ if dim == 1: # default value
+ conv = nn.Conv1d
+ bn = nn.GroupNorm
+ elif dim == 2:
+ conv = nn.Conv2d
+ bn = nn.GroupNorm
+ else:
+ raise ValueError
+ if not isinstance(out_channels, (list, tuple)):
+ out_channels = [out_channels]
+ layers = []
+ for oc in out_channels:
+ layers.extend([
+ conv(in_channels, oc, 1),
+ bn(8, oc),
+ Swish(),
+ ])
+ in_channels = oc
+ self.layers = nn.Sequential(*layers)
+
+ def forward(self, inputs):
+ if isinstance(inputs, (list, tuple)):
+ return (self.layers(inputs[0]), *inputs[1:])
+ else:
+ return self.layers(inputs)
diff --git a/model/pvcnn/modules/voxelization.py b/model/pvcnn/modules/voxelization.py
new file mode 100644
index 0000000000000000000000000000000000000000..d19597ed23d6753a670613f42d56755e7ffd7b41
--- /dev/null
+++ b/model/pvcnn/modules/voxelization.py
@@ -0,0 +1,28 @@
+import torch
+import torch.nn as nn
+
+from . import functional as F
+
+__all__ = ['Voxelization']
+
+
+class Voxelization(nn.Module):
+ def __init__(self, resolution, normalize=True, eps=0):
+ super().__init__()
+ self.r = int(resolution)
+ self.normalize = normalize
+ self.eps = eps
+
+ def forward(self, features, coords):
+ coords = coords.detach()
+ norm_coords = coords - coords.mean(2, keepdim=True)
+ if self.normalize:
+ norm_coords = norm_coords / (norm_coords.norm(dim=1, keepdim=True).max(dim=2, keepdim=True).values * 2.0 + self.eps) + 0.5 # within a unit cube of size 1x1x1
+ else:
+ norm_coords = (norm_coords + 1) / 2.0
+ norm_coords = torch.clamp(norm_coords * self.r, 0, self.r - 1)
+ vox_coords = torch.round(norm_coords).to(torch.int32)
+ return F.avg_voxelize(features, vox_coords, self.r), norm_coords
+
+ def extra_repr(self):
+ return 'resolution={}{}'.format(self.r, ', normalized eps = {}'.format(self.eps) if self.normalize else '')
diff --git a/model/pvcnn/pos_enc.py b/model/pvcnn/pos_enc.py
new file mode 100644
index 0000000000000000000000000000000000000000..67e6070297775ff0f8f6ec2b4f3075af1f338c9e
--- /dev/null
+++ b/model/pvcnn/pos_enc.py
@@ -0,0 +1,88 @@
+"""
+positional encoding
+
+
+
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+
+class Embedder:
+ "adapted from https://github.com/yenchenlin/nerf-pytorch/blob/master/run_nerf_helpers.py#L48"
+ def __init__(self, **kwargs):
+ """
+ default config:
+
+ :param kwargs:
+ """
+ self.kwargs = kwargs
+ self.create_embedding_fn()
+
+ def create_embedding_fn(self):
+ embed_fns = []
+ d = self.kwargs['input_dims']
+ out_dim = 0
+ if self.kwargs['include_input']:
+ embed_fns.append(lambda x: x)
+ out_dim += d
+
+ max_freq = self.kwargs['max_freq_log2']
+ N_freqs = self.kwargs['num_freqs']
+
+ if self.kwargs['log_sampling']:
+ freq_bands = 2. ** torch.linspace(0., max_freq, steps=N_freqs)
+ else:
+ freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, steps=N_freqs)
+
+ for freq in freq_bands:
+ for p_fn in self.kwargs['periodic_fns']:
+ embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
+ out_dim += d
+
+ self.embed_fns = embed_fns
+ self.out_dim = out_dim
+
+ def embed(self, inputs):
+ """
+
+ :param inputs: (N_rays, N_samples, 3)
+ :return: (N_rays, N_samples, D)
+ """
+ return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
+
+
+def get_embedder(multires, i=0, input_dims=3):
+ if i == -1:
+ return nn.Identity(), 3
+
+ embed_kwargs = {
+ 'include_input': True,
+ 'input_dims': input_dims,
+ 'max_freq_log2': multires - 1,
+ 'num_freqs': multires,
+ 'log_sampling': True,
+ 'periodic_fns': [torch.sin, torch.cos],
+ }
+
+ embedder_obj = Embedder(**embed_kwargs)
+ embed = lambda x, eo=embedder_obj: eo.embed(x)
+ return embed, embedder_obj.out_dim
+
+
+def test():
+ ""
+ # x = torch.randn(10, 50, 3)
+ # embed, _ = get_embedder(10)
+ # enc = embed(x)
+ # print(enc.shape) # torch.Size([10, 50, 63])
+ # print(x[0, :2])
+ # print(enc[0, :2]) # this encoding already includes the input coordinates
+
+ embed, _ = get_embedder(15, input_dims=1)
+ enc = embed(torch.randn(1, 1, 1))
+ print(enc.shape) # (1, 1, 31) 2*multires + input_dims
+
+if __name__ == '__main__':
+ test()
\ No newline at end of file
diff --git a/model/pvcnn/pvcnn.py b/model/pvcnn/pvcnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c2f18de2f7cad44217a4b5738c717b8338e92c6
--- /dev/null
+++ b/model/pvcnn/pvcnn.py
@@ -0,0 +1,174 @@
+import numpy as np
+import torch
+import torch.nn as nn
+
+from model.pvcnn.modules import Attention
+from model.pvcnn.pvcnn_utils import create_mlp_components, create_pointnet2_sa_components, create_pointnet2_fp_modules
+from model.pvcnn.pvcnn_utils import get_timestep_embedding
+
+
+class PVCNN2Base(nn.Module):
+ def __init__(
+ self,
+ num_classes: int,
+ embed_dim: int,
+ use_att: bool = True,
+ dropout: float = 0.1,
+ extra_feature_channels: int = 3,
+ width_multiplier: int = 1,
+ voxel_resolution_multiplier: int = 1
+ ):
+ super().__init__()
+ assert extra_feature_channels >= 0
+ self.embed_dim = embed_dim
+ self.dropout = dropout
+ self.width_multiplier = width_multiplier
+
+ self.in_channels = extra_feature_channels + 3
+
+ # Create PointNet-2 model
+ sa_layers, sa_in_channels, channels_sa_features, _ = create_pointnet2_sa_components(
+ sa_blocks_config=self.sa_blocks,
+ extra_feature_channels=extra_feature_channels,
+ with_se=True,
+ embed_dim=embed_dim,
+ use_att=use_att,
+ dropout=dropout,
+ width_multiplier=width_multiplier,
+ voxel_resolution_multiplier=voxel_resolution_multiplier
+ )
+ self.sa_layers = nn.ModuleList(sa_layers)
+
+ # Additional global attention module, default true
+ self.global_att = None if not use_att else Attention(channels_sa_features, 8, D=1)
+
+ # Only use extra features in the last fp module
+ sa_in_channels[0] = extra_feature_channels
+ fp_layers, channels_fp_features = create_pointnet2_fp_modules(
+ fp_blocks=self.fp_blocks,
+ in_channels=channels_sa_features,
+ sa_in_channels=sa_in_channels,
+ with_se=True,
+ embed_dim=embed_dim,
+ use_att=use_att,
+ dropout=dropout,
+ width_multiplier=width_multiplier,
+ voxel_resolution_multiplier=voxel_resolution_multiplier
+ )
+ self.fp_layers = nn.ModuleList(fp_layers)
+
+ # Create MLP layers
+ self.channels_fp_features = channels_fp_features
+ layers, _ = create_mlp_components(
+ in_channels=channels_fp_features,
+ out_channels=[128, dropout, num_classes], # was 0.5
+ classifier=True,
+ dim=2,
+ width_multiplier=width_multiplier
+ )
+ self.classifier = nn.Sequential(*layers) # applied to point features directly
+
+ # Time embedding function
+ self.embedf = nn.Sequential(
+ nn.Linear(embed_dim, embed_dim),
+ nn.LeakyReLU(0.1, inplace=True),
+ nn.Linear(embed_dim, embed_dim),
+ )
+
+ def forward(self, inputs: torch.Tensor, t: torch.Tensor, ret_feats=False):
+ """
+ The inputs have size (B, 3 + S, N), where S is the number of additional
+ feature channels and N is the number of points. The timesteps t can be either
+ continuous or discrete. This model has a sort of U-Net-like structure I think,
+ which is why it first goes down and then up in terms of resolution (?)
+
+ torch.Size([16, 394, 16384])
+ Downscaling step 0 feature shape: torch.Size([16, 64, 1024])
+ Downscaling step 1 feature shape: torch.Size([16, 128, 256])
+ Downscaling step 2 feature shape: torch.Size([16, 256, 64])
+ Downscaling step 3 feature shape: torch.Size([16, 512, 16])
+ Upscaling step 0 feature shape: torch.Size([16, 256, 64])
+ Upscaling step 1 feature shape: torch.Size([16, 256, 256])
+ Upscaling step 2 feature shape: torch.Size([16, 128, 1024])
+ Upscaling step 3 feature shape: torch.Size([16, 64, 16384])
+
+ """
+
+ # Embed timesteps, sinusoidal encoding
+ t_emb = get_timestep_embedding(self.embed_dim, t, inputs.device).float()
+ t_emb = self.embedf(t_emb)[:, :, None].expand(-1, -1, inputs.shape[-1])
+
+ # Separate input coordinates and features
+ coords = inputs[:, :3, :].contiguous() # (B, 3, N) range (-3.5, 3.5)
+ features = inputs # (B, 3 + S, N)
+
+ # Downscaling layers
+ coords_list = []
+ in_features_list = []
+ for i, sa_blocks in enumerate(self.sa_layers):
+ in_features_list.append(features)
+ coords_list.append(coords)
+ if i == 0:
+ features, coords, t_emb = sa_blocks((features, coords, t_emb))
+ else:
+ features, coords, t_emb = sa_blocks((torch.cat([features, t_emb], dim=1), coords, t_emb))
+
+ # Replace the input features
+ in_features_list[0] = inputs[:, 3:, :].contiguous()
+
+ # Apply global attention layer
+ if self.global_att is not None:
+ features = self.global_att(features)
+
+ # Upscaling layers
+ feats_list = [] # save intermediate features from the decoder layers
+ for fp_idx, fp_blocks in enumerate(self.fp_layers):
+ features, coords, t_emb = fp_blocks(
+ ( # this is a tuple because of nn.Sequential
+ coords_list[-1 - fp_idx], # reverse coords list from above
+ coords, # original point coordinates
+ torch.cat([features, t_emb], dim=1), # keep concatenating upsampled features with timesteps
+ in_features_list[-1 - fp_idx], # reverse features list from above
+ t_emb # original timestep embedding
+ ) # this is where point voxel convolution is carried out, the point feature network preserves the order.
+ )
+ feats_list.append((features, coords)) # t_emb is always the same
+
+ # exit(0)
+ # Output MLP layers
+ output = self.classifier(features)
+
+ if ret_feats:
+ return output, feats_list # return intermediate features
+
+ return output
+
+
+class PVCNN2(PVCNN2Base):
+ # exact same configuration from PVD: https://github.com/alexzhou907/PVD/blob/9747265a5f141e5546fd4f862bfa66aa59f1bd33/train_completion.py#L375
+ # conv_configs, sa_configs
+ # conv_configs: (out_ch, num_blocks, voxel_reso), sa_configs: (num_centers, radius, num_neighbors, out_channels)
+ sa_blocks = [
+ ((32, 2, 32), (1024, 0.1, 32, (32, 64))), # the first is out_channels, num_blocks, voxel_resolution
+ ((64, 3, 16), (256, 0.2, 32, (64, 128))),
+ ((128, 3, 8), (64, 0.4, 32, (128, 256))),
+ (None, (16, 0.8, 32, (256, 256, 512))),
+ ]
+ fp_blocks = [
+ ((256, 256), (256, 3, 8)),
+ ((256, 256), (256, 3, 8)),
+ ((256, 128), (128, 2, 16)),
+ ((128, 128, 64), (64, 2, 32)),
+ ]
+
+ def __init__(self, num_classes, embed_dim, use_att=True, dropout=0.1, extra_feature_channels=3,
+ width_multiplier=1, voxel_resolution_multiplier=1):
+ super().__init__(
+ num_classes=num_classes, embed_dim=embed_dim, use_att=use_att,
+ dropout=dropout, extra_feature_channels=extra_feature_channels,
+ width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
+ )
+
+
+
+
diff --git a/model/pvcnn/pvcnn_ho.py b/model/pvcnn/pvcnn_ho.py
new file mode 100644
index 0000000000000000000000000000000000000000..00aa0e1e3947cff1379a230c8429d5cf2a271a0d
--- /dev/null
+++ b/model/pvcnn/pvcnn_ho.py
@@ -0,0 +1,416 @@
+"""
+two separate diffusion models for human+object, with cross attention to communicate between them
+"""
+import numpy as np
+import torch
+import torch.nn as nn
+from torch import Tensor
+from typing import Optional, Tuple
+import math
+
+from model.pvcnn.modules import Attention, PVConv, BallQueryHO
+from model.pvcnn.pvcnn_utils import create_mlp_components, create_pointnet2_sa_components, create_pointnet2_fp_modules
+from model.pvcnn.pvcnn_utils import get_timestep_embedding
+import torch.nn.functional as F
+from .pos_enc import get_embedder
+
+
+def _scaled_dot_product_attention(
+ q: Tensor,
+ k: Tensor,
+ v: Tensor,
+ attn_mask: Optional[Tensor] = None,
+ dropout_p: float = 0.0,
+) -> Tuple[Tensor, Tensor]:
+ r"""
+ Computes scaled dot product attention on query, key and value tensors, using
+ an optional attention mask if passed, and applying dropout if a probability
+ greater than 0.0 is specified.
+ Returns a tensor pair containing attended values and attention weights.
+
+ Args:
+ q, k, v: query, key and value tensors. See Shape section for shape details.
+ attn_mask: optional tensor containing mask values to be added to calculated
+ attention. May be 2D or 3D; see Shape section for details.
+ dropout_p: dropout probability. If greater than 0.0, dropout is applied.
+
+ Shape:
+ - q: :math:`(B, Nt, E)` where B is batch size, Nt is the target sequence length,
+ and E is embedding dimension.
+ - key: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,
+ and E is embedding dimension.
+ - value: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,
+ and E is embedding dimension.
+ - attn_mask: either a 3D tensor of shape :math:`(B, Nt, Ns)` or a 2D tensor of
+ shape :math:`(Nt, Ns)`.
+
+ - Output: attention values have shape :math:`(B, Nt, E)`; attention weights
+ have shape :math:`(B, Nt, Ns)`
+ """
+ B, Nt, E = q.shape
+ q = q / math.sqrt(E)
+ # (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns)
+ if attn_mask is not None:
+ attn = torch.baddbmm(attn_mask, q, k.transpose(-2, -1))
+ else:
+ attn = torch.bmm(q, k.transpose(-2, -1))
+
+ attn = F.softmax(attn, dim=-1)
+ if dropout_p > 0.0:
+ attn = F.dropout(attn, p=dropout_p) # this is only for training?
+ # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)
+ output = torch.bmm(attn, v)
+ return output, attn
+
+
+class PVCNN2HumObj(nn.Module):
+ sa_blocks = [
+ # (out_channel, num_blocks, voxel_reso), (num_centers, radius, num_neighbors, out_channels)
+ ((32, 2, 32), (1024, 0.1, 32, (32, 64))),
+ ((64, 3, 16), (256, 0.2, 32, (64, 128))),
+ ((128, 3, 8), (64, 0.4, 32, (128, 256))),
+ (None, (16, 0.8, 32, (256, 256, 512))),
+ ]
+ fp_blocks = [
+ # (out, in_channels), (out_channels, num_blocks, voxel_resolution)
+ ((256, 256), (256, 3, 8)),
+ ((256, 256), (256, 3, 8)),
+ ((256, 128), (128, 2, 16)),
+ ((128, 128, 64), (64, 2, 32)),
+ ]
+
+ def __init__(
+ self,
+ num_classes: int,
+ embed_dim: int,
+ use_att: bool = True,
+ dropout: float = 0.1,
+ extra_feature_channels: int = 3,
+ width_multiplier: int = 1,
+ voxel_resolution_multiplier: int = 1,
+ attn_type: str = 'simple-cross', #
+ attn_weight: float=1.0, # attention feature weight
+ multires: int = 10, # positional encoding resolution
+ num_neighbours: int = 32 # ball query neighbours
+ ):
+ super(PVCNN2HumObj, self).__init__()
+ assert extra_feature_channels >= 0
+ self.embed_dim = embed_dim
+ self.dropout = dropout
+ self.width_multiplier = width_multiplier
+ self.num_neighbours = num_neighbours
+ self.in_channels = extra_feature_channels + 3
+
+ self.attn_type = attn_type # how to compute attention
+ self.attn_weight = attn_weight
+
+ # separate human/object model
+ classifier, embedf, fp_layers, global_att, sa_layers = self.make_modules(dropout, embed_dim,
+ extra_feature_channels, num_classes,
+ use_att, voxel_resolution_multiplier,
+ width_multiplier)
+
+ self.sa_layers_hum = sa_layers
+ self.global_att_hum = global_att
+ self.fp_layers_hum = fp_layers
+ self.classifier_hum = classifier
+ self.embedf_hum = embedf
+ self.posi_encoder, _ = get_embedder(multires)
+
+ classifier, embedf, fp_layers, global_att, sa_layers = self.make_modules(dropout, embed_dim,
+ extra_feature_channels, num_classes,
+ use_att, voxel_resolution_multiplier,
+ width_multiplier)
+
+ self.sa_layers_obj = sa_layers
+ self.global_att_obj = global_att
+ self.fp_layers_obj = fp_layers
+ self.classifier_obj = classifier
+ self.embedf_obj = embedf
+
+ self.make_coord_attn()
+ assert self.attn_type == 'coord3d+posenc-learnable', f'unknown attention type {self.attn_type}'
+
+ def make_modules(self, dropout, embed_dim, extra_feature_channels, num_classes, use_att,
+ voxel_resolution_multiplier, width_multiplier):
+ """
+ make module for human/object
+ :param dropout:
+ :param embed_dim:
+ :param extra_feature_channels:
+ :param num_classes:
+ :param use_att:
+ :param voxel_resolution_multiplier:
+ :param width_multiplier:
+ :return:
+ """
+ in_ch_multiplier = 1
+ extra_in_channel = 63 # the segmentation+positional feature is projected to dim 63
+
+ # Create PointNet-2 model
+ sa_layers, sa_in_channels, channels_sa_features, _ = create_pointnet2_sa_components(
+ sa_blocks_config=self.sa_blocks,
+ extra_feature_channels=extra_feature_channels,
+ with_se=True,
+ embed_dim=embed_dim,
+ use_att=use_att,
+ dropout=dropout,
+ width_multiplier=width_multiplier,
+ voxel_resolution_multiplier=voxel_resolution_multiplier,
+ in_ch_multiplier=in_ch_multiplier,
+ extra_in_channel=extra_in_channel
+ )
+ sa_layers = nn.ModuleList(sa_layers)
+ # Additional global attention module, default true
+ if self.attn_type == 'coord3d+posenc+rgb':
+ # reduce channel number, only for the global attention layer, the decoders remain unchanged
+ global_att = None if not use_att else Attention(channels_sa_features//2, 8, D=1)
+ else:
+ global_att = None if not use_att else Attention(channels_sa_features, 8, D=1)
+
+ # Only use extra features in the last fp module
+ sa_in_channels[0] = extra_feature_channels
+ fp_layers, channels_fp_features = create_pointnet2_fp_modules(
+ fp_blocks=self.fp_blocks,
+ in_channels=channels_sa_features,
+ sa_in_channels=sa_in_channels,
+ with_se=True,
+ embed_dim=embed_dim,
+ use_att=use_att,
+ dropout=dropout,
+ width_multiplier=width_multiplier,
+ voxel_resolution_multiplier=voxel_resolution_multiplier,
+ in_ch_multiplier=in_ch_multiplier,
+ extra_in_channel=extra_in_channel
+ )
+ fp_layers = nn.ModuleList(fp_layers)
+
+ # Create MLP layers for output prediction
+ layers, _ = create_mlp_components(
+ in_channels=channels_fp_features,
+ out_channels=[128, dropout, num_classes], # was 0.5
+ classifier=True,
+ dim=2,
+ width_multiplier=width_multiplier
+ )
+ classifier = nn.Sequential(*layers) # applied to point features directly
+ # Time embedding function
+ embedf = nn.Sequential(
+ nn.Linear(embed_dim, embed_dim),
+ nn.LeakyReLU(0.1, inplace=True),
+ nn.Linear(embed_dim, embed_dim),
+ )
+ return classifier, embedf, fp_layers, global_att, sa_layers
+
+ def make_coord_attn(self):
+ "learnable attention only on point coordinate + positional encoding "
+ pvconv_encoders = []
+ for i, (conv_configs, sa_configs) in enumerate(self.sa_blocks):
+ # should use point net out channel
+ out_channel = 63
+ layer = nn.MultiheadAttention(out_channel, 1, batch_first=True, kdim=out_channel, vdim=out_channel+2)
+ pvconv_encoders.append(layer) # only one block for conv
+ pvconv_decoders = []
+ for fp_configs, conv_configs in self.fp_blocks:
+ out_channel = 63
+ layer = nn.MultiheadAttention(out_channel, 1, batch_first=True, kdim=out_channel, vdim=out_channel + 2)
+ pvconv_decoders.append(layer)
+ self.cross_conv_encoders = nn.ModuleList(pvconv_encoders)
+ self.cross_conv_decoders = nn.ModuleList(pvconv_decoders)
+
+ def forward(self, inputs_hum: torch.Tensor, inputs_obj: torch.Tensor, t: torch.Tensor, norm_params=None):
+ """
+
+ :param inputs: (B, N, D), N is the number of points, D is the conditional feature dimension
+ :param t: (B, ) timestamps
+ :param norm_params: (2, B, 4), transformation parameters that move points back to H+O joint space, first 3 values are cent, the last is radius/scale
+ :return: (B, N, D_out) x2
+ """
+ inputs_hum = inputs_hum.transpose(1, 2)
+ inputs_obj = inputs_obj.transpose(1, 2)
+
+ # Embed timesteps, sinusoidal encoding
+ t_emb_init = get_timestep_embedding(self.embed_dim, t, inputs_hum.device).float()
+ t_emb_hum = self.embedf_hum(t_emb_init)[:, :, None].expand(-1, -1, inputs_hum.shape[-1]).float()
+ t_emb_obj = self.embedf_obj(t_emb_init)[:, :, None].expand(-1, -1, inputs_obj.shape[-1]).float()
+
+ # Separate input coordinates and features
+ coords_hum, coords_obj = inputs_hum[:, :3, :].contiguous(), inputs_obj[:, :3, :].contiguous() # (B, 3, N) range (-3.5, 3.5)
+ features_hum, features_obj = inputs_hum, inputs_obj # (B, 3 + S, N)
+
+ DEBUG = False
+
+ # Encoder: Downscaling layers
+ coords_list_hum, coords_list_obj = [], []
+ in_features_list_hum, in_features_list_obj = [], []
+ for i, (sa_blocks_h, sa_blocks_o) in enumerate(zip(self.sa_layers_hum, self.sa_layers_obj)):
+ in_features_list_hum.append(features_hum)
+ coords_list_hum.append(coords_hum)
+ in_features_list_obj.append(features_obj)
+ coords_list_obj.append(coords_obj)
+ if i == 0:
+ # First step no timestamp embedding
+ features_hum, coords_hum, t_emb_hum = sa_blocks_h((features_hum, coords_hum, t_emb_hum))
+ features_obj, coords_obj, t_emb_obj = sa_blocks_o((features_obj, coords_obj, t_emb_obj))
+ else:
+ features_hum, coords_hum, t_emb_hum = sa_blocks_h((torch.cat([features_hum, t_emb_hum], dim=1), coords_hum, t_emb_hum))
+ features_obj, coords_obj, t_emb_obj = sa_blocks_o((torch.cat([features_obj, t_emb_obj], dim=1), coords_obj, t_emb_obj))
+
+ if i < len(self.sa_layers_hum)-1:
+ features_hum, features_obj = self.add_attn_feature(features_hum, features_obj,
+ self.transform_coords(coords_hum, norm_params, 0),
+ self.transform_coords(coords_obj, norm_params, 1),
+ self.cross_conv_encoders[i],
+ temb_hum=t_emb_hum,
+ temb_obj=t_emb_obj)
+
+ # for debug: save some point clouds
+ if DEBUG:
+ for i, (ch, co) in enumerate(zip(coords_list_hum, coords_list_obj)):
+ import trimesh
+ ch_ho = self.transform_coords(ch, norm_params, 0)
+ co_ho = self.transform_coords(co, norm_params, 1)
+ points = torch.cat([ch_ho, co_ho], -1).transpose(1, 2)
+ L = ch_ho.shape[-1]
+ vc = np.concatenate(
+ [np.zeros((L, 3)) + np.array([0.5, 1.0, 0]),
+ np.zeros((L, 3)) + np.array([0.05, 1.0, 1.0])]
+ )
+ trimesh.PointCloud(points[0].cpu().numpy(), colors=vc).export(
+ f'/BS/xxie-2/work/pc2-diff/experiments/debug/meshes/encoder_step{i:02d}.ply')
+
+ # Replace the input features
+ in_features_list_hum[0] = inputs_hum[:, 3:, :].contiguous()
+ in_features_list_obj[0] = inputs_obj[:, 3:, :].contiguous()
+
+ # Apply global attention layer
+ if self.global_att_hum is not None:
+ features_hum = self.global_att_hum(features_hum)
+
+ if self.global_att_obj is not None:
+ features_obj = self.global_att_obj(features_obj)
+ # Do cross attention after self-attention
+ if self.attn_type in ['coord3d+posenc-learnable']:
+ features_hum, features_obj = self.add_attn_feature(features_hum, features_obj,
+ self.transform_coords(coords_hum, norm_params, 0),
+ self.transform_coords(coords_obj, norm_params, 1),
+ self.cross_conv_encoders[-1] if self.attn_type in [
+ 'coord3d+posenc-learnable'] else None,
+ temb_hum=t_emb_hum,
+ temb_obj=t_emb_obj)
+
+ # Upscaling layers
+ for fp_idx, (fp_blocks_h, fp_blocks_o) in enumerate(zip(self.fp_layers_hum, self.fp_layers_obj)):
+ features_hum, coords_hum, t_emb_hum = fp_blocks_h(
+ ( # this is a tuple because of nn.Sequential
+ coords_list_hum[-1 - fp_idx], # reverse coords list from above
+ coords_hum, # original point coordinates
+ torch.cat([features_hum, t_emb_hum], dim=1), # keep concatenating upsampled features with timesteps
+ in_features_list_hum[-1 - fp_idx], # reverse features list from above
+ t_emb_hum # original timestep embedding
+ )
+ # this is where point voxel convolution is carried out, the point feature network preserves the order.
+ )
+ features_obj, coords_obj, t_emb_obj = fp_blocks_o(
+ ( # this is a tuple because of nn.Sequential
+ coords_list_obj[-1 - fp_idx], # reverse coords list from above
+ coords_obj, # original point coordinates
+ torch.cat([features_obj, t_emb_obj], dim=1), # keep concatenating upsampled features with timesteps
+ in_features_list_obj[-1 - fp_idx], # reverse features list from above
+ t_emb_obj # original timestep embedding
+ )
+ # this is where point voxel convolution is carried out, the point feature network preserves the order.
+ )
+
+ # these features are reused as input for next layer
+ # add attention except for the last layer
+ if fp_idx < len(self.fp_layers_hum) - 1:
+ # Perform cross attention between human and object branches
+ features_hum, features_obj = self.add_attn_feature(features_hum, features_obj,
+ self.transform_coords(coords_hum, norm_params, 0),
+ self.transform_coords(coords_obj, norm_params, 1),
+ self.cross_conv_decoders[fp_idx] if self.attn_type in ['coord3d+posenc-learnable'] else None,
+ temb_hum=t_emb_hum,
+ temb_obj=t_emb_obj
+ )
+
+ if DEBUG:
+ import trimesh
+ ch_ho = self.transform_coords(coords_hum, norm_params, 0)
+ co_ho = self.transform_coords(coords_obj, norm_params, 1)
+ points = torch.cat([ch_ho, co_ho], -1).transpose(1, 2)
+ L = ch_ho.shape[-1]
+ vc = np.concatenate(
+ [np.zeros((L, 3)) + np.array([0.5, 1.0, 0]),
+ np.zeros((L, 3)) + np.array([0.05, 1.0, 1.0])]
+ )
+ trimesh.PointCloud(points[0].cpu().numpy(), colors=vc).export(
+ f'/BS/xxie-2/work/pc2-diff/experiments/debug/meshes/decoder_step{fp_idx:02d}.ply')
+
+ if DEBUG:
+ exit(0)
+ # Output MLP layers
+ output_hum = self.classifier_hum(features_hum).transpose(1, 2) # convert back to (B, N, D) format
+ output_obj = self.classifier_obj(features_obj).transpose(1, 2)
+
+ return output_hum, output_obj
+
+ def transform_coords(self, coords, norm_params, target_ind):
+ """
+ transform coordinates such that the points align back to H+O interaction space
+ :param coords: (B, 3, N)
+ :param norm_params: (2, B, 4)
+ :param target_ind: 0 or 1
+ :return:
+ """
+ scale = norm_params[target_ind, :, 3:].unsqueeze(1)
+ cent = norm_params[target_ind, :, :3].unsqueeze(-1)
+ coords_ho = coords * 2 * scale + cent
+ return coords_ho
+
+ def add_attn_feature(self, features_hum, features_obj,
+ coords_hum=None, coords_obj=None,
+ attn_module=None,
+ temb_hum=None, temb_obj=None):
+ """
+ compute cross attention between human and object points
+ :param features_hum: (B, D, N)
+ :param features_obj: (B, D, N)
+ :param coords_hum: (B, 3, N), human points in the H+O frame
+ :param coords_obj: (B, 3, N), object points in the H+O frame
+ :param temb: time embedding
+ :return: cross attended human object features.
+ """
+ B, D, N = features_hum.shape
+ # the attn_module is learnable, only difference is the number of output feature dimension
+ onehot_hum, onehot_obj = self.get_onehot_feat(features_hum)
+ pos_hum = self.posi_encoder(coords_hum.permute(0, 2, 1)).permute(0, 2, 1)
+ pos_obj = self.posi_encoder(coords_obj.permute(0, 2, 1)).permute(0, 2, 1)
+ feat_hum = torch.cat([pos_hum, onehot_hum], 1)
+ feat_obj = torch.cat([pos_obj, onehot_obj], 1) # (B, 65, N)
+
+ attn_h2o = attn_module(pos_obj.permute(0, 2, 1),
+ pos_hum.permute(0, 2, 1),
+ feat_hum.permute(0, 2, 1))[0].permute(0, 2, 1)
+ attn_o2h = attn_module(pos_hum.permute(0, 2, 1),
+ pos_obj.permute(0, 2, 1),
+ feat_obj.permute(0, 2, 1))[0].permute(0, 2, 1)
+ features_hum = torch.cat([features_hum, attn_o2h * self.attn_weight], 1)
+ features_obj = torch.cat([features_obj, attn_h2o * self.attn_weight], 1)
+
+ return features_hum, features_obj
+
+ def get_onehot_feat(self, features_hum):
+ """
+ compute a onehot feature vector to identify this is human or object
+ :param features_hum:
+ :return: (B, 2, N) x2 for human and object
+ """
+ B, D, N = features_hum.shape
+ onehot_hum = torch.zeros(B, 2, N).to(features_hum.device)
+ onehot_hum[:, 0] = 1.
+ onehot_obj = torch.zeros(B, 2, N).to(features_hum.device)
+ onehot_obj[:, 1] = 1.0
+ return onehot_hum, onehot_obj
+
+
diff --git a/model/pvcnn/pvcnn_plus_plus.py b/model/pvcnn/pvcnn_plus_plus.py
new file mode 100644
index 0000000000000000000000000000000000000000..8bdc235f29c43f41f0a43ef47abc0ba683a1a46a
--- /dev/null
+++ b/model/pvcnn/pvcnn_plus_plus.py
@@ -0,0 +1,42 @@
+import torch
+import torch.nn as nn
+
+from model.pvcnn.pvcnn import PVCNN2
+from model.pvcnn.pvcnn_utils import create_mlp_components
+from model.simple.simple_model import SimplePointModel
+
+
+class PVCNN2PlusPlus(nn.Module):
+ def __init__(
+ self,
+ *,
+ embed_dim,
+ num_classes,
+ extra_feature_channels,
+ ):
+ super().__init__()
+
+ # Create models
+ self.simple_point_model = SimplePointModel(num_classes=embed_dim, embed_dim=embed_dim,
+ extra_feature_channels=extra_feature_channels, num_layers=3)
+ self.pvcnn = PVCNN2(num_classes=embed_dim, embed_dim=embed_dim,
+ extra_feature_channels=(embed_dim - 3))
+
+ # Tie timestep embeddings
+ self.pvcnn.embedf = self.simple_point_model.timestep_projection
+
+ # # Remove output projections
+ # self.pvcnn.classifier = nn.Identity()
+ # self.simple_point_model.output_projection = nn.Identity()
+
+ # Create new output projection
+ layers, _ = create_mlp_components(
+ in_channels=embed_dim, out_channels=[128, self.pvcnn.dropout, num_classes],
+ classifier=True, dim=2, width_multiplier=self.pvcnn.width_multiplier)
+ self.output_projection = nn.Sequential(*layers)
+
+ def forward(self, inputs: torch.Tensor, t: torch.Tensor):
+ x = self.simple_point_model(inputs, t) # (B, D_emb, N)
+ x = x + self.pvcnn(x, t) # (B, D_emb, N)
+ x = self.output_projection(x) # (B, D_out, N)
+ return x
diff --git a/model/pvcnn/pvcnn_utils.py b/model/pvcnn/pvcnn_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..81533821fba921f8f8e6a1bedeed62cc88963cb0
--- /dev/null
+++ b/model/pvcnn/pvcnn_utils.py
@@ -0,0 +1,216 @@
+import functools
+import torch
+import torch.nn as nn
+import numpy as np
+
+from model.pvcnn.modules import SharedMLP, PVConv, PointNetSAModule, PointNetAModule, PointNetFPModule, Swish
+
+
+def _linear_gn_relu(in_channels, out_channels):
+ return nn.Sequential(nn.Linear(in_channels, out_channels), nn.GroupNorm(8,out_channels), Swish())
+
+
+def create_mlp_components(in_channels, out_channels, classifier=False, dim=2, width_multiplier=1):
+ r = width_multiplier
+
+ if dim == 1:
+ block = _linear_gn_relu
+ else:
+ block = SharedMLP
+ if not isinstance(out_channels, (list, tuple)):
+ out_channels = [out_channels]
+ if len(out_channels) == 0 or (len(out_channels) == 1 and out_channels[0] is None):
+ return nn.Sequential(), in_channels, in_channels
+
+ layers = []
+ for oc in out_channels[:-1]:
+ if oc < 1:
+ layers.append(nn.Dropout(oc))
+ else:
+ oc = int(r * oc)
+ layers.append(block(in_channels, oc))
+ in_channels = oc
+ if dim == 1:
+ if classifier:
+ layers.append(nn.Linear(in_channels, out_channels[-1]))
+ else:
+ layers.append(_linear_gn_relu(in_channels, int(r * out_channels[-1])))
+ else:
+ if classifier:
+ layers.append(nn.Conv1d(in_channels, out_channels[-1], 1))
+ else:
+ layers.append(SharedMLP(in_channels, int(r * out_channels[-1])))
+ return layers, out_channels[-1] if classifier else int(r * out_channels[-1])
+
+
+def create_pointnet_components(blocks, in_channels, embed_dim, with_se=False, normalize=True, eps=0,
+ width_multiplier=1, voxel_resolution_multiplier=1):
+ r, vr = width_multiplier, voxel_resolution_multiplier
+
+ layers, concat_channels = [], 0
+ c = 0
+ for k, (out_channels, num_blocks, voxel_resolution) in enumerate(blocks):
+ out_channels = int(r * out_channels)
+ for p in range(num_blocks):
+ attention = k % 2 == 0 and k > 0 and p == 0
+ if voxel_resolution is None:
+ block = SharedMLP
+ else:
+ block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention,
+ with_se=with_se, normalize=normalize, eps=eps)
+
+ if c == 0:
+ layers.append(block(in_channels, out_channels))
+ else:
+ layers.append(block(in_channels+embed_dim, out_channels))
+ in_channels = out_channels
+ concat_channels += out_channels
+ c += 1
+ return layers, in_channels, concat_channels
+
+
+def create_pointnet2_sa_components(sa_blocks_config, extra_feature_channels, embed_dim=64, use_att=False,
+ dropout=0.1, with_se=False, normalize=True, eps=0,
+ width_multiplier=1, voxel_resolution_multiplier=1,
+ in_ch_multiplier=1,
+ extra_in_channel=0):
+ "use_att is True by default, in_ch_multiplier: increase the input channel dimension"
+ r, vr = width_multiplier, voxel_resolution_multiplier
+ in_channels = extra_feature_channels + 3
+
+ sa_layers, sa_in_channels = [], []
+ block_count = 0
+ for conv_configs, sa_configs in sa_blocks_config:
+ k = 0
+ sa_in_channels.append(in_channels)
+ sa_blocks = []
+
+ if conv_configs is not None:
+ out_channels, num_blocks, voxel_resolution = conv_configs
+ out_channels = int(r * out_channels)
+ for p in range(num_blocks): # pconv is repeated
+ attention = (block_count+1) % 2 == 0 and use_att and p == 0
+ if voxel_resolution is None:
+ block = SharedMLP
+ else:
+ block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention,
+ dropout=dropout,
+ with_se=with_se, with_se_relu=True,
+ normalize=normalize, eps=eps)
+
+ if block_count == 0:
+ sa_blocks.append(block(in_channels, out_channels))
+ elif k ==0:
+ sa_blocks.append(block(in_channels+embed_dim, out_channels))
+ in_channels = out_channels
+ k += 1
+ extra_feature_channels = in_channels
+ num_centers, radius, num_neighbors, out_channels = sa_configs
+ _out_channels = []
+ for oc in out_channels:
+ if isinstance(oc, (list, tuple)):
+ _out_channels.append([int(r * _oc) for _oc in oc])
+ else:
+ _out_channels.append(int(r * oc))
+ out_channels = _out_channels
+ if num_centers is None:
+ block = PointNetAModule # always not-none
+ else:
+ block = functools.partial(PointNetSAModule, num_centers=num_centers, radius=radius,
+ num_neighbors=num_neighbors)
+ sa_blocks.append(block(in_channels=extra_feature_channels+(embed_dim if k==0 else 0 ), out_channels=out_channels,
+ include_coordinates=True))
+ block_count += 1
+ # XH: double the channel for concat, or add additional channel for cross attention
+ if block_count < len(sa_blocks_config):
+ in_channels = extra_feature_channels = int(sa_blocks[-1].out_channels * in_ch_multiplier + extra_in_channel)
+ else:
+ # no cross attention before the self attention module
+ in_channels = extra_feature_channels = int(sa_blocks[-1].out_channels * in_ch_multiplier)
+ if len(sa_blocks) == 1:
+ sa_layers.append(sa_blocks[0]) # first pconv is repeated ?
+ else:
+ sa_layers.append(nn.Sequential(*sa_blocks))
+
+ return sa_layers, sa_in_channels, in_channels, 1 if num_centers is None else num_centers
+
+
+def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, embed_dim=64, use_att=False,
+ dropout=0.1,
+ with_se=False, normalize=True, eps=0,
+ width_multiplier=1, voxel_resolution_multiplier=1,
+ in_ch_multiplier=1, extra_in_channel=0):
+ """
+
+ :param fp_blocks:
+ :param in_channels:
+ :param sa_in_channels:
+ :param embed_dim:
+ :param use_att:
+ :param dropout:
+ :param with_se:
+ :param normalize:
+ :param eps:
+ :param width_multiplier:
+ :param voxel_resolution_multiplier:
+ :param in_ch_multiplier: increase the input channel dimension
+ :return:
+ """
+ r, vr = width_multiplier, voxel_resolution_multiplier
+
+ fp_layers = []
+ c = 0
+ for fp_idx, (fp_configs, conv_configs) in enumerate(fp_blocks):
+ fp_blocks = []
+ out_channels = tuple(int(r * oc) for oc in fp_configs)
+ if fp_idx > 0:
+ # to handle additional channel from concatenating human + object features
+ sa_in_concat = int(in_channels*in_ch_multiplier + extra_in_channel)
+ else:
+ sa_in_concat = in_channels + extra_in_channel # this is for simple-coord3d, where the decoder first layer also has cross attention
+ fp_blocks.append(
+ PointNetFPModule(in_channels=sa_in_concat + sa_in_channels[-1 - fp_idx] + embed_dim,
+ out_channels=out_channels)
+ ) # interpolate + Conv1d, does not change number of points
+ in_channels = out_channels[-1]
+
+ if conv_configs is not None:
+ out_channels, num_blocks, voxel_resolution = conv_configs
+ out_channels = int(r * out_channels)
+ for p in range(num_blocks):
+ attention = (c+1) % 2 == 0 and c < len(fp_blocks) - 1 and use_att and p == 0
+ if voxel_resolution is None:
+ block = SharedMLP
+ else:
+ block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention,
+ dropout=dropout,
+ with_se=with_se, with_se_relu=True,
+ normalize=normalize, eps=eps)
+
+ fp_blocks.append(block(in_channels, out_channels))
+ in_channels = out_channels # this should not change!
+ if len(fp_blocks) == 1:
+ fp_layers.append(fp_blocks[0]) # this is the last block, no PVConv layer
+ else:
+ fp_layers.append(nn.Sequential(*fp_blocks))
+
+ c += 1
+
+ return fp_layers, in_channels
+
+
+def get_timestep_embedding(embed_dim, timesteps, device):
+ """
+ Timestep embedding function. Not that this should work just as well for
+ continuous values as for discrete values.
+ """
+ assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
+ half_dim = embed_dim // 2
+ emb = np.log(10000) / (half_dim - 1)
+ emb = torch.from_numpy(np.exp(np.arange(0, half_dim) * -emb)).float().to(device)
+ emb = timesteps[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embed_dim % 2 == 1: # zero pad
+ emb = nn.functional.pad(emb, (0, 1), "constant", 0)
+ assert emb.shape == torch.Size([timesteps.shape[0], embed_dim])
+ return emb
diff --git a/model/simple/__init__.py b/model/simple/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/simple/simple_model.py b/model/simple/simple_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..ceb25a5cc0dae72298bceed7739050891810ce35
--- /dev/null
+++ b/model/simple/simple_model.py
@@ -0,0 +1,80 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import nn
+
+from .simple_model_utils import FeedForward, BasePointModel
+
+
+class SimplePointModel(BasePointModel):
+ """
+ A simple model that processes a point cloud by applying a series of MLPs to each point
+ individually, along with some pooled global features.
+ """
+
+ def get_layers(self):
+ return nn.ModuleList([FeedForward(
+ d_in=(3 * self.dim), d_hidden=(4 * self.dim), d_out=self.dim,
+ activation=nn.SiLU(), is_gated=True, bias1=False, bias2=False, bias_gate=False, use_layernorm=True
+ ) for _ in range(self.num_layers)])
+
+ def forward(self, inputs: torch.Tensor, t: torch.Tensor):
+
+ # Prepare inputs
+ x, coords = self.prepare_inputs(inputs, t)
+
+ # Model
+ for layer in self.layers:
+ x_pool_max, x_pool_std = self.get_global_tensors(x)
+ x_input = torch.cat((x, x_pool_max, x_pool_std), dim=-1) # (B, N, 3 * D)
+ x = x + layer(x_input) # (B, N, D_model)
+
+ # Project
+ x = self.output_projection(x) # (B, N, D_out)
+ x = torch.transpose(x, -2, -1) # -> (B, D_out, N)
+
+ return x
+
+
+class SimpleNearestNeighborsPointModel(BasePointModel):
+ """
+ A simple model that processes a point cloud by applying a series of MLPs to each point
+ individually, along with some pooled global features, and the features of its nearest
+ neighbors.
+ """
+
+ def __init__(self, num_neighbors: int = 4, **kwargs):
+ self.num_neighbors = num_neighbors
+ super().__init__(**kwargs)
+ from pytorch3d.ops import knn_points
+ self.knn_points = knn_points
+
+ def get_layers(self):
+ return nn.ModuleList([FeedForward(
+ d_in=((3 + self.num_neighbors) * self.dim), d_hidden=(4 * self.dim), d_out=self.dim,
+ activation=nn.SiLU(), is_gated=True, bias1=False, bias2=False, bias_gate=False, use_layernorm=True
+ ) for _ in range(self.num_layers)])
+
+ def forward(self, inputs: torch.Tensor, t: torch.Tensor):
+
+ # Prepare inputs
+ x, coords = self.prepare_inputs(inputs, t) # (B, N, D), (B, N, 3)
+
+ # Get nearest neighbors. Note that the first neighbor is the identity, which is convenient
+ _dists, indices, _neighbors = self.knn_points(
+ p1=coords, p2=coords, K=(self.num_neighbors + 1),
+ return_nn=False) # (B, N, K), (B, N, K)
+ (B, N, D), (_B, _N, K) = x.shape, indices.shape
+
+ # Model
+ for layer in self.layers:
+ x_neighbor = torch.stack([x_i[idx] for x_i, idx in zip(x, indices.reshape(B, N * K))]).reshape(B, N, K * D)
+ x_pool_max, x_pool_std = self.get_global_tensors(x)
+ x_input = torch.cat((x_neighbor, x_pool_max, x_pool_std), dim=-1) # (B, N, (3+K)*D)
+ x = x + layer(x_input) # (B, N, D_model)
+
+ # Project
+ x = self.output_projection(x) # (B, N, D_out)
+ x = torch.transpose(x, -2, -1) # -> (B, D_out, N)
+
+ return x
diff --git a/model/simple/simple_model_utils.py b/model/simple/simple_model_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd84cd1f2f2a235715b98a74c428dcf6d2e36104
--- /dev/null
+++ b/model/simple/simple_model_utils.py
@@ -0,0 +1,282 @@
+from typing import Any, Callable, Iterable, List, Optional, Union
+
+import torch
+import torch.jit as jit
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Size, Tensor, nn
+from torch.nn import LayerNorm
+
+from model.pvcnn.pvcnn_utils import get_timestep_embedding
+
+
+def sample_b(size: Size, sigma: float) -> Tensor:
+ """Sample b matrix for fourier features
+
+ Arguments:
+ size (Size): b matrix size
+ sigma (float): std of the gaussian
+
+ Returns:
+ b (Tensor): b matrix
+ """
+ return torch.randn(size) * sigma
+
+
+@jit.script
+def map_positional_encoding(v: Tensor, freq_bands: Tensor) -> Tensor:
+ """Map v to positional encoding representation phi(v)
+
+ Arguments:
+ v (Tensor): input features (B, IFeatures)
+ freq_bands (Tensor): frequency bands (N_freqs, )
+
+ Returns:
+ phi(v) (Tensor): fourrier features (B, 3 + (2 * N_freqs) * 3)
+ """
+ pe = [v]
+ for freq in freq_bands:
+ fv = freq * v
+ pe += [torch.sin(fv), torch.cos(fv)]
+ return torch.cat(pe, dim=-1)
+
+
+@jit.script
+def map_fourier_features(v: Tensor, b: Tensor) -> Tensor:
+ """Map v to fourier features representation phi(v)
+
+ Arguments:
+ v (Tensor): input features (B, IFeatures)
+ b (Tensor): b matrix (OFeatures, IFeatures)
+
+ Returns:
+ phi(v) (Tensor): fourrier features (B, 2 * Features)
+ """
+ PI = 3.141592653589793
+ a = 2 * PI * v @ b.T
+ return torch.cat((torch.sin(a), torch.cos(a)), dim=-1)
+
+
+class FeatureMapping(nn.Module):
+ """FeatureMapping nn.Module
+
+ Maps v to features following transformation phi(v)
+
+ Arguments:
+ i_dim (int): input dimensions
+ o_dim (int): output dimensions
+ """
+
+ def __init__(self, i_dim: int, o_dim: int) -> None:
+ super().__init__()
+ self.i_dim = i_dim
+ self.o_dim = o_dim
+
+ def forward(self, v: Tensor) -> Tensor:
+ """FeratureMapping forward pass
+
+ Arguments:
+ v (Tensor): input features (B, IFeatures)
+
+ Returns:
+ phi(v) (Tensor): mapped features (B, OFeatures)
+ """
+ raise NotImplementedError("Forward pass not implemented yet!")
+
+
+class PositionalEncoding(FeatureMapping):
+ """PositionalEncoding module
+
+ Maps v to positional encoding representation phi(v)
+
+ Arguments:
+ i_dim (int): input dimension for v
+ N_freqs (int): #frequency to sample (default: 10)
+ """
+
+ def __init__(
+ self,
+ i_dim: int,
+ N_freqs: int = 10,
+ ) -> None:
+ super().__init__(i_dim, 3 + (2 * N_freqs) * 3)
+ self.N_freqs = N_freqs
+
+ a, b = 1, self.N_freqs - 1
+ freq_bands = 2 ** torch.linspace(a, b, self.N_freqs)
+ self.register_buffer("freq_bands", freq_bands)
+
+ def forward(self, v: Tensor) -> Tensor:
+ """Map v to positional encoding representation phi(v)
+
+ Arguments:
+ v (Tensor): input features (B, IFeatures)
+
+ Returns:
+ phi(v) (Tensor): fourrier features (B, 3 + (2 * N_freqs) * 3)
+ """
+ return map_positional_encoding(v, self.freq_bands)
+
+
+class FourierFeatures(FeatureMapping):
+
+ """Fourier Features module
+
+ Maps v to fourier features representation phi(v)
+
+ Arguments:
+ i_dim (int): input dimension for v
+ features (int): output dimension (default: 256)
+ sigma (float): std of the gaussian (default: 26.)
+ """
+
+ def __init__(
+ self,
+ i_dim: int,
+ features: int = 256,
+ sigma: float = 26.,
+ ) -> None:
+ super().__init__(i_dim, 2 * features)
+ self.features = features
+ self.sigma = sigma
+
+ self.size = Size((self.features, self.i_dim))
+ self.register_buffer("b", sample_b(self.size, self.sigma))
+
+ def forward(self, v: Tensor) -> Tensor:
+ """Map v to fourier features representation phi(v)
+
+ Arguments:
+ v (Tensor): input features (B, IFeatures)
+
+ Returns:
+ phi(v) (Tensor): fourrier features (B, 2 * Features)
+ """
+ return map_fourier_features(v, self.b)
+
+
+class FeedForward(nn.Module):
+ """ Adapted from the FeedForward layer from labmlai """
+
+ def __init__(
+ self,
+ d_in: int,
+ d_hidden: int,
+ d_out: int,
+ activation: Callable = nn.ReLU(),
+ is_gated: bool = False,
+ bias1: bool = True,
+ bias2: bool = True,
+ bias_gate: bool = True,
+ dropout: float = 0.1,
+ use_layernorm: bool = False,
+ ):
+ super().__init__()
+ # Layer one parameterized by weight $W_1$ and bias $b_1$
+ self.layer1 = nn.Linear(d_in, d_hidden, bias=bias1)
+ # Layer one parameterized by weight $W_1$ and bias $b_1$
+ self.layer2 = nn.Linear(d_hidden, d_out, bias=bias2)
+ # Hidden layer dropout
+ self.dropout = nn.Dropout(dropout)
+ # Activation function $f$
+ self.activation = activation
+ # Whether there is a gate
+ self.is_gated = is_gated
+ if is_gated:
+ # If there is a gate the linear layer to transform inputs to
+ # be multiplied by the gate, parameterized by weight $V$ and bias $c$
+ self.linear_v = nn.Linear(d_in, d_hidden, bias=bias_gate)
+ # Whether to add a layernorm layer
+ self.use_layernorm = use_layernorm
+ if use_layernorm:
+ self.layernorm = LayerNorm(d_in)
+
+ def forward(self, x: Tensor, coords: Tensor = None) -> Tensor:
+ """Applies a simple feed forward layer"""
+ x = self.layernorm(x) if self.use_layernorm else x
+ g = self.activation(self.layer1(x))
+ x = (g * self.linear_v(x)) if self.is_gated else g
+ x = self.dropout(x)
+ x = self.layer2(x)
+ return x
+
+
+class BasePointModel(nn.Module):
+ """ A base class providing useful methods for point cloud processing. """
+
+ def __init__(
+ self,
+ *,
+ num_classes,
+ embed_dim,
+ extra_feature_channels,
+ dim: int = 128,
+ num_layers: int = 6
+ ):
+
+ super().__init__()
+ self.extra_feature_channels = extra_feature_channels
+ self.timestep_embed_dim = embed_dim
+ self.output_dim = num_classes
+ self.dim = dim
+ self.num_layers = num_layers
+
+ # Time embedding function
+ self.timestep_projection = nn.Sequential(
+ nn.Linear(embed_dim, embed_dim),
+ nn.LeakyReLU(0.1, inplace=True),
+ nn.Linear(embed_dim, embed_dim),
+ )
+
+ # Positional encoding
+ self.positional_encoding = PositionalEncoding(i_dim=3, N_freqs=10)
+ positional_encoding_d_out = 3 + (2 * 10) * 3
+
+ # Input projection (point coords, point coord encodings, other features, and timestep embeddings)
+ self.input_projection = nn.Linear(
+ in_features=(3 + positional_encoding_d_out + extra_feature_channels + self.timestep_embed_dim),
+ out_features=self.dim
+ )
+
+ # Transformer layers
+ self.layers = self.get_layers()
+
+ # Output projection
+ self.output_projection = nn.Linear(self.dim, self.output_dim)
+
+ def get_layers(self):
+ raise NotImplementedError('This method should be implemented by subclasses')
+
+ def prepare_inputs(self, inputs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
+ """
+ The inputs have size (B, 3 + S, N), where S is the number of additional
+ feature channels and N is the number of points. The timesteps t can be either
+ continuous or discrete. This model has a sort of U-Net-like structure I think,
+ which is why it first goes down and then up in terms of resolution (?)
+ """
+
+ # Embed and project timesteps
+ t_emb = get_timestep_embedding(self.timestep_embed_dim, t, inputs.device)
+ t_emb = self.timestep_projection(t_emb)[:, None, :].expand(-1, inputs.shape[-1], -1) # (B, N, D_t_emb)
+
+ # Separate input coordinates and features
+ x = torch.transpose(inputs, -2, -1) # -> (B, N, 3 + S)
+ coords = x[:, :, :3] # (B, N, 3), point coordinates
+
+ # Positional encoding of point coords
+ coords_posenc = self.positional_encoding(coords) # (B, N, D_p_enc)
+
+ # Project
+ x = torch.cat((x, coords_posenc, t_emb), dim=2) # (B, N, 3 + S + D_p_enc + D_t_emb)
+ x = self.input_projection(x) # (B, N, D_model)
+
+ return x, coords
+
+ def get_global_tensors(self, x: Tensor):
+ B, N, D = x.shape
+ x_pool_max = torch.max(x, dim=1, keepdim=True).values.repeat(1, N, 1) # (B, 1, D)
+ x_pool_std = torch.std(x, dim=1, keepdim=True).repeat(1, N, 1) # (B, 1, D)
+ return x_pool_max, x_pool_std
+
+ def forward(self, inputs: torch.Tensor, t: torch.Tensor):
+ raise NotImplementedError('This method should be implemented by subclasses')
\ No newline at end of file
diff --git a/render/__init__.py b/render/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/render/pyt3d_wrapper.py b/render/pyt3d_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba2c025ee5383cfe1bde3412b0e051569dfffd80
--- /dev/null
+++ b/render/pyt3d_wrapper.py
@@ -0,0 +1,302 @@
+"""
+a simple wrapper for pytorch3d rendering
+Cite: BEHAVE: Dataset and Method for Tracking Human Object Interaction
+"""
+import numpy as np
+import torch
+from copy import deepcopy
+# Data structures and functions for rendering
+from pytorch3d.renderer import (
+ PointLights,
+ RasterizationSettings,
+ MeshRenderer,
+ MeshRasterizer,
+ SoftPhongShader,
+ TexturesVertex,
+ PerspectiveCameras,
+ PointsRasterizer,
+ AlphaCompositor,
+ PointsRasterizationSettings,
+)
+from pytorch3d.structures import Meshes, join_meshes_as_scene, Pointclouds
+
+SMPL_OBJ_COLOR_LIST = [
+ [0.65098039, 0.74117647, 0.85882353], # SMPL
+ [251 / 255.0, 128 / 255.0, 114 / 255.0], # object
+ ]
+
+
+class MeshRendererWrapper:
+ "a simple wrapper for the pytorch3d mesh renderer"
+ def __init__(self, image_size=1200,
+ faces_per_pixel=1,
+ device='cuda:0',
+ blur_radius=0, lights=None,
+ materials=None, max_faces_per_bin=50000):
+ self.image_size = image_size
+ self.faces_per_pixel=faces_per_pixel
+ self.max_faces_per_bin=max_faces_per_bin # prevent overflow, see https://github.com/facebookresearch/pytorch3d/issues/348
+ self.blur_radius = blur_radius
+ self.device = device
+ self.lights=lights if lights is not None else PointLights(
+ ((0.5, 0.5, 0.5),), ((0.5, 0.5, 0.5),), ((0.05, 0.05, 0.05),), ((0, -2, 0),), device
+ )
+ self.materials = materials
+ self.renderer = self.setup_renderer()
+
+ def setup_renderer(self):
+ # for sillhouette rendering
+ sigma = 1e-4
+ raster_settings = RasterizationSettings(
+ image_size=self.image_size,
+ blur_radius=self.blur_radius,
+ # blur_radius=np.log(1. / 1e-4 - 1.) * sigma, # this will create large sphere for each face
+ faces_per_pixel=self.faces_per_pixel,
+ clip_barycentric_coords=False,
+ max_faces_per_bin=self.max_faces_per_bin
+ )
+ shader = SoftPhongShader(
+ device=self.device,
+ lights=self.lights,
+ materials=self.materials)
+ renderer = MeshRenderer(
+ rasterizer=MeshRasterizer(
+ raster_settings=raster_settings),
+ shader=shader
+ )
+ return renderer
+
+ def render(self, meshes, cameras, ret_mask=False, mode='rgb'):
+ assert len(meshes.faces_list()) == 1, 'currently only support batch size =1 rendering!'
+ images = self.renderer(meshes, cameras=cameras)
+ # print(images.shape)
+ if ret_mask or mode=='mask':
+ mask = images[0, ..., 3].cpu().detach().numpy()
+ return images[0, ..., :3].cpu().detach().numpy(), mask > 0
+ return images[0, ..., :3].cpu().detach().numpy()
+
+
+def get_kinect_camera(device='cuda:0', kid=1):
+ R, T = torch.eye(3), torch.zeros(3)
+ R[0, 0] = R[1, 1] = -1 # pytorch3d y-axis up, need to rotate to kinect coordinate
+ R = R.unsqueeze(0)
+ T = T.unsqueeze(0)
+ assert kid in [0, 1, 2, 3], f'invalid kinect index {kid}!'
+ if kid == 0:
+ fx, fy = 976.212, 976.047
+ cx, cy = 1017.958, 787.313
+ elif kid == 1:
+ fx, fy = 979.784, 979.840 # for original kinect coordinate system
+ cx, cy = 1018.952, 779.486
+ elif kid == 2:
+ fx, fy = 974.899, 974.337
+ cx, cy = 1018.747, 786.176
+ else:
+ fx, fy = 972.873, 972.790
+ cx, cy = 1022.0565, 770.397
+ color_w, color_h = 2048, 1536 # kinect color image size
+ cam_center = torch.tensor((cx, cy), dtype=torch.float32).unsqueeze(0)
+ focal_length = torch.tensor((fx, fy), dtype=torch.float32).unsqueeze(0)
+ cam = PerspectiveCameras(focal_length=focal_length, principal_point=cam_center,
+ image_size=((color_w, color_h),),
+ device=device,
+ R=R, T=T)
+ return cam
+
+
+class PcloudRenderer:
+ "a simple wrapper for pytorch3d point cloud renderer"
+ def __init__(self, image_size=1024, radius=0.005, points_per_pixel=10,
+ device='cuda:0', bin_size=128, batch_size=1, ret_depth=False):
+ camera_centers = []
+ focal_lengths = []
+ for i in range(batch_size):
+ camera_centers.append(torch.Tensor([image_size / 2., image_size / 2.]).to(device))
+ focal_lengths.append(torch.Tensor([image_size / 2., image_size / 2.]).to(device))
+ self.image_size = image_size
+ self.device = device
+ self.camera_center = torch.stack(camera_centers)
+ self.focal_length = torch.stack(focal_lengths)
+ self.ret_depth = ret_depth # return depth map or not
+ self.renderer = self.setup_renderer(radius, points_per_pixel, bin_size)
+
+ def render(self, pc, cameras, mode='image'):
+ # TODO: support batch rendering
+ """
+ render the point cloud, compute the world coordinate of each pixel based on zbuf
+ image: (H, W, 3)
+ xyz_world: (H, W, 3), the third dimension is the xyz coordinate in world space
+ """
+ # assert cameras.R.shape[0]==1, "batch rendering is not supported for now!"
+ images, fragments = self.renderer(pc, cameras=cameras)
+ if mode=='image':
+ if images.shape[0] == 1:
+ img = images[0, ..., :3].cpu().numpy().copy()
+ else:
+ img = images[..., :3].cpu().numpy().copy()
+ return img
+ elif mode=='mask':
+ zbuf = torch.mean(fragments.zbuf, -1) # (B, H, W)
+ masks = zbuf >= 0
+ if images.shape[0] == 1:
+ img = images[0, ..., :3].cpu().numpy()
+ masks = masks[0].cpu().numpy().astype(bool)
+ else:
+ img = images[..., :3].cpu().numpy()
+ masks = masks.cpu().numpy().astype(bool)
+
+ return img, masks
+
+ def get_xy_ndc(self):
+ """
+ return (H, W, 2), each pixel is the x,y coordinate in NDC space
+ """
+ py, px = torch.meshgrid(torch.linspace(0, self.image_size-1, self.image_size),
+ torch.linspace(0, self.image_size-1, self.image_size))
+ x_ndc = 1 - 2*px/(self.image_size - 1)
+ y_ndc = 1 - 2*py/(self.image_size - 1)
+ xy_ndc = torch.stack([x_ndc, y_ndc], axis=-1).to(self.device)
+ return xy_ndc.squeeze(0).unsqueeze(0)
+
+ def setup_renderer(self, radius, points_per_pixel, bin_size):
+ raster_settings = PointsRasterizationSettings(
+ image_size=self.image_size,
+ # radius=0.003,
+ radius=radius,
+ points_per_pixel=points_per_pixel,
+ bin_size=bin_size,
+ max_points_per_bin=500000
+ )
+ # Create a points renderer by compositing points using an alpha compositor (nearer points
+ # are weighted more heavily). See [1] for an explanation.
+ rasterizer = PointsRasterizer(raster_settings=raster_settings)
+ renderer = PointsRendererWithFragments(
+ rasterizer=rasterizer,
+ compositor=AlphaCompositor()
+ )
+ return renderer
+
+
+class PointsRendererWithFragments(torch.nn.Module):
+ def __init__(self, rasterizer, compositor):
+ super().__init__()
+ self.rasterizer = rasterizer
+ self.compositor = compositor
+
+ def forward(self, point_clouds, **kwargs) -> (torch.Tensor, torch.Tensor):
+ fragments = self.rasterizer(point_clouds, **kwargs)
+ # Construct weights based on the distance of a point to the true point.
+ # However, this could be done differently: e.g. predicted as opposed
+ # to a function of the weights.
+ r = self.rasterizer.raster_settings.radius
+
+ dists2 = fragments.dists.permute(0, 3, 1, 2)
+ weights = 1 - dists2 / (r * r)
+ images = self.compositor(
+ fragments.idx.long().permute(0, 3, 1, 2),
+ weights,
+ point_clouds.features_packed().permute(1, 0),
+ **kwargs,
+ )
+
+ # permute so image comes at the end
+ images = images.permute(0, 2, 3, 1)
+
+ return images, fragments
+
+# class PcloudsRenderer
+
+
+class DepthRasterizer(torch.nn.Module):
+ """
+ simply rasterize a mesh or point cloud to depth image
+ """
+ def __init__(self, image_size, dtype='pc',
+ radius=0.005, points_per_pixel=1,
+ bin_size=128,
+ blur_radius=0,
+ max_faces_per_bin=50000,
+ faces_per_pixel=1,):
+ """
+ image_size: (height, width)
+ """
+ super(DepthRasterizer, self).__init__()
+ if dtype == 'pc':
+ raster_settings = PointsRasterizationSettings(
+ image_size=image_size,
+ radius=radius,
+ points_per_pixel=points_per_pixel,
+ bin_size=bin_size
+ )
+ self.rasterizer = PointsRasterizer(raster_settings=raster_settings)
+ elif dtype == 'mesh':
+ raster_settings = RasterizationSettings(
+ image_size=image_size,
+ blur_radius=blur_radius,
+ # blur_radius=np.log(1. / 1e-4 - 1.) * sigma, # this will create large sphere for each face
+ faces_per_pixel=faces_per_pixel,
+ clip_barycentric_coords=False,
+ max_faces_per_bin=max_faces_per_bin
+ )
+ self.rasterizer=MeshRasterizer(raster_settings=raster_settings)
+ else:
+ raise NotImplemented
+
+ def forward(self, data, to_np=True, **kwargs):
+ fragments = self.rasterizer(data, **kwargs)
+ if to_np:
+ zbuf = fragments.zbuf # (B, H, W, points_per_pixel)
+ return zbuf[0, ..., 0].cpu().numpy()
+ return fragments.zbuf
+
+
+def test_depth_rasterizer():
+ from psbody.mesh import Mesh
+ import cv2
+ m = Mesh()
+ m.load_from_file("/BS/xxie-4/work/kindata/Sep29_shuo_chairwood_hand/t0003.000/person/person.ply")
+ device = 'cuda:0'
+ pc = Pointclouds([torch.from_numpy(m.v).float().to(device)],
+ features=[torch.from_numpy(m.vc).float().to(device)])
+ rasterizer = DepthRasterizer(image_size=(480, 640))
+ camera = get_kinect_camera(device)
+
+ depth = rasterizer(pc, cameras=camera)
+ std = torch.std(depth, -1)
+ print('max std', torch.max(std)) # maximum std is up to 1.7m, too much!
+ print('min std', torch.min(std))
+
+ print(depth.shape)
+ dmap = depth[0, ..., 0].cpu().numpy()
+ dmap[dmap<0] = 0
+ cv2.imwrite('debug/depth.png', (dmap*1000).astype(np.uint16))
+
+def test_mesh_rasterizer():
+ from psbody.mesh import Mesh
+ import cv2
+ m = Mesh()
+ m.load_from_file("/BS/xxie-4/work/kindata/Sep29_shuo_chairwood_hand/t0003.000/person/fit02/person_fit.ply")
+ device = 'cuda:0'
+ mesh = Meshes([torch.from_numpy(m.v).float().to(device)],
+ [torch.from_numpy(m.f.astype(int)).to(device)])
+ rasterizer = DepthRasterizer(image_size=(480, 640), dtype='mesh')
+ camera = get_kinect_camera(device)
+
+ depth = rasterizer(mesh, to_np=False, cameras=camera)
+
+ print(depth.shape)
+ dmap = depth[0, ..., 0].cpu().numpy()
+ dmap[dmap < 0] = 0
+ cv2.imwrite('debug/depth_mesh.png', (dmap * 1000).astype(np.uint16))
+
+
+
+if __name__ == '__main__':
+ # test_depth_rasterizer()
+ test_mesh_rasterizer()
+
+
+
+
+
+
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..42632834d92da7c3585976b86e3071f3a9fe8aca
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,15 @@
+accelerate
+diffusers
+hydra-core
+imageio
+ninja
+opencv-python-headless
+plotly
+rich
+scipy
+timm
+torch-ema
+tqdm
+transformers
+wandb
+trimesh
\ No newline at end of file
diff --git a/training_utils.py b/training_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa9d6035ef9ca2ee8457f335e2283b6ea28ac546
--- /dev/null
+++ b/training_utils.py
@@ -0,0 +1,443 @@
+"""
+Misc functions, including distributed helpers, mostly from torchvision
+"""
+import glob
+import math
+import os
+import time
+import datetime
+import random
+from dataclasses import dataclass
+from collections import defaultdict, deque
+from typing import Callable, Optional
+from PIL import Image
+import numpy as np
+import torch
+import torch.distributed as dist
+import torchvision
+from accelerate import Accelerator
+from omegaconf import DictConfig
+
+from configs.structured import ProjectConfig
+
+
+
+@dataclass
+class TrainState:
+ epoch: int = 0
+ step: int = 0
+ best_val: Optional[float] = None
+
+
+def get_optimizer(cfg: ProjectConfig, model: torch.nn.Module, accelerator: Accelerator) -> torch.optim.Optimizer:
+ """Gets optimizer from configs"""
+
+ # Determine the learning rate
+ if cfg.optimizer.scale_learning_rate_with_batch_size:
+ lr = accelerator.state.num_processes * cfg.dataloader.batch_size * cfg.optimizer.lr
+ print('lr = {ws} (num gpus) * {bs} (batch_size) * {blr} (base learning rate) = {lr}'.format(
+ ws=accelerator.state.num_processes, bs=cfg.dataloader.batch_size, blr=cfg.optimizer.lr, lr=lr))
+ else: # scale base learning rate by batch size
+ lr = cfg.optimizer.lr
+ print('lr = {lr} (absolute learning rate)'.format(lr=lr))
+
+ # Get optimizer parameters, excluding certain parameters from weight decay
+ no_decay = ["bias", "LayerNorm.weight"]
+ parameters = [
+ {
+ "params": [p for n, p in model.named_parameters() if p.requires_grad and not any(nd in n for nd in no_decay)],
+ "weight_decay": cfg.optimizer.weight_decay,
+ },
+ {
+ "params": [p for n, p in model.named_parameters() if p.requires_grad and any(nd in n for nd in no_decay)],
+ "weight_decay": 0.0,
+ },
+ ]
+
+ # Construct optimizer
+ if cfg.optimizer.type == 'torch':
+ Optimizer: torch.optim.Optimizer = getattr(torch.optim, cfg.optimizer.name)
+ optimizer = Optimizer(parameters, lr=lr, **cfg.optimizer.kwargs)
+ elif cfg.optimizer.type == 'timm':
+ from timm.optim import create_optimizer_v2
+ optimizer = create_optimizer_v2(model_or_params=parameters, lr=lr, **cfg.optimizer.kwargs)
+ elif cfg.optimizer.type == 'transformers':
+ import transformers
+ Optimizer: torch.optim.Optimizer = getattr(transformers, cfg.optimizer.name)
+ optimizer = Optimizer(parameters, lr=lr, **cfg.optimizer.kwargs)
+ else:
+ raise NotImplementedError(f'Invalid optimizer configs: {cfg.optimizer}')
+
+ return optimizer
+
+
+def get_scheduler(cfg: ProjectConfig, optimizer: torch.optim.Optimizer) -> Callable:
+ """Gets scheduler from configs"""
+
+ # Get scheduler
+ if cfg.scheduler.type == 'torch':
+ Scheduler: torch.optim.lr_scheduler._LRScheduler = getattr(torch.optim.lr_scheduler, cfg.scheduler.type)
+ scheduler = Scheduler(optimizer=optimizer, **cfg.scheduler.kwargs)
+ if cfg.scheduler.get('warmup', 0):
+ from warmup_scheduler import GradualWarmupScheduler
+ scheduler = GradualWarmupScheduler(optimizer, multiplier=1,
+ total_epoch=cfg.scheduler.warmup, after_scheduler=scheduler)
+ elif cfg.scheduler.type == 'timm':
+ from timm.scheduler import create_scheduler
+ scheduler, _ = create_scheduler(optimizer=optimizer, args=cfg.scheduler.kwargs)
+ elif cfg.scheduler.type == 'transformers':
+ from transformers import get_scheduler # default: linear scheduler without warm up and linear decay
+ scheduler = get_scheduler(optimizer=optimizer, **cfg.scheduler.kwargs)
+ else:
+ raise NotImplementedError(f'invalid scheduler configs: {cfg.scheduler}')
+
+ return scheduler
+
+
+@torch.no_grad()
+def accuracy(output, target, topk=(1,)):
+ """Computes the accuracy over the k top predictions for the specified values of k"""
+ maxk = max(topk)
+ batch_size = target.size(0)
+
+ _, pred = output.topk(maxk, 1, True, True)
+ pred = pred.t()
+ correct = pred.eq(target.reshape(1, -1).expand_as(pred))
+
+ res = []
+ for k in topk:
+ correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
+ res.append(correct_k.mul_(100.0 / batch_size))
+ return res
+
+
+class SmoothedValue(object):
+ """Track a series of values and provide access to smoothed values over a
+ window or the global series average.
+ """
+
+ def __init__(self, window_size=20, fmt=None):
+ if fmt is None:
+ fmt = "{median:.4f} ({global_avg:.4f})"
+ self.deque = deque(maxlen=window_size)
+ self.total = 0.0
+ self.count = 0
+ self.fmt = fmt
+
+ def update(self, value, n=1):
+ self.deque.append(value)
+ self.count += n
+ self.total += value * n
+
+ def synchronize_between_processes(self, device='cuda'):
+ """
+ Warning: does not synchronize the deque!
+ """
+ if not using_distributed():
+ return
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device=device)
+ dist.barrier()
+ dist.all_reduce(t)
+ t = t.tolist()
+ self.count = int(t[0])
+ self.total = t[1]
+
+ @property
+ def median(self):
+ d = torch.tensor(list(self.deque))
+ return d.median().item()
+
+ @property
+ def avg(self):
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
+ return d.mean().item()
+
+ @property
+ def global_avg(self):
+ return self.total / max(self.count, 1)
+
+ @property
+ def max(self):
+ return max(self.deque)
+
+ @property
+ def value(self):
+ return self.deque[-1]
+
+ def __str__(self):
+ return self.fmt.format(
+ median=self.median,
+ avg=self.avg,
+ global_avg=self.global_avg,
+ max=self.max,
+ value=self.value) if len(self.deque) > 0 else ""
+
+
+class MetricLogger(object):
+ def __init__(self, delimiter="\t"):
+ self.meters = defaultdict(SmoothedValue)
+ self.delimiter = delimiter
+
+ def update(self, **kwargs):
+ n = kwargs.pop('n', 1)
+ for k, v in kwargs.items():
+ if isinstance(v, torch.Tensor):
+ v = v.item()
+ assert isinstance(v, (float, int))
+ self.meters[k].update(v, n=n)
+
+ def __getattr__(self, attr):
+ if attr in self.meters:
+ return self.meters[attr]
+ if attr in self.__dict__:
+ return self.__dict__[attr]
+ raise AttributeError("'{}' object has no attribute '{}'".format(
+ type(self).__name__, attr))
+
+ def __str__(self):
+ loss_str = []
+ for name, meter in self.meters.items():
+ loss_str.append(
+ "{}: {}".format(name, str(meter))
+ )
+ return self.delimiter.join(loss_str)
+
+ def synchronize_between_processes(self, device='cuda'):
+ for meter in self.meters.values():
+ meter.synchronize_between_processes(device=device)
+
+ def add_meter(self, name, meter):
+ self.meters[name] = meter
+
+ def log_every(self, iterable, print_freq, header=None):
+ i = 0
+ if not header:
+ header = ''
+ start_time = time.time()
+ end = time.time()
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
+ data_time = SmoothedValue(fmt='{avg:.4f}')
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
+ log_msg = [
+ header,
+ '[{0' + space_fmt + '}/{1}]',
+ 'eta: {eta}',
+ '{meters}',
+ 'time: {time}',
+ 'data: {data}'
+ ]
+ if torch.cuda.is_available():
+ log_msg.append('max mem: {memory:.0f}')
+ log_msg = self.delimiter.join(log_msg)
+ MB = 1024.0 * 1024.0
+ for obj in iterable:
+ data_time.update(time.time() - end)
+ yield obj
+ iter_time.update(time.time() - end)
+ if i % print_freq == 0 or i == len(iterable) - 1:
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+ if torch.cuda.is_available():
+ print(log_msg.format(
+ i, len(iterable), eta=eta_string,
+ meters=str(self),
+ time=str(iter_time), data=str(data_time),
+ memory=torch.cuda.max_memory_allocated() / MB))
+ else:
+ print(log_msg.format(
+ i, len(iterable), eta=eta_string,
+ meters=str(self),
+ time=str(iter_time), data=str(data_time)))
+ i += 1
+ end = time.time()
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('{} Total time: {} ({:.4f} s / it)'.format(
+ header, total_time_str, total_time / len(iterable)))
+
+
+class NormalizeInverse(torchvision.transforms.Normalize):
+ """
+ Undoes the normalization and returns the reconstructed images in the input domain.
+ """
+
+ def __init__(self, mean, std):
+ mean = torch.as_tensor(mean)
+ std = torch.as_tensor(std)
+ std_inv = 1 / (std + 1e-7)
+ mean_inv = -mean * std_inv
+ super().__init__(mean=mean_inv, std=std_inv)
+
+ def __call__(self, tensor):
+ return super().__call__(tensor.clone())
+
+
+def resume_from_checkpoint(cfg: ProjectConfig, model, optimizer=None, scheduler=None, model_ema=None):
+
+ # Check if resuming training from a checkpoint
+ if not cfg.checkpoint.resume:
+ print('Starting training from scratch')
+ return TrainState()
+
+ # XH: find checkpiont path automatically
+ if not os.path.isfile(cfg.checkpoint.resume):
+ print(f"The given checkpoint path {cfg.checkpoint.resume} does not exist, trying to find one...")
+ # print(os.getcwd())
+ ckpt_file = os.path.join(cfg.run.code_dir_abs, f'outputs/{cfg.run.name}/single/checkpoint-latest.pth')
+ if not os.path.isfile(ckpt_file):
+ # just get the fist dir, for backward compatibility
+ folders = sorted(glob.glob(os.path.join(cfg.run.code_dir_abs, f'outputs/{cfg.run.name}/2023-*')))
+ assert len(folders) <= 1
+ if len(folders) > 0:
+ ckpt_file = os.path.join(folders[0], 'checkpoint-latest.pth')
+
+ if os.path.isfile(ckpt_file):
+ print(f"Found checkpoint at {ckpt_file}!")
+ cfg.checkpoint.resume = ckpt_file
+ else:
+ print(f"No checkpoint found in outputs/{cfg.run.name}/single/!")
+ return TrainState()
+
+ # If resuming, load model state dict
+ print(f'Loading checkpoint ({datetime.datetime.now()})')
+ checkpoint = torch.load(cfg.checkpoint.resume, map_location='cpu')
+ if 'model' in checkpoint:
+ state_dict, key = checkpoint['model'], 'model'
+ else:
+ print("Warning: no model found in checkpoint!")
+ state_dict, key = checkpoint, 'N/A'
+ if any(k.startswith('module.') for k in state_dict.keys()):
+ state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
+ print('Removed "module." from checkpoint state dict')
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
+ print(f'Loaded model checkpoint key {key} from {cfg.checkpoint.resume}')
+ if len(missing_keys):
+ print(f' - Missing_keys: {missing_keys}')
+ if len(unexpected_keys):
+ print(f' - Unexpected_keys: {unexpected_keys}')
+ # 298 missing, 328 unexpected! total 448 modules.
+ print(f"{len(missing_keys)} missing, {len(unexpected_keys)} unexpected! total {len(model.state_dict().keys())} modules.")
+ # print("First 10 keys:")
+ # for i in range(10):
+ # print(missing_keys[i], unexpected_keys[i])
+ # exit(0)
+ if 'step' in checkpoint:
+ print("Number of trained steps:", checkpoint['step'])
+
+ # TODO: implement better loading for fine tuning
+ # Resume model ema
+ if cfg.ema.use_ema:
+ if checkpoint['model_ema']:
+ model_ema.load_state_dict(checkpoint['model_ema'])
+ print('Loaded model ema from checkpoint')
+ else:
+ model_ema.load_state_dict(model.parameters())
+ print('No model ema in checkpoint; loaded current parameters into model')
+ else:
+ if 'model_ema' in checkpoint and checkpoint['model_ema']:
+ print('Not using model ema, but model_ema found in checkpoint (you probably want to resume it!)')
+ else:
+ print('Not using model ema, and no model_ema found in checkpoint.')
+
+ # Resume optimizer and/or training state
+ train_state = TrainState()
+ if 'train' in cfg.run.job:
+ if cfg.checkpoint.resume_training:
+ assert (
+ cfg.checkpoint.resume_training_optimizer
+ or cfg.checkpoint.resume_training_scheduler
+ or cfg.checkpoint.resume_training_state
+ or cfg.checkpoint.resume_training
+ ), f'Invalid configs: {cfg.checkpoint}'
+ if cfg.checkpoint.resume_training_optimizer:
+ if 'optimizer' not in checkpoint:
+ assert 'tune' in cfg.run.name, f'please check the checkpoint for run {cfg.run.name}'
+ print("Warning: not loading optimizer!")
+ else:
+ assert 'optimizer' in checkpoint, f'Value not in {checkpoint.keys()}'
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ print(f'Loaded optimizer from checkpoint')
+ else:
+ print(f'Did not load optimizer from checkpoint')
+ if cfg.checkpoint.resume_training_scheduler:
+ if 'scheduler' not in checkpoint:
+ assert 'tune' in cfg.run.name, f'please check the checkpoint for run {cfg.run.name}'
+ print("Warning: not loading scheduler!")
+ else:
+ assert 'scheduler' in checkpoint, f'Value not in {checkpoint.keys()}'
+ scheduler.load_state_dict(checkpoint['scheduler'])
+ print(f'Loaded scheduler from checkpoint')
+ else:
+ print(f'Did not load scheduler from checkpoint')
+ if cfg.checkpoint.resume_training_state:
+ if 'steps' in checkpoint and 'step' not in checkpoint: # fixes an old typo
+ checkpoint['step'] = checkpoint.pop('steps')
+ assert {'epoch', 'step', 'best_val'}.issubset(set(checkpoint.keys()))
+ epoch, step, best_val = checkpoint['epoch'] + 1, checkpoint['step'], checkpoint['best_val']
+ train_state = TrainState(epoch=epoch, step=step, best_val=best_val)
+ print(f'Resumed state from checkpoint: step {step}, epoch {epoch}, best_val {best_val}')
+ else:
+ print(f'Did not load train state from checkpoint')
+ else:
+ print('Did not resume optimizer, scheduler, or epoch from checkpoint')
+
+ print(f'Finished loading checkpoint ({datetime.datetime.now()})')
+
+ return train_state
+
+
+def setup_distributed_print(is_master):
+ """
+ This function disables printing when not in master process
+ """
+ import builtins as __builtin__
+ from rich import print as __richprint__
+ builtin_print = __richprint__ # __builtin__.print
+
+ def print(*args, **kwargs):
+ force = kwargs.pop('force', False)
+ if is_master or force:
+ builtin_print(*args, **kwargs)
+
+ __builtin__.print = print
+
+
+def using_distributed():
+ return dist.is_available() and dist.is_initialized()
+
+
+def get_rank():
+ return dist.get_rank() if using_distributed() else 0
+
+
+def set_seed(seed):
+ rank = get_rank()
+ seed = seed + rank
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+ torch.backends.cudnn.enabled = True
+ torch.backends.cudnn.benchmark = True
+ if using_distributed():
+ print(f'Seeding node {rank} with seed {seed}', force=True)
+ else:
+ print(f'Seeding node {rank} with seed {seed}')
+
+
+def compute_grad_norm(parameters):
+ # total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2).item()
+ total_norm = 0
+ for p in parameters:
+ if p.grad is not None and p.requires_grad:
+ param_norm = p.grad.detach().data.norm(2)
+ total_norm += param_norm.item() ** 2
+ total_norm = total_norm ** 0.5
+ return total_norm
+
+
+class dotdict(dict):
+ """dot.notation access to dictionary attributes"""
+ __getattr__ = dict.get
+ __setattr__ = dict.__setitem__
+ __delattr__ = dict.__delitem__