{
"cells": [
{
"cell_type": "markdown",
"id": "f37f4e3b-f764-4502-a6a2-6417bd9bfab9",
"metadata": {},
"source": [
"# Matcha-TTS: A fast TTS architecture with conditional flow matching\n",
"---\n",
"[Shivam Mehta](https://www.kth.se/profile/smehta), [Ruibo Tu](https://www.kth.se/profile/ruibo), [Jonas Beskow](https://www.kth.se/profile/beskow), [Éva Székely](https://www.kth.se/profile/szekely), and [Gustav Eje Henter](https://people.kth.se/~ghe/)\n",
"\n",
"We introduce Matcha-TTS, a new encoder-decoder architecture for speedy TTS acoustic modelling, trained using optimal-transport conditional flow matching (OT-CFM). This yields an ODE-based decoder capable of high output quality in fewer synthesis steps than models trained using score matching. Careful design choices additionally ensure each synthesis step is fast to run. The method is probabilistic, non-autoregressive, and learns to speak from scratch without external alignments. Compared to strong pre-trained baseline models, the Matcha-TTS system has the smallest memory footprint, rivals the speed of the fastest models on long utterances, and attains the highest mean opinion score in a listening test.\n",
"\n",
"Demo Page: https://shivammehta25.github.io/Matcha-TTS \\\n",
"Code: https://github.com/shivammehta25/Matcha-TTS\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "148f4bc0-c28e-4670-9a5e-4c7928ab8992",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"env: CUDA_VISIBLE_DEVICES=0\n"
]
}
],
"source": [
"%env CUDA_VISIBLE_DEVICES=0"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "8d5876c0-b47e-4c80-9e9c-62550f81b64e",
"metadata": {},
"outputs": [],
"source": [
"import datetime as dt\n",
"from pathlib import Path\n",
"\n",
"import IPython.display as ipd\n",
"import numpy as np\n",
"import soundfile as sf\n",
"import torch\n",
"from tqdm.auto import tqdm\n",
"\n",
"# Hifigan imports\n",
"from matcha.hifigan.config import v1\n",
"from matcha.hifigan.denoiser import Denoiser\n",
"from matcha.hifigan.env import AttrDict\n",
"from matcha.hifigan.models import Generator as HiFiGAN\n",
"# Matcha imports\n",
"from matcha.models.matcha_tts import MatchaTTS\n",
"from matcha.text import sequence_to_text, text_to_sequence\n",
"from matcha.utils.model import denormalize\n",
"from matcha.utils.utils import get_user_data_dir, intersperse"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "b1a30306-588c-4f22-8d9b-e2676880b0e5",
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"%matplotlib inline\n",
"# This allows for real time code changes being reflected in the notebook, no need to restart the kernel"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "a312856b-01a9-4d75-a4c8-4666dffa0692",
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "markdown",
"id": "88f3b3c3-d014-443b-84eb-e143cdec3e21",
"metadata": {},
"source": [
"## Filepaths"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "7640a4c1-44ce-447c-a8ff-45012fb7bddd",
"metadata": {},
"outputs": [],
"source": [
"MATCHA_CHECKPOINT = get_user_data_dir()/\"matcha_ljspeech.ckpt\"\n",
"HIFIGAN_CHECKPOINT = get_user_data_dir() / \"hifigan_T2_v1\"\n",
"OUTPUT_FOLDER = \"synth_output\""
]
},
{
"cell_type": "markdown",
"id": "6477a3a9-71f2-4d2f-bb86-bdf3e31c2461",
"metadata": {},
"source": [
"## Load Matcha-TTS"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "26a16230-04ba-4825-a844-2fb5ab945e24",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model loaded! Parameter count: 18,204,193\n"
]
}
],
"source": [
"def load_model(checkpoint_path):\n",
" model = MatchaTTS.load_from_checkpoint(checkpoint_path, map_location=device)\n",
" model.eval()\n",
" return model\n",
"count_params = lambda x: f\"{sum(p.numel() for p in x.parameters()):,}\"\n",
"\n",
"\n",
"model = load_model(MATCHA_CHECKPOINT)\n",
"print(f\"Model loaded! Parameter count: {count_params(model)}\")"
]
},