| import copy
|
| import math
|
| import os
|
| from typing import Any, List, Optional, Union
|
|
|
| import numpy as np
|
| import torch
|
| from pytorch_lightning import LightningDataModule, LightningModule
|
| from torch.distributions import MultivariateNormal
|
| from torchdyn.core import NeuralODE
|
| from torchvision import transforms
|
|
|
| from .components.augmentation import (
|
| AugmentationModule,
|
| AugmentedVectorField,
|
| Sequential,
|
| )
|
| from .components.distribution_distances import compute_distribution_distances
|
| from .components.optimal_transport import OTPlanSampler
|
| from .components.plotting import (
|
| plot_samples,
|
| plot_trajectory,
|
| store_trajectories,
|
| )
|
| from .components.schedule import ConstantNoiseScheduler, NoiseScheduler
|
| from .components.solver import FlowSolver
|
| from .utils import get_wandb_logger
|
|
|
|
|
| class CFMLitModule(LightningModule):
|
| """Conditional Flow Matching Module for training generative models and models over time."""
|
|
|
| def __init__(
|
| self,
|
| net: Any,
|
| optimizer: Any,
|
| datamodule: LightningDataModule,
|
| augmentations: AugmentationModule,
|
| partial_solver: FlowSolver,
|
| scheduler: Optional[Any] = None,
|
| neural_ode: Optional[Any] = None,
|
| ot_sampler: Optional[Union[str, Any]] = None,
|
| sigma_min: float = 0.1,
|
| avg_size: int = -1,
|
| leaveout_timepoint: int = -1,
|
| test_nfe: int = 100,
|
| plot: bool = False,
|
| nice_name: str = "CFM",
|
| ) -> None:
|
| """Initialize a conditional flow matching network either as a generative model or for a
|
| sequence of timepoints.
|
|
|
| Note: DDP does not currently work with NeuralODE objects from torchdyn
|
| in the init so we initialize them every time we need to do a sampling
|
| step.
|
|
|
| Args:
|
| net: torch module representing dx/dt = f(t, x) for t in [1, T] missing dimension.
|
| optimizer: partial torch.optimizer missing parameters.
|
| datamodule: datamodule object needs to have "dim", "IS_TRAJECTORY" properties.
|
| ot_sampler: ot_sampler specified as an object or string. If none then no OT is used in minibatch.
|
| sigma_min: sigma_min determines the width of the Gaussian smoothing of the data and interpolations.
|
| leaveout_timepoint: which (if any) timepoint to leave out during the training phase
|
| plot: if true, log intermediate plots during validation
|
| """
|
| super().__init__()
|
| self.save_hyperparameters(
|
| ignore=[
|
| "net",
|
| "optimizer",
|
| "scheduler",
|
| "datamodule",
|
| "augmentations",
|
| "partial_solver",
|
| ],
|
| logger=False,
|
| )
|
| self.datamodule = datamodule
|
| self.is_trajectory = False
|
| if hasattr(datamodule, "IS_TRAJECTORY"):
|
| self.is_trajectory = datamodule.IS_TRAJECTORY
|
|
|
|
|
| if hasattr(datamodule, "dim"):
|
| self.dim = datamodule.dim
|
| self.is_image = False
|
| elif hasattr(datamodule, "dims"):
|
| self.dim = datamodule.dims
|
| self.is_image = True
|
| else:
|
| raise NotImplementedError("Datamodule must have either dim or dims")
|
| self.net = net(dim=self.dim)
|
| self.augmentations = augmentations
|
| self.aug_net = AugmentedVectorField(self.net, self.augmentations.regs, self.dim)
|
| self.val_augmentations = AugmentationModule(
|
|
|
| l1_reg=1,
|
| l2_reg=1,
|
| squared_l2_reg=1,
|
| )
|
| self.val_aug_net = AugmentedVectorField(self.net, self.val_augmentations.regs, self.dim)
|
| if neural_ode is not None:
|
| self.aug_node = Sequential(
|
| self.augmentations.augmenter,
|
| neural_ode(self.aug_net),
|
| )
|
|
|
| self.partial_solver = partial_solver
|
| self.optimizer = optimizer
|
| self.scheduler = scheduler
|
| self.ot_sampler = ot_sampler
|
| if ot_sampler == "None":
|
| self.ot_sampler = None
|
| if isinstance(self.ot_sampler, str):
|
|
|
| self.ot_sampler = OTPlanSampler(method=ot_sampler, reg=2 * sigma_min**2)
|
| self.criterion = torch.nn.MSELoss()
|
|
|
| def forward_integrate(self, batch: Any, t_span: torch.Tensor):
|
| """Forward pass with integration over t_span intervals.
|
|
|
| (t, x, t_span) -> [x_t_span].
|
| """
|
| X = self.unpack_batch(batch)
|
| X_start = X[:, t_span[0], :]
|
| traj = self.node.trajectory(X_start, t_span=t_span)
|
| return traj
|
|
|
| def forward(self, t: torch.Tensor, x: torch.Tensor):
|
| """Forward pass (t, x) -> dx/dt."""
|
| return self.net(t, x)
|
|
|
| def unpack_batch(self, batch):
|
| """Unpacks a batch of data to a single tensor."""
|
| if self.is_trajectory:
|
| return torch.stack(batch, dim=1)
|
| if not isinstance(self.dim, int):
|
|
|
| return batch[0]
|
| return batch
|
|
|
| def preprocess_batch(self, X, training=False):
|
| """Converts a batch of data into matched a random pair of (x0, x1)"""
|
| t_select = torch.zeros(1, device=X.device)
|
| if self.is_trajectory:
|
| batch_size, times, dim = X.shape
|
| if not hasattr(self.datamodule, "HAS_JOINT_PLANS"):
|
|
|
|
|
| tmp_ot_list = []
|
| for t in range(times - 1):
|
| if training and t + 1 == self.hparams.leaveout_timepoint:
|
| tmp_ot = torch.stack((X[:, t], X[:, t + 2]))
|
| else:
|
| tmp_ot = torch.stack((X[:, t], X[:, t + 1]))
|
| if (
|
| training
|
| and self.ot_sampler is not None
|
| and t != self.hparams.leaveout_timepoint
|
| ):
|
| tmp_ot = torch.stack(self.ot_sampler.sample_plan(tmp_ot[0], tmp_ot[1]))
|
|
|
| tmp_ot_list.append(tmp_ot)
|
| tmp_ot_list = torch.stack(tmp_ot_list)
|
|
|
|
|
| if training and self.hparams.leaveout_timepoint > 0:
|
|
|
| t_select = torch.randint(times - 2, size=(batch_size,), device=X.device)
|
| t_select[t_select >= self.hparams.leaveout_timepoint] += 1
|
| else:
|
| t_select = torch.randint(times - 1, size=(batch_size,))
|
| x0 = []
|
| x1 = []
|
| for i in range(batch_size):
|
| ti = t_select[i]
|
| ti_next = ti + 1
|
| if training and ti_next == self.hparams.leaveout_timepoint:
|
| ti_next += 1
|
| if hasattr(self.datamodule, "HAS_JOINT_PLANS"):
|
| x0.append(torch.tensor(self.datamodule.timepoint_data[ti][X[i, ti]]))
|
| pi = self.datamodule.pi[ti]
|
| if training and ti + 1 == self.hparams.leaveout_timepoint:
|
| pi = self.datamodule.pi_leaveout[ti]
|
| index_batch = X[i][ti]
|
| i_next = np.random.choice(
|
| pi.shape[1], p=pi[index_batch] / pi[index_batch].sum()
|
| )
|
| x1.append(torch.tensor(self.datamodule.timepoint_data[ti_next][i_next]))
|
| else:
|
| x0.append(tmp_ot_list[ti][0][i])
|
| x1.append(tmp_ot_list[ti][1][i])
|
| x0, x1 = torch.stack(x0), torch.stack(x1)
|
| else:
|
| batch_size = X.shape[0]
|
|
|
| x0 = torch.randn_like(X)
|
| x1 = X
|
| return x0, x1, t_select
|
|
|
| def average_ut(self, x, t, mu_t, sigma_t, ut):
|
| pt = torch.exp(-0.5 * (torch.cdist(x, mu_t) ** 2) / (sigma_t**2))
|
| batch_size = x.shape[0]
|
| ind = torch.randint(
|
| batch_size, size=(batch_size, self.hparams.avg_size - 1)
|
| )
|
|
|
| ind = torch.cat([ind, torch.arange(batch_size)[:, None]], dim=1)
|
| pt_sub = torch.stack([pt[i, ind[i]] for i in range(batch_size)])
|
| ut_sub = torch.stack([ut[ind[i]] for i in range(batch_size)])
|
| p_sum = torch.sum(pt_sub, dim=1, keepdim=True)
|
| ut = torch.sum(pt_sub[:, :, None] * ut_sub, dim=1) / p_sum
|
|
|
| return x[:1], ut[:1], t[:1]
|
|
|
| def calc_mu_sigma(self, x0, x1, t):
|
| mu_t = t * x1 + (1 - t) * x0
|
| sigma_t = self.hparams.sigma_min
|
| return mu_t, sigma_t
|
|
|
| def calc_u(self, x0, x1, x, t, mu_t, sigma_t):
|
| del x, t, mu_t, sigma_t
|
| return x1 - x0
|
|
|
| def calc_loc_and_target(self, x0, x1, t, t_select, training):
|
| """Computes the loss on a batch of data."""
|
| t_xshape = t.reshape(-1, *([1] * (x0.dim() - 1)))
|
| mu_t, sigma_t = self.calc_mu_sigma(x0, x1, t_xshape)
|
| eps_t = torch.randn_like(mu_t)
|
| x = mu_t + sigma_t * eps_t
|
| ut = self.calc_u(x0, x1, x, t_xshape, mu_t, sigma_t)
|
|
|
|
|
|
|
| if training and self.hparams.leaveout_timepoint > 0:
|
| ut[t_select + 1 == self.hparams.leaveout_timepoint] /= 2
|
| t[t_select + 1 == self.hparams.leaveout_timepoint] *= 2
|
|
|
|
|
|
|
| t = t + t_select.reshape(-1, *t.shape[1:])
|
| return x, ut, t, mu_t, sigma_t, eps_t
|
|
|
| def step(self, batch: Any, training: bool = False):
|
| """Computes the loss on a batch of data."""
|
| X = self.unpack_batch(batch)
|
| x0, x1, t_select = self.preprocess_batch(X, training)
|
|
|
| if self.hparams.avg_size > 0:
|
| t = torch.rand(1).repeat(X.shape[0]).type_as(X)
|
| else:
|
| t = torch.rand(X.shape[0]).type_as(X)
|
|
|
| if self.ot_sampler is not None and not self.is_trajectory:
|
| x0, x1 = self.ot_sampler.sample_plan(x0, x1)
|
|
|
| x, ut, t, mu_t, sigma_t, eps_t = self.calc_loc_and_target(x0, x1, t, t_select, training)
|
|
|
| if self.hparams.avg_size > 0:
|
| x, ut, t = self.average_ut(x, t, mu_t, sigma_t, ut)
|
| aug_x = self.aug_net(t, x, augmented_input=False)
|
| reg, vt = self.augmentations(aug_x)
|
| return torch.mean(reg), self.criterion(vt, ut)
|
|
|
| def training_step(self, batch: Any, batch_idx: int):
|
| reg, mse = self.step(batch, training=True)
|
| loss = mse + reg
|
| prefix = "train"
|
| self.log_dict(
|
| {f"{prefix}/loss": loss, f"{prefix}/mse": mse, f"{prefix}/reg": reg},
|
| on_step=True,
|
| on_epoch=False,
|
| prog_bar=True,
|
| )
|
| return loss
|
|
|
| def image_eval_step(self, batch: Any, batch_idx: int, prefix: str):
|
| import os
|
|
|
| from torchvision.utils import save_image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| solver = self.partial_solver(self.net, self.dim)
|
| if isinstance(self.hparams.test_nfe, int):
|
| t_span = torch.linspace(0, 1, int(self.hparams.test_nfe) + 1)
|
| elif isinstance(self.hparams.test_nfe, str):
|
| solver.ode_solver = "tsit5"
|
| t_span = torch.linspace(0, 1, 2)
|
| else:
|
| raise NotImplementedError(f"Unknown test procedure {self.hparams.test_nfe}")
|
| traj = solver.odeint(torch.randn(batch[0].shape[0], *self.dim).type_as(batch[0]), t_span)[
|
| -1
|
| ]
|
| os.makedirs("images", exist_ok=True)
|
| mean = [-x / 255.0 for x in [125.3, 123.0, 113.9]]
|
| std = [255.0 / x for x in [63.0, 62.1, 66.7]]
|
| inv_normalize = transforms.Compose(
|
| [
|
| transforms.Normalize(mean=[0.0, 0.0, 0.0], std=std),
|
| transforms.Normalize(mean=mean, std=[1.0, 1.0, 1.0]),
|
| ]
|
| )
|
| traj = inv_normalize(traj)
|
| traj = torch.clip(traj, min=0, max=1.0)
|
| for i, image in enumerate(traj):
|
| save_image(image, fp=f"images/{batch_idx}_{i}.png")
|
| return {"x": batch[0]}
|
|
|
| def eval_step(self, batch: Any, batch_idx: int, prefix: str):
|
| if prefix == "test" and self.is_image:
|
| self.image_eval_step(batch, batch_idx, prefix)
|
| shapes = [b.shape[0] for b in batch]
|
|
|
| if not self.is_image and prefix == "val" and shapes.count(shapes[0]) == len(shapes):
|
| reg, mse = self.step(batch, training=False)
|
| loss = mse + reg
|
| self.log_dict(
|
| {f"{prefix}/loss": loss, f"{prefix}/mse": mse, f"{prefix}/reg": reg},
|
| on_step=False,
|
| on_epoch=True,
|
| sync_dist=True,
|
| )
|
| return {"loss": loss, "mse": mse, "reg": reg, "x": self.unpack_batch(batch)}
|
|
|
| return {"x": batch}
|
|
|
| def preprocess_epoch_end(self, outputs: List[Any], prefix: str):
|
| """Preprocess the outputs of the epoch end function."""
|
| if self.is_trajectory and prefix == "test" and isinstance(outputs[0]["x"], list):
|
|
|
| x = outputs[0]["x"]
|
| ts = len(x)
|
| x0 = x[0]
|
| x_rest = x[1:]
|
| elif self.is_trajectory:
|
| if hasattr(self.datamodule, "HAS_JOINT_PLANS"):
|
| x = [torch.tensor(dd) for dd in self.datamodule.timepoint_data]
|
| x0 = x[0]
|
| x_rest = x[1:]
|
| ts = len(x)
|
| else:
|
| v = {k: torch.cat([d[k] for d in outputs]) for k in ["x"]}
|
| x = v["x"]
|
| ts = x.shape[1]
|
| x0 = x[:, 0, :]
|
| x_rest = x[:, 1:]
|
| else:
|
| if isinstance(self.dim, int):
|
| v = {k: torch.cat([d[k] for d in outputs]) for k in ["x"]}
|
| x = v["x"]
|
| else:
|
| x = [d["x"] for d in outputs][0][0][:100]
|
|
|
| rand = torch.randn_like(x)
|
|
|
| x = torch.stack([rand, x], dim=1)
|
| ts = x.shape[1]
|
| x0 = x[:, 0]
|
| x_rest = x[:, 1:]
|
| return ts, x, x0, x_rest
|
|
|
| def forward_eval_integrate(self, ts, x0, x_rest, outputs, prefix):
|
|
|
| t_span = torch.linspace(0, 1, 101)
|
| regs = []
|
| trajs = []
|
| full_trajs = []
|
| solver = self.partial_solver(self.net, self.dim)
|
| nfe = 0
|
| x0_tmp = x0.clone()
|
|
|
| if self.is_image:
|
| traj = solver.odeint(x0, t_span)
|
| full_trajs.append(traj)
|
| trajs.append(traj[0])
|
| trajs.append(traj[-1])
|
| nfe += solver.nfe
|
|
|
| if not self.is_image:
|
| solver.augmentations = self.val_augmentations
|
| for i in range(ts - 1):
|
| traj, aug = solver.odeint(x0_tmp, t_span + i)
|
| full_trajs.append(traj)
|
| traj, aug = traj[-1], aug[-1]
|
| x0_tmp = traj
|
| regs.append(torch.mean(aug, dim=0).detach().cpu().numpy())
|
| trajs.append(traj)
|
| nfe += solver.nfe
|
|
|
| full_trajs = torch.cat(full_trajs)
|
|
|
| if not self.is_image:
|
| regs = np.stack(regs).mean(axis=0)
|
| names = [f"{prefix}/{name}" for name in self.val_augmentations.names]
|
| self.log_dict(dict(zip(names, regs)), sync_dist=True)
|
|
|
|
|
| if (
|
| self.is_trajectory
|
| and prefix == "test"
|
| and isinstance(outputs[0]["x"], list)
|
| and not hasattr(self.datamodule, "GAUSSIAN_CLOSED_FORM")
|
| ):
|
|
|
| trajs = []
|
| full_trajs = []
|
| nfe = 0
|
| x0_tmp = x0
|
| for i in range(ts - 1):
|
| traj, _ = solver.odeint(x0_tmp, t_span + i)
|
| traj = traj[-1]
|
| x0_tmp = x_rest[i]
|
| trajs.append(traj)
|
| nfe += solver.nfe
|
| names, dists = compute_distribution_distances(trajs[:-1], x_rest[:-1])
|
| else:
|
| names, dists = compute_distribution_distances(trajs, x_rest)
|
| names = [f"{prefix}/{name}" for name in names]
|
| d = dict(zip(names, dists))
|
| if self.hparams.leaveout_timepoint >= 0:
|
| to_add = {
|
| f"{prefix}/t_out/{key.split('/')[-1]}": val
|
| for key, val in d.items()
|
| if key.startswith(f"{prefix}/t{self.hparams.leaveout_timepoint}")
|
| }
|
| d.update(to_add)
|
| d[f"{prefix}/nfe"] = nfe
|
|
|
| self.log_dict(d, sync_dist=True)
|
|
|
| if hasattr(self.datamodule, "GAUSSIAN_CLOSED_FORM"):
|
| solver.augmentations = None
|
|
|
|
|
|
|
|
|
| t_span = torch.linspace(0, 1, 21)
|
| traj = solver.odeint(x0, t_span)
|
| assert traj.shape[0] == t_span.shape[0]
|
| kls = [
|
| self.datamodule.KL(xt, self.hparams.sigma_min, t) for t, xt in zip(t_span, traj)
|
| ]
|
| self.log_dict({f"{prefix}/kl/mean": torch.stack(kls).mean().item()}, sync_dist=True)
|
| self.log_dict({f"{prefix}/kl/tp_{i}": kls[i] for i in range(21)}, sync_dist=True)
|
|
|
| return trajs, full_trajs
|
|
|
| def eval_epoch_end(self, outputs: List[Any], prefix: str):
|
| wandb_logger = get_wandb_logger(self.loggers)
|
| if prefix == "test" and self.is_image:
|
| os.makedirs("images", exist_ok=True)
|
| if len(os.listdir("images")) > 0:
|
| path = "/home/mila/a/alexander.tong/scratch/trajectory-inference/data/fid_stats_cifar10_train.npz"
|
| from pytorch_fid import fid_score
|
|
|
| fid = fid_score.calculate_fid_given_paths(["images", path], 256, "cuda", 2048, 0)
|
| self.log(f"{prefix}/fid", fid)
|
|
|
| ts, x, x0, x_rest = self.preprocess_epoch_end(outputs, prefix)
|
| trajs, full_trajs = self.forward_eval_integrate(ts, x0, x_rest, outputs, prefix)
|
|
|
| if self.hparams.plot:
|
| if isinstance(self.dim, int):
|
| plot_trajectory(
|
| x,
|
| full_trajs,
|
| title=f"{self.current_epoch}_ode",
|
| key="ode_path",
|
| wandb_logger=wandb_logger,
|
| )
|
| else:
|
| plot_samples(
|
| trajs[-1],
|
| title=f"{self.current_epoch}_samples",
|
| wandb_logger=wandb_logger,
|
| )
|
|
|
| if prefix == "test" and not self.is_image:
|
| store_trajectories(x, self.net)
|
|
|
| def validation_step(self, batch: Any, batch_idx: int):
|
| return self.eval_step(batch, batch_idx, "val")
|
|
|
| def validation_epoch_end(self, outputs: List[Any]):
|
| self.eval_epoch_end(outputs, "val")
|
|
|
| def test_step(self, batch: Any, batch_idx: int):
|
| return self.eval_step(batch, batch_idx, "test")
|
|
|
| def test_epoch_end(self, outputs: List[Any]):
|
| self.eval_epoch_end(outputs, "test")
|
|
|
| def configure_optimizers(self):
|
| """Pass model parameters to optimizer."""
|
| optimizer = self.optimizer(params=self.parameters())
|
| if self.scheduler is None:
|
| return optimizer
|
|
|
| scheduler = self.scheduler(optimizer)
|
| return [optimizer], [{"scheduler": scheduler, "interval": "epoch"}]
|
|
|
| def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
|
| scheduler.step(epoch=self.current_epoch)
|
|
|
|
|
| class RectifiedFlowLitModule(CFMLitModule):
|
| def __init__(
|
| self,
|
| net: Any,
|
| optimizer: Any,
|
| datamodule: LightningDataModule,
|
| augmentations: AugmentationModule,
|
| partial_solver: FlowSolver,
|
| val_augmentations: Optional[AugmentationModule] = None,
|
| scheduler: Optional[Any] = None,
|
| neural_ode: Optional[Any] = None,
|
| ot_sampler: Optional[Union[str, Any]] = None,
|
| sigma_min: float = 0.1,
|
| rectify_epochs: Optional[List[int]] = None,
|
| test_nfe: int = 100,
|
| avg_size: int = -1,
|
| leaveout_timepoint: int = -1,
|
| plot: bool = False,
|
| nice_name: str = "Rect",
|
| ) -> None:
|
| """Initialize a conditional flow matching network either as a generative model or for a
|
| sequence of timepoints.
|
|
|
| Args:
|
| net: torch module representing dx/dt = f(t, x) for t in [1, T] missing dimension.
|
| optimizer: partial torch.optimizer missing parameters.
|
| datamodule: datamodule object needs to have "dim", "IS_TRAJECTORY" properties.
|
| ot_sampler: ot_sampler specified as an object or string. If none then no OT is used in minibatch.
|
| sigma_min: sigma_min determines the width of the Gaussian smoothing of the data and interpolations.
|
| leaveout_timepoint: which (if any) timepoint to leave out during the training phase
|
| plot: if true, log intermediate plots during validation
|
| """
|
| super(CFMLitModule, self).__init__()
|
| self.save_hyperparameters(
|
| ignore=[
|
| "net",
|
| "optimizer",
|
| "scheduler",
|
| "datamodule",
|
| "augmentations",
|
| "val_augmentations",
|
| "partial_solver",
|
| ],
|
| logger=False,
|
| )
|
| self.datamodule = datamodule
|
| self.is_trajectory = False
|
| if hasattr(datamodule, "IS_TRAJECTORY"):
|
| self.is_trajectory = datamodule.IS_TRAJECTORY
|
| if hasattr(datamodule, "dim"):
|
| self.dim = datamodule.dim
|
| self.is_image = False
|
| elif hasattr(datamodule, "dims"):
|
| self.dim = datamodule.dims
|
| self.is_image = True
|
| else:
|
| raise NotImplementedError("Datamodule must have either dim or dims")
|
| self.net = net(dim=self.dim)
|
| self.frozen_net = None
|
| self.augmentations = augmentations
|
| self.aug_net = AugmentedVectorField(self.net, self.augmentations.regs, self.dim)
|
| self.val_augmentations = val_augmentations
|
| if val_augmentations is None:
|
| self.val_augmentations = AugmentationModule(
|
| l1_reg=1,
|
| l2_reg=1,
|
| squared_l2_reg=1,
|
| )
|
| self.val_aug_net = AugmentedVectorField(self.net, self.val_augmentations.regs, self.dim)
|
| if neural_ode is not None:
|
| self.aug_node = Sequential(
|
| self.augmentations.augmenter,
|
| neural_ode(self.aug_net),
|
| )
|
| self.partial_solver = partial_solver
|
| self.optimizer = optimizer
|
| self.scheduler = scheduler
|
| self.ot_sampler = ot_sampler
|
| if ot_sampler == "None":
|
| self.ot_sampler = None
|
| if isinstance(self.ot_sampler, str):
|
|
|
| self.ot_sampler = OTPlanSampler(method=ot_sampler, reg=2 * sigma_min**2)
|
| self.criterion = torch.nn.MSELoss()
|
|
|
| def preprocess_batch(self, X, training=False):
|
| """Converts a batch of data into matched a random pair of (x0, x1)"""
|
| t_select = torch.zeros(1, device=X.device)
|
| if self.is_trajectory:
|
| batch_size, times, dim = X.shape
|
| if training and self.hparams.leaveout_timepoint > 0:
|
|
|
| t_select = torch.randint(times - 2, size=(batch_size,), device=X.device)
|
| t_select[t_select >= self.hparams.leaveout_timepoint] += 1
|
| else:
|
| t_select = torch.randint(times - 1, size=(batch_size,))
|
| x0 = []
|
| x1 = []
|
| for i in range(batch_size):
|
| ti = t_select[i]
|
| ti_next = ti + 1
|
| if training and ti_next == self.hparams.leaveout_timepoint:
|
| ti_next += 1
|
| x0.append(X[i, ti])
|
| x1.append(X[i, ti_next])
|
| x0, x1 = torch.stack(x0), torch.stack(x1)
|
| else:
|
| batch_size = X.shape[0]
|
|
|
| x0 = torch.randn_like(X)
|
| x1 = X
|
|
|
| if self.frozen_net is not None:
|
|
|
| assert t_select[0] == 0
|
| t_span = torch.linspace(0, 1, 100)
|
| val_node = NeuralODE(self.frozen_net, solver="euler")
|
| with torch.no_grad():
|
| _, traj = val_node(x0, t_span)
|
| x1 = traj[-1]
|
| return x0, x1, t_select
|
|
|
| def training_epoch_end(self, training_step_outputs):
|
| if (
|
| self.hparams.rectify_epochs is not None
|
| and self.current_epoch in self.hparams.rectify_epochs
|
| ):
|
| self.frozen_net = copy.deepcopy(self.net)
|
|
|
|
|
| class ActionMatchingLitModule(CFMLitModule):
|
| """Implements Action Matching: Learning Stochastic Dynamics from Samples (Neklyudov et al.
|
| 2022)
|
|
|
| Requires net to have a .energy function where net.energy(t, x): \\mathbb{R}^{d+1} \to
|
| \\mathbb{R} and net.forward is equal to \nabla_x(net.energy).
|
| """
|
|
|
| def step(self, batch: Any, training: bool = False):
|
| """Computes the loss on a batch of data."""
|
| assert not self.is_trajectory
|
| energy = self.net.energy
|
| X = self.unpack_batch(batch)
|
| x0, x1, t_select = self.preprocess_batch(X, training)
|
|
|
| if self.ot_sampler is not None:
|
| x0, x1 = self.ot_sampler.sample_plan(x0, x1)
|
|
|
| t = torch.rand(X.shape[0]).type_as(X)
|
| t_xshape = t.reshape(-1, *([1] * (x0.dim() - 1)))
|
| xt = t_xshape * x1 + (1 - t_xshape) * x0
|
|
|
| t = t + t_select.reshape(-1, *t.shape[1:])
|
|
|
| xt.requires_grad, t_xshape.requires_grad = True, True
|
| with torch.set_grad_enabled(True):
|
| st = torch.sum(energy(torch.cat([xt, t_xshape], dim=-1)))
|
| dsdx, dsdt = torch.autograd.grad(st, (xt, t_xshape), create_graph=True)
|
| xt.requires_grad, t_xshape.requires_grad = False, False
|
| a0 = energy(torch.cat([x0, torch.zeros(x0.shape[0], 1)], dim=-1))
|
| a1 = energy(torch.cat([x1, torch.ones(x1.shape[0], 1)], dim=-1))
|
| loss = a0 - a1 + 0.5 * (dsdx**2).sum(1, keepdims=True) + dsdt
|
| loss = loss.mean()
|
| aug_x = self.aug_net(t, xt, augmented_input=False)
|
| reg, vt = self.augmentations(aug_x)
|
| return torch.mean(reg), loss
|
|
|
|
|
| class VariancePreservingCFM(CFMLitModule):
|
| """Implements a variance preserving time schedule as suggested in (Albergo et al.
|
|
|
| 2023) here we have an interpolation cos(t pi/2) x_0 + sin(t pi/2) x_1.
|
| """
|
|
|
| def calc_mu_sigma(self, x0, x1, t):
|
| assert not self.is_trajectory
|
| mu_t = torch.cos(math.pi / 2 * t) * x0 + torch.sin(math.pi / 2 * t) * x1
|
| sigma_t = self.hparams.sigma_min
|
| return mu_t, sigma_t
|
|
|
| def calc_u(self, x0, x1, x, t, mu_t, sigma_t):
|
| del x, mu_t, sigma_t
|
| return math.pi / 2 * (torch.cos(math.pi / 2 * t) * x1 - torch.sin(math.pi / 2 * t) * x0)
|
|
|
|
|
| class SBCFMLitModule(CFMLitModule):
|
| """Implements a Schrodinger Bridge based conditional flow matching model.
|
|
|
| This is similar to the OTCFM loss, however with the variance varying with t*(1-t). This has
|
| provably equal probability flow to the Schrodinger bridge solution when the transport is
|
| computed with the squared Euclidean distance on R^d.
|
| """
|
|
|
| def calc_mu_sigma(self, x0, x1, t):
|
| assert not self.is_trajectory
|
| mu_t = t * x1 + (1 - t) * x0
|
| sigma_t = self.hparams.sigma_min * torch.sqrt(t - t**2)
|
| return mu_t, sigma_t
|
|
|
| def calc_u(self, x0, x1, x, t, mu_t, sigma_t):
|
| del sigma_t
|
| sigma_t_prime_over_sigma_t = (1 - 2 * t) / (2 * t * (1 - t))
|
| ut = sigma_t_prime_over_sigma_t * (x - mu_t) + x1 - x0
|
| return ut
|
|
|
|
|
| class SF2MLitModule(CFMLitModule):
|
| def __init__(
|
| self,
|
| net: Any,
|
| optimizer: Any,
|
| datamodule: LightningDataModule,
|
| augmentations: AugmentationModule,
|
| partial_solver: FlowSolver,
|
| score_net: Optional[Any] = None,
|
| scheduler: Optional[Any] = None,
|
| ot_sampler: Optional[Union[str, Any]] = None,
|
| sigma: Optional[NoiseScheduler] = None,
|
| sigma_min: float = 0.1,
|
| outer_loop_epochs: Optional[int] = None,
|
| score_weight: float = 1.0,
|
| avg_size: int = -1,
|
| leaveout_timepoint: int = -1,
|
| test_nfe: int = 100,
|
| test_sde: bool = False,
|
| plot: bool = False,
|
| nice_name: Optional[str] = "SF2M",
|
| ) -> None:
|
| """Initialize a conditional flow matching network either as a generative model or for a
|
| sequence of timepoints.
|
|
|
| Args:
|
| net: torch module representing dx/dt = f(t, x) for t in [1, T] missing dimension.
|
| score_net: torch module representing the score function of the flow.
|
| If not supplied it is assumed that the net contains both flow and
|
| score.
|
| optimizer: partial torch.optimizer missing parameters.
|
| datamodule: datamodule object needs to have "dim", "IS_TRAJECTORY" properties.
|
| ot_sampler: ot_sampler specified as an object or string. If none then no OT is used in minibatch.
|
| sigma: sigma determines the width of the Gaussian smoothing of the data and interpolations.
|
| leaveout_timepoint: which (if any) timepoint to leave out during the training phase
|
| plot: if true, log intermediate plots during validation
|
| """
|
| super(CFMLitModule, self).__init__()
|
| self.save_hyperparameters(
|
| ignore=[
|
| "net",
|
| "optimizer",
|
| "scheduler",
|
| "datamodule",
|
| "augmentations",
|
| "sigma_scheduler",
|
| "partial_solver",
|
| ],
|
| logger=False,
|
| )
|
| self.datamodule = datamodule
|
| self.is_trajectory = False
|
| if hasattr(datamodule, "IS_TRAJECTORY"):
|
| self.is_trajectory = datamodule.IS_TRAJECTORY
|
|
|
|
|
| if hasattr(datamodule, "dim"):
|
| self.dim = datamodule.dim
|
| self.is_image = False
|
| elif hasattr(datamodule, "dims"):
|
| self.dim = datamodule.dims
|
| self.is_image = True
|
| else:
|
| raise NotImplementedError("Datamodule must have either dim or dims")
|
| self.net = net(dim=self.dim)
|
| self.separate_score = score_net is not None
|
| self.score_net = score_net
|
| if self.separate_score:
|
| self.score_net = score_net(dim=self.dim)
|
| self.partial_solver = partial_solver
|
| self.augmentations = augmentations
|
| self.aug_net = AugmentedVectorField(self.net, self.augmentations.regs, self.dim)
|
| self.val_augmentations = AugmentationModule(
|
|
|
| l1_reg=1,
|
| l2_reg=1,
|
| squared_l2_reg=1,
|
| )
|
| self.val_aug_net = AugmentedVectorField(self.net, self.val_augmentations.regs, self.dim)
|
| self.optimizer = optimizer
|
| self.scheduler = scheduler
|
| self.sigma = sigma
|
| if sigma is None:
|
| self.sigma = ConstantNoiseScheduler(sigma_min)
|
| self.ot_sampler = ot_sampler
|
| if ot_sampler == "None":
|
| self.ot_sampler = None
|
| if isinstance(self.ot_sampler, str):
|
|
|
| self.ot_sampler = OTPlanSampler(method=ot_sampler, reg=2 * self.sigma.F(1))
|
| self.criterion = torch.nn.MSELoss()
|
|
|
|
|
| self.stored_data = None
|
| self.tmp_stored_data = None
|
|
|
| def calc_mu_sigma(self, x0, x1, t):
|
|
|
| ft = self.sigma.F(t)
|
| fone = self.sigma.F(1)
|
| mu_t = x0 + (x1 - x0) * ft / fone
|
|
|
| sigma_t = torch.sqrt(ft - ft**2 / fone)
|
| return mu_t, sigma_t
|
|
|
| def calc_u(self, x0, x1, x, t, mu_t, sigma_t):
|
| ft = self.sigma.F(t)
|
| fone = self.sigma.F(1)
|
| sigma_t_prime = self.sigma(t) ** 2 - 2 * ft * self.sigma(t) ** 2 / fone
|
| sigma_t_prime_over_sigma_t = sigma_t_prime / (sigma_t + 1e-8)
|
| mu_t_prime = (x1 - x0) * self.sigma(t) ** 2 / fone
|
| ut = sigma_t_prime_over_sigma_t * (x - mu_t) + mu_t_prime
|
| return ut
|
|
|
| def calc_loc_and_target(self, x0, x1, t, t_select, training):
|
| t_xshape = t.reshape(-1, *([1] * (x0.dim() - 1)))
|
| mu_t, sigma_t = self.calc_mu_sigma(x0, x1, t_xshape)
|
| eps_t = torch.randn_like(mu_t)
|
| x = mu_t + sigma_t * eps_t
|
| ut = self.calc_u(x0, x1, x, t_xshape, mu_t, sigma_t)
|
|
|
|
|
|
|
| if training and self.hparams.leaveout_timepoint > 0:
|
| ut[t_select + 1 == self.hparams.leaveout_timepoint] /= 2
|
| t[t_select + 1 == self.hparams.leaveout_timepoint] *= 2
|
|
|
|
|
|
|
| score_target = eps_t
|
|
|
| t = t + t_select.reshape(-1, *t.shape[1:])
|
| return x, ut, t, mu_t, sigma_t, score_target
|
|
|
| def forward_flow_and_score(self, t, x):
|
| if self.separate_score:
|
| reg, vt = self.augmentations(self.aug_net(t, x, augmented_input=False))
|
| st = self.score_net(t, x)
|
| return reg, vt, st
|
| reg, vtst = self.augmentations(self.aug_net(t, x, augmented_input=False))
|
| split_idx = vtst.shape[1] // 2
|
| vt, st = vtst[:, :split_idx], vtst[:, split_idx:]
|
| return reg, vt, st
|
|
|
| def step(self, batch: Any, training: bool = False):
|
| """Computes the loss on a batch of data."""
|
| X = self.unpack_batch(batch)
|
| x0, x1, t_select = self.preprocess_batch(X, training)
|
|
|
| if self.hparams.avg_size > 0:
|
| t = torch.rand(1).repeat(X.shape[0]).type_as(X)
|
| else:
|
| t = torch.rand(X.shape[0]).type_as(X)
|
|
|
| if self.ot_sampler is not None and self.stored_data is None:
|
| x0, x1 = self.ot_sampler.sample_plan(x0, x1)
|
| t_orig = t.clone()
|
|
|
| x, ut, t, mu_t, sigma_t, score_target = self.calc_loc_and_target(
|
| x0, x1, t, t_select, training
|
| )
|
|
|
| if self.hparams.avg_size > 0:
|
| x, ut, t = self.average_ut(x, t, mu_t, sigma_t, ut)
|
|
|
| reg, vt, st = self.forward_flow_and_score(t, x)
|
| flow_loss = self.criterion(vt, ut)
|
| score_loss = self.criterion(
|
| -sigma_t * st / (self.sigma(t_orig.reshape(sigma_t.shape)) ** 2) * 2,
|
| score_target,
|
| )
|
| return torch.mean(reg) + self.hparams.score_weight * score_loss, flow_loss
|
|
|
| def forward_sde_eval(self, ts, x0, x_rest, outputs, prefix):
|
|
|
| t_span = torch.linspace(0, 1, 2)
|
| solver = self.partial_solver(
|
| self.net, self.dim, score_field=self.score_net, sigma=self.sigma
|
| )
|
| if False and self.is_image:
|
| traj = solver.sdeint(x0, t_span, logqp=False)
|
|
|
| trajs = []
|
| full_trajs = []
|
| nfe = 0
|
| kldiv_total = 0
|
| x0_tmp = x0.clone()
|
| for i in range(ts - 1):
|
| traj, kldiv = solver.sdeint(x0_tmp, t_span + i, logqp=True)
|
| kldiv_total += torch.mean(kldiv[-1])
|
| x0_tmp = traj[-1]
|
| trajs.append(traj[-1])
|
| full_trajs.append(traj)
|
| nfe += solver.nfe
|
| full_trajs = torch.cat(full_trajs)
|
| if not self.is_image:
|
|
|
| if (
|
| self.is_trajectory
|
| and prefix == "test"
|
| and isinstance(outputs[0]["x"], list)
|
| and not hasattr(self.datamodule, "GAUSSIAN_CLOSED_FORM")
|
| ):
|
| trajs = []
|
| full_trajs = []
|
| nfe = 0
|
| kldiv_total = 0
|
| x0_tmp = x0.clone()
|
| for i in range(ts - 1):
|
| traj, kldiv = solver.sdeint(x0_tmp, t_span + i, logqp=True)
|
| x0_tmp = x_rest[i]
|
| kldiv_total += torch.mean(kldiv[-1])
|
| trajs.append(traj[-1])
|
| full_trajs.append(traj)
|
| nfe += solver.nfe
|
| names, dists = compute_distribution_distances(trajs[:-1], x_rest[:-1])
|
| else:
|
| names, dists = compute_distribution_distances(trajs, x_rest)
|
| names = [f"{prefix}/sde/{name}" for name in names]
|
| d = dict(zip(names, dists))
|
| if self.hparams.leaveout_timepoint >= 0:
|
| to_add = {
|
| f"{prefix}/sde/t_out/{key.split('/')[-1]}": val
|
| for key, val in d.items()
|
| if key.startswith(f"{prefix}/sde/t{self.hparams.leaveout_timepoint}")
|
| }
|
| d.update(to_add)
|
| d[f"{prefix}/sde/nfe"] = nfe
|
| d[f"{prefix}/sde/kldiv"] = kldiv_total
|
| self.log_dict(d, sync_dist=True)
|
| if hasattr(self.datamodule, "GAUSSIAN_CLOSED_FORM"):
|
| solver.augmentations = None
|
| t_span = torch.linspace(0, 1, 21)
|
| solver.dt = 0.05
|
|
|
| traj = solver.sdeint(x0, t_span)
|
| assert traj.shape[0] == t_span.shape[0]
|
| kls = [
|
| self.datamodule.KL(xt, self.hparams.sigma_min, t) for t, xt in zip(t_span, traj)
|
| ]
|
| self.log_dict(
|
| {f"{prefix}/sde/kl/mean": torch.stack(kls).mean().item()},
|
| sync_dist=True,
|
| )
|
| self.log_dict({f"{prefix}/sde/kl/tp_{i}": kls[i] for i in range(21)}, sync_dist=True)
|
| return trajs, full_trajs
|
|
|
| def eval_epoch_end(self, outputs: List[Any], prefix: str):
|
| super().eval_epoch_end(outputs, prefix)
|
| wandb_logger = get_wandb_logger(self.loggers)
|
| ts, x, x0, x_rest = self.preprocess_epoch_end(outputs, prefix)
|
| if isinstance(self.dim, int):
|
| traj, sde_traj = self.forward_sde_eval(ts, x0, x_rest, outputs, prefix)
|
|
|
| if self.hparams.plot:
|
| if isinstance(self.dim, int):
|
| plot_trajectory(
|
| x,
|
| sde_traj,
|
| title=f"{self.current_epoch}_sde_traj",
|
| key="sde",
|
| wandb_logger=wandb_logger,
|
| )
|
|
|
| def preprocess_batch(self, X, training=False):
|
| """Converts a batch of data into matched a random pair of (x0, x1)"""
|
| if self.stored_data is not None and training:
|
|
|
| idx = torch.randint(self.stored_data.shape[0], size=(X.shape[0],))
|
| X = self.stored_data[idx]
|
| t_select = torch.zeros(1, device=X.device)
|
| return X[:, 0], X[:, 1], t_select
|
| return super().preprocess_batch(X, training)
|
|
|
| def training_step(self, batch: Any, batch_idx: int):
|
|
|
| if (
|
| self.hparams.outer_loop_epochs is not None
|
| and (self.current_epoch + 1) % self.hparams.outer_loop_epochs == 0
|
| ):
|
| X = self.unpack_batch(batch)
|
| x0, x1, t_select = self.preprocess_batch(X, training=True)
|
| assert not torch.any(t_select)
|
| solver = self.partial_solver
|
| t_span = torch.linspace(0, 1, 2)
|
| solver = self.partial_solver(
|
| self.net, self.dim, score_field=self.score_net, sigma=self.sigma
|
| )
|
| batch_size = x0.shape[0]
|
| with torch.no_grad():
|
| forward_traj = solver.sdeint(x0[: batch_size // 2], t_span)
|
| backward_traj = torch.flip(
|
| solver.sdeint(x1[batch_size // 2 :], t_span, reverse=True), (0,)
|
| )
|
| stored_traj = torch.cat([forward_traj, backward_traj], dim=1)
|
| stored_traj = stored_traj.transpose(0, 1)
|
| if batch_idx == 0:
|
| self.tmp_stored_data = []
|
| self.tmp_stored_data.append(stored_traj)
|
| return super().training_step(batch, batch_idx)
|
|
|
| def training_epoch_end(self, training_step_outputs):
|
| if (
|
| self.hparams.outer_loop_epochs is not None
|
| and (self.current_epoch + 1) % self.hparams.outer_loop_epochs == 0
|
| ):
|
| self.stored_data = torch.cat(self.tmp_stored_data, dim=0).detach().clone()
|
|
|
| def image_eval_step(self, batch: Any, batch_idx: int, prefix: str):
|
| import os
|
|
|
| from torchvision.utils import save_image
|
|
|
| solver = self.partial_solver(self.net, self.dim)
|
| if isinstance(self.hparams.test_nfe, int):
|
| t_span = torch.linspace(0, 1, int(self.hparams.test_nfe) + 1)
|
| elif isinstance(self.hparams.test_nfe, str):
|
| solver.ode_solver = "tsit5"
|
| t_span = torch.linspace(0, 1, 2).type_as(batch[0])
|
| else:
|
| raise NotImplementedError(f"Unknown test procedure {self.hparams.test_nfe}")
|
| if self.hparams.test_sde:
|
| solver = self.partial_solver(
|
| self.net, self.dim, score_field=self.score_net, sigma=self.sigma
|
| )
|
| solver.dt = 1 / int(self.hparams.test_nfe)
|
| t_span = torch.linspace(0, 1, 2).type_as(batch[0])
|
| integrator = solver.sdeint
|
| else:
|
| integrator = solver.odeint
|
| x0 = torch.randn(5 * batch[0].shape[0], *self.dim).type_as(batch[0])
|
| traj = integrator(x0, t_span)[-1]
|
| os.makedirs("images", exist_ok=True)
|
| mean = [-x / 255.0 for x in [125.3, 123.0, 113.9]]
|
| std = [255.0 / x for x in [63.0, 62.1, 66.7]]
|
| inv_normalize = transforms.Compose(
|
| [
|
| transforms.Normalize(mean=[0.0, 0.0, 0.0], std=std),
|
| transforms.Normalize(mean=mean, std=[1.0, 1.0, 1.0]),
|
| ]
|
| )
|
| traj = inv_normalize(traj)
|
| traj = torch.clip(traj, min=0, max=1.0)
|
| for i, image in enumerate(traj):
|
| save_image(image, fp=f"images/{batch_idx}_{i}.png")
|
| os.makedirs("compressed_images", exist_ok=True)
|
| torch.save(traj.cpu(), f"compressed_images/{batch_idx}.pt")
|
| return {"x": batch[0]}
|
|
|
|
|
| class OneWaySF2MLitModule(SF2MLitModule):
|
| def calc_loc_and_target(self, x0, x1, t, t_select, training):
|
| x, ut, t, mu_t, sigma_t, score_target = super().calc_loc_and_target(
|
| x0, x1, t, t_select, training
|
| )
|
| t_xshape = t.reshape(-1, *([1] * (x0.dim() - 1)))
|
| eps_t = -score_target * 2 / (self.sigma(t_xshape) ** 2)
|
| forward_target = (
|
| x1 - x0 - (self.sigma(t_xshape) * torch.sqrt(t_xshape / (1 - t_xshape + 1e-6))) * eps_t
|
| )
|
| return x, forward_target, t, mu_t, sigma_t, None
|
|
|
| def step(self, batch: Any, training: bool = False):
|
| """Computes the loss on a batch of data."""
|
| X = self.unpack_batch(batch)
|
| x0, x1, t_select = self.preprocess_batch(X, training)
|
|
|
| if self.hparams.avg_size > 0:
|
| t = torch.rand(1).repeat(X.shape[0]).type_as(X)
|
| else:
|
| t = torch.rand(X.shape[0]).type_as(X)
|
|
|
| if self.ot_sampler is not None and self.stored_data is None:
|
| x0, x1 = self.ot_sampler.sample_plan(x0, x1)
|
|
|
| x, forward_target, t, _, _, _ = self.calc_loc_and_target(x0, x1, t, t_select, training)
|
| t_xshape = t.reshape(-1, *([1] * (x0.dim() - 1)))
|
| forward_scaling = (1 + self.sigma(t_xshape) ** 2 * t_xshape / (1 - t_xshape + 1e-6)) ** -1
|
| reg, vt, st = self.forward_flow_and_score(t, x)
|
| forward_flow_loss = torch.mean(forward_scaling * (vt - forward_target) ** 2)
|
| return torch.mean(reg), forward_flow_loss
|
|
|
| def forward_eval_integrate(self, ts, x0, x_rest, outputs, prefix):
|
|
|
| t_span = torch.linspace(0, 1, 101).type_as(x0)
|
| regs = []
|
| trajs = []
|
| full_trajs = []
|
| solver = self.partial_solver(
|
| self.net, self.dim, score_field=self.score_net, sigma=self.sigma
|
| )
|
| nfe = 0
|
| x0_tmp = x0.clone()
|
| for i in range(ts - 1):
|
| if not self.is_image:
|
| solver.augmentations = self.val_augmentations
|
| traj, aug = solver.sdeint(x0_tmp, t_span + i)
|
| aug = aug[-1]
|
| regs.append(torch.mean(aug, dim=0).detach().cpu().numpy())
|
| else:
|
| traj = solver.sdeint(x0_tmp, t_span + i)
|
| full_trajs.append(traj)
|
| traj = traj[-1]
|
| x0_tmp = traj
|
| trajs.append(traj)
|
| nfe += solver.nfe
|
|
|
| if not self.is_image:
|
| regs = np.stack(regs).mean(axis=0)
|
| names = [f"{prefix}/{name}" for name in self.val_augmentations.names]
|
| self.log_dict(dict(zip(names, regs)), sync_dist=True)
|
|
|
|
|
| names, dists = compute_distribution_distances(trajs, x_rest)
|
| names = [f"{prefix}/{name}" for name in names]
|
| d = dict(zip(names, dists))
|
| if self.hparams.leaveout_timepoint >= 0:
|
| to_add = {
|
| f"{prefix}/t_out/{key.split('/')[-1]}": val
|
| for key, val in d.items()
|
| if key.startswith(f"{prefix}/t{self.hparams.leaveout_timepoint}")
|
| }
|
| d.update(to_add)
|
| d[f"{prefix}/nfe"] = nfe
|
| self.log_dict(d, sync_dist=True)
|
|
|
| if hasattr(self.datamodule, "GAUSSIAN_CLOSED_FORM"):
|
| solver.augmentations = None
|
| t_span = torch.linspace(0, 1, 21)
|
| traj = solver.odeint(x0, t_span)
|
|
|
|
|
| assert traj.shape[0] == t_span.shape[0]
|
| kls = [
|
| self.datamodule.KL(xt, self.hparams.sigma_min, t) for t, xt in zip(t_span, traj)
|
| ]
|
|
|
|
|
| self.log_dict({f"{prefix}/kl/mean": torch.stack(kls).mean().item()}, sync_dist=True)
|
| self.log_dict({f"{prefix}/kl/tp_{i}": kls[i] for i in range(21)}, sync_dist=True)
|
|
|
| full_trajs = torch.cat(full_trajs)
|
| return trajs, full_trajs
|
|
|
|
|
| class DSBMLitModule(SF2MLitModule):
|
| """Based on SF2M module except directly regresses against the target SDE drift rather than
|
| separating the ODE and Score components."""
|
|
|
| def calc_loc_and_target(self, x0, x1, t, t_select, training):
|
| t_xshape = t.reshape(-1, *([1] * (x0.dim() - 1))).clone()
|
| x, ut, t_plus_t_select, mu_t, sigma_t, eps_t = super().calc_loc_and_target(
|
| x0, x1, t, t_select, training
|
| )
|
| forward_target = (
|
| x1 - x0 - (self.sigma(t_xshape) * torch.sqrt(t_xshape / (1 - t_xshape + 1e-6))) * eps_t
|
| )
|
| backward_target = (
|
| x0
|
| - x1
|
| - (self.sigma(t_xshape) * torch.sqrt((1 - t_xshape) / (t_xshape + 1e-6))) * eps_t
|
| )
|
| return x, forward_target, t_plus_t_select, mu_t, sigma_t, backward_target
|
|
|
| def step(self, batch: Any, training: bool = False):
|
| """Computes the loss on a batch of data."""
|
| X = self.unpack_batch(batch)
|
| x0, x1, t_select = self.preprocess_batch(X, training)
|
|
|
| if self.hparams.avg_size > 0:
|
| t = torch.rand(1).repeat(X.shape[0]).type_as(X)
|
| else:
|
| t = torch.rand(X.shape[0]).type_as(X)
|
|
|
| if self.ot_sampler is not None and self.stored_data is None:
|
| x0, x1 = self.ot_sampler.sample_plan(x0, x1)
|
|
|
| forward_scaling = (1 + self.sigma(t) ** 2 * t / (1 - t + 1e-6)) ** -1
|
| backward_scaling = (1 + self.sigma(t) ** 2 * (1 - t) / (t + 1e-6)) ** -1
|
| x, forward_target, t, _, _, backward_target = self.calc_loc_and_target(
|
| x0, x1, t, t_select, training
|
| )
|
|
|
| reg, vt, st = self.forward_flow_and_score(t, x)
|
| forward_flow_loss = torch.mean(forward_scaling[:, None] * (vt - forward_target) ** 2)
|
| backward_flow_loss = torch.mean(backward_scaling[:, None] * (st - backward_target) ** 2)
|
| if not torch.isfinite(forward_flow_loss) or not torch.isfinite(backward_flow_loss):
|
| raise ValueError("Loss Not Finite")
|
|
|
| return torch.mean(reg) + backward_flow_loss, forward_flow_loss
|
|
|
| def forward_eval_integrate(self, ts, x0, x_rest, outputs, prefix):
|
|
|
| t_span = torch.linspace(0, 1, 101)
|
| regs = []
|
| trajs = []
|
| full_trajs = []
|
| solver = self.partial_solver(
|
| self.net, self.dim, score_field=self.score_net, sigma=self.sigma
|
| )
|
| nfe = 0
|
| x0_tmp = x0.clone()
|
| for i in range(ts - 1):
|
| if not self.is_image:
|
| solver.augmentations = self.val_augmentations
|
| traj, aug = solver.odeint(x0_tmp, t_span + i)
|
| else:
|
| traj = solver.odeint(x0_tmp, t_span + i)
|
| full_trajs.append(traj)
|
| if not self.is_image:
|
| traj, aug = traj[-1], aug[-1]
|
| else:
|
| traj = traj[-1]
|
| aug = torch.tensor(0.0)
|
| x0_tmp = traj
|
| regs.append(torch.mean(aug, dim=0).detach().cpu().numpy())
|
| trajs.append(traj)
|
| nfe += solver.nfe
|
|
|
| if not self.is_image:
|
| regs = np.stack(regs).mean(axis=0)
|
| names = [f"{prefix}/{name}" for name in self.val_augmentations.names]
|
| self.log_dict(dict(zip(names, regs)), sync_dist=True)
|
|
|
|
|
| names, dists = compute_distribution_distances(trajs, x_rest)
|
| names = [f"{prefix}/{name}" for name in names]
|
| d = dict(zip(names, dists))
|
| if self.hparams.leaveout_timepoint >= 0:
|
| to_add = {
|
| f"{prefix}/t_out/{key.split('/')[-1]}": val
|
| for key, val in d.items()
|
| if key.startswith(f"{prefix}/t{self.hparams.leaveout_timepoint}")
|
| }
|
| d.update(to_add)
|
| d[f"{prefix}/nfe"] = nfe
|
| self.log_dict(d, sync_dist=True)
|
|
|
| if hasattr(self.datamodule, "GAUSSIAN_CLOSED_FORM"):
|
| solver.augmentations = None
|
| t_span = torch.linspace(0, 1, 21)
|
| traj = solver.odeint(x0, t_span)
|
|
|
|
|
| assert traj.shape[0] == t_span.shape[0]
|
| kls = [
|
| self.datamodule.KL(xt, self.hparams.sigma_min, t) for t, xt in zip(t_span, traj)
|
| ]
|
|
|
|
|
| self.log_dict({f"{prefix}/kl/mean": torch.stack(kls).mean().item()}, sync_dist=True)
|
| self.log_dict({f"{prefix}/kl/tp_{i}": kls[i] for i in range(21)}, sync_dist=True)
|
|
|
| full_trajs = torch.cat(full_trajs)
|
| return trajs, full_trajs
|
|
|
|
|
| class DSBMSharedLitModule(SF2MLitModule):
|
| """Based on SF2M module except directly regresses against the target SDE drift rather than
|
| separating the ODE and Score components."""
|
|
|
| def step(self, batch: Any, training: bool = False):
|
| """Computes the loss on a batch of data."""
|
| X = self.unpack_batch(batch)
|
| x0, x1, t_select = self.preprocess_batch(X, training)
|
|
|
| if self.hparams.avg_size > 0:
|
| t = torch.rand(1).repeat(X.shape[0]).type_as(X)
|
| else:
|
| t = torch.rand(X.shape[0]).type_as(X)
|
|
|
| if self.ot_sampler is not None:
|
| x0, x1 = self.ot_sampler.sample_plan(x0, x1)
|
|
|
| x, ut, t, mu_t, sigma_t, score_target = self.calc_loc_and_target(
|
| x0, x1, t, t_select, training
|
| )
|
|
|
| if self.hparams.avg_size > 0:
|
| x, ut, t = self.average_ut(x, t, mu_t, sigma_t, ut)
|
| aug_x = self.aug_net(t, x, augmented_input=False)
|
| reg, vt = self.augmentations(aug_x)
|
| forward_flow_loss = self.criterion(vt + sigma_t * self.score_net(t, x), ut + score_target)
|
| backward_flow_loss = self.criterion(
|
| -vt + sigma_t * self.score_net(t, x), -ut + score_target
|
| )
|
|
|
|
|
| return torch.mean(reg) + backward_flow_loss, forward_flow_loss
|
|
|
|
|
| class FMLitModule(CFMLitModule):
|
| """Implements a Lipman et al.
|
|
|
| 2023 style flow matching loss. This maps the standard normal distribution to the data
|
| distribution by using conditional flows that are the optimal transport flow from a narrow
|
| Gaussian around a datapoint to a standard N(x | 0, 1).
|
| """
|
|
|
| def calc_mu_sigma(self, x0, x1, t):
|
| assert not self.is_trajectory
|
| del x0
|
| sigma_min = self.hparams.sigma_min
|
| mu_t = t * x1
|
| sigma_t = 1 - (1 - sigma_min) * t
|
| return mu_t, sigma_t
|
|
|
| def calc_u(self, x0, x1, x, t, mu_t, sigma_t):
|
| del x0, mu_t, sigma_t
|
| sigma_min = self.hparams.sigma_min
|
| ut = (x1 - (1 - sigma_min) * x) / (1 - (1 - sigma_min) * t)
|
| return ut
|
|
|
|
|
| class SplineCFMLitModule(CFMLitModule):
|
| """Implements cubic spline version of OT-CFM."""
|
|
|
| def preprocess_batch(self, X, training=False):
|
| from torchcubicspline import NaturalCubicSpline, natural_cubic_spline_coeffs
|
|
|
| """Converts a batch of data into matched a random pair of (x0, x1)"""
|
| lotp = self.hparams.leaveout_timepoint
|
| valid_times = torch.arange(X.shape[1]).type_as(X)
|
| t_select = torch.zeros(1)
|
| batch_size, times, dim = X.shape
|
|
|
| if training and self.hparams.leaveout_timepoint > 0:
|
|
|
| t_select = torch.randint(times - 2, size=(batch_size,))
|
| X = torch.cat([X[:, :lotp], X[:, lotp + 1 :]], dim=1)
|
| valid_times = valid_times[valid_times != lotp]
|
| else:
|
| t_select = torch.randint(times - 1, size=(batch_size,))
|
| traj = torch.from_numpy(self.ot_sampler.sample_trajectory(X)).type_as(X)
|
| x0 = []
|
| x1 = []
|
| for i in range(batch_size):
|
| x0.append(traj[i, t_select[i]])
|
| x1.append(traj[i, t_select[i] + 1])
|
| x0, x1 = torch.stack(x0), torch.stack(x1)
|
| if training and self.hparams.leaveout_timepoint > 0:
|
| t_select[t_select >= self.hparams.leaveout_timepoint] += 1
|
|
|
| coeffs = natural_cubic_spline_coeffs(valid_times, traj)
|
| spline = NaturalCubicSpline(coeffs)
|
| return x0, x1, t_select, spline
|
|
|
| def step(self, batch: Any, training: bool = False):
|
| """Computes the loss on a batch of data."""
|
| assert self.is_trajectory
|
| X = self.unpack_batch(batch)
|
| x0, x1, t_select, spline = self.preprocess_batch(X, training)
|
|
|
| t = torch.rand(X.shape[0], 1)
|
|
|
|
|
|
|
| t = t + t_select[:, None]
|
| ut = torch.stack([spline.derivative(b[0])[i] for i, b in enumerate(t)], dim=0)
|
| mu_t = torch.stack([spline.evaluate(b[0])[i] for i, b in enumerate(t)], dim=0)
|
| sigma_t = self.hparams.sigma_min
|
|
|
|
|
|
|
| if training and self.hparams.leaveout_timepoint > 0:
|
| ut[t_select + 1 == self.hparams.leaveout_timepoint] /= 2
|
| t[t_select + 1 == self.hparams.leaveout_timepoint] *= 2
|
|
|
| x = mu_t + sigma_t * torch.randn_like(x0)
|
| aug_x = self.aug_net(t, x, augmented_input=False)
|
| reg, vt = self.augmentations(aug_x)
|
| return torch.mean(reg), self.criterion(vt, ut)
|
|
|
|
|
| class CNFLitModule(CFMLitModule):
|
| def forward_integrate(self, batch: Any, t_span: torch.Tensor):
|
| """Forward pass with integration over t_span intervals.
|
|
|
| (t, x, t_span) -> [x_t_span].
|
| """
|
| return super().forward_integrate(batch, t_span + 1)
|
|
|
| def step(self, batch: Any, training: bool = False):
|
| obs = self.unpack_batch(batch)
|
| if not self.is_trajectory:
|
| obs = obs[:, None, :]
|
| even_ts = torch.arange(obs.shape[1]).to(obs) + 1
|
| self.prior = MultivariateNormal(
|
| torch.zeros(self.dim).type_as(obs), torch.eye(self.dim).type_as(obs)
|
| )
|
|
|
| reversed_ts = torch.cat([torch.flip(even_ts, [0]), torch.tensor([0]).type_as(even_ts)])
|
|
|
|
|
|
|
| if self.is_trajectory:
|
| reversed_ts -= 1
|
| losses = []
|
| regs = []
|
| for t in range(len(reversed_ts) - 1):
|
|
|
| if self.hparams.leaveout_timepoint == t:
|
| continue
|
| ts, x = reversed_ts[t:], obs[:, len(even_ts) - t - 1, :]
|
|
|
| _, x = self.aug_node(x, ts)
|
| x = x[-1]
|
|
|
| delta_logprob, reg, x = self.augmentations(x)
|
| logprob = self.prior.log_prob(x).to(x) - delta_logprob
|
| losses.append(-torch.mean(logprob))
|
|
|
| regs.append(-reg)
|
|
|
| reg = torch.mean(torch.stack(regs))
|
| loss = torch.mean(torch.stack(losses))
|
| return reg, loss
|
|
|