|
import os |
|
import torch |
|
import sys |
|
import gradio as gr |
|
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
if "vits" not in sys.path: |
|
sys.path.append("vits") |
|
|
|
from vits import commons, utils |
|
from vits.models import SynthesizerTrn |
|
|
|
class TextMapper(object): |
|
def __init__(self, vocab_file): |
|
self.symbols = [x.strip() for x in open(vocab_file, encoding="utf-8").readlines()] |
|
self.SPACE_ID = self.symbols.index(" ") |
|
self._symbol_to_id = {s: i for i, s in enumerate(self.symbols)} |
|
|
|
def text_to_sequence(self, text, cleaner_names): |
|
sequence = [self._symbol_to_id[symbol] for symbol in text.strip()] |
|
return sequence |
|
|
|
def get_text(self, text, hps): |
|
text_norm = self.text_to_sequence(text, hps.data.text_cleaners) |
|
if hps.data.add_blank: |
|
text_norm = commons.intersperse(text_norm, 0) |
|
return torch.LongTensor(text_norm) |
|
|
|
def filter_oov(self, text, lang=None): |
|
val_chars = self._symbol_to_id |
|
return "".join(filter(lambda x: x in val_chars, text)) |
|
|
|
def synthesize(text, speed): |
|
if speed is None: |
|
speed = 1.0 |
|
|
|
lang_code = "fao" |
|
|
|
vocab_file = hf_hub_download(repo_id="facebook/mms-tts", filename="vocab.txt", subfolder=f"models/{lang_code}") |
|
config_file = hf_hub_download(repo_id="facebook/mms-tts", filename="config.json", subfolder=f"models/{lang_code}") |
|
g_pth = hf_hub_download(repo_id="facebook/mms-tts", filename="G_100000.pth", subfolder=f"models/{lang_code}") |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
hps = utils.get_hparams_from_file(config_file) |
|
text_mapper = TextMapper(vocab_file) |
|
net_g = SynthesizerTrn(len(text_mapper.symbols), hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, **hps.model) |
|
net_g.to(device) |
|
net_g.eval() |
|
utils.load_checkpoint(g_pth, net_g, None) |
|
|
|
text = text.lower() |
|
text = text_mapper.filter_oov(text) |
|
stn_tst = text_mapper.get_text(text, hps) |
|
with torch.no_grad(): |
|
x_tst = stn_tst.unsqueeze(0).to(device) |
|
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device) |
|
hyp = net_g.infer(x_tst, x_tst_lengths, noise_scale=0.667, noise_scale_w=0.8, length_scale=1.0 / speed)[0][0, 0].cpu().float().numpy() |
|
|
|
return gr.Audio.update(value=(hps.data.sampling_rate, hyp)), text |
|
|