import argparse import json import os import re import tempfile from pathlib import Path import librosa import numpy as np import torch from torch import no_grad, LongTensor import commons import utils import gradio as gr import gradio.utils as gr_utils import json import gradio.processing_utils as gr_processing_utils from models import SynthesizerTrn from text import text_to_sequence, _clean_text # from mel_processing import spectrogram_torch # import sounddevice as sd # from scipy.io.wavfile import write # import scikits.audiolab # import soundfile as sf import scipy.io.wavfile as wf import base64 limitation = False device = torch.device('cpu') download_audio_js = """ () =>{{ let root = document.querySelector("body > gradio-app"); if (root.shadowRoot != null) root = root.shadowRoot; let audio = root.querySelector("#{audio_id}").querySelector("audio"); if (audio == undefined) return; audio = audio.src; let oA = document.createElement("a"); oA.download = Math.floor(Math.random()*100000000)+'.wav'; oA.href = audio; document.body.appendChild(oA); oA.click(); oA.remove(); }} """ # download = gr.Button("Download Audio") tts_input1 = gr.TextArea(label="inputText", value="あなたと一緒にいると、とても興奮します", elem_id=f"tts-input{0}") tts_output2 = gr.Audio(label="outputAudio", elem_id=f"tts-audio{0}") def get_text(text, hps, is_symbol): text_norm = text_to_sequence(text, hps.symbols, [] if is_symbol else hps.data.text_cleaners) if hps.data.add_blank: text_norm = commons.intersperse(text_norm, 0) text_norm = LongTensor(text_norm) return text_norm def create_tts_fn(model, hps, speaker_ids): def tts_fn(text, speaker, speed, is_symbol): if limitation: text_len = len(re.sub("\[([A-Z]{2})\]", "", text)) max_len = 150 if is_symbol: max_len *= 3 if text_len > max_len: return "Error: Text is too long", None speaker_id = speaker_ids[speaker] stn_tst = get_text(text, hps, is_symbol) with no_grad(): x_tst = stn_tst.unsqueeze(0).to(device) x_tst_lengths = LongTensor([stn_tst.size(0)]).to(device) sid = LongTensor([speaker_id]).to(device) audio = model.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=.667, noise_scale_w=0.8, length_scale=1.0 / speed)[0][0, 0].data.cpu().float().numpy() del stn_tst, x_tst, x_tst_lengths, sid return "Success", (hps.data.sampling_rate, audio) return tts_fn def create_to_symbol_fn(hps): def to_symbol_fn(is_symbol_input, input_text, temp_text): return (_clean_text(input_text, hps.data.text_cleaners), input_text) if is_symbol_input \ else (temp_text, temp_text) return to_symbol_fn def main(input): models_tts = [] models_vc = [] models_soft_vc = [] device = torch.device("cpu") global result with open("saved_model/info.json", "r", encoding="utf-8") as f: models_info = json.load(f) for i, info in models_info.items(): if int(i) == 0: name = info["title"] author = info["author"] lang = info["lang"] example = info["example"] config_path = f"saved_model/{i}/config.json" model_path = f"saved_model/{i}/model.pth" cover = info["cover"] cover_path = f"saved_model/{i}/{cover}" if cover else None hps = utils.get_hparams_from_file(config_path) model = SynthesizerTrn( len(hps.symbols), hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, n_speakers=hps.data.n_speakers, **hps.model) utils.load_checkpoint(model_path, model, None) model.eval().to(device) speaker_ids = [sid for sid, name in enumerate(hps.speakers) if name != "None"] speakers = [name for sid, name in enumerate(hps.speakers) if name != "None"] # input_text = get_text("ヨスガノソラ", hps, True) print(speaker_ids[0]) vtts = create_tts_fn(model, hps, speaker_ids) symbol = create_to_symbol_fn(hps) result = vtts(input, speaker_ids[0], 1, False) # wf.write('anime_girl3.wav', result[1][0], result[1][1]) # print(type(result[1][0]), result[1][0]) # download.click(None, [], [], _js=download_audio_js.format(audio_id=f"tts-audio{0}")) # return result[1][0], result[1][1] wf.write(os.path.join(os.path.dirname(__file__), 'audio/animegirl.wav'), result[1][0], result[1][1]) return """ """.format(os.path.join(os.path.dirname(__file__), 'audio/animegirl.wav')) # return str(result[1][1]) # base64.b64encode(open("animegirl.wav").read()) # return str(json.dumps(result[1][1])) # result[1][1] print(models_tts) demo = gr.Interface(fn=main, inputs="text", outputs="html") # outputs=gr.outputs.Textbox(label="outputAudio")) if __name__ == "__main__": demo.launch(debug=True) # main(input = "あなたと一緒にいると、とても興奮します")