Mark Duppenthaler commited on
Commit
2d522b6
1 Parent(s): 1727d3b

Update with temp work

Browse files
app.py CHANGED
@@ -10,6 +10,8 @@ from seamless_communication.models.inference.translator import Translator
10
 
11
 
12
  from m4t_app import *
 
 
13
 
14
  from pydub import AudioSegment
15
  import time
@@ -19,6 +21,7 @@ from time import sleep
19
 
20
  USE_M4T = True
21
 
 
22
 
23
  def translate_audio_file_segment(audio_file):
24
  print("translate_m4t state")
@@ -90,7 +93,9 @@ def blocks():
90
  )
91
 
92
  most_recent_input_audio_segment = gr.Audio(
93
- label="Recent Input Audio Segment segments", format="bytes", streaming=True
 
 
94
  )
95
  # TODO: Should add combined input audio segments...
96
 
 
10
 
11
 
12
  from m4t_app import *
13
+ from simuleval_transcoder import *
14
+ # from simuleval_transcoder import *
15
 
16
  from pydub import AudioSegment
17
  import time
 
21
 
22
  USE_M4T = True
23
 
24
+ Transcoder = SimulevalTranscoder()
25
 
26
  def translate_audio_file_segment(audio_file):
27
  print("translate_m4t state")
 
93
  )
94
 
95
  most_recent_input_audio_segment = gr.Audio(
96
+ label="Recent Input Audio Segment segments",
97
+ format="bytes",
98
+ streaming=True
99
  )
100
  # TODO: Should add combined input audio segments...
101
 
internal_demo_simuleval_transcoder.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from simuleval.utils.agent import build_system_from_dir
2
+ from typing import Any, Tuple
3
+ import numpy as np
4
+ import soundfile
5
+ from fairseq.data.audio.audio_utils import convert_waveform
6
+ import io
7
+ import asyncio
8
+ from simuleval.data.segments import SpeechSegment, EmptySegment
9
+ import threading
10
+ import math
11
+ import logging
12
+ import sys
13
+ from pathlib import Path
14
+ import time
15
+ from g2p_en import G2p
16
+ import torch
17
+ import traceback
18
+ import time
19
+ import random
20
+
21
+ from .speech_and_text_output import SpeechAndTextOutput
22
+
23
+ MODEL_SAMPLE_RATE = 16_000
24
+
25
+ logger = logging.getLogger()
26
+ logger.addHandler(logging.StreamHandler(sys.stdout))
27
+
28
+
29
+ class SimulevalTranscoder:
30
+ def __init__(self, agent, sample_rate, debug, buffer_limit):
31
+ self.agent = agent
32
+ self.input_queue = asyncio.Queue()
33
+ self.output_queue = asyncio.Queue()
34
+ self.states = self.agent.build_states()
35
+ if debug:
36
+ self.states[0].debug = True
37
+ self.incoming_sample_rate = sample_rate
38
+ self.close = False
39
+ self.g2p = G2p()
40
+
41
+ # buffer all outgoing translations within this amount of time
42
+ self.output_buffer_idle_ms = 5000
43
+ self.output_buffer_size_limit = (
44
+ buffer_limit # phonemes for text, seconds for speech
45
+ )
46
+ self.output_buffer_cur_size = 0
47
+ self.output_buffer = []
48
+ self.speech_output_sample_rate = None
49
+
50
+ self.last_output_ts = time.time() * 1000
51
+ self.timeout_ms = (
52
+ 30000 # close the transcoder thread after this amount of silence
53
+ )
54
+ self.first_input_ts = None
55
+ self.first_output_ts = None
56
+ self.output_data_type = None # speech or text
57
+ self.debug = debug
58
+ self.debug_ts = f"{time.time()}_{random.randint(1000, 9999)}"
59
+ if self.debug:
60
+ debug_folder = Path(__file__).resolve().parent.parent / "debug"
61
+ self.test_incoming_wav = soundfile.SoundFile(
62
+ debug_folder / f"{self.debug_ts}_test_incoming.wav",
63
+ mode="w+",
64
+ format="WAV",
65
+ subtype="PCM_16",
66
+ samplerate=self.incoming_sample_rate,
67
+ channels=1,
68
+ )
69
+ self.states[0].test_input_segments_wav = soundfile.SoundFile(
70
+ debug_folder / f"{self.debug_ts}_test_input_segments.wav",
71
+ mode="w+",
72
+ format="WAV",
73
+ samplerate=MODEL_SAMPLE_RATE,
74
+ channels=1,
75
+ )
76
+
77
+ def debug_log(self, *args):
78
+ if self.debug:
79
+ logger.info(*args)
80
+
81
+ @classmethod
82
+ def build_agent(cls, model_path):
83
+ logger.info(f"Building simuleval agent: {model_path}")
84
+ agent = build_system_from_dir(
85
+ Path(__file__).resolve().parent.parent / f"models/{model_path}",
86
+ config_name="vad_main.yaml",
87
+ )
88
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
89
+ agent.to(device, fp16=True)
90
+ logger.info(
91
+ f"Successfully built simuleval agent {model_path} on device {device}"
92
+ )
93
+
94
+ return agent
95
+
96
+ def process_incoming_bytes(self, incoming_bytes):
97
+ segment, _sr = self._preprocess_wav(incoming_bytes)
98
+ # # segment is array([0, 0, 0, ..., 0, 0, 0], dtype=int16)
99
+ self.input_queue.put_nowait(segment)
100
+
101
+ def get_input_segment(self):
102
+ if self.input_queue.empty():
103
+ return None
104
+ chunk = self.input_queue.get_nowait()
105
+ self.input_queue.task_done()
106
+ return chunk
107
+
108
+ def _preprocess_wav(self, data: Any) -> Tuple[np.ndarray, int]:
109
+ segment, sample_rate = soundfile.read(
110
+ io.BytesIO(data),
111
+ dtype="float32",
112
+ always_2d=True,
113
+ frames=-1,
114
+ start=0,
115
+ format="RAW",
116
+ subtype="PCM_16",
117
+ samplerate=self.incoming_sample_rate,
118
+ channels=1,
119
+ )
120
+ if self.debug:
121
+ self.test_incoming_wav.seek(0, soundfile.SEEK_END)
122
+ self.test_incoming_wav.write(segment)
123
+
124
+ segment = segment.T
125
+ segment, new_sample_rate = convert_waveform(
126
+ segment,
127
+ sample_rate,
128
+ normalize_volume=False,
129
+ to_mono=True,
130
+ to_sample_rate=MODEL_SAMPLE_RATE,
131
+ )
132
+
133
+ assert MODEL_SAMPLE_RATE == new_sample_rate
134
+ segment = segment.squeeze(axis=0)
135
+ return segment, new_sample_rate
136
+
137
+ def process_pipeline_impl(self, input_segment):
138
+ try:
139
+ output_segment = self.agent.pushpop(input_segment, self.states)
140
+ if (
141
+ self.states[0].first_input_ts is not None
142
+ and self.first_input_ts is None
143
+ ):
144
+ # TODO: this is hacky
145
+ self.first_input_ts = self.states[0].first_input_ts
146
+
147
+ if not output_segment.is_empty:
148
+ self.output_queue.put_nowait(output_segment)
149
+
150
+ if output_segment.finished:
151
+ self.debug_log("OUTPUT SEGMENT IS FINISHED. Resetting states.")
152
+
153
+ for state in self.states:
154
+ state.reset()
155
+
156
+ if self.debug:
157
+ # when we rebuild states, this value is reset to whatever
158
+ # is in the system dir config, which defaults debug=False.
159
+ self.states[0].debug = True
160
+ except Exception as e:
161
+ logger.error(f"Got exception while processing pipeline: {e}")
162
+ traceback.print_exc()
163
+ return input_segment
164
+
165
+ def process_pipeline_loop(self):
166
+ if self.close:
167
+ return # closes the thread
168
+
169
+ self.debug_log("processing_pipeline")
170
+ while not self.close:
171
+ input_segment = self.get_input_segment()
172
+ if input_segment is None:
173
+ if self.states[0].is_fresh_state: # TODO: this is hacky
174
+ time.sleep(0.3)
175
+ else:
176
+ time.sleep(0.03)
177
+ continue
178
+ self.process_pipeline_impl(input_segment)
179
+ self.debug_log("finished processing_pipeline")
180
+
181
+ def process_pipeline_once(self):
182
+ if self.close:
183
+ return
184
+
185
+ self.debug_log("processing pipeline once")
186
+ input_segment = self.get_input_segment()
187
+ if input_segment is None:
188
+ return
189
+ self.process_pipeline_impl(input_segment)
190
+ self.debug_log("finished processing_pipeline_once")
191
+
192
+ def get_output_segment(self):
193
+ if self.output_queue.empty():
194
+ return None
195
+
196
+ output_chunk = self.output_queue.get_nowait()
197
+ self.output_queue.task_done()
198
+ return output_chunk
199
+
200
+ def start(self):
201
+ self.debug_log("starting transcoder in a thread")
202
+ threading.Thread(target=self.process_pipeline_loop).start()
203
+
204
+ def first_translation_time(self):
205
+ return round((self.first_output_ts - self.first_input_ts) / 1000, 2)
206
+
207
+ def get_buffered_output(self) -> SpeechAndTextOutput:
208
+ now = time.time() * 1000
209
+ self.debug_log(f"get_buffered_output queue size: {self.output_queue.qsize()}")
210
+ while not self.output_queue.empty():
211
+ tmp_out = self.get_output_segment()
212
+ if tmp_out and len(tmp_out.content) > 0:
213
+ if not self.output_data_type:
214
+ self.output_data_type = tmp_out.data_type
215
+ if len(self.output_buffer) == 0:
216
+ self.last_output_ts = now
217
+ self._populate_output_buffer(tmp_out)
218
+ self._increment_output_buffer_size(tmp_out)
219
+
220
+ if tmp_out.finished:
221
+ res = self._gather_output_buffer_data(final=True)
222
+ self.output_buffer = []
223
+ self.increment_output_buffer_size = 0
224
+ self.last_output_ts = now
225
+ self.first_output_ts = now
226
+ return res
227
+
228
+ if len(self.output_buffer) > 0 and (
229
+ now - self.last_output_ts >= self.output_buffer_idle_ms
230
+ or self.output_buffer_cur_size >= self.output_buffer_size_limit
231
+ ):
232
+ self.last_output_ts = now
233
+ res = self._gather_output_buffer_data(final=False)
234
+ self.output_buffer = []
235
+ self.output_buffer_phoneme_count = 0
236
+ self.first_output_ts = now
237
+ return res
238
+ else:
239
+ return None
240
+
241
+ def _gather_output_buffer_data(self, final):
242
+ if self.output_data_type == "text":
243
+ return SpeechAndTextOutput(text=" ".join(self.output_buffer), final=final)
244
+ elif self.output_data_type == "speech":
245
+ return SpeechAndTextOutput(
246
+ speech_samples=self.output_buffer,
247
+ speech_sample_rate=MODEL_SAMPLE_RATE,
248
+ final=final,
249
+ )
250
+ else:
251
+ raise ValueError(
252
+ f"Invalid output buffer data type: {self.output_data_type}"
253
+ )
254
+
255
+ def _increment_output_buffer_size(self, segment):
256
+ if segment.data_type == "text":
257
+ self.output_buffer_cur_size += self._compute_phoneme_count(segment.content)
258
+ elif segment.data_type == "speech":
259
+ self.output_buffer_cur_size += (
260
+ len(segment.content) / MODEL_SAMPLE_RATE
261
+ ) # seconds
262
+
263
+ def _populate_output_buffer(self, segment):
264
+ if segment.data_type == "text":
265
+ self.output_buffer.append(segment.content)
266
+ elif segment.data_type == "speech":
267
+ self.output_buffer += segment.content
268
+ else:
269
+ raise ValueError(f"Invalid segment data type: {segment.data_type}")
270
+
271
+ def _compute_phoneme_count(self, string: str) -> int:
272
+ return len([x for x in self.g2p(string) if x != " "])
requirements.txt CHANGED
@@ -1,9 +1,23 @@
1
  # fairseq2==0.1.0
 
 
