aadnk commited on
Commit
a8eb534
2 Parent(s): 0819c3a 43189ac

Merge branch 'main' of https://huggingface.co/spaces/aadnk/whisper-webui into main

Browse files
Files changed (7) hide show
  1. app.py +72 -31
  2. cli.py +17 -2
  3. config.json5 +10 -1
  4. src/config.py +11 -1
  5. src/utils.py +118 -8
  6. src/vad.py +8 -0
  7. src/whisper/whisperContainer.py +3 -2
app.py CHANGED
@@ -1,4 +1,5 @@
1
  from datetime import datetime
 
2
  import math
3
  from typing import Iterator, Union
4
  import argparse
@@ -28,7 +29,7 @@ import ffmpeg
28
  import gradio as gr
29
 
30
  from src.download import ExceededMaximumDuration, download_url
31
- from src.utils import slugify, write_srt, write_vtt
32
  from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
33
  from src.whisper.abstractWhisperContainer import AbstractWhisperContainer
34
  from src.whisper.whisperFactory import create_whisper_container
@@ -84,37 +85,49 @@ class WhisperTranscriber:
84
  print("[Auto parallel] Using GPU devices " + str(self.parallel_device_list) + " and " + str(self.vad_cpu_cores) + " CPU cores for VAD/transcription.")
85
 
86
  # Entry function for the simple tab
87
- def transcribe_webui_simple(self, modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow):
88
- return self.transcribe_webui_simple_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
 
 
 
 
89
 
90
  # Entry function for the simple tab progress
91
- def transcribe_webui_simple_progress(self, modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow,
92
- progress=gr.Progress()):
 
 
93
 
94
- vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, self.app_config.vad_initial_prompt_mode)
95
 
96
- return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vadOptions, progress=progress)
 
97
 
98
  # Entry function for the full tab
99
  def transcribe_webui_full(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
100
- vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
101
- initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
102
- condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
103
- compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float):
 
 
104
 
105
  return self.transcribe_webui_full_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task,
106
  vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
 
107
  initial_prompt, temperature, best_of, beam_size, patience, length_penalty, suppress_tokens,
108
  condition_on_previous_text, fp16, temperature_increment_on_fallback,
109
  compression_ratio_threshold, logprob_threshold, no_speech_threshold)
110
 
111
  # Entry function for the full tab with progress
112
  def transcribe_webui_full_progress(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
113
- vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
114
- initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
115
- condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
116
- compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float,
117
- progress=gr.Progress()):
 
 
118
 
119
  # Handle temperature_increment_on_fallback
120
  if temperature_increment_on_fallback is not None:
@@ -128,13 +141,15 @@ class WhisperTranscriber:
128
  initial_prompt=initial_prompt, temperature=temperature, best_of=best_of, beam_size=beam_size, patience=patience, length_penalty=length_penalty, suppress_tokens=suppress_tokens,
129
  condition_on_previous_text=condition_on_previous_text, fp16=fp16,
130
  compression_ratio_threshold=compression_ratio_threshold, logprob_threshold=logprob_threshold, no_speech_threshold=no_speech_threshold,
 
131
  progress=progress)
132
 
133
  def transcribe_webui(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
134
- vadOptions: VadOptions, progress: gr.Progress = None, **decodeOptions: dict):
 
135
  try:
136
  sources = self.__get_source(urlData, multipleFiles, microphoneData)
137
-
138
  try:
139
  selectedLanguage = languageName.lower() if len(languageName) > 0 else None
140
  selectedModel = modelName if modelName is not None else "base"
@@ -185,7 +200,7 @@ class WhisperTranscriber:
185
  # Update progress
186
  current_progress += source_audio_duration
187
 
188
- source_download, source_text, source_vtt = self.write_result(result, filePrefix, outputDirectory)
189
 
190
  if len(sources) > 1:
191
  # Add new line separators
@@ -359,7 +374,7 @@ class WhisperTranscriber:
359
 
360
  return config
361
 
362
- def write_result(self, result: dict, source_name: str, output_dir: str):
363
  if not os.path.exists(output_dir):
364
  os.makedirs(output_dir)
365
 
@@ -368,13 +383,15 @@ class WhisperTranscriber:
368
  languageMaxLineWidth = self.__get_max_line_width(language)
369
 
