aadnk commited on
Commit
fdb8dbd
·
1 Parent(s): 7d67dd2

Add a "full" interface with every option supported by Whisper

Browse files

Note that some of these options may crash the model if they are set incorrectly.
Use with care.

Files changed (1) hide show
  1. app.py +52 -3
app.py CHANGED
@@ -8,6 +8,7 @@ import os
8
  import pathlib
9
  import tempfile
10
  import zipfile
 
11
 
12
  import torch
13
  from src.modelCache import ModelCache
@@ -85,7 +86,28 @@ class WhisperTranscriber:
85
  self.vad_cpu_cores = min(os.cpu_count(), MAX_AUTO_CPU_CORES)
86
  print("[Auto parallel] Using GPU devices " + str(self.parallel_device_list) + " and " + str(self.vad_cpu_cores) + " CPU cores for VAD/transcription.")
87
 
88
- def transcribe_webui(self, modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  try:
90
  sources = self.__get_source(urlData, multipleFiles, microphoneData)
91
 
@@ -118,7 +140,7 @@ class WhisperTranscriber:
118
  print("Transcribing ", source.source_path)
119
 
120
  # Transcribe
121
- result = self.transcribe_file(model, source.source_path, selectedLanguage, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
122
  filePrefix = slugify(source_prefix + source.get_short_name(), allow_unicode=True)
123
 
124
  source_download, source_text, source_vtt = self.write_result(result, filePrefix, outputDirectory)
@@ -356,7 +378,7 @@ def create_ui(input_audio_max_duration, share=False, server_name: str = None, se
356
 
357
  ui_article = "Read the [documentation here](https://gitlab.com/aadnk/whisper-webui/-/blob/main/docs/options.md)"
358
 
359
- demo = gr.Interface(fn=ui.transcribe_webui, description=ui_description, article=ui_article, inputs=[
360
  gr.Dropdown(choices=WHISPER_MODELS, value=default_model_name, label="Model"),
361
  gr.Dropdown(choices=sorted(LANGUAGES), label="Language"),
362
  gr.Text(label="URL (YouTube, etc.)"),
@@ -368,12 +390,39 @@ def create_ui(input_audio_max_duration, share=False, server_name: str = None, se
368
  gr.Number(label="VAD - Max Merge Size (s)", precision=0, value=30),
369
  gr.Number(label="VAD - Padding (s)", precision=None, value=1),
370
  gr.Number(label="VAD - Prompt Window (s)", precision=None, value=3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
  ], outputs=[
372
  gr.File(label="Download"),
373
  gr.Text(label="Transcription"),
374
  gr.Text(label="Segments")
375
  ])
376
 
 
 
377
  demo.launch(share=share, server_name=server_name, server_port=server_port)
378
 
379
  # Clean up
 
8
  import pathlib
9
  import tempfile
10
  import zipfile
11
+ import numpy as np
12
 
13
  import torch
14
  from src.modelCache import ModelCache
 
86
  self.vad_cpu_cores = min(os.cpu_count(), MAX_AUTO_CPU_CORES)
87
  print("[Auto parallel] Using GPU devices " + str(self.parallel_device_list) + " and " + str(self.vad_cpu_cores) + " CPU cores for VAD/transcription.")
88
 
89
+ # Entry function for the simple tab
90
+ def transcribe_webui_simple(self, modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow):
91
+ return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
92
+
93
+ # Entry function for the full tab
94
+ def transcribe_webui_full(self, modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow,
95
+ initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
96
+ condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
97
+ compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float):
98
+
99
+ # Handle temperature_increment_on_fallback
100
+ if temperature_increment_on_fallback is not None:
101
+ temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
102
+ else:
103
+ temperature = [temperature]
104
+
105
+ return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow,
106
+ initial_prompt=initial_prompt, temperature=temperature, best_of=best_of, beam_size=beam_size, patience=patience, length_penalty=length_penalty, suppress_tokens=suppress_tokens,
107
+ condition_on_previous_text=condition_on_previous_text, fp16=fp16,
108
+ compression_ratio_threshold=compression_ratio_threshold, logprob_threshold=logprob_threshold, no_speech_threshold=no_speech_threshold)
109
+
110
+ def transcribe_webui(self, modelName, languageName, urlData, multipleFiles, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, **decodeOptions: dict):
111
  try:
112
  sources = self.__get_source(urlData, multipleFiles, microphoneData)
113
 
 
140
  print("Transcribing ", source.source_path)
141
 
142
  # Transcribe
143
+ result = self.transcribe_file(model, source.source_path, selectedLanguage, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, **decodeOptions)
144
  filePrefix = slugify(source_prefix + source.get_short_name(), allow_unicode=True)
145
 
146
  source_download, source_text, source_vtt = self.write_result(result, filePrefix, outputDirectory)
 
378
 
379
  ui_article = "Read the [documentation here](https://gitlab.com/aadnk/whisper-webui/-/blob/main/docs/options.md)"
380
 
381
+ simple_inputs = lambda : [
382
  gr.Dropdown(choices=WHISPER_MODELS, value=default_model_name, label="Model"),
383
  gr.Dropdown(choices=sorted(LANGUAGES), label="Language"),
384
  gr.Text(label="URL (YouTube, etc.)"),
 
390
  gr.Number(label="VAD - Max Merge Size (s)", precision=0, value=30),
391
  gr.Number(label="VAD - Padding (s)", precision=None, value=1),
392
  gr.Number(label="VAD - Prompt Window (s)", precision=None, value=3)
393
+ ]
394
+
395
+ simple_transcribe = gr.Interface(fn=ui.transcribe_webui_simple, description=ui_description, article=ui_article, inputs=simple_inputs(), outputs=[
396
+ gr.File(label="Download"),
397
+ gr.Text(label="Transcription"),
398
+ gr.Text(label="Segments")
399
+ ])
400
+
401
+ full_description = ui_description + "\n\n\n\n" + "Be careful when changing some of the options in the full interface - this can cause the model to crash."
402
+
403
+ full_transcribe = gr.Interface(fn=ui.transcribe_webui_full, description=full_description, article=ui_article, inputs=[
404
+ *simple_inputs(),
405
+ gr.TextArea(label="Initial Prompt"),
406
+ gr.Number(label="Temperature", value=0),
407
+ gr.Number(label="Best Of - Non-zero temperature", value=5, precision=0),
408
+ gr.Number(label="Beam Size - Zero temperature", value=5, precision=0),
409
+ gr.Number(label="Patience - Zero temperature", value=None),
410
+ gr.Number(label="Length Penalty - Any temperature", value=None),
411
+ gr.Text(label="Suppress Tokens - Comma-separated list of token IDs", value="-1"),
412
+ gr.Checkbox(label="Condition on previous text", value=True),
413
+ gr.Checkbox(label="FP16", value=True),
414
+ gr.Number(label="Temperature increment on fallback", value=0.2),
415
+ gr.Number(label="Compression ratio threshold", value=2.4),
416
+ gr.Number(label="Logprob threshold", value=-1.0),
417
+ gr.Number(label="No speech threshold", value=0.6)
418
  ], outputs=[
419
  gr.File(label="Download"),
420
  gr.Text(label="Transcription"),
421
  gr.Text(label="Segments")
422
  ])
423
 
424
+ demo = gr.TabbedInterface([simple_transcribe, full_transcribe], tab_names=["Simple", "Full"])
425
+
426
  demo.launch(share=share, server_name=server_name, server_port=server_port)
427
 
428
  # Clean up