SoybeanMilk commited on
Commit
22034c0
1 Parent(s): 2601bfd

Add support for the ALMA model.

Browse files
Files changed (5) hide show
  1. app.py +19 -0
  2. config.json5 +7 -0
  3. config.py +177 -0
  4. translationModel.py +259 -0
  5. utils.py +314 -0
app.py CHANGED
@@ -231,6 +231,8 @@ class WhisperTranscriber:
231
  nllbLangName: str = decodeOptions.pop("nllbLangName")
232
  mt5ModelName: str = decodeOptions.pop("mt5ModelName")
233
  mt5LangName: str = decodeOptions.pop("mt5LangName")
 
 
234
 
235
  translationBatchSize: int = decodeOptions.pop("translationBatchSize")
236
  translationNoRepeatNgramSize: int = decodeOptions.pop("translationNoRepeatNgramSize")
@@ -334,6 +336,10 @@ class WhisperTranscriber:
334
  selectedModelName = mt5ModelName if mt5ModelName is not None and len(mt5ModelName) > 0 else "mt5-zh-ja-en-trimmed/K024"
335
  selectedModel = next((modelConfig for modelConfig in self.app_config.models["mt5"] if modelConfig.name == selectedModelName), None)
336
  translationLang = get_lang_from_m2m100_name(mt5LangName)
 
 
 
 
337
 
338
  if translationLang is not None:
339
  translationModel = TranslationModel(modelConfig=selectedModel, whisperLang=whisperLang, translationLang=translationLang, batchSize=translationBatchSize, noRepeatNgramSize=translationNoRepeatNgramSize, numBeams=translationNumBeams)
@@ -826,6 +832,7 @@ def create_ui(app_config: ApplicationConfig):
826
  nllb_models = app_config.get_model_names("nllb")
827
  m2m100_models = app_config.get_model_names("m2m100")
828
  mt5_models = app_config.get_model_names("mt5")
 
829
 
830
  common_whisper_inputs = lambda : {
831
  gr.Dropdown(label="Whisper - Model (for audio)", choices=whisper_models, value=app_config.default_model_name, elem_id="whisperModelName"),
@@ -843,6 +850,10 @@ def create_ui(app_config: ApplicationConfig):
843
  gr.Dropdown(label="MT5 - Model (for translate)", choices=mt5_models, elem_id="mt5ModelName"),
844
  gr.Dropdown(label="MT5 - Language", choices=sorted(get_lang_m2m100_names(["en", "ja", "zh"])), elem_id="mt5LangName"),
845
  }
 
 
 
 
846
 
847
  common_translation_inputs = lambda : {
848
  gr.Number(label="Translation - Batch Size", precision=0, value=app_config.translation_batch_size, elem_id="translationBatchSize"),
@@ -903,9 +914,13 @@ def create_ui(app_config: ApplicationConfig):
903
  with gr.Tab(label="MT5") as simpleMT5Tab:
904
  with gr.Row():
905
  simpleInputDict.update(common_mt5_inputs())
 
 
 
906
  simpleM2M100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [simpleTranslateInput] )
907
  simpleNllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [simpleTranslateInput] )
908
  simpleMT5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [simpleTranslateInput] )
 
909
  with gr.Column():
910
  with gr.Tab(label="URL") as simpleUrlTab:
911
  simpleInputDict.update({gr.Text(label="URL (YouTube, etc.)", elem_id = "urlData")})
@@ -962,9 +977,13 @@ def create_ui(app_config: ApplicationConfig):
962
  with gr.Tab(label="MT5") as fullMT5Tab:
963
  with gr.Row():
964
  fullInputDict.update(common_mt5_inputs())
 
 
 
965
  fullM2M100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [fullTranslateInput] )
966
  fullNllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [fullTranslateInput] )
967
  fullMT5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [fullTranslateInput] )
 
968
  with gr.Column():
969
  with gr.Tab(label="URL") as fullUrlTab:
970
  fullInputDict.update({gr.Text(label="URL (YouTube, etc.)", elem_id = "urlData")})
 
231
  nllbLangName: str = decodeOptions.pop("nllbLangName")
232
  mt5ModelName: str = decodeOptions.pop("mt5ModelName")
233
  mt5LangName: str = decodeOptions.pop("mt5LangName")
234
+ ALMAModelName: str = decodeOptions.pop("ALMAModelName")
235
+ ALMALangName: str = decodeOptions.pop("ALMALangName")
236
 
237
  translationBatchSize: int = decodeOptions.pop("translationBatchSize")
238
  translationNoRepeatNgramSize: int = decodeOptions.pop("translationNoRepeatNgramSize")
 
336
  selectedModelName = mt5ModelName if mt5ModelName is not None and len(mt5ModelName) > 0 else "mt5-zh-ja-en-trimmed/K024"
337
  selectedModel = next((modelConfig for modelConfig in self.app_config.models["mt5"] if modelConfig.name == selectedModelName), None)
338
  translationLang = get_lang_from_m2m100_name(mt5LangName)
339
+ elif translateInput == "ALMA" and ALMALangName is not None and len(ALMALangName) > 0:
340
+ selectedModelName = ALMAModelName if ALMAModelName is not None and len(ALMAModelName) > 0 else "ALMA-13B-GPTQ/TheBloke"
341
+ selectedModel = next((modelConfig for modelConfig in self.app_config.models["ALMA"] if modelConfig.name == selectedModelName), None)
342
+ translationLang = get_lang_from_m2m100_name(ALMALangName)
343
 
