import torch as th import numpy as np import logging import enum from . import path from .utils import EasyDict, log_state, mean_flat from .integrators import ode, sde class ModelType(enum.Enum): """ Which type of output the model predicts. """ NOISE = enum.auto() # the model predicts epsilon SCORE = enum.auto() # the model predicts \nabla \log p(x) VELOCITY = enum.auto() # the model predicts v(x) class PathType(enum.Enum): """ Which type of path to use. """ LINEAR = enum.auto() GVP = enum.auto() VP = enum.auto() class WeightType(enum.Enum): """ Which type of weighting to use. """ NONE = enum.auto() VELOCITY = enum.auto() LIKELIHOOD = enum.auto() class SNRType(enum.Enum): UNIFORM = enum.auto() LOGNORM = enum.auto() class Transport: def __init__( self, *, model_type, path_type, loss_type, train_eps, sample_eps, snr_type ): path_options = { PathType.LINEAR: path.ICPlan, PathType.GVP: path.GVPCPlan, PathType.VP: path.VPCPlan, } self.loss_type = loss_type self.model_type = model_type self.path_sampler = path_options[path_type]() self.train_eps = train_eps self.sample_eps = sample_eps self.snr_type = snr_type def prior_logp(self, z): """ Standard multivariate normal prior Assume z is batched """ shape = th.tensor(z.size()) N = th.prod(shape[1:]) _fn = lambda x: -N / 2.0 * np.log(2 * np.pi) - th.sum(x**2) / 2.0 return th.vmap(_fn)(z) def check_interval( self, train_eps, sample_eps, *, diffusion_form="SBDM", sde=False, reverse=False, eval=False, last_step_size=0.0, ): t0 = 0 t1 = 1 eps = train_eps if not eval else sample_eps if type(self.path_sampler) in [path.VPCPlan]: t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size elif (type(self.path_sampler) in [path.ICPlan, path.GVPCPlan]) and ( self.model_type != ModelType.VELOCITY or sde ): # avoid numerical issue by taking a first semi-implicit step t0 = ( eps if (diffusion_form == "SBDM" and sde) or self.model_type != ModelType.VELOCITY else 0 ) t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size if reverse: t0, t1 = 1 - t0, 1 - t1 return t0, t1 def sample(self, x1): """Sampling x0 & t based on shape of x1 (if needed) Args: x1 - data point; [batch, *dim] """ if isinstance(x1, (list, tuple)): x0 = [th.randn_like(img_start) for img_start in x1] else: x0 = th.randn_like(x1) t0, t1 = self.check_interval(self.train_eps, self.sample_eps) if self.snr_type == SNRType.UNIFORM: t = th.rand((len(x1),)) * (t1 - t0) + t0 elif self.snr_type == SNRType.LOGNORM: u = th.normal(mean=0.0, std=1.0, size=(len(x1),)) t = 1 / (1 + th.exp(-u)) * (t1 - t0) + t0 else: raise ValueError(f"Unknown snr type: {self.snr_type}") t = t.to(x1[0]) return t, x0, x1 def training_losses(self, model, x1, model_kwargs=None): """Loss for training the score model Args: - model: backbone model; could be score, noise, or velocity - x1: datapoint - model_kwargs: additional arguments for the model """ if model_kwargs == None: model_kwargs = {} t, x0, x1 = self.sample(x1) t, xt, ut = self.path_sampler.plan(t, x0, x1) model_output = model(xt, t, **model_kwargs) B = len(x0) terms = {} # terms['pred'] = model_output if self.model_type == ModelType.VELOCITY: if isinstance(x1, (list, tuple)): assert len(model_output) == len(ut) == len(x1) for i in range(B): assert ( model_output[i].shape == ut[i].shape == x1[i].shape ), f"{model_output[i].shape} {ut[i].shape} {x1[i].shape}" terms["task_loss"] = th.stack( [((ut[i] - model_output[i]) ** 2).mean() for i in range(B)], dim=0, ) else: terms["task_loss"] = mean_flat(((model_output - ut) ** 2)) else: raise NotImplementedError # _, drift_var = self.path_sampler.compute_drift(xt, t) # sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, xt)) # if self.loss_type in [WeightType.VELOCITY]: # weight = (drift_var / sigma_t) ** 2 # elif self.loss_type in [WeightType.LIKELIHOOD]: # weight = drift_var / (sigma_t ** 2) # elif self.loss_type in [WeightType.NONE]: # weight = 1 # else: # raise NotImplementedError() # # if self.model_type == ModelType.NOISE: # terms['task_loss'] = mean_flat(weight * ((model_output - x0) ** 2)) # else: # terms['task_loss'] = mean_flat(weight * ((model_output * sigma_t + x0) ** 2)) terms["loss"] = terms["task_loss"] terms["task_loss"] = terms["task_loss"].clone().detach() return terms def get_drift(self): """member function for obtaining the drift of the probability flow ODE""" def score_ode(x, t, model, **model_kwargs): drift_mean, drift_var = self.path_sampler.compute_drift(x, t) model_output = model(x, t, **model_kwargs) return -drift_mean + drift_var * model_output # by change of variable def noise_ode(x, t, model, **model_kwargs): drift_mean, drift_var = self.path_sampler.compute_drift(x, t) sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x)) model_output = model(x, t, **model_kwargs) score = model_output / -sigma_t return -drift_mean + drift_var * score def velocity_ode(x, t, model, **model_kwargs): model_output = model(x, t, **model_kwargs) return model_output if self.model_type == ModelType.NOISE: drift_fn = noise_ode elif self.model_type == ModelType.SCORE: drift_fn = score_ode else: drift_fn = velocity_ode def body_fn(x, t, model, **model_kwargs): model_output = drift_fn(x, t, model, **model_kwargs) assert ( model_output.shape == x.shape ), "Output shape from ODE solver must match input shape" return model_output return body_fn def get_score( self, ): """member function for obtaining score of x_t = alpha_t * x + sigma_t * eps""" if self.model_type == ModelType.NOISE: score_fn = ( lambda x, t, model, **kwargs: model(x, t, **kwargs) / -self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))[0] ) elif self.model_type == ModelType.SCORE: score_fn = lambda x, t, model, **kwagrs: model(x, t, **kwagrs) elif self.model_type == ModelType.VELOCITY: score_fn = ( lambda x, t, model, **kwargs: self.path_sampler.get_score_from_velocity( model(x, t, **kwargs), x, t ) ) else: raise NotImplementedError() return score_fn class Sampler: """Sampler class for the transport model""" def __init__( self, transport, ): """Constructor for a general sampler; supporting different sampling methods Args: - transport: an tranport object specify model prediction & interpolant type """ self.transport = transport self.drift = self.transport.get_drift() self.score = self.transport.get_score() def __get_sde_diffusion_and_drift( self, *, diffusion_form="SBDM", diffusion_norm=1.0, ): def diffusion_fn(x, t): diffusion = self.transport.path_sampler.compute_diffusion( x, t, form=diffusion_form, norm=diffusion_norm ) return diffusion sde_drift = lambda x, t, model, **kwargs: self.drift( x, t, model, **kwargs ) + diffusion_fn(x, t) * self.score(x, t, model, **kwargs) sde_diffusion = diffusion_fn return sde_drift, sde_diffusion def __get_last_step( self, sde_drift, *, last_step, last_step_size, ): """Get the last step function of the SDE solver""" if last_step is None: last_step_fn = lambda x, t, model, **model_kwargs: x elif last_step == "Mean": last_step_fn = ( lambda x, t, model, **model_kwargs: x + sde_drift(x, t, model, **model_kwargs) * last_step_size ) elif last_step == "Tweedie": alpha = ( self.transport.path_sampler.compute_alpha_t ) # simple aliasing; the original name was too long sigma = self.transport.path_sampler.compute_sigma_t last_step_fn = lambda x, t, model, **model_kwargs: x / alpha(t)[0][0] + ( sigma(t)[0][0] ** 2 ) / alpha(t)[0][0] * self.score(x, t, model, **model_kwargs) elif last_step == "Euler": last_step_fn = ( lambda x, t, model, **model_kwargs: x + self.drift(x, t, model, **model_kwargs) * last_step_size ) else: raise NotImplementedError() return last_step_fn def sample_sde( self, *, sampling_method="Euler", diffusion_form="SBDM", diffusion_norm=1.0, last_step="Mean", last_step_size=0.04, num_steps=250, ): """returns a sampling function with given SDE settings Args: - sampling_method: type of sampler used in solving the SDE; default to be Euler-Maruyama - diffusion_form: function form of diffusion coefficient; default to be matching SBDM - diffusion_norm: function magnitude of diffusion coefficient; default to 1 - last_step: type of the last step; default to identity - last_step_size: size of the last step; default to match the stride of 250 steps over [0,1] - num_steps: total integration step of SDE """ if last_step is None: last_step_size = 0.0 sde_drift, sde_diffusion = self.__get_sde_diffusion_and_drift( diffusion_form=diffusion_form, diffusion_norm=diffusion_norm, ) t0, t1 = self.transport.check_interval( self.transport.train_eps, self.transport.sample_eps, diffusion_form=diffusion_form, sde=True, eval=True, reverse=False, last_step_size=last_step_size, ) _sde = sde( sde_drift, sde_diffusion, t0=t0, t1=t1, num_steps=num_steps, sampler_type=sampling_method, ) last_step_fn = self.__get_last_step( sde_drift, last_step=last_step, last_step_size=last_step_size ) def _sample(init, model, **model_kwargs): xs = _sde.sample(init, model, **model_kwargs) ts = th.ones(init.size(0), device=init.device) * t1 x = last_step_fn(xs[-1], ts, model, **model_kwargs) xs.append(x) assert len(xs) == num_steps, "Samples does not match the number of steps" return xs return _sample def sample_ode( self, *, sampling_method="dopri5", num_steps=50, atol=1e-6, rtol=1e-3, reverse=False, time_shifting_factor=None, ): """returns a sampling function with given ODE settings Args: - sampling_method: type of sampler used in solving the ODE; default to be Dopri5 - num_steps: - fixed solver (Euler, Heun): the actual number of integration steps performed - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation - atol: absolute error tolerance for the solver - rtol: relative error tolerance for the solver - reverse: whether solving the ODE in reverse (data to noise); default to False """ if reverse: drift = lambda x, t, model, **kwargs: self.drift( x, th.ones_like(t) * (1 - t), model, **kwargs ) else: drift = self.drift t0, t1 = self.transport.check_interval( self.transport.train_eps, self.transport.sample_eps, sde=False, eval=True, reverse=reverse, last_step_size=0.0, ) _ode = ode( drift=drift, t0=t0, t1=t1, sampler_type=sampling_method, num_steps=num_steps, atol=atol, rtol=rtol, time_shifting_factor=time_shifting_factor, ) return _ode.sample def sample_ode_likelihood( self, *, sampling_method="dopri5", num_steps=50, atol=1e-6, rtol=1e-3, ): """returns a sampling function for calculating likelihood with given ODE settings Args: - sampling_method: type of sampler used in solving the ODE; default to be Dopri5 - num_steps: - fixed solver (Euler, Heun): the actual number of integration steps performed - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation - atol: absolute error tolerance for the solver - rtol: relative error tolerance for the solver """ def _likelihood_drift(x, t, model, **model_kwargs): x, _ = x eps = th.randint(2, x.size(), dtype=th.float, device=x.device) * 2 - 1 t = th.ones_like(t) * (1 - t) with th.enable_grad(): x.requires_grad = True grad = th.autograd.grad( th.sum(self.drift(x, t, model, **model_kwargs) * eps), x )[0] logp_grad = th.sum(grad * eps, dim=tuple(range(1, len(x.size())))) drift = self.drift(x, t, model, **model_kwargs) return (-drift, logp_grad) t0, t1 = self.transport.check_interval( self.transport.train_eps, self.transport.sample_eps, sde=False, eval=True, reverse=False, last_step_size=0.0, ) _ode = ode( drift=_likelihood_drift, t0=t0, t1=t1, sampler_type=sampling_method, num_steps=num_steps, atol=atol, rtol=rtol, ) def _sample_fn(x, model, **model_kwargs): init_logp = th.zeros(x.size(0)).to(x) input = (x, init_logp) drift, delta_logp = _ode.sample(input, model, **model_kwargs) drift, delta_logp = drift[-1], delta_logp[-1] prior_logp = self.transport.prior_logp(drift) logp = prior_logp - delta_logp return logp, drift return _sample_fn