Shivam Mehta commited on
Commit
23f59c4
1 Parent(s): 2f40390

Adding multispeaker support for huggingface space

Browse files
Files changed (1) hide show
  1. app.py +266 -154
app.py CHANGED
@@ -5,7 +5,7 @@ from pathlib import Path
5
  import gradio as gr
6
  import soundfile as sf
7
  import torch
8
- from matcha.cli import (MATCHA_URLS, VOCODER_URL, assert_model_downloaded,
9
  get_device, load_matcha, load_vocoder, process_text,
10
  to_waveform)
11
  from matcha.utils.utils import get_user_data_dir, plot_tensor
@@ -16,18 +16,57 @@ args = Namespace(
16
  cpu=False,
17
  model="matcha_ljspeech",
18
  vocoder="hifigan_T2_v1",
19
- spk=None,
20
  )
21
 
22
- MATCHA_TTS_LOC = LOCATION / f"{args.model}.ckpt"
23
- VOCODER_LOC = LOCATION / f"{args.vocoder}"
 
24
  LOGO_URL = "https://shivammehta25.github.io/Matcha-TTS/images/logo.png"
25
- assert_model_downloaded(MATCHA_TTS_LOC, MATCHA_URLS[args.model])
26
- assert_model_downloaded(VOCODER_LOC, VOCODER_URL[args.vocoder])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  device = get_device(args)
28
 
29
- model = load_matcha(args.model, MATCHA_TTS_LOC, device)
30
- vocoder, denoiser = load_vocoder(args.vocoder, VOCODER_LOC, device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
 
33
  @torch.inference_mode()
@@ -37,173 +76,246 @@ def process_text_gradio(text):
37
 
38
 
39
  @torch.inference_mode()
40
- def synthesise_mel(text, text_length, n_timesteps, temperature, length_scale):
41
- output = model.synthesise(
42
- text,
43
- text_length,
44
- n_timesteps=n_timesteps,
45
- temperature=temperature,
46
- spks=args.spk,
47
- length_scale=length_scale,
48
- )
49
- output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
51
  sf.write(fp.name, output["waveform"], 22050, "PCM_24")
52
 
53
  return fp.name, plot_tensor(output["mel"].squeeze().cpu().numpy())
54
 
55
 
56
- def run_full_synthesis(text, n_timesteps, mel_temp, length_scale):
 
 
 
 
 
 
57
  phones, text, text_lengths = process_text_gradio(text)
58
- audio, mel_spectrogram = synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale)
59
  return phones, audio, mel_spectrogram
60
 
61
 
62
- def main():
63
- description = """# 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching
64
- ### [Shivam Mehta](https://www.kth.se/profile/smehta), [Ruibo Tu](https://www.kth.se/profile/ruibo), [Jonas Beskow](https://www.kth.se/profile/beskow), [Éva Székely](https://www.kth.se/profile/szekely), and [Gustav Eje Henter](https://people.kth.se/~ghe/)
65
- We propose 🍵 Matcha-TTS, a new approach to non-autoregressive neural TTS, that uses conditional flow matching (similar to rectified flows) to speed up ODE-based speech synthesis. Our method:
66
 
67
 
68
- * Is probabilistic
69
- * Has compact memory footprint
70
- * Sounds highly natural
71
- * Is very fast to synthesise from
72
 
73
 
74
- Check out our [demo page](https://shivammehta25.github.io/Matcha-TTS). Read our [arXiv preprint for more details](https://arxiv.org/abs/2309.03199).
75
- Code is available in our [GitHub repository](https://github.com/shivammehta25/Matcha-TTS), along with pre-trained models.
76
 
77
- Cached examples are available at the bottom of the page.
78
-
79
- Note: Synthesis speed may be slower than in our paper due to I/O latency and because this instance runs on CPUs.
80
- """
81
-
82
- with gr.Blocks(title="🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching") as demo:
83
- processed_text = gr.State(value=None)
84
- processed_text_len = gr.State(value=None)
85
-
86
- with gr.Box():
87
- with gr.Row():
88
- gr.Markdown(description, scale=3)
89
- gr.Image(LOGO_URL, label="Matcha-TTS logo", height=150, width=150, scale=1, show_label=False)
90
-
91
- with gr.Box():
92
- with gr.Row():
93
- gr.Markdown("# Text Input")
94
- with gr.Row():
95
- text = gr.Textbox(value="", lines=2, label="Text to synthesise")
96
-
97
- with gr.Row():
98
- gr.Markdown("### Hyper parameters")
99
- with gr.Row():
100
- n_timesteps = gr.Slider(
101
- label="Number of ODE steps",
102
- minimum=1,
103
- maximum=100,
104
- step=1,
105
- value=10,
106
- interactive=True,
107
- )
108
- length_scale = gr.Slider(
109
- label="Length scale (Speaking rate)",
110
- minimum=0.5,
111
- maximum=1.5,
112
- step=0.05,
113
- value=1.0,
114
- interactive=True,
115
- )
116
- mel_temp = gr.Slider(
117
- label="Sampling temperature",
118
- minimum=0.00,
119
- maximum=2.001,
120
- step=0.16675,
121
- value=0.667,
122
- interactive=True,
123
- )
124
-
125
- synth_btn = gr.Button("Synthesise")
126
-
127
- with gr.Box():
128
- with gr.Row():
129
- gr.Markdown("### Phonetised text")
130
- phonetised_text = gr.Textbox(interactive=False, scale=10, label="Phonetised text")
131
-
132
- with gr.Box():
133
- with gr.Row():
134
- mel_spectrogram = gr.Image(interactive=False, label="mel spectrogram")
135
-
136
- # with gr.Row():
137
- audio = gr.Audio(interactive=False, label="Audio")
138
 
 
 
 
 
 
139
  with gr.Row():
