Spaces:
Running
Running
import os | |
import sys | |
import json | |
import time | |
from os.path import getsize | |
from pathlib import Path | |
from importlib.metadata import version, PackageNotFoundError | |
import gradio as gr | |
from gradio.utils import is_zero_gpu_space | |
try: | |
import spaces | |
except ImportError: | |
print("ZeroGPU is not available, skipping...") | |
import torch | |
import torchaudio | |
from huggingface_hub import hf_hub_download | |
# RAD-TTS code | |
from radtts import RADTTS | |
from data import TextProcessor | |
from common import update_params | |
from torch_env import device | |
# Vocoder | |
from vocos import Vocos | |
def download_file_from_repo( | |
repo_id: str, | |
filename: str, | |
local_dir: str = ".", | |
repo_type: str = "model", | |
) -> str: | |
try: | |
os.makedirs(local_dir, exist_ok=True) | |
file_path = hf_hub_download( | |
repo_id=repo_id, | |
filename=filename, | |
local_dir=local_dir, | |
cache_dir=None, | |
force_download=False, | |
repo_type=repo_type, | |
) | |
return file_path | |
except Exception as e: | |
raise Exception(f"An error occurred during download: {e}") from e | |
download_file_from_repo( | |
"Yehor/radtts-uk", | |
"radtts-pp-dap-model/model_dap_84000_state.pt", | |
"./models/", | |
) | |
try: | |
spaces_version = version("spaces") | |
print("ZeroGPU is available, changing inference call.") | |
except PackageNotFoundError: | |
spaces_version = "N/A" | |
print("ZeroGPU is not available, skipping...") | |
use_zero_gpu = is_zero_gpu_space() | |
# Init the model | |
params = [] | |
# Load the config | |
config = json.loads(Path("config.json").read_text()) | |
update_params(config, params) | |
data_config = config["data_config"] | |
model_config = config["model_config"] | |
# Load vocoder | |
vocos_config = hf_hub_download( | |
"patriotyk/vocos-mel-hifigan-compat-44100khz", "config.yaml" | |
) | |
vocos_model = hf_hub_download( | |
"patriotyk/vocos-mel-hifigan-compat-44100khz", "pytorch_model.bin" | |
) | |
vocos_model_path = Path(vocos_model) | |
state_dict = torch.load(vocos_model_path, map_location="cpu") | |
vocos = Vocos.from_hparams(vocos_config).to(device) | |
vocos.load_state_dict(state_dict, strict=True) | |
vocos.eval() | |
# Load RAD-TTS | |
radtts = RADTTS(**model_config).to(device) | |
radtts.enable_inverse_cache() # cache inverse matrix for 1x1 invertible convs | |
radtts_model_path = Path("models/radtts-pp-dap-model/model_dap_84000_state.pt") | |
checkpoint_dict = torch.load(radtts_model_path, map_location="cpu") | |
state_dict = checkpoint_dict["state_dict"] | |
radtts.load_state_dict(state_dict, strict=False) | |
radtts.eval() | |
radtts_params = f"{sum(param.numel() for param in radtts.parameters()):,}" | |
vocos_params = f"{sum(param.numel() for param in vocos.parameters()):,}" | |
print(f"Loaded checkpoint (RAD-TTS++), number of parameters: {radtts_params}") | |
print(f"Loaded checkpoint (Vocos), number of parameters: {vocos_params}") | |
text_processor = TextProcessor( | |
data_config["training_files"], | |
**dict( | |
(k, v) | |
for k, v in data_config.items() | |
if k not in ["training_files", "validation_files"] | |
), | |
) | |
# Config | |
concurrency_limit = 5 | |
title = "RAD-TTS++ Ukrainian" | |
# https://www.tablesgenerator.com/markdown_tables | |
authors_table = """ | |
## Authors | |
Follow them on social networks and **contact** if you need any help or have any questions: | |
| <img src="https://avatars.githubusercontent.com/u/7875085?v=4" width="100"> **Yehor Smoliakov** | | |
|-------------------------------------------------------------------------------------------------| | |
| https://t.me/smlkw in Telegram | | |
| https://x.com/yehor_smoliakov at X | | |
| https://github.com/egorsmkv at GitHub | | |
| https://huggingface.co/Yehor at Hugging Face | | |
| or use egorsmkv@gmail.com | | |
""".strip() | |
description_head = f""" | |
# {title} | |
## Overview | |
Type your text in Ukrainian and select a voice to synthesize speech using [the RAD-TTS++ model](https://huggingface.co/Yehor/radtts-uk) and [Vocos](https://huggingface.co/patriotyk/vocos-mel-hifigan-compat-44100khz) with 44100 Hz. | |
""".strip() | |
description_foot = f""" | |
{authors_table} | |
""".strip() | |
tech_env = f""" | |
#### Environment | |
- Python: {sys.version} | |
- Torch device: {device} | |
#### Models | |
##### Acoustic model (Text-to-MEL) | |
- Name: RAD-TTS++ (DAP) | |
- Parameters: {radtts_params} | |
- File size: {getsize(radtts_model_path) / 1e6:.2f} MB | |
##### Vocoder (MEL-to-WAVE) | |
- Name: Vocos | |
- Parameters: {vocos_params} | |
- File size: {getsize(vocos_model_path) / 1e6:.2f} MB | |
""".strip() | |
tech_libraries = f""" | |
#### Libraries | |
- vocos: {version("vocos")} | |
- gradio: {version("gradio")} | |
- huggingface_hub: {version("huggingface_hub")} | |
- spaces: {spaces_version} | |
- torch: {version("torch")} | |
- torchaudio: {version("torchaudio")} | |
- scipy: {version("scipy")} | |
- numba: {version("numba")} | |
- librosa: {version("librosa")} | |
""".strip() | |
voices = { | |
"lada": 0, | |
"mykyta": 1, | |
"tetiana": 2, | |
} | |
examples = [ | |
[ | |
"Прокинувся ґазда вранці. Пішов, вичистив з-під коня, вичистив з-під бика, вичистив з-під овечок, вибрав молодняк, відніс його набік.", | |
"Mykyta", | |
], | |
[ | |
"Пішов взяв сіна, дав корові. Пішов взяв сіна, дав бикові. Ячміню коняці насипав. Зайшов почистив корову, зайшов почистив бика, зайшов почистив коня, за яйця його мацнув.", | |
"Lada", | |
], | |
[ | |
"Кінь ногою здригнув, на хазяїна ласкавим оком подивився. Тоді дядько пішов відкрив курей, гусей, качок, повиносив їм зерна, огірків нарізаних, нагодував. Коли чує – з хати дружина кличе. Зайшов. Дітки повмивані, сидять за столом, всі чекають тата. Взяв він ложку, перехрестив дітей, перехрестив лоба, почали снідати. Поснідали, він дістав пряників, роздав дітям. Діти зібралися, пішли в школу. Дядько вийшов, сів на призьбі, взяв сапку, почав мантачити. Мантачив-мантачив, коли – жінка виходить. Він їй ту сапку дає, ласкаво за сраку вщипнув, жінка до нього лагідно всміхнулася, пішла на город – сапати. Коли – йде пастух і товар кличе в череду. Повідмикав дядько овечок, коровку, бика, коня, все відпустив. Сів попри хати, дістав табАку, відірвав шмат газети, насипав, наслинив собі гарну таку цигарку. Благодать божа – і сонечко вже здійнялося над деревами. Дядько встромив цигарку в рота, дістав сірники, тільки чиркати – коли раптом з хати: Доброе утро! Московское время – шесть часов утра! Витяг дядько цигарку с рота, сплюнув набік, і сам собі каже: Ана маєш. Прокинулись, бляді!", | |
"Tetiana", | |
], | |
] | |
def inference( | |
text, | |
voice, | |
n_takes, | |
use_latest_take, | |
token_dur_scaling, | |
f0_mean, | |
f0_std, | |
energy_mean, | |
energy_std, | |
sigma_decoder, | |
sigma_token_duration, | |
sigma_f0, | |
sigma_energy, | |
): | |
if not text: | |
raise gr.Error("Please paste your text.") | |
request = { | |
"text": text, | |
"voice": voice, | |
"n_takes": n_takes, | |
"use_latest_take": use_latest_take, | |
"token_dur_scaling": token_dur_scaling, | |
"f0_mean": f0_mean, | |
"f0_std": f0_std, | |
"energy_mean": energy_mean, | |
"energy_std": energy_std, | |
"sigma_decoder": sigma_decoder, | |
"sigma_token_duration": sigma_token_duration, | |
"sigma_f0": sigma_f0, | |
"sigma_energy": sigma_energy, | |
} | |
print(json.dumps(request, indent=2)) | |
speaker = speaker_text = speaker_attributes = voice.lower() | |
tensor_text = torch.LongTensor(text_processor.tp.encode_text(text)).to(device) | |
speaker_tensor = torch.LongTensor([voices[speaker]]).to(device) | |
speaker_id = speaker_id_text = speaker_id_attributes = speaker_tensor | |
if speaker_text is not None: | |
speaker_id_text = torch.LongTensor([voices[speaker_text]]).to(device) | |
if speaker_attributes is not None: | |
speaker_id_attributes = torch.LongTensor([voices[speaker_attributes]]).to( | |
device | |
) | |
inference_start = time.time() | |
mels = [] | |
for n_take in range(n_takes): | |
gr.Info(f"Inferencing take {n_take + 1}", duration=1) | |
with torch.autocast(device, enabled=False): | |
with torch.inference_mode(): | |
outputs = radtts.infer( | |
speaker_id, | |
tensor_text[None], | |
sigma_decoder, | |
sigma_token_duration, | |
sigma_f0, | |
sigma_energy, | |
token_dur_scaling, | |
token_duration_max=100, | |
speaker_id_text=speaker_id_text, | |
speaker_id_attributes=speaker_id_attributes, | |
f0_mean=f0_mean, | |
f0_std=f0_std, | |
energy_mean=energy_mean, | |
energy_std=energy_std, | |
) | |
mels.append(outputs["mel"]) | |
gr.Info("Synthesized MEL spectrograms, converting to WAVE.", duration=0.5) | |
wav_gen_all = [] | |
for mel in mels: | |
wav_gen_all.append(vocos.decode(mel)) | |
if use_latest_take: | |
wav_gen = wav_gen_all[-1] # Get the latest generated wav | |
else: | |
wav_gen = torch.cat(wav_gen_all, dim=1) # Concatenate all the generated wavs | |
duration = len(wav_gen[0]) / 44_100 | |
torchaudio.save("audio.wav", wav_gen.cpu(), 44_100, encoding="PCM_S") | |
elapsed_time = time.time() - inference_start | |
rtf = elapsed_time / duration | |
speed_ratio = duration / elapsed_time | |
speech_rate = len(text.split(" ")) / duration | |
rtf_value = f"Real-Time Factor: {round(rtf, 4)}, time: {round(elapsed_time, 4)} seconds, audio duration: {round(duration, 4)} seconds. Speed ratio: {round(speed_ratio, 2)}x. Speech rate: {round(speech_rate, 4)} words-per-second." | |
gr.Success("Finished!", duration=0.5) | |
return [gr.Audio("audio.wav"), rtf_value] | |
inference_func = inference | |
if use_zero_gpu: | |
inference_func = spaces.GPU(inference) | |
demo = gr.Blocks( | |
title=title, | |
analytics_enabled=False, | |
theme=gr.themes.Base(), | |
) | |
with demo: | |
gr.Markdown(description_head) | |
gr.Markdown("## Usage") | |
with gr.Row(): | |
with gr.Column(): | |
audio = gr.Audio(label="Synthesized speech") | |
rtf = gr.Markdown( | |
label="Real-Time Factor", | |
value="Here you will see how fast the model and the speaker is.", | |
) | |
with gr.Row(): | |
with gr.Column(): | |
text = gr.Text( | |
label="Text", | |
value="Сл+ава Укра+їні! — українське вітання, національне гасло.", | |
) | |
voice = gr.Radio( | |
label="Voice", | |
choices=[voice.title() for voice in voices.keys()], | |
value="Tetiana", | |
) | |
with gr.Accordion("Advanced options", open=False): | |
gr.Markdown("You can change the voice, speed, and other parameters.") | |
with gr.Column(): | |
n_takes = gr.Number( | |
label="Number of takes", | |
value=1, | |
minimum=1, | |
maximum=10, | |
step=1, | |
) | |
use_latest_take = gr.Checkbox( | |
label="Use the latest take", | |
value=False, | |
) | |
token_dur_scaling = gr.Number( | |
label="Token duration scaling", | |
value=1.0, | |
minimum=0.0, | |
maximum=10, | |
step=0.1, | |
) | |
with gr.Row(): | |
f0_mean = gr.Number( | |
label="F0 mean", | |
value=0, | |
minimum=0.0, | |
maximum=1.0, | |
step=0.1, | |
) | |
f0_std = gr.Number( | |
label="F0 std", | |
value=0, | |
minimum=0.0, | |
maximum=1.0, | |
step=0.1, | |
) | |
energy_mean = gr.Number( | |
label="Energy mean", | |
value=0, | |
minimum=0.0, | |
maximum=1.0, | |
step=0.1, | |
) | |
energy_std = gr.Number( | |
label="Energy std", | |
value=0, | |
minimum=0.0, | |
maximum=1.0, | |
step=0.1, | |
) | |
with gr.Row(): | |
sigma_decoder = gr.Number( | |
label="Sampling sigma for decoder", | |
value=0.8, | |
minimum=0.0, | |
maximum=1.0, | |
step=0.1, | |
) | |
sigma_token_duration = gr.Number( | |
label="Sampling sigma for duration", | |
value=0.666, | |
minimum=0.0, | |
maximum=1.0, | |
step=0.1, | |
) | |
sigma_f0 = gr.Number( | |
label="Sampling sigma for F0", | |
value=1.0, | |
minimum=0.0, | |
maximum=1.0, | |
step=0.1, | |
) | |
sigma_energy = gr.Number( | |
label="Sampling sigma for energy avg", | |
value=1.0, | |
minimum=0.0, | |
maximum=1.0, | |
step=0.1, | |
) | |
gr.Button("Run").click( | |
inference_func, | |
concurrency_limit=concurrency_limit, | |
inputs=[ | |
text, | |
voice, | |
n_takes, | |
use_latest_take, | |
token_dur_scaling, | |
f0_mean, | |
f0_std, | |
energy_mean, | |
energy_std, | |
sigma_decoder, | |
sigma_token_duration, | |
sigma_f0, | |
sigma_energy, | |
], | |
outputs=[audio, rtf], | |
) | |
with gr.Row(): | |
gr.Examples( | |
label="Choose an example", | |
inputs=[ | |
text, | |
voice, | |
], | |
examples=examples, | |
) | |
gr.Markdown(description_foot) | |
gr.Markdown("### Gradio app uses:") | |
gr.Markdown(tech_env) | |
gr.Markdown(tech_libraries) | |
if __name__ == "__main__": | |
demo.queue() | |
demo.launch() | |