File size: 4,718 Bytes
29dc912
b6e73ca
 
2f6d6ff
b6e73ca
 
 
 
 
700e801
b6e73ca
 
 
 
29dc912
 
 
 
 
 
 
 
 
 
a90c9d1
29dc912
 
 
 
 
 
 
a90c9d1
29dc912
 
 
 
 
 
 
 
 
 
 
 
4c9e450
29dc912
a90c9d1
4c9e450
700e801
29dc912
670759b
29dc912
 
 
 
 
 
 
 
 
 
8f95475
700e801
29dc912
 
 
 
 
 
 
 
 
 
8f95475
 
 
 
 
29dc912
8f95475
 
 
 
 
 
 
 
 
 
 
700e801
8f95475
 
 
 
670759b
8f95475
 
 
 
 
 
 
 
 
 
670759b
 
8f95475
 
 
a1cd11f
8f95475
29dc912
8f95475
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# coding=utf-8
import logging
import sys
import os
logging.getLogger('numba').setLevel(logging.WARNING)

logging.basicConfig(
    format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    level=os.environ.get("LOGLEVEL", "INFO").upper(),
    stream=sys.stdout,
)
logger = logging.getLogger("APP")

import time
import os
import gradio as gr
import utils
import argparse
import commons
from models import SynthesizerTrn
from text import text_to_sequence
import torch
from torch import no_grad, LongTensor
from gradio_client import utils as client_utils
limitation = os.getenv("SYSTEM") == "spaces"  # limit text and audio length in huggingface spaces

audio_postprocess_ori = gr.Audio.postprocess
def audio_postprocess(self, y):
    data = audio_postprocess_ori(self, y)
    if data is None:
        return None
    return client_utils.encode_url_or_file_to_base64(data["name"])
gr.Audio.postprocess = audio_postprocess

def get_text(text, hps):
    text_norm, clean_text = text_to_sequence(text, hps.symbols, hps.data.text_cleaners)
    if hps.data.add_blank:
        text_norm = commons.intersperse(text_norm, 0)
    text_norm = LongTensor(text_norm)
    return text_norm, clean_text

def vits(text, language, speaker_id, noise_scale, noise_scale_w, length_scale):
    start = time.perf_counter()
    if not len(text):
        return None
    text = text.replace('\n', ' ').replace('\r', '').replace(" ", "")
    if len(text) > 200 and limitation:
        return None
    if language == "中文":
        text = f"[ZH]{text}[ZH]"
    elif language == "日语":
        text = f"[JA]{text}[JA]"
    else:
        text = f"{text}"
    stn_tst, clean_text = get_text(text, hps_ms)
    with no_grad():
        x_tst = stn_tst.unsqueeze(0).to(device)
        x_tst_lengths = LongTensor([stn_tst.size(0)]).to(device)
        speaker_id = LongTensor([speaker_id]).to(device)
        audio = net_g_ms.infer(x_tst, x_tst_lengths, sid=speaker_id, noise_scale=noise_scale, noise_scale_w=noise_scale_w,
                               length_scale=length_scale)[0][0, 0].data.cpu().float().numpy()
    logger.info(f"gen: {(text[:100], language, speaker_id, noise_scale, noise_scale_w, length_scale)}")
    return (22050, audio)

def search_speaker(search_value):
    for s in speakers:
        if search_value == s:
            return s
    for s in speakers:
        if search_value in s:
            return s


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--device', type=str, default='cpu')
    args = parser.parse_args()
    device = torch.device(args.device)

    hps_ms = utils.get_hparams_from_file(r'./model/config.json')
    net_g_ms = SynthesizerTrn(
        len(hps_ms.symbols),
        hps_ms.data.filter_length // 2 + 1,
        hps_ms.train.segment_size // hps_ms.data.hop_length,
        n_speakers=hps_ms.data.n_speakers,
        **hps_ms.model)
    _ = net_g_ms.eval().to(device)
    speakers = hps_ms.speakers
    speakers = [f"{i}.{s}" for i, s in enumerate(speakers)]
    model, optimizer, learning_rate, epochs = utils.load_checkpoint(r'./model/G_953000.pth', net_g_ms, None)

    app = gr.Interface(
        fn=vits,
        inputs=[
            gr.Textbox(label="Text (200 words limitation)", lines=5, value="可莉不知道哦!", elem_id=f"input-text"),
            gr.Radio(label="language", choices=["中文", "日语", "中日混合(格式参考下面的example)"], value="中文"),
            gr.Dropdown(label="Speaker", choices=speakers, type="index", value=speakers[329]),
            gr.Slider(label="noise_scale (控制感情变化程度)", minimum=0.1, maximum=1.0, step=0.1, value=0.1, interactive=True),
            gr.Slider(label="noise_scale_w (控制音素发音长度)", minimum=0.1, maximum=1.0, step=0.1, value=0.7, interactive=True),
            gr.Slider(label="length_scale (控制整体语速)", minimum=0.1, maximum=2.0, step=0.1, value=1.2, interactive=True),
        ],
        outputs=gr.Audio(label="Output Audio", elem_id=f"tts-audio"),
        examples=[
            ["可莉不知道哦!", "中文", speakers[329], 0.1, 0.6, 1.2],
            ["该做什么好呢?", "中文", speakers[104], 0.1, 0.8, 1.2],
            ["我给你讲个故事吧!", "中文", speakers[122], 0.1, 0.8, 1.2],
            ["おはようございます~", "日语", speakers[335], 0.1, 0.6, 1.2],
            ["[ZH]我会用日语说早上好啦![ZH][JA]おはようございます~[JA]", "中日混合", speakers[317], 0.1, 0.6, 1.2],
        ],
        title="VITS Genshin",
        description="",
        cache_examples=False
    )

    app.queue(concurrency_count=1)
    app.launch()