|
import torch |
|
from torch import Tensor |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import logging |
|
import json |
|
from typing import Optional |
|
from pathlib import Path |
|
from dataclasses import dataclass |
|
import os |
|
|
|
from mars5.model import CodecLM, ResidualTransformer |
|
from vocos import Vocos |
|
from encodec import EncodecModel |
|
from mars5.diffuser import MultinomialDiffusion, DSH, perform_simple_inference |
|
from mars5.minbpe.regex import RegexTokenizer, GPT4_SPLIT_PATTERN |
|
from mars5.minbpe.codebook import CodebookTokenizer |
|
from mars5.ar_generate import ar_generate |
|
from mars5.utils import nuke_weight_norm |
|
from mars5.trim import trim |
|
import tempfile |
|
import logging |
|
|
|
|
|
@dataclass |
|
class InferenceConfig(): |
|
""" The defaults configuration variables for TTS inference. """ |
|
|
|
|
|
temperature: float = 0.7 |
|
top_k: int = 200 |
|
top_p: float = 0.2 |
|
typical_p: float = 1.0 |
|
freq_penalty: float = 3 |
|
presence_penalty: float = 0.4 |
|
rep_penalty_window: int = 80 |
|
|
|
eos_penalty_decay: float = 0.5 |
|
eos_penalty_factor: float = 1 |
|
eos_estimated_gen_length_factor: float = 1.0 |
|
|
|
|
|
|
|
timesteps: int = 200 |
|
x_0_temp: float = 0.7 |
|
q0_override_steps: int = 20 |
|
nar_guidance_w: float = 3 |
|
|
|
max_prompt_dur: float = 12 |
|
|
|
|
|
|
|
|
|
|
|
generate_max_len_override: int = -1 |
|
|
|
|
|
|
|
|
|
deep_clone: bool = True |
|
|
|
use_kv_cache: bool = True |
|
trim_db: float = 27 |
|
beam_width: int = 1 |
|
ref_audio_pad: float = 0 |
|
|
|
|
|
class Mars5TTS(nn.Module): |
|
|
|
def __init__(self, ar_ckpt, nar_ckpt, device: str = None) -> None: |
|
super().__init__() |
|
|
|
if device is None: |
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
self.device = torch.device(device) |
|
|
|
self.codec = EncodecModel.encodec_model_24khz().to(device).eval() |
|
self.codec.set_target_bandwidth(6.0) |
|
|
|
|
|
self.texttok = RegexTokenizer(GPT4_SPLIT_PATTERN) |
|
tfn = tempfile.mkstemp(suffix='texttok.model')[1] |
|
Path(tfn).write_text(ar_ckpt['vocab']['texttok.model']) |
|
self.texttok.load(tfn) |
|
os.remove(tfn) |
|
|
|
sfn = tempfile.mkstemp(suffix='speechtok.model')[1] |
|
self.speechtok = CodebookTokenizer(GPT4_SPLIT_PATTERN) |
|
Path(sfn).write_text(ar_ckpt['vocab']['speechtok.model']) |
|
self.speechtok.load(sfn) |
|
os.remove(sfn) |
|
|
|
self.n_vocab = len(self.texttok.vocab) + len(self.speechtok.vocab) |
|
self.n_text_vocab = len(self.texttok.vocab) + 1 |
|
self.diffusion_n_classes: int = 1025 |
|
|
|
self.codeclm = CodecLM(n_vocab=self.n_vocab, dim=1536, dim_ff_scale=7/3) |
|
self.codeclm.load_state_dict(ar_ckpt['model']) |
|
self.codeclm = self.codeclm.to(self.device).eval() |
|
|
|
self.codecnar = ResidualTransformer(n_text_vocab=self.n_text_vocab, n_quant=self.diffusion_n_classes, |
|
p_cond_drop=0, dropout=0) |
|
self.codecnar.load_state_dict(nar_ckpt['model']) |
|
self.codecnar = self.codecnar.to(self.device).eval() |
|
self.default_T = 200 |
|
|
|
self.sr = 24000 |
|
self.latent_sr = 75 |
|
|
|
|
|
self.vocos = Vocos.from_pretrained("charactr/vocos-encodec-24khz").to(self.device).eval() |
|
nuke_weight_norm(self.codec) |
|
nuke_weight_norm(self.vocos) |
|
|
|
@torch.inference_mode |
|
def vocode(self, tokens: Tensor) -> Tensor: |
|
""" Vocodes tokens of shape (seq_len, n_q) """ |
|
tokens = tokens.T.to(self.device) |
|
features = self.vocos.codes_to_features(tokens) |
|
|
|
|
|
|
|
|
|
|
|
bandwidth_id = torch.tensor([1], device=self.device) |
|
wav_diffusion = self.vocos.decode(features, bandwidth_id=bandwidth_id) |
|
return wav_diffusion.cpu().squeeze()[None] |
|
|
|
@torch.inference_mode |
|
def tts(self, text: str, ref_audio: Tensor, ref_transcript: Optional[str] = None, |
|
cfg: Optional[InferenceConfig] = InferenceConfig()) -> Tensor: |
|
""" Perform TTS for `text`, given a reference audio `ref_audio` (of shape [sequence_length,], sampled at 24kHz) |
|
which has an associated `ref_transcript`. Perform inference using the inference |
|
config given by `cfg`, which controls the temperature, top_p, etc... |
|
Returns: |
|
- `ar_codes`: (seq_len,) long tensor of discrete coarse code outputs from the AR model. |
|
- `out_wav`: (T,) float output audio tensor sampled at 24kHz. |
|
""" |
|
|
|
if cfg.deep_clone and ref_transcript is None: |
|
raise AssertionError( |
|
("Inference config deep clone is set to true, but reference transcript not specified! " |
|
"Please specify the transcript of the prompt, or set deep_clone=False in the inference `cfg` argument." |
|
)) |
|
ref_dur = ref_audio.shape[-1]/self.sr |
|
if ref_dur > cfg.max_prompt_dur: |
|
logging.warning((f"Reference audio duration is {ref_dur:.2f} > max suggested ref audio. " |
|
f"Expect quality degradations. We recommend you trim prompt to be shorter than max prompt length.")) |
|
|
|
|
|
text_tokens = self.texttok.encode("<|startoftext|>"+text.strip()+"<|endoftext|>", |
|
allowed_special='all') |
|
|
|
text_tokens_full = self.texttok.encode("<|startoftext|>"+ ref_transcript + ' ' + str(text).strip()+"<|endoftext|>", |
|
allowed_special='all') |
|
|
|
if ref_audio.dim() == 1: ref_audio = ref_audio[None] |
|
if ref_audio.shape[0] != 1: ref_audio = ref_audio.mean(dim=0, keepdim=True) |
|
ref_audio = F.pad(ref_audio, (int(self.sr*cfg.ref_audio_pad), 0)) |
|
|
|
prompt_codec = self.codec.encode(ref_audio[None].to(self.device))[0][0] |
|
|
|
n_speech_inp = 0 |
|
n_start_skip = 0 |
|
q0_str = ' '.join([str(t) for t in prompt_codec[0, 0].tolist()]) |
|
|
|
speech_tokens = self.speechtok.encode(q0_str.strip()) |
|
spk_ref_codec = prompt_codec[0, :, :].T |
|
|
|
raw_prompt_acoustic_len = len(prompt_codec[0,0].squeeze()) |
|
offset_speech_codes = [p+len(self.texttok.vocab) for p in speech_tokens] |
|
if not cfg.deep_clone: |
|
|
|
|
|
offset_speech_codes = offset_speech_codes[:n_speech_inp] |
|
else: |
|
|
|
|
|
text_tokens = text_tokens_full |
|
|
|
n_speech_inp = len(offset_speech_codes) |
|
prompt = torch.tensor(text_tokens + offset_speech_codes, dtype=torch.long, device=self.device) |
|
first_codec_idx = prompt.shape[-1] - n_speech_inp + 1 |
|
|
|
|
|
|
|
logging.debug(f"Raw acoustic prompt length: {raw_prompt_acoustic_len}") |
|
|
|
ar_codes = ar_generate(self.texttok, self.speechtok, self.codeclm, |
|
prompt, spk_ref_codec, first_codec_idx, |
|
max_len=cfg.generate_max_len_override if cfg.generate_max_len_override > 1 else 2000, |
|
temperature=cfg.temperature, topk=cfg.top_k, top_p=cfg.top_p, typical_p=cfg.typical_p, |
|
alpha_frequency=cfg.freq_penalty, alpha_presence=cfg.presence_penalty, penalty_window=cfg.rep_penalty_window, |
|
eos_penalty_decay=cfg.eos_penalty_decay, eos_penalty_factor=cfg.eos_penalty_factor, |
|
beam_width=cfg.beam_width, beam_length_penalty=1, |
|
n_phones_gen=round(cfg.eos_estimated_gen_length_factor*len(text)), |
|
vocode=False, use_kv_cache=cfg.use_kv_cache) |
|
|
|
|
|
output_tokens = ar_codes - len(self.texttok.vocab) |
|
output_tokens = output_tokens.clamp(min=0).squeeze()[first_codec_idx:].cpu().tolist() |
|
gen_codes_decoded = self.speechtok.decode_int(output_tokens) |
|
gen_codes_decoded = torch.tensor([s for s in gen_codes_decoded if type(s) == int], dtype=torch.long, device=self.device) |
|
|
|
c_text = torch.tensor(text_tokens, dtype=torch.long, device=self.device)[None] |
|
c_codes = prompt_codec.permute(0, 2, 1) |
|
c_texts_lengths = torch.tensor([len(text_tokens)], dtype=torch.long, device=self.device) |
|
c_codes_lengths = torch.tensor([c_codes.shape[1],], dtype=torch.long, device=self.device) |
|
|
|
_x = gen_codes_decoded[None, n_start_skip:, None].repeat(1, 1, 8) |
|
x_padding_mask = torch.zeros((1, _x.shape[1]), dtype=torch.bool, device=_x.device) |
|
|
|
|
|
|
|
T = self.default_T |
|
diff = MultinomialDiffusion(self.diffusion_n_classes, timesteps=T, device=self.device) |
|
|
|
dsh_cfg = DSH(last_greedy=True, x_0_temp=cfg.x_0_temp, |
|
guidance_w=cfg.nar_guidance_w, |
|
deep_clone=cfg.deep_clone, jump_len=1, jump_n_sample=1, |
|
q0_override_steps=cfg.q0_override_steps, |
|
enable_kevin_scaled_inference=True, |
|
progress=False) |
|
|
|
final_output = perform_simple_inference(self.codecnar,( |
|
c_text, c_codes, c_texts_lengths, c_codes_lengths, _x, x_padding_mask |
|
), diff, diff.num_timesteps, torch.float16, dsh=dsh_cfg, retain_quant0=True) |
|
|
|
skip_front = raw_prompt_acoustic_len if cfg.deep_clone else 0 |
|
final_output = final_output[0, skip_front:].to(self.device) |
|
|
|
|
|
final_audio = self.vocode(final_output).squeeze() |
|
final_audio, _ = trim(final_audio.cpu(), top_db=cfg.trim_db) |
|
|
|
return gen_codes_decoded, final_audio |
|
|