140
- examples = gr.Examples( # pylint: disable=unused-variable
141
- examples=[
142
- [
143
- "We propose Matcha-TTS, a new approach to non-autoregressive neural TTS, that uses conditional flow matching (similar to rectified flows) to speed up O D E-based speech synthesis.",
144
- 50,
145
- 0.677,
146
- 1.0,
147
- ],
148
- [
149
- "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
150
- 2,
151
- 0.677,
152
- 1.0,
153
- ],
154
- [
155
- "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
156
- 4,
157
- 0.677,
158
- 1.0,
159
- ],
160
- [
161
- "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
162
- 10,
163
- 0.677,
164
- 1.0,
165
- ],
166
- [
167
- "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
168
- 50,
169
- 0.677,
170
- 1.0,
171
- ],
172
- [
173
- "The narrative of these events is based largely on the recollections of the participants.",
174
- 10,
175
- 0.677,
176
- 1.0,
177
- ],
178
- [
179
- "The jury did not believe him, and the verdict was for the defendants.",
180
- 10,
181
- 0.677,
182
- 1.0,
183
- ],
184
- ],
185
- fn=run_full_synthesis,
186
- inputs=[text, n_timesteps, mel_temp, length_scale],
187
- outputs=[phonetised_text, audio, mel_spectrogram],
188
- cache_examples=True,
189
  )
190
 
191
- synth_btn.click(
192
- fn=process_text_gradio,
193
- inputs=[
194
- text,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  ],
196
- outputs=[phonetised_text, processed_text, processed_text_len],
197
- api_name="matcha_tts",
198
- queue=True,
199
- ).then(
200
- fn=synthesise_mel,
201
- inputs=[processed_text, processed_text_len, n_timesteps, mel_temp, length_scale],
202
- outputs=[audio, mel_spectrogram],
203
  )
204
 
205
- demo.queue(concurrency_count=5).launch()
 
 
 
 
206
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
- if __name__ == "__main__":
209
- main()
 
5
  import gradio as gr
6
  import soundfile as sf
7
  import torch
8
+ from matcha.cli import (MATCHA_URLS, VOCODER_URLS, assert_model_downloaded,
9
  get_device, load_matcha, load_vocoder, process_text,
10
  to_waveform)
11
  from matcha.utils.utils import get_user_data_dir, plot_tensor
 
16
  cpu=False,
17
  model="matcha_ljspeech",
18
  vocoder="hifigan_T2_v1",
19
+ spk=0,
20
  )
21
 
22
+
23
+ MATCHA_TTS_LOC = lambda x: LOCATION / f"{x}.ckpt" # noqa: E731
24
+ VOCODER_LOC = lambda x: LOCATION / f"{x}" # noqa: E731
25
  LOGO_URL = "https://shivammehta25.github.io/Matcha-TTS/images/logo.png"
