Vaibhav Srivastav commited on
Commit
db4c88c
·
1 Parent(s): 51eeef5
Files changed (1) hide show
  1. app.py +34 -413
app.py CHANGED
@@ -1,417 +1,38 @@
1
- from __future__ import annotations
2
-
3
- import os
4
- import pathlib
5
-
6
- import gradio as gr
7
- import numpy as np
8
  import torch
9
- import torchaudio
10
- from fairseq2.assets import InProcAssetMetadataProvider, asset_store
11
- from huggingface_hub import snapshot_download
12
- from seamless_communication.inference import Translator
13
-
14
- from lang_list import (
15
- ASR_TARGET_LANGUAGE_NAMES,
16
- LANGUAGE_NAME_TO_CODE,
17
- S2ST_TARGET_LANGUAGE_NAMES,
18
- S2TT_TARGET_LANGUAGE_NAMES,
19
- T2ST_TARGET_LANGUAGE_NAMES,
20
- T2TT_TARGET_LANGUAGE_NAMES,
21
- TEXT_SOURCE_LANGUAGE_NAMES,
22
- )
23
-
24
- CHECKPOINTS_PATH = pathlib.Path(os.getenv("CHECKPOINTS_PATH", "/home/user/app/models"))
25
- if not CHECKPOINTS_PATH.exists():
26
- snapshot_download(repo_id="facebook/seamless-m4t-v2-large", repo_type="model", local_dir=CHECKPOINTS_PATH)
27
- asset_store.env_resolvers.clear()
28
- asset_store.env_resolvers.append(lambda: "demo")
29
- demo_metadata = [
30
- {
31
- "name": "seamlessM4T_v2_large@demo",
32
- "checkpoint": f"file://{CHECKPOINTS_PATH}/seamlessM4T_v2_large.pt",
33
- "char_tokenizer": f"file://{CHECKPOINTS_PATH}/spm_char_lang38_tc.model",
34
- },
35
- {
36
- "name": "vocoder_v2@demo",
37
- "checkpoint": f"file://{CHECKPOINTS_PATH}/vocoder_v2.pt",
38
- },
39
- ]
40
- asset_store.metadata_providers.append(InProcAssetMetadataProvider(demo_metadata))
41
-
42
- DESCRIPTION = """\
43
- # SeamlessM4T
44
-
45
- [SeamlessM4T](https://github.com/facebookresearch/seamless_communication) is designed to provide high-quality
46
- translation, allowing people from different linguistic communities to communicate effortlessly through speech and text.
47
- This unified model enables multiple tasks like Speech-to-Speech (S2ST), Speech-to-Text (S2TT), Text-to-Speech (T2ST)
48
- translation and more, without relying on multiple separate models.
49
- """
50
-
51
- CACHE_EXAMPLES = os.getenv("CACHE_EXAMPLES") == "1" and torch.cuda.is_available()
52
-
53
- AUDIO_SAMPLE_RATE = 16000.0
54
- MAX_INPUT_AUDIO_LENGTH = 60 # in seconds
55
- DEFAULT_TARGET_LANGUAGE = "French"
56
-
57
- if torch.cuda.is_available():
58
- device = torch.device("cuda:0")
59
- dtype = torch.float16
60
- else:
61
- device = torch.device("cpu")
62
- dtype = torch.float32
63
-
64
- translator = Translator(
65
- model_name_or_card="seamlessM4T_v2_large",
66
- vocoder_name_or_card="vocoder_v2",
67
- device=device,
68
- dtype=dtype,
69
- apply_mintox=True,
70
- )
71
-
72
-
73
- def preprocess_audio(input_audio: str) -> None:
74
- arr, org_sr = torchaudio.load(input_audio)
75
- new_arr = torchaudio.functional.resample(arr, orig_freq=org_sr, new_freq=AUDIO_SAMPLE_RATE)
76
- max_length = int(MAX_INPUT_AUDIO_LENGTH * AUDIO_SAMPLE_RATE)
77
- if new_arr.shape[1] > max_length:
78
- new_arr = new_arr[:, :max_length]
79
- gr.Warning(f"Input audio is too long. Only the first {MAX_INPUT_AUDIO_LENGTH} seconds is used.")
80
- torchaudio.save(input_audio, new_arr, sample_rate=int(AUDIO_SAMPLE_RATE))
81
-
82
-
83
- def run_s2st(
84
- input_audio: str, source_language: str, target_language: str
85
- ) -> tuple[tuple[int, np.ndarray] | None, str]:
86
- preprocess_audio(input_audio)
87
- source_language_code = LANGUAGE_NAME_TO_CODE[source_language]
88
- target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
89
- out_texts, out_audios = translator.predict(
90
- input=input_audio,
91
- task_str="S2ST",
92
- src_lang=source_language_code,
93
- tgt_lang=target_language_code,
94
- )
95
- out_text = str(out_texts[0])
96
- out_wav = out_audios.audio_wavs[0].cpu().detach().numpy()
97
- return (int(AUDIO_SAMPLE_RATE), out_wav), out_text
98
-
99
-
100
- def run_s2tt(input_audio: str, source_language: str, target_language: str) -> str:
101
- preprocess_audio(input_audio)
102
- source_language_code = LANGUAGE_NAME_TO_CODE[source_language]
103
- target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
104
- out_texts, _ = translator.predict(
105
- input=input_audio,
106
- task_str="S2TT",
107
- src_lang=source_language_code,
108
- tgt_lang=target_language_code,
109
- )
110
- return str(out_texts[0])
111
-
112
-
113
- def run_t2st(input_text: str, source_language: str, target_language: str) -> tuple[tuple[int, np.ndarray] | None, str]:
114
- source_language_code = LANGUAGE_NAME_TO_CODE[source_language]
115
- target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
116
- out_texts, out_audios = translator.predict(
117
- input=input_text,
118
- task_str="T2ST",
119
- src_lang=source_language_code,
120
- tgt_lang=target_language_code,
121
- )
122
- out_text = str(out_texts[0])
123
- out_wav = out_audios.audio_wavs[0].cpu().detach().numpy()
124
- return (int(AUDIO_SAMPLE_RATE), out_wav), out_text
125
-
126
-
127
- def run_t2tt(input_text: str, source_language: str, target_language: str) -> str:
128
- source_language_code = LANGUAGE_NAME_TO_CODE[source_language]
129
- target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
130
- out_texts, _ = translator.predict(
131
- input=input_text,
132
- task_str="T2TT",
133
- src_lang=source_language_code,
134
- tgt_lang=target_language_code,
135
- )
136
- return str(out_texts[0])
137
-
138
-
139
- def run_asr(input_audio: str, target_language: str) -> str:
140
- preprocess_audio(input_audio)
141
- target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
142
- out_texts, _ = translator.predict(
143
- input=input_audio,
144
- task_str="ASR",
145
- src_lang=target_language_code,
146
- tgt_lang=target_language_code,
147
- )
148
- return str(out_texts[0])
149
-
150
-
151
- with gr.Blocks() as demo_s2st:
152
- with gr.Row():
153
- with gr.Column():
154
- with gr.Group():
155
- input_audio = gr.Audio(label="Input speech", type="filepath")
156
- source_language = gr.Dropdown(
157
- label="Source language",
158
- choices=ASR_TARGET_LANGUAGE_NAMES,
159
- value="English",
160
- )
161
- target_language = gr.Dropdown(
162
- label="Target language",
163
- choices=S2ST_TARGET_LANGUAGE_NAMES,
164
- value=DEFAULT_TARGET_LANGUAGE,
165
- )
166
- btn = gr.Button("Translate")
167
- with gr.Column():
168
- with gr.Group():
169
- output_audio = gr.Audio(
170
- label="Translated speech",
171
- autoplay=False,
172
- streaming=False,
173
- type="numpy",
174
- )
175
- output_text = gr.Textbox(label="Translated text")
176
-
177
- gr.Examples(
178
- examples=[
179
- ["assets/sample_input.mp3", "English", "French"],
180
- ["assets/sample_input.mp3", "English", "Mandarin Chinese"],
181
- ["assets/sample_input_2.mp3", "English", "Hindi"],
182
- ["assets/sample_input_2.mp3", "English", "Spanish"],
183
- ],
184
- inputs=[input_audio, source_language, target_language],
185
- outputs=[output_audio, output_text],
186
- fn=run_s2st,
187
- cache_examples=CACHE_EXAMPLES,
188
- api_name=False,
189
- )
190
-
191
- btn.click(
192
- fn=run_s2st,
193
- inputs=[input_audio, source_language, target_language],
194
- outputs=[output_audio, output_text],
195
- api_name="s2st",
196
- )
197
-
198
- with gr.Blocks() as demo_s2tt:
199
- with gr.Row():
200
- with gr.Column():
201
- with gr.Group():
202
- input_audio = gr.Audio(label="Input speech", type="filepath")
203
- source_language = gr.Dropdown(
204
- label="Source language",
205
- choices=ASR_TARGET_LANGUAGE_NAMES,
206
- value="English",
207
- )
208
- target_language = gr.Dropdown(
209
- label="Target language",
210
- choices=S2TT_TARGET_LANGUAGE_NAMES,
211
- value=DEFAULT_TARGET_LANGUAGE,
212
- )
213
- btn = gr.Button("Translate")
214
- with gr.Column():
215
- output_text = gr.Textbox(label="Translated text")
216
-
217
- gr.Examples(
218
- examples=[
219
- ["assets/sample_input.mp3", "English", "French"],
220
- ["assets/sample_input.mp3", "English", "Mandarin Chinese"],
221
- ["assets/sample_input_2.mp3", "English", "Hindi"],
222
- ["assets/sample_input_2.mp3", "English", "Spanish"],
223
- ],
224
- inputs=[input_audio, source_language, target_language],
225
- outputs=output_text,
226
- fn=run_s2tt,
227
- cache_examples=CACHE_EXAMPLES,
228
- api_name=False,
229
- )
230
-
231
- btn.click(
232
- fn=run_s2tt,
233
- inputs=[input_audio, source_language, target_language],
234
- outputs=output_text,
235
- api_name="s2tt",
236
- )
237
-
238
- with gr.Blocks() as demo_t2st:
239
- with gr.Row():
240
- with gr.Column():
241
- with gr.Group():
242
- input_text = gr.Textbox(label="Input text")
243
- with gr.Row():
244
- source_language = gr.Dropdown(
245
- label="Source language",
246
- choices=TEXT_SOURCE_LANGUAGE_NAMES,
247
- value="English",
248
- )
249
- target_language = gr.Dropdown(
250
- label="Target language",
251
- choices=T2ST_TARGET_LANGUAGE_NAMES,
252
- value=DEFAULT_TARGET_LANGUAGE,
253
- )
254
- btn = gr.Button("Translate")
255
- with gr.Column():
256
- with gr.Group():
257
- output_audio = gr.Audio(
258
- label="Translated speech",
259
- autoplay=False,
260
- streaming=False,
261
- type="numpy",
262
- )
263
- output_text = gr.Textbox(label="Translated text")
264
-
265
- gr.Examples(
266
- examples=[
267
- [
268
- "My favorite animal is the elephant.",
269
- "English",
270
- "French",
271
- ],
272
- [
273
- "My favorite animal is the elephant.",
274
- "English",
275
- "Mandarin Chinese",
276
- ],
277
- [
278
- "Meta AI's Seamless M4T model is democratising spoken communication across language barriers",
279
- "English",
280
- "Hindi",
281
- ],
282
- [
283
- "Meta AI's Seamless M4T model is democratising spoken communication across language barriers",
284
- "English",
285
- "Spanish",
286
- ],
287
- ],
288
- inputs=[input_text, source_language, target_language],
289
- outputs=[output_audio, output_text],
290
- fn=run_t2st,
291
- cache_examples=CACHE_EXAMPLES,
292
- api_name=False,
293
- )
294
-
295
- gr.on(
296
- triggers=[input_text.submit, btn.click],
297
- fn=run_t2st,
298
- inputs=[input_text, source_language, target_language],
299
- outputs=[output_audio, output_text],
300
- api_name="t2st",
301
- )
302
-
303
- with gr.Blocks() as demo_t2tt:
304
- with gr.Row():
305
- with gr.Column():
306
- with gr.Group():
307
- input_text = gr.Textbox(label="Input text")
308
- with gr.Row():
309
- source_language = gr.Dropdown(
310
- label="Source language",
311
- choices=TEXT_SOURCE_LANGUAGE_NAMES,
312
- value="English",
313
- )
314
- target_language = gr.Dropdown(
315
- label="Target language",
316
- choices=T2TT_TARGET_LANGUAGE_NAMES,
317
- value=DEFAULT_TARGET_LANGUAGE,
318
- )
319
- btn = gr.Button("Translate")
320
- with gr.Column():
321
- output_text = gr.Textbox(label="Translated text")
322
-
323
- gr.Examples(
324
- examples=[
325
- [
326
- "My favorite animal is the elephant.",
327
- "English",
328
- "French",
329
- ],
330
- [
331
- "My favorite animal is the elephant.",
332
- "English",
333
- "Mandarin Chinese",
334
- ],
335
- [
336
- "Meta AI's Seamless M4T model is democratising spoken communication across language barriers",
337
- "English",
338
- "Hindi",
339
- ],
340
- [
341
- "Meta AI's Seamless M4T model is democratising spoken communication across language barriers",
342
- "English",
343
- "Spanish",
344
- ],
345
- ],
346
- inputs=[input_text, source_language, target_language],
347
- outputs=output_text,
348
- fn=run_t2tt,
349
- cache_examples=CACHE_EXAMPLES,
350
- api_name=False,
351
- )
352
-
353
- gr.on(
354
- triggers=[input_text.submit, btn.click],
355
- fn=run_t2tt,
356
- inputs=[input_text, source_language, target_language],
357
- outputs=output_text,
358
- api_name="t2tt",
359
- )
360
-
361
- with gr.Blocks() as demo_asr:
362
- with gr.Row():
363
- with gr.Column():
364
- with gr.Group():
365
- input_audio = gr.Audio(label="Input speech", type="filepath")
366
- target_language = gr.Dropdown(
367
- label="Target language",
368
- choices=ASR_TARGET_LANGUAGE_NAMES,
369
- value=DEFAULT_TARGET_LANGUAGE,
370
- )
371
- btn = gr.Button("Translate")
372
- with gr.Column():
373
- output_text = gr.Textbox(label="Translated text")
374
-
375
- gr.Examples(
376
- examples=[
377
- ["assets/sample_input.mp3", "English"],
378
- ["assets/sample_input_2.mp3", "English"],
379
- ],
380
- inputs=[input_audio, target_language],
381
- outputs=output_text,
382
- fn=run_asr,
383
- cache_examples=CACHE_EXAMPLES,
384
- api_name=False,
385
- )
386
-
387
- btn.click(
388
- fn=run_asr,
389
- inputs=[input_audio, target_language],
390
- outputs=output_text,
391
- api_name="asr",
392
- )
393
-
394
-
395
- with gr.Blocks(css="style.css") as demo:
396
- gr.Markdown(DESCRIPTION)
397
- gr.DuplicateButton(
398
- value="Duplicate Space for private use",
399
- elem_id="duplicate-button",
400
- visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
401
- )
402
-
403
- with gr.Tabs():
404
- with gr.Tab(label="S2ST"):
405
- demo_s2st.render()
406
- with gr.Tab(label="S2TT"):
407
- demo_s2tt.render()
408
- with gr.Tab(label="T2ST"):
409
- demo_t2st.render()
410
- with gr.Tab(label="T2TT"):
411
- demo_t2tt.render()
412
- with gr.Tab(label="ASR"):
413
- demo_asr.render()
414
 
 
 
