aadnk commited on
Commit
1acaa19
1 Parent(s): 8670926

Add more configuration options to config.json5

Browse files
Files changed (7) hide show
  1. app-local.py +3 -1
  2. app-network.py +3 -1
  3. app-shared.py +3 -1
  4. app.py +35 -39
  5. cli.py +27 -24
  6. config.json5 +54 -3
  7. src/config.py +64 -4
app-local.py CHANGED
@@ -1,3 +1,5 @@
1
  # Run the app with no audio file restrictions
2
  from app import create_ui
3
- create_ui(-1)
 
 
1
  # Run the app with no audio file restrictions
2
  from app import create_ui
3
+ from src.config import ApplicationConfig
4
+
5
+ create_ui(ApplicationConfig.create_default(input_audio_max_duration=-1))
app-network.py CHANGED
@@ -1,3 +1,5 @@
1
  # Run the app with no audio file restrictions, and make it available on the network
2
  from app import create_ui
3
- create_ui(-1, server_name="0.0.0.0")
 
 
1
  # Run the app with no audio file restrictions, and make it available on the network
2
  from app import create_ui
3
+ from src.config import ApplicationConfig
4
+
5
+ create_ui(ApplicationConfig.create_default(input_audio_max_duration=-1, server_name="0.0.0.0"))
app-shared.py CHANGED
@@ -1,3 +1,5 @@
1
  # Run the app with no audio file restrictions
2
  from app import create_ui
3
- create_ui(-1, share=True)
 
 
1
  # Run the app with no audio file restrictions
2
  from app import create_ui
3
+ from src.config import ApplicationConfig
4
+
5
+ create_ui(ApplicationConfig.create_default(input_audio_max_duration=-1, share=True))
app.py CHANGED
@@ -27,11 +27,7 @@ from src.utils import slugify, write_srt, write_vtt
27
  from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
28
  from src.whisperContainer import WhisperContainer
29
 
30
- # Limitations (set to -1 to disable)
31
- DEFAULT_INPUT_AUDIO_MAX_DURATION = 600 # seconds
32
-
33
- # Whether or not to automatically delete all uploaded files, to save disk space
34
- DELETE_UPLOADED_FILES = True
35
 
36
  # Gradio seems to truncate files without keeping the extension, so we need to truncate the file prefix ourself
37
  MAX_FILE_PREFIX_LENGTH = 17