26
+ RADIO_OPTIONS = {
27
+ "Multi Speaker (VCTK)": {
28
+ "model": "matcha_vctk",
29
+ "vocoder": "hifigan_univ_v1",
30
+ },
31
+ "Single Speaker (LJ Speech)": {
32
+ "model": "matcha_ljspeech",
33
+ "vocoder": "hifigan_T2_v1",
34
+ },
35
+ }
36
+
37
+ # Ensure all the required models are downloaded
38
+ assert_model_downloaded(MATCHA_TTS_LOC("matcha_ljspeech"), MATCHA_URLS["matcha_ljspeech"])
39
+ assert_model_downloaded(VOCODER_LOC("hifigan_T2_v1"), VOCODER_URLS["hifigan_T2_v1"])
40
+ assert_model_downloaded(MATCHA_TTS_LOC("matcha_vctk"), MATCHA_URLS["matcha_vctk"])
41
+ assert_model_downloaded(VOCODER_LOC("hifigan_univ_v1"), VOCODER_URLS["hifigan_univ_v1"])
42
+
43
+ # get device
44
  device = get_device(args)
45
 
46
+ # Load default models
47
+ matcha_ljspeech = load_matcha(args.model, MATCHA_TTS_LOC(args.model), device)
48
+ hifigan_T2_v1, hifigan_T2_v1_denoiser = load_vocoder(args.vocoder, VOCODER_LOC(args.vocoder), device)
49
+
50
+ matcha_vctk = load_matcha("matcha_vctk", MATCHA_TTS_LOC("matcha_vctk"), device)
51
+ hifigan_univ_v1, hifigan_univ_v1_denoiser = load_vocoder("hifigan_univ_v1", VOCODER_LOC("hifigan_univ_v1"), device)
52
+
53
+
54
+
55
+ def load_model_ui(model_type, textbox):
56
+ model_name = RADIO_OPTIONS[model_type]["model"]
57
+
58
+ if model_name == "matcha_ljspeech":
59
+ spk_slider = gr.update(visible=False, value=-1)
60
+ single_speaker_examples = gr.update(visible=True)
61
+ multi_speaker_examples = gr.update(visible=False)
62
+ length_scale = gr.update(value=0.95)
63
+ else:
64
+ spk_slider = gr.update(visible=True, value=0)
65
+ single_speaker_examples = gr.update(visible=False)
66
+ multi_speaker_examples = gr.update(visible=True)
67
+ length_scale = gr.update(value=0.85)
68
+
69
+ return textbox, gr.update(interactive=True), spk_slider, single_speaker_examples, multi_speaker_examples, length_scale
70
 
71
 
72
  @torch.inference_mode()
 
76
 
77
 
78
  @torch.inference_mode()
79
+ def synthesise_mel(text, text_length, n_timesteps, temperature, length_scale, spk):
80
+ spk = torch.tensor([spk], device=device, dtype=torch.long) if spk >= 0 else None
81
+
82
+ if spk is None:
83
+ output = matcha_ljspeech.synthesise(
84
+ text,
85
+ text_length,
86
+ n_timesteps=n_timesteps,
87
+ temperature=temperature,
88
+ spks=None,
89
+ length_scale=length_scale,
90
+ )
91
+ output["waveform"] = to_waveform(output["mel"], hifigan_T2_v1, hifigan_T2_v1_denoiser)
92
+ else:
93
+ output = matcha_vctk.synthesise(
94
+ text,
95
+ text_length,
96
+ n_timesteps=n_timesteps,
97
+ temperature=temperature,
98
+ spks=spk,
99
+ length_scale=length_scale,
100
+ )
101
+ output["waveform"] = to_waveform(output["mel"], hifigan_univ_v1, hifigan_univ_v1_denoiser)
102
+
103
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
104
  sf.write(fp.name, output["waveform"], 22050, "PCM_24")
105
 
106
  return fp.name, plot_tensor(output["mel"].squeeze().cpu().numpy())
107
 
108
 
