import sys import logging import os import json import torch import argparse import commons import utils import gradio as gr from models import SynthesizerTrn from text.symbols import symbols from text import cleaned_text_to_sequence, get_bert from text.cleaner import clean_text logging.getLogger("numba").setLevel(logging.WARNING) logging.getLogger("markdown_it").setLevel(logging.WARNING) logging.getLogger("urllib3").setLevel(logging.WARNING) logging.getLogger("matplotlib").setLevel(logging.WARNING) logging.basicConfig( level=logging.INFO, format="| %(name)s | %(levelname)s | %(message)s" ) logger = logging.getLogger(__name__) limitation = os.getenv("SYSTEM") == "spaces" # limit text and audio length in huggingface spaces def get_text(text, hps): language_str = "JP" norm_text, phone, tone, word2ph = clean_text(text, language_str) phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str) if hps.data.add_blank: phone = commons.intersperse(phone, 0) tone = commons.intersperse(tone, 0) language = commons.intersperse(language, 0) for i in range(len(word2ph)): word2ph[i] = word2ph[i] * 2 word2ph[0] += 1 bert = get_bert(norm_text, word2ph, language_str, device) del word2ph assert bert.shape[-1] == len(phone), phone ja_bert = bert bert = torch.zeros(1024, len(phone)) assert bert.shape[-1] == len( phone ), f"Bert seq len {bert.shape[-1]} != {len(phone)}" phone = torch.LongTensor(phone) tone = torch.LongTensor(tone) language = torch.LongTensor(language) return bert, ja_bert, phone, tone, language def infer(text, sdp_ratio, noise_scale, noise_scale_w, length_scale, sid, net_g_ms, hps): bert, ja_bert, phones, tones, lang_ids = get_text(text, hps) with torch.no_grad(): x_tst = phones.to(device).unsqueeze(0) tones = tones.to(device).unsqueeze(0) lang_ids = lang_ids.to(device).unsqueeze(0) bert = bert.to(device).unsqueeze(0) ja_bert = ja_bert.to(device).unsqueeze(0) x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device) del phones sid = torch.LongTensor([sid]).to(device) audio = ( net_g_ms.infer( x_tst, x_tst_lengths, sid, tones, lang_ids, bert, ja_bert, sdp_ratio=sdp_ratio, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale, )[0][0, 0] .data.cpu() .float() .numpy() ) del x_tst, tones, lang_ids, bert, x_tst_lengths, sid torch.cuda.empty_cache() return audio def create_tts_fn(net_g_ms, hps): def tts_fn(text, speaker, sdp_ratio, noise_scale, noise_scale_w, length_scale): print(f"{text} | {speaker}") sid = hps.data.spk2id[speaker] text = text.replace('\n', ' ').replace('\r', '').replace(" ", "") if limitation: max_len = 100 if len(text) > max_len: return "Error: Text is too long", None audio = infer(text, sdp_ratio=sdp_ratio, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale, sid=sid, net_g_ms=net_g_ms, hps=hps) return "Success", (hps.data.sampling_rate, audio) return tts_fn if __name__ == "__main__": device = ( "cuda:0" if torch.cuda.is_available() else ( "mps" if sys.platform == "darwin" and torch.backends.mps.is_available() else "cpu" ) ) parser = argparse.ArgumentParser() parser.add_argument("--share", default=False, help="make link public", action="store_true") parser.add_argument("-d", "--debug", action="store_true", help="enable DEBUG-LEVEL log") args = parser.parse_args() if args.debug: logger.info("Enable DEBUG-LEVEL log") logging.basicConfig(level=logging.DEBUG) models = [] with open("pretrained_models/info.json", "r", encoding="utf-8") as f: models_info = json.load(f) for i, info in models_info.items(): if not info['enable']: continue name = info['name'] title = info['title'] example = info['example'] hps = utils.get_hparams_from_file(f"./pretrained_models/{name}/config.json") net_g_ms = SynthesizerTrn( len(symbols), hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, n_speakers=hps.data.n_speakers, **hps.model) utils.load_checkpoint(f'pretrained_models/{i}/{i}.pth', net_g_ms, None, skip_optimizer=True) _ = net_g_ms.eval().to(device) models.append((name, title, example, list(hps.data.spk2id.keys()), net_g_ms, create_tts_fn(net_g_ms, hps))) with gr.Blocks(theme='NoCrypt/miku') as app: with gr.Tabs(): for (name, title, example, speakers, net_g_ms, tts_fn) in models: with gr.TabItem(name): with gr.Row(): gr.Markdown( '