import inspect import random from typing import Optional import torch import torch.nn.functional as F from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from diffusers.schedulers.scheduling_ddim import DDIMScheduler from diffusers.schedulers.scheduling_pndm import PNDMScheduler from pytorch3d.implicitron.dataset.data_loader_map_provider import FrameData from pytorch3d.renderer.cameras import CamerasBase from pytorch3d.structures import Pointclouds from torch import Tensor from tqdm import tqdm from .model_utils import get_num_points, get_custom_betas from .point_cloud_model import PointCloudModel from .projection_model import PointCloudProjectionModel class ConditionalPointCloudDiffusionModel(PointCloudProjectionModel): def __init__( self, beta_start: float, beta_end: float, beta_schedule: str, point_cloud_model: str, point_cloud_model_embed_dim: int, **kwargs, # projection arguments ): super().__init__(**kwargs) # Checks if not self.predict_shape: raise NotImplementedError('Must predict shape if performing diffusion.') # Create diffusion model schedulers which define the sampling timesteps self.dm_pred_type = kwargs.get('dm_pred_type', "epsilon") assert self.dm_pred_type in ['epsilon','sample'] scheduler_kwargs = {"prediction_type": self.dm_pred_type} if beta_schedule == 'custom': scheduler_kwargs.update(dict(trained_betas=get_custom_betas(beta_start=beta_start, beta_end=beta_end))) else: scheduler_kwargs.update(dict(beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule)) self.schedulers_map = { 'ddpm': DDPMScheduler(**scheduler_kwargs, clip_sample=False), 'ddim': DDIMScheduler(**scheduler_kwargs, clip_sample=False), 'pndm': PNDMScheduler(**scheduler_kwargs), } self.scheduler = self.schedulers_map['ddpm'] # this can be changed for inference # Create point cloud model for processing point cloud at each diffusion step self.init_pcloud_model(kwargs, point_cloud_model, point_cloud_model_embed_dim) self.load_sample_init = kwargs.get('load_sample_init', False) self.sample_init_scale = kwargs.get('sample_init_scale', 1.0) self.test_init_with_gtpc = kwargs.get('test_init_with_gtpc', False) self.consistent_center = kwargs.get('consistent_center', False) self.cam_noise_std = kwargs.get('cam_noise_std', 0.0) # add noise to camera based on timestamps def init_pcloud_model(self, kwargs, point_cloud_model, point_cloud_model_embed_dim): self.point_cloud_model = PointCloudModel( model_type=point_cloud_model, embed_dim=point_cloud_model_embed_dim, in_channels=self.in_channels, out_channels=self.out_channels, # voxel resolution multiplier is 1. voxel_resolution_multiplier=kwargs.get('voxel_resolution_multiplier', 1) ) def forward_train( self, pc: Pointclouds, camera: Optional[CamerasBase], image_rgb: Optional[Tensor], mask: Optional[Tensor], return_intermediate_steps: bool = False, **kwargs ): # Normalize colors and convert to tensor x_0 = self.point_cloud_to_tensor(pc, normalize=True, scale=True) # this will not pack the point colors B, N, D = x_0.shape # Sample random noise noise = torch.randn_like(x_0) if self.consistent_center: # modification suggested by https://arxiv.org/pdf/2308.07837.pdf noise = noise - torch.mean(noise, dim=1, keepdim=True) # Sample random timesteps for each point_cloud timestep = torch.randint(0, self.scheduler.num_train_timesteps, (B,), device=self.device, dtype=torch.long) # Add noise to points x_t = self.scheduler.add_noise(x_0, noise, timestep) # diffusion noisy adding, only add to the coordinate, not features # add noise to the camera pose, based on timestamps if self.cam_noise_std > 0.000001: # the noise is very different camera = camera.clone() camT = camera.T # (B, 3) dist = torch.sqrt(torch.sum(camT**2, -1, keepdim=True)) nratio = timestep[:, None] / self.scheduler.num_train_timesteps # time-dependent noise tnoise = torch.randn(B, 3).to(dist.device)/3. * dist * self.cam_noise_std * nratio camera.T = camera.T + tnoise # Conditioning, the pixel-aligned feature is based on points with noise (new points) x_t_input = self.get_diffu_input(camera, image_rgb, mask, timestep, x_t, **kwargs) # Forward loss, noise_pred = self.compute_loss(noise, timestep, x_0, x_t_input) # Whether to return intermediate steps if return_intermediate_steps: return loss, (x_0, x_t, noise, noise_pred) return loss def compute_loss(self, noise, timestep, x_0, x_t_input): x_pred = torch.zeros_like(x_0) if self.self_conditioning: # self conditioning, from https://openreview.net/pdf?id=3itjR9QxFw if random.uniform(0, 1.) > 0.5: with torch.no_grad(): x_pred = self.point_cloud_model(torch.cat([x_t_input, x_pred], -1), timestep) noise_pred = self.point_cloud_model(torch.cat([x_t_input, x_pred], -1), timestep) else: noise_pred = self.point_cloud_model(x_t_input, timestep) # Check if not noise_pred.shape == noise.shape: raise ValueError(f'{noise_pred.shape=} and {noise.shape=}') # Loss if self.dm_pred_type == 'epsilon': loss = F.mse_loss(noise_pred, noise) elif self.dm_pred_type == 'sample': loss = F.mse_loss(noise_pred, x_0) # predicting sample else: raise NotImplementedError return loss, noise_pred def get_diffu_input(self, camera, image_rgb, mask, timestep, x_t, **kwargs): "return: (B, N, D), the exact input to the diffusion model, x_t: (B, N, 3)" x_t_input = self.get_input_with_conditioning(x_t, camera=camera, image_rgb=image_rgb, mask=mask, t=timestep) return x_t_input @torch.no_grad() def forward_sample( self, num_points: int, camera: Optional[CamerasBase], image_rgb: Optional[Tensor], mask: Optional[Tensor], # Optional overrides scheduler: Optional[str] = 'ddpm', # Inference parameters num_inference_steps: Optional[int] = 1000, eta: Optional[float] = 0.0, # for DDIM # Whether to return all the intermediate steps in generation return_sample_every_n_steps: int = -1, # Whether to disable tqdm disable_tqdm: bool = False, gt_pc: Pointclouds = None, **kwargs ): # Get scheduler from mapping, or use self.scheduler if None scheduler = self.scheduler if scheduler is None else self.schedulers_map[scheduler] # Get the size of the noise N = num_points B = 1 if image_rgb is None else image_rgb.shape[0] D = self.get_x_T_channel() device = self.device if image_rgb is None else image_rgb.device sample_from_interm = kwargs.get('sample_from_interm', False) interm_steps = kwargs.get('noise_step') if sample_from_interm else -1 x_t = self.initialize_x_T(device, gt_pc, (B, N, D), interm_steps, scheduler) x_pred = torch.zeros_like(x_t) # Set timesteps extra_step_kwargs = self.setup_reverse_process(eta, num_inference_steps, scheduler) # Loop over timesteps all_outputs = [] return_all_outputs = (return_sample_every_n_steps > 0) progress_bar = tqdm(scheduler.timesteps.to(device), desc=f'Sampling ({x_t.shape})', disable=disable_tqdm) for i, t in enumerate(progress_bar): add_interm_output = (return_all_outputs and ( i % return_sample_every_n_steps == 0 or i == len(scheduler.timesteps) - 1)) # Conditioning x_t_input = self.get_diffu_input(camera, image_rgb, mask, t, x_t, **kwargs) if self.self_conditioning: x_t_input = torch.cat([x_t_input, x_pred], -1) # add self-conditioning inference_binary = (i == len(progress_bar) - 1) | add_interm_output # One reverse step with conditioning x_t = self.reverse_step(extra_step_kwargs, scheduler, t, x_t, x_t_input, inference_binary=inference_binary) # (B, N, D), D=3 or 4 x_pred = x_t # for next iteration self conditioning # Append to output list if desired if add_interm_output: all_outputs.append(x_t) # Convert output back into a point cloud, undoing normalization and scaling output = self.tensor_to_point_cloud(x_t, denormalize=True, unscale=True) # this convert the points back to original scale if return_all_outputs: all_outputs = torch.stack(all_outputs, dim=1) # (B, sample_steps, N, D) all_outputs = [self.tensor_to_point_cloud(o, denormalize=True, unscale=True) for o in all_outputs] return (output, all_outputs) if return_all_outputs else output def get_x_T_channel(self): D = 3 + (self.color_channels if self.predict_color else 0) return D def initialize_x_T(self, device, gt_pc, shape, interm_steps:int=-1, scheduler=None): B, N, D = shape # Sample noise initialization if interm_steps > 0: # Sample from some intermediate steps x_0 = self.point_cloud_to_tensor(gt_pc, normalize=True, scale=True) noise = torch.randn(B, N, D, device=device) # always make sure the noise does not change the pc center, this is important to reduce 0.1cm CD! noise = noise - torch.mean(noise, dim=1, keepdim=True) x_t = scheduler.add_noise(x_0, noise, torch.tensor([interm_steps - 1] * B).long().to(device)) # Add noise else: # Sample from random Gaussian x_t = torch.randn(B, N, D, device=device) x_t = x_t * self.sample_init_scale # for test if self.consistent_center: x_t = x_t - torch.mean(x_t, dim=1, keepdim=True) return x_t def reverse_step(self, extra_step_kwargs, scheduler, t, x_t, x_t_input, **kwargs): """ run one reverse step to compute x_t :param extra_step_kwargs: :param scheduler: :param t: [1], diffusion time step :param x_t: (B, N, 3) :param x_t_input: conditional features (B, N, F) :param kwargs: other configurations to run diffusion step :return: denoised x_t """ B = x_t.shape[0] # Forward noise_pred = self.point_cloud_model(x_t_input, t.reshape(1).expand(B)) if self.consistent_center: assert self.dm_pred_type != 'sample', 'incompatible dm predition type for CCD!' # suggested by the CCD-3DR paper noise_pred = noise_pred - torch.mean(noise_pred, dim=1, keepdim=True) # Step x_t = scheduler.step(noise_pred, t, x_t, **extra_step_kwargs).prev_sample if self.consistent_center: x_t = x_t - torch.mean(x_t, dim=1, keepdim=True) return x_t def setup_reverse_process(self, eta, num_inference_steps, scheduler): """ setup diffusion chain, and others. """ accepts_offset = "offset" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) extra_set_kwargs = {"offset": 1} if accepts_offset else {} scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) # Prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(scheduler.step).parameters.keys()) extra_step_kwargs = {"eta": eta} if accepts_eta else {} return extra_step_kwargs def forward(self, batch: FrameData, mode: str = 'train', **kwargs): """ A wrapper around the forward method for training and inference """ if isinstance(batch, dict): # fixes a bug with multiprocessing where batch becomes a dict batch = FrameData(**batch) # it really makes no sense, I do not understand it if mode == 'train': return self.forward_train( pc=batch.sequence_point_cloud, camera=batch.camera, image_rgb=batch.image_rgb, mask=batch.fg_probability, **kwargs) elif mode == 'sample': num_points = kwargs.pop('num_points', get_num_points(batch.sequence_point_cloud)) return self.forward_sample( num_points=num_points, camera=batch.camera, image_rgb=batch.image_rgb, mask=batch.fg_probability, **kwargs) else: raise NotImplementedError()