344
  if translationLang is not None:
345
  translationModel = TranslationModel(modelConfig=selectedModel, whisperLang=whisperLang, translationLang=translationLang, batchSize=translationBatchSize, noRepeatNgramSize=translationNoRepeatNgramSize, numBeams=translationNumBeams)
 
832
  nllb_models = app_config.get_model_names("nllb")
833
  m2m100_models = app_config.get_model_names("m2m100")
834
  mt5_models = app_config.get_model_names("mt5")
835
+ ALMA_models = app_config.get_model_names("ALMA")
836
 
837
  common_whisper_inputs = lambda : {
838
  gr.Dropdown(label="Whisper - Model (for audio)", choices=whisper_models, value=app_config.default_model_name, elem_id="whisperModelName"),
 
850
  gr.Dropdown(label="MT5 - Model (for translate)", choices=mt5_models, elem_id="mt5ModelName"),
851
  gr.Dropdown(label="MT5 - Language", choices=sorted(get_lang_m2m100_names(["en", "ja", "zh"])), elem_id="mt5LangName"),
852
  }
853
+ common_ALMA_inputs = lambda : {
854
+ gr.Dropdown(label="ALMA - Model (for translate)", choices=ALMA_models, elem_id="ALMAModelName"),
855
+ gr.Dropdown(label="ALMA - Language", choices=sorted(get_lang_m2m100_names(["en", "ja", "zh"])), elem_id="ALMALangName"),
856
+ }
857
 
858
  common_translation_inputs = lambda : {
859
  gr.Number(label="Translation - Batch Size", precision=0, value=app_config.translation_batch_size, elem_id="translationBatchSize"),
 
914
  with gr.Tab(label="MT5") as simpleMT5Tab:
915
  with gr.Row():
916
  simpleInputDict.update(common_mt5_inputs())
917
+ with gr.Tab(label="ALMA") as simpleALMATab:
918
+ with gr.Row():
919
+ simpleInputDict.update(common_ALMA_inputs())
920
  simpleM2M100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [simpleTranslateInput] )
921
  simpleNllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [simpleTranslateInput] )
922
  simpleMT5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [simpleTranslateInput] )
923
+ simpleALMATab.select(fn=lambda: "ALMA", inputs = [], outputs= [simpleTranslateInput] )
924
  with gr.Column():
925
  with gr.Tab(label="URL") as simpleUrlTab:
926
  simpleInputDict.update({gr.Text(label="URL (YouTube, etc.)", elem_id = "urlData")})
 
977
  with gr.Tab(label="MT5") as fullMT5Tab:
978
  with gr.Row():
979
  fullInputDict.update(common_mt5_inputs())
980
+ with gr.Tab(label="ALMA") as fullALMATab:
981
+ with gr.Row():
982
+ fullInputDict.update(common_ALMA_inputs())
983
  fullM2M100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [fullTranslateInput] )
984
  fullNllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [fullTranslateInput] )
985
  fullMT5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [fullTranslateInput] )
986
+ fullALMATab.select(fn=lambda: "ALMA", inputs = [], outputs= [fullTranslateInput] )
987
  with gr.Column():
988
  with gr.Tab(label="URL") as fullUrlTab:
989
  fullInputDict.update({gr.Text(label="URL (YouTube, etc.)", elem_id = "urlData")})
config.json5 CHANGED
@@ -187,6 +187,13 @@
187
  "url": "engmatic-earth/mt5-zh-ja-en-trimmed-fine-tuned-v1",
188
  "type": "huggingface"
189
  }
 
 
 
 
 
 
 
190
  ]
191
  },
192
  // Configuration options that will be used if they are not specified in the command line arguments.
 
187
  "url": "engmatic-earth/mt5-zh-ja-en-trimmed-fine-tuned-v1",
188
  "type": "huggingface"
189
  }
190
+ ],
191
+ "ALMA": [
192
+ {
193
+ "name": "ALMA-13B-GPTQ/TheBloke",
194
+ "url": "TheBloke/ALMA-13B-GPTQ",
195
+ "type": "huggingface",
196
+ },
197
  ]
198
  },
199
  // Configuration options that will be used if they are not specified in the command line arguments.