370
  print("Max line width " + str(languageMaxLineWidth))
371
- vtt = self.__get_subs(result["segments"], "vtt", languageMaxLineWidth)
372
- srt = self.__get_subs(result["segments"], "srt", languageMaxLineWidth)
 
373
 
374
  output_files = []
375
  output_files.append(self.__create_file(srt, output_dir, source_name + "-subs.srt"));
376
  output_files.append(self.__create_file(vtt, output_dir, source_name + "-subs.vtt"));
377
  output_files.append(self.__create_file(text, output_dir, source_name + "-transcript.txt"));
 
378
 
379
  return output_files, text, vtt
380
 
@@ -394,13 +411,13 @@ class WhisperTranscriber:
394
  # 80 latin characters should fit on a 1080p/720p screen
395
  return 80
396
 
397
- def __get_subs(self, segments: Iterator[dict], format: str, maxLineWidth: int) -> str:
398
  segmentStream = StringIO()
399
 
400
  if format == 'vtt':
401
- write_vtt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
402
  elif format == 'srt':
403
- write_srt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
404
  else:
405
  raise Exception("Unknown format " + format)
406
 
@@ -460,24 +477,34 @@ def create_ui(app_config: ApplicationConfig):
460
 
461
  whisper_models = app_config.get_model_names()
462
 
463
- simple_inputs = lambda : [
464
  gr.Dropdown(choices=whisper_models, value=app_config.default_model_name, label="Model"),
465
  gr.Dropdown(choices=sorted(get_language_names()), label="Language", value=app_config.language),
466
  gr.Text(label="URL (YouTube, etc.)"),
467
  gr.File(label="Upload Files", file_count="multiple"),
468
  gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
469
  gr.Dropdown(choices=["transcribe", "translate"], label="Task", value=app_config.task),
 
 
 
470
  gr.Dropdown(choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], value=app_config.default_vad, label="VAD"),
471
  gr.Number(label="VAD - Merge Window (s)", precision=0, value=app_config.vad_merge_window),
472
  gr.Number(label="VAD - Max Merge Size (s)", precision=0, value=app_config.vad_max_merge_size),
473
- gr.Number(label="VAD - Padding (s)", precision=None, value=app_config.vad_padding),
474
- gr.Number(label="VAD - Prompt Window (s)", precision=None, value=app_config.vad_prompt_window),
 
 
 
475
  ]
476
 
477
  is_queue_mode = app_config.queue_concurrency_count is not None and app_config.queue_concurrency_count > 0
478
 