@@ -62,8 +58,8 @@ LANGUAGES = [
62
  WHISPER_MODELS = ["tiny", "base", "small", "medium", "large", "large-v1", "large-v2"]
63
 
64
  class WhisperTranscriber:
65
- def __init__(self, input_audio_max_duration: float = DEFAULT_INPUT_AUDIO_MAX_DURATION, vad_process_timeout: float = None,
66
- vad_cpu_cores: int = 1, delete_uploaded_files: bool = DELETE_UPLOADED_FILES, output_dir: str = None,
67
  app_config: ApplicationConfig = None):
68
  self.model_cache = ModelCache()
69
  self.parallel_device_list = None
@@ -361,15 +357,13 @@ class WhisperTranscriber:
361
  self.cpu_parallel_context.close()
362
 
363
 
364
- def create_ui(input_audio_max_duration, share=False, server_name: str = None, server_port: int = 7860,
365
- default_model_name: str = "medium", default_vad: str = None, vad_parallel_devices: str = None,
366
- vad_process_timeout: float = None, vad_cpu_cores: int = 1, auto_parallel: bool = False,
367
- output_dir: str = None, app_config: ApplicationConfig = None):
368
- ui = WhisperTranscriber(input_audio_max_duration, vad_process_timeout, vad_cpu_cores, DELETE_UPLOADED_FILES, output_dir, app_config)
369
 
370
  # Specify a list of devices to use for parallel processing
371
- ui.set_parallel_devices(vad_parallel_devices)
372
- ui.set_auto_parallel(auto_parallel)
373
 
374
  ui_description = "Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse "
375
  ui_description += " audio and is also a multi-task model that can perform multilingual speech recognition "
@@ -377,25 +371,25 @@ def create_ui(input_audio_max_duration, share=False, server_name: str = None, se
377
 
378
  ui_description += "\n\n\n\nFor longer audio files (>10 minutes) not in English, it is recommended that you select Silero VAD (Voice Activity Detector) in the VAD option."
379
 
380
- if input_audio_max_duration > 0:
381
- ui_description += "\n\n" + "Max audio file length: " + str(input_audio_max_duration) + " s"
382
 
383
  ui_article = "Read the [documentation here](https://gitlab.com/aadnk/whisper-webui/-/blob/main/docs/options.md)"
384
 
385
  whisper_models = app_config.get_model_names()
386
 
387
  simple_inputs = lambda : [
388
- gr.Dropdown(choices=whisper_models, value=default_model_name, label="Model"),
389
- gr.Dropdown(choices=sorted(LANGUAGES), label="Language"),
390
  gr.Text(label="URL (YouTube, etc.)"),
391
  gr.File(label="Upload Files", file_count="multiple"),
392
  gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
393
- gr.Dropdown(choices=["transcribe", "translate"], label="Task"),
394
- gr.Dropdown(choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], value=default_vad, label="VAD"),
395
- gr.Number(label="VAD - Merge Window (s)", precision=0, value=5),
396
- gr.Number(label="VAD - Max Merge Size (s)", precision=0, value=30),
397
- gr.Number(label="VAD - Padding (s)", precision=None, value=1),
398
- gr.Number(label="VAD - Prompt Window (s)", precision=None, value=3)
399
  ]
400
 
401
  simple_transcribe = gr.Interface(fn=ui.transcribe_webui_simple, description=ui_description, article=ui_article, inputs=simple_inputs(), outputs=[
@@ -409,18 +403,18 @@ def create_ui(input_audio_max_duration, share=False, server_name: str = None, se
409
  full_transcribe = gr.Interface(fn=ui.transcribe_webui_full, description=full_description, article=ui_article, inputs=[
410
  *simple_inputs(),
411
  gr.TextArea(label="Initial Prompt"),
412
- gr.Number(label="Temperature", value=0),
413
- gr.Number(label="Best Of - Non-zero temperature", value=5, precision=0),
414
- gr.Number(label="Beam Size - Zero temperature", value=5, precision=0),
415
- gr.Number(label="Patience - Zero temperature", value=None),
416
- gr.Number(label="Length Penalty - Any temperature", value=None),
417
- gr.Text(label="Suppress Tokens - Comma-separated list of token IDs", value="-1"),
418
- gr.Checkbox(label="Condition on previous text", value=True),
419
- gr.Checkbox(label="FP16", value=True),
420
- gr.Number(label="Temperature increment on fallback", value=0.2),
421
- gr.Number(label="Compression ratio threshold", value=2.4),
422
- gr.Number(label="Logprob threshold", value=-1.0),
423
- gr.Number(label="No speech threshold", value=0.6)
424
  ], outputs=[
425
  gr.File(label="Download"),
426
  gr.Text(label="Transcription"),
@@ -429,13 +423,13 @@ def create_ui(input_audio_max_duration, share=False, server_name: str = None, se
429
 
430
  demo = gr.TabbedInterface([simple_transcribe, full_transcribe], tab_names=["Simple", "Full"])
431
 
432
- demo.launch(share=share, server_name=server_name, server_port=server_port)
433
 
434
  # Clean up
435
  ui.close()
436
 
437
  if __name__ == '__main__':
438
- app_config = ApplicationConfig.parse_file(os.environ.get("WHISPER_WEBUI_CONFIG", "config.json5"))
439
  whisper_models = app_config.get_model_names()
440
 
441
  parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
@@ -463,4 +457,6 @@ if __name__ == '__main__':
463
  help="directory to save the outputs") # None
464
 
465
  args = parser.parse_args().__dict__
466
- create_ui(app_config=app_config, **args)
 
 
27
  from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
28
  from src.whisperContainer import WhisperContainer
29
 
30
+ # Configure more application defaults in config.json5
 
 
 
 
31
 
32
  # Gradio seems to truncate files without keeping the extension, so we need to truncate the file prefix ourself
33
  MAX_FILE_PREFIX_LENGTH = 17
58
  WHISPER_MODELS = ["tiny", "base", "small", "medium", "large", "large-v1", "large-v2"]
59
 
60
  class WhisperTranscriber:
61
+ def __init__(self, input_audio_max_duration: float = None, vad_process_timeout: float = None,
62
+ vad_cpu_cores: int = 1, delete_uploaded_files: bool = False, output_dir: str = None,
63
  app_config: ApplicationConfig = None):
64
  self.model_cache = ModelCache()
65
  self.parallel_device_list = None
357
  self.cpu_parallel_context.close()
358
 
359
 
360
+ def create_ui(app_config: ApplicationConfig):
361
+ ui = WhisperTranscriber(app_config.input_audio_max_duration, app_config.vad_process_timeout, app_config.vad_cpu_cores,
362
+ app_config.delete_uploaded_files, app_config.output_dir, app_config)
 
 
363
 
364
  # Specify a list of devices to use for parallel processing
365
+ ui.set_parallel_devices(app_config.vad_parallel_devices)
366
+ ui.set_auto_parallel(app_config.auto_parallel)
367
 
368
  ui_description = "Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse "
369
  ui_description += " audio and is also a multi-task model that can perform multilingual speech recognition "
371
 
372
  ui_description += "\n\n\n\nFor longer audio files (>10 minutes) not in English, it is recommended that you select Silero VAD (Voice Activity Detector) in the VAD option."
373
 
374
+ if app_config.input_audio_max_duration > 0:
375
+ ui_description += "\n\n" + "Max audio file length: " + str(app_config.input_audio_max_duration) + " s"
376
 
377
  ui_article = "Read the [documentation here](https://gitlab.com/aadnk/whisper-webui/-/blob/main/docs/options.md)"
378
 
379
  whisper_models = app_config.get_model_names()
380
 
381
  simple_inputs = lambda : [
382
+ gr.Dropdown(choices=whisper_models, value=app_config.default_model_name, label="Model"),
383
+ gr.Dropdown(choices=sorted(LANGUAGES), label="Language", value=app_config.language),
384
  gr.Text(label="URL (YouTube, etc.)"),
385
  gr.File(label="Upload Files", file_count="multiple"),
386
  gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
387
+ gr.Dropdown(choices=["transcribe", "translate"], label="Task", value=app_config.task),
388
+ gr.Dropdown(choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], value=app_config.default_vad, label="VAD"),
389
+ gr.Number(label="VAD - Merge Window (s)", precision=0, value=app_config.vad_merge_window),
390
+ gr.Number(label="VAD - Max Merge Size (s)", precision=0, value=app_config.vad_max_merge_size),
391
+ gr.Number(label="VAD - Padding (s)", precision=None, value=app_config.vad_padding),
392
+ gr.Number(label="VAD - Prompt Window (s)", precision=None, value=app_config.vad_prompt_window),
393
  ]
394
 
395
  simple_transcribe = gr.Interface(fn=ui.transcribe_webui_simple, description=ui_description, article=ui_article, inputs=simple_inputs(), outputs=[
403
  full_transcribe = gr.Interface(fn=ui.transcribe_webui_full, description=full_description, article=ui_article, inputs=[
404
  *simple_inputs(),
405
  gr.TextArea(label="Initial Prompt"),
406
+ gr.Number(label="Temperature", value=app_config.temperature),
407
+ gr.Number(label="Best Of - Non-zero temperature", value=app_config.best_of, precision=0),
408
+ gr.Number(label="Beam Size - Zero temperature", value=app_config.beam_size, precision=0),
409
+ gr.Number(label="Patience - Zero temperature", value=app_config.patience),
410
+ gr.Number(label="Length Penalty - Any temperature", value=app_config.length_penalty),
411
+ gr.Text(label="Suppress Tokens - Comma-separated list of token IDs", value=app_config.suppress_tokens),
412
+ gr.Checkbox(label="Condition on previous text", value=app_config.condition_on_previous_text),
413
+ gr.Checkbox(label="FP16", value=app_config.fp16),
414
+ gr.Number(label="Temperature increment on fallback", value=app_config.temperature_increment_on_fallback),
415
+ gr.Number(label="Compression ratio threshold", value=app_config.compression_ratio_threshold),
416
+ gr.Number(label="Logprob threshold", value=app_config.logprob_threshold),
417
+ gr.Number(label="No speech threshold", value=app_config.no_speech_threshold)
418
  ], outputs=[
419
  gr.File(label="Download"),
420
  gr.Text(label="Transcription"),
423
 
424
  demo = gr.TabbedInterface([simple_transcribe, full_transcribe], tab_names=["Simple", "Full"])
425
 
426
+ demo.launch(share=app_config.share, server_name=app_config.server_name, server_port=app_config.server_port)
427
 
428
  # Clean up
429
  ui.close()
430
 
431
  if __name__ == '__main__':
432
+ app_config = ApplicationConfig.create_default()
433
  whisper_models = app_config.get_model_names()
434
 
435
  parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
457
  help="directory to save the outputs") # None
458
 
459
  args = parser.parse_args().__dict__
460
+
461
+ updated_config = app_config.update(**args)
462
+ create_ui(app_config=updated_config)
cli.py CHANGED
@@ -14,37 +14,40 @@ from src.utils import optional_float, optional_int, str2bool
14
  from src.whisperContainer import WhisperContainer
15
 
16
  def cli():
17
- app_config = ApplicationConfig.parse_file(os.environ.get("WHISPER_WEBUI_CONFIG", "config.json5"))
18
  whisper_models = app_config.get_model_names()
19
 
 
 
 
20
  parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
21
  parser.add_argument("audio", nargs="+", type=str, \
22
  help="audio file(s) to transcribe")
23
  parser.add_argument("--model", default=app_config.default_model_name, choices=whisper_models, \
24
  help="name of the Whisper model to use") # medium
25
- parser.add_argument("--model_dir", type=str, default=None, \
26
  help="the path to save model files; uses ~/.cache/whisper by default")
27
- parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", \
28
  help="device to use for PyTorch inference")
29
- parser.add_argument("--output_dir", "-o", type=str, default=".", \
30
  help="directory to save the outputs")
31
- parser.add_argument("--verbose", type=str2bool, default=True, \
32
  help="whether to print out the progress and debug messages")
33
 
34
- parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], \
35
  help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
36
- parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES), \
37
  help="language spoken in the audio, specify None to perform language detection")
38
 
39
  parser.add_argument("--vad", type=str, default=app_config.default_vad, choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], \
40
  help="The voice activity detection algorithm to use") # silero-vad
41
- parser.add_argument("--vad_merge_window", type=optional_float, default=5, \
42
  help="The window size (in seconds) to merge voice segments")
43
- parser.add_argument("--vad_max_merge_size", type=optional_float, default=30,\
44
  help="The maximum size (in seconds) of a voice segment")
45
- parser.add_argument("--vad_padding", type=optional_float, default=1, \
46
  help="The padding (in seconds) to add to each voice segment")
47
- parser.add_argument("--vad_prompt_window", type=optional_float, default=3, \
48
  help="The window size of the prompt to pass to Whisper")
49
  parser.add_argument("--vad_cpu_cores", type=int, default=app_config.vad_cpu_cores, \
50
  help="The number of CPU cores to use for VAD pre-processing.") # 1
@@ -53,33 +56,33 @@ def cli():
53
  parser.add_argument("--auto_parallel", type=bool, default=app_config.auto_parallel, \
54
  help="True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.") # False
55
 
56
- parser.add_argument("--temperature", type=float, default=0, \
57
  help="temperature to use for sampling")
58
- parser.add_argument("--best_of", type=optional_int, default=5, \
59
  help="number of candidates when sampling with non-zero temperature")
60
- parser.add_argument("--beam_size", type=optional_int, default=5, \
61
  help="number of beams in beam search, only applicable when temperature is zero")
62
- parser.add_argument("--patience", type=float, default=None, \
63
  help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
64
- parser.add_argument("--length_penalty", type=float, default=None, \
65
  help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple lengt normalization by default")
66
 
67
- parser.add_argument("--suppress_tokens", type=str, default="-1", \
68
  help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
69
- parser.add_argument("--initial_prompt", type=str, default=None, \
70
  help="optional text to provide as a prompt for the first window.")
71
- parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, \
72
  help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
73
- parser.add_argument("--fp16", type=str2bool, default=True, \
74
  help="whether to perform inference in fp16; True by default")
75
 
76
- parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, \
77
  help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
78
- parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, \
79
  help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
80
- parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, \
81
  help="if the average log probability is lower than this value, treat the decoding as failed")
82
- parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, \
83
  help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
84
 
85
  args = parser.parse_args().__dict__
14
  from src.whisperContainer import WhisperContainer
15
 
16
  def cli():
17
+ app_config = ApplicationConfig.create_default()
18
  whisper_models = app_config.get_model_names()
19
 
20
+ # For the CLI, we fallback to saving the output to the current directory
21
+ output_dir = app_config.output_dir if app_config.output_dir is not None else "."
22
+
23
  parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
24
  parser.add_argument("audio", nargs="+", type=str, \
25
  help="audio file(s) to transcribe")
26
  parser.add_argument("--model", default=app_config.default_model_name, choices=whisper_models, \
27
  help="name of the Whisper model to use") # medium
28
+ parser.add_argument("--model_dir", type=str, default=app_config.model_dir, \
29
  help="the path to save model files; uses ~/.cache/whisper by default")
30
+ parser.add_argument("--device", default=app_config.device, \
31
  help="device to use for PyTorch inference")
32
+ parser.add_argument("--output_dir", "-o", type=str, default=output_dir, \
33
  help="directory to save the outputs")
34
+ parser.add_argument("--verbose", type=str2bool, default=app_config.verbose, \
35
  help="whether to print out the progress and debug messages")
36
 
37
+ parser.add_argument("--task", type=str, default=app_config.task, choices=["transcribe", "translate"], \
38
  help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
39
+ parser.add_argument("--language", type=str, default=app_config.language, choices=sorted(LANGUAGES), \
40
  help="language spoken in the audio, specify None to perform language detection")
41
 
42
  parser.add_argument("--vad", type=str, default=app_config.default_vad, choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], \
43
  help="The voice activity detection algorithm to use") # silero-vad
44
+ parser.add_argument("--vad_merge_window", type=optional_float, default=app_config.vad_merge_window, \
45
  help="The window size (in seconds) to merge voice segments")
46
+ parser.add_argument("--vad_max_merge_size", type=optional_float, default=app_config.vad_max_merge_size,\
47
  help="The maximum size (in seconds) of a voice segment")
48
+ parser.add_argument("--vad_padding", type=optional_float, default=app_config.vad_padding, \
49
  help="The padding (in seconds) to add to each voice segment")
50
+ parser.add_argument("--vad_prompt_window", type=optional_float, default=app_config.vad_prompt_window, \
51
  help="The window size of the prompt to pass to Whisper")
52
  parser.add_argument("--vad_cpu_cores", type=int, default=app_config.vad_cpu_cores, \
53
  help="The number of CPU cores to use for VAD pre-processing.") # 1
56
  parser.add_argument("--auto_parallel", type=bool, default=app_config.auto_parallel, \
57
  help="True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.") # False
58
 
59
+ parser.add_argument("--temperature", type=float, default=app_config.temperature, \
60
  help="temperature to use for sampling")
61
+ parser.add_argument("--best_of", type=optional_int, default=app_config.best_of, \
62
  help="number of candidates when sampling with non-zero temperature")
63
+ parser.add_argument("--beam_size", type=optional_int, default=app_config.beam_size, \
64
  help="number of beams in beam search, only applicable when temperature is zero")
65
+ parser.add_argument("--patience", type=float, default=app_config.patience, \
66
  help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
67
+ parser.add_argument("--length_penalty", type=float, default=app_config.length_penalty, \
68
  help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple lengt normalization by default")
69
 
70
+ parser.add_argument("--suppress_tokens", type=str, default=app_config.suppress_tokens, \
71
  help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
72
+ parser.add_argument("--initial_prompt", type=str, default=app_config.initial_prompt, \
73
  help="optional text to provide as a prompt for the first window.")
74
+ parser.add_argument("--condition_on_previous_text", type=str2bool, default=app_config.condition_on_previous_text, \
75
  help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
76
+ parser.add_argument("--fp16", type=str2bool, default=app_config.fp16, \
77
  help="whether to perform inference in fp16; True by default")
78
 
79
+ parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=app_config.temperature_increment_on_fallback, \
80
  help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
81
+ parser.add_argument("--compression_ratio_threshold", type=optional_float, default=app_config.compression_ratio_threshold, \
82
  help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
83
+ parser.add_argument("--logprob_threshold", type=optional_float, default=app_config.logprob_threshold, \
84
  help="if the average log probability is lower than this value, treat the decoding as failed")
85
+ parser.add_argument("--no_speech_threshold", type=optional_float, default=app_config.no_speech_threshold, \
86
  help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
87
 
88
  args = parser.parse_args().__dict__
config.json5 CHANGED
@@ -45,7 +45,9 @@
45
  ],
46
  // Configuration options that will be used if they are not specified in the command line arguments.
47
 
48
- // Maximum audio file length in seconds, or -1 for no limit.
 
 
49
  "input_audio_max_duration": 600,
50
  // True to share the app on HuggingFace.
51
  "share": false,
@@ -53,6 +55,11 @@
53
  "server_name": null,
54
  // The port to bind to.
55
  "server_port": 7860,
 
 
 
 
 
56
  // The default model name.
57
  "default_model_name": "medium",
58
  // The default VAD.
@@ -65,6 +72,50 @@
65
  "vad_process_timeout": 1800,
66
  // True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.
67
  "auto_parallel": false,
68
- // Directory to save the outputs
69
- "output_dir": null
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  }
45
  ],
46
  // Configuration options that will be used if they are not specified in the command line arguments.
47
 
48
+ // * WEBUI options *
49
+
50
+ // Maximum audio file length in seconds, or -1 for no limit. Ignored by CLI.
51
  "input_audio_max_duration": 600,
52
  // True to share the app on HuggingFace.
53
  "share": false,
55
  "server_name": null,
56
  // The port to bind to.
57
  "server_port": 7860,
58
+ // Whether or not to automatically delete all uploaded files, to save disk space
59
+ "delete_uploaded_files": true,
60
+
61
+ // * General options *
62
+
63
  // The default model name.
64
  "default_model_name": "medium",
65
  // The default VAD.
72
  "vad_process_timeout": 1800,
73
  // True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.
74
  "auto_parallel": false,
75
+ // Directory to save the outputs (CLI will use the current directory if not specified)
76
+ "output_dir": null,
77
+ // The path to save model files; uses ~/.cache/whisper by default
78
+ "model_dir": null,
79
+ // Device to use for PyTorch inference, or Null to use the default device
80
+ "device": null,
81
+ // Whether to print out the progress and debug messages
82
+ "verbose": true,
83
+ // Whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')
84
+ "task": "transcribe",
85
+ // Language spoken in the audio, specify None to perform language detection
86
+ "language": null,
87
+ // The window size (in seconds) to merge voice segments
88
+ "vad_merge_window": 5,
89
+ // The maximum size (in seconds) of a voice segment
90
+ "vad_max_merge_size": 30,
91
+ // The padding (in seconds) to add to each voice segment
92
+ "vad_padding": 1,
93
+ // The window size of the prompt to pass to Whisper
94
+ "vad_prompt_window": 3,
95
+ // Temperature to use for sampling
96
+ "temperature": 0,
97
+ // Number of candidates when sampling with non-zero temperature
98
+ "best_of": 5,
99
+ // Number of beams in beam search, only applicable when temperature is zero
100
+ "beam_size": 5,
101
+ // Optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search
102
+ "patience": null,
103
+ // Optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default
104
+ "length_penalty": null,
105
+ // Comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations
106
+ "suppress_tokens": "-1",
107
+ // Optional text to provide as a prompt for the first window
108
+ "initial_prompt": null,
109
+ // If True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop
110
+ "condition_on_previous_text": true,
111
+ // Whether to perform inference in fp16; True by default
112
+ "fp16": true,
113
+ // Temperature to increase when falling back when the decoding fails to meet either of the thresholds below
114
+ "temperature_increment_on_fallback": 0.2,
115
+ // If the gzip compression ratio is higher than this value, treat the decoding as failed
116
+ "compression_ratio_threshold": 2.4,
117
+ // If the average log probability is lower than this value, treat the decoding as failed
118
+ "logprob_threshold": -1.0,
119
+ // If the probability of the <no-speech> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence
120
+ "no_speech_threshold": 0.6
121
  }
src/config.py CHANGED
@@ -3,6 +3,8 @@ import urllib
3
  import os
4
  from typing import List
5
  from urllib.parse import urlparse
 
 
6
 
7
  from tqdm import tqdm
8
 
@@ -101,14 +103,33 @@ class ModelConfig:
101
 
102
  class ApplicationConfig:
103
  def __init__(self, models: List[ModelConfig] = [], input_audio_max_duration: int = 600,
104
- share: bool = False, server_name: str = None, server_port: int = 7860, default_model_name: str = "medium",
105
- default_vad: str = "silero-vad", vad_parallel_devices: str = "", vad_cpu_cores: int = 1, vad_process_timeout: int = 1800,
106
- auto_parallel: bool = False, output_dir: str = None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  self.models = models
 
 
108
  self.input_audio_max_duration = input_audio_max_duration
109
  self.share = share
110
  self.server_name = server_name
111
  self.server_port = server_port
 
 
112
  self.default_model_name = default_model_name
113
  self.default_vad = default_vad
114
  self.vad_parallel_devices = vad_parallel_devices
@@ -117,9 +138,48 @@ class ApplicationConfig:
117
  self.auto_parallel = auto_parallel
118
  self.output_dir = output_dir
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  def get_model_names(self):
121
  return [ x.name for x in self.models ]
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  @staticmethod
124
  def parse_file(config_path: str):
125
  import json5
@@ -131,4 +191,4 @@ class ApplicationConfig:
131
 
132
  models = [ ModelConfig(**x) for x in data_models ]
133
 
134
- return ApplicationConfig(models, **data)
3
  import os
4
  from typing import List
5
  from urllib.parse import urlparse
6
+ import json5
7
+ import torch
8
 
9
  from tqdm import tqdm
10
 
103
 
104
  class ApplicationConfig:
105
  def __init__(self, models: List[ModelConfig] = [], input_audio_max_duration: int = 600,
106
+ share: bool = False, server_name: str = None, server_port: int = 7860, delete_uploaded_files: bool = True,
107
+ default_model_name: str = "medium", default_vad: str = "silero-vad",
108
+ vad_parallel_devices: str = "", vad_cpu_cores: int = 1, vad_process_timeout: int = 1800,
109
+ auto_parallel: bool = False, output_dir: str = None,
110
+ model_dir: str = None, device: str = None,
111
+ verbose: bool = True, task: str = "transcribe", language: str = None,
112
+ vad_merge_window: float = 5, vad_max_merge_size: float = 30,
113
+ vad_padding: float = 1, vad_prompt_window: float = 3,
114
+ temperature: float = 0, best_of: int = 5, beam_size: int = 5,
115
+ patience: float = None, length_penalty: float = None,
116
+ suppress_tokens: str = "-1", initial_prompt: str = None,
117
+ condition_on_previous_text: bool = True, fp16: bool = True,
118
+ temperature_increment_on_fallback: float = 0.2, compression_ratio_threshold: float = 2.4,
119
+ logprob_threshold: float = -1.0, no_speech_threshold: float = 0.6):
120
+
121
+ if device is None:
122
+ device = "cuda" if torch.cuda.is_available() else "cpu"
123
+
124
  self.models = models
125
+
126
+ # WebUI settings
127
  self.input_audio_max_duration = input_audio_max_duration
128
  self.share = share
129
  self.server_name = server_name
130
  self.server_port = server_port
131
+ self.delete_uploaded_files = delete_uploaded_files
132
+
133
  self.default_model_name = default_model_name
134
  self.default_vad = default_vad
135
  self.vad_parallel_devices = vad_parallel_devices
138
  self.auto_parallel = auto_parallel
139
  self.output_dir = output_dir
140
 
141
+ self.model_dir = model_dir
142
+ self.device = device
143
+ self.verbose = verbose
144
+ self.task = task
145
+ self.language = language
146
+ self.vad_merge_window = vad_merge_window
147
+ self.vad_max_merge_size = vad_max_merge_size
148
+ self.vad_padding = vad_padding
149
+ self.vad_prompt_window = vad_prompt_window
150
+ self.temperature = temperature
151
+ self.best_of = best_of
152
+ self.beam_size = beam_size
153
+ self.patience = patience
154
+ self.length_penalty = length_penalty
155
+ self.suppress_tokens = suppress_tokens
156
+ self.initial_prompt = initial_prompt
157
+ self.condition_on_previous_text = condition_on_previous_text
158
+ self.fp16 = fp16
159
+ self.temperature_increment_on_fallback = temperature_increment_on_fallback
160
+ self.compression_ratio_threshold = compression_ratio_threshold
161
+ self.logprob_threshold = logprob_threshold
162
+ self.no_speech_threshold = no_speech_threshold
163
+
164
  def get_model_names(self):
165
  return [ x.name for x in self.models ]
166
 
167
+ def update(self, **new_values):
168
+ result = ApplicationConfig(**self.__dict__)
169
+
170
+ for key, value in new_values.items():
171
+ setattr(result, key, value)
172
+ return result
173
+
174
+ @staticmethod
175
+ def create_default(**kwargs):
176
+ app_config = ApplicationConfig.parse_file(os.environ.get("WHISPER_WEBUI_CONFIG", "config.json5"))
177
+
178
+ # Update with kwargs
179
+ if len(kwargs) > 0:
180
+ app_config = app_config.update(**kwargs)
181
+ return app_config
182
+
183
  @staticmethod
184
  def parse_file(config_path: str):
185
  import json5
191
 
192
  models = [ ModelConfig(**x) for x in data_models ]
193
 
194
+ return ApplicationConfig(models, **data)