File size: 7,387 Bytes
b6ac700
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8077be2
b6ac700
 
 
8077be2
b6ac700
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1da02d
 
 
3cd7d59
 
b6ac700
 
8077be2
b6ac700
 
 
 
 
 
 
 
 
 
 
8077be2
b6ac700
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1da02d
 
 
 
 
 
3cd7d59
a1da02d
b6ac700
 
 
8077be2
 
 
b6ac700
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8077be2
 
b6ac700
8077be2
b6ac700
8077be2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
from enum import Enum
import urllib

import os
from typing import List
from urllib.parse import urlparse
import json5
import torch

from tqdm import tqdm

class ModelConfig:
    def __init__(self, name: str, url: str, path: str = None, type: str = "whisper"):
        """
        Initialize a model configuration.

        name: Name of the model
        url: URL to download the model from
        path: Path to the model file. If not set, the model will be downloaded from the URL.
        type: Type of model. Can be whisper or huggingface.
        """
        self.name = name
        self.url = url
        self.path = path
        self.type = type

VAD_INITIAL_PROMPT_MODE_VALUES=["prepend_all_segments", "prepend_first_segment", "json_prompt_mode"]

class VadInitialPromptMode(Enum):
    PREPEND_ALL_SEGMENTS = 1
    PREPREND_FIRST_SEGMENT = 2
    JSON_PROMPT_MODE = 3

    @staticmethod
    def from_string(s: str):
        normalized = s.lower() if s is not None else None

        if normalized == "prepend_all_segments":
            return VadInitialPromptMode.PREPEND_ALL_SEGMENTS
        elif normalized == "prepend_first_segment":
            return VadInitialPromptMode.PREPREND_FIRST_SEGMENT
        elif normalized == "json_prompt_mode":
            return VadInitialPromptMode.JSON_PROMPT_MODE
        elif normalized is not None and normalized != "":
            raise ValueError(f"Invalid value for VadInitialPromptMode: {s}")
        else:
            return None

class ApplicationConfig:
    def __init__(self, models: List[ModelConfig] = [], nllb_models: List[ModelConfig] = [], input_audio_max_duration: int = 600, 
                 share: bool = False, server_name: str = None, server_port: int = 7860, 
                 queue_concurrency_count: int = 1, delete_uploaded_files: bool = True,
                 whisper_implementation: str = "whisper",
                 default_model_name: str = "medium", default_nllb_model_name: str = "distilled-600M", default_vad: str = "silero-vad", 
                 vad_parallel_devices: str = "", vad_cpu_cores: int = 1, vad_process_timeout: int = 1800, 
                 auto_parallel: bool = False, output_dir: str = None,
                 model_dir: str = None, device: str = None, 
                 verbose: bool = True, task: str = "transcribe", language: str = None,
                 vad_initial_prompt_mode: str = "prepend_first_segment ", 
                 vad_merge_window: float = 5, vad_max_merge_size: float = 30,
                 vad_padding: float = 1, vad_prompt_window: float = 3,
                 temperature: float = 0, best_of: int = 5, beam_size: int = 5,
                 patience: float = None, length_penalty: float = None,
                 suppress_tokens: str = "-1", initial_prompt: str = None,
                 condition_on_previous_text: bool = True, fp16: bool = True,
                 compute_type: str = "float16", 
                 temperature_increment_on_fallback: float = 0.2, compression_ratio_threshold: float = 2.4,
                 logprob_threshold: float = -1.0, no_speech_threshold: float = 0.6,
                 # Word timestamp settings
                 word_timestamps: bool = False, prepend_punctuations: str = "\"\'“¿([{-",
                 append_punctuations: str = "\"\'.。,,!!??::”)]}、", 
                 highlight_words: bool = False,
                 # Diarization
                 auth_token: str = None, diarization: bool = False, diarization_speakers: int = 2,
                 diarization_min_speakers: int = 1, diarization_max_speakers: int = 5,
                 diarization_process_timeout: int = 60):
        
        self.models = models
        self.nllb_models = nllb_models
        
        # WebUI settings
        self.input_audio_max_duration = input_audio_max_duration
        self.share = share
        self.server_name = server_name
        self.server_port = server_port
        self.queue_concurrency_count = queue_concurrency_count
        self.delete_uploaded_files = delete_uploaded_files

        self.whisper_implementation = whisper_implementation
        self.default_model_name = default_model_name
        self.default_nllb_model_name = default_nllb_model_name
        self.default_vad = default_vad
        self.vad_parallel_devices = vad_parallel_devices
        self.vad_cpu_cores = vad_cpu_cores
        self.vad_process_timeout = vad_process_timeout
        self.auto_parallel = auto_parallel
        self.output_dir = output_dir

        self.model_dir = model_dir
        self.device = device
        self.verbose = verbose
        self.task = task
        self.language = language
        self.vad_initial_prompt_mode = vad_initial_prompt_mode
        self.vad_merge_window = vad_merge_window
        self.vad_max_merge_size = vad_max_merge_size
        self.vad_padding = vad_padding
        self.vad_prompt_window = vad_prompt_window
        self.temperature = temperature
        self.best_of = best_of
        self.beam_size = beam_size
        self.patience = patience
        self.length_penalty = length_penalty
        self.suppress_tokens = suppress_tokens
        self.initial_prompt = initial_prompt
        self.condition_on_previous_text = condition_on_previous_text
        self.fp16 = fp16
        self.compute_type = compute_type
        self.temperature_increment_on_fallback = temperature_increment_on_fallback
        self.compression_ratio_threshold = compression_ratio_threshold
        self.logprob_threshold = logprob_threshold
        self.no_speech_threshold = no_speech_threshold
        
        # Word timestamp settings
        self.word_timestamps = word_timestamps
        self.prepend_punctuations = prepend_punctuations
        self.append_punctuations = append_punctuations
        self.highlight_words = highlight_words
        
        # Diarization settings
        self.auth_token = auth_token
        self.diarization = diarization
        self.diarization_speakers = diarization_speakers
        self.diarization_min_speakers = diarization_min_speakers
        self.diarization_max_speakers = diarization_max_speakers
        self.diarization_process_timeout = diarization_process_timeout

    def get_model_names(self):
        return [ x.name for x in self.models ]

    def get_nllb_model_names(self):
        return [ x.name for x in self.nllb_models ]

    def update(self, **new_values):
        result = ApplicationConfig(**self.__dict__)

        for key, value in new_values.items():
            setattr(result, key, value)
        return result

    @staticmethod
    def create_default(**kwargs):
        app_config = ApplicationConfig.parse_file(os.environ.get("WHISPER_WEBUI_CONFIG", "config.json5"))

        # Update with kwargs
        if len(kwargs) > 0:
            app_config = app_config.update(**kwargs)
        return app_config

    @staticmethod
    def parse_file(config_path: str):
        import json5

        with open(config_path, "r", encoding="utf-8") as f:
            # Load using json5
            data = json5.load(f)
            data_models = data.pop("models", [])
            data_nllb_models = data.pop("nllb_models", [])
            
            models = [ ModelConfig(**x) for x in data_models ]
            nllb_models = [ ModelConfig(**x) for x in data_nllb_models ]

            return ApplicationConfig(models, nllb_models, **data)