aadnk commited on
Commit
8031785
·
1 Parent(s): 168184d

Add inital prompt mode. GITLAB #7

Browse files

Add a new configuration option `vad_initial_prompt_mode` in config.json5
and the application arguments that have the following two modes:
* prepend_all_segments: Preprend the initial prompt to each VAD segment
* prepend_first_segment: Only preprend the initial prompt the first VAD segment.

This is useful if you're using the prompt to improve the accuracy of
the transcription of unusual technical terms consistently throughout a lecture. You
can add these terms to the prompt, and then set `vad_initial_prompt_mode` to
`prepend_all_segments` to include the prompt in every VAD segment.

Note that this will have no effect if you're not using a VAD.

app.py CHANGED
@@ -12,7 +12,7 @@ import numpy as np
12
 
13
  import torch
14
 
15
- from src.config import ApplicationConfig
16
  from src.hooks.progressListener import ProgressListener
17
  from src.hooks.subTaskProgressListener import SubTaskProgressListener
18
  from src.hooks.whisperProgressHook import create_progress_listener_handle
@@ -43,6 +43,17 @@ MAX_AUTO_CPU_CORES = 8
43
 
44
  WHISPER_MODELS = ["tiny", "base", "small", "medium", "large", "large-v1", "large-v2"]
45
 
 
 
 
 
 
 
 
 
 
 
 
46
  class WhisperTranscriber:
47
  def __init__(self, input_audio_max_duration: float = None, vad_process_timeout: float = None,
48
  vad_cpu_cores: int = 1, delete_uploaded_files: bool = False, output_dir: str = None,
@@ -75,11 +86,14 @@ class WhisperTranscriber:
75
  # Entry function for the simple tab
76
  def transcribe_webui_simple(self, modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow,
77
  progress=gr.Progress()):
78
- return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow,
79
- progress=progress)
 
 
80
 
81
  # Entry function for the full tab
82
- def transcribe_webui_full(self, modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow,
 
83
  initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
84
  condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
85
  compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float,
@@ -91,14 +105,16 @@ class WhisperTranscriber:
91
  else:
92
  temperature = [temperature]
93
 
94
- return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow,
 
 
95
  initial_prompt=initial_prompt, temperature=temperature, best_of=best_of, beam_size=beam_size, patience=patience, length_penalty=length_penalty, suppress_tokens=suppress_tokens,
96
  condition_on_previous_text=condition_on_previous_text, fp16=fp16,
97
  compression_ratio_threshold=compression_ratio_threshold, logprob_threshold=logprob_threshold, no_speech_threshold=no_speech_threshold,
98
  progress=progress)
99
 
100
- def transcribe_webui(self, modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow,
101
- progress: gr.Progress = None, **decodeOptions: dict):
102
  try:
103
  sources = self.__get_source(urlData, multipleFiles, microphoneData)
104
 
@@ -146,7 +162,7 @@ class WhisperTranscriber:
146
  sub_task_total=source_audio_duration)
147
 
148
  # Transcribe
149
- result = self.transcribe_file(model, source.source_path, selectedLanguage, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, scaled_progress_listener, **decodeOptions)
150
  filePrefix = slugify(source_prefix + source.get_short_name(), allow_unicode=True)
151
 
152
  # Update progress
@@ -210,8 +226,8 @@ class WhisperTranscriber:
210
  except ExceededMaximumDuration as e:
211
  return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
212
 
213
- def transcribe_file(self, model: AbstractWhisperContainer, audio_path: str, language: str, task: str = None, vad: str = None,
214
- vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1,
215
  progressListener: ProgressListener = None, **decodeOptions: dict):
216
 
217
  initial_prompt = decodeOptions.pop('initial_prompt', None)
@@ -224,26 +240,26 @@ class WhisperTranscriber:
224
  task = decodeOptions.pop('task')
225
 
226
  # Callable for processing an audio file
227
- whisperCallable = model.create_callback(language, task, initial_prompt, **decodeOptions)
228
 
229
  # The results
230
- if (vad == 'silero-vad'):
231
  # Silero VAD where non-speech gaps are transcribed