479
  simple_transcribe = gr.Interface(fn=ui.transcribe_webui_simple_progress if is_queue_mode else ui.transcribe_webui_simple,
480
- description=ui_description, article=ui_article, inputs=simple_inputs(), outputs=[
 
 
 
 
481
  gr.File(label="Download"),
482
  gr.Text(label="Transcription"),
483
  gr.Text(label="Segments")
@@ -487,8 +514,17 @@ def create_ui(app_config: ApplicationConfig):
487
 
488
  full_transcribe = gr.Interface(fn=ui.transcribe_webui_full_progress if is_queue_mode else ui.transcribe_webui_full,
489
  description=full_description, article=ui_article, inputs=[
490
- *simple_inputs(),
 
 
 
 
491
  gr.Dropdown(choices=["prepend_first_segment", "prepend_all_segments"], value=app_config.vad_initial_prompt_mode, label="VAD - Initial Prompt Mode"),
 
 
 
 
 
492
  gr.TextArea(label="Initial Prompt"),
493
  gr.Number(label="Temperature", value=app_config.temperature),
494
  gr.Number(label="Best Of - Non-zero temperature", value=app_config.best_of, precision=0),
@@ -501,7 +537,7 @@ def create_ui(app_config: ApplicationConfig):
501
  gr.Number(label="Temperature increment on fallback", value=app_config.temperature_increment_on_fallback),
502
  gr.Number(label="Compression ratio threshold", value=app_config.compression_ratio_threshold),
503
  gr.Number(label="Logprob threshold", value=app_config.logprob_threshold),
504
- gr.Number(label="No speech threshold", value=app_config.no_speech_threshold)
505
  ], outputs=[
506
  gr.File(label="Download"),
507
  gr.Text(label="Transcription"),
@@ -560,9 +596,14 @@ if __name__ == '__main__':
560
  help="the Whisper implementation to use")
561
  parser.add_argument("--compute_type", type=str, default=default_app_config.compute_type, choices=["default", "auto", "int8", "int8_float16", "int16", "float16", "float32"], \
562
  help="the compute type to use for inference")
 
 
563
 
564
  args = parser.parse_args().__dict__
565
 
566
  updated_config = default_app_config.update(**args)
567
 
 
 
 
568
  create_ui(app_config=updated_config)
 
1
  from datetime import datetime
2
+ import json
3
  import math
4
  from typing import Iterator, Union
5
  import argparse
 
29
  import gradio as gr
30
 
31
  from src.download import ExceededMaximumDuration, download_url
32
+ from src.utils import optional_int, slugify, write_srt, write_vtt
33
  from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
34
  from src.whisper.abstractWhisperContainer import AbstractWhisperContainer
35
  from src.whisper.whisperFactory import create_whisper_container
 
85
  print("[Auto parallel] Using GPU devices " + str(self.parallel_device_list) + " and " + str(self.vad_cpu_cores) + " CPU cores for VAD/transcription.")
86
 
87
  # Entry function for the simple tab
88
+ def transcribe_webui_simple(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
89
+ vad, vadMergeWindow, vadMaxMergeSize,
90
+ word_timestamps: bool = False, highlight_words: bool = False):
91
+ return self.transcribe_webui_simple_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task,
92
+ vad, vadMergeWindow, vadMaxMergeSize,
93
+ word_timestamps, highlight_words)
94
 
95
  # Entry function for the simple tab progress
96
+ def transcribe_webui_simple_progress(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
97
+ vad, vadMergeWindow, vadMaxMergeSize,
98
+ word_timestamps: bool = False, highlight_words: bool = False,
99
+ progress=gr.Progress()):
100
 
101
+ vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, self.app_config.vad_padding, self.app_config.vad_prompt_window, self.app_config.vad_initial_prompt_mode)
102
 
103
+ return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vadOptions,
104
+ word_timestamps=word_timestamps, highlight_words=highlight_words, progress=progress)
105
 
106
  # Entry function for the full tab
107
  def transcribe_webui_full(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
108
+ vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
109
+ # Word timestamps
110
+ word_timestamps: bool, highlight_words: bool, prepend_punctuations: str, append_punctuations: str,
111
+ initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
112
+ condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
113
+ compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float):
114
 
115
  return self.transcribe_webui_full_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task,
116
  vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
117
+ word_timestamps, highlight_words, prepend_punctuations, append_punctuations,
118
  initial_prompt, temperature, best_of, beam_size, patience, length_penalty, suppress_tokens,
119
  condition_on_previous_text, fp16, temperature_increment_on_fallback,
120
  compression_ratio_threshold, logprob_threshold, no_speech_threshold)
121
 
122
  # Entry function for the full tab with progress
123
  def transcribe_webui_full_progress(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
124
+ vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
125
+ # Word timestamps
126
+ word_timestamps: bool, highlight_words: bool, prepend_punctuations: str, append_punctuations: str,
127
+ initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
128
+ condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
129
+ compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float,
130
+ progress=gr.Progress()):
131
 
132
  # Handle temperature_increment_on_fallback
133
  if temperature_increment_on_fallback is not None:
 
141
  initial_prompt=initial_prompt, temperature=temperature, best_of=best_of, beam_size=beam_size, patience=patience, length_penalty=length_penalty, suppress_tokens=suppress_tokens,
142
  condition_on_previous_text=condition_on_previous_text, fp16=fp16,
143
  compression_ratio_threshold=compression_ratio_threshold, logprob_threshold=logprob_threshold, no_speech_threshold=no_speech_threshold,
144
+ word_timestamps=word_timestamps, prepend_punctuations=prepend_punctuations, append_punctuations=append_punctuations, highlight_words=highlight_words,
145
  progress=progress)
146
 
147
  def transcribe_webui(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
148
+ vadOptions: VadOptions, progress: gr.Progress = None, highlight_words: bool = False,
149
+ **decodeOptions: dict):
150
  try:
151
  sources = self.__get_source(urlData, multipleFiles, microphoneData)
152
+
153
  try:
154
  selectedLanguage = languageName.lower() if len(languageName) > 0 else None
