aadnk commited on
Commit
4a9d465
·
1 Parent(s): 33a2c1e

Support progress for multiple devices

Browse files
Files changed (3) hide show
  1. app.py +2 -2
  2. src/vad.py +78 -66
  3. src/vadParallel.py +50 -8
app.py CHANGED
@@ -279,7 +279,6 @@ class WhisperTranscriber:
279
  # No parallel devices, so just run the VAD and Whisper in sequence
280
  return vadModel.transcribe(audio_path, whisperCallable, vadConfig, progressListener=progressListener)
281
 
282
- # TODO: Handle progress listener
283
  gpu_devices = self.parallel_device_list
284
 
285
  if (gpu_devices is None or len(gpu_devices) == 0):
@@ -297,7 +296,8 @@ class WhisperTranscriber:
297
  parallel_vad = ParallelTranscription()
298
  return parallel_vad.transcribe_parallel(transcription=vadModel, audio=audio_path, whisperCallable=whisperCallable,
299
  config=vadConfig, cpu_device_count=self.vad_cpu_cores, gpu_devices=gpu_devices,
300
- cpu_parallel_context=self.cpu_parallel_context, gpu_parallel_context=self.gpu_parallel_context)
 
301
 
302
  def _has_parallel_devices(self):
303
  return (self.parallel_device_list is not None and len(self.parallel_device_list) > 0) or self.vad_cpu_cores > 1
 
279
  # No parallel devices, so just run the VAD and Whisper in sequence
280
  return vadModel.transcribe(audio_path, whisperCallable, vadConfig, progressListener=progressListener)
281
 
 
282
  gpu_devices = self.parallel_device_list
283
 
284
  if (gpu_devices is None or len(gpu_devices) == 0):
 
296
  parallel_vad = ParallelTranscription()
297
  return parallel_vad.transcribe_parallel(transcription=vadModel, audio=audio_path, whisperCallable=whisperCallable,
298
  config=vadConfig, cpu_device_count=self.vad_cpu_cores, gpu_devices=gpu_devices,
299
+ cpu_parallel_context=self.cpu_parallel_context, gpu_parallel_context=self.gpu_parallel_context,
300
+ progress_listener=progressListener)
301
 
302
  def _has_parallel_devices(self):
303
  return (self.parallel_device_list is not None and len(self.parallel_device_list) > 0) or self.vad_cpu_cores > 1
src/vad.py CHANGED
@@ -153,84 +153,96 @@ class AbstractTranscription(ABC):
153
  A list of start and end timestamps, in fractional seconds.
154
  """
155
 
156
- max_audio_duration = get_audio_duration(audio)
157
- timestamp_segments = self.get_transcribe_timestamps(audio, config, 0, max_audio_duration)
 
158
 
159
- # Get speech timestamps from full audio file
160
- merged = self.get_merged_timestamps(timestamp_segments, config, max_audio_duration)
161
 
162
- # A deque of transcribed segments that is passed to the next segment as a prompt
163
- prompt_window = deque()
164
 
165
- print("Processing timestamps:")
166
- pprint(merged)
167
-
168
- result = {
169
- 'text': "",
170
- 'segments': [],
171
- 'language': ""
172
- }
173
- languageCounter = Counter()
174
- detected_language = None
175
-
176
- segment_index = config.initial_segment_index
177
-
178
- # For each time segment, run whisper
179
- for segment in merged:
180
- segment_index += 1
181
- segment_start = segment['start']
182
- segment_end = segment['end']
183
- segment_expand_amount = segment.get('expand_amount', 0)
184
- segment_gap = segment.get('gap', False)
185
-
186
- segment_duration = segment_end - segment_start
187
-
188
- if segment_duration < MIN_SEGMENT_DURATION:
189
- continue
190
-
191
- # Audio to run on Whisper
192
- segment_audio = self.get_audio_segment(audio, start_time = str(segment_start), duration = str(segment_duration))
193
- # Previous segments to use as a prompt
194
- segment_prompt = ' '.join([segment['text'] for segment in prompt_window]) if len(prompt_window) > 0 else None
195
-
196
- # Detected language
197
- detected_language = languageCounter.most_common(1)[0][0] if len(languageCounter) > 0 else None
198
-
199
- print("Running whisper from ", format_timestamp(segment_start), " to ", format_timestamp(segment_end), ", duration: ",
200
- segment_duration, "expanded: ", segment_expand_amount, "prompt: ", segment_prompt, "language: ", detected_language)
201
-
202
- scaled_progress_listener = SubTaskProgressListener(progressListener, base_task_total=max_audio_duration, sub_task_start=segment_start, sub_task_total=segment_duration)
203
- segment_result = whisperCallable.invoke(segment_audio, segment_index, segment_prompt, detected_language, progress_listener=scaled_progress_listener)
204
-
205
- adjusted_segments = self.adjust_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
206
 
207
- # Propagate expand amount to the segments
208
- if (segment_expand_amount > 0):
209
- segment_without_expansion = segment_duration - segment_expand_amount
 
 
 
 
210
 
211
- for adjusted_segment in adjusted_segments:
212
- adjusted_segment_end = adjusted_segment['end']
213
 
214
- # Add expand amount if the segment got expanded
215
- if (adjusted_segment_end > segment_without_expansion):
216
- adjusted_segment["expand_amount"] = adjusted_segment_end - segment_without_expansion
217
 
218
- # Append to output
219
- result['text'] += segment_result['text']
220
- result['segments'].extend(adjusted_segments)
 
 
 
 
221
 
222
- # Increment detected language
223
- if not segment_gap:
224
- languageCounter[segment_result['language']] += 1
225
 
226
- # Update prompt window
227
- self.__update_prompt_window(prompt_window, adjusted_segments, segment_end, segment_gap, config)
228
-
229
- if detected_language is not None:
230
- result['language'] = detected_language
231
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  return result
233
 
 
 
 
234
  def __update_prompt_window(self, prompt_window: Deque, adjusted_segments: List, segment_end: float, segment_gap: bool, config: TranscriptionConfig):
235
  if (config.max_prompt_window is not None and config.max_prompt_window > 0):
236
  # Add segments to the current prompt window (unless it is a speech gap)
 
153
  A list of start and end timestamps, in fractional seconds.
154
  """
