import numpy as np import torch as th from .gaussian_diffusion import GaussianDiffusion, GaussianDiffusionDDPM def space_timesteps(num_timesteps, sample_timesteps): """ Create a list of timesteps to use from an original diffusion process, given the number of timesteps we want to take from equally-sized portions of the original process. :param num_timesteps: the number of diffusion steps in the original process to divide up. :param section_counts: timesteps for sampling :return: a set of diffusion steps from the original process to use. """ all_steps = [int((num_timesteps/sample_timesteps) * x) for x in range(sample_timesteps)] return set(all_steps) class SpacedDiffusion(GaussianDiffusion): """ A diffusion process which can skip steps in a base diffusion process. :param use_timesteps: a collection (sequence or set) of timesteps from the original diffusion process to retain. :param kwargs: the kwargs to create the base diffusion process. """ def __init__(self, use_timesteps, **kwargs): self.use_timesteps = set(use_timesteps) self.timestep_map = [] self.original_num_steps = len(kwargs["sqrt_etas"]) base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa new_sqrt_etas = [] for ii, etas_current in enumerate(base_diffusion.sqrt_etas): if ii in self.use_timesteps: new_sqrt_etas.append(etas_current) self.timestep_map.append(ii) kwargs["sqrt_etas"] = np.array(new_sqrt_etas) super().__init__(**kwargs) def p_mean_variance(self, model, *args, **kwargs): # pylint: disable=signature-differs return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) def training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs return super().training_losses(self._wrap_model(model), *args, **kwargs) def _wrap_model(self, model): if isinstance(model, _WrappedModel): return model return _WrappedModel(model, self.timestep_map, self.original_num_steps) class _WrappedModel: def __init__(self, model, timestep_map, original_num_steps): self.model = model self.timestep_map = timestep_map self.original_num_steps = original_num_steps def __call__(self, x, ts, **kwargs): map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) new_ts = map_tensor[ts] return self.model(x, new_ts, **kwargs) class SpacedDiffusionDDPM(GaussianDiffusionDDPM): """ A diffusion process which can skip steps in a base diffusion process. :param use_timesteps: a collection (sequence or set) of timesteps from the original diffusion process to retain. :param kwargs: the kwargs to create the base diffusion process. """ def __init__(self, use_timesteps, **kwargs): self.use_timesteps = set(use_timesteps) self.timestep_map = [] self.original_num_steps = len(kwargs["betas"]) base_diffusion = GaussianDiffusionDDPM(**kwargs) # pylint: disable=missing-kwoa last_alpha_cumprod = 1.0 new_betas = [] for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): if i in self.use_timesteps: new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) last_alpha_cumprod = alpha_cumprod self.timestep_map.append(i) kwargs["betas"] = np.array(new_betas) super().__init__(**kwargs) def p_mean_variance(self, model, *args, **kwargs): # pylint: disable=signature-differs return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) def training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs return super().training_losses(self._wrap_model(model), *args, **kwargs) def _wrap_model(self, model): if isinstance(model, _WrappedModel): return model return _WrappedModel(model, self.timestep_map, self.original_num_steps)