config.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+ import os
4
+ from typing import List, Dict, Literal
5
+
6
+
7
+ class ModelConfig:
8
+ def __init__(self, name: str, url: str, path: str = None, type: str = "whisper", tokenizer_url: str = None):
9
+ """
10
+ Initialize a model configuration.
11
+
12
+ name: Name of the model
13
+ url: URL to download the model from
14
+ path: Path to the model file. If not set, the model will be downloaded from the URL.
15
+ type: Type of model. Can be whisper or huggingface.
16
+ """
17
+ self.name = name
18
+ self.url = url
19
+ self.path = path
20
+ self.type = type
21
+ self.tokenizer_url = tokenizer_url
22
+
23
+ VAD_INITIAL_PROMPT_MODE_VALUES=["prepend_all_segments", "prepend_first_segment", "json_prompt_mode"]
24
+
25
+ class VadInitialPromptMode(Enum):
26
+ PREPEND_ALL_SEGMENTS = 1
27
+ PREPREND_FIRST_SEGMENT = 2
28
+ JSON_PROMPT_MODE = 3
29
+
30
+ @staticmethod
31
+ def from_string(s: str):
32
+ normalized = s.lower() if s is not None and len(s) > 0 else None
33
+
34
+ if normalized == "prepend_all_segments":
35
+ return VadInitialPromptMode.PREPEND_ALL_SEGMENTS
36
+ elif normalized == "prepend_first_segment":
37
+ return VadInitialPromptMode.PREPREND_FIRST_SEGMENT
38
+ elif normalized == "json_prompt_mode":
39
+ return VadInitialPromptMode.JSON_PROMPT_MODE
40
+ elif normalized is not None and normalized != "":
41
+ raise ValueError(f"Invalid value for VadInitialPromptMode: {s}")
42
+ else:
43
+ return None
44
+
45
+ class ApplicationConfig:
46
+ def __init__(self, models: Dict[Literal["whisper", "m2m100", "nllb", "mt5", "ALMA"], List[ModelConfig]],
47
+ input_audio_max_duration: int = 600, share: bool = False, server_name: str = None, server_port: int = 7860,
48
+ queue_concurrency_count: int = 1, delete_uploaded_files: bool = True,
49
+ whisper_implementation: str = "whisper", default_model_name: str = "medium",
50
+ default_nllb_model_name: str = "distilled-600M", default_vad: str = "silero-vad",
51
+ vad_parallel_devices: str = "", vad_cpu_cores: int = 1, vad_process_timeout: int = 1800,
52
+ auto_parallel: bool = False, output_dir: str = None,
53
+ model_dir: str = None, device: str = None,
54
+ verbose: bool = True, task: str = "transcribe", language: str = None,
55
+ vad_initial_prompt_mode: str = "prepend_first_segment ",
56
+ vad_merge_window: float = 5, vad_max_merge_size: float = 30,
57
+ vad_padding: float = 1, vad_prompt_window: float = 3,
58
+ temperature: float = 0, best_of: int = 5, beam_size: int = 5,
59
+ patience: float = None, length_penalty: float = None,
60
+ suppress_tokens: str = "-1", initial_prompt: str = None,
61
+ condition_on_previous_text: bool = True, fp16: bool = True,
62
+ compute_type: str = "float16",
63
+ temperature_increment_on_fallback: float = 0.2, compression_ratio_threshold: float = 2.4,
64
+ logprob_threshold: float = -1.0, no_speech_threshold: float = 0.6,
65
+ repetition_penalty: float = 1.0, no_repeat_ngram_size: int = 0,
66
+ # Word timestamp settings
67
+ word_timestamps: bool = False, prepend_punctuations: str = "\"\'“¿([{-",
68
+ append_punctuations: str = "\"\'.。,,!!??::”)]}、",
69
+ highlight_words: bool = False,
70
+ # Diarization
71
+ auth_token: str = None, diarization: bool = False, diarization_speakers: int = 2,
72
+ diarization_min_speakers: int = 1, diarization_max_speakers: int = 5,
73
+ diarization_process_timeout: int = 60,
74
+ # Translation
75
+ translation_batch_size: int = 2,
76
+ translation_no_repeat_ngram_size: int = 3,
77
+ translation_num_beams: int = 2,
78
+ ):
79
+
80
+ self.models = models
81
+
82
+ # WebUI settings
83
+ self.input_audio_max_duration = input_audio_max_duration
84
+ self.share = share
85
+ self.server_name = server_name
86
+ self.server_port = server_port
87
+ self.queue_concurrency_count = queue_concurrency_count
88
+ self.delete_uploaded_files = delete_uploaded_files
89
+
90
+ self.whisper_implementation = whisper_implementation
91
+ self.default_model_name = default_model_name
92
+ self.default_nllb_model_name = default_nllb_model_name
93
+ self.default_vad = default_vad
94
+ self.vad_parallel_devices = vad_parallel_devices
95
+ self.vad_cpu_cores = vad_cpu_cores
96
+ self.vad_process_timeout = vad_process_timeout
97
+ self.auto_parallel = auto_parallel
98
+ self.output_dir = output_dir
99
+
100
+ self.model_dir = model_dir
101
+ self.device = device
102
+ self.verbose = verbose
103
+ self.task = task
104
+ self.language = language
105
+ self.vad_initial_prompt_mode = vad_initial_prompt_mode
106
+ self.vad_merge_window = vad_merge_window
107
+ self.vad_max_merge_size = vad_max_merge_size
108
+ self.vad_padding = vad_padding
109
+ self.vad_prompt_window = vad_prompt_window
110
+ self.temperature = temperature
111
+ self.best_of = best_of
112
+ self.beam_size = beam_size
113
+ self.patience = patience
114
+ self.length_penalty = length_penalty
115
+ self.suppress_tokens = suppress_tokens
116
+ self.initial_prompt = initial_prompt
117
+ self.condition_on_previous_text = condition_on_previous_text
118
+ self.fp16 = fp16
119
+ self.compute_type = compute_type
120
+ self.temperature_increment_on_fallback = temperature_increment_on_fallback
121
+ self.compression_ratio_threshold = compression_ratio_threshold
122
+ self.logprob_threshold = logprob_threshold
123
+ self.no_speech_threshold = no_speech_threshold
124
+ self.repetition_penalty = repetition_penalty
125
+ self.no_repeat_ngram_size = no_repeat_ngram_size
126
+
127
+ # Word timestamp settings
128
+ self.word_timestamps = word_timestamps
129
+ self.prepend_punctuations = prepend_punctuations
130
+ self.append_punctuations = append_punctuations
131
+ self.highlight_words = highlight_words
132
+
133
+ # Diarization settings
134
+ self.auth_token = auth_token
135
+ self.diarization = diarization
136
+ self.diarization_speakers = diarization_speakers
137
+ self.diarization_min_speakers = diarization_min_speakers
138
+ self.diarization_max_speakers = diarization_max_speakers
139
+ self.diarization_process_timeout = diarization_process_timeout
140
+ # Translation
141
+ self.translation_batch_size = translation_batch_size
142
+ self.translation_no_repeat_ngram_size = translation_no_repeat_ngram_size
143
+ self.translation_num_beams = translation_num_beams
144
+
145
+ def get_model_names(self, name: str):
146
+ return [ x.name for x in self.models[name] ]
147
+
148
+ def update(self, **new_values):
149
+ result = ApplicationConfig(**self.__dict__)
150
+
151
+ for key, value in new_values.items():
152
+ setattr(result, key, value)
153
+ return result
154
+
155
+ @staticmethod
156
+ def create_default(**kwargs):
157
+ app_config = ApplicationConfig.parse_file(os.environ.get("WHISPER_WEBUI_CONFIG", "config.json5"))
158
+
159
+ # Update with kwargs
160
+ if len(kwargs) > 0:
161
+ app_config = app_config.update(**kwargs)
162
+ return app_config
163
+
164
+ @staticmethod
165
+ def parse_file(config_path: str):
166
+ import json5
167
+
168
+ with open(config_path, "r", encoding="utf-8") as f:
169
+ # Load using json5
170
+ data = json5.load(f)
171
+ data_models = data.pop("models", [])
172
+ models: Dict[Literal["whisper", "m2m100", "nllb", "mt5", "ALMA"], List[ModelConfig]] = {
173
+ key: [ModelConfig(**item) for item in value]
174
+ for key, value in data_models.items()
175
+ }
176
+
177
+ return ApplicationConfig(models, **data)
translationModel.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+ import huggingface_hub
4
+ import requests
5
+ import torch
6
+
7
+ import ctranslate2
8
+ import transformers
9
+
10
+ import re
11
+
12
+ from typing import Optional
13
+ from src.config import ModelConfig
14
+ from src.translation.translationLangs import TranslationLang, get_lang_from_whisper_code
15
+ from peft import PeftModel
16
+
17
+ class TranslationModel:
18
+ def __init__(
19
+ self,
20
+ modelConfig: ModelConfig,
21
+ device: str = None,
22
+ whisperLang: TranslationLang = None,
23
+ translationLang: TranslationLang = None,
24
+ batchSize: int = 2,
25
+ noRepeatNgramSize: int = 3,
26
+ numBeams: int = 2,
27
+ downloadRoot: Optional[str] = None,
28
+ localFilesOnly: bool = False,
29
+ loadModel: bool = False,
30
+ ):
31
+ """Initializes the M2M100 / Nllb-200 / mt5 model.
32
+
33
+ Args:
34
+ modelConfig: Config of the model to use (distilled-600M, distilled-1.3B,
35
+ 1.3B, 3.3B...) or a path to a converted
36
+ model directory. When a size is configured, the converted model is downloaded
37
+ from the Hugging Face Hub.
38
+ device: Device to use for computation (cpu, cuda, ipu, xpu, mkldnn, opengl, opencl,
39
+ ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia).
40
+ device_index: Device ID to use.
41
+ The model can also be loaded on multiple GPUs by passing a list of IDs
42
+ (e.g. [0, 1, 2, 3]). In that case, multiple transcriptions can run in parallel
43
+ when transcribe() is called from multiple Python threads (see also num_workers).
44
+ compute_type: Type to use for computation.
45
+ See https://opennmt.net/CTranslate2/quantization.html.
46
+ cpu_threads: Number of threads to use when running on CPU (4 by default).
47
+ A non zero value overrides the OMP_NUM_THREADS environment variable.
48
+ num_workers: When transcribe() is called from multiple Python threads,
49
+ having multiple workers enables true parallelism when running the model
50
+ (concurrent calls to self.model.generate() will run in parallel).
51
+ This can improve the global throughput at the cost of increased memory usage.
52
+ downloadRoot: Directory where the models should be saved. If not set, the models
53
+ are saved in the standard Hugging Face cache directory.
54
+ localFilesOnly: If True, avoid downloading the file and return the path to the
55
+ local cached file if it exists.
56
+ """
57
+ self.modelConfig = modelConfig
58
+ self.whisperLang = whisperLang # self.translationLangWhisper = get_lang_from_whisper_code(whisperLang.code.lower() if whisperLang is not None else "en")
59
+ self.translationLang = translationLang
60
+
61
+ if translationLang is None:
62
+ return
63
+
64
+ self.batchSize = batchSize
65
+ self.noRepeatNgramSize = noRepeatNgramSize
66
+ self.numBeams = numBeams
67
+
68
+ if os.path.isdir(modelConfig.url):
69
+ self.modelPath = modelConfig.url
70
+ else:
71
+ self.modelPath = download_model(
72
+ modelConfig,
73
+ localFilesOnly=localFilesOnly,
74
+ cacheDir=downloadRoot,
75
+ )
76
+
77
+ if device is None:
78
+ if torch.cuda.is_available():
79
+ device = "cuda" if "ct2" in self.modelPath else "cuda:0"
80
+ else:
81
+ device = "cpu"
82
+
83
+ self.device = device
84
+
85
+ if loadModel:
86
+ self.load_model()
87
+
88
+ def load_model(self):
89
+ print('\n\nLoading model: %s\n\n' % self.modelPath)
90
+ if "ct2" in self.modelPath:
91
+ if "nllb" in self.modelPath:
92
+ self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelConfig.tokenizer_url if self.modelConfig.tokenizer_url is not None and len(self.modelConfig.tokenizer_url) > 0 else self.modelPath, src_lang=self.whisperLang.nllb.code)
93
+ self.targetPrefix = [self.translationLang.nllb.code]
94
+ elif "m2m100" in self.modelPath:
95
+ self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelConfig.tokenizer_url if self.modelConfig.tokenizer_url is not None and len(self.modelConfig.tokenizer_url) > 0 else self.modelPath, src_lang=self.whisperLang.m2m100.code)
96
+ self.targetPrefix = [self.transTokenizer.lang_code_to_token[self.translationLang.m2m100.code]]
97
+ self.transModel = ctranslate2.Translator(self.modelPath, compute_type="auto", device=self.device)
98
+
99
+ elif "mt5" in self.modelPath:
100
+ self.mt5Prefix = self.whisperLang.whisper.code + "2" + self.translationLang.whisper.code + ": "
101
+ self.transTokenizer = transformers.T5Tokenizer.from_pretrained(self.modelPath, legacy=False) #requires spiece.model
102
+ self.transModel = transformers.MT5ForConditionalGeneration.from_pretrained(self.modelPath)
103
+ self.transTranslator = transformers.pipeline('text2text-generation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer)
104
+ elif "ALMA" in self.modelPath:
105
+ self.ALMAPrefix = "Translate this from " + self.whisperLang.whisper.code + " to " + self.translationLang.whisper.code + ":" + self.whisperLang.whisper.code + ":"
106
+ self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelPath, use_fast=True)
107
+ self.transModel = transformers.AutoModelForCausalLM.from_pretrained(self.modelPath, device_map="auto", trust_remote_code=False, revision="main")
108
+ self.transTranslator = transformers.pipeline("text-generation", model=self.transModel, tokenizer=self.transTokenizer, batch_size=2, do_sample=True, temperature=0.7, top_p=0.95, top_k=40, repetition_penalty=1.1)
109
+ else:
110
+ self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelPath)
111
+ self.transModel = transformers.AutoModelForSeq2SeqLM.from_pretrained(self.modelPath)
112
+ if "m2m100" in self.modelPath:
113
+ self.transTranslator = transformers.pipeline('translation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer, src_lang=self.whisperLang.m2m100.code, tgt_lang=self.translationLang.m2m100.code)
114
+ else: #NLLB
115
+ self.transTranslator = transformers.pipeline('translation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer, src_lang=self.whisperLang.nllb.code, tgt_lang=self.translationLang.nllb.code)
116
+
117
+ def release_vram(self):
118
+ try:
119
+ if torch.cuda.is_available():
120
+ if "ct2" not in self.modelPath:
121
+ device = torch.device("cpu")
122
+ self.transModel.to(device)
123
+ del self.transModel
124
+ torch.cuda.empty_cache()
125
+ print("release vram end.")
126
+ except Exception as e:
127
+ print("Error release vram: " + str(e))
128
+
129
+
130
+ def translation(self, text: str, max_length: int = 400):
131
+ output = None
132
+ result = None
133
+ try:
134
+ if "ct2" in self.modelPath:
135
+ source = self.transTokenizer.convert_ids_to_tokens(self.transTokenizer.encode(text))
136
+ output = self.transModel.translate_batch([source], target_prefix=[self.targetPrefix], max_batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, beam_size=self.numBeams)
137
+ target = output[0].hypotheses[0][1:]
138
+ result = self.transTokenizer.decode(self.transTokenizer.convert_tokens_to_ids(target))
139
+ elif "mt5" in self.modelPath:
140
+ output = self.transTranslator(self.mt5Prefix + text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams) #, num_return_sequences=2
141
+ result = output[0]['generated_text']
142
+ elif "ALMA" in self.modelPath:
143
+ output = self.transTranslator(self.ALMAPrefix + text + self.translationLang.whisper.code + ":", max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams)
144
+ result = output[0]['generated_text']
145
+ result = re.sub(rf'^(.*{self.translationLang.whisper.code}: )', '', result) # Remove the prompt from the result
146
+ result = re.sub(rf'^(Translate this from .* to .*:)', '', result) # Remove the translation instruction
147
+ return result.strip()
148
+ else: #M2M100 & NLLB
149
+ output = self.transTranslator(text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams)
150
+ result = output[0]['translation_text']
151
+ except Exception as e:
152
+ print("Error translation text: " + str(e))
153
+
154
+ return result
155
+
156
+
157
+ _MODELS = ["distilled-600M", "distilled-1.3B", "1.3B", "3.3B",
158
+ "ct2fast-nllb-200-distilled-1.3B-int8_float16",
159
+ "ct2fast-nllb-200-3.3B-int8_float16",
160
+ "nllb-200-3.3B-ct2-float16", "nllb-200-1.3B-ct2", "nllb-200-1.3B-ct2-int8", "nllb-200-1.3B-ct2-float16",
161
+ "nllb-200-distilled-1.3B-ct2", "nllb-200-distilled-1.3B-ct2-int8", "nllb-200-distilled-1.3B-ct2-float16",
162
+ "nllb-200-distilled-600M-ct2", "nllb-200-distilled-600M-ct2-int8", "nllb-200-distilled-600M-ct2-float16",
163
+ "m2m100_1.2B-ct2", "m2m100_418M-ct2", "m2m100-12B-ct2",
164
+ "m2m100_1.2B", "m2m100_418M",
165
+ "mt5-zh-ja-en-trimmed",
166
+ "mt5-zh-ja-en-trimmed-fine-tuned-v1",
167
+ "ALMA-13B-GPTQ"]
168
+
169
+ def check_model_name(name):
170
+ return any(allowed_name in name for allowed_name in _MODELS)
171
+
172
+ def download_model(
173
+ modelConfig: ModelConfig,
174
+ outputDir: Optional[str] = None,
175
+ localFilesOnly: bool = False,
176
+ cacheDir: Optional[str] = None,
177
+ ):
178
+ """"download_model" is referenced from the "utils.py" script
179
+ of the "faster_whisper" project, authored by guillaumekln.
180
+
181
+ Downloads a nllb-200 model from the Hugging Face Hub.
182
+
183
+ The model is downloaded from https://huggingface.co/facebook.
184
+
185
+ Args:
186
+ modelConfig: config of the model to download (facebook/nllb-distilled-600M,
187
+ facebook/nllb-distilled-1.3B, facebook/nllb-1.3B, facebook/nllb-3.3B...).
188
+ outputDir: Directory where the model should be saved. If not set, the model is saved in
189
+ the cache directory.
190
+ localFilesOnly: If True, avoid downloading the file and return the path to the local
191
+ cached file if it exists.
192
+ cacheDir: Path to the folder where cached files are stored.
193
+
194
+ Returns:
195
+ The path to the downloaded model.
196
+
197
+ Raises:
198
+ ValueError: if the model size is invalid.
199
+ """
200
+ if not check_model_name(modelConfig.name):
201
+ raise ValueError(
202
+ "Invalid model name '%s', expected one of: %s" % (modelConfig.name, ", ".join(_MODELS))
203
+ )
204
+
205
+ repoId = modelConfig.url #"facebook/nllb-200-%s" %
206
+
207
+ allowPatterns = [
208
+ "config.json",
209
+ "generation_config.json",
210
+ "model.bin",
211
+ "pytorch_model.bin",
212
+ "pytorch_model.bin.index.json",
213
+ "pytorch_model-*.bin",
214
+ "pytorch_model-00001-of-00003.bin",
215
+ "pytorch_model-00002-of-00003.bin",
216
+ "pytorch_model-00003-of-00003.bin",
217
+ "sentencepiece.bpe.model",
218
+ "tokenizer.json",
219
+ "tokenizer_config.json",
220
+ "shared_vocabulary.txt",
221
+ "shared_vocabulary.json",
222
+ "special_tokens_map.json",
223
+ "spiece.model",
224
+ "vocab.json", #m2m100
225
+ "model.safetensors",
226
+ "quantize_config.json",
227
+ "tokenizer.model"
228
+ ]
229
+
230
+ kwargs = {
231
+ "local_files_only": localFilesOnly,
232
+ "allow_patterns": allowPatterns,
233
+ #"tqdm_class": disabled_tqdm,
234
+ }
235
+
236
+ if outputDir is not None:
237
+ kwargs["local_dir"] = outputDir
238
+ kwargs["local_dir_use_symlinks"] = False
239
+
240
+ if cacheDir is not None:
241
+ kwargs["cache_dir"] = cacheDir
242
+
243
+ try:
244
+ return huggingface_hub.snapshot_download(repoId, **kwargs)
245
+ except (
246
+ huggingface_hub.utils.HfHubHTTPError,
247
+ requests.exceptions.ConnectionError,
248
+ ) as exception:
249
+ warnings.warn(
250
+ "An error occured while synchronizing the model %s from the Hugging Face Hub:\n%s",
251
+ repoId,
252
+ exception,
253
+ )
254
+ warnings.warn(
255
+ "Trying to load the model directly from the local cache, if it exists."
256
+ )
257
+
258
+ kwargs["local_files_only"] = True
259
+ return huggingface_hub.snapshot_download(repoId, **kwargs)
utils.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import textwrap
2
+ import unicodedata
3
+ import re
4
+
5
+ import zlib
6
+ from typing import Iterator, TextIO, Union
7
+ import tqdm
8
+
9
+ import urllib3
10
+
11
+
12
+ def exact_div(x, y):
13
+ assert x % y == 0
14
+ return x // y
15
+
16
+
17
+ def str2bool(string):
18
+ str2val = {"True": True, "False": False}
19
+ if string in str2val:
20
+ return str2val[string]
21
+ else:
22
+ raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
23
+
24
+
25
+ def optional_int(string):
26
+ return None if string == "None" else int(string)
27
+
28
+
29
+ def optional_float(string):
30
+ return None if string == "None" else float(string)
31
+
32
+
33
+ def compression_ratio(text) -> float:
34
+ return len(text) / len(zlib.compress(text.encode("utf-8")))
35
+
36
+
37
+ def format_timestamp(seconds: float, always_include_hours: bool = False, fractionalSeperator: str = '.'):
38
+ assert seconds >= 0, "non-negative timestamp expected"
39
+ milliseconds = round(seconds * 1000.0)
40
+
41
+ hours = milliseconds // 3_600_000
42
+ milliseconds -= hours * 3_600_000
43
+
44
+ minutes = milliseconds // 60_000
45
+ milliseconds -= minutes * 60_000
46
+
47
+ seconds = milliseconds // 1_000
48
+ milliseconds -= seconds * 1_000
49
+
50
+ hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
51
+ return f"{hours_marker}{minutes:02d}:{seconds:02d}{fractionalSeperator}{milliseconds:03d}"
52
+
53
+
54
+ def write_txt(transcript: Iterator[dict], file: TextIO):
55
+ for segment in transcript:
56
+ print(segment['text'].strip(), file=file, flush=True)
57
+
58
+
59
+ def write_vtt(transcript: Iterator[dict], file: TextIO,
60
+ maxLineWidth=None, highlight_words: bool = False):
61
+ iterator = __subtitle_preprocessor_iterator(transcript, maxLineWidth, highlight_words)
62
+
63
+ print("WEBVTT\n", file=file)
64
+
65
+ for segment in iterator:
66
+ text = segment['text'].replace('-->', '->')
67
+
68
+ print(
69
+ f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
70
+ f"{text}\n",
71
+ file=file,
72
+ flush=True,
73
+ )
74
+
75
+ def write_srt(transcript: Iterator[dict], file: TextIO,
76
+ maxLineWidth=None, highlight_words: bool = False):
77
+ """
78
+ Write a transcript to a file in SRT format.
79
+ Example usage:
80
+ from pathlib import Path
81
+ from whisper.utils import write_srt
82
+ result = transcribe(model, audio_path, temperature=temperature, **args)
83
+ # save SRT
84
+ audio_basename = Path(audio_path).stem
85
+ with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
86
+ write_srt(result["segments"], file=srt)
87
+ """
88
+ iterator = __subtitle_preprocessor_iterator(transcript, maxLineWidth, highlight_words)
89
+
90
+ for i, segment in enumerate(iterator, start=1):
91
+ text = segment['text'].replace('-->', '->')
92
+
93
+ # write srt lines
94
+ print(
95
+ f"{i}\n"
96
+ f"{format_timestamp(segment['start'], always_include_hours=True, fractionalSeperator=',')} --> "
97
+ f"{format_timestamp(segment['end'], always_include_hours=True, fractionalSeperator=',')}\n"
98
+ f"{text}\n",
99
+ file=file,
100
+ flush=True,
101
+ )
102
+
103
+ def write_srt_original(transcript: Iterator[dict], file: TextIO,
104
+ maxLineWidth=None, highlight_words: bool = False, bilingual: bool = False):
105
+ """
106
+ Write a transcript to a file in SRT format.
107
+ Example usage:
108
+ from pathlib import Path
109
+ from whisper.utils import write_srt
110
+ result = transcribe(model, audio_path, temperature=temperature, **args)
111
+ # save SRT
112
+ audio_basename = Path(audio_path).stem
113
+ with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
114
+ write_srt(result["segments"], file=srt)
115
+ """
116
+ iterator = __subtitle_preprocessor_iterator(transcript, maxLineWidth, highlight_words)
117
+
118
+ for i, segment in enumerate(iterator, start=1):
119
+ if "original" not in segment:
120
+ continue
121
+
122
+ original = segment['original'].replace('-->', '->')
123
+
124
+ # write srt lines
125
+ print(
126
+ f"{i}\n"
127
+ f"{format_timestamp(segment['start'], always_include_hours=True, fractionalSeperator=',')} --> "
128
+ f"{format_timestamp(segment['end'], always_include_hours=True, fractionalSeperator=',')}",
129
+ file=file,
130
+ flush=True,
131
+ )
132
+
133
+ if original is not None: print(f"{original}\n",
134
+ file=file,
135
+ flush=True)
136
+
137
+ if bilingual:
138
+ text = segment['text'].replace('-->', '->')
139
+ print(f"{text}\n",
140
+ file=file,
141
+ flush=True)
142
+
143
+ def __subtitle_preprocessor_iterator(transcript: Iterator[dict], maxLineWidth: int = None, highlight_words: bool = False):
144
+ for segment in transcript:
145
+ words: list = segment.get('words', [])
146
+
147
+ # Append longest speaker ID if available
148
+ segment_longest_speaker = segment.get('longest_speaker', None)
149
+
150
+ # Yield the segment as-is or processed
151
+ if len(words) == 0 and (maxLineWidth is None or maxLineWidth < 0) and segment_longest_speaker is None:
152
+ yield segment
153
+
154
+ if segment_longest_speaker is not None:
155
+ segment_longest_speaker = segment_longest_speaker.replace("SPEAKER", "S")
156
+
157
+ subtitle_start = segment['start']
158
+ subtitle_end = segment['end']
159
+ text = segment['text'].strip()
160
+ original_text = segment['original'].strip() if 'original' in segment else None
161
+
162
+ if len(words) == 0:
163
+ # Prepend the longest speaker ID if available
164
+ if segment_longest_speaker is not None:
165
+ text = f"({segment_longest_speaker}) {text}"
166
+
167
+ result = {
168
+ 'start': subtitle_start,
169
+ 'end' : subtitle_end,
170
+ 'text' : process_text(text, maxLineWidth)
171
+ }
172
+ if original_text is not None and len(original_text) > 0:
173
+ result.update({'original': process_text(original_text, maxLineWidth)})
174
+ yield result
175
+
176
+ # We are done
177
+ continue
178
+
179
+ if segment_longest_speaker is not None:
180
+ # Add the beginning
181
+ words.insert(0, {
182
+ 'start': subtitle_start,
183
+ 'end' : subtitle_start,
184
+ 'word' : f"({segment_longest_speaker})"
185
+ })
186
+
187
+ text_words = [text] if not highlight_words and original_text is not None and len(original_text) > 0 else [ this_word["word"] for this_word in words ]
188
+ subtitle_text = __join_words(text_words, maxLineWidth)
189
+
190
+ # Iterate over the words in the segment
191
+ if highlight_words:
192
+ last = subtitle_start
193
+
194
+ for i, this_word in enumerate(words):
195
+ start = this_word['start']
196
+ end = this_word['end']
197
+
198
+ if last != start:
199
+ # Display the text up to this point
200
+ yield {
201
+ 'start': last,
202
+ 'end' : start,
203
+ 'text' : subtitle_text
204
+ }
205
+
206
+ # Display the text with the current word highlighted
207
+ yield {
208
+ 'start': start,
209
+ 'end' : end,
210
+ 'text' : __join_words(
211
+ [
212
+ {
213
+ "word": re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
214
+ if j == i
215
+ else word,
216
+ # The HTML tags <u> and </u> are not displayed,
217
+ # # so they should not be counted in the word length
218
+ "length": len(word)
219
+ } for j, word in enumerate(text_words)
220
+ ], maxLineWidth)
221
+ }
222
+ last = end
223
+
224
+ if last != subtitle_end:
225
+ # Display the last part of the text
226
+ yield {
227
+ 'start': last,
228
+ 'end' : subtitle_end,
229
+ 'text' : subtitle_text
230
+ }
231
+
232
+ # Just return the subtitle text
233
+ else:
234
+ result = {
235
+ 'start': subtitle_start,
236
+ 'end' : subtitle_end,
237
+ 'text' : subtitle_text
238
+ }
239
+ if original_text is not None and len(original_text) > 0:
240
+ result.update({'original': original_text})
241
+ yield result
242
+
243
+ def __join_words(words: Iterator[Union[str, dict]], maxLineWidth: int = None):
244
+ if maxLineWidth is None or maxLineWidth < 0:
245
+ return " ".join(words)
246
+
247
+ lines = []
248
+ current_line = ""
249
+ current_length = 0
250
+
251
+ for entry in words:
252
+ # Either accept a string or a dict with a 'word' and 'length' field
253
+ if isinstance(entry, dict):
254
+ word = entry['word']
255
+ word_length = entry['length']
256
+ else:
257
+ word = entry
258
+ word_length = len(word)
259
+
260
+ if current_length > 0 and current_length + word_length > maxLineWidth:
261
+ lines.append(current_line)
262
+ current_line = ""
263
+ current_length = 0
264
+
265
+ current_length += word_length
266
+ # The word will be prefixed with a space by Whisper, so we don't need to add one here
267
+ current_line += word
268
+
269
+ if len(current_line) > 0:
270
+ lines.append(current_line)
271
+
272
+ return "\n".join(lines)
273
+
274
+ def process_text(text: str, maxLineWidth=None):
275
+ if (maxLineWidth is None or maxLineWidth < 0):
276
+ return text
277
+
278
+ lines = textwrap.wrap(text, width=maxLineWidth, tabsize=4)
279
+ return '\n'.join(lines)
280
+
281
+ def slugify(value, allow_unicode=False, is_lower=False):
282
+ """
283
+ Taken from https://github.com/django/django/blob/master/django/utils/text.py
284
+ Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
285
+ dashes to single dashes. Remove characters that aren't alphanumerics,
286
+ underscores, or hyphens. Convert to lowercase. Also strip leading and
287
+ trailing whitespace, dashes, and underscores.
288
+ """
289
+ value = str(value)
290
+ if allow_unicode:
291
+ value = unicodedata.normalize('NFKC', value)
292
+ else:
293
+ value = unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii')
294
+ if is_lower:
295
+ value = value.lower()
296
+ value = re.sub(r'[^\w\s-]', '', value.replace("/","_").replace("⧸","_"))
297
+ return re.sub(r'[-\s]+', '-', value).strip('-_')
298
+
299
+ def download_file(url: str, destination: str):
300
+ with urllib3.request.urlopen(url) as source, open(destination, "wb") as output:
301
+ with tqdm(
302
+ total=int(source.info().get("Content-Length")),
303
+ ncols=80,
304
+ unit="iB",
305
+ unit_scale=True,
306
+ unit_divisor=1024,
307
+ ) as loop:
308
+ while True:
309
+ buffer = source.read(8192)
310
+ if not buffer:
311
+ break
312
+
313
+ output.write(buffer)
314
+ loop.update(len(buffer))