MARS5-TTS / mars5 /diffuser.py
arnavmehta7's picture
Add files (#1)
8520a55 verified
raw
history blame
No virus
20.3 kB
"""
Discrete multinomial diffusion code adapted from https://github.com/RF5/transfusion-asr,
which in turn is adapted from https://github.com/ehoogeboom/multinomial_diffusion.
Please see the original repo (https://github.com/ehoogeboom/multinomial_diffusion) and paper for full
details on how multinomial diffusion works -- thanks to the original authors!
"""
import torch
from torch import Tensor
from torch.functional import F
import numpy as np
from dataclasses import dataclass
from typing import Union
# -------------- Multinomial utility functions -----------
MIN_LOG_ARG = 1e-7 # originally was 1e-40
def log_1_min_a(a): return torch.log((1 - a.exp()).clamp_(min=1e-30))
def log_add_exp(a, b):
maximum = torch.max(a, b)
return maximum + torch.log(torch.exp(a - maximum) + torch.exp(b - maximum))
def extract(a: Tensor, t, x_shape):
""" Given 1D vector of alpha/alpha_cum/betas, get index at `t` of shape (bs,), and then
broadcast it to number of dims in `x_shape`.
"""
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def index_to_log_onehot(x, num_classes, dim=-1, dtype=torch.float32):
""" Convert indices `x` (bs, ...) to approx one-hot log-probs of shape (bs, ..., num_classes) """
assert x.max().item() < num_classes, \
f'Error: {x.max().item()} >= {num_classes}'
x_onehot = F.one_hot(x, num_classes)
if dim == 1:
permute_order = (0, -1) + tuple(range(1, len(x.size())))
x_onehot = x_onehot.permute(permute_order)
else:
pass
log_x = torch.log(x_onehot.to(dtype).clamp(min=MIN_LOG_ARG)) # so min(log_x) will be -30
return log_x
def sum_except_batch(x: Tensor, num_dims=1) -> Tensor:
'''
Sums all dimensions except the first.
Args:
x: Tensor, shape (batch_size, ...)
num_dims: int, number of batch dims (default=1)
Returns:
x_sum: Tensor, shape (batch_size,)
'''
return x.reshape(*x.shape[:num_dims], -1).sum(-1)
# -------------- Multinomial diffusion class -------------
class MultinomialDiffusion():
def __init__(self, num_classes, timesteps=100, diffusion_s=0.008,
loss_type='vb_stochastic', parametrization='x0',
dtype=torch.float32,
device='cpu'):
super(MultinomialDiffusion, self).__init__()
assert loss_type in ('vb_stochastic',)
assert parametrization in ('x0', 'direct')
self.num_classes = num_classes
self.loss_type = loss_type
self.num_timesteps = timesteps
self.parametrization = parametrization
alphas = self.cosine_beta_schedule(timesteps, diffusion_s)
alphas = alphas.to(torch.float64)
log_alpha = alphas.log()
log_cumprod_alpha = torch.cumsum(log_alpha, dim=-1)
log_1_min_alpha = log_1_min_a(log_alpha) # = log(betas)
log_1_min_cumprod_alpha = log_1_min_a(log_cumprod_alpha) # = log(1- \bar{a})
a = log_add_exp(log_alpha, log_1_min_alpha) # log(1-beta + beta) = log(1) = 0
assert log_add_exp(log_alpha, log_1_min_alpha).abs().sum().item() < 1.e-5
assert log_add_exp(log_cumprod_alpha, log_1_min_cumprod_alpha).abs().sum().item() < 1e-5
assert (torch.cumsum(log_alpha, dim=-1) - log_cumprod_alpha).abs().sum().item() < 1.e-5
# Convert to float32 and register buffers.
self.log_alpha = log_alpha.to(dtype).to(device)
self.log_1_min_alpha = log_1_min_alpha.to(dtype).to(device)
self.log_cumprod_alpha = log_cumprod_alpha.to(dtype).to(device)
self.log_1_min_cumprod_alpha = log_1_min_cumprod_alpha.to(dtype).to(device)
@staticmethod
def cosine_beta_schedule(timesteps, s=0.008) -> Tensor:
"""
cosine schedule as proposed in https://arxiv.org/abs/2102.09672 .
Returns alpha parameters, NOT Beta
"""
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
alphas = (alphas_cumprod[1:] / alphas_cumprod[:-1])
alphas = torch.clamp(alphas, 0.001, 1.0)
return torch.sqrt(alphas)
def multinomial_kl(self, log_prob1: Tensor, log_prob2: Tensor, dim=-1) -> Tensor:
""" Get KL divergence between two categorical distributions specified with `log_prob1` and `log_prob2`.
Assumed probability dim is `dim` (i.e. log_prob1.exp().sum(dim=`dim`) should be tensor of ones)
"""
kl = (log_prob1.exp() * (log_prob1 - log_prob2)).sum(dim=dim)
return kl
def q_pred_one_timestep(self, log_x_t: Tensor, t: Tensor) -> Tensor:
""" Compute q(x_t | x_{t-1}) = C(x_t | alpha_t * x_{t-1} + (1-alpha_t)/K in the log-domain
given `log_x_t` as log one-hot encoding of x_t.
Recall due to symmetry property we can compute
this value using x_t instead of x_{t-1} (se appendix A of https://arxiv.org/pdf/2102.05379.pdf)
"""
dt = log_x_t.dtype
log_alpha_t = extract(self.log_alpha, t, log_x_t.shape).to(dt)
log_1_min_alpha_t = extract(self.log_1_min_alpha, t, log_x_t.shape).to(dt)
# alpha_t * E[xt] + (1 - alpha_t) 1 / K
log_probs = log_add_exp(
log_x_t + log_alpha_t,
log_1_min_alpha_t - np.log(self.num_classes)
)
return log_probs
def q_pred_one_timestep_scaled(self, log_x_t: Tensor, t: Tensor, c: int, jump_len: int) -> Tensor:
""" Compute q(x_t | x_{t-1}) = C(x_t | alpha_t * x_{t-1} + (1-alpha_t)/K in the log-domain
given `log_x_t` as log one-hot encoding of x_t.
Recall due to symmetry property we can compute
this value using x_t instead of x_{t-1} (se appendix A of https://arxiv.org/pdf/2102.05379.pdf)
"""
dt = log_x_t.dtype
log_alpha_t = extract(self.log_alpha, t, log_x_t.shape).to(dt)
log_1_min_alpha_t = extract(self.log_1_min_alpha, t, log_x_t.shape).to(dt)
# Magic
xax = torch.arange(0,log_x_t.shape[1],1).to(log_x_t.device)
aa=log_x_t.shape[1]*(c/jump_len)
sig = 1/(1+torch.exp(-(xax-aa+20)/8))
log_alpha_t = (torch.log(1/sig)[None,:,None] + log_alpha_t).clamp(-torch.inf, 0)
log_1_min_alpha_t = torch.log(sig)[None,:,None] + log_1_min_alpha_t
# alpha_t * E[xt] + (1 - alpha_t) 1 / K
log_probs = log_add_exp(
log_x_t + log_alpha_t,
log_1_min_alpha_t - np.log(self.num_classes)
)
return log_probs
def q_pred(self, log_x_start: Tensor, t) -> Tensor:
""" Compute q(x_t | x_0) = C(x_t | bar{alpha}_t * x_0 + (1 - bar{alpha}_t)/K ) in log domain,
given `log_x_start` of log probs of x_0.
"""
dt = log_x_start.dtype
log_cumprod_alpha_t = extract(self.log_cumprod_alpha, t, log_x_start.shape).to(dt)
log_1_min_cumprod_alpha = extract(self.log_1_min_cumprod_alpha, t, log_x_start.shape).to(dt)
log_probs = log_add_exp(
log_x_start + log_cumprod_alpha_t,
log_1_min_cumprod_alpha - np.log(self.num_classes)
)
return log_probs
def q_posterior(self, log_x_start, log_x_t, t):
""" Compute `q(xt-1 | xt, x0) = q(xt | xt-1, x0) * q(xt-1 | x0) / q(xt | x0)`
where q(xt | xt-1, x0) = q(xt | xt-1).
"""
# q(xt-1 | xt, x0) = q(xt | xt-1, x0) * q(xt-1 | x0) / q(xt | x0)
# where q(xt | xt-1, x0) = q(xt | xt-1).
t_minus_1 = t - 1
# Remove negative values, will not be used anyway for final decoder
t_minus_1 = torch.where(t_minus_1 < 0, torch.zeros_like(t_minus_1), t_minus_1)
log_EV_qxtmin_x0 = self.q_pred(log_x_start, t_minus_1) # log( q(x_{t-1} | x_0) )
# if t == 0, then log( q(x_0 | x_0) ) = log( one_hot(x_0) ), not even random at that point.
# so, where t == 0
num_axes = (1,) * (len(log_x_start.size()) - 1)
t_broadcast = t.view(-1, *num_axes) * torch.ones_like(log_x_start) # broadcast to non-batch axes
log_EV_qxtmin_x0 = torch.where(t_broadcast == 0, log_x_start, log_EV_qxtmin_x0)
# where it is zero, replace
# with log one-hot encoding of x0.
# Note: _NOT_ x_tmin1, which is how the formula is typically used!!!
# Not very easy to see why this is true. But it is :)
# log_EV_qxtmin_x0 ~ q(x_{t-1} | x_0)
# q_pred_one_timestep(log_x_t, t) ~ q(x_t | x_{t-1}) (which due to symmetry can be computed using x_t)
unnormed_logprobs = log_EV_qxtmin_x0 + self.q_pred_one_timestep(log_x_t, t) # numerator of bayes
# approximate denominator with just a normalizing sum.
log_EV_xtmin_given_xt_given_xstart = \
unnormed_logprobs \
- torch.logsumexp(unnormed_logprobs, dim=-1, keepdim=True)
return log_EV_xtmin_given_xt_given_xstart
def p_pred(self, log_x_t, t, log_x0_pred):
""" Predict `p(x_{t-1} | x_t)` using `q(xt-1 | xt, hat{x0})`, where `hat{x0}` is given by
log probabilities from model as `log_x0_pred` (bs, ...., K) and x_t is given by
`log_x_t` of shape `(bs, ..., K)`
"""
# log_x_recon = self.predict_start(log_x, t=t) # model itself predicts x_0
# log_x0_pred
log_model_pred = self.q_posterior(
log_x_start=log_x0_pred, log_x_t=log_x_t, t=t)
return log_model_pred
def log_sample_categorical(self, logprobs: Tensor, dim=-1) -> Tensor:
""" Sample from categorical `logprobs` (bs, ..., probs), where position of probs is specified
by `dim`.
Returns sampled long indices of shape `(bs, ...)`
"""
uniform = torch.rand_like(logprobs)
gumbel_noise = -torch.log( (-torch.log(uniform.clamp_(min=MIN_LOG_ARG)) ).clamp_(min=MIN_LOG_ARG))
sample = (gumbel_noise + logprobs).argmax(dim=dim)
return sample
def q_sample(self, log_x_start, t):
""" Draw `x_t` ~ q(x_t | x_0) . `log_x_start` is of shape `(bs, ..., K)`, returns result of same shape """
log_EV_qxt_x0 = self.q_pred(log_x_start, t)
sample = self.log_sample_categorical(log_EV_qxt_x0)
# log_sample = index_to_log_onehot(sample, self.num_classes)
return sample #log_sample
def compute_Lt(self, log_x_start: Tensor, log_x_t: Tensor, log_x0_pred: Tensor, t,
detach_mean=False, include_kl_prior=True):
""" Get loss given one-hot log x_0, one-hot log x_t, t, and model prediction `log_x0_pred`.
Parameters:
- `log_x_start`: ground-truth input x0, converted to log one-hot (bs, ..., K)
- `log_x_t`: sampled noisy input at `x_t`, converted to log one-hot (bs, ..., K)
- `t`: diffusion timestep (bs,)
- `log_x0_pred`: model prediction of log probabilities of x0, i.e. hat{x0}.
- `include_kl_prior`: add last two terms to model loss (does not change optimization problem).
"""
dtype = log_x_start.dtype
log_true_prob = self.q_posterior(
log_x_start=log_x_start, log_x_t=log_x_t, t=t)
log_model_prob = self.p_pred(log_x_t=log_x_t, t=t, log_x0_pred=log_x0_pred)
if detach_mean:
log_model_prob = log_model_prob.detach()
kl = self.multinomial_kl(log_true_prob, log_model_prob)
kl = sum_except_batch(kl)
# Add L_0, -log(p(x_0 | x_1))
decoder_nll = - (log_x_start.exp() * log_model_prob).sum(dim=-1)
decoder_nll = sum_except_batch(decoder_nll)
mask = (t == torch.zeros_like(t)).to(dtype)
loss = mask * decoder_nll + (1. - mask) * kl # only add L0 if t == 0.
if include_kl_prior:
pt = torch.ones_like(t, dtype=dtype)
kl_prior = self.kl_prior(log_x_start)
loss = (kl) + kl_prior
return loss
def kl_prior(self, log_x_start: Tensor) -> Tensor:
""" This function computes -H_{q}(x_T | x_0)+H_{p}(x_T), which
by some math (see wiki for KL div relation to conditional entropy).
So KL(q(x_T | x_0) || 1/K) = -H_{q}(x_T | x_0)+H_{p}(x_T) for categorical distribution.
Given `log_x_start` (bs, ..., probs), return KL prior of shape (bs,)
"""
b = log_x_start.size(0)
device = log_x_start.device
ones = torch.ones(b, device=device, dtype=torch.long)
log_qxT_prob = self.q_pred(log_x_start, t=(self.num_timesteps - 1) * ones) # q(x_T | x_0)
log_half_prob = -torch.log(self.num_classes * torch.ones_like(log_qxT_prob)) # log(1/K), broadcast to q(x_T|x_0) shape
kl_prior = self.multinomial_kl(log_qxT_prob, log_half_prob)
return sum_except_batch(kl_prior)
def index2logit(x: Tensor, vocab_size: int, dtype=torch.float32):
x = F.one_hot(x, num_classes=vocab_size).to(dtype)
x = x * (vocab_size/(vocab_size - 1)) - 1/(vocab_size - 1)
return x
# ------------------------------
# Functions adapted from the full
@dataclass
class DSH():
# Diffusion Sampling Hyperparameters [DSH] (Section 4)
jump_len: int = 1 # j in RePaint paper [default 10] (Section 4.1)
jump_n_sample: int = 1 # r in RePaint paper [default 10] (Section 4.1)
last_greedy: bool = False # whether to not sample at t=0, but take argmax prediction. [default False]
x_0_temp: float = 1.0 # reweight temp for model prediction of x0
guidance_w: float = 1.0 # classifier free guidance weight [default 1.5] (Section 4.3)
enable_kevin_scaled_inference: bool = True # sequentially progressive diffusion [default True] (Section 4.2)
T_override: Union[None, int] = None # allow variable transcription sizes during inference (Section 4.4)
deep_clone: bool = False # whether to do deep clone.
q0_override_steps: int = 0 # number of steps that we allow overriding the input quant level 0 inputs.
progress: bool = False # whether to show progress bar
def get_schedule(t_T, jump_len=10, jump_n_sample=10):
jumps = {}
for j in range(0, t_T - jump_len, jump_len):
jumps[j] = jump_n_sample - 1
t = t_T
ts = []
while t >= 1:
t = t-1
ts.append(t)
if jumps.get(t, 0) > 0:
jumps[t] = jumps[t] - 1
for _ in range(jump_len):
t = t + 1
ts.append(t)
ts.append(-1)
return ts
def forward_diffusion(diff: MultinomialDiffusion, dtype, x, t, c=None, dsh=DSH):
"""Simple forward diffusion process p"""
log_x_t = index_to_log_onehot(x, diff.num_classes, dtype=dtype)
if c is not None: x = diff.q_pred_one_timestep_scaled(log_x_t, t, c, dsh.jump_len)
else: x = diff.q_pred_one_timestep(log_x_t, t)
x = diff.log_sample_categorical(x)
return x
def reverse_diffusion(diff: MultinomialDiffusion, model, batch, x_known=None, m=None,
last_greedy=False, temperature=1.0, alphas=None, ensemble_size=1, dsh=DSH):
"""Reverse diffusion process q: predict x_{t-1} given x, t, x_known, m. Optionally do not sample model output
for t=0, but rather use the greedy argmax with `last_greedy`.
"""
x = batch[4]
t = batch[-1]
if x_known is None: x_known = torch.zeros_like(x)
if m is None: m = torch.zeros_like(x)
# Equation 8b
# for b in batch:
# print(f"{b.shape}: {b}")
x_0_pred = model(*batch) # (bs, seq_len, logit_dim, n_quant)
x_0_pred = x_0_pred.permute(0, 1, 3, 2) # (bs, seq_len, n_quant, dim)
if dsh.guidance_w != 1:
uncond_x_0_pred = model(*(c.clone() if c is not None else None for c in batch), drop_cond=True)
uncond_x_0_pred = uncond_x_0_pred.permute(0, 1, 3, 2)
x_0_pred = dsh.guidance_w*x_0_pred + (1-dsh.guidance_w)*uncond_x_0_pred
x_0_pred = x_0_pred / temperature
log_x_0_pred = F.log_softmax(x_0_pred, dim=-1)
log_x_t = index_to_log_onehot(x, diff.num_classes, dtype=x_0_pred.dtype)
# print("PRE: ", log_x_t.shape, t.shape, log_x_0_pred.shape)
log_model_pred = diff.p_pred(log_x_t, t, log_x_0_pred) # p(x_{t-1} | x_{t})
a_t = alphas[t[0]] if alphas is not None else 0
mat = torch.eye(ensemble_size, device=x.device)*(1-a_t)
mat += 1/ensemble_size * a_t
mat = torch.block_diag(*([mat]*(x.shape[0]//ensemble_size)))
log_model_pred = ( (mat[..., None, None] ).log().to(x.dtype) + log_model_pred[None])
log_model_pred = torch.logsumexp(log_model_pred, dim=1)
if (t==0).all() and last_greedy: # Do not sample at t=0
x_tm1_unknown = log_model_pred.argmax(dim=-1)
else:
x_tm1_unknown = diff.log_sample_categorical(log_model_pred)
# Equation 8a
x_known_log = index_to_log_onehot(x_known, diff.num_classes, dtype=x_0_pred.dtype)
if (t==0).all(): # Do not sample at t=0
x_tm1_known = x_known
else:
x_tm1_known = diff.q_sample(x_known_log, t)
# Equation 8c
x_tm1 = x_tm1_known * m.long() + x_tm1_unknown * (1 - m.long())
return x_tm1, x_0_pred
@torch.inference_mode()
def perform_simple_inference(model: torch.nn.Module, batch: tuple, diff: MultinomialDiffusion, T, dtype=torch.float16,
retain_quant0: bool = True, dsh=DSH):
""" If `retain_quant0`, then do not sample quant0 in each forward or reverse diffusion step. """
# (bs=1, N), (bs, seq_len2, 8), (bs,)
c_text, c_codes, c_text_lengths, c_codes_lengths, x, x_padding_mask = batch
device = c_text.device
bs = c_text.shape[0]
x_quant0 = x[..., 0].clone() # (bs, seq_len) 0th quant level
x = torch.randint(0, diff.num_classes, x.shape, dtype=x.dtype, device=device)
# CRITICAL LINE: override quantization level 0 with provided quant0 level.
x[..., 0] = x_quant0
# RePaint paper resample scheduling
times = get_schedule(T, jump_n_sample=dsh.jump_n_sample, jump_len=dsh.jump_len)
x_known = torch.zeros_like(x)
x_known[..., 0] = x[..., 0] # override L0 codes
m = torch.zeros_like(x).bool()
# (bs, seq_len, 8)
m[..., 0] = True
offset = 0
if dsh.deep_clone:
print(f"Note: using deep clone. Assuming input `c_phones` is concatenated prompt and output phones.",
"Also assuming no padded indices in `c_codes`.")
prompt = c_codes
x = torch.cat((prompt, x), dim=1) # (bs=1, sl1 + sl2, 8)
x_known = torch.cat((prompt, x_known), dim=1)
x_padding_mask = torch.cat((
torch.zeros(x_padding_mask.shape[0], c_codes_lengths[0], dtype=torch.bool, device=x_padding_mask.device),
x_padding_mask), dim=-1
)
# (bs=1, :up to prompt duration, all 8 codebooks) = True/masked.
m = torch.cat((torch.ones_like(prompt), m), dim=1)
x_quant0 = torch.cat((prompt[..., 0], x_quant0), dim=-1)
offset = c_codes_lengths[0]
print(f"New x: {x.shape} | new x_known: {x_known.shape} . Base prompt: {prompt.shape}. New padding mask: {x_padding_mask.shape} | m shape: {m.shape}")
c = 0 # sequentially progressive diffusion offset (Section 4.2)
# ensemble bs (not in paper)
alphas = torch.linspace(1, 0, T).to(device)
pb = zip(times[:-1], times[1:])
if dsh.progress:
from fastprogress import progress_bar
pb = progress_bar(pb, total=len(times)-1)
# See RePaint paper algorithm
for t_last, t_cur in pb:
t = torch.ones((bs,), dtype=torch.long, device=x.device) * (t_last)
if t_cur < t_last:
if c > dsh.jump_n_sample:
c = 0
c += 1/dsh.jump_len
# Reverse diffusion: q
cbatch = (c_text, c_codes, c_text_lengths, c_codes_lengths, x, x_padding_mask, t)
x, x_0_pred = reverse_diffusion(diff, model, cbatch, x_known, m, temperature=dsh.x_0_temp, alphas=alphas, ensemble_size=1, dsh=dsh)
else:
# Forward diffusion: p
if dsh.enable_kevin_scaled_inference: x = forward_diffusion(diff, dtype, x, t, c=c, dsh=dsh)
else: x = forward_diffusion(diff, dtype, x, t, c=None, dsh=dsh)
if retain_quant0 and dsh.q0_override_steps < t_last:
x[..., 0] = x_quant0
# crop offset:
x = x[:, offset:]
return x