415
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
416
  if __name__ == "__main__":
417
- demo.queue(max_size=50).launch()
 
 
 
 
 
 
 
 
1
  import torch
2
+ import torch.nn.functional as F
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
+ from einops import rearrange
5
+ import gradio as gr
6
 
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM
8
+
9
+ from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
10
+
11
+ device = "cuda"
12
+ tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
13
+ model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-130m", device=device, dtype=torch.float16)
14
+
15
+ def pred(text_in):
16
+ tokens = tokenizer(text_in, return_tensors="pt")
17
+ input_ids = tokens.input_ids.to(device=device)
18
+ attn_mask = tokens.attention_mask.to(device=device)
19
+ max_length = input_ids.shape[1] + 100
20
+ fn = lambda: model.generate(
21
+ input_ids=input_ids,
22
+ max_length=max_length,
23
+ cg=True,
24
+ return_dict_in_generate=True,
25
+ output_scores=True,
26
+ enable_timing=False,
27
+ temperature=1.0,
28
+ top_k=1,
29
+ top_p=1.0,
30
+ )
31
+ out = fn()
32
+ text_out = tokenizer.batch_decode(out.sequences.tolist())
33
+ return text_out
34
+
35
+ demo = gr.Interface(fn=pred, inputs="text", outputs="text")
36
+
37
  if __name__ == "__main__":
38
+ demo.launch()