Sang-Hoon Lee commited on
Commit
6ed416b
1 Parent(s): 6d99823

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +236 -0
app.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import argparse
4
+ import numpy as np
5
+ from scipy.io.wavfile import write
6
+ import torchaudio
7
+ import utils
8
+ from Mels_preprocess import MelSpectrogramFixed
9
+
10
+ from hierspeechpp_speechsynthesizer import (
11
+ SynthesizerTrn
12
+ )
13
+ from ttv_v1.text import text_to_sequence
14
+ from ttv_v1.t2w2v_transformer import SynthesizerTrn as Text2W2V
15
+ from speechsr24k.speechsr import SynthesizerTrn as SpeechSR24
16
+ from speechsr48k.speechsr import SynthesizerTrn as SpeechSR48
17
+ from denoiser.generator import MPNet
18
+ from denoiser.infer import denoise
19
+
20
+ import gradio as gr
21
+
22
+ def load_text(fp):
23
+ with open(fp, 'r') as f:
24
+ filelist = [line.strip() for line in f.readlines()]
25
+ return filelist
26
+ def load_checkpoint(filepath, device):
27
+ print(filepath)
28
+ assert os.path.isfile(filepath)
29
+ print("Loading '{}'".format(filepath))
30
+ checkpoint_dict = torch.load(filepath, map_location=device)
31
+ print("Complete.")
32
+ return checkpoint_dict
33
+ def get_param_num(model):
34
+ num_param = sum(param.numel() for param in model.parameters())
35
+ return num_param
36
+ def intersperse(lst, item):
37
+ result = [item] * (len(lst) * 2 + 1)
38
+ result[1::2] = lst
39
+ return result
40
+ def add_blank_token(text):
41
+
42
+ text_norm = intersperse(text, 0)
43
+ text_norm = torch.LongTensor(text_norm)
44
+ return text_norm
45
+
46
+ def tts(text,
47
+ prompt,
48
+ ttv_temperature,
49
+ vc_temperature,
50
+ duratuion_temperature,
51
+ duratuion_length,
52
+ denoise_ratio,
53
+ random_seed):
54
+
55
+ torch.manual_seed(random_seed)
56
+ torch.cuda.manual_seed(random_seed)
57
+ np.random.seed(random_seed)
58
+
59
+ text_len = len(text)
60
+ if text_len > 200:
61
+ raise gr.Error("Text length limited to 200 characters for this demo. Current text length is " + str(text_len))
62
+
63
+ else:
64
+ text = text_to_sequence(str(text), ["english_cleaners2"])
65
+
66
+ token = add_blank_token(text).unsqueeze(0).cuda()
67
+ token_length = torch.LongTensor([token.size(-1)]).cuda()
68
+
69
+ # Prompt load
70
+ # sample_rate, audio = prompt
71
+ # audio = torch.FloatTensor([audio]).cuda()
72
+ # if audio.shape[0] != 1:
73
+ # audio = audio[:1,:]
74
+ # audio = audio / 32768
75
+ audio, sample_rate = torchaudio.load(prompt)
76
+
77
+ # support only single channel
78
+
79
+ # Resampling
80
+ if sample_rate != 16000:
81
+ audio = torchaudio.functional.resample(audio, sample_rate, 16000, resampling_method="kaiser_window")
82
+
83
+ # We utilize a hop size of 320 but denoiser uses a hop size of 400 so we utilize a hop size of 1600
84
+ ori_prompt_len = audio.shape[-1]
85
+ p = (ori_prompt_len // 1600 + 1) * 1600 - ori_prompt_len
86
+ audio = torch.nn.functional.pad(audio, (0, p), mode='constant').data
87
+
88
+ # If you have a memory issue during denosing the prompt, try to denoise the prompt with cpu before TTS
89
+ # We will have a plan to replace a memory-efficient denoiser
90
+ if denoise == 0:
91
+ audio = torch.cat([audio.cuda(), audio.cuda()], dim=0)
92
+ else:
93
+ with torch.no_grad():
94
+
95
+ if ori_prompt_len > 80000:
96
+ denoised_audio = []
97
+ for i in range((ori_prompt_len//80000)):
98
+ denoised_audio.append(denoise(audio.squeeze(0).cuda()[i*80000:(i+1)*80000], denoiser, hps_denoiser))
99
+
100
+ denoised_audio.append(denoise(audio.squeeze(0).cuda()[(i+1)*80000:], denoiser, hps_denoiser))
101
+ denoised_audio = torch.cat(denoised_audio, dim=1)
102
+ else:
103
+ denoised_audio = denoise(audio.squeeze(0).cuda(), denoiser, hps_denoiser)
104
+
105
+ audio = torch.cat([audio.cuda(), denoised_audio[:,:audio.shape[-1]]], dim=0)
106
+
107
+ audio = audio[:,:ori_prompt_len] # 20231108 We found that large size of padding decreases a performance so we remove the paddings after denosing.
108
+
109
+ if audio.shape[-1]<48000:
110
+ audio = torch.cat([audio,audio,audio,audio,audio], dim=1)
111
+
112
+ src_mel = mel_fn(audio.cuda())
113
+
114
+ src_length = torch.LongTensor([src_mel.size(2)]).to(device)
115
+ src_length2 = torch.cat([src_length,src_length], dim=0)
116
+
117
+ ## TTV (Text --> W2V, F0)
118
+ with torch.no_grad():
119
+ w2v_x, pitch = text2w2v.infer_noise_control(token, token_length, src_mel, src_length2,
120
+ noise_scale=ttv_temperature, noise_scale_w=duratuion_temperature,
121
+ length_scale=duratuion_length, denoise_ratio=denoise_ratio)
122
+ src_length = torch.LongTensor([w2v_x.size(2)]).cuda()
123
+
124
+ pitch[pitch<torch.log(torch.tensor([55]).cuda())] = 0
125
+
126
+ ## Hierarchical Speech Synthesizer (W2V, F0 --> 16k Audio)
127
+ converted_audio = \
128
+ net_g.voice_conversion_noise_control(w2v_x, src_length, src_mel, src_length2, pitch, noise_scale=vc_temperature, denoise_ratio=denoise_ratio)
129
+
130
+ converted_audio = speechsr(converted_audio)
131
+
132
+ converted_audio = converted_audio.squeeze()
133
+
134
+ converted_audio = converted_audio / (torch.abs(converted_audio).max()) * 32767.0 * 0.999
135
+ converted_audio = converted_audio.cpu().numpy().astype('int16')
136
+
137
+ write('output.wav', 48000, converted_audio)
138
+ return 'output.wav'
139
+
140
+ def main():
141
+ print('Initializing Inference Process..')
142
+
143
+ parser = argparse.ArgumentParser()
144
+ parser.add_argument('--input_prompt', default='example/steve-jobs-2005.wav')
145
+ parser.add_argument('--input_txt', default='example/abstract.txt')
146
+ parser.add_argument('--output_dir', default='output')
147
+ parser.add_argument('--ckpt', default='./logs/hierspeechpp_eng_kor/hierspeechpp_v2_ckpt.pth')
148
+ parser.add_argument('--ckpt_text2w2v', '-ct', help='text2w2v checkpoint path', default='./logs/ttv_libritts_v1/ttv_lt960_ckpt.pth')
149
+ parser.add_argument('--ckpt_sr', type=str, default='./speechsr24k/G_340000.pth')
150
+ parser.add_argument('--ckpt_sr48', type=str, default='./speechsr48k/G_100000.pth')
151
+ parser.add_argument('--denoiser_ckpt', type=str, default='denoiser/g_best')
152
+ parser.add_argument('--scale_norm', type=str, default='max')
153
+ parser.add_argument('--output_sr', type=float, default=48000)
154
+ parser.add_argument('--noise_scale_ttv', type=float,
155
+ default=0.333)
156
+ parser.add_argument('--noise_scale_vc', type=float,
157
+ default=0.333)
158
+ parser.add_argument('--denoise_ratio', type=float,
159
+ default=0.8)
160
+ parser.add_argument('--duration_ratio', type=float,
161
+ default=0.8)
162
+ parser.add_argument('--seed', type=int,
163
+ default=1111)
164
+ a = parser.parse_args()
165
+
166
+ global device, hps, hps_t2w2v,h_sr,h_sr48, hps_denoiser
167
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
168
+
169
+ hps = utils.get_hparams_from_file(os.path.join(os.path.split(a.ckpt)[0], 'config.json'))
170
+ hps_t2w2v = utils.get_hparams_from_file(os.path.join(os.path.split(a.ckpt_text2w2v)[0], 'config.json'))
171
+ h_sr = utils.get_hparams_from_file(os.path.join(os.path.split(a.ckpt_sr)[0], 'config.json') )
172
+ h_sr48 = utils.get_hparams_from_file(os.path.join(os.path.split(a.ckpt_sr48)[0], 'config.json') )
173
+ hps_denoiser = utils.get_hparams_from_file(os.path.join(os.path.split(a.denoiser_ckpt)[0], 'config.json'))
174
+
175
+ global mel_fn, net_g, text2w2v, speechsr, denoiser
176
+
177
+ mel_fn = MelSpectrogramFixed(
178
+ sample_rate=hps.data.sampling_rate,
179
+ n_fft=hps.data.filter_length,
180
+ win_length=hps.data.win_length,
181
+ hop_length=hps.data.hop_length,
182
+ f_min=hps.data.mel_fmin,
183
+ f_max=hps.data.mel_fmax,
184
+ n_mels=hps.data.n_mel_channels,
185
+ window_fn=torch.hann_window
186
+ ).cuda()
187
+
188
+ net_g = SynthesizerTrn(hps.data.filter_length // 2 + 1,
189
+ hps.train.segment_size // hps.data.hop_length,
190
+ **hps.model).cuda()
191
+ net_g.load_state_dict(torch.load(a.ckpt))
192
+ _ = net_g.eval()
193
+
194
+ text2w2v = Text2W2V(hps.data.filter_length // 2 + 1,
195
+ hps.train.segment_size // hps.data.hop_length,
196
+ **hps_t2w2v.model).cuda()
197
+ text2w2v.load_state_dict(torch.load(a.ckpt_text2w2v))
198
+ text2w2v.eval()
199
+
200
+ speechsr = SpeechSR48(h_sr48.data.n_mel_channels,
201
+ h_sr48.train.segment_size // h_sr48.data.hop_length,
202
+ **h_sr48.model).cuda()
203
+ utils.load_checkpoint(a.ckpt_sr48, speechsr, None)
204
+ speechsr.eval()
205
+
206
+ denoiser = MPNet(hps_denoiser).cuda()
207
+ state_dict = load_checkpoint(a.denoiser_ckpt, device)
208
+ denoiser.load_state_dict(state_dict['generator'])
209
+ denoiser.eval()
210
+
211
+ demo_play = gr.Interface(fn = tts,
212
+ inputs = [gr.Textbox(max_lines=6, label="Input Text", value="HierSpeech is a zero shot speech synthesis model, which can generate high-quality audio", info="Up to 200 characters"),
213
+ gr.Audio(type='filepath', value="./example/3_rick_gt.wav"),
214
+ gr.Slider(0,1,0.333),
215
+ gr.Slider(0,1,0.333),
216
+ gr.Slider(0,1,1.0),
217
+ gr.Slider(0.5,2,1.0),
218
+ gr.Slider(0,1,0),
219
+ gr.Slider(0,9999,1111)],
220
+ outputs = 'audio',
221
+ title = 'HierSpeech++',
222
+ description = '''<div>
223
+ <p style="text-align: left"> HierSpeech++ is a zero-shot speech synthesis model.</p>
224
+ <p style="text-align: left"> Our model is trained with LibriTTS dataset so this model only supports english. We will release a multi-lingual HierSpeech++ soon.</p>
225
+ <p style="text-align: left"> <a href="https://sh-lee-prml.github.io/HierSpeechpp-demo/">[Demo Page]</a> <a href="https://github.com/sh-lee-prml/HierSpeechpp">[Source Code]</a></p>
226
+ </div>''',
227
+ examples=[["HierSpeech is a zero-shot speech synthesis model, which can generate high-quality audio", "./example/3_rick_gt.wav", 0.333,0.333, 1.0, 1.0, 0, 1111],
228
+ ["HierSpeech is a zero-shot speech synthesis model, which can generate high-quality audio", "./example/ex01_whisper_00359.wav", 0.333,0.333, 1.0, 1.0, 0, 1111],
229
+ ["Hi there, I'm your new voice clone. Try your best to upload quality audio", "./example/female.wav", 0.333,0.333, 1.0, 1.0, 0, 1111],
230
+ ["Hello I'm HierSpeech++", "./example/reference_1.wav", 0.333,0.333, 1.0, 1.0, 0, 1111],
231
+ ]
232
+ )
233
+ demo_play.launch(share=True, server_port=8888)
234
+
235
+ if __name__ == '__main__':
236
+ main()