import multiprocessing from queue import Empty import threading import time from src.hooks.progressListener import ProgressListener from src.vad import AbstractTranscription, TranscriptionConfig, get_audio_duration from multiprocessing import Pool, Queue from typing import Any, Dict, List, Union import os from src.whisper.abstractWhisperContainer import AbstractWhisperCallback class _ProgressListenerToQueue(ProgressListener): def __init__(self, progress_queue: Queue): self.progress_queue = progress_queue self.progress_total = 0 self.prev_progress = 0 def on_progress(self, current: Union[int, float], total: Union[int, float]): delta = current - self.prev_progress self.prev_progress = current self.progress_total = total self.progress_queue.put(delta) def on_finished(self): if self.progress_total > self.prev_progress: delta = self.progress_total - self.prev_progress self.progress_queue.put(delta) self.prev_progress = self.progress_total class ParallelContext: def __init__(self, num_processes: int = None, auto_cleanup_timeout_seconds: float = None): self.num_processes = num_processes self.auto_cleanup_timeout_seconds = auto_cleanup_timeout_seconds self.lock = threading.Lock() self.ref_count = 0 self.pool = None self.cleanup_timer = None def get_pool(self): # Initialize pool lazily if (self.pool is None): context = multiprocessing.get_context('spawn') self.pool = context.Pool(self.num_processes) self.ref_count = self.ref_count + 1 if (self.auto_cleanup_timeout_seconds is not None): self._stop_auto_cleanup() return self.pool def return_pool(self, pool): if (self.pool == pool and self.ref_count > 0): self.ref_count = self.ref_count - 1 if (self.ref_count == 0): if (self.auto_cleanup_timeout_seconds is not None): self._start_auto_cleanup() def _start_auto_cleanup(self): if (self.cleanup_timer is not None): self.cleanup_timer.cancel() self.cleanup_timer = threading.Timer(self.auto_cleanup_timeout_seconds, self._execute_cleanup) self.cleanup_timer.start() print("Started auto cleanup of pool in " + str(self.auto_cleanup_timeout_seconds) + " seconds") def _stop_auto_cleanup(self): if (self.cleanup_timer is not None): self.cleanup_timer.cancel() self.cleanup_timer = None print("Stopped auto cleanup of pool") def _execute_cleanup(self): print("Executing cleanup of pool") if (self.ref_count == 0): self.close() def close(self): self._stop_auto_cleanup() if (self.pool is not None): print("Closing pool of " + str(self.num_processes) + " processes") self.pool.close() self.pool.join() self.pool = None class ParallelTranscriptionConfig(TranscriptionConfig): def __init__(self, device_id: str, override_timestamps, initial_segment_index, copy: TranscriptionConfig = None): super().__init__(copy.non_speech_strategy, copy.segment_padding_left, copy.segment_padding_right, copy.max_silent_period, copy.max_merge_size, copy.max_prompt_window, initial_segment_index) self.device_id = device_id self.override_timestamps = override_timestamps class ParallelTranscription(AbstractTranscription): # Silero VAD typically takes about 3 seconds per minute, so there's no need to split the chunks # into smaller segments than 2 minute (min 6 seconds per CPU core) MIN_CPU_CHUNK_SIZE_SECONDS = 2 * 60 def __init__(self, sampling_rate: int = 16000): super().__init__(sampling_rate=sampling_rate) def transcribe_parallel(self, transcription: AbstractTranscription, audio: str, whisperCallable: AbstractWhisperCallback, config: TranscriptionConfig, cpu_device_count: int, gpu_devices: List[str], cpu_parallel_context: ParallelContext = None, gpu_parallel_context: ParallelContext = None, progress_listener: ProgressListener = None): total_duration = get_audio_duration(audio) # First, get the timestamps for the original audio if (cpu_device_count > 1 and not transcription.is_transcribe_timestamps_fast()): merged = self._get_merged_timestamps_parallel(transcription, audio, config, total_duration, cpu_device_count, cpu_parallel_context) else: timestamp_segments = transcription.get_transcribe_timestamps(audio, config, 0, total_duration) merged = transcription.get_merged_timestamps(timestamp_segments, config, total_duration) # We must make sure the whisper model is downloaded if (len(gpu_devices) > 1): whisperCallable.model_container.ensure_downloaded() # Split into a list for each device # TODO: Split by time instead of by number of chunks merged_split = list(self._split(merged, len(gpu_devices))) # Parameters that will be passed to the transcribe function parameters = [] segment_index = config.initial_segment_index processing_manager = multiprocessing.Manager() progress_queue = processing_manager.Queue() for i in range(len(gpu_devices)): # Note that device_segment_list can be empty. But we will still create a process for it, # as otherwise we run the risk of assigning the same device to multiple processes. device_segment_list = list(merged_split[i]) if i < len(merged_split) else [] device_id = gpu_devices[i] print("Device " + str(device_id) + " (index " + str(i) + ") has " + str(len(device_segment_list)) + " segments") # Create a new config with the given device ID device_config = ParallelTranscriptionConfig(device_id, device_segment_list, segment_index, config) segment_index += len(device_segment_list) progress_listener_to_queue = _ProgressListenerToQueue(progress_queue) parameters.append([audio, whisperCallable, device_config, progress_listener_to_queue]); merged = { 'text': '', 'segments': [], 'language': None } created_context = False perf_start_gpu = time.perf_counter() # Spawn a separate process for each device try: if (gpu_parallel_context is None): gpu_parallel_context = ParallelContext(len(gpu_devices)) created_context = True # Get a pool of processes pool = gpu_parallel_context.get_pool() # Run the transcription in parallel results_async = pool.starmap_async(self.transcribe, parameters) total_progress = 0 while not results_async.ready(): try: delta = progress_queue.get(timeout=5) # Set a timeout of 5 seconds except Empty: continue total_progress += delta if progress_listener is not None: progress_listener.on_progress(total_progress, total_duration) results = results_async.get() # Call the finished callback if progress_listener is not None: progress_listener.on_finished() for result in results: # Merge the results if (result['text'] is not None): merged['text'] += result['text'] if (result['segments'] is not None): merged['segments'].extend(result['segments']) if (result['language'] is not None): merged['language'] = result['language'] finally: # Return the pool to the context if (gpu_parallel_context is not None): gpu_parallel_context.return_pool(pool) # Always close the context if we created it if (created_context): gpu_parallel_context.close() perf_end_gpu = time.perf_counter() print("Parallel transcription took " + str(perf_end_gpu - perf_start_gpu) + " seconds") return merged def _get_merged_timestamps_parallel(self, transcription: AbstractTranscription, audio: str, config: TranscriptionConfig, total_duration: float, cpu_device_count: int, cpu_parallel_context: ParallelContext = None): parameters = [] chunk_size = max(total_duration / cpu_device_count, self.MIN_CPU_CHUNK_SIZE_SECONDS) chunk_start = 0 cpu_device_id = 0 perf_start_time = time.perf_counter() # Create chunks that will be processed on the CPU while (chunk_start < total_duration): chunk_end = min(chunk_start + chunk_size, total_duration) if (chunk_end - chunk_start < 1): # No need to process chunks that are less than 1 second break print("Parallel VAD: Executing chunk from " + str(chunk_start) + " to " + str(chunk_end) + " on CPU device " + str(cpu_device_id)) parameters.append([audio, config, chunk_start, chunk_end]); cpu_device_id += 1 chunk_start = chunk_end created_context = False # Spawn a separate process for each device try: if (cpu_parallel_context is None): cpu_parallel_context = ParallelContext(cpu_device_count) created_context = True # Get a pool of processes pool = cpu_parallel_context.get_pool() # Run the transcription in parallel. Note that transcription must be picklable. results = pool.starmap(transcription.get_transcribe_timestamps, parameters) timestamps = [] # Flatten the results for result in results: timestamps.extend(result) merged = transcription.get_merged_timestamps(timestamps, config, total_duration) perf_end_time = time.perf_counter() print("Parallel VAD processing took {} seconds".format(perf_end_time - perf_start_time)) return merged finally: # Return the pool to the context if (cpu_parallel_context is not None): cpu_parallel_context.return_pool(pool) # Always close the context if we created it if (created_context): cpu_parallel_context.close() def get_transcribe_timestamps(self, audio: str, config: ParallelTranscriptionConfig, start_time: float, duration: float): return [] def get_merged_timestamps(self, timestamps: List[Dict[str, Any]], config: ParallelTranscriptionConfig, total_duration: float): # Override timestamps that will be processed if (config.override_timestamps is not None): print("(get_merged_timestamps) Using override timestamps of size " + str(len(config.override_timestamps))) return config.override_timestamps return super().get_merged_timestamps(timestamps, config, total_duration) def transcribe(self, audio: str, whisperCallable: AbstractWhisperCallback, config: ParallelTranscriptionConfig, progressListener: ProgressListener = None): # Override device ID the first time if (os.environ.get("INITIALIZED", None) is None): os.environ["INITIALIZED"] = "1" # Note that this may be None if the user didn't specify a device. In that case, Whisper will # just use the default GPU device. if (config.device_id is not None): print("Using device " + config.device_id) os.environ["CUDA_VISIBLE_DEVICES"] = config.device_id return super().transcribe(audio, whisperCallable, config, progressListener) def _split(self, a, n): """Split a list into n approximately equal parts.""" k, m = divmod(len(a), n) return (a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n))