155
  selectedModel = modelName if modelName is not None else "base"
 
200
  # Update progress
201
  current_progress += source_audio_duration
202
 
203
+ source_download, source_text, source_vtt = self.write_result(result, filePrefix, outputDirectory, highlight_words)
204
 
205
  if len(sources) > 1:
206
  # Add new line separators
 
374
 
375
  return config
376
 
377
+ def write_result(self, result: dict, source_name: str, output_dir: str, highlight_words: bool = False):
378
  if not os.path.exists(output_dir):
379
  os.makedirs(output_dir)
380
 
 
383
  languageMaxLineWidth = self.__get_max_line_width(language)
384
 
385
  print("Max line width " + str(languageMaxLineWidth))
386
+ vtt = self.__get_subs(result["segments"], "vtt", languageMaxLineWidth, highlight_words=highlight_words)
387
+ srt = self.__get_subs(result["segments"], "srt", languageMaxLineWidth, highlight_words=highlight_words)
388
+ json_result = json.dumps(result, indent=4, ensure_ascii=False)
389
 
390
  output_files = []
391
  output_files.append(self.__create_file(srt, output_dir, source_name + "-subs.srt"));
392
  output_files.append(self.__create_file(vtt, output_dir, source_name + "-subs.vtt"));
393
  output_files.append(self.__create_file(text, output_dir, source_name + "-transcript.txt"));
394
+ output_files.append(self.__create_file(json_result, output_dir, source_name + "-result.json"));
395
 
396
  return output_files, text, vtt
397
 
 
411
  # 80 latin characters should fit on a 1080p/720p screen
412
  return 80
413
 
414
+ def __get_subs(self, segments: Iterator[dict], format: str, maxLineWidth: int, highlight_words: bool = False) -> str:
415
  segmentStream = StringIO()
416
 
417
  if format == 'vtt':
418
+ write_vtt(segments, file=segmentStream, maxLineWidth=maxLineWidth, highlight_words=highlight_words)
419
  elif format == 'srt':
420
+ write_srt(segments, file=segmentStream, maxLineWidth=maxLineWidth, highlight_words=highlight_words)
421
  else:
422
  raise Exception("Unknown format " + format)
423
 
 
477
 
478
  whisper_models = app_config.get_model_names()
479
 
480
+ common_inputs = lambda : [
481
  gr.Dropdown(choices=whisper_models, value=app_config.default_model_name, label="Model"),
482
  gr.Dropdown(choices=sorted(get_language_names()), label="Language", value=app_config.language),
483
  gr.Text(label="URL (YouTube, etc.)"),
484
  gr.File(label="Upload Files", file_count="multiple"),
485
  gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
486
  gr.Dropdown(choices=["transcribe", "translate"], label="Task", value=app_config.task),
487
+ ]
488
+
489
+ common_vad_inputs = lambda : [
490
  gr.Dropdown(choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], value=app_config.default_vad, label="VAD"),
491
  gr.Number(label="VAD - Merge Window (s)", precision=0, value=app_config.vad_merge_window),
492
  gr.Number(label="VAD - Max Merge Size (s)", precision=0, value=app_config.vad_max_merge_size),
493
+ ]
494
+
495
+ common_word_timestamps_inputs = lambda : [
496
+ gr.Checkbox(label="Word Timestamps", value=app_config.word_timestamps),
497
+ gr.Checkbox(label="Word Timestamps - Highlight Words", value=app_config.highlight_words),
498
  ]
499
 
500
  is_queue_mode = app_config.queue_concurrency_count is not None and app_config.queue_concurrency_count > 0
501
 
