radames HF staff commited on
Commit
4a67689
1 Parent(s): 3e8007d
app.py DELETED
@@ -1,215 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import gradio as gr
4
- import numpy as np
5
-
6
- import asyncio
7
- from simuleval_transcoder import SimulevalTranscoder, logger
8
-
9
- import time
10
- from simuleval.utils.agent import build_system_from_dir
11
- import torch
12
-
13
-
14
- language_code_to_name = {
15
- "cmn": "Mandarin Chinese",
16
- "deu": "German",
17
- "eng": "English",
18
- "fra": "French",
19
- "spa": "Spanish",
20
- }
21
- S2ST_TARGET_LANGUAGE_NAMES = language_code_to_name.values()
22
- LANGUAGE_NAME_TO_CODE = {v: k for k, v in language_code_to_name.items()}
23
-
24
- DEFAULT_TARGET_LANGUAGE = "English"
25
-
26
-
27
- def build_agent(model_path, config_name=None):
28
- agent = build_system_from_dir(
29
- model_path, config_name=config_name,
30
- )
31
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
- agent.to(device, fp16=True)
33
-
34
- return agent
35
-
36
- agent = build_agent("models", "vad_s2st_sc_24khz_main.yaml")
37
- transcoder = SimulevalTranscoder(
38
- agent,
39
- sample_rate=48_000,
40
- debug=False,
41
- buffer_limit=1,
42
- )
43
-
44
- def start_recording():
45
- logger.debug(f"start_recording: starting transcoder")
46
- transcoder.reset_states()
47
- transcoder.close = False
48
- transcoder.start()
49
-
50
- def stop_recording():
51
- transcoder.close = True
52
-
53
- class MyState:
54
- def __init__(self):
55
- self.queue = asyncio.Queue()
56
- self.close = False
57
-
58
-
59
- s = MyState()
60
-
61
- def process_incoming_bytes(audio):
62
- logger.debug(f"process_bytes: incoming audio")
63
- sample_rate, data = audio
64
- transcoder.process_incoming_bytes(data.tobytes(), 'eng', sample_rate)
65
- s.queue.put_nowait(audio)
66
-
67
-
68
-
69
- def get_buffered_output():
70
-
71
- speech_and_text_output = transcoder.get_buffered_output()
72
- if speech_and_text_output is None:
73
- logger.debug("No output from transcoder.get_buffered_output()")
74
- return None, None, None
75
-
76
- logger.debug(f"We DID get output from the transcoder!")
77
-
78
- text = None
79
- speech = None
80
-
81
- if speech_and_text_output.speech_samples:
82
- speech = (speech_and_text_output.speech_sample_rate, speech_and_text_output.speech_samples)
83
-
84
- if speech_and_text_output.text:
85
- text = speech_and_text_output.text
86
- if speech_and_text_output.final:
87
- text += "\n"
88
-
89
- return speech, text, speech_and_text_output.final
90
-
91
- from scipy.io.wavfile import write as scipy_write
92
- def streaming_input_callback():
93
- final = False
94
- max_wait_s = 15
95
- wait_s = 0
96
- translated_text_state = ""
97
- sample_rate = 24000
98
- while not transcoder.close:
99
- translated_wav_segment, translated_text, final = get_buffered_output()
100
-
101
- if translated_wav_segment is None and translated_text is None:
102
- time.sleep(0.3)
103
- wait_s += 0.3
104
- if wait_s >= max_wait_s:
105
- transcoder.close = True
106
- continue
107
- wait_s = 0
108
- if translated_wav_segment is not None:
109
- sample_rate, audio_bytes = translated_wav_segment
110
- print("output sample rate", sample_rate)
111
- translated_wav_segment = sample_rate, np.array(audio_bytes)
112
- else:
113
- translated_wav_segment = sample_rate, np.empty(0, dtype=np.int16)
114
-
115
- if translated_text is not None:
116
- translated_text_state += " | " + str(translated_text)
117
-
118
- stream_output_text = translated_text_state
119
- if translated_text is not None:
120
- print("translated:", translated_text_state)
121
- yield [
122
- translated_wav_segment,
123
- stream_output_text,
124
- translated_text_state,
125
- ]
126
-
127
-
128
- def streaming_callback_dummy():
129
- i = 0
130
- out_text = ""
131
- while not transcoder.close:
132
- if s.queue.empty():
133
- yield (
134
- (48000, np.empty(0, dtype=np.int16)), out_text, out_text
135
- )
136
- time.sleep(0.3)
137
- else:
138
- i += 1
139
- out_text += " | " + str(i)
140
- print(out_text)
141
- audio = s.queue.get_nowait()
142
- if i == 0:
143
- print(audio[0], type(audio[1]))
144
- s.queue.task_done()
145
- yield audio, out_text, out_text
146
-
147
- def clear():
148
- logger.debug(f"Clearing State")
149
- return [bytes(), ""]
150
-
151
-
152
- def blocks():
153
- with gr.Blocks() as demo:
154
-
155
- with gr.Row():
156
- # TODO: add target language switching
157
- target_language = gr.Dropdown(
158
- label="Target language",
159
- choices=S2ST_TARGET_LANGUAGE_NAMES,
160
- value=DEFAULT_TARGET_LANGUAGE,
161
- )
162
-
163
- translated_text_state = gr.State("")
164
-
165
- input_audio = gr.Audio(
166
- label="Input Audio",
167
- sources=["microphone"],
168
- streaming=True,
169
- )
170
-
171
- output_translation_segment = gr.Audio(
172
- label="Translated audio segment",
173
- autoplay=True,
174
- streaming=True,
175
- )
176
-
177
- # Output text segment
178
- stream_output_text = gr.Textbox(label="Translated text")
179
-
180
- input_audio.clear(
181
- clear, None, [output_translation_segment, translated_text_state]
182
- )
183
- input_audio.start_recording(
184
- clear, None, [output_translation_segment, translated_text_state]
185
- ).then(
186
- start_recording
187
- ).then(
188
- # TODO: streaming speech autoplay works fine with streaming_callback_dummy,
189
- # but speech output from streaming_input_callback has a huge delay
190
- # when comparing print/debugging logs vs. output speech
191
- # TODO: text output works fine with one output, but is not
192
- # updating when output is both text + speech
193
- # streaming_callback_dummy,
194
- streaming_input_callback,
195
- None,
196
- [
197
- output_translation_segment,
198
- stream_output_text,
199
- translated_text_state,
200
- ]
201
- )
202
- input_audio.stop_recording(
203
- stop_recording
204
- )
205
- input_audio.stream(
206
- # TODO: *only when streaming speech output* about half the time
207
- # there is some race condition in gradio where process_incoming_bytes
208
- # stops getting called once the first speech chunk is yield-ed
209
- # in streaming_input_callback (or streaming_callback_dummy)
210
- process_incoming_bytes, [input_audio], None
211
- )
212
-
213
- demo.launch(server_port=6010)
214
-
215
- blocks()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
assets/sample_input.mp3 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:982369687f05bf8fcd6923c4ffcccda0fcce92f44eceae5a9d00a431f07ea87b
3
- size 10272
 
 
 
 
assets/sample_input_2.mp3 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:6a505a4641e3f5f0ddec9508832793aa20e63d2545530b66bc04a9bd19a742e6
3
- size 30624
 
 
 
 
lang_list.py DELETED
@@ -1,254 +0,0 @@
1
- # Language dict
2
- language_code_to_name = {
3
- "afr": "Afrikaans",
4
- "amh": "Amharic",
5
- "arb": "Modern Standard Arabic",
6
- "ary": "Moroccan Arabic",
7
- "arz": "Egyptian Arabic",
8
- "asm": "Assamese",
9
- "ast": "Asturian",
10
- "azj": "North Azerbaijani",
11
- "bel": "Belarusian",
12
- "ben": "Bengali",
13
- "bos": "Bosnian",
14
- "bul": "Bulgarian",
15
- "cat": "Catalan",
16
- "ceb": "Cebuano",
17
- "ces": "Czech",
18
- "ckb": "Central Kurdish",
19
- "cmn": "Mandarin Chinese",
20
- "cym": "Welsh",
21
- "dan": "Danish",
22
- "deu": "German",
23
- "ell": "Greek",
24
- "eng": "English",
25
- "est": "Estonian",
26
- "eus": "Basque",
27
- "fin": "Finnish",
28
- "fra": "French",
29
- "gaz": "West Central Oromo",
30
- "gle": "Irish",
31
- "glg": "Galician",
32
- "guj": "Gujarati",
33
- "heb": "Hebrew",
34
- "hin": "Hindi",
35
- "hrv": "Croatian",
36
- "hun": "Hungarian",
37
- "hye": "Armenian",
38
- "ibo": "Igbo",
39
- "ind": "Indonesian",
40
- "isl": "Icelandic",
41
- "ita": "Italian",
42
- "jav": "Javanese",
43
- "jpn": "Japanese",
44
- "kam": "Kamba",
45
- "kan": "Kannada",
46
- "kat": "Georgian",
47
- "kaz": "Kazakh",
48
- "kea": "Kabuverdianu",
49
- "khk": "Halh Mongolian",
50
- "khm": "Khmer",
51
- "kir": "Kyrgyz",
52
- "kor": "Korean",
53
- "lao": "Lao",
54
- "lit": "Lithuanian",
55
- "ltz": "Luxembourgish",
56
- "lug": "Ganda",
57
- "luo": "Luo",
58
- "lvs": "Standard Latvian",
59
- "mai": "Maithili",
60
- "mal": "Malayalam",
61
- "mar": "Marathi",
62
- "mkd": "Macedonian",
63
- "mlt": "Maltese",
64
- "mni": "Meitei",
65
- "mya": "Burmese",
66
- "nld": "Dutch",
67
- "nno": "Norwegian Nynorsk",
68
- "nob": "Norwegian Bokm\u00e5l",
69
- "npi": "Nepali",
70
- "nya": "Nyanja",
71
- "oci": "Occitan",
72
- "ory": "Odia",
73
- "pan": "Punjabi",
74
- "pbt": "Southern Pashto",
75
- "pes": "Western Persian",
76
- "pol": "Polish",
77
- "por": "Portuguese",
78
- "ron": "Romanian",
79
- "rus": "Russian",
80
- "slk": "Slovak",
81
- "slv": "Slovenian",
82
- "sna": "Shona",
83
- "snd": "Sindhi",
84
- "som": "Somali",
85
- "spa": "Spanish",
86
- "srp": "Serbian",
87
- "swe": "Swedish",
88
- "swh": "Swahili",
89
- "tam": "Tamil",
90
- "tel": "Telugu",
91
- "tgk": "Tajik",
92
- "tgl": "Tagalog",
93
- "tha": "Thai",
94
- "tur": "Turkish",
95
- "ukr": "Ukrainian",
96
- "urd": "Urdu",
97
- "uzn": "Northern Uzbek",
98
- "vie": "Vietnamese",
99
- "xho": "Xhosa",
100
- "yor": "Yoruba",
101
- "yue": "Cantonese",
102
- "zlm": "Colloquial Malay",
103
- "zsm": "Standard Malay",
104
- "zul": "Zulu",
105
- }
106
- LANGUAGE_NAME_TO_CODE = {v: k for k, v in language_code_to_name.items()}
107
-
108
- # Source langs: S2ST / S2TT / ASR don't need source lang
109
- # T2TT / T2ST use this
110
- text_source_language_codes = [
111
- "afr",
112
- "amh",
113
- "arb",
114
- "ary",
115
- "arz",
116
- "asm",
117
- "azj",
118
- "bel",
119
- "ben",
120
- "bos",
121
- "bul",
122
- "cat",
123
- "ceb",
124
- "ces",
125
- "ckb",
126
- "cmn",
127
- "cym",
128
- "dan",
129
- "deu",
130
- "ell",
131
- "eng",
132
- "est",
133
- "eus",
134
- "fin",
135
- "fra",
136
- "gaz",
137
- "gle",
138
- "glg",
139
- "guj",
140
- "heb",
141
- "hin",
142
- "hrv",
143
- "hun",
144
- "hye",
145
- "ibo",
146
- "ind",
147
- "isl",
148
- "ita",
149
- "jav",
150
- "jpn",
151
- "kan",
152
- "kat",
153
- "kaz",
154
- "khk",
155
- "khm",
156
- "kir",
157
- "kor",
158
- "lao",
159
- "lit",
160
- "lug",
161
- "luo",
162
- "lvs",
163
- "mai",
164
- "mal",
165
- "mar",
166
- "mkd",
167
- "mlt",
168
- "mni",
169
- "mya",
170
- "nld",
171
- "nno",
172
- "nob",
173
- "npi",
174
- "nya",
175
- "ory",
176
- "pan",
177
- "pbt",
178
- "pes",
179
- "pol",
180
- "por",
181
- "ron",
182
- "rus",
183
- "slk",
184
- "slv",
185
- "sna",
186
- "snd",
187
- "som",
188
- "spa",
189
- "srp",
190
- "swe",
191
- "swh",
192
- "tam",
193
- "tel",
194
- "tgk",
195
- "tgl",
196
- "tha",
197
- "tur",
198
- "ukr",
199
- "urd",
200
- "uzn",
201
- "vie",
202
- "yor",
203
- "yue",
204
- "zsm",
205
- "zul",
206
- ]
207
- TEXT_SOURCE_LANGUAGE_NAMES = sorted([language_code_to_name[code] for code in text_source_language_codes])
208
-
209
- # Target langs:
210
- # S2ST / T2ST
211
- s2st_target_language_codes = [
212
- "eng",
213
- "arb",
214
- "ben",
215
- "cat",
216
- "ces",
217
- "cmn",
218
- "cym",
219
- "dan",
220
- "deu",
221
- "est",
222
- "fin",
223
- "fra",
224
- "hin",
225
- "ind",
226
- "ita",
227
- "jpn",
228
- "kor",
229
- "mlt",
230
- "nld",
231
- "pes",
232
- "pol",
233
- "por",
234
- "ron",
235
- "rus",
236
- "slk",
237
- "spa",
238
- "swe",
239
- "swh",
240
- "tel",
241
- "tgl",
242
- "tha",
243
- "tur",
244
- "ukr",
245
- "urd",
246
- "uzn",
247
- "vie",
248
- ]
249
- S2ST_TARGET_LANGUAGE_NAMES = sorted([language_code_to_name[code] for code in s2st_target_language_codes])
250
-
251
- # S2TT / ASR
252
- S2TT_TARGET_LANGUAGE_NAMES = TEXT_SOURCE_LANGUAGE_NAMES
253
- # T2TT
254
- T2TT_TARGET_LANGUAGE_NAMES = TEXT_SOURCE_LANGUAGE_NAMES
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
m4t_app.py DELETED
@@ -1,463 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import os
4
-
5
- import gradio as gr
6
- import numpy as np
7
- import torch
8
- import torchaudio
9
- from seamless_communication.models.inference.translator import Translator
10
-
11
- from lang_list import (
12
- LANGUAGE_NAME_TO_CODE,
13
- S2ST_TARGET_LANGUAGE_NAMES,
14
- S2TT_TARGET_LANGUAGE_NAMES,
15
- T2TT_TARGET_LANGUAGE_NAMES,
16
- TEXT_SOURCE_LANGUAGE_NAMES,
17
- )
18
-
19
- DESCRIPTION = """# SeamlessM4T
20
-
21
- # mduppes aaaaaa
22
-
23
- [SeamlessM4T](https://github.com/facebookresearch/seamless_communication) is designed to provide high-quality
24
- translation, allowing people from different linguistic communities to communicate effortlessly through speech and text.
25
-
26
- This unified model enables multiple tasks like Speech-to-Speech (S2ST), Speech-to-Text (S2TT), Text-to-Speech (T2ST)
27
- translation and more, without relying on multiple separate models.
28
- """
29
-
30
- CACHE_EXAMPLES = os.getenv("CACHE_EXAMPLES") == "1"
31
-
32
- TASK_NAMES = [
33
- "S2ST (Speech to Speech translation)",
34
- "S2TT (Speech to Text translation)",
35
- "T2ST (Text to Speech translation)",
36
- "T2TT (Text to Text translation)",
37
- "ASR (Automatic Speech Recognition)",
38
- ]
39
- AUDIO_SAMPLE_RATE = 16000.0
40
- MAX_INPUT_AUDIO_LENGTH = 60 # in seconds
41
- DEFAULT_TARGET_LANGUAGE = "French"
42
-
43
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
44
- print("DEVICE", device)
45
- translator = Translator(
46
- model_name_or_card="seamlessM4T_medium",
47
- vocoder_name_or_card="vocoder_36langs",
48
- device=device,
49
- # dtype=torch.float16,
50
- # For CPU Mode need to use 32, float16 causes errors downstream
51
- dtype=torch.float32,
52
- )
53
-
54
- def get_translator():
55
- return translator
56
-
57
-
58
- def transcribe(audio):
59
- print(audio)
60
- text = p(audio)["text"]
61
- return text
62
-
63
- def transcribe_state(audio, state = ""):
64
- print(audio)
65
- text = p(audio)["text"]
66
- state += text + " "
67
- return state, state
68
-
69
-
70
- def predict(
71
- task_name: str,
72
- audio_source: str,
73
- input_audio_mic: str | None,
74
- input_audio_file: str | None,
75
- input_text: str | None,
76
- source_language: str | None,
77
- target_language: str,
78
- ) -> tuple[tuple[int, np.ndarray] | None, str]:
79
- task_name = task_name.split()[0]
80
- source_language_code = LANGUAGE_NAME_TO_CODE[source_language] if source_language else None
81
- target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
82
-
83
- if task_name in ["S2ST", "S2TT", "ASR"]:
84
- if audio_source == "microphone":
85
- input_data = input_audio_mic
86
- else:
87
- input_data = input_audio_file
88
-
89
- arr, org_sr = torchaudio.load(input_data)
90
- print(task_name, audio_source, input_audio_mic, type(input_audio_file), type(input_text), source_language, target_language)
91
- new_arr = torchaudio.functional.resample(arr, orig_freq=org_sr, new_freq=AUDIO_SAMPLE_RATE)
92
- max_length = int(MAX_INPUT_AUDIO_LENGTH * AUDIO_SAMPLE_RATE)
93
- if new_arr.shape[1] > max_length:
94
- new_arr = new_arr[:, :max_length]
95
- gr.Warning(f"Input audio is too long. Only the first {MAX_INPUT_AUDIO_LENGTH} seconds is used.")
96
- torchaudio.save(input_data, new_arr, sample_rate=int(AUDIO_SAMPLE_RATE))
97
- else:
98
- input_data = input_text
99
- text_out, wav, sr = translator.predict(
100
- input=input_data,
101
- task_str=task_name,
102
- tgt_lang=target_language_code,
103
- src_lang=source_language_code,
104
- ngram_filtering=True,
105
- sample_rate=AUDIO_SAMPLE_RATE,
106
- )
107
- print("translation response", text_out, wav, sr)
108
- # text_out = "Testing"
109
- # return None, text_out
110
- if task_name in ["S2ST", "T2ST"]:
111
- return (sr, wav.cpu().detach().numpy()), text_out
112
- else:
113
- return None, text_out
114
-
115
-
116
- def process_s2st_example(input_audio_file: str, target_language: str) -> tuple[tuple[int, np.ndarray] | None, str]:
117
- return predict(
118
- task_name="S2ST",
119
- audio_source="file",
120
- input_audio_mic=None,
121
- input_audio_file=input_audio_file,
122
- input_text=None,
123
- source_language=None,
124
- target_language=target_language,
125
- )
126
-
127
-
128
- def process_s2tt_example(input_audio_file: str, target_language: str) -> tuple[tuple[int, np.ndarray] | None, str]:
129
- return predict(
130
- task_name="S2TT",
131
- audio_source="file",
132
- input_audio_mic=None,
133
- input_audio_file=input_audio_file,
134
- input_text=None,
135
- source_language=None,
136
- target_language=target_language,
137
- )
138
-
139
-
140
- def process_t2st_example(
141
- input_text: str, source_language: str, target_language: str
142
- ) -> tuple[tuple[int, np.ndarray] | None, str]:
143
- return predict(
144
- task_name="T2ST",
145
- audio_source="",
146
- input_audio_mic=None,
147
- input_audio_file=None,
148
- input_text=input_text,
149
- source_language=source_language,
150
- target_language=target_language,
151
- )
152
-
153
-
154
- def process_t2tt_example(
155
- input_text: str, source_language: str, target_language: str
156
- ) -> tuple[tuple[int, np.ndarray] | None, str]:
157
- return predict(
158
- task_name="T2TT",
159
- audio_source="",
160
- input_audio_mic=None,
161
- input_audio_file=None,
162
- input_text=input_text,
163
- source_language=source_language,
164
- target_language=target_language,
165
- )
166
-
167
-
168
- def process_asr_example(input_audio_file: str, target_language: str) -> tuple[tuple[int, np.ndarray] | None, str]:
169
- return predict(
170
- task_name="ASR",
171
- audio_source="file",
172
- input_audio_mic=None,
173
- input_audio_file=input_audio_file,
174
- input_text=None,
175
- source_language=None,
176
- target_language=target_language,
177
- )
178
-
179
-
180
- def update_audio_ui(audio_source: str) -> tuple[dict, dict]:
181
- mic = audio_source == "microphone"
182
- return (
183
- gr.update(visible=mic, value=None), # input_audio_mic
184
- gr.update(visible=not mic, value=None), # input_audio_file
185
- )
186
-
187
-
188
- def update_input_ui(task_name: str) -> tuple[dict, dict, dict, dict]:
189
- task_name = task_name.split()[0]
190
- if task_name == "S2ST":
191
- return (
192
- gr.update(visible=True), # audio_box
193
- gr.update(visible=False), # input_text
194
- gr.update(visible=False), # source_language
195
- gr.update(
196
- visible=True, choices=S2ST_TARGET_LANGUAGE_NAMES, value=DEFAULT_TARGET_LANGUAGE
197
- ), # target_language
198
- )
199
- elif task_name == "S2TT":
200
- return (
201
- gr.update(visible=True), # audio_box
202
- gr.update(visible=False), # input_text
203
- gr.update(visible=False), # source_language
204
- gr.update(
205
- visible=True, choices=S2TT_TARGET_LANGUAGE_NAMES, value=DEFAULT_TARGET_LANGUAGE
206
- ), # target_language
207
- )
208
- elif task_name == "T2ST":
209
- return (
210
- gr.update(visible=False), # audio_box
211
- gr.update(visible=True), # input_text
212
- gr.update(visible=True), # source_language
213
- gr.update(
214
- visible=True, choices=S2ST_TARGET_LANGUAGE_NAMES, value=DEFAULT_TARGET_LANGUAGE
215
- ), # target_language
216
- )
217
- elif task_name == "T2TT":
218
- return (
219
- gr.update(visible=False), # audio_box
220
- gr.update(visible=True), # input_text
221
- gr.update(visible=True), # source_language
222
- gr.update(
223
- visible=True, choices=T2TT_TARGET_LANGUAGE_NAMES, value=DEFAULT_TARGET_LANGUAGE
224
- ), # target_language
225
- )
226
- elif task_name == "ASR":
227
- return (
228
- gr.update(visible=True), # audio_box
229
- gr.update(visible=False), # input_text
230
- gr.update(visible=False), # source_language
231
- gr.update(
232
- visible=True, choices=S2TT_TARGET_LANGUAGE_NAMES, value=DEFAULT_TARGET_LANGUAGE
233
- ), # target_language
234
- )
235
- else:
236
- raise ValueError(f"Unknown task: {task_name}")
237
-
238
-
239
- def update_output_ui(task_name: str) -> tuple[dict, dict]:
240
- task_name = task_name.split()[0]
241
- if task_name in ["S2ST", "T2ST"]:
242
- return (
243
- gr.update(visible=True, value=None), # output_audio
244
- gr.update(value=None), # output_text
245
- )
246
- elif task_name in ["S2TT", "T2TT", "ASR"]:
247
- return (
248
- gr.update(visible=False, value=None), # output_audio
249
- gr.update(value=None), # output_text
250
- )
251
- else:
252
- raise ValueError(f"Unknown task: {task_name}")
253
-
254
-
255
- def update_example_ui(task_name: str) -> tuple[dict, dict, dict, dict, dict]:
256
- task_name = task_name.split()[0]
257
- return (
258
- gr.update(visible=task_name == "S2ST"), # s2st_example_row
259
- gr.update(visible=task_name == "S2TT"), # s2tt_example_row
260
- gr.update(visible=task_name == "T2ST"), # t2st_example_row
261
- gr.update(visible=task_name == "T2TT"), # t2tt_example_row
262
- gr.update(visible=task_name == "ASR"), # asr_example_row
263
- )
264
-
265
- def m4t_demo():
266
-
267
- with gr.Blocks(css="style.css") as demo:
268
- gr.Markdown(DESCRIPTION)
269
- gr.DuplicateButton(
270
- value="Duplicate Space for private use",
271
- elem_id="duplicate-button",
272
- visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
273
- )
274
-
275
- with gr.Group():
276
- task_name = gr.Dropdown(
277
- label="Task",
278
- choices=TASK_NAMES,
279
- value=TASK_NAMES[0],
280
- )
281
-
282
-
283
- with gr.Row():
284
- source_language = gr.Dropdown(
285
- label="Source language",
286
- choices=TEXT_SOURCE_LANGUAGE_NAMES,
287
- value="English",
288
- visible=False,
289
- )
290
- target_language = gr.Dropdown(
291
- label="Target language",
292
- choices=S2ST_TARGET_LANGUAGE_NAMES,
293
- value=DEFAULT_TARGET_LANGUAGE,
294
- )
295
- with gr.Row() as audio_box:
296
- audio_source = gr.Radio(
297
- label="Audio source",
298
- choices=["file", "microphone"],
299
- value="file",
300
- )
301
- input_audio_mic = gr.Audio(
302
- label="Input speech",
303
- type="filepath",
304
- source="microphone",
305
- visible=False,
306
- )
307
- input_audio_file = gr.Audio(
308
- label="Input speech",
309
- type="filepath",
310
- source="upload",
311
- visible=True,
312
- )
313
- input_text = gr.Textbox(label="Input text", visible=False)
314
- btn = gr.Button("Translate")
315
- with gr.Column():
316
- output_audio = gr.Audio(
317
- label="Translated speech",
318
- autoplay=False,
319
- streaming=False,
320
- type="numpy",
321
- )
322
- output_text = gr.Textbox(label="Translated text")
323
-
324
- with gr.Row(visible=True) as s2st_example_row:
325
- s2st_examples = gr.Examples(
326
- examples=[
327
- ["assets/sample_input.mp3", "French"],
328
- ["assets/sample_input.mp3", "Mandarin Chinese"],
329
- ["assets/sample_input_2.mp3", "Hindi"],
330
- ["assets/sample_input_2.mp3", "Spanish"],
331
- ],
332
- inputs=[input_audio_file, target_language],
333
- outputs=[output_audio, output_text],
334
- fn=process_s2st_example,
335
- cache_examples=CACHE_EXAMPLES,
336
- )
337
- with gr.Row(visible=False) as s2tt_example_row:
338
- s2tt_examples = gr.Examples(
339
- examples=[
340
- ["assets/sample_input.mp3", "French"],
341
- ["assets/sample_input.mp3", "Mandarin Chinese"],
342
- ["assets/sample_input_2.mp3", "Hindi"],
343
- ["assets/sample_input_2.mp3", "Spanish"],
344
- ],
345
- inputs=[input_audio_file, target_language],
346
- outputs=[output_audio, output_text],
347
- fn=process_s2tt_example,
348
- cache_examples=CACHE_EXAMPLES,
349
- )
350
- with gr.Row(visible=False) as t2st_example_row:
351
- t2st_examples = gr.Examples(
352
- examples=[
353
- ["My favorite animal is the elephant.", "English", "French"],
354
- ["My favorite animal is the elephant.", "English", "Mandarin Chinese"],
355
- [
356
- "Meta AI's Seamless M4T model is democratising spoken communication across language barriers",
357
- "English",
358
- "Hindi",
359
- ],
360
- [
361
- "Meta AI's Seamless M4T model is democratising spoken communication across language barriers",
362
- "English",
363
- "Spanish",
364
- ],
365
- ],
366
- inputs=[input_text, source_language, target_language],
367
- outputs=[output_audio, output_text],
368
- fn=process_t2st_example,
369
- cache_examples=CACHE_EXAMPLES,
370
- )
371
- with gr.Row(visible=False) as t2tt_example_row:
372
- t2tt_examples = gr.Examples(
373
- examples=[
374
- ["My favorite animal is the elephant.", "English", "French"],
375
- ["My favorite animal is the elephant.", "English", "Mandarin Chinese"],
376
- [
377
- "Meta AI's Seamless M4T model is democratising spoken communication across language barriers",
378
- "English",
379
- "Hindi",
380
- ],
381
- [
382
- "Meta AI's Seamless M4T model is democratising spoken communication across language barriers",
383
- "English",
384
- "Spanish",
385
- ],
386
- ],
387
- inputs=[input_text, source_language, target_language],
388
- outputs=[output_audio, output_text],
389
- fn=process_t2tt_example,
390
- cache_examples=CACHE_EXAMPLES,
391
- )
392
- with gr.Row(visible=False) as asr_example_row:
393
- asr_examples = gr.Examples(
394
- examples=[
395
- ["assets/sample_input.mp3", "English"],
396
- ["assets/sample_input_2.mp3", "English"],
397
- ],
398
- inputs=[input_audio_file, target_language],
399
- outputs=[output_audio, output_text],
400
- fn=process_asr_example,
401
- cache_examples=CACHE_EXAMPLES,
402
- )
403
-
404
- audio_source.change(
405
- fn=update_audio_ui,
406
- inputs=audio_source,
407
- outputs=[
408
- input_audio_mic,
409
- input_audio_file,
410
- ],
411
- queue=False,
412
- api_name=False,
413
- )
414
- task_name.change(
415
- fn=update_input_ui,
416
- inputs=task_name,
417
- outputs=[
418
- audio_box,
419
- input_text,
420
- source_language,
421
- target_language,
422
- ],
423
- queue=False,
424
- api_name=False,
425
- ).then(
426
- fn=update_output_ui,
427
- inputs=task_name,
428
- outputs=[output_audio, output_text],
429
- queue=False,
430
- api_name=False,
431
- ).then(
432
- fn=update_example_ui,
433
- inputs=task_name,
434
- outputs=[
435
- s2st_example_row,
436
- s2tt_example_row,
437
- t2st_example_row,
438
- t2tt_example_row,
439
- asr_example_row,
440
- ],
441
- queue=False,
442
- api_name=False,
443
- )
444
-
445
- btn.click(
446
- fn=predict,
447
- inputs=[
448
- task_name,
449
- audio_source,
450
- input_audio_mic,
451
- input_audio_file,
452
- input_text,
453
- source_language,
454
- target_language,
455
- ],
456
- outputs=[output_audio, output_text],
457
- api_name="run",
458
- )
459
- demo.queue(max_size=50).launch()
460
-
461
- # Linking models to the space
462
- # 'facebook/seamless-m4t-large'
463
- # 'facebook/SONAR'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/vad_s2st_sc_24khz_main.yaml DELETED
@@ -1,24 +0,0 @@
1
- agent_class: seamless_communication.streaming.agents.mma_m4t_s2st.SeamlessS2STJointVADAgent
2
- # checkpoint: checkpoint_best.pt
3
- monotonic_decoder_model_name: seamless_streaming_monotonic_decoder
4
- unity_model_name: seamless_streaming_unity
5
- sentencepiece_model: spm_256k_nllb100.model
6
-
7
- task: s2st
8
- tgt_lang: "eng"
9
- min_unit_chunk_size: 50
10
- decision_threshold: 0.7
11
- no_early_stop: True
12
- block_ngrams: True
13
- vocoder_name: vocoder_pretssel
14
- wav2vec_yaml: wav2vec.yaml
15
- # min_starting_wait: 12
16
- # min_starting_wait_w2vbert: 192
17
-
18
- config_yaml: cfg_fbank_u2t.yaml
19
- vocoder_sample_rate: 24000
20
- upstream_idx: 1
21
- detokenize_only: True
22
- device: cuda:0
23
- max_len_a: 0
24
- max_len_b: 1000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt DELETED
@@ -1,26 +0,0 @@
1
- # TODO: fairseq2 install is complicated so currently done outside
2
-
3
- # fairseq2==0.1.0
4
-
5
- # git+https://github.com/facebookresearch/seamless_communication
6
- # ./fairseq2
7
- # ./seamless_communication
8
- # comment this out to test fairseq1 first
9
- # git+https://github.com/facebookresearch/SimulEval.git
10
- gradio==3.41.0
11
- huggingface_hub==0.16.4
12
- # torch==2.1.0
13
- # torchaudio==2.0.2
14
- # transformers==4.32.1
15
- pydub
16
- g2p_en
17
- colorlog
18
- # git+ssh://git@github.com/facebookresearch/SimulEval.git
19
-
20
- # Can't import fairseq1 together.. causes conflict:
21
- #The conflict is caused by:
22
- # The user requested simuleval 1.1.0 (from git+ssh://****@github.com/facebookresearch/SimulEval.git@tree_pipeline)
23
- # seamless-communication 1.0.0 depends on simuleval 1.0.3.dev36+gd84fa60 (from git+https://github.com/mduppes/SimulEval.git@main)
24
- # From fairseq1 pipeline
25
- # git+ssh://git@github.com/fairinternal/fairseq-py.git@emma_incremental_decoder
26
- # git+ssh://git@github.com/facebookresearch/SimulEval.git@tree_pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sample_wav.py DELETED
The diff for this file is too large to render. See raw diff
 
