from enum import Enum import urllib import os from typing import List from urllib.parse import urlparse import json5 import torch from tqdm import tqdm class ModelConfig: def __init__(self, name: str, url: str, path: str = None, type: str = "whisper"): """ Initialize a model configuration. name: Name of the model url: URL to download the model from path: Path to the model file. If not set, the model will be downloaded from the URL. type: Type of model. Can be whisper or huggingface. """ self.name = name self.url = url self.path = path self.type = type VAD_INITIAL_PROMPT_MODE_VALUES=["prepend_all_segments", "prepend_first_segment", "json_prompt_mode"] class VadInitialPromptMode(Enum): PREPEND_ALL_SEGMENTS = 1 PREPREND_FIRST_SEGMENT = 2 JSON_PROMPT_MODE = 3 @staticmethod def from_string(s: str): normalized = s.lower() if s is not None else None if normalized == "prepend_all_segments": return VadInitialPromptMode.PREPEND_ALL_SEGMENTS elif normalized == "prepend_first_segment": return VadInitialPromptMode.PREPREND_FIRST_SEGMENT elif normalized == "json_prompt_mode": return VadInitialPromptMode.JSON_PROMPT_MODE elif normalized is not None and normalized != "": raise ValueError(f"Invalid value for VadInitialPromptMode: {s}") else: return None class ApplicationConfig: def __init__(self, models: List[ModelConfig] = [], nllb_models: List[ModelConfig] = [], input_audio_max_duration: int = 600, share: bool = False, server_name: str = None, server_port: int = 7860, queue_concurrency_count: int = 1, delete_uploaded_files: bool = True, whisper_implementation: str = "whisper", default_model_name: str = "medium", default_nllb_model_name: str = "distilled-600M", default_vad: str = "silero-vad", vad_parallel_devices: str = "", vad_cpu_cores: int = 1, vad_process_timeout: int = 1800, auto_parallel: bool = False, output_dir: str = None, model_dir: str = None, device: str = None, verbose: bool = True, task: str = "transcribe", language: str = None, vad_initial_prompt_mode: str = "prepend_first_segment ", vad_merge_window: float = 5, vad_max_merge_size: float = 30, vad_padding: float = 1, vad_prompt_window: float = 3, temperature: float = 0, best_of: int = 5, beam_size: int = 5, patience: float = None, length_penalty: float = None, suppress_tokens: str = "-1", initial_prompt: str = None, condition_on_previous_text: bool = True, fp16: bool = True, compute_type: str = "float16", temperature_increment_on_fallback: float = 0.2, compression_ratio_threshold: float = 2.4, logprob_threshold: float = -1.0, no_speech_threshold: float = 0.6, # Word timestamp settings word_timestamps: bool = False, prepend_punctuations: str = "\"\'“¿([{-", append_punctuations: str = "\"\'.。,,!!??::”)]}、", highlight_words: bool = False, # Diarization auth_token: str = None, diarization: bool = False, diarization_speakers: int = 2, diarization_min_speakers: int = 1, diarization_max_speakers: int = 5, diarization_process_timeout: int = 60): self.models = models self.nllb_models = nllb_models # WebUI settings self.input_audio_max_duration = input_audio_max_duration self.share = share self.server_name = server_name self.server_port = server_port self.queue_concurrency_count = queue_concurrency_count self.delete_uploaded_files = delete_uploaded_files self.whisper_implementation = whisper_implementation self.default_model_name = default_model_name self.default_nllb_model_name = default_nllb_model_name self.default_vad = default_vad self.vad_parallel_devices = vad_parallel_devices self.vad_cpu_cores = vad_cpu_cores self.vad_process_timeout = vad_process_timeout self.auto_parallel = auto_parallel self.output_dir = output_dir self.model_dir = model_dir self.device = device self.verbose = verbose self.task = task self.language = language self.vad_initial_prompt_mode = vad_initial_prompt_mode self.vad_merge_window = vad_merge_window self.vad_max_merge_size = vad_max_merge_size self.vad_padding = vad_padding self.vad_prompt_window = vad_prompt_window self.temperature = temperature self.best_of = best_of self.beam_size = beam_size self.patience = patience self.length_penalty = length_penalty self.suppress_tokens = suppress_tokens self.initial_prompt = initial_prompt self.condition_on_previous_text = condition_on_previous_text self.fp16 = fp16 self.compute_type = compute_type self.temperature_increment_on_fallback = temperature_increment_on_fallback self.compression_ratio_threshold = compression_ratio_threshold self.logprob_threshold = logprob_threshold self.no_speech_threshold = no_speech_threshold # Word timestamp settings self.word_timestamps = word_timestamps self.prepend_punctuations = prepend_punctuations self.append_punctuations = append_punctuations self.highlight_words = highlight_words # Diarization settings self.auth_token = auth_token self.diarization = diarization self.diarization_speakers = diarization_speakers self.diarization_min_speakers = diarization_min_speakers self.diarization_max_speakers = diarization_max_speakers self.diarization_process_timeout = diarization_process_timeout def get_model_names(self): return [ x.name for x in self.models ] def get_nllb_model_names(self): return [ x.name for x in self.nllb_models ] def update(self, **new_values): result = ApplicationConfig(**self.__dict__) for key, value in new_values.items(): setattr(result, key, value) return result @staticmethod def create_default(**kwargs): app_config = ApplicationConfig.parse_file(os.environ.get("WHISPER_WEBUI_CONFIG", "config.json5")) # Update with kwargs if len(kwargs) > 0: app_config = app_config.update(**kwargs) return app_config @staticmethod def parse_file(config_path: str): import json5 with open(config_path, "r", encoding="utf-8") as f: # Load using json5 data = json5.load(f) data_models = data.pop("models", []) data_nllb_models = data.pop("nllb_models", []) models = [ ModelConfig(**x) for x in data_models ] nllb_models = [ ModelConfig(**x) for x in data_nllb_models ] return ApplicationConfig(models, nllb_models, **data)