DemoLou commited on
Commit
b19d8fe
1 Parent(s): f177ee9

Create test.py

Browse files
Files changed (1) hide show
  1. test.py +123 -0
test.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import re
5
+ import tempfile
6
+ from pathlib import Path
7
+
8
+ import librosa
9
+ import numpy as np
10
+ import torch
11
+ from torch import no_grad, LongTensor
12
+ import commons
13
+ import utils
14
+ import gradio as gr
15
+ import gradio.utils as gr_utils
16
+ import gradio.processing_utils as gr_processing_utils
17
+ from models import SynthesizerTrn
18
+ from text import text_to_sequence, _clean_text
19
+ from mel_processing import spectrogram_torch
20
+ # import sounddevice as sd
21
+ # from scipy.io.wavfile import write
22
+ # import scikits.audiolab
23
+ # import soundfile as sf
24
+ import scipy.io.wavfile as wf
25
+
26
+ limitation = False
27
+ device = torch.device('cpu')
28
+
29
+
30
+ # fs = 44100
31
+ # data = np.random.uniform(-1, 1, fs)
32
+ # sd.play(data, fs)
33
+ # rate = 44100
34
+ # data = np.random.uniform(-1, 1, rate) # 1 second worth of random samples between -1 and 1
35
+ # scaled = np.int16(data / np.max(np.abs(data)) * 32767)
36
+ # write('test.wav', rate, scaled)
37
+ # data = np.random.uniform(-1, 1, 44100)
38
+ # sf.write('new_file.wav', data, 44100)
39
+
40
+ def get_text(text, hps, is_symbol):
41
+ text_norm = text_to_sequence(text, hps.symbols, [] if is_symbol else hps.data.text_cleaners)
42
+ if hps.data.add_blank:
43
+ text_norm = commons.intersperse(text_norm, 0)
44
+ text_norm = LongTensor(text_norm)
45
+ return text_norm
46
+
47
+ def create_tts_fn(model, hps, speaker_ids):
48
+ def tts_fn(text, speaker, speed, is_symbol):
49
+ if limitation:
50
+ text_len = len(re.sub("\[([A-Z]{2})\]", "", text))
51
+ max_len = 150
52
+ if is_symbol:
53
+ max_len *= 3
54
+ if text_len > max_len:
55
+ return "Error: Text is too long", None
56
+
57
+ speaker_id = speaker_ids[speaker]
58
+ stn_tst = get_text(text, hps, is_symbol)
59
+ with no_grad():
60
+ x_tst = stn_tst.unsqueeze(0).to(device)
61
+ x_tst_lengths = LongTensor([stn_tst.size(0)]).to(device)
62
+ sid = LongTensor([speaker_id]).to(device)
63
+ audio = model.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=.667, noise_scale_w=0.8,
64
+ length_scale=1.0 / speed)[0][0, 0].data.cpu().float().numpy()
65
+ del stn_tst, x_tst, x_tst_lengths, sid
66
+ return "Success", (hps.data.sampling_rate, audio)
67
+
68
+ return tts_fn
69
+
70
+ def create_to_symbol_fn(hps):
71
+ def to_symbol_fn(is_symbol_input, input_text, temp_text):
72
+ return (_clean_text(input_text, hps.data.text_cleaners), input_text) if is_symbol_input \
73
+ else (temp_text, temp_text)
74
+
75
+ return to_symbol_fn
76
+
77
+ def main(input):
78
+ models_tts = []
79
+ models_vc = []
80
+ models_soft_vc = []
81
+ device = torch.device("cpu")
82
+ global result
83
+ with open("saved_model/info.json", "r", encoding="utf-8") as f:
84
+ models_info = json.load(f)
85
+ for i, info in models_info.items():
86
+ if int(i) == 0:
87
+ name = info["title"]
88
+ author = info["author"]
89
+ lang = info["lang"]
90
+ example = info["example"]
91
+ config_path = f"saved_model/{i}/config.json"
92
+ model_path = f"saved_model/{i}/model.pth"
93
+ cover = info["cover"]
94
+ cover_path = f"saved_model/{i}/{cover}" if cover else None
95
+ hps = utils.get_hparams_from_file(config_path)
96
+ model = SynthesizerTrn(
97
+ len(hps.symbols),
98
+ hps.data.filter_length // 2 + 1,
99
+ hps.train.segment_size // hps.data.hop_length,
100
+ n_speakers=hps.data.n_speakers,
101
+ **hps.model)
102
+ utils.load_checkpoint(model_path, model, None)
103
+ model.eval().to(device)
104
+ speaker_ids = [sid for sid, name in enumerate(hps.speakers) if name != "None"]
105
+ speakers = [name for sid, name in enumerate(hps.speakers) if name != "None"]
106
+ # input_text = get_text("ヨスガノソラ", hps, True)
107
+ print(speaker_ids[0])
108
+ vtts = create_tts_fn(model, hps, speaker_ids)
109
+ symbol = create_to_symbol_fn(hps)
110
+ result = vtts(input, speaker_ids[0], 1, False)
111
+ # wf.write('anime_girl3.wav', result[1][0], result[1][1])
112
+ # print(type(result[1][0]), result[1][0])
113
+ return result[1][1]
114
+ print(models_tts)
115
+
116
+ tts_output2 = gr.Audio(label="Output Audio", elem_id=f"tts-audio{0}")
117
+
118
+ demo = gr.Interface(fn=main, ["あなたと一緒にいると、とても興奮します"], [tts_output2])
119
+
120
+ if __name__ == "__main__":
121
+ demo.launch()
122
+
123
+ # main(input = "あなたと一緒にいると、とても興奮します")