|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""All functions and modules related to model definition. |
|
""" |
|
|
|
import torch |
|
|
|
import numpy as np |
|
from ...sdes import OUVESDE, OUVPSDE |
|
|
|
|
|
_MODELS = {} |
|
|
|
|
|
def register_model(cls=None, *, name=None): |
|
"""A decorator for registering model classes.""" |
|
|
|
def _register(cls): |
|
if name is None: |
|
local_name = cls.__name__ |
|
else: |
|
local_name = name |
|
if local_name in _MODELS: |
|
raise ValueError(f'Already registered model with name: {local_name}') |
|
_MODELS[local_name] = cls |
|
return cls |
|
|
|
if cls is None: |
|
return _register |
|
else: |
|
return _register(cls) |
|
|
|
|
|
def get_model(name): |
|
return _MODELS[name] |
|
|
|
|
|
def get_sigmas(sigma_min, sigma_max, num_scales): |
|
"""Get sigmas --- the set of noise levels for SMLD from config files. |
|
Args: |
|
config: A ConfigDict object parsed from the config file |
|
Returns: |
|
sigmas: a jax numpy arrary of noise levels |
|
""" |
|
sigmas = np.exp( |
|
np.linspace(np.log(sigma_max), np.log(sigma_min), num_scales)) |
|
|
|
return sigmas |
|
|
|
|
|
def get_ddpm_params(config): |
|
"""Get betas and alphas --- parameters used in the original DDPM paper.""" |
|
num_diffusion_timesteps = 1000 |
|
|
|
beta_start = config.model.beta_min / config.model.num_scales |
|
beta_end = config.model.beta_max / config.model.num_scales |
|
betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) |
|
|
|
alphas = 1. - betas |
|
alphas_cumprod = np.cumprod(alphas, axis=0) |
|
sqrt_alphas_cumprod = np.sqrt(alphas_cumprod) |
|
sqrt_1m_alphas_cumprod = np.sqrt(1. - alphas_cumprod) |
|
|
|
return { |
|
'betas': betas, |
|
'alphas': alphas, |
|
'alphas_cumprod': alphas_cumprod, |
|
'sqrt_alphas_cumprod': sqrt_alphas_cumprod, |
|
'sqrt_1m_alphas_cumprod': sqrt_1m_alphas_cumprod, |
|
'beta_min': beta_start * (num_diffusion_timesteps - 1), |
|
'beta_max': beta_end * (num_diffusion_timesteps - 1), |
|
'num_diffusion_timesteps': num_diffusion_timesteps |
|
} |
|
|
|
|
|
def create_model(config): |
|
"""Create the score model.""" |
|
model_name = config.model.name |
|
score_model = get_model(model_name)(config) |
|
score_model = score_model.to(config.device) |
|
score_model = torch.nn.DataParallel(score_model) |
|
return score_model |
|
|
|
|
|
def get_model_fn(model, train=False): |
|
"""Create a function to give the output of the score-based model. |
|
|
|
Args: |
|
model: The score model. |
|
train: `True` for training and `False` for evaluation. |
|
|
|
Returns: |
|
A model function. |
|
""" |
|
|
|
def model_fn(x, labels): |
|
"""Compute the output of the score-based model. |
|
|
|
Args: |
|
x: A mini-batch of input data. |
|
labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently |
|
for different models. |
|
|
|
Returns: |
|
A tuple of (model output, new mutable states) |
|
""" |
|
if not train: |
|
model.eval() |
|
return model(x, labels) |
|
else: |
|
model.train() |
|
return model(x, labels) |
|
|
|
return model_fn |
|
|
|
|
|
def get_score_fn(sde, model, train=False, continuous=False): |
|
"""Wraps `score_fn` so that the model output corresponds to a real time-dependent score function. |
|
|
|
Args: |
|
sde: An `sde_lib.SDE` object that represents the forward SDE. |
|
model: A score model. |
|
train: `True` for training and `False` for evaluation. |
|
continuous: If `True`, the score-based model is expected to directly take continuous time steps. |
|
|
|
Returns: |
|
A score function. |
|
""" |
|
model_fn = get_model_fn(model, train=train) |
|
|
|
if isinstance(sde, OUVPSDE): |
|
def score_fn(x, t): |
|
|
|
if continuous: |
|
|
|
|
|
|
|
labels = t * 999 |
|
score = model_fn(x, labels) |
|
std = sde.marginal_prob(torch.zeros_like(x), t)[1] |
|
else: |
|
|
|
labels = t * (sde.N - 1) |
|
score = model_fn(x, labels) |
|
std = sde.sqrt_1m_alphas_cumprod.to(labels.device)[labels.long()] |
|
|
|
score = -score / std[:, None, None, None] |
|
return score |
|
|
|
elif isinstance(sde, OUVESDE): |
|
def score_fn(x, t): |
|
if continuous: |
|
labels = sde.marginal_prob(torch.zeros_like(x), t)[1] |
|
else: |
|
|
|
labels = sde.T - t |
|
labels *= sde.N - 1 |
|
labels = torch.round(labels).long() |
|
|
|
score = model_fn(x, labels) |
|
return score |
|
|
|
else: |
|
raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.") |
|
|
|
return score_fn |
|
|
|
|
|
def to_flattened_numpy(x): |
|
"""Flatten a torch tensor `x` and convert it to numpy.""" |
|
return x.detach().cpu().numpy().reshape((-1,)) |
|
|
|
|
|
def from_flattened_numpy(x, shape): |
|
"""Form a torch tensor with the given `shape` from a flattened numpy array `x`.""" |
|
return torch.from_numpy(x.reshape(shape)) |