File size: 2,369 Bytes
1180f3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
import torch
import sys
import gradio as gr

from huggingface_hub import hf_hub_download

# Setup TTS env
if "vits" not in sys.path:
    sys.path.append("vits")

from vits import commons, utils
from vits.models import SynthesizerTrn

class TextMapper(object):
    def __init__(self, vocab_file):
        self.symbols = [x.strip() for x in open(vocab_file, encoding="utf-8").readlines()]
        self.SPACE_ID = self.symbols.index(" ")
        self._symbol_to_id = {s: i for i, s in enumerate(self.symbols)}

    def text_to_sequence(self, text, cleaner_names):
        sequence = [self._symbol_to_id[symbol] for symbol in text.strip()]
        return sequence

    def get_text(self, text, hps):
        text_norm = self.text_to_sequence(text, hps.data.text_cleaners)
        if hps.data.add_blank:
            text_norm = commons.intersperse(text_norm, 0)
        return torch.LongTensor(text_norm)

    def filter_oov(self, text, lang=None):
        val_chars = self._symbol_to_id
        return "".join(filter(lambda x: x in val_chars, text))

def synthesize(text, speed):
    if speed is None:
        speed = 1.0

    lang_code = "fao"

    vocab_file = hf_hub_download(repo_id="facebook/mms-tts", filename="vocab.txt", subfolder=f"models/{lang_code}")
    config_file = hf_hub_download(repo_id="facebook/mms-tts", filename="config.json", subfolder=f"models/{lang_code}")
    g_pth = hf_hub_download(repo_id="facebook/mms-tts", filename="G_100000.pth", subfolder=f"models/{lang_code}")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    hps = utils.get_hparams_from_file(config_file)
    text_mapper = TextMapper(vocab_file)
    net_g = SynthesizerTrn(len(text_mapper.symbols), hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, **hps.model)
    net_g.to(device)
    net_g.eval()
    utils.load_checkpoint(g_pth, net_g, None)

    text = text.lower()
    text = text_mapper.filter_oov(text)
    stn_tst = text_mapper.get_text(text, hps)
    with torch.no_grad():
        x_tst = stn_tst.unsqueeze(0).to(device)
        x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)
        hyp = net_g.infer(x_tst, x_tst_lengths, noise_scale=0.667, noise_scale_w=0.8, length_scale=1.0 / speed)[0][0, 0].cpu().float().numpy()

    return gr.Audio.update(value=(hps.data.sampling_rate, hyp)), text