simuleval_transcoder.py DELETED
@@ -1,425 +0,0 @@
1
-
2
- from typing import Any, List, Tuple, Union, Optional
3
- import numpy as np
4
- import soundfile
5
- import io
6
- import asyncio
7
- from simuleval.agents.pipeline import TreeAgentPipeline
8
- from simuleval.agents.states import AgentStates
9
- from simuleval.data.segments import Segment, EmptySegment, SpeechSegment
10
- import threading
11
- import math
12
- import logging
13
- import sys
14
- from pathlib import Path
15
- import time
16
- from g2p_en import G2p
17
- import torch
18
- import traceback
19
- import time
20
- import random
21
- import colorlog
22
-
23
-
24
- MODEL_SAMPLE_RATE = 16_000
25
-
26
- logger = logging.getLogger(__name__)
27
- logger.propagate = False
28
- handler = colorlog.StreamHandler(stream=sys.stdout)
29
- formatter = colorlog.ColoredFormatter(
30
- "%(log_color)s[%(asctime)s][%(levelname)s][%(module)s]:%(reset)s %(message)s",
31
- reset=True,
32
- log_colors={
33
- "DEBUG": "cyan",
34
- "INFO": "green",
35
- "WARNING": "yellow",
36
- "ERROR": "red",
37
- "CRITICAL": "red,bg_white",
38
- },
39
- )
40
- handler.setFormatter(formatter)
41
- logger.addHandler(handler)
42
- logger.setLevel(logging.DEBUG)
43
-
44
-
45
- class SpeechAndTextOutput:
46
- def __init__(
47
- self,
48
- text: str = None,
49
- speech_samples: list = None,
50
- speech_sample_rate: float = None,
51
- final: bool = False,
52
- ):
53
- self.text = text
54
- self.speech_samples = speech_samples
55
- self.speech_sample_rate = speech_sample_rate
56
- self.final = final
57
-
58
- class OutputSegments:
59
- def __init__(self, segments: Union[List[Segment], Segment]):
60
- if isinstance(segments, Segment):
61
- segments = [segments]
62
- self.segments: List[Segment] = [s for s in segments]
63
-
64
- @property
65
- def is_empty(self):
66
- return all(segment.is_empty for segment in self.segments)
67
-
68
- @property
69
- def finished(self):
70
- return all(segment.finished for segment in self.segments)
71
-
72
- def compute_length(self, g2p):
73
- lengths = []
74
- for segment in self.segments:
75
- if segment.data_type == "text":
76
- lengths.append(len([x for x in g2p(segment.content) if x != " "]))
77
- elif segment.data_type == "speech":
78
- lengths.append(len(segment.content) / MODEL_SAMPLE_RATE)
79
- elif isinstance(segment, EmptySegment):
80
- continue
81
- else:
82
- logger.warning(
83
- f"Unexpected data_type: {segment.data_type} not in 'speech', 'text'"
84
- )
85
- return max(lengths)
86
-
87
- @classmethod
88
- def join_output_buffer(
89
- cls, buffer: List[List[Segment]], output: SpeechAndTextOutput
90
- ):
91
- num_segments = len(buffer[0])
92
- for i in range(num_segments):
93
- segment_list = [
94
- buffer[j][i]
95
- for j in range(len(buffer))
96
- if buffer[j][i].data_type is not None
97
- ]
98
- if len(segment_list) == 0:
99
- continue
100
- if len(set(segment.data_type for segment in segment_list)) != 1:
101
- logger.warning(
102
- f"Data type mismatch at {i}: {set(segment.data_type for segment in segment_list)}"
103
- )
104
- continue
105
- data_type = segment_list[0].data_type
106
- if data_type == "text":
107
- if output.text is not None:
108
- logger.warning("Multiple text outputs, overwriting!")
109
- output.text = " ".join([segment.content for segment in segment_list])
110
- elif data_type == "speech":
111
- if output.speech_samples is not None:
112
- logger.warning("Multiple speech outputs, overwriting!")
113
- speech_out = []
114
- for segment in segment_list:
115
- speech_out += segment.content
116
- output.speech_samples = speech_out
117
- output.speech_sample_rate = segment.sample_rate
118
- elif isinstance(segment_list[0], EmptySegment):
119
- continue
120
- else:
121
- logger.warning(
122
- f"Invalid output buffer data type: {data_type}, expected 'speech' or 'text"
123
- )
124
-
125
- return output
126
-
127
- def __repr__(self) -> str:
128
- repr_str = str(self.segments)
129
- return f"{self.__class__.__name__}(\n\t{repr_str}\n)"
130
-
131
-
132
- def convert_waveform(
133
- waveform: Union[np.ndarray, torch.Tensor],
134
- sample_rate: int,
135
- normalize_volume: bool = False,
136
- to_mono: bool = False,
137
- to_sample_rate: Optional[int] = None,
138
- ) -> Tuple[Union[np.ndarray, torch.Tensor], int]:
139
- """convert a waveform:
140
- - to a target sample rate
141
- - from multi-channel to mono channel
142
- - volume normalization
143
-
144
- Args:
145
- waveform (numpy.ndarray or torch.Tensor): 2D original waveform
146
- (channels x length)
147
- sample_rate (int): original sample rate
148
- normalize_volume (bool): perform volume normalization
149
- to_mono (bool): convert to mono channel if having multiple channels
150
- to_sample_rate (Optional[int]): target sample rate
151
- Returns:
152
- waveform (numpy.ndarray): converted 2D waveform (channels x length)
153
- sample_rate (float): target sample rate
154
- """
155
- try:
156
- import torchaudio.sox_effects as ta_sox
157
- except ImportError:
158
- raise ImportError("Please install torchaudio: pip install torchaudio")
159
-
160
- effects = []
161
- if normalize_volume:
162
- effects.append(["gain", "-n"])
163
- if to_sample_rate is not None and to_sample_rate != sample_rate:
164
- effects.append(["rate", f"{to_sample_rate}"])
165
- if to_mono and waveform.shape[0] > 1:
166
- effects.append(["channels", "1"])
167
- if len(effects) > 0:
168
- is_np_input = isinstance(waveform, np.ndarray)
169
- _waveform = torch.from_numpy(waveform) if is_np_input else waveform
170
- converted, converted_sample_rate = ta_sox.apply_effects_tensor(
171
- _waveform, sample_rate, effects
172
- )
173
- if is_np_input:
174
- converted = converted.numpy()
175
- return converted, converted_sample_rate
176
- return waveform, sample_rate
177
-
178
- class SimulevalTranscoder:
179
- def __init__(self, agent, sample_rate, debug, buffer_limit):
180
- # agent is stateless
181
- self.agent = agent
182
- self.input_queue = asyncio.Queue()
183
- self.output_queue = asyncio.Queue()
184
- self.states = self.agent.build_states()
185
- if debug:
186
- self.get_states_root().debug = True
187
- self.incoming_sample_rate = sample_rate
188
- self.close = False
189
- self.g2p = G2p()
190
-
191
- # buffer all outgoing translations within this amount of time
192
- self.output_buffer_idle_ms = 5000
193
- self.output_buffer_size_limit = (
194
- buffer_limit # phonemes for text, seconds for speech
195
- )
196
- self.output_buffer_cur_size = 0
197
- self.output_buffer: List[List[Segment]] = []
198
- self.speech_output_sample_rate = None
199
-
200
- self.last_output_ts = time.time() * 1000
201
- self.timeout_ms = (
202
- 30000 # close the transcoder thread after this amount of silence
203
- )
204
- self.first_input_ts = None
205
- self.first_output_ts = None
206
- self.debug = debug
207
- self.debug_ts = f"{time.time()}_{random.randint(1000, 9999)}"
208
- if self.debug:
209
- debug_folder = Path(__file__).resolve().parent.parent / "debug"
210
- self.test_incoming_wav = soundfile.SoundFile(
211
- debug_folder / f"{self.debug_ts}_test_incoming.wav",
212
- mode="w+",
213
- format="WAV",
214
- subtype="PCM_16",
215
- samplerate=self.incoming_sample_rate,
216
- channels=1,
217
- )
218
- self.get_states_root().test_input_segments_wav = soundfile.SoundFile(
219
- debug_folder / f"{self.debug_ts}_test_input_segments.wav",
220
- mode="w+",
221
- format="WAV",
222
- samplerate=MODEL_SAMPLE_RATE,
223
- channels=1,
224
- )
225
-
226
- def get_states_root(self) -> AgentStates:
227
- if isinstance(self.agent, TreeAgentPipeline):
228
- # self.states is a dict
229
- return self.states[self.agent.source_module]
230
- else:
231
- # self.states is a list
232
- return self.states[0]
233
-
234
- def reset_states(self):
235
- if isinstance(self.agent, TreeAgentPipeline):
236
- states_iter = self.states.values()
237
- else:
238
- states_iter = self.states
239
- for state in states_iter:
240
- state.reset()
241
-
242
- def debug_log(self, *args):
243
- if self.debug:
244
- logger.info(*args)
245
-
246
- def process_incoming_bytes(self, incoming_bytes, target_language, sample_rate):
247
- # TODO: currently just taking sample rate here, refactor sample rate
248
- # bytes is 16bit signed int
249
- self.incoming_sample_rate = sample_rate
250
- segment, sr = self._preprocess_wav(incoming_bytes)
251
-
252
- segment = SpeechSegment(
253
- content=segment, sample_rate=sr, tgt_lang=target_language
254
- )
255
- # # segment is array([0, 0, 0, ..., 0, 0, 0], dtype=int16)
256
- self.input_queue.put_nowait(segment)
257
- print("process_incoming: put input_queue")
258
-
259
- def get_input_segment(self):
260
- if self.input_queue.empty():
261
- return None
262
- chunk = self.input_queue.get_nowait()
263
- self.input_queue.task_done()
264
- return chunk
265
-
266
- def _preprocess_wav(self, data: Any) -> Tuple[np.ndarray, int]:
267
- segment, sample_rate = soundfile.read(
268
- io.BytesIO(data),
269
- dtype="float32",
270
- always_2d=True,
271
- frames=-1,
272
- start=0,
273
- format="RAW",
274
- subtype="PCM_16",
275
- samplerate=self.incoming_sample_rate,
276
- channels=1,
277
- )
278
- if self.debug:
279
- self.test_incoming_wav.seek(0, soundfile.SEEK_END)
280
- self.test_incoming_wav.write(segment)
281
-
282
- segment = segment.T
283
- segment, new_sample_rate = convert_waveform(
284
- segment,
285
- sample_rate,
286
- normalize_volume=False,
287
- to_mono=True,
288
- to_sample_rate=MODEL_SAMPLE_RATE,
289
- )
290
-
291
- assert MODEL_SAMPLE_RATE == new_sample_rate
292
- segment = segment.squeeze(axis=0)
293
- return segment, new_sample_rate
294
-
295
- def process_pipeline_impl(self, input_segment):
296
- try:
297
- with torch.no_grad():
298
- output_segment = OutputSegments(
299
- self.agent.pushpop(input_segment, self.states)
300
- )
301
- if (
302
- self.get_states_root().first_input_ts is not None
303
- and self.first_input_ts is None
304
- ):
305
- # TODO: this is hacky
306
- self.first_input_ts = self.get_states_root().first_input_ts
307
-
308
- if not output_segment.is_empty:
309
- print("PUT IN OUTPUT QUEUE")
310
- self.output_queue.put_nowait(output_segment)
311
-
312
- if output_segment.finished:
313
- print("OUTPUT SEGMENT IS FINISHED. Resetting states.")
314
-
315
- self.reset_states()
316
-
317
- if self.debug:
318
- # when we rebuild states, this value is reset to whatever
319
- # is in the system dir config, which defaults debug=False.
320
- self.get_states_root().debug = True
321
- except Exception as e:
322
- logger.error(f"Got exception while processing pipeline: {e}")
323
- traceback.print_exc()
324
- return input_segment
325
-
326
- def process_pipeline_loop(self):
327
- if self.close:
328
- print("transcoder closed")
329
- return # closes the thread
330
-
331
- print("processing_pipeline")
332
- while not self.close:
333
- input_segment = self.get_input_segment()
334
- if input_segment is None:
335
- if self.get_states_root().is_fresh_state: # TODO: this is hacky
336
- time.sleep(0.3)
337
- print("loop: input_queue empty")
338
- else:
339
- time.sleep(0.03)
340
- continue
341
- print("loop: got input_segment")
342
- self.process_pipeline_impl(input_segment)
343
- print("finished processing_pipeline")
344
-
345
- def process_pipeline_once(self):
346
- if self.close:
347
- return
348
-
349
- self.debug_log("processing pipeline once")
350
- input_segment = self.get_input_segment()
351
- if input_segment is None:
352
- return
353
- self.process_pipeline_impl(input_segment)
354
- self.debug_log("finished processing_pipeline_once")
355
-
356
- def get_output_segment(self):
357
- if self.output_queue.empty():
358
- return None
359
-
360
- output_chunk = self.output_queue.get_nowait()
361
- self.output_queue.task_done()
362
- return output_chunk
363
-
364
- def start(self):
365
- print("starting transcoder in a thread")
366
- threading.Thread(target=self.process_pipeline_loop).start()
367
-
368
- def first_translation_time(self):
369
- return round((self.first_output_ts - self.first_input_ts) / 1000, 2)
370
-
371
- def get_buffered_output(self) -> SpeechAndTextOutput:
372
- now = time.time() * 1000
373
- print(f"get_buffered_output queue size: {self.output_queue.qsize()}")
374
- while not self.output_queue.empty():
375
- tmp_out = self.get_output_segment()
376
- if tmp_out and tmp_out.compute_length(self.g2p) > 0:
377
- if len(self.output_buffer) == 0:
378
- self.last_output_ts = now
379
- self._populate_output_buffer(tmp_out)
380
- self._increment_output_buffer_size(tmp_out)
381
-
382
- if tmp_out.finished:
383
- self.debug_log("tmp_out.finished")
384
- res = self._gather_output_buffer_data(final=True)
385
- self.debug_log(f"gathered output data: {res}")
386
- self.output_buffer = []
387
- self.increment_output_buffer_size = 0
388
- self.last_output_ts = now
389
- self.first_output_ts = now
390
- return res
391
- else:
392
- self.debug_log("tmp_out.compute_length is not > 0")
393
-
394
- if len(self.output_buffer) > 0 and (
395
- now - self.last_output_ts >= self.output_buffer_idle_ms
396
- or self.output_buffer_cur_size >= self.output_buffer_size_limit
397
- ):
398
- self.debug_log(
399
- "[get_buffered_output] output_buffer is not empty. getting res to return."
400
- )
401
- self.last_output_ts = now
402
- res = self._gather_output_buffer_data(final=False)
403
- self.debug_log(f"gathered output data: {res}")
404
- self.output_buffer = []
405
- self.output_buffer_phoneme_count = 0
406
- self.first_output_ts = now
407
- return res
408
- else:
409
- self.debug_log("[get_buffered_output] output_buffer is empty...")
410
- return None
411
-
412
- def _gather_output_buffer_data(self, final):
413
- output = SpeechAndTextOutput()
414
- output.final = final
415
- output = OutputSegments.join_output_buffer(self.output_buffer, output)
416
- return output
417
-
418
- def _increment_output_buffer_size(self, segment: OutputSegments):
419
- self.output_buffer_cur_size += segment.compute_length(self.g2p)
420
-
421
- def _populate_output_buffer(self, segment: OutputSegments):
422
- self.output_buffer.append(segment.segments)
423
-
424
- def _compute_phoneme_count(self, string: str) -> int:
425
- return len([x for x in self.g2p(string) if x != " "])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
style.css DELETED
@@ -1,16 +0,0 @@
1
- h1 {
2
- text-align: center;
3
- }
4
-
5
- #duplicate-button {
6
- margin: auto;
7
- color: #fff;
8
- background: #1565c0;
9
- border-radius: 100vh;
10
- }
11
-
12
- #component-0 {
13
- max-width: 730px;
14
- margin: auto;
15
- padding-top: 1.5rem;
16
- }