aadnk commited on
Commit
530547e
1 Parent(s): de77829

Do not parallelize period vad

Browse files
Files changed (2) hide show
  1. src/vad.py +10 -0
  2. src/vadParallel.py +1 -1
src/vad.py CHANGED
@@ -77,6 +77,12 @@ class AbstractTranscription(ABC):
77
  def get_audio_segment(self, str, start_time: str = None, duration: str = None):
78
  return load_audio(str, self.sampling_rate, start_time, duration)
79
 
 
 
 
 
 
 
80
  @abstractmethod
81
  def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig, start_time: float, end_time: float):
82
  """
@@ -462,6 +468,10 @@ class VadPeriodicTranscription(AbstractTranscription):
462
  def __init__(self, sampling_rate: int = 16000):
463
  super().__init__(sampling_rate=sampling_rate)
464
 
 
 
 
 
465
  def get_transcribe_timestamps(self, audio: str, config: PeriodicTranscriptionConfig, start_time: float, end_time: float):
466
  result = []
467
 
 
77
  def get_audio_segment(self, str, start_time: str = None, duration: str = None):
78
  return load_audio(str, self.sampling_rate, start_time, duration)
79
 
80
+ def is_transcribe_timestamps_fast(self):
81
+ """
82
+ Determine if get_transcribe_timestamps is fast enough to not need parallelization.
83
+ """
84
+ return False
85
+
86
  @abstractmethod
87
  def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig, start_time: float, end_time: float):
88
  """
 
468
  def __init__(self, sampling_rate: int = 16000):
469
  super().__init__(sampling_rate=sampling_rate)
470
 
471
+ def is_transcribe_timestamps_fast(self):
472
+ # This is a very fast VAD - no need to parallelize it
473
+ return True
474
+
475
  def get_transcribe_timestamps(self, audio: str, config: PeriodicTranscriptionConfig, start_time: float, end_time: float):
476
  result = []
477
 
src/vadParallel.py CHANGED
@@ -90,7 +90,7 @@ class ParallelTranscription(AbstractTranscription):
90
  total_duration = get_audio_duration(audio)
91
 
92
  # First, get the timestamps for the original audio
93
- if (cpu_device_count > 1):
94
  merged = self._get_merged_timestamps_parallel(transcription, audio, config, total_duration, cpu_device_count, cpu_parallel_context)
95
  else:
96
  timestamp_segments = transcription.get_transcribe_timestamps(audio, config, 0, total_duration)
 
90
  total_duration = get_audio_duration(audio)
91
 
92
  # First, get the timestamps for the original audio
93
+ if (cpu_device_count > 1 and not transcription.is_transcribe_timestamps_fast()):
94
  merged = self._get_merged_timestamps_parallel(transcription, audio, config, total_duration, cpu_device_count, cpu_parallel_context)
95
  else:
96
  timestamp_segments = transcription.get_transcribe_timestamps(audio, config, 0, total_duration)