from omegaconf.dictconfig import DictConfig from typing import List, Tuple from ema_pytorch import EMA import numpy as np import torch from torchtyping import TensorType import torch.nn as nn import lightning as L from utils.random_utils import StackedRandomGenerator # ------------------------------------------------------------------------------------- # batch_size, num_samples = None, None num_feats, num_rawfeats, num_cams = None, None, None RawTrajectory = TensorType["num_samples", "num_rawfeats", "num_cams"] # ------------------------------------------------------------------------------------- # class Diffuser(L.LightningModule): def __init__( self, network: nn.Module, guidance_weight: float, ema_kwargs: DictConfig, sampling_kwargs: DictConfig, edm2_normalization: bool, **kwargs, ): super().__init__() # Network and EMA self.net = network self.ema = EMA(self.net, **ema_kwargs) self.guidance_weight = guidance_weight self.edm2_normalization = edm2_normalization self.sigma_data = network.sigma_data # Sampling self.num_steps = sampling_kwargs.num_steps self.sigma_min = sampling_kwargs.sigma_min self.sigma_max = sampling_kwargs.sigma_max self.rho = sampling_kwargs.rho self.S_churn = sampling_kwargs.S_churn self.S_noise = sampling_kwargs.S_noise self.S_min = sampling_kwargs.S_min self.S_max = ( sampling_kwargs.S_max if isinstance(sampling_kwargs.S_max, float) else float("inf") ) # ---------------------------------------------------------------------------------- # def on_predict_start(self): eval_dataset = self.trainer.datamodule.eval_dataset self.modalities = list(eval_dataset.modality_datasets.keys()) self.get_matrix = self.trainer.datamodule.train_dataset.get_matrix self.v_get_matrix = self.trainer.datamodule.eval_dataset.get_matrix def predict_step(self, batch, batch_idx): ref_samples, mask = batch["traj_feat"], batch["padding_mask"] if len(self.modalities) > 0: cond_k = [x for x in batch.keys() if "traj" not in x and "feat" in x] cond_data = [batch[cond] for cond in cond_k] conds = {} for cond in cond_k: cond_name = cond.replace("_feat", "") if isinstance(batch[f"{cond_name}_raw"], dict): for cond_name_, x in batch[f"{cond_name}_raw"].items(): conds[cond_name_] = x else: conds[cond_name] = batch[f"{cond_name}_raw"] batch["conds"] = conds else: cond_data = None # cf edm2 sigma_data normalization / https://arxiv.org/pdf/2312.02696.pdf if self.edm2_normalization: ref_samples *= self.sigma_data _, gen_samples = self.sample(self.ema.ema_model, ref_samples, cond_data, mask) batch["ref_samples"] = torch.stack([self.v_get_matrix(x) for x in ref_samples]) batch["gen_samples"] = torch.stack([self.get_matrix(x) for x in gen_samples]) return batch # --------------------------------------------------------------------------------- # def sample( self, net: torch.nn.Module, traj_samples: RawTrajectory, cond_samples: TensorType["num_samples", "num_feats"], mask: TensorType["num_samples", "num_feats"], external_seeds: List[int] = None, ) -> Tuple[RawTrajectory, RawTrajectory]: # Pick latents num_samples = traj_samples.shape[0] seeds = self.gen_seeds if hasattr(self, "gen_seeds") else range(num_samples) rnd = StackedRandomGenerator(self.device, seeds) sz = [num_samples, self.net.num_feats, self.net.num_cams] latents = rnd.randn_rn(sz, device=self.device) # Generate trajectories. generations = self.edm_sampler( net, latents, class_labels=cond_samples, mask=mask, randn_like=rnd.randn_like, guidance_weight=self.guidance_weight, # ----------------------------------- # num_steps=self.num_steps, sigma_min=self.sigma_min, sigma_max=self.sigma_max, rho=self.rho, S_churn=self.S_churn, S_min=self.S_min, S_max=self.S_max, S_noise=self.S_noise, ) return latents, generations @staticmethod def edm_sampler( net, latents, class_labels=None, mask=None, guidance_weight=2.0, randn_like=torch.randn_like, num_steps=18, sigma_min=0.002, sigma_max=80, rho=7, S_churn=0, S_min=0, S_max=float("inf"), S_noise=1, ): # Time step discretization. step_indices = torch.arange(num_steps, device=latents.device) t_steps = ( sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) ) ** rho t_steps = torch.cat( [torch.as_tensor(t_steps), torch.zeros_like(t_steps[:1])] ) # t_N = 0 # Main sampling loop. bool_mask = ~mask.to(bool) x_next = latents * t_steps[0] bs = latents.shape[0] for i, (t_cur, t_next) in enumerate( zip(t_steps[:-1], t_steps[1:]) ): # 0, ..., N-1 x_cur = x_next # Increase noise temporarily. gamma = ( min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 ) t_hat = torch.as_tensor(t_cur + gamma * t_cur) x_hat = x_cur + (t_hat**2 - t_cur**2).sqrt() * S_noise * randn_like(x_cur) # Euler step. if class_labels is not None: class_label_knot = [torch.zeros_like(label) for label in class_labels] x_hat_both = torch.cat([x_hat, x_hat], dim=0) y_label_both = [ torch.cat([y, y_knot], dim=0) for y, y_knot in zip(class_labels, class_label_knot) ] bool_mask_both = torch.cat([bool_mask, bool_mask], dim=0) t_hat_both = torch.cat([t_hat.expand(bs), t_hat.expand(bs)], dim=0) cond_denoised, denoised = net( x_hat_both, t_hat_both, y=y_label_both, mask=bool_mask_both ).chunk(2, dim=0) denoised = denoised + (cond_denoised - denoised) * guidance_weight else: denoised = net(x_hat, t_hat.expand(bs), mask=bool_mask) d_cur = (x_hat - denoised) / t_hat x_next = x_hat + (t_next - t_hat) * d_cur # Apply 2nd order correction. if i < num_steps - 1: if class_labels is not None: class_label_knot = [ torch.zeros_like(label) for label in class_labels ] x_next_both = torch.cat([x_next, x_next], dim=0) y_label_both = [ torch.cat([y, y_knot], dim=0) for y, y_knot in zip(class_labels, class_label_knot) ] bool_mask_both = torch.cat([bool_mask, bool_mask], dim=0) t_next_both = torch.cat( [t_next.expand(bs), t_next.expand(bs)], dim=0 ) cond_denoised, denoised = net( x_next_both, t_next_both, y=y_label_both, mask=bool_mask_both ).chunk(2, dim=0) denoised = denoised + (cond_denoised - denoised) * guidance_weight else: denoised = net(x_next, t_next.expand(bs), mask=bool_mask) d_prime = (x_next - denoised) / t_next x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) return x_next