|
""" |
|
ein notation: |
|
b - batch |
|
n - sequence |
|
nt - text sequence |
|
nw - raw wave length |
|
d - dimension |
|
""" |
|
|
|
from __future__ import annotations |
|
from typing import Callable |
|
from random import random |
|
|
|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
from torch.nn.utils.rnn import pad_sequence |
|
|
|
from torchdiffeq import odeint |
|
|
|
from einops import rearrange |
|
|
|
from model.modules import MelSpec |
|
|
|
from model.utils import ( |
|
default, exists, |
|
list_str_to_idx, list_str_to_tensor, |
|
lens_to_mask, mask_from_frac_lengths, |
|
) |
|
|
|
|
|
class CFM(nn.Module): |
|
def __init__( |
|
self, |
|
transformer: nn.Module, |
|
sigma = 0., |
|
odeint_kwargs: dict = dict( |
|
|
|
|
|
method = 'euler' |
|
), |
|
audio_drop_prob = 0.3, |
|
cond_drop_prob = 0.2, |
|
num_channels = None, |
|
mel_spec_module: nn.Module | None = None, |
|
mel_spec_kwargs: dict = dict(), |
|
frac_lengths_mask: tuple[float, float] = (0.7, 1.), |
|
vocab_char_map: dict[str: int] | None = None |
|
): |
|
super().__init__() |
|
|
|
self.frac_lengths_mask = frac_lengths_mask |
|
|
|
|
|
self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs)) |
|
num_channels = default(num_channels, self.mel_spec.n_mel_channels) |
|
self.num_channels = num_channels |
|
|
|
|
|
self.audio_drop_prob = audio_drop_prob |
|
self.cond_drop_prob = cond_drop_prob |
|
|
|
|
|
self.transformer = transformer |
|
dim = transformer.dim |
|
self.dim = dim |
|
|
|
|
|
self.sigma = sigma |
|
|
|
|
|
self.odeint_kwargs = odeint_kwargs |
|
|
|
|
|
self.vocab_char_map = vocab_char_map |
|
|
|
@property |
|
def device(self): |
|
return next(self.parameters()).device |
|
|
|
@torch.no_grad() |
|
def sample( |
|
self, |
|
cond: float['b n d'] | float['b nw'], |
|
text: int['b nt'] | list[str], |
|
duration: int | int['b'], |
|
*, |
|
lens: int['b'] | None = None, |
|
steps = 32, |
|
cfg_strength = 1., |
|
sway_sampling_coef = None, |
|
seed: int | None = None, |
|
max_duration = 4096, |
|
vocoder: Callable[[float['b d n']], float['b nw']] | None = None, |
|
no_ref_audio = False, |
|
duplicate_test = False, |
|
t_inter = 0.1, |
|
edit_mask = None, |
|
): |
|
self.eval() |
|
|
|
|
|
|
|
if cond.ndim == 2: |
|
cond = self.mel_spec(cond) |
|
cond = rearrange(cond, 'b d n -> b n d') |
|
assert cond.shape[-1] == self.num_channels |
|
|
|
batch, cond_seq_len, device = *cond.shape[:2], cond.device |
|
if not exists(lens): |
|
lens = torch.full((batch,), cond_seq_len, device = device, dtype = torch.long) |
|
|
|
|
|
|
|
if isinstance(text, list): |
|
if exists(self.vocab_char_map): |
|
text = list_str_to_idx(text, self.vocab_char_map).to(device) |
|
else: |
|
text = list_str_to_tensor(text).to(device) |
|
assert text.shape[0] == batch |
|
|
|
if exists(text): |
|
text_lens = (text != -1).sum(dim = -1) |
|
lens = torch.maximum(text_lens, lens) |
|
|
|
|
|
|
|
cond_mask = lens_to_mask(lens) |
|
if edit_mask is not None: |
|
cond_mask = cond_mask & edit_mask |
|
|
|
if isinstance(duration, int): |
|
duration = torch.full((batch,), duration, device = device, dtype = torch.long) |
|
|
|
duration = torch.maximum(lens + 1, duration) |
|
duration = duration.clamp(max = max_duration) |
|
max_duration = duration.amax() |
|
|
|
|
|
if duplicate_test: |
|
test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2*cond_seq_len), value = 0.) |
|
|
|
cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value = 0.) |
|
cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value = False) |
|
cond_mask = rearrange(cond_mask, '... -> ... 1') |
|
step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) |
|
|
|
if batch > 1: |
|
mask = lens_to_mask(duration) |
|
else: |
|
mask = None |
|
|
|
|
|
if no_ref_audio: |
|
cond = torch.zeros_like(cond) |
|
|
|
|
|
|
|
def fn(t, x): |
|
|
|
|
|
|
|
|
|
pred = self.transformer(x = x, cond = step_cond, text = text, time = t, mask = mask, drop_audio_cond = False, drop_text = False) |
|
if cfg_strength < 1e-5: |
|
return pred |
|
|
|
null_pred = self.transformer(x = x, cond = step_cond, text = text, time = t, mask = mask, drop_audio_cond = True, drop_text = True) |
|
return pred + (pred - null_pred) * cfg_strength |
|
|
|
|
|
|
|
|
|
y0 = [] |
|
for dur in duration: |
|
if exists(seed): |
|
torch.manual_seed(seed) |
|
y0.append(torch.randn(dur, self.num_channels, device = self.device)) |
|
y0 = pad_sequence(y0, padding_value = 0, batch_first = True) |
|
|
|
t_start = 0 |
|
|
|
|
|
if duplicate_test: |
|
t_start = t_inter |
|
y0 = (1 - t_start) * y0 + t_start * test_cond |
|
steps = int(steps * (1 - t_start)) |
|
|
|
t = torch.linspace(t_start, 1, steps, device = self.device) |
|
if sway_sampling_coef is not None: |
|
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) |
|
|
|
trajectory = odeint(fn, y0, t, **self.odeint_kwargs) |
|
|
|
sampled = trajectory[-1] |
|
out = sampled |
|
out = torch.where(cond_mask, cond, out) |
|
|
|
if exists(vocoder): |
|
out = rearrange(out, 'b n d -> b d n') |
|
out = vocoder(out) |
|
|
|
return out, trajectory |
|
|
|
def forward( |
|
self, |
|
inp: float['b n d'] | float['b nw'], |
|
text: int['b nt'] | list[str], |
|
*, |
|
lens: int['b'] | None = None, |
|
noise_scheduler: str | None = None, |
|
): |
|
|
|
if inp.ndim == 2: |
|
inp = self.mel_spec(inp) |
|
inp = rearrange(inp, 'b d n -> b n d') |
|
assert inp.shape[-1] == self.num_channels |
|
|
|
batch, seq_len, dtype, device, σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma |
|
|
|
|
|
if isinstance(text, list): |
|
if exists(self.vocab_char_map): |
|
text = list_str_to_idx(text, self.vocab_char_map).to(device) |
|
else: |
|
text = list_str_to_tensor(text).to(device) |
|
assert text.shape[0] == batch |
|
|
|
|
|
if not exists(lens): |
|
lens = torch.full((batch,), seq_len, device = device) |
|
|
|
mask = lens_to_mask(lens, length = seq_len) |
|
|
|
|
|
frac_lengths = torch.zeros((batch,), device = self.device).float().uniform_(*self.frac_lengths_mask) |
|
rand_span_mask = mask_from_frac_lengths(lens, frac_lengths) |
|
|
|
if exists(mask): |
|
rand_span_mask &= mask |
|
|
|
|
|
x1 = inp |
|
|
|
|
|
x0 = torch.randn_like(x1) |
|
|
|
|
|
time = torch.rand((batch,), dtype = dtype, device = self.device) |
|
|
|
|
|
|
|
t = rearrange(time, 'b -> b 1 1') |
|
φ = (1 - t) * x0 + t * x1 |
|
flow = x1 - x0 |
|
|
|
|
|
cond = torch.where( |
|
rand_span_mask[..., None], |
|
torch.zeros_like(x1), x1 |
|
) |
|
|
|
|
|
drop_audio_cond = random() < self.audio_drop_prob |
|
if random() < self.cond_drop_prob: |
|
drop_audio_cond = True |
|
drop_text = True |
|
else: |
|
drop_text = False |
|
|
|
|
|
|
|
pred = self.transformer(x = φ, cond = cond, text = text, time = time, drop_audio_cond = drop_audio_cond, drop_text = drop_text) |
|
|
|
|
|
loss = F.mse_loss(pred, flow, reduction = 'none') |
|
loss = loss[rand_span_mask] |
|
|
|
return loss.mean(), cond, pred |
|
|