
{
"cell_type": "markdown",
"id": "3077b84b-e3b6-42e1-a84b-2f7084b13f92",
"metadata": {},
"source": [
"## Load HiFi-GAN (Vocoder)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "f6b68184-968d-4868-9029-f0c40e9e68af",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Removing weight norm...\n"
]
}
],
"source": [
"def load_vocoder(checkpoint_path):\n",
" h = AttrDict(v1)\n",
" hifigan = HiFiGAN(h).to(device)\n",
" hifigan.load_state_dict(torch.load(checkpoint_path, map_location=device)['generator'])\n",
" _ = hifigan.eval()\n",
" hifigan.remove_weight_norm()\n",
" return hifigan\n",
"\n",
"vocoder = load_vocoder(HIFIGAN_CHECKPOINT)\n",
"denoiser = Denoiser(vocoder, mode='zeros')"
]
},
{
"cell_type": "markdown",
"id": "4cbc2ba0-09ff-40e2-9e60-6b77b534f9fb",
"metadata": {},
"source": [
"### Helper functions to synthesise"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "880a1879-24fd-4757-849c-850339120796",
"metadata": {},
"outputs": [],
"source": [
"@torch.inference_mode()\n",
"def process_text(text: str):\n",
" x = torch.tensor(intersperse(text_to_sequence(text, ['english_cleaners2']), 0),dtype=torch.long, device=device)[None]\n",
" x_lengths = torch.tensor([x.shape[-1]],dtype=torch.long, device=device)\n",
" x_phones = sequence_to_text(x.squeeze(0).tolist())\n",
" return {\n",
" 'x_orig': text,\n",
" 'x': x,\n",
" 'x_lengths': x_lengths,\n",
" 'x_phones': x_phones\n",
" }\n",
"\n",
"\n",
"@torch.inference_mode()\n",
"def synthesise(text, spks=None):\n",
" text_processed = process_text(text)\n",
" start_t = dt.datetime.now()\n",
" output = model.synthesise(\n",
" text_processed['x'], \n",
" text_processed['x_lengths'],\n",
" n_timesteps=n_timesteps,\n",
" temperature=temperature,\n",
" spks=spks,\n",
" length_scale=length_scale\n",
" )\n",
" # merge everything to one dict \n",
" output.update({'start_t': start_t, **text_processed})\n",
" return output\n",
"\n",
"@torch.inference_mode()\n",
"def to_waveform(mel, vocoder):\n",
" audio = vocoder(mel).clamp(-1, 1)\n",
" audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze()\n",
" return audio.cpu().squeeze()\n",
" \n",
"def save_to_folder(filename: str, output: dict, folder: str):\n",
" folder = Path(folder)\n",
" folder.mkdir(exist_ok=True, parents=True)\n",
" np.save(folder / f'{filename}', output['mel'].cpu().numpy())\n",
" sf.write(folder / f'{filename}.wav', output['waveform'], 22050, 'PCM_24')"
]
},
{
"cell_type": "markdown",
"id": "78f857e3-2ef7-4c86-b776-596c4d3cf875",
"metadata": {},
"source": [
"## Setup text to synthesise"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "2e0a9acd-0845-4192-ba09-b9683e28a3ac",
"metadata": {},
"outputs": [],
"source": [
"texts = [\n",
" \"The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.\"\n",
"]"
]
},
{
"cell_type": "markdown",
"id": "a9da9e2d-99b9-4c6f-8a08-c828e2cba121",
"metadata": {},
"source": [
"### Hyperparameters"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "f0d216e5-4895-4da8-9d24-9e61021d2556",
"metadata": {},
"outputs": [],
"source": [
"## Number of ODE Solver steps\n",
"n_timesteps = 10\n",
"\n",
"## Changes to the speaking rate\n",
"length_scale=1.0\n",
"\n",
"## Sampling temperature\n",
"temperature = 0.667"
]
},
{
"cell_type": "markdown",
"id": "b93aac89-c7f8-4975-8510-4e763c9689f4",
"metadata": {},
"source": [
"## Synthesis"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "5a227963-aa12-43b9-a706-1168b6fc0ba5",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8342d12401c54017b0e19b8d293a06bf",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/1 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"*****************************************************\n",
"Input text - 0\n",
"-----------------------------------------------------\n",
"The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.\n",
"*****************************************************\n",
"Phonetised text - 0\n",
"-----------------------------------------------------\n",
"_ð_ə_ _s_ˈ_i_ː_k_ɹ_ᵻ_t_ _s_ˈ_ɜ_ː_v_ɪ_s_ _b_ᵻ_l_ˈ_i_ː_v_d_ _ð_ˌ_ɐ_ɾ_ɪ_t_ _w_ʌ_z_ _v_ˈ_ɛ_ɹ_i_ _d_ˈ_a_ʊ_t_f_ə_l_ _ð_æ_t_ _ˌ_ɛ_n_i_ _p_ɹ_ˈ_ɛ_z_ɪ_d_ə_n_t_ _w_ʊ_d_ _ɹ_ˈ_a_ɪ_d_ _ɹ_ˈ_ɛ_ɡ_j_ʊ_l_ɚ_l_i_ _ɪ_n_ _ɐ_ _v_ˈ_i_ə_k_ə_l_ _w_ɪ_ð_ _ɐ_ _f_ˈ_ɪ_k_s_t_ _t_ˈ_ɑ_ː_p_,_ _ˈ_i_ː_v_ə_n_ _ð_ˌ_o_ʊ_ _t_ɹ_æ_n_s_p_ˈ_æ_ɹ_ə_n_t_._\n",
"*****************************************************\n",
"RTF:\t\t0.017228\n",
"RTF Waveform:\t0.021445\n"
]
},
{
"data": {
"text/html": [
"\n",
" \n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of ODE steps: 10\n",
"Mean RTF:\t\t\t\t0.017228 ± 0.000000\n",
"Mean RTF Waveform (incl. vocoder):\t0.021445 ± 0.000000\n"
]
}
],
"source": [
"outputs, rtfs = [], []\n",
"rtfs_w = []\n",
"for i, text in enumerate(tqdm(texts)):\n",
" output = synthesise(text) #, torch.tensor([15], device=device, dtype=torch.long).unsqueeze(0))\n",
" output['waveform'] = to_waveform(output['mel'], vocoder)\n",
"\n",
" # Compute Real Time Factor (RTF) with HiFi-GAN\n",
" t = (dt.datetime.now() - output['start_t']).total_seconds()\n",
" rtf_w = t * 22050 / (output['waveform'].shape[-1])\n",
"\n",
" ## Pretty print\n",
" print(f\"{'*' * 53}\")\n",
" print(f\"Input text - {i}\")\n",
" print(f\"{'-' * 53}\")\n",
" print(output['x_orig'])\n",
" print(f\"{'*' * 53}\")\n",
" print(f\"Phonetised text - {i}\")\n",
" print(f\"{'-' * 53}\")\n",
" print(output['x_phones'])\n",
" print(f\"{'*' * 53}\")\n",
" print(f\"RTF:\\t\\t{output['rtf']:.6f}\")\n",
" print(f\"RTF Waveform:\\t{rtf_w:.6f}\")\n",
" rtfs.append(output['rtf'])\n",
" rtfs_w.append(rtf_w)\n",
"\n",
" ## Display the synthesised waveform\n",
" ipd.display(ipd.Audio(output['waveform'], rate=22050))\n",
"\n",
" ## Save the generated waveform\n",
" save_to_folder(i, output, OUTPUT_FOLDER)\n",
"\n",
"print(f\"Number of ODE steps: {n_timesteps}\")\n",
"print(f\"Mean RTF:\\t\\t\\t\\t{np.mean(rtfs):.6f} ± {np.std(rtfs):.6f}\")\n",
"print(f\"Mean RTF Waveform (incl. vocoder):\\t{np.mean(rtfs_w):.6f} ± {np.std(rtfs_w):.6f}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e3e85c3f-1623-4647-b40c-fa96907656fc",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}