Anna Sun commited on
Commit
c1e0588
1 Parent(s): 7fb1760

Model fixes

Browse files
Files changed (3) hide show
  1. app.py +102 -146
  2. models/vad_s2st_sc_24khz_main.yaml +24 -0
  3. simuleval_transcoder.py +18 -49
app.py CHANGED
@@ -1,24 +1,15 @@
1
  from __future__ import annotations
2
 
3
- import os
4
-
5
  import gradio as gr
6
  import numpy as np
7
- import torch
8
- import torchaudio
9
- import sys
10
- from sample_wav import sample_wav
11
- np.set_printoptions(threshold=sys.maxsize)
12
 
13
- from simuleval_transcoder import *
 
14
 
15
- from pydub import AudioSegment
16
  import time
17
- from time import sleep
 
18
 
19
- from seamless_communication.cli.streaming.agents.tt_waitk_unity_s2t_m4t import (
20
- TestTimeWaitKUnityS2TM4T,
21
- )
22
 
23
  language_code_to_name = {
24
  "cmn": "Mandarin Chinese",
@@ -32,7 +23,17 @@ LANGUAGE_NAME_TO_CODE = {v: k for k, v in language_code_to_name.items()}
32
 
33
  DEFAULT_TARGET_LANGUAGE = "English"
34
 
35
- # TODO: Update this so it takes in target langs from input, refactor sample rate
 
 
 
 
 
 
 
 
 
 
36
  transcoder = SimulevalTranscoder(
37
  sample_rate=48_000,
38
  debug=False,
@@ -41,93 +42,97 @@ transcoder = SimulevalTranscoder(
41
 
42
  def start_recording():
43
  logger.debug(f"start_recording: starting transcoder")
 
44
  transcoder.start()
 
45
 
 
 
46
 
47
- def translate_audio_segment(audio):
48
- logger.debug(f"translate_audio_segment: incoming audio")
49
- sample_rate, data = audio
 
50
 
51
- # print(sample_rate)
52
- # print("--------- start \n")
53
- # # print(data)
54
- # def map(x):
55
- # return x
56
- # print(data.tolist())
57
- # print("--------- end \n")
58
 
 
59
 
 
 
 
60
  transcoder.process_incoming_bytes(data.tobytes(), 'eng', sample_rate)
 
 
 
 
 
61
 
62
  speech_and_text_output = transcoder.get_buffered_output()
63
  if speech_and_text_output is None:
64
  logger.debug("No output from transcoder.get_buffered_output()")
65
- return None, None
66
 
67
- logger.debug(f"We DID get output from the transcoder! {speech_and_text_output}")
68
 
69
  text = None
70
  speech = None
71
 
72
  if speech_and_text_output.speech_samples:
73
- speech = (speech_and_text_output.speech_samples, speech_and_text_output.speech_sample_rate)
74
 
75
  if speech_and_text_output.text:
76
  text = speech_and_text_output.text
77
  if speech_and_text_output.final:
78
  text += "\n"
79
 
80
- return speech, text
81
-
82
- def dummy_ouput():
83
- np.array()
84
-
85
- def streaming_input_callback(
86
- audio_file, translated_audio_bytes_state, translated_text_state
87
- ):
88
- translated_wav_segment, translated_text = translate_audio_segment(audio_file)
89
- logger.debug(f'translated_audio_bytes_state {translated_audio_bytes_state}')
90
- logger.debug(f'translated_wav_segment {translated_wav_segment}')
91
-
92
- # TODO: accumulate each segment to provide a continuous audio segment
93
-
94
- # TEMP
95
- translated_wav_segment = (46_000, sample_wav())
96
-
97
- if translated_wav_segment is not None:
98
- sample_rate, audio_bytes = translated_wav_segment
99
- # TODO: convert to 16 bit int
100
- # audio_np_array = np.frombuffer(audio_bytes, dtype=np.float32, count=3)
101
- audio_np_array = audio_bytes
102
-
103
-
104
- # combine translated wav
105
- if type(translated_audio_bytes_state) is not tuple:
106
- translated_audio_bytes_state = (sample_rate, audio_np_array)
107
- # translated_audio_bytes_state = np.array([])
108
  else:
109
-
110
- translated_audio_bytes_state = (translated_audio_bytes_state[0], np.append(translated_audio_bytes_state[1], translated_wav_segment[1]))
111
-
112
- if translated_text is not None:
113
- translated_text_state += " | " + str(translated_text)
114
-
115
- # most_recent_input_audio_segment = (most_recent_input_audio_segment[0], np.append(most_recent_input_audio_segment[1], audio_file[1]))
116
-
117
- # Not necessary but for readability.
118
- most_recent_input_audio_segment = audio_file
119
- translated_wav_segment = translated_wav_segment
120
- output_translation_combined = translated_audio_bytes_state
121
- stream_output_text = translated_text_state
122
- return [
123
- most_recent_input_audio_segment,
124
- translated_wav_segment,
125
- output_translation_combined,
126
- stream_output_text,
127
- translated_audio_bytes_state,
128
- translated_text_state,
129
- ]
130
-
 
 
 
 
131
 
132
  def clear():
133
  logger.debug(f"Clearing State")
@@ -138,105 +143,56 @@ def blocks():
138
  with gr.Blocks() as demo:
139
 
140
  with gr.Row():
141
- # Hook this up once supported
142
  target_language = gr.Dropdown(
143
  label="Target language",
144
  choices=S2ST_TARGET_LANGUAGE_NAMES,
145
  value=DEFAULT_TARGET_LANGUAGE,
146
  )
147
 
148
- translated_audio_bytes_state = gr.State(None)
149
  translated_text_state = gr.State("")
150
 
151
  input_audio = gr.Audio(
152
  label="Input Audio",
153
- # source="microphone", # gradio==3.41.0
154
- sources=["microphone"], # new gradio seems to call this less often...
155
  streaming=True,
156
  )
157
 
158
- # input_audio = gr.Audio(
159
- # label="Input Audio",
160
- # type="filepath",
161
- # source="microphone",
162
- # streaming=True,
163
- # )
164
-
165
- most_recent_input_audio_segment = gr.Audio(
166
- label="Recent Input Audio Segment segments",
167
- # format="bytes",
168
- streaming=True
169
- )
170
-
171
- # Force translate
172
- stream_as_bytes_btn = gr.Button("Force translate most recent recording segment (ask for model output)")
173
  output_translation_segment = gr.Audio(
174
  label="Translated audio segment",
175
- autoplay=False,
176
- streaming=True,
177
- type="numpy",
178
- )
179
-
180
- output_translation_combined = gr.Audio(
181
- label="Translated audio combined",
182
- autoplay=False,
183
  streaming=True,
184
- type="numpy",
185
  )
186
 
187
- # Could add output text segment
188
  stream_output_text = gr.Textbox(label="Translated text")
189
 
190
- stream_as_bytes_btn.click(
191
- streaming_input_callback,
192
- [input_audio, translated_audio_bytes_state, translated_text_state],
193
- [
194
- most_recent_input_audio_segment,
195
- output_translation_segment,
196
- output_translation_combined,
197
- stream_output_text,
198
- translated_audio_bytes_state,
199
- translated_text_state,
200
- ],
201
  )
202
-
203
- # input_audio.change(
204
- # streaming_input_callback,
205
- # [input_audio, translated_audio_bytes_state, translated_text_state],
206
- # [
207
- # most_recent_input_audio_segment,
208
- # output_translation_segment,
209
- # output_translation_combined,
210
- # stream_output_text,
211
- # translated_audio_bytes_state,
212
- # translated_text_state,
213
- # ],
214
- # )
215
-
216
- input_audio.stream(
217
  streaming_input_callback,
218
- [input_audio, translated_audio_bytes_state, translated_text_state],
219
  [
220
- most_recent_input_audio_segment,
221
  output_translation_segment,
222
- output_translation_combined,
223
  stream_output_text,
224
- translated_audio_bytes_state,
225
  translated_text_state,
226
  ],
227
  )
228
-
229
- input_audio.start_recording(
230
- start_recording,
231
- )
232
-
233
- input_audio.clear(
234
- clear, None, [translated_audio_bytes_state, translated_text_state]
235
  )
236
- input_audio.start_recording(
237
- clear, None, [translated_audio_bytes_state, translated_text_state]
238
  )
239
 
240
- demo.queue().launch()
241
 
242
  blocks()
 
1
  from __future__ import annotations
2
 
 
 
3
  import gradio as gr
4
  import numpy as np
 
 
 
 
 
5
 
6
+ import asyncio
7
+ from simuleval_transcoder import SimulevalTranscoder, logger
8
 
 
9
  import time
10
+ from simuleval.utils.agent import build_system_from_dir
11
+ import torch
12
 
 
 
 
13
 
14
  language_code_to_name = {
15
  "cmn": "Mandarin Chinese",
 
23
 
24
  DEFAULT_TARGET_LANGUAGE = "English"
25
 
26
+
27
+ def build_agent(model_path, config_name=None):
28
+ agent = build_system_from_dir(
29
+ model_path, config_name=config_name,
30
+ )
31
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+ agent.to(device, fp16=True)
33
+
34
+ return agent
35
+
36
+ agent = build_agent("models", "vad_s2st_sc_24khz_main.yaml")
37
  transcoder = SimulevalTranscoder(
38
  sample_rate=48_000,
39
  debug=False,
 
42
 
43
  def start_recording():
44
  logger.debug(f"start_recording: starting transcoder")
45
+ transcoder.reset_states()
46
  transcoder.start()
47
+ transcoder.close = False
48
 
49
+ def stop_recording():
50
+ transcoder.close = True
51
 
52
+ class MyState:
53
+ def __init__(self):
54
+ self.queue = asyncio.Queue()
55
+ self.close = False
56
 
 
 
 
 
 
 
 
57
 
58
+ s = MyState()
59
 
60
+ def process_incoming_bytes(audio):
61
+ logger.debug(f"process_bytes: incoming audio")
62
+ sample_rate, data = audio
63
  transcoder.process_incoming_bytes(data.tobytes(), 'eng', sample_rate)
64
+ s.queue.put_nowait(audio)
65
+
66
+
67
+
68
+ def get_buffered_output():
69
 
70
  speech_and_text_output = transcoder.get_buffered_output()
71
  if speech_and_text_output is None:
72
  logger.debug("No output from transcoder.get_buffered_output()")
73
+ return None, None, None
74
 
75
+ logger.debug(f"We DID get output from the transcoder!")
76
 
77
  text = None
78
  speech = None
79
 
80
  if speech_and_text_output.speech_samples:
81
+ speech = (speech_and_text_output.speech_sample_rate, speech_and_text_output.speech_samples)
82
 
83
  if speech_and_text_output.text:
84
  text = speech_and_text_output.text
85
  if speech_and_text_output.final:
86
  text += "\n"
87
 
88
+ return speech, text, speech_and_text_output.final
89
+
90
+ def streaming_input_callback():
91
+ final = False
92
+ max_wait_s = 15
93
+ wait_s = 0
94
+ translated_text_state = ""
95
+ while not transcoder.close:
96
+ translated_wav_segment, translated_text, final = get_buffered_output()
97
+
98
+ if translated_wav_segment is None and translated_text is None:
99
+ time.sleep(0.3)
100
+ wait_s += 0.3
101
+ if wait_s >= max_wait_s:
102
+ transcoder.close = True
103
+ continue
104
+ wait_s = 0
105
+ if translated_wav_segment is not None:
106
+ sample_rate, audio_bytes = translated_wav_segment
107
+ print("output sample rate", sample_rate)
108
+ translated_wav_segment = sample_rate, np.array(audio_bytes)
 
 
 
 
 
 
 
109
  else:
110
+ translated_wav_segment = bytes()
111
+
112
+ if translated_text is not None:
113
+ translated_text_state += " | " + str(translated_text)
114
+
115
+ stream_output_text = translated_text_state
116
+ if translated_text is not None:
117
+ print("translated:", translated_text_state)
118
+ yield [
119
+ translated_wav_segment,
120
+ stream_output_text,
121
+ translated_text_state,
122
+ ]
123
+
124
+
125
+ def streaming_callback_dummy():
126
+ while not transcoder.close:
127
+ if s.queue.empty():
128
+ print("empty")
129
+ yield bytes()
130
+ time.sleep(0.3)
131
+ else:
132
+ print("audio")
133
+ audio = s.queue.get_nowait()
134
+ s.queue.task_done()
135
+ yield audio
136
 
137
  def clear():
138
  logger.debug(f"Clearing State")
 
143
  with gr.Blocks() as demo:
144
 
145
  with gr.Row():
146
+ # TODO: add target language switching
147
  target_language = gr.Dropdown(
148
  label="Target language",
149
  choices=S2ST_TARGET_LANGUAGE_NAMES,
150
  value=DEFAULT_TARGET_LANGUAGE,
151
  )
152
 
 
153
  translated_text_state = gr.State("")
154
 
155
  input_audio = gr.Audio(
156
  label="Input Audio",
157
+ sources=["microphone"],
 
158
  streaming=True,
159
  )
160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  output_translation_segment = gr.Audio(
162
  label="Translated audio segment",
163
+ autoplay=True,
 
 
 
 
 
 
 
164
  streaming=True,
 
165
  )
166
 
167
+ # Output text segment
168
  stream_output_text = gr.Textbox(label="Translated text")
169
 
170
+ input_audio.clear(
171
+ clear, None, [output_translation_segment, translated_text_state]
 
 
 
 
 
 
 
 
 
172
  )
173
+ input_audio.start_recording(
174
+ clear, None, [output_translation_segment, translated_text_state]
175
+ ).then(
176
+ start_recording
177
+ ).then(
178
+ # streaming_callback_dummy, # TODO: autoplay works fine with streaming_callback_dummy
179
+ # None,
180
+ # output_translation_segment
 
 
 
 
 
 
 
181
  streaming_input_callback,
182
+ None,
183
  [
 
184
  output_translation_segment,
 
185
  stream_output_text,
 
186
  translated_text_state,
187
  ],
188
  )
189
+ input_audio.stop_recording(
190
+ stop_recording
 
 
 
 
 
191
  )
192
+ input_audio.stream(
193
+ process_incoming_bytes, [input_audio], None
194
  )
195
 
196
+ demo.launch(server_port=6010)
197
 
198
  blocks()
models/vad_s2st_sc_24khz_main.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ agent_class: seamless_communication.streaming.agents.mma_m4t_s2st.SeamlessS2STJointVADAgent
2
+ # checkpoint: checkpoint_best.pt
3
+ monotonic_decoder_model_name: seamless_streaming_monotonic_decoder
4
+ unity_model_name: seamless_streaming_unity
5
+ sentencepiece_model: spm_256k_nllb100.model
6
+
7
+ task: s2st
8
+ tgt_lang: "eng"
9
+ min_unit_chunk_size: 50
10
+ decision_threshold: 0.7
11
+ no_early_stop: True
12
+ block_ngrams: True
13
+ vocoder_name: vocoder_pretssel
14
+ wav2vec_yaml: wav2vec.yaml
15
+ # min_starting_wait: 12
16
+ # min_starting_wait_w2vbert: 192
17
+
18
+ config_yaml: cfg_fbank_u2t.yaml
19
+ vocoder_sample_rate: 24000
20
+ upstream_idx: 1
21
+ detokenize_only: True
22
+ device: cuda:0
23
+ max_len_a: 0
24
+ max_len_b: 1000
simuleval_transcoder.py CHANGED
@@ -20,13 +20,6 @@ import time
20
  import random
21
  import colorlog
22
 
23
- # Sanity check that pipeline is loadable
24
- from seamless_communication.cli.streaming.agents.tt_waitk_unity_s2t_m4t import (
25
- # TestTimeWaitKUnityS2TM4T,
26
- TestTimeWaitKUnityS2TM4TVAD
27
- )
28
-
29
- from simuleval.utils.agent import build_system_args
30
 
31
  MODEL_SAMPLE_RATE = 16_000
32
 
@@ -49,35 +42,6 @@ logger.addHandler(handler)
49
  logger.setLevel(logging.DEBUG)
50
 
51
 
52
- # TODO: Integrate this better so target lang and others can be changed. Also currently dependent on devserver internals
53
- def build_agent():
54
- config = {
55
- 'dataloader': 'fairseq2_s2t',
56
- 'data_file': '/large_experiments/seamless/ust/abinesh/data/s2st50_manifests/50-10/simuleval/dev_mtedx_filt_50-10_debug.tsv',
57
- 'model_name': 'seamlessM4T_v2_large',
58
- 'device': 'cuda:0',
59
- 'source_segment_size': 320,
60
- 'waitk_lagging': 7,
61
- 'fixed_pre_decision_ratio': 2,
62
- 'init_target_tokens': '</s> __eng__',
63
- 'max_len_a': 0,
64
- 'max_len_b': 200,
65
- 'agent_class': 'seamless_communication.cli.streaming.agents.tt_waitk_unity_s2t_m4t.TestTimeWaitKUnityS2TM4TVAD',
66
- 'task': 's2st',
67
- 'tgt_lang': 'eng',
68
- 'latency_metrics': 'StartOffset EndOffset AL',
69
- 'output': 'TestTimeWaitKUnityS2TM4TVAD-wait7-debug'
70
- }
71
-
72
- agent , _ = build_system_args(config)
73
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
74
- # agent.to(device, fp16=True)
75
- logger.info(
76
- f"Successfully built simuleval agent"
77
- )
78
-
79
- return agent
80
-
81
  class SpeechAndTextOutput:
82
  def __init__(
83
  self,
@@ -150,7 +114,7 @@ class OutputSegments:
150
  for segment in segment_list:
151
  speech_out += segment.content
152
  output.speech_samples = speech_out
153
- output.speech_sample_rate = MODEL_SAMPLE_RATE
154
  elif isinstance(segment_list[0], EmptySegment):
155
  continue
156
  else:
@@ -212,8 +176,9 @@ def convert_waveform(
212
  return waveform, sample_rate
213
 
214
  class SimulevalTranscoder:
215
- def __init__(self, sample_rate, debug, buffer_limit):
216
- self.agent = build_agent()
 
217
  self.input_queue = asyncio.Queue()
218
  self.output_queue = asyncio.Queue()
219
  self.states = self.agent.build_states()
@@ -289,6 +254,7 @@ class SimulevalTranscoder:
289
  )
290
  # # segment is array([0, 0, 0, ..., 0, 0, 0], dtype=int16)
291
  self.input_queue.put_nowait(segment)
 
292
 
293
  def get_input_segment(self):
294
  if self.input_queue.empty():
@@ -340,10 +306,11 @@ class SimulevalTranscoder:
340
  self.first_input_ts = self.get_states_root().first_input_ts
341
 
342
  if not output_segment.is_empty:
 
343
  self.output_queue.put_nowait(output_segment)
344
 
345
  if output_segment.finished:
346
- self.debug_log("OUTPUT SEGMENT IS FINISHED. Resetting states.")
347
 
348
  self.reset_states()
349
 
@@ -360,17 +327,19 @@ class SimulevalTranscoder:
360
  if self.close:
361
  return # closes the thread
362
 
363
- self.debug_log("processing_pipeline")
364
  while not self.close:
365
  input_segment = self.get_input_segment()
366
  if input_segment is None:
367
- # if self.get_states_root().is_fresh_state: # TODO: this is hacky
368
- # time.sleep(0.3)
369
- # else:
370
- time.sleep(0.03)
 
371
  continue
 
372
  self.process_pipeline_impl(input_segment)
373
- self.debug_log("finished processing_pipeline")
374
 
375
  def process_pipeline_once(self):
376
  if self.close:
@@ -392,7 +361,7 @@ class SimulevalTranscoder:
392
  return output_chunk
393
 
394
  def start(self):
395
- self.debug_log("starting transcoder in a thread")
396
  threading.Thread(target=self.process_pipeline_loop).start()
397
 
398
  def first_translation_time(self):
@@ -400,7 +369,7 @@ class SimulevalTranscoder:
400
 
401
  def get_buffered_output(self) -> SpeechAndTextOutput:
402
  now = time.time() * 1000
403
- self.debug_log(f"get_buffered_output queue size: {self.output_queue.qsize()}")
404
  while not self.output_queue.empty():
405
  tmp_out = self.get_output_segment()
406
  if tmp_out and tmp_out.compute_length(self.g2p) > 0:
@@ -452,4 +421,4 @@ class SimulevalTranscoder:
452
  self.output_buffer.append(segment.segments)
453
 
454
  def _compute_phoneme_count(self, string: str) -> int:
455
- return len([x for x in self.g2p(string) if x != " "])
 
20
  import random
21
  import colorlog
22
 
 
 
 
 
 
 
 
23
 
24
  MODEL_SAMPLE_RATE = 16_000
25
 
 
42
  logger.setLevel(logging.DEBUG)
43
 
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  class SpeechAndTextOutput:
46
  def __init__(
47
  self,
 
114
  for segment in segment_list:
115
  speech_out += segment.content
116
  output.speech_samples = speech_out
117
+ output.speech_sample_rate = segment.sample_rate
118
  elif isinstance(segment_list[0], EmptySegment):
119
  continue
120
  else:
 
176
  return waveform, sample_rate
177
 
178
  class SimulevalTranscoder:
179
+ def __init__(self, agent, sample_rate, debug, buffer_limit):
180
+ # agent is stateless
181
+ self.agent = agent
182
  self.input_queue = asyncio.Queue()
183
  self.output_queue = asyncio.Queue()
184
  self.states = self.agent.build_states()
 
254
  )
255
  # # segment is array([0, 0, 0, ..., 0, 0, 0], dtype=int16)
256
  self.input_queue.put_nowait(segment)
257
+ print("process_incoming: put input_queue")
258
 
259
  def get_input_segment(self):
260
  if self.input_queue.empty():
 
306
  self.first_input_ts = self.get_states_root().first_input_ts
307
 
308
  if not output_segment.is_empty:
309
+ print("PUT IN OUTPUT QUEUE")
310
  self.output_queue.put_nowait(output_segment)
311
 
312
  if output_segment.finished:
313
+ print("OUTPUT SEGMENT IS FINISHED. Resetting states.")
314
 
315
  self.reset_states()
316
 
 
327
  if self.close:
328
  return # closes the thread
329
 
330
+ print("processing_pipeline")
331
  while not self.close:
332
  input_segment = self.get_input_segment()
333
  if input_segment is None:
334
+ if self.get_states_root().is_fresh_state: # TODO: this is hacky
335
+ time.sleep(0.3)
336
+ print("loop: input_queue empty")
337
+ else:
338
+ time.sleep(0.03)
339
  continue
340
+ print("loop: got input_segment")
341
  self.process_pipeline_impl(input_segment)
342
+ print("finished processing_pipeline")
343
 
344
  def process_pipeline_once(self):
345
  if self.close:
 
361
  return output_chunk
362
 
363
  def start(self):
364
+ print("starting transcoder in a thread")
365
  threading.Thread(target=self.process_pipeline_loop).start()
366
 
367
  def first_translation_time(self):
 
369
 
370
  def get_buffered_output(self) -> SpeechAndTextOutput:
371
  now = time.time() * 1000
372
+ print(f"get_buffered_output queue size: {self.output_queue.qsize()}")
373
  while not self.output_queue.empty():
374
  tmp_out = self.get_output_segment()
375
  if tmp_out and tmp_out.compute_length(self.g2p) > 0:
 
421
  self.output_buffer.append(segment.segments)
422
 
423
  def _compute_phoneme_count(self, string: str) -> int:
424
+ return len([x for x in self.g2p(string) if x != " "])