{ "cells": [ { "cell_type": "markdown", "id": "f37f4e3b-f764-4502-a6a2-6417bd9bfab9", "metadata": {}, "source": [ "# Matcha-TTS: A fast TTS architecture with conditional flow matching\n", "---\n", "[Shivam Mehta](https://www.kth.se/profile/smehta), [Ruibo Tu](https://www.kth.se/profile/ruibo), [Jonas Beskow](https://www.kth.se/profile/beskow), [Éva Székely](https://www.kth.se/profile/szekely), and [Gustav Eje Henter](https://people.kth.se/~ghe/)\n", "\n", "We introduce Matcha-TTS, a new encoder-decoder architecture for speedy TTS acoustic modelling, trained using optimal-transport conditional flow matching (OT-CFM). This yields an ODE-based decoder capable of high output quality in fewer synthesis steps than models trained using score matching. Careful design choices additionally ensure each synthesis step is fast to run. The method is probabilistic, non-autoregressive, and learns to speak from scratch without external alignments. Compared to strong pre-trained baseline models, the Matcha-TTS system has the smallest memory footprint, rivals the speed of the fastest models on long utterances, and attains the highest mean opinion score in a listening test.\n", "\n", "Demo Page: https://shivammehta25.github.io/Matcha-TTS \\\n", "Code: https://github.com/shivammehta25/Matcha-TTS\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "148f4bc0-c28e-4670-9a5e-4c7928ab8992", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "env: CUDA_VISIBLE_DEVICES=0\n" ] } ], "source": [ "%env CUDA_VISIBLE_DEVICES=0" ] }, { "cell_type": "code", "execution_count": 2, "id": "8d5876c0-b47e-4c80-9e9c-62550f81b64e", "metadata": {}, "outputs": [], "source": [ "import datetime as dt\n", "from pathlib import Path\n", "\n", "import IPython.display as ipd\n", "import numpy as np\n", "import soundfile as sf\n", "import torch\n", "from tqdm.auto import tqdm\n", "\n", "# Hifigan imports\n", "from matcha.hifigan.config import v1\n", "from matcha.hifigan.denoiser import Denoiser\n", "from matcha.hifigan.env import AttrDict\n", "from matcha.hifigan.models import Generator as HiFiGAN\n", "# Matcha imports\n", "from matcha.models.matcha_tts import MatchaTTS\n", "from matcha.text import sequence_to_text, text_to_sequence\n", "from matcha.utils.model import denormalize\n", "from matcha.utils.utils import get_user_data_dir, intersperse" ] }, { "cell_type": "code", "execution_count": 3, "id": "b1a30306-588c-4f22-8d9b-e2676880b0e5", "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "# This allows for real time code changes being reflected in the notebook, no need to restart the kernel" ] }, { "cell_type": "code", "execution_count": 4, "id": "a312856b-01a9-4d75-a4c8-4666dffa0692", "metadata": {}, "outputs": [], "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" ] }, { "cell_type": "markdown", "id": "88f3b3c3-d014-443b-84eb-e143cdec3e21", "metadata": {}, "source": [ "## Filepaths" ] }, { "cell_type": "code", "execution_count": 5, "id": "7640a4c1-44ce-447c-a8ff-45012fb7bddd", "metadata": {}, "outputs": [], "source": [ "MATCHA_CHECKPOINT = get_user_data_dir()/\"matcha_ljspeech.ckpt\"\n", "HIFIGAN_CHECKPOINT = get_user_data_dir() / \"hifigan_T2_v1\"\n", "OUTPUT_FOLDER = \"synth_output\"" ] }, { "cell_type": "markdown", "id": "6477a3a9-71f2-4d2f-bb86-bdf3e31c2461", "metadata": {}, "source": [ "## Load Matcha-TTS" ] }, { "cell_type": "code", "execution_count": 6, "id": "26a16230-04ba-4825-a844-2fb5ab945e24", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model loaded! Parameter count: 18,204,193\n" ] } ], "source": [ "def load_model(checkpoint_path):\n", " model = MatchaTTS.load_from_checkpoint(checkpoint_path, map_location=device)\n", " model.eval()\n", " return model\n", "count_params = lambda x: f\"{sum(p.numel() for p in x.parameters()):,}\"\n", "\n", "\n", "model = load_model(MATCHA_CHECKPOINT)\n", "print(f\"Model loaded! Parameter count: {count_params(model)}\")" ] }, { "cell_type": "markdown", "id": "3077b84b-e3b6-42e1-a84b-2f7084b13f92", "metadata": {}, "source": [ "## Load HiFi-GAN (Vocoder)" ] }, { "cell_type": "code", "execution_count": 7, "id": "f6b68184-968d-4868-9029-f0c40e9e68af", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Removing weight norm...\n" ] } ], "source": [ "def load_vocoder(checkpoint_path):\n", " h = AttrDict(v1)\n", " hifigan = HiFiGAN(h).to(device)\n", " hifigan.load_state_dict(torch.load(checkpoint_path, map_location=device)['generator'])\n", " _ = hifigan.eval()\n", " hifigan.remove_weight_norm()\n", " return hifigan\n", "\n", "vocoder = load_vocoder(HIFIGAN_CHECKPOINT)\n", "denoiser = Denoiser(vocoder, mode='zeros')" ] }, { "cell_type": "markdown", "id": "4cbc2ba0-09ff-40e2-9e60-6b77b534f9fb", "metadata": {}, "source": [ "### Helper functions to synthesise" ] }, { "cell_type": "code", "execution_count": 8, "id": "880a1879-24fd-4757-849c-850339120796", "metadata": {}, "outputs": [], "source": [ "@torch.inference_mode()\n", "def process_text(text: str):\n", " x = torch.tensor(intersperse(text_to_sequence(text, ['english_cleaners2']), 0),dtype=torch.long, device=device)[None]\n", " x_lengths = torch.tensor([x.shape[-1]],dtype=torch.long, device=device)\n", " x_phones = sequence_to_text(x.squeeze(0).tolist())\n", " return {\n", " 'x_orig': text,\n", " 'x': x,\n", " 'x_lengths': x_lengths,\n", " 'x_phones': x_phones\n", " }\n", "\n", "\n", "@torch.inference_mode()\n", "def synthesise(text, spks=None):\n", " text_processed = process_text(text)\n", " start_t = dt.datetime.now()\n", " output = model.synthesise(\n", " text_processed['x'], \n", " text_processed['x_lengths'],\n", " n_timesteps=n_timesteps,\n", " temperature=temperature,\n", " spks=spks,\n", " length_scale=length_scale\n", " )\n", " # merge everything to one dict \n", " output.update({'start_t': start_t, **text_processed})\n", " return output\n", "\n", "@torch.inference_mode()\n", "def to_waveform(mel, vocoder):\n", " audio = vocoder(mel).clamp(-1, 1)\n", " audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze()\n", " return audio.cpu().squeeze()\n", " \n", "def save_to_folder(filename: str, output: dict, folder: str):\n", " folder = Path(folder)\n", " folder.mkdir(exist_ok=True, parents=True)\n", " np.save(folder / f'{filename}', output['mel'].cpu().numpy())\n", " sf.write(folder / f'{filename}.wav', output['waveform'], 22050, 'PCM_24')" ] }, { "cell_type": "markdown", "id": "78f857e3-2ef7-4c86-b776-596c4d3cf875", "metadata": {}, "source": [ "## Setup text to synthesise" ] }, { "cell_type": "code", "execution_count": 9, "id": "2e0a9acd-0845-4192-ba09-b9683e28a3ac", "metadata": {}, "outputs": [], "source": [ "texts = [\n", " \"The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.\"\n", "]" ] }, { "cell_type": "markdown", "id": "a9da9e2d-99b9-4c6f-8a08-c828e2cba121", "metadata": {}, "source": [ "### Hyperparameters" ] }, { "cell_type": "code", "execution_count": 10, "id": "f0d216e5-4895-4da8-9d24-9e61021d2556", "metadata": {}, "outputs": [], "source": [ "## Number of ODE Solver steps\n", "n_timesteps = 10\n", "\n", "## Changes to the speaking rate\n", "length_scale=1.0\n", "\n", "## Sampling temperature\n", "temperature = 0.667" ] }, { "cell_type": "markdown", "id": "b93aac89-c7f8-4975-8510-4e763c9689f4", "metadata": {}, "source": [ "## Synthesis" ] }, { "cell_type": "code", "execution_count": 15, "id": "5a227963-aa12-43b9-a706-1168b6fc0ba5", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8342d12401c54017b0e19b8d293a06bf", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/1 [00:00\n", " \n", " Your browser does not support the audio element.\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Number of ODE steps: 10\n", "Mean RTF:\t\t\t\t0.017228 ± 0.000000\n", "Mean RTF Waveform (incl. vocoder):\t0.021445 ± 0.000000\n" ] } ], "source": [ "outputs, rtfs = [], []\n", "rtfs_w = []\n", "for i, text in enumerate(tqdm(texts)):\n", " output = synthesise(text) #, torch.tensor([15], device=device, dtype=torch.long).unsqueeze(0))\n", " output['waveform'] = to_waveform(output['mel'], vocoder)\n", "\n", " # Compute Real Time Factor (RTF) with HiFi-GAN\n", " t = (dt.datetime.now() - output['start_t']).total_seconds()\n", " rtf_w = t * 22050 / (output['waveform'].shape[-1])\n", "\n", " ## Pretty print\n", " print(f\"{'*' * 53}\")\n", " print(f\"Input text - {i}\")\n", " print(f\"{'-' * 53}\")\n", " print(output['x_orig'])\n", " print(f\"{'*' * 53}\")\n", " print(f\"Phonetised text - {i}\")\n", " print(f\"{'-' * 53}\")\n", " print(output['x_phones'])\n", " print(f\"{'*' * 53}\")\n", " print(f\"RTF:\\t\\t{output['rtf']:.6f}\")\n", " print(f\"RTF Waveform:\\t{rtf_w:.6f}\")\n", " rtfs.append(output['rtf'])\n", " rtfs_w.append(rtf_w)\n", "\n", " ## Display the synthesised waveform\n", " ipd.display(ipd.Audio(output['waveform'], rate=22050))\n", "\n", " ## Save the generated waveform\n", " save_to_folder(i, output, OUTPUT_FOLDER)\n", "\n", "print(f\"Number of ODE steps: {n_timesteps}\")\n", "print(f\"Mean RTF:\\t\\t\\t\\t{np.mean(rtfs):.6f} ± {np.std(rtfs):.6f}\")\n", "print(f\"Mean RTF Waveform (incl. vocoder):\\t{np.mean(rtfs_w):.6f} ± {np.std(rtfs_w):.6f}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "e3e85c3f-1623-4647-b40c-fa96907656fc", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }