Do not parallelize period vad
Browse files- src/vad.py +10 -0
- src/vadParallel.py +1 -1
src/vad.py
CHANGED
@@ -77,6 +77,12 @@ class AbstractTranscription(ABC):
|
|
77 |
def get_audio_segment(self, str, start_time: str = None, duration: str = None):
|
78 |
return load_audio(str, self.sampling_rate, start_time, duration)
|
79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
@abstractmethod
|
81 |
def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig, start_time: float, end_time: float):
|
82 |
"""
|
@@ -462,6 +468,10 @@ class VadPeriodicTranscription(AbstractTranscription):
|
|
462 |
def __init__(self, sampling_rate: int = 16000):
|
463 |
super().__init__(sampling_rate=sampling_rate)
|
464 |
|
|
|
|
|
|
|
|
|
465 |
def get_transcribe_timestamps(self, audio: str, config: PeriodicTranscriptionConfig, start_time: float, end_time: float):
|
466 |
result = []
|
467 |
|
|
|
77 |
def get_audio_segment(self, str, start_time: str = None, duration: str = None):
|
78 |
return load_audio(str, self.sampling_rate, start_time, duration)
|
79 |
|
80 |
+
def is_transcribe_timestamps_fast(self):
|
81 |
+
"""
|
82 |
+
Determine if get_transcribe_timestamps is fast enough to not need parallelization.
|
83 |
+
"""
|
84 |
+
return False
|
85 |
+
|
86 |
@abstractmethod
|
87 |
def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig, start_time: float, end_time: float):
|
88 |
"""
|
|
|
468 |
def __init__(self, sampling_rate: int = 16000):
|
469 |
super().__init__(sampling_rate=sampling_rate)
|
470 |
|
471 |
+
def is_transcribe_timestamps_fast(self):
|
472 |
+
# This is a very fast VAD - no need to parallelize it
|
473 |
+
return True
|
474 |
+
|
475 |
def get_transcribe_timestamps(self, audio: str, config: PeriodicTranscriptionConfig, start_time: float, end_time: float):
|
476 |
result = []
|
477 |
|
src/vadParallel.py
CHANGED
@@ -90,7 +90,7 @@ class ParallelTranscription(AbstractTranscription):
|
|
90 |
total_duration = get_audio_duration(audio)
|
91 |
|
92 |
# First, get the timestamps for the original audio
|
93 |
-
if (cpu_device_count > 1):
|
94 |
merged = self._get_merged_timestamps_parallel(transcription, audio, config, total_duration, cpu_device_count, cpu_parallel_context)
|
95 |
else:
|
96 |
timestamp_segments = transcription.get_transcribe_timestamps(audio, config, 0, total_duration)
|
|
|
90 |
total_duration = get_audio_duration(audio)
|
91 |
|
92 |
# First, get the timestamps for the original audio
|
93 |
+
if (cpu_device_count > 1 and not transcription.is_transcribe_timestamps_fast()):
|
94 |
merged = self._get_merged_timestamps_parallel(transcription, audio, config, total_duration, cpu_device_count, cpu_parallel_context)
|
95 |
else:
|
96 |
timestamp_segments = transcription.get_transcribe_timestamps(audio, config, 0, total_duration)
|