matxa-alvocat-tts-ca / infer_onnx.py
Baybars's picture
about page template added
fc52d83
raw
history blame
9.91 kB
import numpy as np
import onnxruntime
from text import text_to_sequence, sequence_to_text
import torch
import gradio as gr
import soundfile as sf
import tempfile
import yaml
import json
import os
from huggingface_hub import hf_hub_download
from time import perf_counter
DEFAULT_SPEAKER_ID = os.environ.get("DEFAULT_SPEAKER_ID", default="caf_08106")
def intersperse(lst, item):
result = [item] * (len(lst) * 2 + 1)
result[1::2] = lst
return result
def process_text(i: int, text: str, device: torch.device):
print(f"[{i}] - Input text: {text}")
x = torch.tensor(
intersperse(text_to_sequence(text, ["catalan_cleaners"]), 0),
dtype=torch.long,
device=device,
)[None]
x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=device)
x_phones = sequence_to_text(x.squeeze(0).tolist())
print(x_phones)
return x.numpy(), x_lengths.numpy()
MODEL_PATH_MATCHA_MEL=hf_hub_download(repo_id="BSC-LT/matcha-tts-cat-multispeaker", filename="matcha_multispeaker_cat_opset_15_10_steps_2399.onnx")
MODEL_PATH_MATCHA="matcha_hifigan_multispeaker_cat.onnx"
MODEL_PATH_VOCOS=hf_hub_download(repo_id="BSC-LT/vocos-mel-22khz-cat", filename="mel_spec_22khz_cat.onnx")
CONFIG_PATH=hf_hub_download(repo_id="BSC-LT/vocos-mel-22khz-cat", filename="config.yaml")
SPEAKER_ID_DICT="spk_to_id.json"
sess_options = onnxruntime.SessionOptions()
model_matcha_mel= onnxruntime.InferenceSession(str(MODEL_PATH_MATCHA_MEL), sess_options=sess_options, providers=["CPUExecutionProvider"])
model_vocos = onnxruntime.InferenceSession(str(MODEL_PATH_VOCOS), sess_options=sess_options, providers=["CPUExecutionProvider"])
#model_matcha = onnxruntime.InferenceSession(str(MODEL_PATH_MATCHA), sess_options=sess_options, providers=["CPUExecutionProvider"])
speaker_id_dict = json.load(open(SPEAKER_ID_DICT))
speakers = [sp for sp in speaker_id_dict.keys()]
speakers.sort()
def vocos_inference(mel,denoise):
with open(CONFIG_PATH, "r") as f:
config = yaml.safe_load(f)
params = config["feature_extractor"]["init_args"]
sample_rate = params["sample_rate"]
n_fft= params["n_fft"]
hop_length= params["hop_length"]
win_length = n_fft
# ONNX inference
mag, x, y = model_vocos.run(
None,
{
"mels": mel
},
)
# complex spectrogram from vocos output
spectrogram = mag * (x + 1j * y)
window = torch.hann_window(win_length)
if denoise:
# Vocoder bias
mel_rand = torch.zeros_like(torch.tensor(mel))
mag_bias, x_bias, y_bias = model_vocos.run(
None,
{
"mels": mel_rand.float().numpy()
},
)
# complex spectrogram from vocos output
spectrogram_bias = mag_bias * (x_bias + 1j * y_bias)
# Denoising
spec = torch.view_as_real(torch.tensor(spectrogram))
# get magnitude of vocos spectrogram
mag_spec = torch.sqrt(spec.pow(2).sum(-1))
# get magnitude of bias spectrogram
spec_bias = torch.view_as_real(torch.tensor(spectrogram_bias))
mag_spec_bias = torch.sqrt(spec_bias.pow(2).sum(-1))
# substract
strength = 0.0025
mag_spec_denoised = mag_spec - mag_spec_bias * strength
mag_spec_denoised = torch.clamp(mag_spec_denoised, 0.0)
# return to complex spectrogram from magnitude
angle = torch.atan2(spec[..., -1], spec[..., 0] )
spectrogram = torch.complex(mag_spec_denoised * torch.cos(angle), mag_spec_denoised * torch.sin(angle))
# Inverse stft
pad = (win_length - hop_length) // 2
spectrogram = torch.tensor(spectrogram)
B, N, T = spectrogram.shape
print("Spectrogram synthesized shape", spectrogram.shape)
# Inverse FFT
ifft = torch.fft.irfft(spectrogram, n_fft, dim=1, norm="backward")
ifft = ifft * window[None, :, None]
# Overlap and Add
output_size = (T - 1) * hop_length + win_length
y = torch.nn.functional.fold(
ifft, output_size=(1, output_size), kernel_size=(1, win_length), stride=(1, hop_length),
)[:, 0, 0, pad:-pad]
# Window envelope
window_sq = window.square().expand(1, T, -1).transpose(1, 2)
window_envelope = torch.nn.functional.fold(
window_sq, output_size=(1, output_size), kernel_size=(1, win_length), stride=(1, hop_length),
).squeeze()[pad:-pad]
# Normalize
assert (window_envelope > 1e-11).all()
y = y / window_envelope
return y
def tts(text:str, spk_name:str, temperature:float, length_scale:float, denoise:bool):
spk_id = speaker_id_dict[spk_name]
sid = np.array([int(spk_id)]) if spk_id is not None else None
text_matcha , text_lengths = process_text(0,text,"cpu")
# MATCHA VOCOS
inputs = {
"x": text_matcha,
"x_lengths": text_lengths,
"scales": np.array([temperature, length_scale], dtype=np.float32),
"spks": sid
}
mel_t0 = perf_counter()
# matcha mel inference
mel, mel_lengths = model_matcha_mel.run(None, inputs)
mel_infer_secs = perf_counter() - mel_t0
print("Matcha Mel inference time", mel_infer_secs)
vocos_t0 = perf_counter()
# vocos inference
wavs_vocos = vocos_inference(mel,denoise)
vocos_infer_secs = perf_counter() - vocos_t0
print("Vocos inference time", vocos_infer_secs)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False, dir="/home/user/app") as fp_matcha_vocos:
sf.write(fp_matcha_vocos.name, wavs_vocos.squeeze(0), 22050, "PCM_24")
#MATCHA HIFIGAN
inputs = {
"x": text_matcha,
"x_lengths": text_lengths,
"scales": np.array([temperature, length_scale], dtype=np.float32),
"spks": sid
}
hifigan_t0 = perf_counter()
print(f"RTF matcha + vocos { (mel_infer_secs + vocos_infer_secs) / (wavs_vocos.shape[1]/22050) }")
return fp_matcha_vocos.name
## GUI space
title = """
<div style="text-align: center; max-width: 700px; margin: 0 auto;">
<div
style="display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;"
> <h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
Natural and efficient TTS in Catalan
</h1> </div>
</div>
"""
description = """
🍵 Matcha-TTS, a new approach to non-autoregressive neural TTS, that uses conditional flow matching (similar to rectified flows) to speed up ODE-based speech synthesis
For vocoders we use [Vocos](https://huggingface.co/BSC-LT/vocos-mel-22khz-cat) trained in a catalan set of ~28 hours.
[Matcha](https://huggingface.co/BSC-LT/matcha-tts-cat-onnx) was trained using openslr69 and festcat datasets
"""
about = """
## 📄 About
The TTS test about.
## Samples
<table style="font-size:16px">
<col width="205">
<col width="205">
<thead>
<tr>
<td>Col1</td>
<td>Col2</td>
<td>Col3</td>
</tr>
</thead>
<tbody>
<tr>
<td><audio controls="" preload="none" style="width: 200px">audio not supported<source src="https://samplelib.com/lib/preview/mp3/sample-3s.mp3"></audio></td>
<td><audio controls="" preload="none" style="width: 200px">audio not supported<source src="https://samplelib.com/lib/preview/mp3/sample-6s.mp3"></audio></td>
<td><audio controls="" preload="none" style="width: 200px">audio not supported<source src="https://samplelib.com/lib/preview/mp3/sample-9s.mp3"></audio></td>
</tr>
<tr>
<td><audio controls="" preload="none" style="width: 200px">audio not supported<source src="https://samplelib.com/lib/preview/mp3/sample-3s.mp3"></audio></td>
<td><audio controls="" preload="none" style="width: 200px">audio not supported<source src="https://samplelib.com/lib/preview/mp3/sample-6s.mp3"></audio></td>
<td><audio controls="" preload="none" style="width: 200px">audio not supported<source src="https://samplelib.com/lib/preview/mp3/sample-9s.mp3"></audio></td>
</tr>
<tr>
<td><audio controls="" preload="none" style="width: 200px">audio not supported<source src="https://samplelib.com/lib/preview/mp3/sample-3s.mp3"></audio></td>
<td><audio controls="" preload="none" style="width: 200px">audio not supported<source src="https://samplelib.com/lib/preview/mp3/sample-6s.mp3"></audio></td>
<td><audio controls="" preload="none" style="width: 200px">audio not supported<source src="https://samplelib.com/lib/preview/mp3/sample-9s.mp3"></audio></td>
</tr>
</tbody></table>
"""
article = "Training and demo by The Language Technologies Unit from Barcelona Supercomputing Center."
vits2_inference = gr.Interface(
fn=tts,
inputs=[
gr.Textbox(
value="m'ha costat molt desenvolupar una veu, i ara que la tinc no estaré en silenci.",
max_lines=1,
label="Input text",
),
gr.Dropdown(
choices=speakers,
label="Speaker id",
value=DEFAULT_SPEAKER_ID,
info=f"Models are trained on 47 speakers. You can prompt the model using one of these speaker ids."
),
gr.Slider(
0.1,
2.0,
value=0.667,
step=0.01,
label="Temperature",
info=f"Temperature",
),
gr.Slider(
0.5,
2.0,
value=1.0,
step=0.01,
label="Length scale",
info=f"Controls speech pace, larger values for slower pace and smaller values for faster pace",
),
gr.Checkbox(label="Denoise", info="Removes model bias from vocos", value=True),
],
outputs=[gr.Audio(label="Matcha vocos", interactive=False, type="filepath")]
)
about_article = gr.Markdown(about)
demo = gr.Blocks()
with demo:
gr.Markdown(title)
gr.Markdown(description)
gr.TabbedInterface([vits2_inference, about_article], ["Demo", "About"])
gr.Markdown(article)
demo.queue(max_size=10)
demo.launch(show_api=False, server_name="0.0.0.0", server_port=7860)