155
 
156
+ try:
157
+ max_audio_duration = self.get_audio_duration(audio, config)
158
+ timestamp_segments = self.get_transcribe_timestamps(audio, config, 0, max_audio_duration)
159
 
160
+ # Get speech timestamps from full audio file
161
+ merged = self.get_merged_timestamps(timestamp_segments, config, max_audio_duration)
162
 
163
+ # A deque of transcribed segments that is passed to the next segment as a prompt
164
+ prompt_window = deque()
165
 
166
+ print("Processing timestamps:")
167
+ pprint(merged)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
+ result = {
170
+ 'text': "",
171
+ 'segments': [],
172
+ 'language': ""
173
+ }
174
+ languageCounter = Counter()
175
+ detected_language = None
176
 
177
+ segment_index = config.initial_segment_index
 
178
 
179
+ # Calculate progress
180
+ progress_start_offset = merged[0]['start'] if len(merged) > 0 else 0
181
+ progress_total_duration = sum([segment['end'] - segment['start'] for segment in merged])
182
 
183
+ # For each time segment, run whisper
184
+ for segment in merged:
185
+ segment_index += 1
186
+ segment_start = segment['start']
187
+ segment_end = segment['end']
188
+ segment_expand_amount = segment.get('expand_amount', 0)
189
+ segment_gap = segment.get('gap', False)
190
 
191
+ segment_duration = segment_end - segment_start
 
 
192
 
193
+ if segment_duration < MIN_SEGMENT_DURATION:
194
+ continue
 
 
 
195
 
196
+ # Audio to run on Whisper
197
+ segment_audio = self.get_audio_segment(audio, start_time = str(segment_start), duration = str(segment_duration))
198
+ # Previous segments to use as a prompt
199
+ segment_prompt = ' '.join([segment['text'] for segment in prompt_window]) if len(prompt_window) > 0 else None
200
+
201
+ # Detected language
202
+ detected_language = languageCounter.most_common(1)[0][0] if len(languageCounter) > 0 else None
203
+
204
+ print("Running whisper from ", format_timestamp(segment_start), " to ", format_timestamp(segment_end), ", duration: ",
205
+ segment_duration, "expanded: ", segment_expand_amount, "prompt: ", segment_prompt, "language: ", detected_language)
206
+
207
+ scaled_progress_listener = SubTaskProgressListener(progressListener, base_task_total=progress_total_duration,
208
+ sub_task_start=segment_start - progress_start_offset, sub_task_total=segment_duration)
209
+ segment_result = whisperCallable.invoke(segment_audio, segment_index, segment_prompt, detected_language, progress_listener=scaled_progress_listener)
210
+
211
+ adjusted_segments = self.adjust_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
212
+
213
+ # Propagate expand amount to the segments
214
+ if (segment_expand_amount > 0):
215
+ segment_without_expansion = segment_duration - segment_expand_amount
216
+
217
+ for adjusted_segment in adjusted_segments:
218
+ adjusted_segment_end = adjusted_segment['end']
219
+
220
+ # Add expand amount if the segment got expanded
221
+ if (adjusted_segment_end > segment_without_expansion):
222
+ adjusted_segment["expand_amount"] = adjusted_segment_end - segment_without_expansion
223
+
224
+ # Append to output
225
+ result['text'] += segment_result['text']
226
+ result['segments'].extend(adjusted_segments)
227
+
228
+ # Increment detected language
229
+ if not segment_gap:
230
+ languageCounter[segment_result['language']] += 1
231
+
232
+ # Update prompt window
233
+ self.__update_prompt_window(prompt_window, adjusted_segments, segment_end, segment_gap, config)
234
+
235
+ if detected_language is not None:
236
+ result['language'] = detected_language
237
+ finally:
238
+ # Notify progress listener that we are done
239
+ if progressListener is not None:
240
+ progressListener.on_finished()
241
  return result
