Add inital prompt mode. GITLAB #7
Browse filesAdd a new configuration option `vad_initial_prompt_mode` in config.json5
and the application arguments that have the following two modes:
* prepend_all_segments: Preprend the initial prompt to each VAD segment
* prepend_first_segment: Only preprend the initial prompt the first VAD segment.
This is useful if you're using the prompt to improve the accuracy of
the transcription of unusual technical terms consistently throughout a lecture. You
can add these terms to the prompt, and then set `vad_initial_prompt_mode` to
`prepend_all_segments` to include the prompt in every VAD segment.
Note that this will have no effect if you're not using a VAD.
- app.py +42 -23
- cli.py +9 -5
- config.json5 +2 -0
- src/config.py +18 -0
- src/whisper/abstractWhisperContainer.py +16 -2
- src/whisper/fasterWhisperContainer.py +15 -5
- src/whisper/whisperContainer.py +14 -5
app.py
CHANGED
@@ -12,7 +12,7 @@ import numpy as np
|
|
12 |
|
13 |
import torch
|
14 |
|
15 |
-
from src.config import ApplicationConfig
|
16 |
from src.hooks.progressListener import ProgressListener
|
17 |
from src.hooks.subTaskProgressListener import SubTaskProgressListener
|
18 |
from src.hooks.whisperProgressHook import create_progress_listener_handle
|
@@ -43,6 +43,17 @@ MAX_AUTO_CPU_CORES = 8
|
|
43 |
|
44 |
WHISPER_MODELS = ["tiny", "base", "small", "medium", "large", "large-v1", "large-v2"]
|
45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
class WhisperTranscriber:
|
47 |
def __init__(self, input_audio_max_duration: float = None, vad_process_timeout: float = None,
|
48 |
vad_cpu_cores: int = 1, delete_uploaded_files: bool = False, output_dir: str = None,
|
@@ -75,11 +86,14 @@ class WhisperTranscriber:
|
|
75 |
# Entry function for the simple tab
|
76 |
def transcribe_webui_simple(self, modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow,
|
77 |
progress=gr.Progress()):
|
78 |
-
|
79 |
-
|
|
|
|
|
80 |
|
81 |
# Entry function for the full tab
|
82 |
-
def transcribe_webui_full(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
|
|
83 |
initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
|
84 |
condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
|
85 |
compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float,
|
@@ -91,14 +105,16 @@ class WhisperTranscriber:
|
|
91 |
else:
|
92 |
temperature = [temperature]
|
93 |
|
94 |
-
|
|
|
|
|
95 |
initial_prompt=initial_prompt, temperature=temperature, best_of=best_of, beam_size=beam_size, patience=patience, length_penalty=length_penalty, suppress_tokens=suppress_tokens,
|
96 |
condition_on_previous_text=condition_on_previous_text, fp16=fp16,
|
97 |
compression_ratio_threshold=compression_ratio_threshold, logprob_threshold=logprob_threshold, no_speech_threshold=no_speech_threshold,
|
98 |
progress=progress)
|
99 |
|
100 |
-
def transcribe_webui(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
101 |
-
progress: gr.Progress = None, **decodeOptions: dict):
|
102 |
try:
|
103 |
sources = self.__get_source(urlData, multipleFiles, microphoneData)
|
104 |
|
@@ -146,7 +162,7 @@ class WhisperTranscriber:
|
|
146 |
sub_task_total=source_audio_duration)
|
147 |
|
148 |
# Transcribe
|
149 |
-
result = self.transcribe_file(model, source.source_path, selectedLanguage, task,
|
150 |
filePrefix = slugify(source_prefix + source.get_short_name(), allow_unicode=True)
|
151 |
|
152 |
# Update progress
|
@@ -210,8 +226,8 @@ class WhisperTranscriber:
|
|
210 |
except ExceededMaximumDuration as e:
|
211 |
return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
|
212 |
|
213 |
-
def transcribe_file(self, model: AbstractWhisperContainer, audio_path: str, language: str, task: str = None,
|
214 |
-
|
215 |
progressListener: ProgressListener = None, **decodeOptions: dict):
|
216 |
|
217 |
initial_prompt = decodeOptions.pop('initial_prompt', None)
|
@@ -224,26 +240,26 @@ class WhisperTranscriber:
|
|
224 |
task = decodeOptions.pop('task')
|
225 |
|
226 |
# Callable for processing an audio file
|
227 |
-
whisperCallable = model.create_callback(language, task, initial_prompt, **decodeOptions)
|
228 |
|
229 |
# The results
|
230 |
-
if (vad == 'silero-vad'):
|
231 |
# Silero VAD where non-speech gaps are transcribed
|
232 |
-
process_gaps = self._create_silero_config(NonSpeechStrategy.CREATE_SEGMENT,
|
233 |
result = self.process_vad(audio_path, whisperCallable, self.vad_model, process_gaps, progressListener=progressListener)
|
234 |
-
elif (vad == 'silero-vad-skip-gaps'):
|
235 |
# Silero VAD where non-speech gaps are simply ignored
|
236 |
-
skip_gaps = self._create_silero_config(NonSpeechStrategy.SKIP,
|
237 |
result = self.process_vad(audio_path, whisperCallable, self.vad_model, skip_gaps, progressListener=progressListener)
|
238 |
-
elif (vad == 'silero-vad-expand-into-gaps'):
|
239 |
# Use Silero VAD where speech-segments are expanded into non-speech gaps
|
240 |
-
expand_gaps = self._create_silero_config(NonSpeechStrategy.EXPAND_SEGMENT,
|
241 |
result = self.process_vad(audio_path, whisperCallable, self.vad_model, expand_gaps, progressListener=progressListener)
|
242 |
-
elif (vad == 'periodic-vad'):
|
243 |
# Very simple VAD - mark every 5 minutes as speech. This makes it less likely that Whisper enters an infinite loop, but
|
244 |
# it may create a break in the middle of a sentence, causing some artifacts.
|
245 |
periodic_vad = VadPeriodicTranscription()
|
246 |
-
period_config = PeriodicTranscriptionConfig(periodic_duration=vadMaxMergeSize, max_prompt_window=vadPromptWindow)
|
247 |
result = self.process_vad(audio_path, whisperCallable, periodic_vad, period_config, progressListener=progressListener)
|
248 |
|
249 |
else:
|
@@ -314,15 +330,15 @@ class WhisperTranscriber:
|
|
314 |
else:
|
315 |
return prompt1 + " " + prompt2
|
316 |
|
317 |
-
def _create_silero_config(self, non_speech_strategy: NonSpeechStrategy,
|
318 |
# Use Silero VAD
|
319 |
if (self.vad_model is None):
|
320 |
self.vad_model = VadSileroTranscription()
|
321 |
|
322 |
config = TranscriptionConfig(non_speech_strategy = non_speech_strategy,
|
323 |
-
max_silent_period=vadMergeWindow, max_merge_size=vadMaxMergeSize,
|
324 |
-
segment_padding_left=vadPadding, segment_padding_right=vadPadding,
|
325 |
-
max_prompt_window=vadPromptWindow)
|
326 |
|
327 |
return config
|
328 |
|
@@ -451,6 +467,7 @@ def create_ui(app_config: ApplicationConfig):
|
|
451 |
|
452 |
full_transcribe = gr.Interface(fn=ui.transcribe_webui_full, description=full_description, article=ui_article, inputs=[
|
453 |
*simple_inputs(),
|
|
|
454 |
gr.TextArea(label="Initial Prompt"),
|
455 |
gr.Number(label="Temperature", value=app_config.temperature),
|
456 |
gr.Number(label="Best Of - Non-zero temperature", value=app_config.best_of, precision=0),
|
@@ -503,6 +520,8 @@ if __name__ == '__main__':
|
|
503 |
help="The default model name.") # medium
|
504 |
parser.add_argument("--default_vad", type=str, default=default_app_config.default_vad, \
|
505 |
help="The default VAD.") # silero-vad
|
|
|
|
|
506 |
parser.add_argument("--vad_parallel_devices", type=str, default=default_app_config.vad_parallel_devices, \
|
507 |
help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.") # ""
|
508 |
parser.add_argument("--vad_cpu_cores", type=int, default=default_app_config.vad_cpu_cores, \
|
|
|
12 |
|
13 |
import torch
|
14 |
|
15 |
+
from src.config import ApplicationConfig, VadInitialPromptMode
|
16 |
from src.hooks.progressListener import ProgressListener
|
17 |
from src.hooks.subTaskProgressListener import SubTaskProgressListener
|
18 |
from src.hooks.whisperProgressHook import create_progress_listener_handle
|
|
|
43 |
|
44 |
WHISPER_MODELS = ["tiny", "base", "small", "medium", "large", "large-v1", "large-v2"]
|
45 |
|
46 |
+
class VadOptions:
|
47 |
+
def __init__(self, vad: str = None, vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1,
|
48 |
+
vadInitialPromptMode: Union[VadInitialPromptMode, str] = VadInitialPromptMode.PREPREND_FIRST_SEGMENT):
|
49 |
+
self.vad = vad
|
50 |
+
self.vadMergeWindow = vadMergeWindow
|
51 |
+
self.vadMaxMergeSize = vadMaxMergeSize
|
52 |
+
self.vadPadding = vadPadding
|
53 |
+
self.vadPromptWindow = vadPromptWindow
|
54 |
+
self.vadInitialPromptMode = vadInitialPromptMode if isinstance(vadInitialPromptMode, VadInitialPromptMode) \
|
55 |
+
else VadInitialPromptMode.from_string(vadInitialPromptMode)
|
56 |
+
|
57 |
class WhisperTranscriber:
|
58 |
def __init__(self, input_audio_max_duration: float = None, vad_process_timeout: float = None,
|
59 |
vad_cpu_cores: int = 1, delete_uploaded_files: bool = False, output_dir: str = None,
|
|
|
86 |
# Entry function for the simple tab
|
87 |
def transcribe_webui_simple(self, modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow,
|
88 |
progress=gr.Progress()):
|
89 |
+
|
90 |
+
vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, self.app_config.vad_initial_prompt_mode)
|
91 |
+
|
92 |
+
return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vadOptions, progress=progress)
|
93 |
|
94 |
# Entry function for the full tab
|
95 |
+
def transcribe_webui_full(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
96 |
+
vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
|
97 |
initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
|
98 |
condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
|
99 |
compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float,
|
|
|
105 |
else:
|
106 |
temperature = [temperature]
|
107 |
|
108 |
+
vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode)
|
109 |
+
|
110 |
+
return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vadOptions,
|
111 |
initial_prompt=initial_prompt, temperature=temperature, best_of=best_of, beam_size=beam_size, patience=patience, length_penalty=length_penalty, suppress_tokens=suppress_tokens,
|
112 |
condition_on_previous_text=condition_on_previous_text, fp16=fp16,
|
113 |
compression_ratio_threshold=compression_ratio_threshold, logprob_threshold=logprob_threshold, no_speech_threshold=no_speech_threshold,
|
114 |
progress=progress)
|
115 |
|
116 |
+
def transcribe_webui(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
117 |
+
vadOptions: VadOptions, progress: gr.Progress = None, **decodeOptions: dict):
|
118 |
try:
|
119 |
sources = self.__get_source(urlData, multipleFiles, microphoneData)
|
120 |
|
|
|
162 |
sub_task_total=source_audio_duration)
|
163 |
|
164 |
# Transcribe
|
165 |
+
result = self.transcribe_file(model, source.source_path, selectedLanguage, task, vadOptions, scaled_progress_listener, **decodeOptions)
|
166 |
filePrefix = slugify(source_prefix + source.get_short_name(), allow_unicode=True)
|
167 |
|
168 |
# Update progress
|
|
|
226 |
except ExceededMaximumDuration as e:
|
227 |
return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
|
228 |
|
229 |
+
def transcribe_file(self, model: AbstractWhisperContainer, audio_path: str, language: str, task: str = None,
|
230 |
+
vadOptions: VadOptions = VadOptions(),
|
231 |
progressListener: ProgressListener = None, **decodeOptions: dict):
|
232 |
|
233 |
initial_prompt = decodeOptions.pop('initial_prompt', None)
|
|
|
240 |
task = decodeOptions.pop('task')
|
241 |
|
242 |
# Callable for processing an audio file
|
243 |
+
whisperCallable = model.create_callback(language, task, initial_prompt, initial_prompt_mode=vadOptions.vadInitialPromptMode, **decodeOptions)
|
244 |
|
245 |
# The results
|
246 |
+
if (vadOptions.vad == 'silero-vad'):
|
247 |
# Silero VAD where non-speech gaps are transcribed
|
248 |
+
process_gaps = self._create_silero_config(NonSpeechStrategy.CREATE_SEGMENT, vadOptions)
|
249 |
result = self.process_vad(audio_path, whisperCallable, self.vad_model, process_gaps, progressListener=progressListener)
|
250 |
+
elif (vadOptions.vad == 'silero-vad-skip-gaps'):
|
251 |
# Silero VAD where non-speech gaps are simply ignored
|
252 |
+
skip_gaps = self._create_silero_config(NonSpeechStrategy.SKIP, vadOptions)
|
253 |
result = self.process_vad(audio_path, whisperCallable, self.vad_model, skip_gaps, progressListener=progressListener)
|
254 |
+
elif (vadOptions.vad == 'silero-vad-expand-into-gaps'):
|
255 |
# Use Silero VAD where speech-segments are expanded into non-speech gaps
|
256 |
+
expand_gaps = self._create_silero_config(NonSpeechStrategy.EXPAND_SEGMENT, vadOptions)
|
257 |
result = self.process_vad(audio_path, whisperCallable, self.vad_model, expand_gaps, progressListener=progressListener)
|
258 |
+
elif (vadOptions.vad == 'periodic-vad'):
|
259 |
# Very simple VAD - mark every 5 minutes as speech. This makes it less likely that Whisper enters an infinite loop, but
|
260 |
# it may create a break in the middle of a sentence, causing some artifacts.
|
261 |
periodic_vad = VadPeriodicTranscription()
|
262 |
+
period_config = PeriodicTranscriptionConfig(periodic_duration=vadOptions.vadMaxMergeSize, max_prompt_window=vadOptions.vadPromptWindow)
|
263 |
result = self.process_vad(audio_path, whisperCallable, periodic_vad, period_config, progressListener=progressListener)
|
264 |
|
265 |
else:
|
|
|
330 |
else:
|
331 |
return prompt1 + " " + prompt2
|
332 |
|
333 |
+
def _create_silero_config(self, non_speech_strategy: NonSpeechStrategy, vadOptions: VadOptions):
|
334 |
# Use Silero VAD
|
335 |
if (self.vad_model is None):
|
336 |
self.vad_model = VadSileroTranscription()
|
337 |
|
338 |
config = TranscriptionConfig(non_speech_strategy = non_speech_strategy,
|
339 |
+
max_silent_period=vadOptions.vadMergeWindow, max_merge_size=vadOptions.vadMaxMergeSize,
|
340 |
+
segment_padding_left=vadOptions.vadPadding, segment_padding_right=vadOptions.vadPadding,
|
341 |
+
max_prompt_window=vadOptions.vadPromptWindow)
|
342 |
|
343 |
return config
|
344 |
|
|
|
467 |
|
468 |
full_transcribe = gr.Interface(fn=ui.transcribe_webui_full, description=full_description, article=ui_article, inputs=[
|
469 |
*simple_inputs(),
|
470 |
+
gr.Dropdown(choices=["prepend_first_segment", "prepend_all_segments"], value=app_config.vad_initial_prompt_mode, label="VAD - Initial Prompt Mode"),
|
471 |
gr.TextArea(label="Initial Prompt"),
|
472 |
gr.Number(label="Temperature", value=app_config.temperature),
|
473 |
gr.Number(label="Best Of - Non-zero temperature", value=app_config.best_of, precision=0),
|
|
|
520 |
help="The default model name.") # medium
|
521 |
parser.add_argument("--default_vad", type=str, default=default_app_config.default_vad, \
|
522 |
help="The default VAD.") # silero-vad
|
523 |
+
parser.add_argument("--vad_initial_prompt_mode", type=str, default=default_app_config.vad_initial_prompt_mode, choices=["prepend_all_segments", "prepend_first_segment"], \
|
524 |
+
help="Whether or not to prepend the initial prompt to each VAD segment (prepend_all_segments), or just the first segment (prepend_first_segment)") # prepend_first_segment
|
525 |
parser.add_argument("--vad_parallel_devices", type=str, default=default_app_config.vad_parallel_devices, \
|
526 |
help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.") # ""
|
527 |
parser.add_argument("--vad_cpu_cores", type=int, default=default_app_config.vad_cpu_cores, \
|
cli.py
CHANGED
@@ -6,8 +6,8 @@ import warnings
|
|
6 |
import numpy as np
|
7 |
|
8 |
import torch
|
9 |
-
from app import WhisperTranscriber
|
10 |
-
from src.config import ApplicationConfig
|
11 |
from src.download import download_url
|
12 |
from src.languages import get_language_names
|
13 |
|
@@ -47,6 +47,8 @@ def cli():
|
|
47 |
|
48 |
parser.add_argument("--vad", type=str, default=app_config.default_vad, choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], \
|
49 |
help="The voice activity detection algorithm to use") # silero-vad
|
|
|
|
|
50 |
parser.add_argument("--vad_merge_window", type=optional_float, default=app_config.vad_merge_window, \
|
51 |
help="The window size (in seconds) to merge voice segments")
|
52 |
parser.add_argument("--vad_max_merge_size", type=optional_float, default=app_config.vad_max_merge_size,\
|
@@ -115,6 +117,7 @@ def cli():
|
|
115 |
temperature = [temperature]
|
116 |
|
117 |
vad = args.pop("vad")
|
|
|
118 |
vad_merge_window = args.pop("vad_merge_window")
|
119 |
vad_max_merge_size = args.pop("vad_max_merge_size")
|
120 |
vad_padding = args.pop("vad_padding")
|
@@ -150,9 +153,10 @@ def cli():
|
|
150 |
source_path = source["path"]
|
151 |
source_name = source["name"]
|
152 |
|
153 |
-
|
154 |
-
|
155 |
-
|
|
|
156 |
|
157 |
transcriber.write_result(result, source_name, output_dir)
|
158 |
|
|
|
6 |
import numpy as np
|
7 |
|
8 |
import torch
|
9 |
+
from app import VadOptions, WhisperTranscriber
|
10 |
+
from src.config import ApplicationConfig, VadInitialPromptMode
|
11 |
from src.download import download_url
|
12 |
from src.languages import get_language_names
|
13 |
|
|
|
47 |
|
48 |
parser.add_argument("--vad", type=str, default=app_config.default_vad, choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], \
|
49 |
help="The voice activity detection algorithm to use") # silero-vad
|
50 |
+
parser.add_argument("--vad_initial_prompt_mode", type=str, default=app_config.vad_initial_prompt_mode, choices=["prepend_all_segments", "prepend_first_segment"], \
|
51 |
+
help="Whether or not to prepend the initial prompt to each VAD segment (prepend_all_segments), or just the first segment (prepend_first_segment)") # prepend_first_segment
|
52 |
parser.add_argument("--vad_merge_window", type=optional_float, default=app_config.vad_merge_window, \
|
53 |
help="The window size (in seconds) to merge voice segments")
|
54 |
parser.add_argument("--vad_max_merge_size", type=optional_float, default=app_config.vad_max_merge_size,\
|
|
|
117 |
temperature = [temperature]
|
118 |
|
119 |
vad = args.pop("vad")
|
120 |
+
vad_initial_prompt_mode = args.pop("vad_initial_prompt_mode")
|
121 |
vad_merge_window = args.pop("vad_merge_window")
|
122 |
vad_max_merge_size = args.pop("vad_max_merge_size")
|
123 |
vad_padding = args.pop("vad_padding")
|
|
|
153 |
source_path = source["path"]
|
154 |
source_name = source["name"]
|
155 |
|
156 |
+
vadOptions = VadOptions(vad, vad_merge_window, vad_max_merge_size, vad_padding, vad_prompt_window,
|
157 |
+
VadInitialPromptMode.from_string(vad_initial_prompt_mode))
|
158 |
+
|
159 |
+
result = transcriber.transcribe_file(model, source_path, temperature=temperature, vadOptions=vadOptions, **args)
|
160 |
|
161 |
transcriber.write_result(result, source_name, output_dir)
|
162 |
|
config.json5
CHANGED
@@ -97,6 +97,8 @@
|
|
97 |
"vad_max_merge_size": 30,
|
98 |
// The padding (in seconds) to add to each voice segment
|
99 |
"vad_padding": 1,
|
|
|
|
|
100 |
// The window size of the prompt to pass to Whisper
|
101 |
"vad_prompt_window": 3,
|
102 |
// Temperature to use for sampling
|
|
|
97 |
"vad_max_merge_size": 30,
|
98 |
// The padding (in seconds) to add to each voice segment
|
99 |
"vad_padding": 1,
|
100 |
+
// Whether or not to prepend the initial prompt to each VAD segment (prepend_all_segments), or just the first segment (prepend_first_segment)
|
101 |
+
"vad_initial_prompt_mode": "prepend_first_segment",
|
102 |
// The window size of the prompt to pass to Whisper
|
103 |
"vad_prompt_window": 3,
|
104 |
// Temperature to use for sampling
|
src/config.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import urllib
|
2 |
|
3 |
import os
|
@@ -23,6 +24,21 @@ class ModelConfig:
|
|
23 |
self.path = path
|
24 |
self.type = type
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
class ApplicationConfig:
|
27 |
def __init__(self, models: List[ModelConfig] = [], input_audio_max_duration: int = 600,
|
28 |
share: bool = False, server_name: str = None, server_port: int = 7860,
|
@@ -33,6 +49,7 @@ class ApplicationConfig:
|
|
33 |
auto_parallel: bool = False, output_dir: str = None,
|
34 |
model_dir: str = None, device: str = None,
|
35 |
verbose: bool = True, task: str = "transcribe", language: str = None,
|
|
|
36 |
vad_merge_window: float = 5, vad_max_merge_size: float = 30,
|
37 |
vad_padding: float = 1, vad_prompt_window: float = 3,
|
38 |
temperature: float = 0, best_of: int = 5, beam_size: int = 5,
|
@@ -67,6 +84,7 @@ class ApplicationConfig:
|
|
67 |
self.verbose = verbose
|
68 |
self.task = task
|
69 |
self.language = language
|
|
|
70 |
self.vad_merge_window = vad_merge_window
|
71 |
self.vad_max_merge_size = vad_max_merge_size
|
72 |
self.vad_padding = vad_padding
|
|
|
1 |
+
from enum import Enum
|
2 |
import urllib
|
3 |
|
4 |
import os
|
|
|
24 |
self.path = path
|
25 |
self.type = type
|
26 |
|
27 |
+
class VadInitialPromptMode(Enum):
|
28 |
+
PREPEND_ALL_SEGMENTS = 1
|
29 |
+
PREPREND_FIRST_SEGMENT = 2
|
30 |
+
|
31 |
+
@staticmethod
|
32 |
+
def from_string(s: str):
|
33 |
+
normalized = s.lower() if s is not None else None
|
34 |
+
|
35 |
+
if normalized == "prepend_all_segments":
|
36 |
+
return VadInitialPromptMode.PREPEND_ALL_SEGMENTS
|
37 |
+
elif normalized == "prepend_first_segment":
|
38 |
+
return VadInitialPromptMode.PREPREND_FIRST_SEGMENT
|
39 |
+
else:
|
40 |
+
raise ValueError(f"Invalid value for VadInitialPromptMode: {s}")
|
41 |
+
|
42 |
class ApplicationConfig:
|
43 |
def __init__(self, models: List[ModelConfig] = [], input_audio_max_duration: int = 600,
|
44 |
share: bool = False, server_name: str = None, server_port: int = 7860,
|
|
|
49 |
auto_parallel: bool = False, output_dir: str = None,
|
50 |
model_dir: str = None, device: str = None,
|
51 |
verbose: bool = True, task: str = "transcribe", language: str = None,
|
52 |
+
vad_initial_prompt_mode: str = "prepend_first_segment ",
|
53 |
vad_merge_window: float = 5, vad_max_merge_size: float = 30,
|
54 |
vad_padding: float = 1, vad_prompt_window: float = 3,
|
55 |
temperature: float = 0, best_of: int = 5, beam_size: int = 5,
|
|
|
84 |
self.verbose = verbose
|
85 |
self.task = task
|
86 |
self.language = language
|
87 |
+
self.vad_initial_prompt_mode = vad_initial_prompt_mode
|
88 |
self.vad_merge_window = vad_merge_window
|
89 |
self.vad_max_merge_size = vad_max_merge_size
|
90 |
self.vad_padding = vad_padding
|
src/whisper/abstractWhisperContainer.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import abc
|
2 |
from typing import List
|
3 |
-
from src.config import ModelConfig
|
4 |
|
5 |
from src.hooks.progressListener import ProgressListener
|
6 |
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
|
@@ -24,6 +24,15 @@ class AbstractWhisperCallback:
|
|
24 |
"""
|
25 |
raise NotImplementedError()
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
def _concat_prompt(self, prompt1, prompt2):
|
28 |
if (prompt1 is None):
|
29 |
return prompt2
|
@@ -66,7 +75,9 @@ class AbstractWhisperContainer:
|
|
66 |
pass
|
67 |
|
68 |
@abc.abstractmethod
|
69 |
-
def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None,
|
|
|
|
|
70 |
"""
|
71 |
Create a WhisperCallback object that can be used to transcript audio files.
|
72 |
|
@@ -78,6 +89,9 @@ class AbstractWhisperContainer:
|
|
78 |
The task - either translate or transcribe.
|
79 |
initial_prompt: str
|
80 |
The initial prompt to use for the transcription.
|
|
|
|
|
|
|
81 |
decodeOptions: dict
|
82 |
Additional options to pass to the decoder. Must be pickleable.
|
83 |
|
|
|
1 |
import abc
|
2 |
from typing import List
|
3 |
+
from src.config import ModelConfig, VadInitialPromptMode
|
4 |
|
5 |
from src.hooks.progressListener import ProgressListener
|
6 |
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
|
|
|
24 |
"""
|
25 |
raise NotImplementedError()
|
26 |
|
27 |
+
def _get_initial_prompt(self, initial_prompt: str, initial_prompt_mode: VadInitialPromptMode,
|
28 |
+
prompt: str, segment_index: int):
|
29 |
+
if (initial_prompt_mode == VadInitialPromptMode.PREPEND_ALL_SEGMENTS):
|
30 |
+
return self._concat_prompt(initial_prompt, prompt)
|
31 |
+
elif (initial_prompt_mode == VadInitialPromptMode.PREPREND_FIRST_SEGMENT):
|
32 |
+
return self._concat_prompt(initial_prompt, prompt) if segment_index == 0 else prompt
|
33 |
+
else:
|
34 |
+
raise ValueError(f"Unknown initial prompt mode {initial_prompt_mode}")
|
35 |
+
|
36 |
def _concat_prompt(self, prompt1, prompt2):
|
37 |
if (prompt1 is None):
|
38 |
return prompt2
|
|
|
75 |
pass
|
76 |
|
77 |
@abc.abstractmethod
|
78 |
+
def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None,
|
79 |
+
initial_prompt_mode: VadInitialPromptMode = VadInitialPromptMode.PREPREND_FIRST_SEGMENT,
|
80 |
+
**decodeOptions: dict) -> AbstractWhisperCallback:
|
81 |
"""
|
82 |
Create a WhisperCallback object that can be used to transcript audio files.
|
83 |
|
|
|
89 |
The task - either translate or transcribe.
|
90 |
initial_prompt: str
|
91 |
The initial prompt to use for the transcription.
|
92 |
+
initial_prompt_mode: VadInitialPromptMode
|
93 |
+
The mode to use for the initial prompt. If set to PREPEND_FIRST_SEGMENT, the initial prompt will be prepended to the first segment of audio.
|
94 |
+
If set to PREPEND_ALL_SEGMENTS, the initial prompt will be prepended to all segments of audio.
|
95 |
decodeOptions: dict
|
96 |
Additional options to pass to the decoder. Must be pickleable.
|
97 |
|
src/whisper/fasterWhisperContainer.py
CHANGED
@@ -2,7 +2,7 @@ import os
|
|
2 |
from typing import List, Union
|
3 |
|
4 |
from faster_whisper import WhisperModel, download_model
|
5 |
-
from src.config import ModelConfig
|
6 |
from src.hooks.progressListener import ProgressListener
|
7 |
from src.languages import get_language_from_name
|
8 |
from src.modelCache import ModelCache
|
@@ -51,7 +51,9 @@ class FasterWhisperContainer(AbstractWhisperContainer):
|
|
51 |
model = WhisperModel(model_config.url, device=device, compute_type=self.compute_type)
|
52 |
return model
|
53 |
|
54 |
-
def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None,
|
|
|
|
|
55 |
"""
|
56 |
Create a WhisperCallback object that can be used to transcript audio files.
|
57 |
|
@@ -63,6 +65,9 @@ class FasterWhisperContainer(AbstractWhisperContainer):
|
|
63 |
The task - either translate or transcribe.
|
64 |
initial_prompt: str
|
65 |
The initial prompt to use for the transcription.
|
|
|
|
|
|
|
66 |
decodeOptions: dict
|
67 |
Additional options to pass to the decoder. Must be pickleable.
|
68 |
|
@@ -70,14 +75,17 @@ class FasterWhisperContainer(AbstractWhisperContainer):
|
|
70 |
-------
|
71 |
A WhisperCallback object.
|
72 |
"""
|
73 |
-
return FasterWhisperCallback(self, language=language, task=task, initial_prompt=initial_prompt, **decodeOptions)
|
74 |
|
75 |
class FasterWhisperCallback(AbstractWhisperCallback):
|
76 |
-
def __init__(self, model_container: FasterWhisperContainer, language: str = None, task: str = None,
|
|
|
|
|
77 |
self.model_container = model_container
|
78 |
self.language = language
|
79 |
self.task = task
|
80 |
self.initial_prompt = initial_prompt
|
|
|
81 |
self.decodeOptions = decodeOptions
|
82 |
|
83 |
self._printed_warning = False
|
@@ -125,9 +133,11 @@ class FasterWhisperCallback(AbstractWhisperCallback):
|
|
125 |
# See if supress_tokens is a string - if so, convert it to a list of ints
|
126 |
decodeOptions["suppress_tokens"] = self._split_suppress_tokens(suppress_tokens)
|
127 |
|
|
|
|
|
128 |
segments_generator, info = model.transcribe(audio, \
|
129 |
language=language_code if language_code else detected_language, task=self.task, \
|
130 |
-
initial_prompt=
|
131 |
**decodeOptions
|
132 |
)
|
133 |
|
|
|
2 |
from typing import List, Union
|
3 |
|
4 |
from faster_whisper import WhisperModel, download_model
|
5 |
+
from src.config import ModelConfig, VadInitialPromptMode
|
6 |
from src.hooks.progressListener import ProgressListener
|
7 |
from src.languages import get_language_from_name
|
8 |
from src.modelCache import ModelCache
|
|
|
51 |
model = WhisperModel(model_config.url, device=device, compute_type=self.compute_type)
|
52 |
return model
|
53 |
|
54 |
+
def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None,
|
55 |
+
initial_prompt_mode: VadInitialPromptMode = VadInitialPromptMode.PREPREND_FIRST_SEGMENT,
|
56 |
+
**decodeOptions: dict) -> AbstractWhisperCallback:
|
57 |
"""
|
58 |
Create a WhisperCallback object that can be used to transcript audio files.
|
59 |
|
|
|
65 |
The task - either translate or transcribe.
|
66 |
initial_prompt: str
|
67 |
The initial prompt to use for the transcription.
|
68 |
+
initial_prompt_mode: VadInitialPromptMode
|
69 |
+
The mode to use for the initial prompt. If set to PREPEND_FIRST_SEGMENT, the initial prompt will be prepended to the first segment of audio.
|
70 |
+
If set to PREPEND_ALL_SEGMENTS, the initial prompt will be prepended to all segments of audio.
|
71 |
decodeOptions: dict
|
72 |
Additional options to pass to the decoder. Must be pickleable.
|
73 |
|
|
|
75 |
-------
|
76 |
A WhisperCallback object.
|
77 |
"""
|
78 |
+
return FasterWhisperCallback(self, language=language, task=task, initial_prompt=initial_prompt, initial_prompt_mode=initial_prompt_mode, **decodeOptions)
|
79 |
|
80 |
class FasterWhisperCallback(AbstractWhisperCallback):
|
81 |
+
def __init__(self, model_container: FasterWhisperContainer, language: str = None, task: str = None,
|
82 |
+
initial_prompt: str = None, initial_prompt_mode: VadInitialPromptMode=VadInitialPromptMode.PREPREND_FIRST_SEGMENT,
|
83 |
+
**decodeOptions: dict):
|
84 |
self.model_container = model_container
|
85 |
self.language = language
|
86 |
self.task = task
|
87 |
self.initial_prompt = initial_prompt
|
88 |
+
self.initial_prompt_mode = initial_prompt_mode
|
89 |
self.decodeOptions = decodeOptions
|
90 |
|
91 |
self._printed_warning = False
|
|
|
133 |
# See if supress_tokens is a string - if so, convert it to a list of ints
|
134 |
decodeOptions["suppress_tokens"] = self._split_suppress_tokens(suppress_tokens)
|
135 |
|
136 |
+
initial_prompt = self._get_initial_prompt(self.initial_prompt, self.initial_prompt_mode, prompt, segment_index)
|
137 |
+
|
138 |
segments_generator, info = model.transcribe(audio, \
|
139 |
language=language_code if language_code else detected_language, task=self.task, \
|
140 |
+
initial_prompt=initial_prompt, \
|
141 |
**decodeOptions
|
142 |
)
|
143 |
|
src/whisper/whisperContainer.py
CHANGED
@@ -11,7 +11,7 @@ from src.hooks.progressListener import ProgressListener
|
|
11 |
import whisper
|
12 |
from whisper import Whisper
|
13 |
|
14 |
-
from src.config import ModelConfig
|
15 |
from src.hooks.whisperProgressHook import create_progress_listener_handle
|
16 |
|
17 |
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
|
@@ -69,7 +69,9 @@ class WhisperContainer(AbstractWhisperContainer):
|
|
69 |
|
70 |
return whisper.load_model(model_path, device=self.device, download_root=self.download_root)
|
71 |
|
72 |
-
def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None,
|
|
|
|
|
73 |
"""
|
74 |
Create a WhisperCallback object that can be used to transcript audio files.
|
75 |
|
@@ -81,6 +83,9 @@ class WhisperContainer(AbstractWhisperContainer):
|
|
81 |
The task - either translate or transcribe.
|
82 |
initial_prompt: str
|
83 |
The initial prompt to use for the transcription.
|
|
|
|
|
|
|
84 |
decodeOptions: dict
|
85 |
Additional options to pass to the decoder. Must be pickleable.
|
86 |
|
@@ -88,7 +93,7 @@ class WhisperContainer(AbstractWhisperContainer):
|
|
88 |
-------
|
89 |
A WhisperCallback object.
|
90 |
"""
|
91 |
-
return WhisperCallback(self, language=language, task=task, initial_prompt=initial_prompt, **decodeOptions)
|
92 |
|
93 |
def _get_model_path(self, model_config: ModelConfig, root_dir: str = None):
|
94 |
from src.conversion.hf_converter import convert_hf_whisper
|
@@ -157,11 +162,13 @@ class WhisperContainer(AbstractWhisperContainer):
|
|
157 |
return model_config.path
|
158 |
|
159 |
class WhisperCallback(AbstractWhisperCallback):
|
160 |
-
def __init__(self, model_container: WhisperContainer, language: str = None, task: str = None, initial_prompt: str = None,
|
|
|
161 |
self.model_container = model_container
|
162 |
self.language = language
|
163 |
self.task = task
|
164 |
self.initial_prompt = initial_prompt
|
|
|
165 |
self.decodeOptions = decodeOptions
|
166 |
|
167 |
def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
|
@@ -194,8 +201,10 @@ class WhisperCallback(AbstractWhisperCallback):
|
|
194 |
if self.model_container.compute_type in ["fp16", "float16"]:
|
195 |
decodeOptions["fp16"] = True
|
196 |
|
|
|
|
|
197 |
return model.transcribe(audio, \
|
198 |
language=self.language if self.language else detected_language, task=self.task, \
|
199 |
-
initial_prompt=
|
200 |
**decodeOptions
|
201 |
)
|
|
|
11 |
import whisper
|
12 |
from whisper import Whisper
|
13 |
|
14 |
+
from src.config import ModelConfig, VadInitialPromptMode
|
15 |
from src.hooks.whisperProgressHook import create_progress_listener_handle
|
16 |
|
17 |
from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
|
|
|
69 |
|
70 |
return whisper.load_model(model_path, device=self.device, download_root=self.download_root)
|
71 |
|
72 |
+
def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None,
|
73 |
+
initial_prompt_mode: VadInitialPromptMode = VadInitialPromptMode.PREPREND_FIRST_SEGMENT,
|
74 |
+
**decodeOptions: dict) -> AbstractWhisperCallback:
|
75 |
"""
|
76 |
Create a WhisperCallback object that can be used to transcript audio files.
|
77 |
|
|
|
83 |
The task - either translate or transcribe.
|
84 |
initial_prompt: str
|
85 |
The initial prompt to use for the transcription.
|
86 |
+
initial_prompt_mode: VadInitialPromptMode
|
87 |
+
The mode to use for the initial prompt. If set to PREPEND_FIRST_SEGMENT, the initial prompt will be prepended to the first segment of audio.
|
88 |
+
If set to PREPEND_ALL_SEGMENTS, the initial prompt will be prepended to all segments of audio.
|
89 |
decodeOptions: dict
|
90 |
Additional options to pass to the decoder. Must be pickleable.
|
91 |
|
|
|
93 |
-------
|
94 |
A WhisperCallback object.
|
95 |
"""
|
96 |
+
return WhisperCallback(self, language=language, task=task, initial_prompt=initial_prompt, initial_prompt_mode=initial_prompt_mode, **decodeOptions)
|
97 |
|
98 |
def _get_model_path(self, model_config: ModelConfig, root_dir: str = None):
|
99 |
from src.conversion.hf_converter import convert_hf_whisper
|
|
|
162 |
return model_config.path
|
163 |
|
164 |
class WhisperCallback(AbstractWhisperCallback):
|
165 |
+
def __init__(self, model_container: WhisperContainer, language: str = None, task: str = None, initial_prompt: str = None,
|
166 |
+
initial_prompt_mode: VadInitialPromptMode=VadInitialPromptMode.PREPREND_FIRST_SEGMENT, **decodeOptions: dict):
|
167 |
self.model_container = model_container
|
168 |
self.language = language
|
169 |
self.task = task
|
170 |
self.initial_prompt = initial_prompt
|
171 |
+
self.initial_prompt_mode = initial_prompt_mode
|
172 |
self.decodeOptions = decodeOptions
|
173 |
|
174 |
def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
|
|
|
201 |
if self.model_container.compute_type in ["fp16", "float16"]:
|
202 |
decodeOptions["fp16"] = True
|
203 |
|
204 |
+
initial_prompt = self._get_initial_prompt(self.initial_prompt, self.initial_prompt_mode, prompt, segment_index)
|
205 |
+
|
206 |
return model.transcribe(audio, \
|
207 |
language=self.language if self.language else detected_language, task=self.task, \
|
208 |
+
initial_prompt=initial_prompt, \
|
209 |
**decodeOptions
|
210 |
)
|