109
+ def multispeaker_example_cacher(text, n_timesteps, mel_temp, length_scale, spk):
110
+ phones, text, text_lengths = process_text_gradio(text)
111
+ audio, mel_spectrogram = synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale, spk)
112
+ return phones, audio, mel_spectrogram
113
+
114
+
115
+ def ljspeech_example_cacher(text, n_timesteps, mel_temp, length_scale, spk=-1):
116
  phones, text, text_lengths = process_text_gradio(text)
117
+ audio, mel_spectrogram = synthesise_mel(text, text_lengths, n_timesteps, mel_temp, length_scale, spk)
118
  return phones, audio, mel_spectrogram
119
 
120
 
121
+ description = """# 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching
122
+ ### [Shivam Mehta](https://www.kth.se/profile/smehta), [Ruibo Tu](https://www.kth.se/profile/ruibo), [Jonas Beskow](https://www.kth.se/profile/beskow), [Éva Székely](https://www.kth.se/profile/szekely), and [Gustav Eje Henter](https://people.kth.se/~ghe/)
123
+ We propose 🍵 Matcha-TTS, a new approach to non-autoregressive neural TTS, that uses conditional flow matching (similar to rectified flows) to speed up ODE-based speech synthesis. Our method:
 
124
 
125
 
126
+ * Is probabilistic
127
+ * Has compact memory footprint
128
+ * Sounds highly natural
129
+ * Is very fast to synthesise from
130
 
131
 
132
+ Check out our [demo page](https://shivammehta25.github.io/Matcha-TTS). Read our [arXiv preprint for more details](https://arxiv.org/abs/2309.03199).
133
+ Code is available in our [GitHub repository](https://github.com/shivammehta25/Matcha-TTS), along with pre-trained models.
134
 
135
+ Cached examples are available at the bottom of the page.
136
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
+ with gr.Blocks(title="🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching") as demo:
139
+ processed_text = gr.State(value=None)
140
+ processed_text_len = gr.State(value=None)
141
+
142
+ with gr.Box():
143
  with gr.Row():
144
+ gr.Markdown(description, scale=3)
145
+ gr.Image(LOGO_URL, label="Matcha-TTS logo", height=150, width=150, scale=1, show_label=False)
146
+
147
+ with gr.Box():
148
+ radio_options = list(RADIO_OPTIONS.keys())
149
+ model_type = gr.Radio(
150
+ radio_options, value=radio_options[0], label="Choose a Model", interactive=True, container=False
151
+ )
152
+
153
+ with gr.Row():
154
+ gr.Markdown("# Text Input")
155
+ with gr.Row():
156
+ text = gr.Textbox(value="", lines=2, label="Text to synthesise", scale=3)
157
+ spk_slider = gr.Slider(
158
+ minimum=0, maximum=107, step=1, value=args.spk, label="Speaker ID", interactive=True, scale=1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  )
160
 
161
+ with gr.Row():
162
+ gr.Markdown("### Hyper parameters")
163
+ with gr.Row():
164
+ n_timesteps = gr.Slider(
165
+ label="Number of ODE steps",
166
+ minimum=1,
167
+ maximum=100,
168
+ step=1,
169
+ value=10,
170
+ interactive=True,
171
+ )
172
+ length_scale = gr.Slider(
173
+ label="Length scale (Speaking rate)",
174
+ minimum=0.5,
175
+ maximum=1.5,
176
+ step=0.05,
177
+ value=1.0,
178
+ interactive=True,
179
+ )
180
+ mel_temp = gr.Slider(
181
+ label="Sampling temperature",
182
+ minimum=0.00,
183
+ maximum=2.001,
184
+ step=0.16675,
185
+ value=0.667,
186
+ interactive=True,
187
+ )
188
+
189
+ synth_btn = gr.Button("Synthesise")
190
+
191
+ with gr.Box():
192
+ with gr.Row():
193
+ gr.Markdown("### Phonetised text")
194
+ phonetised_text = gr.Textbox(interactive=False, scale=10, label="Phonetised text")
195
+
196
+ with gr.Box():
197
+ with gr.Row():
198
+ mel_spectrogram = gr.Image(interactive=False, label="mel spectrogram")
199
+
200
+ # with gr.Row():
201
+ audio = gr.Audio(interactive=False, label="Audio")
202
+
203
+ with gr.Row(visible=False) as example_row_lj_speech:
204
+ examples = gr.Examples( # pylint: disable=unused-variable
205
+ examples=[
206
+ [
207
+ "We propose Matcha-TTS, a new approach to non-autoregressive neural TTS, that uses conditional flow matching (similar to rectified flows) to speed up O D E-based speech synthesis.",
208
+ 50,
209
+ 0.677,
210
+ 0.95,
211
+ ],
212
+ [
213
+ "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
214
+ 2,
215
+ 0.677,
216
+ 0.95,
217
+ ],
218
+ [
219
+ "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
220
+ 4,
221
+ 0.677,
222
+ 0.95,
223
+ ],
224
+ [
225
+ "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
226
+ 10,
227
+ 0.677,
228
+ 0.95,
229
+ ],
230
+ [
231
+ "The Secret Service believed that it was very doubtful that any President would ride regularly in a vehicle with a fixed top, even though transparent.",
232
+ 50,
233
+ 0.677,
234
+ 0.95,
235
+ ],
236
+ [
237
+ "The narrative of these events is based largely on the recollections of the participants.",
238
+ 10,
239
+ 0.677,
240
+ 0.95,
241
+ ],
242
+ [
243
+ "The jury did not believe him, and the verdict was for the defendants.",
244
+ 10,
245
+ 0.677,
246
+ 0.95,
247
+ ],
248
+ ],
249
+ fn=ljspeech_example_cacher,
250
+ inputs=[text, n_timesteps, mel_temp, length_scale],
251
+ outputs=[phonetised_text, audio, mel_spectrogram],
252
+ cache_examples=True,
253
+ )
254
+
255
+ with gr.Row() as example_row_multispeaker:
256
+ multi_speaker_examples = gr.Examples( # pylint: disable=unused-variable
257
+ examples=[
258
+ [
259
+ "Hello everyone! I am speaker 0 and I am here to tell you that Matcha-TTS is amazing!",
260
+ 10,
261
+ 0.677,
262
+ 0.85,
263
+ 0,
264
+ ],
265
+ [
266
+ "Hello everyone! I am speaker 16 and I am here to tell you that Matcha-TTS is amazing!",
267
+ 10,
268
+ 0.677,
269
+ 0.85,
270
+ 16,
271
+ ],
272
+ [
273
+ "Hello everyone! I am speaker 44 and I am here to tell you that Matcha-TTS is amazing!",
274
+ 50,
275
+ 0.677,
276
+ 0.85,
277
+ 44,
278
+ ],
279
+ [
280
+ "Hello everyone! I am speaker 45 and I am here to tell you that Matcha-TTS is amazing!",
281
+ 50,
282
+ 0.677,
283
+ 0.85,
284
+ 45,
285
+ ],
286
+ [
287
+ "Hello everyone! I am speaker 58 and I am here to tell you that Matcha-TTS is amazing!",
288
+ 4,
289
+ 0.677,
290
+ 0.85,
291
+ 58,
292
+ ],
293
  ],
294
+ fn=multispeaker_example_cacher,
295
+ inputs=[text, n_timesteps, mel_temp, length_scale, spk_slider],
296
+ outputs=[phonetised_text, audio, mel_spectrogram],
297
+ cache_examples=True,
298
+ label="Multi Speaker Examples",
 
 
299
  )
300
 
301
+ model_type.change(lambda x: gr.update(interactive=False), inputs=[synth_btn], outputs=[synth_btn]).then(
302
+ load_model_ui,
303
+ inputs=[model_type, text],
304
+ outputs=[text, synth_btn, spk_slider, example_row_lj_speech, example_row_multispeaker, length_scale],
305
+ )
306
 
307
+ synth_btn.click(
308
+ fn=process_text_gradio,
309
+ inputs=[
310
+ text,
311
+ ],
312
+ outputs=[phonetised_text, processed_text, processed_text_len],
313
+ api_name="matcha_tts",
314
+ queue=True,
315
+ ).then(
316
+ fn=synthesise_mel,
317
+ inputs=[processed_text, processed_text_len, n_timesteps, mel_temp, length_scale, spk_slider],
318
+ outputs=[audio, mel_spectrogram],
319
+ )
320
 
321
+ demo.queue(concurrency_count=5).launch(debug=True)