import os import torch import sys import gradio as gr from huggingface_hub import hf_hub_download # Setup TTS env if "vits" not in sys.path: sys.path.append("vits") from vits import commons, utils from vits.models import SynthesizerTrn class TextMapper(object): def __init__(self, vocab_file): self.symbols = [x.strip() for x in open(vocab_file, encoding="utf-8").readlines()] self.SPACE_ID = self.symbols.index(" ") self._symbol_to_id = {s: i for i, s in enumerate(self.symbols)} def text_to_sequence(self, text, cleaner_names): sequence = [self._symbol_to_id[symbol] for symbol in text.strip()] return sequence def get_text(self, text, hps): text_norm = self.text_to_sequence(text, hps.data.text_cleaners) if hps.data.add_blank: text_norm = commons.intersperse(text_norm, 0) return torch.LongTensor(text_norm) def filter_oov(self, text, lang=None): val_chars = self._symbol_to_id return "".join(filter(lambda x: x in val_chars, text)) def synthesize(text, speed): if speed is None: speed = 1.0 lang_code = "fao" vocab_file = hf_hub_download(repo_id="facebook/mms-tts", filename="vocab.txt", subfolder=f"models/{lang_code}") config_file = hf_hub_download(repo_id="facebook/mms-tts", filename="config.json", subfolder=f"models/{lang_code}") g_pth = hf_hub_download(repo_id="facebook/mms-tts", filename="G_100000.pth", subfolder=f"models/{lang_code}") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") hps = utils.get_hparams_from_file(config_file) text_mapper = TextMapper(vocab_file) net_g = SynthesizerTrn(len(text_mapper.symbols), hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, **hps.model) net_g.to(device) net_g.eval() utils.load_checkpoint(g_pth, net_g, None) text = text.lower() text = text_mapper.filter_oov(text) stn_tst = text_mapper.get_text(text, hps) with torch.no_grad(): x_tst = stn_tst.unsqueeze(0).to(device) x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device) hyp = net_g.infer(x_tst, x_tst_lengths, noise_scale=0.667, noise_scale_w=0.8, length_scale=1.0 / speed)[0][0, 0].cpu().float().numpy() return gr.Audio.update(value=(hps.data.sampling_rate, hyp)), text