aadnk commited on
Commit
95261ed
1 Parent(s): 8f3aedf

Add support for parallel execution on multiple GPUs

Browse files
Files changed (5) hide show
  1. app.py +37 -14
  2. cli.py +2 -0
  3. src/vad.py +42 -24
  4. src/vadParallel.py +81 -0
  5. src/whisperContainer.py +91 -0
app.py CHANGED
@@ -1,9 +1,13 @@
1
  from typing import Iterator
 
2
 
3
  from io import StringIO
4
  import os
5
  import pathlib
6
  import tempfile
 
 
 
7
 
8
  # External programs
9
  import whisper
@@ -14,7 +18,7 @@ import gradio as gr
14
 
15
  from src.download import ExceededMaximumDuration, download_url
16
  from src.utils import slugify, write_srt, write_vtt
17
- from src.vad import NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
18
 
19
  # Limitations (set to -1 to disable)
20
  DEFAULT_INPUT_AUDIO_MAX_DURATION = 600 # seconds
@@ -48,6 +52,7 @@ LANGUAGES = [
48
  class WhisperTranscriber:
49
  def __init__(self, inputAudioMaxDuration: float = DEFAULT_INPUT_AUDIO_MAX_DURATION, deleteUploadedFiles: bool = DELETE_UPLOADED_FILES):
50
  self.model_cache = dict()
 
51
 
52
  self.vad_model = None
53
  self.inputAudioMaxDuration = inputAudioMaxDuration
@@ -64,7 +69,7 @@ class WhisperTranscriber:
64
  model = self.model_cache.get(selectedModel, None)
65
 
66
  if not model:
67
- model = whisper.load_model(selectedModel)
68
  self.model_cache[selectedModel] = model
69
 
70
  # Execute whisper
@@ -87,7 +92,7 @@ class WhisperTranscriber:
87
  except ExceededMaximumDuration as e:
88
  return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
89
 
90
- def transcribe_file(self, model: whisper.Whisper, audio_path: str, language: str, task: str = None, vad: str = None,
91
  vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1, **decodeOptions: dict):
92
 
93
  initial_prompt = decodeOptions.pop('initial_prompt', None)
@@ -96,35 +101,42 @@ class WhisperTranscriber:
96
  task = decodeOptions.pop('task')
97
 
98
  # Callable for processing an audio file
99
- whisperCallable = lambda audio, segment_index, prompt, detected_language : model.transcribe(audio, \
100
- language=language if language else detected_language, task=task, \
101
- initial_prompt=self._concat_prompt(initial_prompt, prompt) if segment_index == 0 else prompt, \
102
- **decodeOptions)
103
 
104
  # The results
105
  if (vad == 'silero-vad'):
106
  # Silero VAD where non-speech gaps are transcribed
107
  process_gaps = self._create_silero_config(NonSpeechStrategy.CREATE_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
108
- result = self.vad_model.transcribe(audio_path, whisperCallable, process_gaps)
109
  elif (vad == 'silero-vad-skip-gaps'):
110
  # Silero VAD where non-speech gaps are simply ignored
111
  skip_gaps = self._create_silero_config(NonSpeechStrategy.SKIP, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
112
- result = self.vad_model.transcribe(audio_path, whisperCallable, skip_gaps)
113
  elif (vad == 'silero-vad-expand-into-gaps'):
114
  # Use Silero VAD where speech-segments are expanded into non-speech gaps
115
  expand_gaps = self._create_silero_config(NonSpeechStrategy.EXPAND_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
116
- result = self.vad_model.transcribe(audio_path, whisperCallable, expand_gaps)
117
  elif (vad == 'periodic-vad'):
118
  # Very simple VAD - mark every 5 minutes as speech. This makes it less likely that Whisper enters an infinite loop, but
119
  # it may create a break in the middle of a sentence, causing some artifacts.
120
  periodic_vad = VadPeriodicTranscription()
121
- result = periodic_vad.transcribe(audio_path, whisperCallable, PeriodicTranscriptionConfig(periodic_duration=vadMaxMergeSize, max_prompt_window=vadPromptWindow))
 
 
122
  else:
123
  # Default VAD
124
  result = whisperCallable(audio_path, 0, None, None)
125
 
126
  return result
127
 
 
 
 
 
 
 
 
 
128
  def _concat_prompt(self, prompt1, prompt2):
129
  if (prompt1 is None):
130
  return prompt2
@@ -218,9 +230,12 @@ class WhisperTranscriber:
218
  return file.name
219
 
220
 
221
- def create_ui(inputAudioMaxDuration, share=False, server_name: str = None):
222
  ui = WhisperTranscriber(inputAudioMaxDuration)
223
 
 
 
 
224
  ui_description = "Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse "
225
  ui_description += " audio and is also a multi-task model that can perform multilingual speech recognition "
226
  ui_description += " as well as speech translation and language identification. "
@@ -250,7 +265,15 @@ def create_ui(inputAudioMaxDuration, share=False, server_name: str = None):
250
  gr.Text(label="Segments")
251
  ])
252
 
253
- demo.launch(share=share, server_name=server_name)
254
 
255
  if __name__ == '__main__':
256
- create_ui(DEFAULT_INPUT_AUDIO_MAX_DURATION)
 
 
 
 
 
 
 
 
 
1
  from typing import Iterator
2
+ import argparse
3
 
4
  from io import StringIO
5
  import os
6
  import pathlib
7
  import tempfile
8
+ from src.vadParallel import ParallelTranscription
9
+
10
+ from src.whisperContainer import WhisperContainer
11
 
12
  # External programs
13
  import whisper
 
18
 
19
  from src.download import ExceededMaximumDuration, download_url
20
  from src.utils import slugify, write_srt, write_vtt
21
+ from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
22
 
23
  # Limitations (set to -1 to disable)
24
  DEFAULT_INPUT_AUDIO_MAX_DURATION = 600 # seconds
 
52
  class WhisperTranscriber:
53
  def __init__(self, inputAudioMaxDuration: float = DEFAULT_INPUT_AUDIO_MAX_DURATION, deleteUploadedFiles: bool = DELETE_UPLOADED_FILES):
54
  self.model_cache = dict()
55
+ self.parallel_device_list = None
56
 
57
  self.vad_model = None
58
  self.inputAudioMaxDuration = inputAudioMaxDuration
 
69
  model = self.model_cache.get(selectedModel, None)
70
 
71
  if not model:
72
+ model = WhisperContainer(selectedModel)
73
  self.model_cache[selectedModel] = model
74
 
75
  # Execute whisper
 
92
  except ExceededMaximumDuration as e:
93
  return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
94
 
95
+ def transcribe_file(self, model: WhisperContainer, audio_path: str, language: str, task: str = None, vad: str = None,
96
  vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1, **decodeOptions: dict):
97
 
98
  initial_prompt = decodeOptions.pop('initial_prompt', None)
 
101
  task = decodeOptions.pop('task')
102
 
103
  # Callable for processing an audio file
104
+ whisperCallable = model.create_callback(language, task, initial_prompt, **decodeOptions)
 
 
 
105
 
106
  # The results
107
  if (vad == 'silero-vad'):
108
  # Silero VAD where non-speech gaps are transcribed
109
  process_gaps = self._create_silero_config(NonSpeechStrategy.CREATE_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
110
+ result = self.process_vad(audio_path, whisperCallable, self.vad_model, process_gaps)
111
  elif (vad == 'silero-vad-skip-gaps'):
112
  # Silero VAD where non-speech gaps are simply ignored
113
  skip_gaps = self._create_silero_config(NonSpeechStrategy.SKIP, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
114
+ result = self.process_vad(audio_path, whisperCallable, self.vad_model, skip_gaps)
115
  elif (vad == 'silero-vad-expand-into-gaps'):
116
  # Use Silero VAD where speech-segments are expanded into non-speech gaps
117
  expand_gaps = self._create_silero_config(NonSpeechStrategy.EXPAND_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
118
+ result = self.process_vad(audio_path, whisperCallable, self.vad_model, expand_gaps)
119
  elif (vad == 'periodic-vad'):
120
  # Very simple VAD - mark every 5 minutes as speech. This makes it less likely that Whisper enters an infinite loop, but
121
  # it may create a break in the middle of a sentence, causing some artifacts.
122
  periodic_vad = VadPeriodicTranscription()
123
+ period_config = PeriodicTranscriptionConfig(periodic_duration=vadMaxMergeSize, max_prompt_window=vadPromptWindow)
124
+ result = self.process_vad(audio_path, whisperCallable, periodic_vad, period_config)
125
+
126
  else:
127
  # Default VAD
128
  result = whisperCallable(audio_path, 0, None, None)
129
 
130
  return result
131
 
132
+ def process_vad(self, audio_path, whisperCallable, vadModel: AbstractTranscription, vadConfig: TranscriptionConfig):
133
+ if (self.parallel_device_list is None or len(self.parallel_device_list) == 0):
134
+ # No parallel devices, so just run the VAD and Whisper in sequence
135
+ return vadModel.transcribe(audio_path, whisperCallable, vadConfig)
136
+
137
+ parallell_vad = ParallelTranscription()
138
+ return parallell_vad.transcribe_parallel(transcription=vadModel, audio=audio_path, whisperCallable=whisperCallable, config=vadConfig, devices=self.parallel_device_list)
139
+
140
  def _concat_prompt(self, prompt1, prompt2):
141
  if (prompt1 is None):
142
  return prompt2
 
230
  return file.name
231
 
232
 
233
+ def create_ui(inputAudioMaxDuration, share=False, server_name: str = None, server_port: int = 7860, vad_parallel_devices: str = None):
234
  ui = WhisperTranscriber(inputAudioMaxDuration)
235
 
236
+ # Specify a list of devices to use for parallel processing
237
+ ui.parallel_device_list = [ device.strip() for device in vad_parallel_devices.split(",") ] if vad_parallel_devices else None
238
+
239
  ui_description = "Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse "
240
  ui_description += " audio and is also a multi-task model that can perform multilingual speech recognition "
241
  ui_description += " as well as speech translation and language identification. "
 
265
  gr.Text(label="Segments")
266
  ])
267
 
268
+ demo.launch(share=share, server_name=server_name, server_port=server_port)
269
 
270
  if __name__ == '__main__':
271
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
272
+ parser.add_argument("--inputAudioMaxDuration", type=int, default=600, help="Maximum audio file length in seconds, or -1 for no limit.")
273
+ parser.add_argument("--share", type=bool, default=False, help="True to share the app on HuggingFace.")
274
+ parser.add_argument("--server_name", type=str, default=None, help="The host or IP to bind to. If None, bind to localhost.")
275
+ parser.add_argument("--server_port", type=int, default=7860, help="The port to bind to.")
276
+ parser.add_argument("--vad_parallel_devices", type=str, default="0,1", help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.")
277
+
278
+ args = parser.parse_args().__dict__
279
+ create_ui(**args)
cli.py CHANGED
@@ -31,6 +31,7 @@ def cli():
31
  parser.add_argument("--vad_max_merge_size", type=optional_float, default=30, help="The maximum size (in seconds) of a voice segment")
32
  parser.add_argument("--vad_padding", type=optional_float, default=1, help="The padding (in seconds) to add to each voice segment")
33
  parser.add_argument("--vad_prompt_window", type=optional_float, default=3, help="The window size of the prompt to pass to Whisper")
 
34
 
35
  parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
36
  parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
@@ -74,6 +75,7 @@ def cli():
74
 
75
  model = whisper.load_model(model_name, device=device, download_root=model_dir)
76
  transcriber = WhisperTranscriber(deleteUploadedFiles=False)
 
77
 
78
  for audio_path in args.pop("audio"):
79
  sources = []
 
31
  parser.add_argument("--vad_max_merge_size", type=optional_float, default=30, help="The maximum size (in seconds) of a voice segment")
32
  parser.add_argument("--vad_padding", type=optional_float, default=1, help="The padding (in seconds) to add to each voice segment")
33
  parser.add_argument("--vad_prompt_window", type=optional_float, default=3, help="The window size of the prompt to pass to Whisper")
34
+ parser.add_argument("--vad_parallel_devices", type=str, default="0", help="A commma delimited list of CUDA devices to use for paralell processing. If None, disable parallel processing.")
35
 
36
  parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
37
  parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
 
75
 
76
  model = whisper.load_model(model_name, device=device, download_root=model_dir)
77
  transcriber = WhisperTranscriber(deleteUploadedFiles=False)
78
+ transcriber.parallel_device_list = args.pop("vad_parallel_devices")
79
 
80
  for audio_path in args.pop("audio"):
81
  sources = []
src/vad.py CHANGED
@@ -6,6 +6,7 @@ from typing import Any, Deque, Iterator, List, Dict
6
  from pprint import pprint
7
 
8
  from src.segments import merge_timestamps
 
9
 
10
  # Workaround for https://github.com/tensorflow/tensorflow/issues/48797
11
  try:
@@ -51,19 +52,20 @@ VAD_MAX_PROCESSING_CHUNK = 60 * 60 # 60 minutes of audio
51
  class TranscriptionConfig(ABC):
52
  def __init__(self, non_speech_strategy: NonSpeechStrategy = NonSpeechStrategy.SKIP,
53
  segment_padding_left: float = None, segment_padding_right = None, max_silent_period: float = None,
54
- max_merge_size: float = None, max_prompt_window: float = None):
55
  self.non_speech_strategy = non_speech_strategy
56
  self.segment_padding_left = segment_padding_left
57
  self.segment_padding_right = segment_padding_right
58
  self.max_silent_period = max_silent_period
59
  self.max_merge_size = max_merge_size
60
  self.max_prompt_window = max_prompt_window
 
61
 
62
  class PeriodicTranscriptionConfig(TranscriptionConfig):
63
  def __init__(self, periodic_duration: float, non_speech_strategy: NonSpeechStrategy = NonSpeechStrategy.SKIP,
64
  segment_padding_left: float = None, segment_padding_right = None, max_silent_period: float = None,
65
- max_merge_size: float = None, max_prompt_window: float = None):
66
- super().__init__(non_speech_strategy, segment_padding_left, segment_padding_right, max_silent_period, max_merge_size, max_prompt_window)
67
  self.periodic_duration = periodic_duration
68
 
69
  class AbstractTranscription(ABC):
@@ -91,37 +93,26 @@ class AbstractTranscription(ABC):
91
  """
92
  return
93
 
94
- def transcribe(self, audio: str, whisperCallable, config: TranscriptionConfig):
95
  """
96
- Transcribe the given audo file.
 
97
 
98
  Parameters
99
  ----------
100
  audio: str
101
- The audio file.
102
-
103
- whisperCallable: Callable[[Union[str, np.ndarray, torch.Tensor], int, str, str], dict[str, Union[dict, Any]]]
104
- The callback that is used to invoke Whisper on an audio file/buffer. The first parameter is the audio file/buffer,
105
- the second parameter is an optional text prompt, and the last is the current detected language. The return value is the result of the Whisper call.
106
 
107
  Returns
108
  -------
109
  A list of start and end timestamps, in fractional seconds.
110
  """
111
-
112
- # get speech timestamps from full audio file
113
  seconds_timestamps = self.get_transcribe_timestamps(audio, config)
114
 
115
- #for seconds_timestamp in seconds_timestamps:
116
- # print("VAD timestamp ", format_timestamp(seconds_timestamp['start']), " to ", format_timestamp(seconds_timestamp['end']))
117
-
118
- merged = merge_timestamps(seconds_timestamps, config.max_silent_period, config.max_merge_size, config.segment_padding_left, config.segment_padding_right)
119
-
120
- # A deque of transcribed segments that is passed to the next segment as a prompt
121
- prompt_window = deque()
122
-
123
- print("Timestamps:")
124
- pprint(merged)
125
 
126
  if config.non_speech_strategy != NonSpeechStrategy.SKIP:
127
  max_audio_duration = get_audio_duration(audio)
@@ -138,6 +129,32 @@ class AbstractTranscription(ABC):
138
 
139
  print("Transcribing non-speech:")
140
  pprint(merged)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  result = {
143
  'text': "",
@@ -147,7 +164,7 @@ class AbstractTranscription(ABC):
147
  languageCounter = Counter()
148
  detected_language = None
149
 
150
- segment_index = -1
151
 
152
  # For each time segment, run whisper
153
  for segment in merged:
@@ -172,7 +189,7 @@ class AbstractTranscription(ABC):
172
 
173
  print("Running whisper from ", format_timestamp(segment_start), " to ", format_timestamp(segment_end), ", duration: ",
174
  segment_duration, "expanded: ", segment_expand_amount, "prompt: ", segment_prompt, "language: ", detected_language)
175
- segment_result = whisperCallable(segment_audio, segment_index, segment_prompt, detected_language)
176
 
177
  adjusted_segments = self.adjust_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
178
 
@@ -373,6 +390,7 @@ class AbstractTranscription(ABC):
373
  })
374
  return result
375
 
 
376
  class VadSileroTranscription(AbstractTranscription):
377
  def __init__(self, sampling_rate: int = 16000):
378
  super().__init__(sampling_rate=sampling_rate)
 
6
  from pprint import pprint
7
 
8
  from src.segments import merge_timestamps
9
+ from src.whisperContainer import WhisperCallback
10
 
11
  # Workaround for https://github.com/tensorflow/tensorflow/issues/48797
12
  try:
 
52
  class TranscriptionConfig(ABC):
53
  def __init__(self, non_speech_strategy: NonSpeechStrategy = NonSpeechStrategy.SKIP,
54
  segment_padding_left: float = None, segment_padding_right = None, max_silent_period: float = None,
55
+ max_merge_size: float = None, max_prompt_window: float = None, initial_segment_index = -1):
56
  self.non_speech_strategy = non_speech_strategy
57
  self.segment_padding_left = segment_padding_left
58
  self.segment_padding_right = segment_padding_right
59
  self.max_silent_period = max_silent_period
60
  self.max_merge_size = max_merge_size
61
  self.max_prompt_window = max_prompt_window
62
+ self.initial_segment_index = initial_segment_index
63
 
64
  class PeriodicTranscriptionConfig(TranscriptionConfig):
65
  def __init__(self, periodic_duration: float, non_speech_strategy: NonSpeechStrategy = NonSpeechStrategy.SKIP,
66
  segment_padding_left: float = None, segment_padding_right = None, max_silent_period: float = None,
67
+ max_merge_size: float = None, max_prompt_window: float = None, initial_segment_index = -1):
68
+ super().__init__(non_speech_strategy, segment_padding_left, segment_padding_right, max_silent_period, max_merge_size, max_prompt_window, initial_segment_index)
69
  self.periodic_duration = periodic_duration
70
 
71
  class AbstractTranscription(ABC):
 
93
  """
94
  return
95
 
96
+ def get_merged_timestamps(self, audio: str, config: TranscriptionConfig):
97
  """
98
+ Get the start and end timestamps of the sections that should be transcribed by this VAD method,
99
+ after merging the segments using the specified configuration.
100
 
101
  Parameters
102
  ----------
103
  audio: str
104
+ The audio file.
105
+ config: TranscriptionConfig
106
+ The transcription configuration.
 
 
107
 
108
  Returns
109
  -------
110
  A list of start and end timestamps, in fractional seconds.
111
  """
 
 
112
  seconds_timestamps = self.get_transcribe_timestamps(audio, config)
113
 
114
+ merged = merge_timestamps(seconds_timestamps, config.max_silent_period, config.max_merge_size,
115
+ config.segment_padding_left, config.segment_padding_right)
 
 
 
 
 
 
 
 
116
 
117
  if config.non_speech_strategy != NonSpeechStrategy.SKIP:
118
  max_audio_duration = get_audio_duration(audio)
 
129
 
130
  print("Transcribing non-speech:")
131
  pprint(merged)
132
+ return merged
133
+
134
+ def transcribe(self, audio: str, whisperCallable: WhisperCallback, config: TranscriptionConfig):
135
+ """
136
+ Transcribe the given audo file.
137
+
138
+ Parameters
139
+ ----------
140
+ audio: str
141
+ The audio file.
142
+ whisperCallable: WhisperCallback
143
+ A callback object to call to transcribe each segment.
144
+
145
+ Returns
146
+ -------
147
+ A list of start and end timestamps, in fractional seconds.
148
+ """
149
+
150
+ # Get speech timestamps from full audio file
151
+ merged = self.get_merged_timestamps(audio, config)
152
+
153
+ # A deque of transcribed segments that is passed to the next segment as a prompt
154
+ prompt_window = deque()
155
+
156
+ print("Processing timestamps:")
157
+ pprint(merged)
158
 
159
  result = {
160
  'text': "",
 
164
  languageCounter = Counter()
165
  detected_language = None
166
 
167
+ segment_index = config.initial_segment_index
168
 
169
  # For each time segment, run whisper
170
  for segment in merged:
 
189
 
190
  print("Running whisper from ", format_timestamp(segment_start), " to ", format_timestamp(segment_end), ", duration: ",
191
  segment_duration, "expanded: ", segment_expand_amount, "prompt: ", segment_prompt, "language: ", detected_language)
192
+ segment_result = whisperCallable.invoke(segment_audio, segment_index, segment_prompt, detected_language)
193
 
194
  adjusted_segments = self.adjust_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
195
 
 
390
  })
391
  return result
392
 
393
+
394
  class VadSileroTranscription(AbstractTranscription):
395
  def __init__(self, sampling_rate: int = 16000):
396
  super().__init__(sampling_rate=sampling_rate)
src/vadParallel.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.vad import AbstractTranscription, TranscriptionConfig
2
+ from src.whisperContainer import WhisperCallback
3
+
4
+ from multiprocessing import Pool
5
+
6
+ from typing import List
7
+ import os
8
+
9
+ class ParallelTranscriptionConfig(TranscriptionConfig):
10
+ def __init__(self, device_id: str, override_timestamps, initial_segment_index, copy: TranscriptionConfig = None):
11
+ super().__init__(copy.non_speech_strategy, copy.segment_padding_left, copy.segment_padding_right, copy.max_silent_period, copy.max_merge_size, copy.max_prompt_window, initial_segment_index)
12
+ self.device_id = device_id
13
+ self.override_timestamps = override_timestamps
14
+
15
+ class ParallelTranscription(AbstractTranscription):
16
+ def __init__(self, sampling_rate: int = 16000):
17
+ super().__init__(sampling_rate=sampling_rate)
18
+
19
+
20
+ def transcribe_parallel(self, transcription: AbstractTranscription, audio: str, whisperCallable: WhisperCallback, config: TranscriptionConfig, devices: List[str]):
21
+ # First, get the timestamps for the original audio
22
+ merged = transcription.get_merged_timestamps(audio, config)
23
+
24
+ # Split into a list for each device
25
+ merged_split = self._chunks(merged, len(merged) // len(devices))
26
+
27
+ # Parameters that will be passed to the transcribe function
28
+ parameters = []
29
+ segment_index = config.initial_segment_index
30
+
31
+ for i in range(len(devices)):
32
+ device_segment_list = merged_split[i]
33
+
34
+ # Create a new config with the given device ID
35
+ device_config = ParallelTranscriptionConfig(devices[i], device_segment_list, segment_index, config)
36
+ segment_index += len(device_segment_list)
37
+
38
+ parameters.append([audio, whisperCallable, device_config]);
39
+
40
+ merged = {
41
+ 'text': '',
42
+ 'segments': [],
43
+ 'language': None
44
+ }
45
+
46
+ with Pool(len(devices)) as p:
47
+ # Run the transcription in parallel
48
+ results = p.starmap(self.transcribe, parameters)
49
+
50
+ for result in results:
51
+ # Merge the results
52
+ if (result['text'] is not None):
53
+ merged['text'] += result['text']
54
+ if (result['segments'] is not None):
55
+ merged['segments'].extend(result['segments'])
56
+ if (result['language'] is not None):
57
+ merged['language'] = result['language']
58
+
59
+ return merged
60
+
61
+ def get_transcribe_timestamps(self, audio: str, config: ParallelTranscriptionConfig):
62
+ return []
63
+
64
+ def get_merged_timestamps(self, audio: str, config: ParallelTranscriptionConfig):
65
+ # Override timestamps that will be processed
66
+ if (config.override_timestamps is not None):
67
+ print("Using override timestamps of size " + str(len(config.override_timestamps)))
68
+ return config.override_timestamps
69
+ return super().get_merged_timestamps(audio, config)
70
+
71
+ def transcribe(self, audio: str, whisperCallable: WhisperCallback, config: ParallelTranscriptionConfig):
72
+ # Override device ID
73
+ if (config.device_id is not None):
74
+ print("Using device " + config.device_id)
75
+ os.environ["CUDA_VISIBLE_DEVICES"] = config.device_id
76
+ return super().transcribe(audio, whisperCallable, config)
77
+
78
+ def _chunks(self, lst, n):
79
+ """Yield successive n-sized chunks from lst."""
80
+ return [lst[i:i + n] for i in range(0, len(lst), n)]
81
+
src/whisperContainer.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # External programs
2
+ import whisper
3
+
4
+ class WhisperContainer:
5
+ def __init__(self, model_name: str, device: str = None):
6
+ self.model_name = model_name
7
+ self.device = device
8
+
9
+ # Will be created on demand
10
+ self.model = None
11
+
12
+ def get_model(self):
13
+ if self.model is None:
14
+ print("Loading model " + self.model_name)
15
+ self.model = whisper.load_model(self.model_name, device=self.device)
16
+ return self.model
17
+
18
+ def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
19
+ """
20
+ Create a WhisperCallback object that can be used to transcript audio files.
21
+
22
+ Parameters
23
+ ----------
24
+ language: str
25
+ The target language of the transcription. If not specified, the language will be inferred from the audio content.
26
+ task: str
27
+ The task - either translate or transcribe.
28
+ initial_prompt: str
29
+ The initial prompt to use for the transcription.
30
+ decodeOptions: dict
31
+ Additional options to pass to the decoder. Must be pickleable.
32
+
33
+ Returns
34
+ -------
35
+ A WhisperCallback object.
36
+ """
37
+ return WhisperCallback(self, language=language, task=task, initial_prompt=initial_prompt, **decodeOptions)
38
+
39
+ # This is required for multiprocessing
40
+ def __getstate__(self):
41
+ return { "model_name": self.model_name, "device": self.device }
42
+
43
+ def __setstate__(self, state):
44
+ self.model_name = state["model_name"]
45
+ self.device = state["device"]
46
+ self.model = None
47
+
48
+
49
+ class WhisperCallback:
50
+ def __init__(self, model_container: WhisperContainer, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
51
+ self.model_container = model_container
52
+ self.language = language
53
+ self.task = task
54
+ self.initial_prompt = initial_prompt
55
+ self.decodeOptions = decodeOptions
56
+
57
+ def invoke(self, audio, segment_index: int, prompt: str, detected_language: str):
58
+ """
59
+ Peform the transcription of the given audio file or data.
60
+
61
+ Parameters
62
+ ----------
63
+ audio: Union[str, np.ndarray, torch.Tensor]
64
+ The audio file to transcribe, or the audio data as a numpy array or torch tensor.
65
+ segment_index: int
66
+ The target language of the transcription. If not specified, the language will be inferred from the audio content.
67
+ task: str
68
+ The task - either translate or transcribe.
69
+ prompt: str
70
+ The prompt to use for the transcription.
71
+ detected_language: str
72
+ The detected language of the audio file.
73
+
74
+ Returns
75
+ -------
76
+ The result of the Whisper call.
77
+ """
78
+ model = self.model_container.get_model()
79
+
80
+ return model.transcribe(audio, \
81
+ language=self.language if self.language else detected_language, task=self.task, \
82
+ initial_prompt=self._concat_prompt(self.initial_prompt, prompt) if segment_index == 0 else prompt, \
83
+ **self.decodeOptions)
84
+
85
+ def _concat_prompt(self, prompt1, prompt2):
86
+ if (prompt1 is None):
87
+ return prompt2
88
+ elif (prompt2 is None):
89
+ return prompt1
90
+ else:
91
+ return prompt1 + " " + prompt2