| |
| |
| |
| |
| |
|
|
| """ |
| Functions for Noise Schedule, defines diffusion process, reverse process and data processor. |
| """ |
|
|
| from collections import namedtuple |
| import random |
| import typing as tp |
| import julius |
| import torch |
|
|
| TrainingItem = namedtuple("TrainingItem", "noisy noise step") |
|
|
|
|
| def betas_from_alpha_bar(alpha_bar): |
| alphas = torch.cat([torch.Tensor([alpha_bar[0]]), alpha_bar[1:]/alpha_bar[:-1]]) |
| return 1 - alphas |
|
|
|
|
| class SampleProcessor(torch.nn.Module): |
| def project_sample(self, x: torch.Tensor): |
| """Project the original sample to the 'space' where the diffusion will happen.""" |
| return x |
|
|
| def return_sample(self, z: torch.Tensor): |
| """Project back from diffusion space to the actual sample space.""" |
| return z |
|
|
|
|
| class MultiBandProcessor(SampleProcessor): |
| """ |
| MultiBand sample processor. The input audio is splitted across |
| frequency bands evenly distributed in mel-scale. |
| |
| Each band will be rescaled to match the power distribution |
| of Gaussian noise in that band, using online metrics |
| computed on the first few samples. |
| |
| Args: |
| n_bands (int): Number of mel-bands to split the signal over. |
| sample_rate (int): Sample rate of the audio. |
| num_samples (int): Number of samples to use to fit the rescaling |
| for each band. The processor won't be stable |
| until it has seen that many samples. |
| power_std (float or list/tensor): The rescaling factor computed to match the |
| power of Gaussian noise in each band is taken to |
| that power, i.e. `1.` means full correction of the energy |
| in each band, and values less than `1` means only partial |
| correction. Can be used to balance the relative importance |
| of low vs. high freq in typical audio signals. |
| """ |
| def __init__(self, n_bands: int = 8, sample_rate: float = 24_000, |
| num_samples: int = 10_000, power_std: tp.Union[float, tp.List[float], torch.Tensor] = 1.): |
| super().__init__() |
| self.n_bands = n_bands |
| self.split_bands = julius.SplitBands(sample_rate, n_bands=n_bands) |
| self.num_samples = num_samples |
| self.power_std = power_std |
| if isinstance(power_std, list): |
| assert len(power_std) == n_bands |
| power_std = torch.tensor(power_std) |
| self.register_buffer('counts', torch.zeros(1)) |
| self.register_buffer('sum_x', torch.zeros(n_bands)) |
| self.register_buffer('sum_x2', torch.zeros(n_bands)) |
| self.register_buffer('sum_target_x2', torch.zeros(n_bands)) |
| self.counts: torch.Tensor |
| self.sum_x: torch.Tensor |
| self.sum_x2: torch.Tensor |
| self.sum_target_x2: torch.Tensor |
|
|
| @property |
| def mean(self): |
| mean = self.sum_x / self.counts |
| return mean |
|
|
| @property |
| def std(self): |
| std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt() |
| return std |
|
|
| @property |
| def target_std(self): |
| target_std = self.sum_target_x2 / self.counts |
| return target_std |
|
|
| def project_sample(self, x: torch.Tensor): |
| assert x.dim() == 3 |
| bands = self.split_bands(x) |
| if self.counts.item() < self.num_samples: |
| ref_bands = self.split_bands(torch.randn_like(x)) |
| self.counts += len(x) |
| self.sum_x += bands.mean(dim=(2, 3)).sum(dim=1) |
| self.sum_x2 += bands.pow(2).mean(dim=(2, 3)).sum(dim=1) |
| self.sum_target_x2 += ref_bands.pow(2).mean(dim=(2, 3)).sum(dim=1) |
| rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std |
| bands = (bands - self.mean.view(-1, 1, 1, 1)) * rescale.view(-1, 1, 1, 1) |
| return bands.sum(dim=0) |
|
|
| def return_sample(self, x: torch.Tensor): |
| assert x.dim() == 3 |
| bands = self.split_bands(x) |
| rescale = (self.std / self.target_std) ** self.power_std |
| bands = bands * rescale.view(-1, 1, 1, 1) + self.mean.view(-1, 1, 1, 1) |
| return bands.sum(dim=0) |
|
|
|
|
| class NoiseSchedule: |
| """Noise schedule for diffusion. |
| |
| Args: |
| beta_t0 (float): Variance of the first diffusion step. |
| beta_t1 (float): Variance of the last diffusion step. |
| beta_exp (float): Power schedule exponent |
| num_steps (int): Number of diffusion step. |
| variance (str): choice of the sigma value for the denoising eq. Choices: "beta" or "beta_tilde" |
| clip (float): clipping value for the denoising steps |
| rescale (float): rescaling value to avoid vanishing signals unused by default (i.e 1) |
| repartition (str): shape of the schedule only power schedule is supported |
| sample_processor (SampleProcessor): Module that normalize data to match better the gaussian distribution |
| noise_scale (float): Scaling factor for the noise |
| """ |
| def __init__(self, beta_t0: float = 1e-4, beta_t1: float = 0.02, num_steps: int = 1000, variance: str = 'beta', |
| clip: float = 5., rescale: float = 1., device='cuda', beta_exp: float = 1, |
| repartition: str = "power", alpha_sigmoid: dict = {}, n_bands: tp.Optional[int] = None, |
| sample_processor: SampleProcessor = SampleProcessor(), noise_scale: float = 1.0, **kwargs): |
|
|
| self.beta_t0 = beta_t0 |
| self.beta_t1 = beta_t1 |
| self.variance = variance |
| self.num_steps = num_steps |
| self.clip = clip |
| self.sample_processor = sample_processor |
| self.rescale = rescale |
| self.n_bands = n_bands |
| self.noise_scale = noise_scale |
| assert n_bands is None |
| if repartition == "power": |
| self.betas = torch.linspace(beta_t0 ** (1 / beta_exp), beta_t1 ** (1 / beta_exp), num_steps, |
| device=device, dtype=torch.float) ** beta_exp |
| else: |
| raise RuntimeError('Not implemented') |
| self.rng = random.Random(1234) |
|
|
| def get_beta(self, step: tp.Union[int, torch.Tensor]): |
| if self.n_bands is None: |
| return self.betas[step] |
| else: |
| return self.betas[:, step] |
|
|
| def get_initial_noise(self, x: torch.Tensor): |
| if self.n_bands is None: |
| return torch.randn_like(x) |
| return torch.randn((x.size(0), self.n_bands, x.size(2))) |
|
|
| def get_alpha_bar(self, step: tp.Optional[tp.Union[int, torch.Tensor]] = None) -> torch.Tensor: |
| """Return 'alpha_bar', either for a given step, or as a tensor with its value for each step.""" |
| if step is None: |
| return (1 - self.betas).cumprod(dim=-1) |
| if type(step) is int: |
| return (1 - self.betas[:step + 1]).prod() |
| else: |
| return (1 - self.betas).cumprod(dim=0)[step].view(-1, 1, 1) |
|
|
| def get_training_item(self, x: torch.Tensor, tensor_step: bool = False) -> TrainingItem: |
| """Create a noisy data item for diffusion model training: |
| |
| Args: |
| x (torch.Tensor): clean audio data torch.tensor(bs, 1, T) |
| tensor_step (bool): If tensor_step = false, only one step t is sample, |
| the whole batch is diffused to the same step and t is int. |
| If tensor_step = true, t is a tensor of size (x.size(0),) |
| every element of the batch is diffused to a independently sampled. |
| """ |
| step: tp.Union[int, torch.Tensor] |
| if tensor_step: |
| bs = x.size(0) |
| step = torch.randint(0, self.num_steps, size=(bs,), device=x.device) |
| else: |
| step = self.rng.randrange(self.num_steps) |
| alpha_bar = self.get_alpha_bar(step) |
|
|
| x = self.sample_processor.project_sample(x) |
| noise = torch.randn_like(x) |
| noisy = (alpha_bar.sqrt() / self.rescale) * x + (1 - alpha_bar).sqrt() * noise * self.noise_scale |
| return TrainingItem(noisy, noise, step) |
|
|
| def generate(self, model: torch.nn.Module, initial: tp.Optional[torch.Tensor] = None, |
| condition: tp.Optional[torch.Tensor] = None, return_list: bool = False): |
| """Full ddpm reverse process. |
| |
| Args: |
| model (nn.Module): Diffusion model. |
| initial (tensor): Initial Noise. |
| condition (tensor): Input conditionning Tensor (e.g. encodec compressed representation). |
| return_list (bool): Whether to return the whole process or only the sampled point. |
| """ |
| alpha_bar = self.get_alpha_bar(step=self.num_steps - 1) |
| current = initial |
| iterates = [initial] |
| for step in range(self.num_steps)[::-1]: |
| with torch.no_grad(): |
| estimate = model(current, step, condition=condition).sample |
| alpha = 1 - self.betas[step] |
| previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt() |
| previous_alpha_bar = self.get_alpha_bar(step=step - 1) |
| if step == 0: |
| sigma2 = 0 |
| elif self.variance == 'beta': |
| sigma2 = 1 - alpha |
| elif self.variance == 'beta_tilde': |
| sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha) |
| elif self.variance == 'none': |
| sigma2 = 0 |
| else: |
| raise ValueError(f'Invalid variance type {self.variance}') |
|
|
| if sigma2 > 0: |
| previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale |
| if self.clip: |
| previous = previous.clamp(-self.clip, self.clip) |
| current = previous |
| alpha_bar = previous_alpha_bar |
| if step == 0: |
| previous *= self.rescale |
| if return_list: |
| iterates.append(previous.cpu()) |
|
|
| if return_list: |
| return iterates |
| else: |
| return self.sample_processor.return_sample(previous) |
|
|
| def generate_subsampled(self, model: torch.nn.Module, initial: torch.Tensor, step_list: tp.Optional[list] = None, |
| condition: tp.Optional[torch.Tensor] = None, return_list: bool = False): |
| """Reverse process that only goes through Markov chain states in step_list.""" |
| if step_list is None: |
| step_list = list(range(1000))[::-50] + [0] |
| alpha_bar = self.get_alpha_bar(step=self.num_steps - 1) |
| alpha_bars_subsampled = (1 - self.betas).cumprod(dim=0)[list(reversed(step_list))].cpu() |
| betas_subsampled = betas_from_alpha_bar(alpha_bars_subsampled) |
| current = initial * self.noise_scale |
| iterates = [current] |
| for idx, step in enumerate(step_list[:-1]): |
| with torch.no_grad(): |
| estimate = model(current, step, condition=condition).sample * self.noise_scale |
| alpha = 1 - betas_subsampled[-1 - idx] |
| previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt() |
| previous_alpha_bar = self.get_alpha_bar(step_list[idx + 1]) |
| if step == step_list[-2]: |
| sigma2 = 0 |
| previous_alpha_bar = torch.tensor(1.0) |
| else: |
| sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha) |
| if sigma2 > 0: |
| previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale |
| if self.clip: |
| previous = previous.clamp(-self.clip, self.clip) |
| current = previous |
| alpha_bar = previous_alpha_bar |
| if step == 0: |
| previous *= self.rescale |
| if return_list: |
| iterates.append(previous.cpu()) |
| if return_list: |
| return iterates |
| else: |
| return self.sample_processor.return_sample(previous) |
|
|