502
  simple_transcribe = gr.Interface(fn=ui.transcribe_webui_simple_progress if is_queue_mode else ui.transcribe_webui_simple,
503
+ description=ui_description, article=ui_article, inputs=[
504
+ *common_inputs(),
505
+ *common_vad_inputs(),
506
+ *common_word_timestamps_inputs(),
507
+ ], outputs=[
508
  gr.File(label="Download"),
509
  gr.Text(label="Transcription"),
510
  gr.Text(label="Segments")
 
514
 
515
  full_transcribe = gr.Interface(fn=ui.transcribe_webui_full_progress if is_queue_mode else ui.transcribe_webui_full,
516
  description=full_description, article=ui_article, inputs=[
517
+ *common_inputs(),
518
+
519
+ *common_vad_inputs(),
520
+ gr.Number(label="VAD - Padding (s)", precision=None, value=app_config.vad_padding),
521
+ gr.Number(label="VAD - Prompt Window (s)", precision=None, value=app_config.vad_prompt_window),
522
  gr.Dropdown(choices=["prepend_first_segment", "prepend_all_segments"], value=app_config.vad_initial_prompt_mode, label="VAD - Initial Prompt Mode"),
523
+
524
+ *common_word_timestamps_inputs(),
525
+ gr.Text(label="Word Timestamps - Prepend Punctuations", value=app_config.prepend_punctuations),
526
+ gr.Text(label="Word Timestamps - Append Punctuations", value=app_config.append_punctuations),
527
+
528
  gr.TextArea(label="Initial Prompt"),
529
  gr.Number(label="Temperature", value=app_config.temperature),
530
  gr.Number(label="Best Of - Non-zero temperature", value=app_config.best_of, precision=0),
 
537
  gr.Number(label="Temperature increment on fallback", value=app_config.temperature_increment_on_fallback),
538
  gr.Number(label="Compression ratio threshold", value=app_config.compression_ratio_threshold),
539
  gr.Number(label="Logprob threshold", value=app_config.logprob_threshold),
540
+ gr.Number(label="No speech threshold", value=app_config.no_speech_threshold),
541
  ], outputs=[
542
  gr.File(label="Download"),
543
  gr.Text(label="Transcription"),
 
596
  help="the Whisper implementation to use")
597
  parser.add_argument("--compute_type", type=str, default=default_app_config.compute_type, choices=["default", "auto", "int8", "int8_float16", "int16", "float16", "float32"], \
598
  help="the compute type to use for inference")
599
+ parser.add_argument("--threads", type=optional_int, default=0,
600
+ help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
601
 
602
  args = parser.parse_args().__dict__
603
 
604
  updated_config = default_app_config.update(**args)
605
 
606
+ if (threads := args.pop("threads")) > 0:
607
+ torch.set_num_threads(threads)
608
+
609
  create_ui(app_config=updated_config)
cli.py CHANGED
@@ -95,6 +95,17 @@ def cli():
95
  parser.add_argument("--no_speech_threshold", type=optional_float, default=app_config.no_speech_threshold, \
96
  help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
97
 
 
 
 
 
 
 
 
 
 
 
 
98
  args = parser.parse_args().__dict__
99
  model_name: str = args.pop("model")
100
  model_dir: str = args.pop("model_dir")
@@ -102,6 +113,9 @@ def cli():
102
  device: str = args.pop("device")
103
  os.makedirs(output_dir, exist_ok=True)
104
 
 
 
 
105
  whisper_implementation = args.pop("whisper_implementation")
106
  print(f"Using {whisper_implementation} for Whisper")
107
 
@@ -126,6 +140,7 @@ def cli():
126
  auto_parallel = args.pop("auto_parallel")
127
 
128
  compute_type = args.pop("compute_type")
 
129
 
130
  transcriber = WhisperTranscriber(delete_uploaded_files=False, vad_cpu_cores=vad_cpu_cores, app_config=app_config)
131
  transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
@@ -133,7 +148,7 @@ def cli():
133
 
134
  model = create_whisper_container(whisper_implementation=whisper_implementation, model_name=model_name,
135
  device=device, compute_type=compute_type, download_root=model_dir, models=app_config.models)
136
-
137
  if (transcriber._has_parallel_devices()):
138
  print("Using parallel devices:", transcriber.parallel_device_list)
139
 
@@ -158,7 +173,7 @@ def cli():
158
 
159
  result = transcriber.transcribe_file(model, source_path, temperature=temperature, vadOptions=vadOptions, **args)
160
 
161
- transcriber.write_result(result, source_name, output_dir)
162
 
163
  transcriber.close()
164
 
 
95
  parser.add_argument("--no_speech_threshold", type=optional_float, default=app_config.no_speech_threshold, \
96
  help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
97
 
98
+ parser.add_argument("--word_timestamps", type=str2bool, default=app_config.word_timestamps,
99
+ help="(experimental) extract word-level timestamps and refine the results based on them")
100
+ parser.add_argument("--prepend_punctuations", type=str, default=app_config.prepend_punctuations,
101
+ help="if word_timestamps is True, merge these punctuation symbols with the next word")
102
+ parser.add_argument("--append_punctuations", type=str, default=app_config.append_punctuations,
103
+ help="if word_timestamps is True, merge these punctuation symbols with the previous word")
104
+ parser.add_argument("--highlight_words", type=str2bool, default=app_config.highlight_words,
105
+ help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
106
+ parser.add_argument("--threads", type=optional_int, default=0,
107
+ help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
108
+
109
  args = parser.parse_args().__dict__
110
  model_name: str = args.pop("model")
111
  model_dir: str = args.pop("model_dir")
 
113
  device: str = args.pop("device")
114
  os.makedirs(output_dir, exist_ok=True)
115
 
116
+ if (threads := args.pop("threads")) > 0:
117
+ torch.set_num_threads(threads)
118
+
119
  whisper_implementation = args.pop("whisper_implementation")
120
  print(f"Using {whisper_implementation} for Whisper")
121
 
 
140
  auto_parallel = args.pop("auto_parallel")
141
 
142
  compute_type = args.pop("compute_type")
143
+ highlight_words = args.pop("highlight_words")
144
 
145
  transcriber = WhisperTranscriber(delete_uploaded_files=False, vad_cpu_cores=vad_cpu_cores, app_config=app_config)
146
  transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
 
148
 
149
  model = create_whisper_container(whisper_implementation=whisper_implementation, model_name=model_name,
150
  device=device, compute_type=compute_type, download_root=model_dir, models=app_config.models)
151
+
152
  if (transcriber._has_parallel_devices()):
153
  print("Using parallel devices:", transcriber.parallel_device_list)
154
 
 
173
 
174
  result = transcriber.transcribe_file(model, source_path, temperature=temperature, vadOptions=vadOptions, **args)
175
 
176
+ transcriber.write_result(result, source_name, output_dir, highlight_words)
177
 
178
  transcriber.close()
179
 
config.json5 CHANGED
@@ -128,5 +128,14 @@
128
  // If the average log probability is lower than this value, treat the decoding as failed
129
  "logprob_threshold": -1.0,
130
  // If the probability of the <no-speech> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence
131
- "no_speech_threshold": 0.6
 
 
 
 
 
 
 
 
 
132
  }
 
128
  // If the average log probability is lower than this value, treat the decoding as failed
129
  "logprob_threshold": -1.0,
130
  // If the probability of the <no-speech> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence
131
+ "no_speech_threshold": 0.6,
132
+
133
+ // (experimental) extract word-level timestamps and refine the results based on them
134
+ "word_timestamps": false,
135
+ // if word_timestamps is True, merge these punctuation symbols with the next word
136
+ "prepend_punctuations": "\"\'“¿([{-",
137
+ // if word_timestamps is True, merge these punctuation symbols with the previous word
138
+ "append_punctuations": "\"\'.。,,!!??::”)]}、",
139
+ // (requires --word_timestamps True) underline each word as it is spoken in srt and vtt
140
+ "highlight_words": false,
141
  }
src/config.py CHANGED
@@ -58,7 +58,11 @@ class ApplicationConfig:
58
  condition_on_previous_text: bool = True, fp16: bool = True,
59
  compute_type: str = "float16",
60
  temperature_increment_on_fallback: float = 0.2, compression_ratio_threshold: float = 2.4,
61
- logprob_threshold: float = -1.0, no_speech_threshold: float = 0.6):
 
 
 
 
62
 
63
  self.models = models
64
 
@@ -104,6 +108,12 @@ class ApplicationConfig:
104
  self.logprob_threshold = logprob_threshold
105
  self.no_speech_threshold = no_speech_threshold
106
 
 
 
 
 
 
 
107
  def get_model_names(self):
108
  return [ x.name for x in self.models ]
109
 
 
58
  condition_on_previous_text: bool = True, fp16: bool = True,
59
  compute_type: str = "float16",
60
  temperature_increment_on_fallback: float = 0.2, compression_ratio_threshold: float = 2.4,
61
+ logprob_threshold: float = -1.0, no_speech_threshold: float = 0.6,
62
+ # Word timestamp settings
63
+ word_timestamps: bool = False, prepend_punctuations: str = "\"\'“¿([{-",
64
+ append_punctuations: str = "\"\'.。,,!!??::”)]}、",
65
+ highlight_words: bool = False):
66
 
67
  self.models = models
68
 
 
108
  self.logprob_threshold = logprob_threshold
109
  self.no_speech_threshold = no_speech_threshold
110
 
111
+ # Word timestamp settings
112
+ self.word_timestamps = word_timestamps
113
+ self.prepend_punctuations = prepend_punctuations
114
+ self.append_punctuations = append_punctuations
115
+ self.highlight_words = highlight_words
116
+
117
  def get_model_names(self):
118
  return [ x.name for x in self.models ]
119
 
src/utils.py CHANGED
@@ -3,7 +3,7 @@ import unicodedata
3
  import re
4
 
5
  import zlib
6
- from typing import Iterator, TextIO
7
  import tqdm
8
 
9
  import urllib3
@@ -56,10 +56,14 @@ def write_txt(transcript: Iterator[dict], file: TextIO):
56
  print(segment['text'].strip(), file=file, flush=True)
57
 
58
 
59
- def write_vtt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
 
 
 
60
  print("WEBVTT\n", file=file)
61
- for segment in transcript:
62
- text = process_text(segment['text'], maxLineWidth).replace('-->', '->')
 
63
 
64
  print(
65
  f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
@@ -68,8 +72,8 @@ def write_vtt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
68
  flush=True,
69
  )
70
 
71
-
72
- def write_srt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
73
  """
74
  Write a transcript to a file in SRT format.
75
  Example usage:
@@ -81,8 +85,10 @@ def write_srt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
81
  with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
82
  write_srt(result["segments"], file=srt)
83
  """
84
- for i, segment in enumerate(transcript, start=1):
85
- text = process_text(segment['text'].strip(), maxLineWidth).replace('-->', '->')
 
 
86
 
87
  # write srt lines
88
  print(
@@ -94,6 +100,110 @@ def write_srt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
94
  flush=True,
95
  )
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  def process_text(text: str, maxLineWidth=None):
98
  if (maxLineWidth is None or maxLineWidth < 0):
99
  return text
 
3
  import re
4
 
5
  import zlib
6
+ from typing import Iterator, TextIO, Union
7
  import tqdm
8
 
9
  import urllib3
 
56
  print(segment['text'].strip(), file=file, flush=True)
57
 
58
 
59
+ def write_vtt(transcript: Iterator[dict], file: TextIO,
60
+ maxLineWidth=None, highlight_words: bool = False):
61
+ iterator = __subtitle_preprocessor_iterator(transcript, maxLineWidth, highlight_words)
62
+
63
  print("WEBVTT\n", file=file)
64
+
65
+ for segment in iterator:
66
+ text = segment['text'].replace('-->', '->')
67
 
68
  print(
69
  f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
 
72
  flush=True,
73
  )
74
 
75
+ def write_srt(transcript: Iterator[dict], file: TextIO,
76
+ maxLineWidth=None, highlight_words: bool = False):
77
  """
78
  Write a transcript to a file in SRT format.
79
  Example usage:
 
85
  with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
86
  write_srt(result["segments"], file=srt)
87
  """
88
+ iterator = __subtitle_preprocessor_iterator(transcript, maxLineWidth, highlight_words)
89
+
90
+ for i, segment in enumerate(iterator, start=1):
91
+ text = segment['text'].replace('-->', '->')
92
 
93
  # write srt lines
94
  print(
 
100
  flush=True,
101
  )
102
 
103
+ def __subtitle_preprocessor_iterator(transcript: Iterator[dict], maxLineWidth: int = None, highlight_words: bool = False):
104
+ for segment in transcript:
105
+ words = segment.get('words', [])
106
+
107
+ if len(words) == 0:
108
+ # Yield the segment as-is or processed
109
+ if maxLineWidth is None or maxLineWidth < 0:
110
+ yield segment
111
+ else:
112
+ yield {
113
+ 'start': segment['start'],
114
+ 'end': segment['end'],
115
+ 'text': process_text(segment['text'].strip(), maxLineWidth)
116
+ }
117
+ # We are done
118
+ continue
119
+
120
+ subtitle_start = segment['start']
121
+ subtitle_end = segment['end']
122
+
123
+ text_words = [ this_word["word"] for this_word in words ]
124
+ subtitle_text = __join_words(text_words, maxLineWidth)
125
+
126
+ # Iterate over the words in the segment
127
+ if highlight_words:
128
+ last = subtitle_start
129
+
130
+ for i, this_word in enumerate(words):
131
+ start = this_word['start']
132
+ end = this_word['end']
133
+
134
+ if last != start:
135
+ # Display the text up to this point
136
+ yield {
137
+ 'start': last,
138
+ 'end': start,
139
+ 'text': subtitle_text
140
+ }
141
+
142
+ # Display the text with the current word highlighted
143
+ yield {
144
+ 'start': start,
145
+ 'end': end,
146
+ 'text': __join_words(
147
+ [
148
+ {
149
+ "word": re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
150
+ if j == i
151
+ else word,
152
+ # The HTML tags <u> and </u> are not displayed,
153
+ # # so they should not be counted in the word length
154
+ "length": len(word)
155
+ } for j, word in enumerate(text_words)
156
+ ], maxLineWidth)
157
+ }
158
+ last = end
159
+
160
+ if last != subtitle_end:
161
+ # Display the last part of the text
162
+ yield {
163
+ 'start': last,
164
+ 'end': subtitle_end,
165
+ 'text': subtitle_text
166
+ }
167
+
168
+ # Just return the subtitle text
169
+ else:
170
+ yield {
171
+ 'start': subtitle_start,
172
+ 'end': subtitle_end,
173
+ 'text': subtitle_text
174
+ }
175
+
176
+ def __join_words(words: Iterator[Union[str, dict]], maxLineWidth: int = None):
177
+ if maxLineWidth is None or maxLineWidth < 0:
178
+ return " ".join(words)
179
+
180
+ lines = []
181
+ current_line = ""
182
+ current_length = 0
183
+
184
+ for entry in words:
185
+ # Either accept a string or a dict with a 'word' and 'length' field
186
+ if isinstance(entry, dict):
187
+ word = entry['word']
188
+ word_length = entry['length']
189
+ else:
190
+ word = entry
191
+ word_length = len(word)
192
+
193
+ if current_length > 0 and current_length + word_length > maxLineWidth:
194
+ lines.append(current_line)
195
+ current_line = ""
196
+ current_length = 0
197
+
198
+ current_length += word_length
199
+ # The word will be prefixed with a space by Whisper, so we don't need to add one here
200
+ current_line += word
201
+
202
+ if len(current_line) > 0:
203
+ lines.append(current_line)
204
+
205
+ return "\n".join(lines)
206
+
207
  def process_text(text: str, maxLineWidth=None):
208
  if (maxLineWidth is None or maxLineWidth < 0):
209
  return text
src/vad.py CHANGED
@@ -404,6 +404,14 @@ class AbstractTranscription(ABC):
404
  # Add to start and end
405
  new_segment['start'] = segment_start + adjust_seconds
406
  new_segment['end'] = segment_end + adjust_seconds
 
 
 
 
 
 
 
 
407
  result.append(new_segment)
408
  return result
409
 
 
404
  # Add to start and end
405
  new_segment['start'] = segment_start + adjust_seconds
406
  new_segment['end'] = segment_end + adjust_seconds
407
+
408
+ # Handle words
409
+ if ('words' in new_segment):
410
+ for word in new_segment['words']:
411
+ # Adjust start and end
412
+ word['start'] = word['start'] + adjust_seconds
413
+ word['end'] = word['end'] + adjust_seconds
414
+
415
  result.append(new_segment)
416
  return result
417
 
src/whisper/whisperContainer.py CHANGED
@@ -203,8 +203,9 @@ class WhisperCallback(AbstractWhisperCallback):
203
 
204
  initial_prompt = self._get_initial_prompt(self.initial_prompt, self.initial_prompt_mode, prompt, segment_index)
205
 
206
- return model.transcribe(audio, \
207
  language=self.language if self.language else detected_language, task=self.task, \
208
  initial_prompt=initial_prompt, \
209
  **decodeOptions
210
- )
 
 
203
 
204
  initial_prompt = self._get_initial_prompt(self.initial_prompt, self.initial_prompt_mode, prompt, segment_index)
205
 
206
+ result = model.transcribe(audio, \
207
  language=self.language if self.language else detected_language, task=self.task, \
208
  initial_prompt=initial_prompt, \
209
  **decodeOptions
210
+ )
211
+ return result