242
 
243
+ def get_audio_duration(self, audio: str, config: TranscriptionConfig):
244
+ return get_audio_duration(audio)
245
+
246
  def __update_prompt_window(self, prompt_window: Deque, adjusted_segments: List, segment_end: float, segment_gap: bool, config: TranscriptionConfig):
247
  if (config.max_prompt_window is not None and config.max_prompt_window > 0):
248
  # Add segments to the current prompt window (unless it is a speech gap)
src/vadParallel.py CHANGED
@@ -1,14 +1,33 @@
1
  import multiprocessing
 
2
  import threading
3
  import time
 
4
  from src.vad import AbstractTranscription, TranscriptionConfig, get_audio_duration
5
  from src.whisperContainer import WhisperCallback
6
 
7
- from multiprocessing import Pool
8
 
9
- from typing import Any, Dict, List
10
  import os
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  class ParallelContext:
14
  def __init__(self, num_processes: int = None, auto_cleanup_timeout_seconds: float = None):
@@ -86,7 +105,8 @@ class ParallelTranscription(AbstractTranscription):
86
  super().__init__(sampling_rate=sampling_rate)
87
 
88
  def transcribe_parallel(self, transcription: AbstractTranscription, audio: str, whisperCallable: WhisperCallback, config: TranscriptionConfig,
89
- cpu_device_count: int, gpu_devices: List[str], cpu_parallel_context: ParallelContext = None, gpu_parallel_context: ParallelContext = None):
 
90
  total_duration = get_audio_duration(audio)
91
 
92
  # First, get the timestamps for the original audio
@@ -108,6 +128,9 @@ class ParallelTranscription(AbstractTranscription):
108
  parameters = []
109
  segment_index = config.initial_segment_index
110
 
 
 
 
111
  for i in range(len(gpu_devices)):
112
  # Note that device_segment_list can be empty. But we will still create a process for it,
113
  # as otherwise we run the risk of assigning the same device to multiple processes.
@@ -120,7 +143,8 @@ class ParallelTranscription(AbstractTranscription):
120
  device_config = ParallelTranscriptionConfig(device_id, device_segment_list, segment_index, config)
121
  segment_index += len(device_segment_list)
122
 
123
- parameters.append([audio, whisperCallable, device_config]);
 
124
 
125
  merged = {
126
  'text': '',
@@ -142,7 +166,24 @@ class ParallelTranscription(AbstractTranscription):
142
  pool = gpu_parallel_context.get_pool()
143
 
144
  # Run the transcription in parallel
145
- results = pool.starmap(self.transcribe, parameters)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
  for result in results:
148
  # Merge the results
@@ -231,11 +272,12 @@ class ParallelTranscription(AbstractTranscription):
231
  def get_merged_timestamps(self, timestamps: List[Dict[str, Any]], config: ParallelTranscriptionConfig, total_duration: float):
232
  # Override timestamps that will be processed
233
  if (config.override_timestamps is not None):
234
- print("Using override timestamps of size " + str(len(config.override_timestamps)))
235
  return config.override_timestamps
236
  return super().get_merged_timestamps(timestamps, config, total_duration)
237
 
238
- def transcribe(self, audio: str, whisperCallable: WhisperCallback, config: ParallelTranscriptionConfig):
 
239
  # Override device ID the first time
240
  if (os.environ.get("INITIALIZED", None) is None):
241
  os.environ["INITIALIZED"] = "1"
@@ -246,7 +288,7 @@ class ParallelTranscription(AbstractTranscription):
246
  print("Using device " + config.device_id)
247
  os.environ["CUDA_VISIBLE_DEVICES"] = config.device_id
248
 
249
- return super().transcribe(audio, whisperCallable, config)
250
 
251
  def _split(self, a, n):
252
  """Split a list into n approximately equal parts."""
 
1
  import multiprocessing
2
+ from queue import Empty
3
  import threading
4
  import time
5
+ from src.hooks.whisperProgressHook import ProgressListener
6
  from src.vad import AbstractTranscription, TranscriptionConfig, get_audio_duration
7
  from src.whisperContainer import WhisperCallback
8
 
9
+ from multiprocessing import Pool, Queue
10
 
11
+ from typing import Any, Dict, List, Union
12
  import os
13
 
14
+ class _ProgressListenerToQueue(ProgressListener):
15
+ def __init__(self, progress_queue: Queue):
16
+ self.progress_queue = progress_queue
17
+ self.progress_total = 0
18
+ self.prev_progress = 0
19
+
20
+ def on_progress(self, current: Union[int, float], total: Union[int, float]):
21
+ delta = current - self.prev_progress
22
+ self.prev_progress = current
23
+ self.progress_total = total
24
+ self.progress_queue.put(delta)
25
+
26
+ def on_finished(self):
27
+ if self.progress_total > self.prev_progress:
28
+ delta = self.progress_total - self.prev_progress
29
+ self.progress_queue.put(delta)
30
+ self.prev_progress = self.progress_total
31
 
32
  class ParallelContext:
33
  def __init__(self, num_processes: int = None, auto_cleanup_timeout_seconds: float = None):
 
105
  super().__init__(sampling_rate=sampling_rate)
106
 
107
  def transcribe_parallel(self, transcription: AbstractTranscription, audio: str, whisperCallable: WhisperCallback, config: TranscriptionConfig,
108
+ cpu_device_count: int, gpu_devices: List[str], cpu_parallel_context: ParallelContext = None, gpu_parallel_context: ParallelContext = None,
109
+ progress_listener: ProgressListener = None):
110
  total_duration = get_audio_duration(audio)
111
 
112
  # First, get the timestamps for the original audio
 
128
  parameters = []
129
  segment_index = config.initial_segment_index
130
 
131
+ processing_manager = multiprocessing.Manager()
132
+ progress_queue = processing_manager.Queue()
133
+
134
  for i in range(len(gpu_devices)):
135
  # Note that device_segment_list can be empty. But we will still create a process for it,
136
  # as otherwise we run the risk of assigning the same device to multiple processes.
 
143
  device_config = ParallelTranscriptionConfig(device_id, device_segment_list, segment_index, config)
144
  segment_index += len(device_segment_list)
145
 
146
+ progress_listener_to_queue = _ProgressListenerToQueue(progress_queue)
147
+ parameters.append([audio, whisperCallable, device_config, progress_listener_to_queue]);
148
 
149
  merged = {
150
  'text': '',
 
166
  pool = gpu_parallel_context.get_pool()
167
 
168
  # Run the transcription in parallel
169
+ results_async = pool.starmap_async(self.transcribe, parameters)
170
+ total_progress = 0
171
+
172
+ while not results_async.ready():
173
+ try:
174
+ delta = progress_queue.get(timeout=5) # Set a timeout of 5 seconds
175
+ except Empty:
176
+ continue
177
+
178
+ total_progress += delta
179
+ if progress_listener is not None:
180
+ progress_listener.on_progress(total_progress, total_duration)
181
+
182
+ results = results_async.get()
183
+
184
+ # Call the finished callback
185
+ if progress_listener is not None:
186
+ progress_listener.on_finished()
187
 
188
  for result in results:
189
  # Merge the results
 
272
  def get_merged_timestamps(self, timestamps: List[Dict[str, Any]], config: ParallelTranscriptionConfig, total_duration: float):
273
  # Override timestamps that will be processed
274
  if (config.override_timestamps is not None):
275
+ print("(get_merged_timestamps) Using override timestamps of size " + str(len(config.override_timestamps)))
276
  return config.override_timestamps
277
  return super().get_merged_timestamps(timestamps, config, total_duration)
278
 
279
+ def transcribe(self, audio: str, whisperCallable: WhisperCallback, config: ParallelTranscriptionConfig,
280
+ progressListener: ProgressListener = None):
281
  # Override device ID the first time
282
  if (os.environ.get("INITIALIZED", None) is None):
283
  os.environ["INITIALIZED"] = "1"
 
288
  print("Using device " + config.device_id)
289
  os.environ["CUDA_VISIBLE_DEVICES"] = config.device_id
290
 
291
+ return super().transcribe(audio, whisperCallable, config, progressListener)
292
 
293
  def _split(self, a, n):
294
  """Split a list into n approximately equal parts."""