232
- process_gaps = self._create_silero_config(NonSpeechStrategy.CREATE_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
233
  result = self.process_vad(audio_path, whisperCallable, self.vad_model, process_gaps, progressListener=progressListener)
234
- elif (vad == 'silero-vad-skip-gaps'):
235
  # Silero VAD where non-speech gaps are simply ignored
236
- skip_gaps = self._create_silero_config(NonSpeechStrategy.SKIP, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
237
  result = self.process_vad(audio_path, whisperCallable, self.vad_model, skip_gaps, progressListener=progressListener)
238
- elif (vad == 'silero-vad-expand-into-gaps'):
239
  # Use Silero VAD where speech-segments are expanded into non-speech gaps
240
- expand_gaps = self._create_silero_config(NonSpeechStrategy.EXPAND_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
241
  result = self.process_vad(audio_path, whisperCallable, self.vad_model, expand_gaps, progressListener=progressListener)
242
- elif (vad == 'periodic-vad'):
243
  # Very simple VAD - mark every 5 minutes as speech. This makes it less likely that Whisper enters an infinite loop, but
244
  # it may create a break in the middle of a sentence, causing some artifacts.
245
  periodic_vad = VadPeriodicTranscription()
246
- period_config = PeriodicTranscriptionConfig(periodic_duration=vadMaxMergeSize, max_prompt_window=vadPromptWindow)
247
  result = self.process_vad(audio_path, whisperCallable, periodic_vad, period_config, progressListener=progressListener)
248
 
249
  else:
@@ -314,15 +330,15 @@ class WhisperTranscriber:
314
  else:
315
  return prompt1 + " " + prompt2
316
 
317
- def _create_silero_config(self, non_speech_strategy: NonSpeechStrategy, vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1):
318
  # Use Silero VAD
319
  if (self.vad_model is None):
320
  self.vad_model = VadSileroTranscription()
321
 
322
  config = TranscriptionConfig(non_speech_strategy = non_speech_strategy,
323
- max_silent_period=vadMergeWindow, max_merge_size=vadMaxMergeSize,
324
- segment_padding_left=vadPadding, segment_padding_right=vadPadding,
325
- max_prompt_window=vadPromptWindow)
326
 
327
  return config
328
 
@@ -451,6 +467,7 @@ def create_ui(app_config: ApplicationConfig):
451
 
452
  full_transcribe = gr.Interface(fn=ui.transcribe_webui_full, description=full_description, article=ui_article, inputs=[
453
  *simple_inputs(),
 
454
  gr.TextArea(label="Initial Prompt"),
455
  gr.Number(label="Temperature", value=app_config.temperature),
456
  gr.Number(label="Best Of - Non-zero temperature", value=app_config.best_of, precision=0),
@@ -503,6 +520,8 @@ if __name__ == '__main__':
503
  help="The default model name.") # medium
504
  parser.add_argument("--default_vad", type=str, default=default_app_config.default_vad, \
505
  help="The default VAD.") # silero-vad
 
 
506
  parser.add_argument("--vad_parallel_devices", type=str, default=default_app_config.vad_parallel_devices, \
507
  help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.") # ""
508
  parser.add_argument("--vad_cpu_cores", type=int, default=default_app_config.vad_cpu_cores, \
 
12
 
13
  import torch
14
 
15
+ from src.config import ApplicationConfig, VadInitialPromptMode
16
  from src.hooks.progressListener import ProgressListener
17
  from src.hooks.subTaskProgressListener import SubTaskProgressListener
18
  from src.hooks.whisperProgressHook import create_progress_listener_handle
 
43
 
44
  WHISPER_MODELS = ["tiny", "base", "small", "medium", "large", "large-v1", "large-v2"]
45
 
46
+ class VadOptions:
47
+ def __init__(self, vad: str = None, vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1,
48
+ vadInitialPromptMode: Union[VadInitialPromptMode, str] = VadInitialPromptMode.PREPREND_FIRST_SEGMENT):
49
+ self.vad = vad
50
+ self.vadMergeWindow = vadMergeWindow
51
+ self.vadMaxMergeSize = vadMaxMergeSize
52
+ self.vadPadding = vadPadding
53
+ self.vadPromptWindow = vadPromptWindow
54
+ self.vadInitialPromptMode = vadInitialPromptMode if isinstance(vadInitialPromptMode, VadInitialPromptMode) \
55
+ else VadInitialPromptMode.from_string(vadInitialPromptMode)
56
+
57
  class WhisperTranscriber:
58
  def __init__(self, input_audio_max_duration: float = None, vad_process_timeout: float = None,
59
  vad_cpu_cores: int = 1, delete_uploaded_files: bool = False, output_dir: str = None,
 
86
  # Entry function for the simple tab
87
  def transcribe_webui_simple(self, modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow,
88
  progress=gr.Progress()):
89
+
90
+ vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, self.app_config.vad_initial_prompt_mode)
91
+
92
+ return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vadOptions, progress=progress)
93
 
94
  # Entry function for the full tab
95
+ def transcribe_webui_full(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
96
+ vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
97
  initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
98
  condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
99
  compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float,
 
105
  else:
106
  temperature = [temperature]
107
 
108
+ vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode)
109
+
110
+ return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vadOptions,
111
  initial_prompt=initial_prompt, temperature=temperature, best_of=best_of, beam_size=beam_size, patience=patience, length_penalty=length_penalty, suppress_tokens=suppress_tokens,
112
  condition_on_previous_text=condition_on_previous_text, fp16=fp16,
113
  compression_ratio_threshold=compression_ratio_threshold, logprob_threshold=logprob_threshold, no_speech_threshold=no_speech_threshold,
114
  progress=progress)
115
 
116
+ def transcribe_webui(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
117
+ vadOptions: VadOptions, progress: gr.Progress = None, **decodeOptions: dict):
118
  try:
119
  sources = self.__get_source(urlData, multipleFiles, microphoneData)
120
 
 
162
  sub_task_total=source_audio_duration)
163
 
164
  # Transcribe
165
+ result = self.transcribe_file(model, source.source_path, selectedLanguage, task, vadOptions, scaled_progress_listener, **decodeOptions)
166
  filePrefix = slugify(source_prefix + source.get_short_name(), allow_unicode=True)
167
 
168
  # Update progress
 
226
  except ExceededMaximumDuration as e:
227
  return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
228
 
229
+ def transcribe_file(self, model: AbstractWhisperContainer, audio_path: str, language: str, task: str = None,
230
+ vadOptions: VadOptions = VadOptions(),
231
  progressListener: ProgressListener = None, **decodeOptions: dict):
232
 
233
  initial_prompt = decodeOptions.pop('initial_prompt', None)
 
240
  task = decodeOptions.pop('task')
241
 
242
  # Callable for processing an audio file
243
+ whisperCallable = model.create_callback(language, task, initial_prompt, initial_prompt_mode=vadOptions.vadInitialPromptMode, **decodeOptions)
244
 
245
  # The results
246
+ if (vadOptions.vad == 'silero-vad'):
247
  # Silero VAD where non-speech gaps are transcribed
248
+ process_gaps = self._create_silero_config(NonSpeechStrategy.CREATE_SEGMENT, vadOptions)
249
  result = self.process_vad(audio_path, whisperCallable, self.vad_model, process_gaps, progressListener=progressListener)
250
+ elif (vadOptions.vad == 'silero-vad-skip-gaps'):
251
  # Silero VAD where non-speech gaps are simply ignored
252
+ skip_gaps = self._create_silero_config(NonSpeechStrategy.SKIP, vadOptions)
253
  result = self.process_vad(audio_path, whisperCallable, self.vad_model, skip_gaps, progressListener=progressListener)
254
+ elif (vadOptions.vad == 'silero-vad-expand-into-gaps'):
255
  # Use Silero VAD where speech-segments are expanded into non-speech gaps
256
+ expand_gaps = self._create_silero_config(NonSpeechStrategy.EXPAND_SEGMENT, vadOptions)
257
  result = self.process_vad(audio_path, whisperCallable, self.vad_model, expand_gaps, progressListener=progressListener)
258
+ elif (vadOptions.vad == 'periodic-vad'):
259
  # Very simple VAD - mark every 5 minutes as speech. This makes it less likely that Whisper enters an infinite loop, but
260
  # it may create a break in the middle of a sentence, causing some artifacts.
261
  periodic_vad = VadPeriodicTranscription()
262
+ period_config = PeriodicTranscriptionConfig(periodic_duration=vadOptions.vadMaxMergeSize, max_prompt_window=vadOptions.vadPromptWindow)
263
  result = self.process_vad(audio_path, whisperCallable, periodic_vad, period_config, progressListener=progressListener)
264
 
265
  else:
 
330
  else:
331
  return prompt1 + " " + prompt2
332
 
333
+ def _create_silero_config(self, non_speech_strategy: NonSpeechStrategy, vadOptions: VadOptions):
334
  # Use Silero VAD
335
  if (self.vad_model is None):
336
  self.vad_model = VadSileroTranscription()
337
 
338
  config = TranscriptionConfig(non_speech_strategy = non_speech_strategy,
339
+ max_silent_period=vadOptions.vadMergeWindow, max_merge_size=vadOptions.vadMaxMergeSize,
340
+ segment_padding_left=vadOptions.vadPadding, segment_padding_right=vadOptions.vadPadding,
341
+ max_prompt_window=vadOptions.vadPromptWindow)
342
 
343
  return config
344
 
 
467
 
468
  full_transcribe = gr.Interface(fn=ui.transcribe_webui_full, description=full_description, article=ui_article, inputs=[
469
  *simple_inputs(),
470
+ gr.Dropdown(choices=["prepend_first_segment", "prepend_all_segments"], value=app_config.vad_initial_prompt_mode, label="VAD - Initial Prompt Mode"),
471
  gr.TextArea(label="Initial Prompt"),
472
  gr.Number(label="Temperature", value=app_config.temperature),
473
  gr.Number(label="Best Of - Non-zero temperature", value=app_config.best_of, precision=0),
 
520
  help="The default model name.") # medium
521
  parser.add_argument("--default_vad", type=str, default=default_app_config.default_vad, \
522
  help="The default VAD.") # silero-vad
523
+ parser.add_argument("--vad_initial_prompt_mode", type=str, default=default_app_config.vad_initial_prompt_mode, choices=["prepend_all_segments", "prepend_first_segment"], \
524
+ help="Whether or not to prepend the initial prompt to each VAD segment (prepend_all_segments), or just the first segment (prepend_first_segment)") # prepend_first_segment
525
  parser.add_argument("--vad_parallel_devices", type=str, default=default_app_config.vad_parallel_devices, \
526
  help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.") # ""
527
  parser.add_argument("--vad_cpu_cores", type=int, default=default_app_config.vad_cpu_cores, \
cli.py CHANGED
@@ -6,8 +6,8 @@ import warnings
6
  import numpy as np
7
 
8
  import torch
9
- from app import WhisperTranscriber
10
- from src.config import ApplicationConfig
11
  from src.download import download_url
12
  from src.languages import get_language_names
13
 
@@ -47,6 +47,8 @@ def cli():
47
 
48
  parser.add_argument("--vad", type=str, default=app_config.default_vad, choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], \
49
  help="The voice activity detection algorithm to use") # silero-vad
 
 
50
  parser.add_argument("--vad_merge_window", type=optional_float, default=app_config.vad_merge_window, \
51
  help="The window size (in seconds) to merge voice segments")
52
  parser.add_argument("--vad_max_merge_size", type=optional_float, default=app_config.vad_max_merge_size,\
@@ -115,6 +117,7 @@ def cli():
115
  temperature = [temperature]
116
 
117
  vad = args.pop("vad")
 
118
  vad_merge_window = args.pop("vad_merge_window")
119
  vad_max_merge_size = args.pop("vad_max_merge_size")
120
  vad_padding = args.pop("vad_padding")
@@ -150,9 +153,10 @@ def cli():
150
  source_path = source["path"]
151
  source_name = source["name"]
152
 
153
- result = transcriber.transcribe_file(model, source_path, temperature=temperature,
154
- vad=vad, vadMergeWindow=vad_merge_window, vadMaxMergeSize=vad_max_merge_size,
155
- vadPadding=vad_padding, vadPromptWindow=vad_prompt_window, **args)
 
156
 
157
  transcriber.write_result(result, source_name, output_dir)
158
 
 
6
  import numpy as np
7
 
8
  import torch
9
+ from app import VadOptions, WhisperTranscriber
10
+ from src.config import ApplicationConfig, VadInitialPromptMode
11
  from src.download import download_url
12
  from src.languages import get_language_names
13
 
 
47
 
48
  parser.add_argument("--vad", type=str, default=app_config.default_vad, choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], \
49
  help="The voice activity detection algorithm to use") # silero-vad
50
+ parser.add_argument("--vad_initial_prompt_mode", type=str, default=app_config.vad_initial_prompt_mode, choices=["prepend_all_segments", "prepend_first_segment"], \
51
+ help="Whether or not to prepend the initial prompt to each VAD segment (prepend_all_segments), or just the first segment (prepend_first_segment)") # prepend_first_segment
52
  parser.add_argument("--vad_merge_window", type=optional_float, default=app_config.vad_merge_window, \
53
  help="The window size (in seconds) to merge voice segments")
54
  parser.add_argument("--vad_max_merge_size", type=optional_float, default=app_config.vad_max_merge_size,\
 
117
  temperature = [temperature]
118
 
119
  vad = args.pop("vad")
120
+ vad_initial_prompt_mode = args.pop("vad_initial_prompt_mode")
121
  vad_merge_window = args.pop("vad_merge_window")
122
  vad_max_merge_size = args.pop("vad_max_merge_size")
123
  vad_padding = args.pop("vad_padding")
 
153
  source_path = source["path"]
154
  source_name = source["name"]
155
 
156
+ vadOptions = VadOptions(vad, vad_merge_window, vad_max_merge_size, vad_padding, vad_prompt_window,
157
+ VadInitialPromptMode.from_string(vad_initial_prompt_mode))
158
+
159
+ result = transcriber.transcribe_file(model, source_path, temperature=temperature, vadOptions=vadOptions, **args)
160
 
161
  transcriber.write_result(result, source_name, output_dir)
162
 
config.json5 CHANGED
@@ -97,6 +97,8 @@
97
  "vad_max_merge_size": 30,
98
  // The padding (in seconds) to add to each voice segment
99
  "vad_padding": 1,
 
 
100
  // The window size of the prompt to pass to Whisper
101
  "vad_prompt_window": 3,
102
  // Temperature to use for sampling
 
97
  "vad_max_merge_size": 30,
98
  // The padding (in seconds) to add to each voice segment
99
  "vad_padding": 1,
100
+ // Whether or not to prepend the initial prompt to each VAD segment (prepend_all_segments), or just the first segment (prepend_first_segment)
101
+ "vad_initial_prompt_mode": "prepend_first_segment",
102
  // The window size of the prompt to pass to Whisper
103
  "vad_prompt_window": 3,
104
  // Temperature to use for sampling
src/config.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import urllib
2
 
3
  import os
@@ -23,6 +24,21 @@ class ModelConfig:
23
  self.path = path
24
  self.type = type
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  class ApplicationConfig:
27
  def __init__(self, models: List[ModelConfig] = [], input_audio_max_duration: int = 600,
28
  share: bool = False, server_name: str = None, server_port: int = 7860,
@@ -33,6 +49,7 @@ class ApplicationConfig:
33
  auto_parallel: bool = False, output_dir: str = None,
34
  model_dir: str = None, device: str = None,
35
  verbose: bool = True, task: str = "transcribe", language: str = None,
 
36
  vad_merge_window: float = 5, vad_max_merge_size: float = 30,
37
  vad_padding: float = 1, vad_prompt_window: float = 3,
38
  temperature: float = 0, best_of: int = 5, beam_size: int = 5,
@@ -67,6 +84,7 @@ class ApplicationConfig:
67
  self.verbose = verbose
68
  self.task = task
69
  self.language = language
 
70
  self.vad_merge_window = vad_merge_window
71
  self.vad_max_merge_size = vad_max_merge_size
72
  self.vad_padding = vad_padding
 
1
+ from enum import Enum
2
  import urllib
3
 
4
  import os
 
24
  self.path = path
25
  self.type = type
26
 
27
+ class VadInitialPromptMode(Enum):
28
+ PREPEND_ALL_SEGMENTS = 1
29
+ PREPREND_FIRST_SEGMENT = 2
30
+
31
+ @staticmethod
32
+ def from_string(s: str):
33
+ normalized = s.lower() if s is not None else None
34
+
35
+ if normalized == "prepend_all_segments":
36
+ return VadInitialPromptMode.PREPEND_ALL_SEGMENTS
37
+ elif normalized == "prepend_first_segment":
38
+ return VadInitialPromptMode.PREPREND_FIRST_SEGMENT
39
+ else:
40
+ raise ValueError(f"Invalid value for VadInitialPromptMode: {s}")
41
+
42
  class ApplicationConfig:
43
  def __init__(self, models: List[ModelConfig] = [], input_audio_max_duration: int = 600,
44
  share: bool = False, server_name: str = None, server_port: int = 7860,
 
49
  auto_parallel: bool = False, output_dir: str = None,
50
  model_dir: str = None, device: str = None,
51
  verbose: bool = True, task: str = "transcribe", language: str = None,
52
+ vad_initial_prompt_mode: str = "prepend_first_segment ",
53
  vad_merge_window: float = 5, vad_max_merge_size: float = 30,
54
  vad_padding: float = 1, vad_prompt_window: float = 3,
55
  temperature: float = 0, best_of: int = 5, beam_size: int = 5,
 
84
  self.verbose = verbose
85
  self.task = task
86
  self.language = language
87
+ self.vad_initial_prompt_mode = vad_initial_prompt_mode
88
  self.vad_merge_window = vad_merge_window
89
  self.vad_max_merge_size = vad_max_merge_size
90
  self.vad_padding = vad_padding
src/whisper/abstractWhisperContainer.py CHANGED
@@ -1,6 +1,6 @@
1
  import abc
2
  from typing import List
3
- from src.config import ModelConfig
4
 
5
  from src.hooks.progressListener import ProgressListener
6
  from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
@@ -24,6 +24,15 @@ class AbstractWhisperCallback:
24
  """
25
  raise NotImplementedError()
26
 
 
 
 
 
 
 
 
 
 
27
  def _concat_prompt(self, prompt1, prompt2):
28
  if (prompt1 is None):
29
  return prompt2
@@ -66,7 +75,9 @@ class AbstractWhisperContainer:
66
  pass
67
 
68
  @abc.abstractmethod
69
- def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict) -> AbstractWhisperCallback:
 
 
70
  """
71
  Create a WhisperCallback object that can be used to transcript audio files.
72
 
@@ -78,6 +89,9 @@ class AbstractWhisperContainer:
78
  The task - either translate or transcribe.
79
  initial_prompt: str
80
  The initial prompt to use for the transcription.
 
 
 
81
  decodeOptions: dict
82
  Additional options to pass to the decoder. Must be pickleable.
83
 
 
1
  import abc
2
  from typing import List
3
+ from src.config import ModelConfig, VadInitialPromptMode
4
 
5
  from src.hooks.progressListener import ProgressListener
6
  from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
 
24
  """
25
  raise NotImplementedError()
26
 
27
+ def _get_initial_prompt(self, initial_prompt: str, initial_prompt_mode: VadInitialPromptMode,
28
+ prompt: str, segment_index: int):
29
+ if (initial_prompt_mode == VadInitialPromptMode.PREPEND_ALL_SEGMENTS):
30
+ return self._concat_prompt(initial_prompt, prompt)
31
+ elif (initial_prompt_mode == VadInitialPromptMode.PREPREND_FIRST_SEGMENT):
32
+ return self._concat_prompt(initial_prompt, prompt) if segment_index == 0 else prompt
33
+ else:
34
+ raise ValueError(f"Unknown initial prompt mode {initial_prompt_mode}")
35
+
36
  def _concat_prompt(self, prompt1, prompt2):
37
  if (prompt1 is None):
38
  return prompt2
 
75
  pass
76
 
77
  @abc.abstractmethod
78
+ def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None,
79
+ initial_prompt_mode: VadInitialPromptMode = VadInitialPromptMode.PREPREND_FIRST_SEGMENT,
80
+ **decodeOptions: dict) -> AbstractWhisperCallback:
81
  """
82
  Create a WhisperCallback object that can be used to transcript audio files.
83
 
 
89
  The task - either translate or transcribe.
90
  initial_prompt: str
91
  The initial prompt to use for the transcription.
92
+ initial_prompt_mode: VadInitialPromptMode
93
+ The mode to use for the initial prompt. If set to PREPEND_FIRST_SEGMENT, the initial prompt will be prepended to the first segment of audio.
94
+ If set to PREPEND_ALL_SEGMENTS, the initial prompt will be prepended to all segments of audio.
95
  decodeOptions: dict
96
  Additional options to pass to the decoder. Must be pickleable.
97
 
src/whisper/fasterWhisperContainer.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  from typing import List, Union
3
 
4
  from faster_whisper import WhisperModel, download_model
5
- from src.config import ModelConfig
6
  from src.hooks.progressListener import ProgressListener
7
  from src.languages import get_language_from_name
8
  from src.modelCache import ModelCache
@@ -51,7 +51,9 @@ class FasterWhisperContainer(AbstractWhisperContainer):
51
  model = WhisperModel(model_config.url, device=device, compute_type=self.compute_type)
52
  return model
53
 
54
- def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
 
 
55
  """
56
  Create a WhisperCallback object that can be used to transcript audio files.
57
 
@@ -63,6 +65,9 @@ class FasterWhisperContainer(AbstractWhisperContainer):
63
  The task - either translate or transcribe.
64
  initial_prompt: str
65
  The initial prompt to use for the transcription.
 
 
 
66
  decodeOptions: dict
67
  Additional options to pass to the decoder. Must be pickleable.
68
 
@@ -70,14 +75,17 @@ class FasterWhisperContainer(AbstractWhisperContainer):
70
  -------
71
  A WhisperCallback object.
72
  """
73
- return FasterWhisperCallback(self, language=language, task=task, initial_prompt=initial_prompt, **decodeOptions)
74
 
75
  class FasterWhisperCallback(AbstractWhisperCallback):
76
- def __init__(self, model_container: FasterWhisperContainer, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
 
 
77
  self.model_container = model_container
78
  self.language = language
79
  self.task = task
80
  self.initial_prompt = initial_prompt
 
81
  self.decodeOptions = decodeOptions
82
 
83
  self._printed_warning = False
@@ -125,9 +133,11 @@ class FasterWhisperCallback(AbstractWhisperCallback):
125
  # See if supress_tokens is a string - if so, convert it to a list of ints
126
  decodeOptions["suppress_tokens"] = self._split_suppress_tokens(suppress_tokens)
127
 
 
 
128
  segments_generator, info = model.transcribe(audio, \
129
  language=language_code if language_code else detected_language, task=self.task, \
130
- initial_prompt=self._concat_prompt(self.initial_prompt, prompt) if segment_index == 0 else prompt, \
131
  **decodeOptions
132
  )
133
 
 
2
  from typing import List, Union
3
 
4
  from faster_whisper import WhisperModel, download_model
5
+ from src.config import ModelConfig, VadInitialPromptMode
6
  from src.hooks.progressListener import ProgressListener
7
  from src.languages import get_language_from_name
8
  from src.modelCache import ModelCache
 
51
  model = WhisperModel(model_config.url, device=device, compute_type=self.compute_type)
52
  return model
53
 
54
+ def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None,
55
+ initial_prompt_mode: VadInitialPromptMode = VadInitialPromptMode.PREPREND_FIRST_SEGMENT,
56
+ **decodeOptions: dict) -> AbstractWhisperCallback:
57
  """
58
  Create a WhisperCallback object that can be used to transcript audio files.
59
 
 
65
  The task - either translate or transcribe.
66
  initial_prompt: str
67
  The initial prompt to use for the transcription.
68
+ initial_prompt_mode: VadInitialPromptMode
69
+ The mode to use for the initial prompt. If set to PREPEND_FIRST_SEGMENT, the initial prompt will be prepended to the first segment of audio.
70
+ If set to PREPEND_ALL_SEGMENTS, the initial prompt will be prepended to all segments of audio.
71
  decodeOptions: dict
72
  Additional options to pass to the decoder. Must be pickleable.
73
 
 
75
  -------
76
  A WhisperCallback object.
77
  """
78
+ return FasterWhisperCallback(self, language=language, task=task, initial_prompt=initial_prompt, initial_prompt_mode=initial_prompt_mode, **decodeOptions)
79
 
80
  class FasterWhisperCallback(AbstractWhisperCallback):
81
+ def __init__(self, model_container: FasterWhisperContainer, language: str = None, task: str = None,
82
+ initial_prompt: str = None, initial_prompt_mode: VadInitialPromptMode=VadInitialPromptMode.PREPREND_FIRST_SEGMENT,
83
+ **decodeOptions: dict):
84
  self.model_container = model_container
85
  self.language = language
86
  self.task = task
87
  self.initial_prompt = initial_prompt
88
+ self.initial_prompt_mode = initial_prompt_mode
89
  self.decodeOptions = decodeOptions
90
 
91
  self._printed_warning = False
 
133
  # See if supress_tokens is a string - if so, convert it to a list of ints
134
  decodeOptions["suppress_tokens"] = self._split_suppress_tokens(suppress_tokens)
135
 
136
+ initial_prompt = self._get_initial_prompt(self.initial_prompt, self.initial_prompt_mode, prompt, segment_index)
137
+
138
  segments_generator, info = model.transcribe(audio, \
139
  language=language_code if language_code else detected_language, task=self.task, \
140
+ initial_prompt=initial_prompt, \
141
  **decodeOptions
142
  )
143
 
src/whisper/whisperContainer.py CHANGED
@@ -11,7 +11,7 @@ from src.hooks.progressListener import ProgressListener
11
  import whisper
12
  from whisper import Whisper
13
 
14
- from src.config import ModelConfig
15
  from src.hooks.whisperProgressHook import create_progress_listener_handle
16
 
17
  from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
@@ -69,7 +69,9 @@ class WhisperContainer(AbstractWhisperContainer):
69
 
70
  return whisper.load_model(model_path, device=self.device, download_root=self.download_root)
71
 
72
- def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
 
 
73
  """
74
  Create a WhisperCallback object that can be used to transcript audio files.
75
 
@@ -81,6 +83,9 @@ class WhisperContainer(AbstractWhisperContainer):
81
  The task - either translate or transcribe.
82
  initial_prompt: str
83
  The initial prompt to use for the transcription.
 
 
 
84
  decodeOptions: dict
85
  Additional options to pass to the decoder. Must be pickleable.
86
 
@@ -88,7 +93,7 @@ class WhisperContainer(AbstractWhisperContainer):
88
  -------
89
  A WhisperCallback object.
90
  """
91
- return WhisperCallback(self, language=language, task=task, initial_prompt=initial_prompt, **decodeOptions)
92
 
93
  def _get_model_path(self, model_config: ModelConfig, root_dir: str = None):
94
  from src.conversion.hf_converter import convert_hf_whisper
@@ -157,11 +162,13 @@ class WhisperContainer(AbstractWhisperContainer):
157
  return model_config.path
158
 
159
  class WhisperCallback(AbstractWhisperCallback):
160
- def __init__(self, model_container: WhisperContainer, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
 
161
  self.model_container = model_container
162
  self.language = language
163
  self.task = task
164
  self.initial_prompt = initial_prompt
 
165
  self.decodeOptions = decodeOptions
166
 
167
  def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
@@ -194,8 +201,10 @@ class WhisperCallback(AbstractWhisperCallback):
194
  if self.model_container.compute_type in ["fp16", "float16"]:
195
  decodeOptions["fp16"] = True
196
 
 
 
197
  return model.transcribe(audio, \
198
  language=self.language if self.language else detected_language, task=self.task, \
199
- initial_prompt=self._concat_prompt(self.initial_prompt, prompt) if segment_index == 0 else prompt, \
200
  **decodeOptions
201
  )
 
11
  import whisper
12
  from whisper import Whisper
13
 
14
+ from src.config import ModelConfig, VadInitialPromptMode
15
  from src.hooks.whisperProgressHook import create_progress_listener_handle
16
 
17
  from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
 
69
 
70
  return whisper.load_model(model_path, device=self.device, download_root=self.download_root)
71
 
72
+ def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None,
73
+ initial_prompt_mode: VadInitialPromptMode = VadInitialPromptMode.PREPREND_FIRST_SEGMENT,
74
+ **decodeOptions: dict) -> AbstractWhisperCallback:
75
  """
76
  Create a WhisperCallback object that can be used to transcript audio files.
77
 
 
83
  The task - either translate or transcribe.
84
  initial_prompt: str
85
  The initial prompt to use for the transcription.
86
+ initial_prompt_mode: VadInitialPromptMode
87
+ The mode to use for the initial prompt. If set to PREPEND_FIRST_SEGMENT, the initial prompt will be prepended to the first segment of audio.
88
+ If set to PREPEND_ALL_SEGMENTS, the initial prompt will be prepended to all segments of audio.
89
  decodeOptions: dict
90
  Additional options to pass to the decoder. Must be pickleable.
91
 
 
93
  -------
94
  A WhisperCallback object.
95
  """
96
+ return WhisperCallback(self, language=language, task=task, initial_prompt=initial_prompt, initial_prompt_mode=initial_prompt_mode, **decodeOptions)
97
 
98
  def _get_model_path(self, model_config: ModelConfig, root_dir: str = None):
99
  from src.conversion.hf_converter import convert_hf_whisper
 
162
  return model_config.path
163
 
164
  class WhisperCallback(AbstractWhisperCallback):
165
+ def __init__(self, model_container: WhisperContainer, language: str = None, task: str = None, initial_prompt: str = None,
166
+ initial_prompt_mode: VadInitialPromptMode=VadInitialPromptMode.PREPREND_FIRST_SEGMENT, **decodeOptions: dict):
167
  self.model_container = model_container
168
  self.language = language
169
  self.task = task
170
  self.initial_prompt = initial_prompt
171
+ self.initial_prompt_mode = initial_prompt_mode
172
  self.decodeOptions = decodeOptions
173
 
174
  def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
 
201
  if self.model_container.compute_type in ["fp16", "float16"]:
202
  decodeOptions["fp16"] = True
203
 
204
+ initial_prompt = self._get_initial_prompt(self.initial_prompt, self.initial_prompt_mode, prompt, segment_index)
205
+
206
  return model.transcribe(audio, \
207
  language=self.language if self.language else detected_language, task=self.task, \
208
+ initial_prompt=initial_prompt, \
209
  **decodeOptions
210
  )