Spaces:
Sleeping
Sleeping
File size: 13,556 Bytes
2fd6166 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 |
import inspect
import random
from typing import Optional
import torch
import torch.nn.functional as F
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_pndm import PNDMScheduler
from pytorch3d.implicitron.dataset.data_loader_map_provider import FrameData
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.structures import Pointclouds
from torch import Tensor
from tqdm import tqdm
from .model_utils import get_num_points, get_custom_betas
from .point_cloud_model import PointCloudModel
from .projection_model import PointCloudProjectionModel
class ConditionalPointCloudDiffusionModel(PointCloudProjectionModel):
def __init__(
self,
beta_start: float,
beta_end: float,
beta_schedule: str,
point_cloud_model: str,
point_cloud_model_embed_dim: int,
**kwargs, # projection arguments
):
super().__init__(**kwargs)
# Checks
if not self.predict_shape:
raise NotImplementedError('Must predict shape if performing diffusion.')
# Create diffusion model schedulers which define the sampling timesteps
self.dm_pred_type = kwargs.get('dm_pred_type', "epsilon")
assert self.dm_pred_type in ['epsilon','sample']
scheduler_kwargs = {"prediction_type": self.dm_pred_type}
if beta_schedule == 'custom':
scheduler_kwargs.update(dict(trained_betas=get_custom_betas(beta_start=beta_start, beta_end=beta_end)))
else:
scheduler_kwargs.update(dict(beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule))
self.schedulers_map = {
'ddpm': DDPMScheduler(**scheduler_kwargs, clip_sample=False),
'ddim': DDIMScheduler(**scheduler_kwargs, clip_sample=False),
'pndm': PNDMScheduler(**scheduler_kwargs),
}
self.scheduler = self.schedulers_map['ddpm'] # this can be changed for inference
# Create point cloud model for processing point cloud at each diffusion step
self.init_pcloud_model(kwargs, point_cloud_model, point_cloud_model_embed_dim)
self.load_sample_init = kwargs.get('load_sample_init', False)
self.sample_init_scale = kwargs.get('sample_init_scale', 1.0)
self.test_init_with_gtpc = kwargs.get('test_init_with_gtpc', False)
self.consistent_center = kwargs.get('consistent_center', False)
self.cam_noise_std = kwargs.get('cam_noise_std', 0.0) # add noise to camera based on timestamps
def init_pcloud_model(self, kwargs, point_cloud_model, point_cloud_model_embed_dim):
self.point_cloud_model = PointCloudModel(
model_type=point_cloud_model,
embed_dim=point_cloud_model_embed_dim,
in_channels=self.in_channels,
out_channels=self.out_channels, # voxel resolution multiplier is 1.
voxel_resolution_multiplier=kwargs.get('voxel_resolution_multiplier', 1)
)
def forward_train(
self,
pc: Pointclouds,
camera: Optional[CamerasBase],
image_rgb: Optional[Tensor],
mask: Optional[Tensor],
return_intermediate_steps: bool = False,
**kwargs
):
# Normalize colors and convert to tensor
x_0 = self.point_cloud_to_tensor(pc, normalize=True, scale=True) # this will not pack the point colors
B, N, D = x_0.shape
# Sample random noise
noise = torch.randn_like(x_0)
if self.consistent_center:
# modification suggested by https://arxiv.org/pdf/2308.07837.pdf
noise = noise - torch.mean(noise, dim=1, keepdim=True)
# Sample random timesteps for each point_cloud
timestep = torch.randint(0, self.scheduler.num_train_timesteps, (B,),
device=self.device, dtype=torch.long)
# Add noise to points
x_t = self.scheduler.add_noise(x_0, noise, timestep) # diffusion noisy adding, only add to the coordinate, not features
# add noise to the camera pose, based on timestamps
if self.cam_noise_std > 0.000001:
# the noise is very different
camera = camera.clone()
camT = camera.T # (B, 3)
dist = torch.sqrt(torch.sum(camT**2, -1, keepdim=True))
nratio = timestep[:, None] / self.scheduler.num_train_timesteps # time-dependent noise
tnoise = torch.randn(B, 3).to(dist.device)/3. * dist * self.cam_noise_std * nratio
camera.T = camera.T + tnoise
# Conditioning, the pixel-aligned feature is based on points with noise (new points)
x_t_input = self.get_diffu_input(camera, image_rgb, mask, timestep, x_t, **kwargs)
# Forward
loss, noise_pred = self.compute_loss(noise, timestep, x_0, x_t_input)
# Whether to return intermediate steps
if return_intermediate_steps:
return loss, (x_0, x_t, noise, noise_pred)
return loss
def compute_loss(self, noise, timestep, x_0, x_t_input):
x_pred = torch.zeros_like(x_0)
if self.self_conditioning:
# self conditioning, from https://openreview.net/pdf?id=3itjR9QxFw
if random.uniform(0, 1.) > 0.5:
with torch.no_grad():
x_pred = self.point_cloud_model(torch.cat([x_t_input, x_pred], -1), timestep)
noise_pred = self.point_cloud_model(torch.cat([x_t_input, x_pred], -1), timestep)
else:
noise_pred = self.point_cloud_model(x_t_input, timestep)
# Check
if not noise_pred.shape == noise.shape:
raise ValueError(f'{noise_pred.shape=} and {noise.shape=}')
# Loss
if self.dm_pred_type == 'epsilon':
loss = F.mse_loss(noise_pred, noise)
elif self.dm_pred_type == 'sample':
loss = F.mse_loss(noise_pred, x_0) # predicting sample
else:
raise NotImplementedError
return loss, noise_pred
def get_diffu_input(self, camera, image_rgb, mask, timestep, x_t, **kwargs):
"return: (B, N, D), the exact input to the diffusion model, x_t: (B, N, 3)"
x_t_input = self.get_input_with_conditioning(x_t, camera=camera,
image_rgb=image_rgb, mask=mask, t=timestep)
return x_t_input
@torch.no_grad()
def forward_sample(
self,
num_points: int,
camera: Optional[CamerasBase],
image_rgb: Optional[Tensor],
mask: Optional[Tensor],
# Optional overrides
scheduler: Optional[str] = 'ddpm',
# Inference parameters
num_inference_steps: Optional[int] = 1000,
eta: Optional[float] = 0.0, # for DDIM
# Whether to return all the intermediate steps in generation
return_sample_every_n_steps: int = -1,
# Whether to disable tqdm
disable_tqdm: bool = False,
gt_pc: Pointclouds = None,
**kwargs
):
# Get scheduler from mapping, or use self.scheduler if None
scheduler = self.scheduler if scheduler is None else self.schedulers_map[scheduler]
# Get the size of the noise
N = num_points
B = 1 if image_rgb is None else image_rgb.shape[0]
D = self.get_x_T_channel()
device = self.device if image_rgb is None else image_rgb.device
sample_from_interm = kwargs.get('sample_from_interm', False)
interm_steps = kwargs.get('noise_step') if sample_from_interm else -1
x_t = self.initialize_x_T(device, gt_pc, (B, N, D), interm_steps, scheduler)
x_pred = torch.zeros_like(x_t)
# Set timesteps
extra_step_kwargs = self.setup_reverse_process(eta, num_inference_steps, scheduler)
# Loop over timesteps
all_outputs = []
return_all_outputs = (return_sample_every_n_steps > 0)
progress_bar = tqdm(scheduler.timesteps.to(device), desc=f'Sampling ({x_t.shape})', disable=disable_tqdm)
for i, t in enumerate(progress_bar):
add_interm_output = (return_all_outputs and (
i % return_sample_every_n_steps == 0 or i == len(scheduler.timesteps) - 1))
# Conditioning
x_t_input = self.get_diffu_input(camera, image_rgb, mask, t, x_t, **kwargs)
if self.self_conditioning:
x_t_input = torch.cat([x_t_input, x_pred], -1) # add self-conditioning
inference_binary = (i == len(progress_bar) - 1) | add_interm_output
# One reverse step with conditioning
x_t = self.reverse_step(extra_step_kwargs, scheduler, t, x_t, x_t_input,
inference_binary=inference_binary) # (B, N, D), D=3 or 4
x_pred = x_t # for next iteration self conditioning
# Append to output list if desired
if add_interm_output:
all_outputs.append(x_t)
# Convert output back into a point cloud, undoing normalization and scaling
output = self.tensor_to_point_cloud(x_t, denormalize=True, unscale=True) # this convert the points back to original scale
if return_all_outputs:
all_outputs = torch.stack(all_outputs, dim=1) # (B, sample_steps, N, D)
all_outputs = [self.tensor_to_point_cloud(o, denormalize=True, unscale=True) for o in all_outputs]
return (output, all_outputs) if return_all_outputs else output
def get_x_T_channel(self):
D = 3 + (self.color_channels if self.predict_color else 0)
return D
def initialize_x_T(self, device, gt_pc, shape, interm_steps:int=-1, scheduler=None):
B, N, D = shape
# Sample noise initialization
if interm_steps > 0:
# Sample from some intermediate steps
x_0 = self.point_cloud_to_tensor(gt_pc, normalize=True, scale=True)
noise = torch.randn(B, N, D, device=device)
# always make sure the noise does not change the pc center, this is important to reduce 0.1cm CD!
noise = noise - torch.mean(noise, dim=1, keepdim=True)
x_t = scheduler.add_noise(x_0, noise, torch.tensor([interm_steps - 1] * B).long().to(device)) # Add noise
else:
# Sample from random Gaussian
x_t = torch.randn(B, N, D, device=device)
x_t = x_t * self.sample_init_scale # for test
if self.consistent_center:
x_t = x_t - torch.mean(x_t, dim=1, keepdim=True)
return x_t
def reverse_step(self, extra_step_kwargs, scheduler, t, x_t, x_t_input, **kwargs):
"""
run one reverse step to compute x_t
:param extra_step_kwargs:
:param scheduler:
:param t: [1], diffusion time step
:param x_t: (B, N, 3)
:param x_t_input: conditional features (B, N, F)
:param kwargs: other configurations to run diffusion step
:return: denoised x_t
"""
B = x_t.shape[0]
# Forward
noise_pred = self.point_cloud_model(x_t_input, t.reshape(1).expand(B))
if self.consistent_center:
assert self.dm_pred_type != 'sample', 'incompatible dm predition type for CCD!'
# suggested by the CCD-3DR paper
noise_pred = noise_pred - torch.mean(noise_pred, dim=1, keepdim=True)
# Step
x_t = scheduler.step(noise_pred, t, x_t, **extra_step_kwargs).prev_sample
if self.consistent_center:
x_t = x_t - torch.mean(x_t, dim=1, keepdim=True)
return x_t
def setup_reverse_process(self, eta, num_inference_steps, scheduler):
"""
setup diffusion chain, and others.
"""
accepts_offset = "offset" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
extra_set_kwargs = {"offset": 1} if accepts_offset else {}
scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
# Prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(scheduler.step).parameters.keys())
extra_step_kwargs = {"eta": eta} if accepts_eta else {}
return extra_step_kwargs
def forward(self, batch: FrameData, mode: str = 'train', **kwargs):
"""
A wrapper around the forward method for training and inference
"""
if isinstance(batch, dict): # fixes a bug with multiprocessing where batch becomes a dict
batch = FrameData(**batch) # it really makes no sense, I do not understand it
if mode == 'train':
return self.forward_train(
pc=batch.sequence_point_cloud,
camera=batch.camera,
image_rgb=batch.image_rgb,
mask=batch.fg_probability,
**kwargs)
elif mode == 'sample':
num_points = kwargs.pop('num_points', get_num_points(batch.sequence_point_cloud))
return self.forward_sample(
num_points=num_points,
camera=batch.camera,
image_rgb=batch.image_rgb,
mask=batch.fg_probability,
**kwargs)
else:
raise NotImplementedError() |