|
import argparse |
|
import json |
|
import os |
|
import re |
|
import tempfile |
|
from pathlib import Path |
|
|
|
import librosa |
|
import numpy as np |
|
import torch |
|
from torch import no_grad, LongTensor |
|
import commons |
|
import utils |
|
import gradio as gr |
|
import gradio.utils as gr_utils |
|
import gradio.processing_utils as gr_processing_utils |
|
from models import SynthesizerTrn |
|
from text import text_to_sequence, _clean_text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
limitation = False |
|
device = torch.device('cpu') |
|
|
|
|
|
download_audio_js = """ |
|
() =>{{ |
|
let root = document.querySelector("body > gradio-app"); |
|
if (root.shadowRoot != null) |
|
root = root.shadowRoot; |
|
let audio = root.querySelector("#{audio_id}").querySelector("audio"); |
|
if (audio == undefined) |
|
return; |
|
audio = audio.src; |
|
let oA = document.createElement("a"); |
|
oA.download = Math.floor(Math.random()*100000000)+'.wav'; |
|
oA.href = audio; |
|
document.body.appendChild(oA); |
|
oA.click(); |
|
oA.remove(); |
|
}} |
|
""" |
|
|
|
|
|
tts_input1 = gr.TextArea(label="inputText", value="あなたと一緒にいると、とても興奮します", elem_id=f"tts-input{0}") |
|
tts_output2 = gr.Audio(label="outputAudio", elem_id=f"tts-audio{0}") |
|
|
|
def get_text(text, hps, is_symbol): |
|
text_norm = text_to_sequence(text, hps.symbols, [] if is_symbol else hps.data.text_cleaners) |
|
if hps.data.add_blank: |
|
text_norm = commons.intersperse(text_norm, 0) |
|
text_norm = LongTensor(text_norm) |
|
return text_norm |
|
|
|
def create_tts_fn(model, hps, speaker_ids): |
|
def tts_fn(text, speaker, speed, is_symbol): |
|
if limitation: |
|
text_len = len(re.sub("\[([A-Z]{2})\]", "", text)) |
|
max_len = 150 |
|
if is_symbol: |
|
max_len *= 3 |
|
if text_len > max_len: |
|
return "Error: Text is too long", None |
|
|
|
speaker_id = speaker_ids[speaker] |
|
stn_tst = get_text(text, hps, is_symbol) |
|
with no_grad(): |
|
x_tst = stn_tst.unsqueeze(0).to(device) |
|
x_tst_lengths = LongTensor([stn_tst.size(0)]).to(device) |
|
sid = LongTensor([speaker_id]).to(device) |
|
audio = model.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=.667, noise_scale_w=0.8, |
|
length_scale=1.0 / speed)[0][0, 0].data.cpu().float().numpy() |
|
del stn_tst, x_tst, x_tst_lengths, sid |
|
return "Success", (hps.data.sampling_rate, audio) |
|
|
|
return tts_fn |
|
|
|
def create_to_symbol_fn(hps): |
|
def to_symbol_fn(is_symbol_input, input_text, temp_text): |
|
return (_clean_text(input_text, hps.data.text_cleaners), input_text) if is_symbol_input \ |
|
else (temp_text, temp_text) |
|
|
|
return to_symbol_fn |
|
|
|
def main(input, authorization): |
|
if os.environ['API_TOKEN'] != authorization: |
|
return |
|
models_tts = [] |
|
models_vc = [] |
|
models_soft_vc = [] |
|
device = torch.device("cpu") |
|
global result |
|
with open("saved_model/info.json", "r", encoding="utf-8") as f: |
|
models_info = json.load(f) |
|
for i, info in models_info.items(): |
|
if int(i) == 0: |
|
name = info["title"] |
|
author = info["author"] |
|
lang = info["lang"] |
|
example = info["example"] |
|
config_path = f"saved_model/{i}/config.json" |
|
model_path = f"saved_model/{i}/model.pth" |
|
cover = info["cover"] |
|
cover_path = f"saved_model/{i}/{cover}" if cover else None |
|
hps = utils.get_hparams_from_file(config_path) |
|
model = SynthesizerTrn( |
|
len(hps.symbols), |
|
hps.data.filter_length // 2 + 1, |
|
hps.train.segment_size // hps.data.hop_length, |
|
n_speakers=hps.data.n_speakers, |
|
**hps.model) |
|
utils.load_checkpoint(model_path, model, None) |
|
model.eval().to(device) |
|
speaker_ids = [sid for sid, name in enumerate(hps.speakers) if name != "None"] |
|
speakers = [name for sid, name in enumerate(hps.speakers) if name != "None"] |
|
|
|
print(speaker_ids[0]) |
|
vtts = create_tts_fn(model, hps, speaker_ids) |
|
symbol = create_to_symbol_fn(hps) |
|
result = vtts(input, speaker_ids[0], 1, False) |
|
|
|
|
|
|
|
return result[1][0], result[1][1] |
|
|
|
print(models_tts) |
|
|
|
|
|
demo = gr.Interface(fn=main, inputs=["text", "text"], outputs="audio") |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch(debug=True) |
|
|
|
|
|
|