2
  git+https://github.com/mduppes/fairseq2.git@93420c86ba01349ee8f90d7adda439b666b50557
3
- git+https://github.com/facebookresearch/seamless_communication
 
 
 
4
  gradio==3.41.0
5
  huggingface_hub==0.16.4
6
  torch==2.0.1
7
  torchaudio==2.0.2
8
  transformers==4.32.1
9
- pydub
 
 
 
 
 
 
 
 
 
 
1
  # fairseq2==0.1.0
2
+
3
+ # Temp to skip
4
  git+https://github.com/mduppes/fairseq2.git@93420c86ba01349ee8f90d7adda439b666b50557
5
+ # git+https://github.com/facebookresearch/seamless_communication
6
+ ./seamless_communication
7
+ # comment this out to test fairseq1 first
8
+ # git+https://github.com/facebookresearch/SimulEval.git
9
  gradio==3.41.0
10
  huggingface_hub==0.16.4
11
  torch==2.0.1
12
  torchaudio==2.0.2
13
  transformers==4.32.1
14
+ pydub
15
+
16
+
17
+ # Can't import fairseq1 together.. causes conflict:
18
+ #The conflict is caused by:
19
+ # The user requested simuleval 1.1.0 (from git+ssh://****@github.com/facebookresearch/SimulEval.git@tree_pipeline)
20
+ # seamless-communication 1.0.0 depends on simuleval 1.0.3.dev36+gd84fa60 (from git+https://github.com/mduppes/SimulEval.git@main)
21
+ # From fairseq1 pipeline
22
+ # git+ssh://git@github.com/fairinternal/fairseq-py.git@emma_incremental_decoder
23
+ # git+ssh://git@github.com/facebookresearch/SimulEval.git@tree_pipeline
simuleval_transcoder.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from pathlib import Path
3
+ from typing import Callable, Dict, List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from fairseq2.assets.card import AssetCard
8
+ from fairseq2.data import Collater
9
+ from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
10
+ from fairseq2.data.text.text_tokenizer import TextTokenizer
11
+ from fairseq2.data.typing import StringLike
12
+ from fairseq2.generation import SequenceToTextOutput, SequenceGeneratorOptions
13
+ from fairseq2.memory import MemoryBlock
14
+ from fairseq2.typing import DataType, Device
15
+ from torch import Tensor
16
+ from enum import Enum, auto
17
+ from seamless_communication.models.inference.ngram_repeat_block_processor import (
18
+ NGramRepeatBlockProcessor,
19
+ )
20
+
21
+ from seamless_communication.models.unity import (
22
+ UnitTokenizer,
23
+ UnitYGenerator,
24
+ UnitYModel,
25
+ load_unity_model,
26
+ load_unity_text_tokenizer,
27
+ load_unity_unit_tokenizer,
28
+ )
29
+ from seamless_communication.models.unity.generator import SequenceToUnitOutput
30
+ from seamless_communication.models.vocoder import load_vocoder_model, Vocoder
31
+
32
+
33
+
34
+ from seamless_communication.models.streaming.agents import (
35
+ SileroVADAgent,
36
+ TestTimeWaitKS2TVAD,
37
+ TestTimeWaitKUnityV1M4T
38
+ )
39
+
40
+ ### From test_pipeline
41
+ import math
42
+ import soundfile
43
+ from argparse import Namespace, ArgumentParser
44
+ from simuleval.data.segments import SpeechSegment, EmptySegment
45
+ from simuleval.utils import build_system_from_dir
46
+ from pathlib import Path
47
+ import numpy as np
48
+
49
+ class AudioFrontEnd:
50
+ def __init__(self, wav_file, segment_size) -> None:
51
+ self.samples, self.sample_rate = soundfile.read(wav_file)
52
+ # print(len(self.samples), self.samples[:100])
53
+ self.samples = self.samples.tolist()
54
+ self.segment_size = segment_size
55
+ self.step = 0
56
+ def send_segment(self):
57
+ """
58
+ This is the front-end logic in simuleval instance.py
59
+ """
60
+ num_samples = math.ceil(self.segment_size / 1000 * self.sample_rate)
61
+ print("self.segment_size", self.segment_size)
62
+ print('num_samples is', num_samples)
63
+ print('self.sample_rate is', self.sample_rate)
64
+ if self.step < len(self.samples):
65
+ if self.step + num_samples >= len(self.samples):
66
+ samples = self.samples[self.step :]
67
+ is_finished = True
68
+ else:
69
+ samples = self.samples[self.step : self.step + num_samples]
70
+ is_finished = False
71
+ self.step = min(self.step + num_samples, len(self.samples))
72
+ # print("len(samples) is", len(samples))
73
+ # import pdb
74
+ # pdb.set_trace()
75
+ segment = SpeechSegment(
76
+ index=self.step / self.sample_rate * 1000,
77
+ content=samples,
78
+ sample_rate=self.sample_rate,
79
+ finished=is_finished,
80
+ )
81
+ else:
82
+ # Finish reading this audio
83
+ segment = EmptySegment(
84
+ index=self.step / self.sample_rate * 1000,
85
+ finished=True,
86
+ )
87
+ return segment
88
+
89
+
90
+
91
+ def load_model_for_inference(
92
+ load_model_fn: Callable[..., nn.Module],
93
+ model_name_or_card: Union[str, AssetCard],
94
+ device: Device,
95
+ dtype: DataType,
96
+ ) -> nn.Module:
97
+ model = load_model_fn(model_name_or_card, device=device, dtype=dtype)
98
+ model.eval()
99
+ return model
100
+
101
+ class SimulevalTranscoder:
102
+ # def __init__(self, agent, sample_rate, debug, buffer_limit):
103
+ def __init__(self):
104
+ print("MDUPPES in here", SileroVADAgent, TestTimeWaitKS2TVAD)
105
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
106
+
107
+ device = "cpu"
108
+ print("DEVICE", device)
109
+ model_name_or_card="seamlessM4T_medium"
110
+ vocoder_name_or_card="vocoder_36langs"
111
+ # dtype=torch.float16,
112
+ # For CPU Mode need to use 32, float16 causes errors downstream
113
+ dtype=dtype=torch.float32
114
+
115
+ model: UnitYModel = load_model_for_inference(
116
+ load_unity_model, model_name_or_card, device, dtype
117
+ )
118
+
119
+
120
+ print(model, type(model))
121
+ parser = ArgumentParser()
122
+ source_segment_size = 320 # milliseconds
123
+ audio_frontend = AudioFrontEnd(
124
+ wav_file="/checkpoint/mduppes/samples/marta.wav",
125
+ segment_size=source_segment_size,
126
+ )
127
+
128
+ # mostly taken from S2S first agent: OnlineFeatureExtractorAgent defaults
129
+ SHIFT_SIZE = 10
130
+ WINDOW_SIZE = 25
131
+ SAMPLE_RATE = 16000
132
+ FEATURE_DIM = 80
133
+
134
+ # args and convert to namespace so it can be accesed via .
135
+ args = {
136
+ "shift_size": SHIFT_SIZE,
137
+ "window_size": WINDOW_SIZE,
138
+ "sample_rate": audio_frontend.sample_rate,
139
+ "feature_dim": 160, # from Wav2Vec2Frontend
140
+ "denormalize": False, # not sure..
141
+ "global_stats": None, # default file path containing cmvn stats..
142
+ }
143
+ print(args)
144
+ args = Namespace(**args)
145
+
146
+ pipeline = TestTimeWaitKUnityV1M4T(model, args)
147
+ system_states = pipeline.build_states()
148
+ print('system states')
149
+ print(system_states)
150
+ input_segment = np.empty(0, dtype=np.int16)
151
+ segments = []
152
+ while True:
153
+ speech_segment = audio_frontend.send_segment()
154
+ input_segment = np.concatenate((input_segment, np.array(speech_segment.content)))
155
+ # Translation happens here
156
+ output_segment = pipeline.pushpop(speech_segment, system_states)
157
+ print('pushpop result')
158
+ print(output_segment)
159
+ if output_segment.finished:
160
+ segments.append(input_segment)
161
+ input_segment = np.empty(0, dtype=np.int16)
162
+ print("Resetting states")
163
+ for state in system_states:
164
+ state.reset()
165
+ if speech_segment.finished:
166
+ break
167
+ # The VAD-segmented samples from the full input audio
168
+ for i, seg in enumerate(segments):
169
+ with soundfile.SoundFile(
170
+ Path("/checkpoint/mduppes/samples") / f"marta_{i}.wav",
171
+ mode="w+",
172
+ format="WAV",
173
+ samplerate=16000,
174
+ channels=1,
175
+ ) as f:
176
+ f.seek(0, soundfile.SEEK_END)
177
+ f.write(seg)
178
+