Spaces:
Sleeping
Sleeping
import inspect | |
import random | |
from typing import Optional | |
import torch | |
import torch.nn.functional as F | |
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler | |
from diffusers.schedulers.scheduling_ddim import DDIMScheduler | |
from diffusers.schedulers.scheduling_pndm import PNDMScheduler | |
from pytorch3d.implicitron.dataset.data_loader_map_provider import FrameData | |
from pytorch3d.renderer.cameras import CamerasBase | |
from pytorch3d.structures import Pointclouds | |
from torch import Tensor | |
from tqdm import tqdm | |
from .model_utils import get_num_points, get_custom_betas | |
from .point_cloud_model import PointCloudModel | |
from .projection_model import PointCloudProjectionModel | |
class ConditionalPointCloudDiffusionModel(PointCloudProjectionModel): | |
def __init__( | |
self, | |
beta_start: float, | |
beta_end: float, | |
beta_schedule: str, | |
point_cloud_model: str, | |
point_cloud_model_embed_dim: int, | |
**kwargs, # projection arguments | |
): | |
super().__init__(**kwargs) | |
# Checks | |
if not self.predict_shape: | |
raise NotImplementedError('Must predict shape if performing diffusion.') | |
# Create diffusion model schedulers which define the sampling timesteps | |
self.dm_pred_type = kwargs.get('dm_pred_type', "epsilon") | |
assert self.dm_pred_type in ['epsilon','sample'] | |
scheduler_kwargs = {"prediction_type": self.dm_pred_type} | |
if beta_schedule == 'custom': | |
scheduler_kwargs.update(dict(trained_betas=get_custom_betas(beta_start=beta_start, beta_end=beta_end))) | |
else: | |
scheduler_kwargs.update(dict(beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule)) | |
self.schedulers_map = { | |
'ddpm': DDPMScheduler(**scheduler_kwargs, clip_sample=False), | |
'ddim': DDIMScheduler(**scheduler_kwargs, clip_sample=False), | |
'pndm': PNDMScheduler(**scheduler_kwargs), | |
} | |
self.scheduler = self.schedulers_map['ddpm'] # this can be changed for inference | |
# Create point cloud model for processing point cloud at each diffusion step | |
self.init_pcloud_model(kwargs, point_cloud_model, point_cloud_model_embed_dim) | |
self.load_sample_init = kwargs.get('load_sample_init', False) | |
self.sample_init_scale = kwargs.get('sample_init_scale', 1.0) | |
self.test_init_with_gtpc = kwargs.get('test_init_with_gtpc', False) | |
self.consistent_center = kwargs.get('consistent_center', False) | |
self.cam_noise_std = kwargs.get('cam_noise_std', 0.0) # add noise to camera based on timestamps | |
def init_pcloud_model(self, kwargs, point_cloud_model, point_cloud_model_embed_dim): | |
self.point_cloud_model = PointCloudModel( | |
model_type=point_cloud_model, | |
embed_dim=point_cloud_model_embed_dim, | |
in_channels=self.in_channels, | |
out_channels=self.out_channels, # voxel resolution multiplier is 1. | |
voxel_resolution_multiplier=kwargs.get('voxel_resolution_multiplier', 1) | |
) | |
def forward_train( | |
self, | |
pc: Pointclouds, | |
camera: Optional[CamerasBase], | |
image_rgb: Optional[Tensor], | |
mask: Optional[Tensor], | |
return_intermediate_steps: bool = False, | |
**kwargs | |
): | |
# Normalize colors and convert to tensor | |
x_0 = self.point_cloud_to_tensor(pc, normalize=True, scale=True) # this will not pack the point colors | |
B, N, D = x_0.shape | |
# Sample random noise | |
noise = torch.randn_like(x_0) | |
if self.consistent_center: | |
# modification suggested by https://arxiv.org/pdf/2308.07837.pdf | |
noise = noise - torch.mean(noise, dim=1, keepdim=True) | |
# Sample random timesteps for each point_cloud | |
timestep = torch.randint(0, self.scheduler.num_train_timesteps, (B,), | |
device=self.device, dtype=torch.long) | |
# Add noise to points | |
x_t = self.scheduler.add_noise(x_0, noise, timestep) # diffusion noisy adding, only add to the coordinate, not features | |
# add noise to the camera pose, based on timestamps | |
if self.cam_noise_std > 0.000001: | |
# the noise is very different | |
camera = camera.clone() | |
camT = camera.T # (B, 3) | |
dist = torch.sqrt(torch.sum(camT**2, -1, keepdim=True)) | |
nratio = timestep[:, None] / self.scheduler.num_train_timesteps # time-dependent noise | |
tnoise = torch.randn(B, 3).to(dist.device)/3. * dist * self.cam_noise_std * nratio | |
camera.T = camera.T + tnoise | |
# Conditioning, the pixel-aligned feature is based on points with noise (new points) | |
x_t_input = self.get_diffu_input(camera, image_rgb, mask, timestep, x_t, **kwargs) | |
# Forward | |
loss, noise_pred = self.compute_loss(noise, timestep, x_0, x_t_input) | |
# Whether to return intermediate steps | |
if return_intermediate_steps: | |
return loss, (x_0, x_t, noise, noise_pred) | |
return loss | |
def compute_loss(self, noise, timestep, x_0, x_t_input): | |
x_pred = torch.zeros_like(x_0) | |
if self.self_conditioning: | |
# self conditioning, from https://openreview.net/pdf?id=3itjR9QxFw | |
if random.uniform(0, 1.) > 0.5: | |
with torch.no_grad(): | |
x_pred = self.point_cloud_model(torch.cat([x_t_input, x_pred], -1), timestep) | |
noise_pred = self.point_cloud_model(torch.cat([x_t_input, x_pred], -1), timestep) | |
else: | |
noise_pred = self.point_cloud_model(x_t_input, timestep) | |
# Check | |
if not noise_pred.shape == noise.shape: | |
raise ValueError(f'{noise_pred.shape=} and {noise.shape=}') | |
# Loss | |
if self.dm_pred_type == 'epsilon': | |
loss = F.mse_loss(noise_pred, noise) | |
elif self.dm_pred_type == 'sample': | |
loss = F.mse_loss(noise_pred, x_0) # predicting sample | |
else: | |
raise NotImplementedError | |
return loss, noise_pred | |
def get_diffu_input(self, camera, image_rgb, mask, timestep, x_t, **kwargs): | |
"return: (B, N, D), the exact input to the diffusion model, x_t: (B, N, 3)" | |
x_t_input = self.get_input_with_conditioning(x_t, camera=camera, | |
image_rgb=image_rgb, mask=mask, t=timestep) | |
return x_t_input | |
def forward_sample( | |
self, | |
num_points: int, | |
camera: Optional[CamerasBase], | |
image_rgb: Optional[Tensor], | |
mask: Optional[Tensor], | |
# Optional overrides | |
scheduler: Optional[str] = 'ddpm', | |
# Inference parameters | |
num_inference_steps: Optional[int] = 1000, | |
eta: Optional[float] = 0.0, # for DDIM | |
# Whether to return all the intermediate steps in generation | |
return_sample_every_n_steps: int = -1, | |
# Whether to disable tqdm | |
disable_tqdm: bool = False, | |
gt_pc: Pointclouds = None, | |
**kwargs | |
): | |
# Get scheduler from mapping, or use self.scheduler if None | |
scheduler = self.scheduler if scheduler is None else self.schedulers_map[scheduler] | |
# Get the size of the noise | |
N = num_points | |
B = 1 if image_rgb is None else image_rgb.shape[0] | |
D = self.get_x_T_channel() | |
device = self.device if image_rgb is None else image_rgb.device | |
sample_from_interm = kwargs.get('sample_from_interm', False) | |
interm_steps = kwargs.get('noise_step') if sample_from_interm else -1 | |
x_t = self.initialize_x_T(device, gt_pc, (B, N, D), interm_steps, scheduler) | |
x_pred = torch.zeros_like(x_t) | |
# Set timesteps | |
extra_step_kwargs = self.setup_reverse_process(eta, num_inference_steps, scheduler) | |
# Loop over timesteps | |
all_outputs = [] | |
return_all_outputs = (return_sample_every_n_steps > 0) | |
progress_bar = tqdm(scheduler.timesteps.to(device), desc=f'Sampling ({x_t.shape})', disable=disable_tqdm) | |
for i, t in enumerate(progress_bar): | |
add_interm_output = (return_all_outputs and ( | |
i % return_sample_every_n_steps == 0 or i == len(scheduler.timesteps) - 1)) | |
# Conditioning | |
x_t_input = self.get_diffu_input(camera, image_rgb, mask, t, x_t, **kwargs) | |
if self.self_conditioning: | |
x_t_input = torch.cat([x_t_input, x_pred], -1) # add self-conditioning | |
inference_binary = (i == len(progress_bar) - 1) | add_interm_output | |
# One reverse step with conditioning | |
x_t = self.reverse_step(extra_step_kwargs, scheduler, t, x_t, x_t_input, | |
inference_binary=inference_binary) # (B, N, D), D=3 or 4 | |
x_pred = x_t # for next iteration self conditioning | |
# Append to output list if desired | |
if add_interm_output: | |
all_outputs.append(x_t) | |
# Convert output back into a point cloud, undoing normalization and scaling | |
output = self.tensor_to_point_cloud(x_t, denormalize=True, unscale=True) # this convert the points back to original scale | |
if return_all_outputs: | |
all_outputs = torch.stack(all_outputs, dim=1) # (B, sample_steps, N, D) | |
all_outputs = [self.tensor_to_point_cloud(o, denormalize=True, unscale=True) for o in all_outputs] | |
return (output, all_outputs) if return_all_outputs else output | |
def get_x_T_channel(self): | |
D = 3 + (self.color_channels if self.predict_color else 0) | |
return D | |
def initialize_x_T(self, device, gt_pc, shape, interm_steps:int=-1, scheduler=None): | |
B, N, D = shape | |
# Sample noise initialization | |
if interm_steps > 0: | |
# Sample from some intermediate steps | |
x_0 = self.point_cloud_to_tensor(gt_pc, normalize=True, scale=True) | |
noise = torch.randn(B, N, D, device=device) | |
# always make sure the noise does not change the pc center, this is important to reduce 0.1cm CD! | |
noise = noise - torch.mean(noise, dim=1, keepdim=True) | |
x_t = scheduler.add_noise(x_0, noise, torch.tensor([interm_steps - 1] * B).long().to(device)) # Add noise | |
else: | |
# Sample from random Gaussian | |
x_t = torch.randn(B, N, D, device=device) | |
x_t = x_t * self.sample_init_scale # for test | |
if self.consistent_center: | |
x_t = x_t - torch.mean(x_t, dim=1, keepdim=True) | |
return x_t | |
def reverse_step(self, extra_step_kwargs, scheduler, t, x_t, x_t_input, **kwargs): | |
""" | |
run one reverse step to compute x_t | |
:param extra_step_kwargs: | |
:param scheduler: | |
:param t: [1], diffusion time step | |
:param x_t: (B, N, 3) | |
:param x_t_input: conditional features (B, N, F) | |
:param kwargs: other configurations to run diffusion step | |
:return: denoised x_t | |
""" | |
B = x_t.shape[0] | |
# Forward | |
noise_pred = self.point_cloud_model(x_t_input, t.reshape(1).expand(B)) | |
if self.consistent_center: | |
assert self.dm_pred_type != 'sample', 'incompatible dm predition type for CCD!' | |
# suggested by the CCD-3DR paper | |
noise_pred = noise_pred - torch.mean(noise_pred, dim=1, keepdim=True) | |
# Step | |
x_t = scheduler.step(noise_pred, t, x_t, **extra_step_kwargs).prev_sample | |
if self.consistent_center: | |
x_t = x_t - torch.mean(x_t, dim=1, keepdim=True) | |
return x_t | |
def setup_reverse_process(self, eta, num_inference_steps, scheduler): | |
""" | |
setup diffusion chain, and others. | |
""" | |
accepts_offset = "offset" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) | |
extra_set_kwargs = {"offset": 1} if accepts_offset else {} | |
scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) | |
# Prepare extra kwargs for the scheduler step, since not all schedulers have the same signature | |
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. | |
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 | |
# and should be between [0, 1] | |
accepts_eta = "eta" in set(inspect.signature(scheduler.step).parameters.keys()) | |
extra_step_kwargs = {"eta": eta} if accepts_eta else {} | |
return extra_step_kwargs | |
def forward(self, batch: FrameData, mode: str = 'train', **kwargs): | |
""" | |
A wrapper around the forward method for training and inference | |
""" | |
if isinstance(batch, dict): # fixes a bug with multiprocessing where batch becomes a dict | |
batch = FrameData(**batch) # it really makes no sense, I do not understand it | |
if mode == 'train': | |
return self.forward_train( | |
pc=batch.sequence_point_cloud, | |
camera=batch.camera, | |
image_rgb=batch.image_rgb, | |
mask=batch.fg_probability, | |
**kwargs) | |
elif mode == 'sample': | |
num_points = kwargs.pop('num_points', get_num_points(batch.sequence_point_cloud)) | |
return self.forward_sample( | |
num_points=num_points, | |
camera=batch.camera, | |
image_rgb=batch.image_rgb, | |
mask=batch.fg_probability, | |
**kwargs) | |
else: | |
raise NotImplementedError() |