Spaces:
Runtime error
Runtime error
SoybeanMilk
commited on
Commit
·
22034c0
1
Parent(s):
2601bfd
Add support for the ALMA model.
Browse files- app.py +19 -0
- config.json5 +7 -0
- config.py +177 -0
- translationModel.py +259 -0
- utils.py +314 -0
app.py
CHANGED
@@ -231,6 +231,8 @@ class WhisperTranscriber:
|
|
231 |
nllbLangName: str = decodeOptions.pop("nllbLangName")
|
232 |
mt5ModelName: str = decodeOptions.pop("mt5ModelName")
|
233 |
mt5LangName: str = decodeOptions.pop("mt5LangName")
|
|
|
|
|
234 |
|
235 |
translationBatchSize: int = decodeOptions.pop("translationBatchSize")
|
236 |
translationNoRepeatNgramSize: int = decodeOptions.pop("translationNoRepeatNgramSize")
|
@@ -334,6 +336,10 @@ class WhisperTranscriber:
|
|
334 |
selectedModelName = mt5ModelName if mt5ModelName is not None and len(mt5ModelName) > 0 else "mt5-zh-ja-en-trimmed/K024"
|
335 |
selectedModel = next((modelConfig for modelConfig in self.app_config.models["mt5"] if modelConfig.name == selectedModelName), None)
|
336 |
translationLang = get_lang_from_m2m100_name(mt5LangName)
|
|
|
|
|
|
|
|
|
337 |
|
338 |
if translationLang is not None:
|
339 |
translationModel = TranslationModel(modelConfig=selectedModel, whisperLang=whisperLang, translationLang=translationLang, batchSize=translationBatchSize, noRepeatNgramSize=translationNoRepeatNgramSize, numBeams=translationNumBeams)
|
@@ -826,6 +832,7 @@ def create_ui(app_config: ApplicationConfig):
|
|
826 |
nllb_models = app_config.get_model_names("nllb")
|
827 |
m2m100_models = app_config.get_model_names("m2m100")
|
828 |
mt5_models = app_config.get_model_names("mt5")
|
|
|
829 |
|
830 |
common_whisper_inputs = lambda : {
|
831 |
gr.Dropdown(label="Whisper - Model (for audio)", choices=whisper_models, value=app_config.default_model_name, elem_id="whisperModelName"),
|
@@ -843,6 +850,10 @@ def create_ui(app_config: ApplicationConfig):
|
|
843 |
gr.Dropdown(label="MT5 - Model (for translate)", choices=mt5_models, elem_id="mt5ModelName"),
|
844 |
gr.Dropdown(label="MT5 - Language", choices=sorted(get_lang_m2m100_names(["en", "ja", "zh"])), elem_id="mt5LangName"),
|
845 |
}
|
|
|
|
|
|
|
|
|
846 |
|
847 |
common_translation_inputs = lambda : {
|
848 |
gr.Number(label="Translation - Batch Size", precision=0, value=app_config.translation_batch_size, elem_id="translationBatchSize"),
|
@@ -903,9 +914,13 @@ def create_ui(app_config: ApplicationConfig):
|
|
903 |
with gr.Tab(label="MT5") as simpleMT5Tab:
|
904 |
with gr.Row():
|
905 |
simpleInputDict.update(common_mt5_inputs())
|
|
|
|
|
|
|
906 |
simpleM2M100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [simpleTranslateInput] )
|
907 |
simpleNllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [simpleTranslateInput] )
|
908 |
simpleMT5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [simpleTranslateInput] )
|
|
|
909 |
with gr.Column():
|
910 |
with gr.Tab(label="URL") as simpleUrlTab:
|
911 |
simpleInputDict.update({gr.Text(label="URL (YouTube, etc.)", elem_id = "urlData")})
|
@@ -962,9 +977,13 @@ def create_ui(app_config: ApplicationConfig):
|
|
962 |
with gr.Tab(label="MT5") as fullMT5Tab:
|
963 |
with gr.Row():
|
964 |
fullInputDict.update(common_mt5_inputs())
|
|
|
|
|
|
|
965 |
fullM2M100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [fullTranslateInput] )
|
966 |
fullNllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [fullTranslateInput] )
|
967 |
fullMT5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [fullTranslateInput] )
|
|
|
968 |
with gr.Column():
|
969 |
with gr.Tab(label="URL") as fullUrlTab:
|
970 |
fullInputDict.update({gr.Text(label="URL (YouTube, etc.)", elem_id = "urlData")})
|
|
|
231 |
nllbLangName: str = decodeOptions.pop("nllbLangName")
|
232 |
mt5ModelName: str = decodeOptions.pop("mt5ModelName")
|
233 |
mt5LangName: str = decodeOptions.pop("mt5LangName")
|
234 |
+
ALMAModelName: str = decodeOptions.pop("ALMAModelName")
|
235 |
+
ALMALangName: str = decodeOptions.pop("ALMALangName")
|
236 |
|
237 |
translationBatchSize: int = decodeOptions.pop("translationBatchSize")
|
238 |
translationNoRepeatNgramSize: int = decodeOptions.pop("translationNoRepeatNgramSize")
|
|
|
336 |
selectedModelName = mt5ModelName if mt5ModelName is not None and len(mt5ModelName) > 0 else "mt5-zh-ja-en-trimmed/K024"
|
337 |
selectedModel = next((modelConfig for modelConfig in self.app_config.models["mt5"] if modelConfig.name == selectedModelName), None)
|
338 |
translationLang = get_lang_from_m2m100_name(mt5LangName)
|
339 |
+
elif translateInput == "ALMA" and ALMALangName is not None and len(ALMALangName) > 0:
|
340 |
+
selectedModelName = ALMAModelName if ALMAModelName is not None and len(ALMAModelName) > 0 else "ALMA-13B-GPTQ/TheBloke"
|
341 |
+
selectedModel = next((modelConfig for modelConfig in self.app_config.models["ALMA"] if modelConfig.name == selectedModelName), None)
|
342 |
+
translationLang = get_lang_from_m2m100_name(ALMALangName)
|
343 |
|
344 |
if translationLang is not None:
|
345 |
translationModel = TranslationModel(modelConfig=selectedModel, whisperLang=whisperLang, translationLang=translationLang, batchSize=translationBatchSize, noRepeatNgramSize=translationNoRepeatNgramSize, numBeams=translationNumBeams)
|
|
|
832 |
nllb_models = app_config.get_model_names("nllb")
|
833 |
m2m100_models = app_config.get_model_names("m2m100")
|
834 |
mt5_models = app_config.get_model_names("mt5")
|
835 |
+
ALMA_models = app_config.get_model_names("ALMA")
|
836 |
|
837 |
common_whisper_inputs = lambda : {
|
838 |
gr.Dropdown(label="Whisper - Model (for audio)", choices=whisper_models, value=app_config.default_model_name, elem_id="whisperModelName"),
|
|
|
850 |
gr.Dropdown(label="MT5 - Model (for translate)", choices=mt5_models, elem_id="mt5ModelName"),
|
851 |
gr.Dropdown(label="MT5 - Language", choices=sorted(get_lang_m2m100_names(["en", "ja", "zh"])), elem_id="mt5LangName"),
|
852 |
}
|
853 |
+
common_ALMA_inputs = lambda : {
|
854 |
+
gr.Dropdown(label="ALMA - Model (for translate)", choices=ALMA_models, elem_id="ALMAModelName"),
|
855 |
+
gr.Dropdown(label="ALMA - Language", choices=sorted(get_lang_m2m100_names(["en", "ja", "zh"])), elem_id="ALMALangName"),
|
856 |
+
}
|
857 |
|
858 |
common_translation_inputs = lambda : {
|
859 |
gr.Number(label="Translation - Batch Size", precision=0, value=app_config.translation_batch_size, elem_id="translationBatchSize"),
|
|
|
914 |
with gr.Tab(label="MT5") as simpleMT5Tab:
|
915 |
with gr.Row():
|
916 |
simpleInputDict.update(common_mt5_inputs())
|
917 |
+
with gr.Tab(label="ALMA") as simpleALMATab:
|
918 |
+
with gr.Row():
|
919 |
+
simpleInputDict.update(common_ALMA_inputs())
|
920 |
simpleM2M100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [simpleTranslateInput] )
|
921 |
simpleNllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [simpleTranslateInput] )
|
922 |
simpleMT5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [simpleTranslateInput] )
|
923 |
+
simpleALMATab.select(fn=lambda: "ALMA", inputs = [], outputs= [simpleTranslateInput] )
|
924 |
with gr.Column():
|
925 |
with gr.Tab(label="URL") as simpleUrlTab:
|
926 |
simpleInputDict.update({gr.Text(label="URL (YouTube, etc.)", elem_id = "urlData")})
|
|
|
977 |
with gr.Tab(label="MT5") as fullMT5Tab:
|
978 |
with gr.Row():
|
979 |
fullInputDict.update(common_mt5_inputs())
|
980 |
+
with gr.Tab(label="ALMA") as fullALMATab:
|
981 |
+
with gr.Row():
|
982 |
+
fullInputDict.update(common_ALMA_inputs())
|
983 |
fullM2M100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [fullTranslateInput] )
|
984 |
fullNllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [fullTranslateInput] )
|
985 |
fullMT5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [fullTranslateInput] )
|
986 |
+
fullALMATab.select(fn=lambda: "ALMA", inputs = [], outputs= [fullTranslateInput] )
|
987 |
with gr.Column():
|
988 |
with gr.Tab(label="URL") as fullUrlTab:
|
989 |
fullInputDict.update({gr.Text(label="URL (YouTube, etc.)", elem_id = "urlData")})
|
config.json5
CHANGED
@@ -187,6 +187,13 @@
|
|
187 |
"url": "engmatic-earth/mt5-zh-ja-en-trimmed-fine-tuned-v1",
|
188 |
"type": "huggingface"
|
189 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
]
|
191 |
},
|
192 |
// Configuration options that will be used if they are not specified in the command line arguments.
|
|
|
187 |
"url": "engmatic-earth/mt5-zh-ja-en-trimmed-fine-tuned-v1",
|
188 |
"type": "huggingface"
|
189 |
}
|
190 |
+
],
|
191 |
+
"ALMA": [
|
192 |
+
{
|
193 |
+
"name": "ALMA-13B-GPTQ/TheBloke",
|
194 |
+
"url": "TheBloke/ALMA-13B-GPTQ",
|
195 |
+
"type": "huggingface",
|
196 |
+
},
|
197 |
]
|
198 |
},
|
199 |
// Configuration options that will be used if they are not specified in the command line arguments.
|
config.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
|
3 |
+
import os
|
4 |
+
from typing import List, Dict, Literal
|
5 |
+
|
6 |
+
|
7 |
+
class ModelConfig:
|
8 |
+
def __init__(self, name: str, url: str, path: str = None, type: str = "whisper", tokenizer_url: str = None):
|
9 |
+
"""
|
10 |
+
Initialize a model configuration.
|
11 |
+
|
12 |
+
name: Name of the model
|
13 |
+
url: URL to download the model from
|
14 |
+
path: Path to the model file. If not set, the model will be downloaded from the URL.
|
15 |
+
type: Type of model. Can be whisper or huggingface.
|
16 |
+
"""
|
17 |
+
self.name = name
|
18 |
+
self.url = url
|
19 |
+
self.path = path
|
20 |
+
self.type = type
|
21 |
+
self.tokenizer_url = tokenizer_url
|
22 |
+
|
23 |
+
VAD_INITIAL_PROMPT_MODE_VALUES=["prepend_all_segments", "prepend_first_segment", "json_prompt_mode"]
|
24 |
+
|
25 |
+
class VadInitialPromptMode(Enum):
|
26 |
+
PREPEND_ALL_SEGMENTS = 1
|
27 |
+
PREPREND_FIRST_SEGMENT = 2
|
28 |
+
JSON_PROMPT_MODE = 3
|
29 |
+
|
30 |
+
@staticmethod
|
31 |
+
def from_string(s: str):
|
32 |
+
normalized = s.lower() if s is not None and len(s) > 0 else None
|
33 |
+
|
34 |
+
if normalized == "prepend_all_segments":
|
35 |
+
return VadInitialPromptMode.PREPEND_ALL_SEGMENTS
|
36 |
+
elif normalized == "prepend_first_segment":
|
37 |
+
return VadInitialPromptMode.PREPREND_FIRST_SEGMENT
|
38 |
+
elif normalized == "json_prompt_mode":
|
39 |
+
return VadInitialPromptMode.JSON_PROMPT_MODE
|
40 |
+
elif normalized is not None and normalized != "":
|
41 |
+
raise ValueError(f"Invalid value for VadInitialPromptMode: {s}")
|
42 |
+
else:
|
43 |
+
return None
|
44 |
+
|
45 |
+
class ApplicationConfig:
|
46 |
+
def __init__(self, models: Dict[Literal["whisper", "m2m100", "nllb", "mt5", "ALMA"], List[ModelConfig]],
|
47 |
+
input_audio_max_duration: int = 600, share: bool = False, server_name: str = None, server_port: int = 7860,
|
48 |
+
queue_concurrency_count: int = 1, delete_uploaded_files: bool = True,
|
49 |
+
whisper_implementation: str = "whisper", default_model_name: str = "medium",
|
50 |
+
default_nllb_model_name: str = "distilled-600M", default_vad: str = "silero-vad",
|
51 |
+
vad_parallel_devices: str = "", vad_cpu_cores: int = 1, vad_process_timeout: int = 1800,
|
52 |
+
auto_parallel: bool = False, output_dir: str = None,
|
53 |
+
model_dir: str = None, device: str = None,
|
54 |
+
verbose: bool = True, task: str = "transcribe", language: str = None,
|
55 |
+
vad_initial_prompt_mode: str = "prepend_first_segment ",
|
56 |
+
vad_merge_window: float = 5, vad_max_merge_size: float = 30,
|
57 |
+
vad_padding: float = 1, vad_prompt_window: float = 3,
|
58 |
+
temperature: float = 0, best_of: int = 5, beam_size: int = 5,
|
59 |
+
patience: float = None, length_penalty: float = None,
|
60 |
+
suppress_tokens: str = "-1", initial_prompt: str = None,
|
61 |
+
condition_on_previous_text: bool = True, fp16: bool = True,
|
62 |
+
compute_type: str = "float16",
|
63 |
+
temperature_increment_on_fallback: float = 0.2, compression_ratio_threshold: float = 2.4,
|
64 |
+
logprob_threshold: float = -1.0, no_speech_threshold: float = 0.6,
|
65 |
+
repetition_penalty: float = 1.0, no_repeat_ngram_size: int = 0,
|
66 |
+
# Word timestamp settings
|
67 |
+
word_timestamps: bool = False, prepend_punctuations: str = "\"\'“¿([{-",
|
68 |
+
append_punctuations: str = "\"\'.。,,!!??::”)]}、",
|
69 |
+
highlight_words: bool = False,
|
70 |
+
# Diarization
|
71 |
+
auth_token: str = None, diarization: bool = False, diarization_speakers: int = 2,
|
72 |
+
diarization_min_speakers: int = 1, diarization_max_speakers: int = 5,
|
73 |
+
diarization_process_timeout: int = 60,
|
74 |
+
# Translation
|
75 |
+
translation_batch_size: int = 2,
|
76 |
+
translation_no_repeat_ngram_size: int = 3,
|
77 |
+
translation_num_beams: int = 2,
|
78 |
+
):
|
79 |
+
|
80 |
+
self.models = models
|
81 |
+
|
82 |
+
# WebUI settings
|
83 |
+
self.input_audio_max_duration = input_audio_max_duration
|
84 |
+
self.share = share
|
85 |
+
self.server_name = server_name
|
86 |
+
self.server_port = server_port
|
87 |
+
self.queue_concurrency_count = queue_concurrency_count
|
88 |
+
self.delete_uploaded_files = delete_uploaded_files
|
89 |
+
|
90 |
+
self.whisper_implementation = whisper_implementation
|
91 |
+
self.default_model_name = default_model_name
|
92 |
+
self.default_nllb_model_name = default_nllb_model_name
|
93 |
+
self.default_vad = default_vad
|
94 |
+
self.vad_parallel_devices = vad_parallel_devices
|
95 |
+
self.vad_cpu_cores = vad_cpu_cores
|
96 |
+
self.vad_process_timeout = vad_process_timeout
|
97 |
+
self.auto_parallel = auto_parallel
|
98 |
+
self.output_dir = output_dir
|
99 |
+
|
100 |
+
self.model_dir = model_dir
|
101 |
+
self.device = device
|
102 |
+
self.verbose = verbose
|
103 |
+
self.task = task
|
104 |
+
self.language = language
|
105 |
+
self.vad_initial_prompt_mode = vad_initial_prompt_mode
|
106 |
+
self.vad_merge_window = vad_merge_window
|
107 |
+
self.vad_max_merge_size = vad_max_merge_size
|
108 |
+
self.vad_padding = vad_padding
|
109 |
+
self.vad_prompt_window = vad_prompt_window
|
110 |
+
self.temperature = temperature
|
111 |
+
self.best_of = best_of
|
112 |
+
self.beam_size = beam_size
|
113 |
+
self.patience = patience
|
114 |
+
self.length_penalty = length_penalty
|
115 |
+
self.suppress_tokens = suppress_tokens
|
116 |
+
self.initial_prompt = initial_prompt
|
117 |
+
self.condition_on_previous_text = condition_on_previous_text
|
118 |
+
self.fp16 = fp16
|
119 |
+
self.compute_type = compute_type
|
120 |
+
self.temperature_increment_on_fallback = temperature_increment_on_fallback
|
121 |
+
self.compression_ratio_threshold = compression_ratio_threshold
|
122 |
+
self.logprob_threshold = logprob_threshold
|
123 |
+
self.no_speech_threshold = no_speech_threshold
|
124 |
+
self.repetition_penalty = repetition_penalty
|
125 |
+
self.no_repeat_ngram_size = no_repeat_ngram_size
|
126 |
+
|
127 |
+
# Word timestamp settings
|
128 |
+
self.word_timestamps = word_timestamps
|
129 |
+
self.prepend_punctuations = prepend_punctuations
|
130 |
+
self.append_punctuations = append_punctuations
|
131 |
+
self.highlight_words = highlight_words
|
132 |
+
|
133 |
+
# Diarization settings
|
134 |
+
self.auth_token = auth_token
|
135 |
+
self.diarization = diarization
|
136 |
+
self.diarization_speakers = diarization_speakers
|
137 |
+
self.diarization_min_speakers = diarization_min_speakers
|
138 |
+
self.diarization_max_speakers = diarization_max_speakers
|
139 |
+
self.diarization_process_timeout = diarization_process_timeout
|
140 |
+
# Translation
|
141 |
+
self.translation_batch_size = translation_batch_size
|
142 |
+
self.translation_no_repeat_ngram_size = translation_no_repeat_ngram_size
|
143 |
+
self.translation_num_beams = translation_num_beams
|
144 |
+
|
145 |
+
def get_model_names(self, name: str):
|
146 |
+
return [ x.name for x in self.models[name] ]
|
147 |
+
|
148 |
+
def update(self, **new_values):
|
149 |
+
result = ApplicationConfig(**self.__dict__)
|
150 |
+
|
151 |
+
for key, value in new_values.items():
|
152 |
+
setattr(result, key, value)
|
153 |
+
return result
|
154 |
+
|
155 |
+
@staticmethod
|
156 |
+
def create_default(**kwargs):
|
157 |
+
app_config = ApplicationConfig.parse_file(os.environ.get("WHISPER_WEBUI_CONFIG", "config.json5"))
|
158 |
+
|
159 |
+
# Update with kwargs
|
160 |
+
if len(kwargs) > 0:
|
161 |
+
app_config = app_config.update(**kwargs)
|
162 |
+
return app_config
|
163 |
+
|
164 |
+
@staticmethod
|
165 |
+
def parse_file(config_path: str):
|
166 |
+
import json5
|
167 |
+
|
168 |
+
with open(config_path, "r", encoding="utf-8") as f:
|
169 |
+
# Load using json5
|
170 |
+
data = json5.load(f)
|
171 |
+
data_models = data.pop("models", [])
|
172 |
+
models: Dict[Literal["whisper", "m2m100", "nllb", "mt5", "ALMA"], List[ModelConfig]] = {
|
173 |
+
key: [ModelConfig(**item) for item in value]
|
174 |
+
for key, value in data_models.items()
|
175 |
+
}
|
176 |
+
|
177 |
+
return ApplicationConfig(models, **data)
|
translationModel.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import warnings
|
3 |
+
import huggingface_hub
|
4 |
+
import requests
|
5 |
+
import torch
|
6 |
+
|
7 |
+
import ctranslate2
|
8 |
+
import transformers
|
9 |
+
|
10 |
+
import re
|
11 |
+
|
12 |
+
from typing import Optional
|
13 |
+
from src.config import ModelConfig
|
14 |
+
from src.translation.translationLangs import TranslationLang, get_lang_from_whisper_code
|
15 |
+
from peft import PeftModel
|
16 |
+
|
17 |
+
class TranslationModel:
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
modelConfig: ModelConfig,
|
21 |
+
device: str = None,
|
22 |
+
whisperLang: TranslationLang = None,
|
23 |
+
translationLang: TranslationLang = None,
|
24 |
+
batchSize: int = 2,
|
25 |
+
noRepeatNgramSize: int = 3,
|
26 |
+
numBeams: int = 2,
|
27 |
+
downloadRoot: Optional[str] = None,
|
28 |
+
localFilesOnly: bool = False,
|
29 |
+
loadModel: bool = False,
|
30 |
+
):
|
31 |
+
"""Initializes the M2M100 / Nllb-200 / mt5 model.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
modelConfig: Config of the model to use (distilled-600M, distilled-1.3B,
|
35 |
+
1.3B, 3.3B...) or a path to a converted
|
36 |
+
model directory. When a size is configured, the converted model is downloaded
|
37 |
+
from the Hugging Face Hub.
|
38 |
+
device: Device to use for computation (cpu, cuda, ipu, xpu, mkldnn, opengl, opencl,
|
39 |
+
ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia).
|
40 |
+
device_index: Device ID to use.
|
41 |
+
The model can also be loaded on multiple GPUs by passing a list of IDs
|
42 |
+
(e.g. [0, 1, 2, 3]). In that case, multiple transcriptions can run in parallel
|
43 |
+
when transcribe() is called from multiple Python threads (see also num_workers).
|
44 |
+
compute_type: Type to use for computation.
|
45 |
+
See https://opennmt.net/CTranslate2/quantization.html.
|
46 |
+
cpu_threads: Number of threads to use when running on CPU (4 by default).
|
47 |
+
A non zero value overrides the OMP_NUM_THREADS environment variable.
|
48 |
+
num_workers: When transcribe() is called from multiple Python threads,
|
49 |
+
having multiple workers enables true parallelism when running the model
|
50 |
+
(concurrent calls to self.model.generate() will run in parallel).
|
51 |
+
This can improve the global throughput at the cost of increased memory usage.
|
52 |
+
downloadRoot: Directory where the models should be saved. If not set, the models
|
53 |
+
are saved in the standard Hugging Face cache directory.
|
54 |
+
localFilesOnly: If True, avoid downloading the file and return the path to the
|
55 |
+
local cached file if it exists.
|
56 |
+
"""
|
57 |
+
self.modelConfig = modelConfig
|
58 |
+
self.whisperLang = whisperLang # self.translationLangWhisper = get_lang_from_whisper_code(whisperLang.code.lower() if whisperLang is not None else "en")
|
59 |
+
self.translationLang = translationLang
|
60 |
+
|
61 |
+
if translationLang is None:
|
62 |
+
return
|
63 |
+
|
64 |
+
self.batchSize = batchSize
|
65 |
+
self.noRepeatNgramSize = noRepeatNgramSize
|
66 |
+
self.numBeams = numBeams
|
67 |
+
|
68 |
+
if os.path.isdir(modelConfig.url):
|
69 |
+
self.modelPath = modelConfig.url
|
70 |
+
else:
|
71 |
+
self.modelPath = download_model(
|
72 |
+
modelConfig,
|
73 |
+
localFilesOnly=localFilesOnly,
|
74 |
+
cacheDir=downloadRoot,
|
75 |
+
)
|
76 |
+
|
77 |
+
if device is None:
|
78 |
+
if torch.cuda.is_available():
|
79 |
+
device = "cuda" if "ct2" in self.modelPath else "cuda:0"
|
80 |
+
else:
|
81 |
+
device = "cpu"
|
82 |
+
|
83 |
+
self.device = device
|
84 |
+
|
85 |
+
if loadModel:
|
86 |
+
self.load_model()
|
87 |
+
|
88 |
+
def load_model(self):
|
89 |
+
print('\n\nLoading model: %s\n\n' % self.modelPath)
|
90 |
+
if "ct2" in self.modelPath:
|
91 |
+
if "nllb" in self.modelPath:
|
92 |
+
self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelConfig.tokenizer_url if self.modelConfig.tokenizer_url is not None and len(self.modelConfig.tokenizer_url) > 0 else self.modelPath, src_lang=self.whisperLang.nllb.code)
|
93 |
+
self.targetPrefix = [self.translationLang.nllb.code]
|
94 |
+
elif "m2m100" in self.modelPath:
|
95 |
+
self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelConfig.tokenizer_url if self.modelConfig.tokenizer_url is not None and len(self.modelConfig.tokenizer_url) > 0 else self.modelPath, src_lang=self.whisperLang.m2m100.code)
|
96 |
+
self.targetPrefix = [self.transTokenizer.lang_code_to_token[self.translationLang.m2m100.code]]
|
97 |
+
self.transModel = ctranslate2.Translator(self.modelPath, compute_type="auto", device=self.device)
|
98 |
+
|
99 |
+
elif "mt5" in self.modelPath:
|
100 |
+
self.mt5Prefix = self.whisperLang.whisper.code + "2" + self.translationLang.whisper.code + ": "
|
101 |
+
self.transTokenizer = transformers.T5Tokenizer.from_pretrained(self.modelPath, legacy=False) #requires spiece.model
|
102 |
+
self.transModel = transformers.MT5ForConditionalGeneration.from_pretrained(self.modelPath)
|
103 |
+
self.transTranslator = transformers.pipeline('text2text-generation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer)
|
104 |
+
elif "ALMA" in self.modelPath:
|
105 |
+
self.ALMAPrefix = "Translate this from " + self.whisperLang.whisper.code + " to " + self.translationLang.whisper.code + ":" + self.whisperLang.whisper.code + ":"
|
106 |
+
self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelPath, use_fast=True)
|
107 |
+
self.transModel = transformers.AutoModelForCausalLM.from_pretrained(self.modelPath, device_map="auto", trust_remote_code=False, revision="main")
|
108 |
+
self.transTranslator = transformers.pipeline("text-generation", model=self.transModel, tokenizer=self.transTokenizer, batch_size=2, do_sample=True, temperature=0.7, top_p=0.95, top_k=40, repetition_penalty=1.1)
|
109 |
+
else:
|
110 |
+
self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelPath)
|
111 |
+
self.transModel = transformers.AutoModelForSeq2SeqLM.from_pretrained(self.modelPath)
|
112 |
+
if "m2m100" in self.modelPath:
|
113 |
+
self.transTranslator = transformers.pipeline('translation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer, src_lang=self.whisperLang.m2m100.code, tgt_lang=self.translationLang.m2m100.code)
|
114 |
+
else: #NLLB
|
115 |
+
self.transTranslator = transformers.pipeline('translation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer, src_lang=self.whisperLang.nllb.code, tgt_lang=self.translationLang.nllb.code)
|
116 |
+
|
117 |
+
def release_vram(self):
|
118 |
+
try:
|
119 |
+
if torch.cuda.is_available():
|
120 |
+
if "ct2" not in self.modelPath:
|
121 |
+
device = torch.device("cpu")
|
122 |
+
self.transModel.to(device)
|
123 |
+
del self.transModel
|
124 |
+
torch.cuda.empty_cache()
|
125 |
+
print("release vram end.")
|
126 |
+
except Exception as e:
|
127 |
+
print("Error release vram: " + str(e))
|
128 |
+
|
129 |
+
|
130 |
+
def translation(self, text: str, max_length: int = 400):
|
131 |
+
output = None
|
132 |
+
result = None
|
133 |
+
try:
|
134 |
+
if "ct2" in self.modelPath:
|
135 |
+
source = self.transTokenizer.convert_ids_to_tokens(self.transTokenizer.encode(text))
|
136 |
+
output = self.transModel.translate_batch([source], target_prefix=[self.targetPrefix], max_batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, beam_size=self.numBeams)
|
137 |
+
target = output[0].hypotheses[0][1:]
|
138 |
+
result = self.transTokenizer.decode(self.transTokenizer.convert_tokens_to_ids(target))
|
139 |
+
elif "mt5" in self.modelPath:
|
140 |
+
output = self.transTranslator(self.mt5Prefix + text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams) #, num_return_sequences=2
|
141 |
+
result = output[0]['generated_text']
|
142 |
+
elif "ALMA" in self.modelPath:
|
143 |
+
output = self.transTranslator(self.ALMAPrefix + text + self.translationLang.whisper.code + ":", max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams)
|
144 |
+
result = output[0]['generated_text']
|
145 |
+
result = re.sub(rf'^(.*{self.translationLang.whisper.code}: )', '', result) # Remove the prompt from the result
|
146 |
+
result = re.sub(rf'^(Translate this from .* to .*:)', '', result) # Remove the translation instruction
|
147 |
+
return result.strip()
|
148 |
+
else: #M2M100 & NLLB
|
149 |
+
output = self.transTranslator(text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams)
|
150 |
+
result = output[0]['translation_text']
|
151 |
+
except Exception as e:
|
152 |
+
print("Error translation text: " + str(e))
|
153 |
+
|
154 |
+
return result
|
155 |
+
|
156 |
+
|
157 |
+
_MODELS = ["distilled-600M", "distilled-1.3B", "1.3B", "3.3B",
|
158 |
+
"ct2fast-nllb-200-distilled-1.3B-int8_float16",
|
159 |
+
"ct2fast-nllb-200-3.3B-int8_float16",
|
160 |
+
"nllb-200-3.3B-ct2-float16", "nllb-200-1.3B-ct2", "nllb-200-1.3B-ct2-int8", "nllb-200-1.3B-ct2-float16",
|
161 |
+
"nllb-200-distilled-1.3B-ct2", "nllb-200-distilled-1.3B-ct2-int8", "nllb-200-distilled-1.3B-ct2-float16",
|
162 |
+
"nllb-200-distilled-600M-ct2", "nllb-200-distilled-600M-ct2-int8", "nllb-200-distilled-600M-ct2-float16",
|
163 |
+
"m2m100_1.2B-ct2", "m2m100_418M-ct2", "m2m100-12B-ct2",
|
164 |
+
"m2m100_1.2B", "m2m100_418M",
|
165 |
+
"mt5-zh-ja-en-trimmed",
|
166 |
+
"mt5-zh-ja-en-trimmed-fine-tuned-v1",
|
167 |
+
"ALMA-13B-GPTQ"]
|
168 |
+
|
169 |
+
def check_model_name(name):
|
170 |
+
return any(allowed_name in name for allowed_name in _MODELS)
|
171 |
+
|
172 |
+
def download_model(
|
173 |
+
modelConfig: ModelConfig,
|
174 |
+
outputDir: Optional[str] = None,
|
175 |
+
localFilesOnly: bool = False,
|
176 |
+
cacheDir: Optional[str] = None,
|
177 |
+
):
|
178 |
+
""""download_model" is referenced from the "utils.py" script
|
179 |
+
of the "faster_whisper" project, authored by guillaumekln.
|
180 |
+
|
181 |
+
Downloads a nllb-200 model from the Hugging Face Hub.
|
182 |
+
|
183 |
+
The model is downloaded from https://huggingface.co/facebook.
|
184 |
+
|
185 |
+
Args:
|
186 |
+
modelConfig: config of the model to download (facebook/nllb-distilled-600M,
|
187 |
+
facebook/nllb-distilled-1.3B, facebook/nllb-1.3B, facebook/nllb-3.3B...).
|
188 |
+
outputDir: Directory where the model should be saved. If not set, the model is saved in
|
189 |
+
the cache directory.
|
190 |
+
localFilesOnly: If True, avoid downloading the file and return the path to the local
|
191 |
+
cached file if it exists.
|
192 |
+
cacheDir: Path to the folder where cached files are stored.
|
193 |
+
|
194 |
+
Returns:
|
195 |
+
The path to the downloaded model.
|
196 |
+
|
197 |
+
Raises:
|
198 |
+
ValueError: if the model size is invalid.
|
199 |
+
"""
|
200 |
+
if not check_model_name(modelConfig.name):
|
201 |
+
raise ValueError(
|
202 |
+
"Invalid model name '%s', expected one of: %s" % (modelConfig.name, ", ".join(_MODELS))
|
203 |
+
)
|
204 |
+
|
205 |
+
repoId = modelConfig.url #"facebook/nllb-200-%s" %
|
206 |
+
|
207 |
+
allowPatterns = [
|
208 |
+
"config.json",
|
209 |
+
"generation_config.json",
|
210 |
+
"model.bin",
|
211 |
+
"pytorch_model.bin",
|
212 |
+
"pytorch_model.bin.index.json",
|
213 |
+
"pytorch_model-*.bin",
|
214 |
+
"pytorch_model-00001-of-00003.bin",
|
215 |
+
"pytorch_model-00002-of-00003.bin",
|
216 |
+
"pytorch_model-00003-of-00003.bin",
|
217 |
+
"sentencepiece.bpe.model",
|
218 |
+
"tokenizer.json",
|
219 |
+
"tokenizer_config.json",
|
220 |
+
"shared_vocabulary.txt",
|
221 |
+
"shared_vocabulary.json",
|
222 |
+
"special_tokens_map.json",
|
223 |
+
"spiece.model",
|
224 |
+
"vocab.json", #m2m100
|
225 |
+
"model.safetensors",
|
226 |
+
"quantize_config.json",
|
227 |
+
"tokenizer.model"
|
228 |
+
]
|
229 |
+
|
230 |
+
kwargs = {
|
231 |
+
"local_files_only": localFilesOnly,
|
232 |
+
"allow_patterns": allowPatterns,
|
233 |
+
#"tqdm_class": disabled_tqdm,
|
234 |
+
}
|
235 |
+
|
236 |
+
if outputDir is not None:
|
237 |
+
kwargs["local_dir"] = outputDir
|
238 |
+
kwargs["local_dir_use_symlinks"] = False
|
239 |
+
|
240 |
+
if cacheDir is not None:
|
241 |
+
kwargs["cache_dir"] = cacheDir
|
242 |
+
|
243 |
+
try:
|
244 |
+
return huggingface_hub.snapshot_download(repoId, **kwargs)
|
245 |
+
except (
|
246 |
+
huggingface_hub.utils.HfHubHTTPError,
|
247 |
+
requests.exceptions.ConnectionError,
|
248 |
+
) as exception:
|
249 |
+
warnings.warn(
|
250 |
+
"An error occured while synchronizing the model %s from the Hugging Face Hub:\n%s",
|
251 |
+
repoId,
|
252 |
+
exception,
|
253 |
+
)
|
254 |
+
warnings.warn(
|
255 |
+
"Trying to load the model directly from the local cache, if it exists."
|
256 |
+
)
|
257 |
+
|
258 |
+
kwargs["local_files_only"] = True
|
259 |
+
return huggingface_hub.snapshot_download(repoId, **kwargs)
|
utils.py
ADDED
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import textwrap
|
2 |
+
import unicodedata
|
3 |
+
import re
|
4 |
+
|
5 |
+
import zlib
|
6 |
+
from typing import Iterator, TextIO, Union
|
7 |
+
import tqdm
|
8 |
+
|
9 |
+
import urllib3
|
10 |
+
|
11 |
+
|
12 |
+
def exact_div(x, y):
|
13 |
+
assert x % y == 0
|
14 |
+
return x // y
|
15 |
+
|
16 |
+
|
17 |
+
def str2bool(string):
|
18 |
+
str2val = {"True": True, "False": False}
|
19 |
+
if string in str2val:
|
20 |
+
return str2val[string]
|
21 |
+
else:
|
22 |
+
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
|
23 |
+
|
24 |
+
|
25 |
+
def optional_int(string):
|
26 |
+
return None if string == "None" else int(string)
|
27 |
+
|
28 |
+
|
29 |
+
def optional_float(string):
|
30 |
+
return None if string == "None" else float(string)
|
31 |
+
|
32 |
+
|
33 |
+
def compression_ratio(text) -> float:
|
34 |
+
return len(text) / len(zlib.compress(text.encode("utf-8")))
|
35 |
+
|
36 |
+
|
37 |
+
def format_timestamp(seconds: float, always_include_hours: bool = False, fractionalSeperator: str = '.'):
|
38 |
+
assert seconds >= 0, "non-negative timestamp expected"
|
39 |
+
milliseconds = round(seconds * 1000.0)
|
40 |
+
|
41 |
+
hours = milliseconds // 3_600_000
|
42 |
+
milliseconds -= hours * 3_600_000
|
43 |
+
|
44 |
+
minutes = milliseconds // 60_000
|
45 |
+
milliseconds -= minutes * 60_000
|
46 |
+
|
47 |
+
seconds = milliseconds // 1_000
|
48 |
+
milliseconds -= seconds * 1_000
|
49 |
+
|
50 |
+
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
|
51 |
+
return f"{hours_marker}{minutes:02d}:{seconds:02d}{fractionalSeperator}{milliseconds:03d}"
|
52 |
+
|
53 |
+
|
54 |
+
def write_txt(transcript: Iterator[dict], file: TextIO):
|
55 |
+
for segment in transcript:
|
56 |
+
print(segment['text'].strip(), file=file, flush=True)
|
57 |
+
|
58 |
+
|
59 |
+
def write_vtt(transcript: Iterator[dict], file: TextIO,
|
60 |
+
maxLineWidth=None, highlight_words: bool = False):
|
61 |
+
iterator = __subtitle_preprocessor_iterator(transcript, maxLineWidth, highlight_words)
|
62 |
+
|
63 |
+
print("WEBVTT\n", file=file)
|
64 |
+
|
65 |
+
for segment in iterator:
|
66 |
+
text = segment['text'].replace('-->', '->')
|
67 |
+
|
68 |
+
print(
|
69 |
+
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
|
70 |
+
f"{text}\n",
|
71 |
+
file=file,
|
72 |
+
flush=True,
|
73 |
+
)
|
74 |
+
|
75 |
+
def write_srt(transcript: Iterator[dict], file: TextIO,
|
76 |
+
maxLineWidth=None, highlight_words: bool = False):
|
77 |
+
"""
|
78 |
+
Write a transcript to a file in SRT format.
|
79 |
+
Example usage:
|
80 |
+
from pathlib import Path
|
81 |
+
from whisper.utils import write_srt
|
82 |
+
result = transcribe(model, audio_path, temperature=temperature, **args)
|
83 |
+
# save SRT
|
84 |
+
audio_basename = Path(audio_path).stem
|
85 |
+
with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
|
86 |
+
write_srt(result["segments"], file=srt)
|
87 |
+
"""
|
88 |
+
iterator = __subtitle_preprocessor_iterator(transcript, maxLineWidth, highlight_words)
|
89 |
+
|
90 |
+
for i, segment in enumerate(iterator, start=1):
|
91 |
+
text = segment['text'].replace('-->', '->')
|
92 |
+
|
93 |
+
# write srt lines
|
94 |
+
print(
|
95 |
+
f"{i}\n"
|
96 |
+
f"{format_timestamp(segment['start'], always_include_hours=True, fractionalSeperator=',')} --> "
|
97 |
+
f"{format_timestamp(segment['end'], always_include_hours=True, fractionalSeperator=',')}\n"
|
98 |
+
f"{text}\n",
|
99 |
+
file=file,
|
100 |
+
flush=True,
|
101 |
+
)
|
102 |
+
|
103 |
+
def write_srt_original(transcript: Iterator[dict], file: TextIO,
|
104 |
+
maxLineWidth=None, highlight_words: bool = False, bilingual: bool = False):
|
105 |
+
"""
|
106 |
+
Write a transcript to a file in SRT format.
|
107 |
+
Example usage:
|
108 |
+
from pathlib import Path
|
109 |
+
from whisper.utils import write_srt
|
110 |
+
result = transcribe(model, audio_path, temperature=temperature, **args)
|
111 |
+
# save SRT
|
112 |
+
audio_basename = Path(audio_path).stem
|
113 |
+
with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
|
114 |
+
write_srt(result["segments"], file=srt)
|
115 |
+
"""
|
116 |
+
iterator = __subtitle_preprocessor_iterator(transcript, maxLineWidth, highlight_words)
|
117 |
+
|
118 |
+
for i, segment in enumerate(iterator, start=1):
|
119 |
+
if "original" not in segment:
|
120 |
+
continue
|
121 |
+
|
122 |
+
original = segment['original'].replace('-->', '->')
|
123 |
+
|
124 |
+
# write srt lines
|
125 |
+
print(
|
126 |
+
f"{i}\n"
|
127 |
+
f"{format_timestamp(segment['start'], always_include_hours=True, fractionalSeperator=',')} --> "
|
128 |
+
f"{format_timestamp(segment['end'], always_include_hours=True, fractionalSeperator=',')}",
|
129 |
+
file=file,
|
130 |
+
flush=True,
|
131 |
+
)
|
132 |
+
|
133 |
+
if original is not None: print(f"{original}\n",
|
134 |
+
file=file,
|
135 |
+
flush=True)
|
136 |
+
|
137 |
+
if bilingual:
|
138 |
+
text = segment['text'].replace('-->', '->')
|
139 |
+
print(f"{text}\n",
|
140 |
+
file=file,
|
141 |
+
flush=True)
|
142 |
+
|
143 |
+
def __subtitle_preprocessor_iterator(transcript: Iterator[dict], maxLineWidth: int = None, highlight_words: bool = False):
|
144 |
+
for segment in transcript:
|
145 |
+
words: list = segment.get('words', [])
|
146 |
+
|
147 |
+
# Append longest speaker ID if available
|
148 |
+
segment_longest_speaker = segment.get('longest_speaker', None)
|
149 |
+
|
150 |
+
# Yield the segment as-is or processed
|
151 |
+
if len(words) == 0 and (maxLineWidth is None or maxLineWidth < 0) and segment_longest_speaker is None:
|
152 |
+
yield segment
|
153 |
+
|
154 |
+
if segment_longest_speaker is not None:
|
155 |
+
segment_longest_speaker = segment_longest_speaker.replace("SPEAKER", "S")
|
156 |
+
|
157 |
+
subtitle_start = segment['start']
|
158 |
+
subtitle_end = segment['end']
|
159 |
+
text = segment['text'].strip()
|
160 |
+
original_text = segment['original'].strip() if 'original' in segment else None
|
161 |
+
|
162 |
+
if len(words) == 0:
|
163 |
+
# Prepend the longest speaker ID if available
|
164 |
+
if segment_longest_speaker is not None:
|
165 |
+
text = f"({segment_longest_speaker}) {text}"
|
166 |
+
|
167 |
+
result = {
|
168 |
+
'start': subtitle_start,
|
169 |
+
'end' : subtitle_end,
|
170 |
+
'text' : process_text(text, maxLineWidth)
|
171 |
+
}
|
172 |
+
if original_text is not None and len(original_text) > 0:
|
173 |
+
result.update({'original': process_text(original_text, maxLineWidth)})
|
174 |
+
yield result
|
175 |
+
|
176 |
+
# We are done
|
177 |
+
continue
|
178 |
+
|
179 |
+
if segment_longest_speaker is not None:
|
180 |
+
# Add the beginning
|
181 |
+
words.insert(0, {
|
182 |
+
'start': subtitle_start,
|
183 |
+
'end' : subtitle_start,
|
184 |
+
'word' : f"({segment_longest_speaker})"
|
185 |
+
})
|
186 |
+
|
187 |
+
text_words = [text] if not highlight_words and original_text is not None and len(original_text) > 0 else [ this_word["word"] for this_word in words ]
|
188 |
+
subtitle_text = __join_words(text_words, maxLineWidth)
|
189 |
+
|
190 |
+
# Iterate over the words in the segment
|
191 |
+
if highlight_words:
|
192 |
+
last = subtitle_start
|
193 |
+
|
194 |
+
for i, this_word in enumerate(words):
|
195 |
+
start = this_word['start']
|
196 |
+
end = this_word['end']
|
197 |
+
|
198 |
+
if last != start:
|
199 |
+
# Display the text up to this point
|
200 |
+
yield {
|
201 |
+
'start': last,
|
202 |
+
'end' : start,
|
203 |
+
'text' : subtitle_text
|
204 |
+
}
|
205 |
+
|
206 |
+
# Display the text with the current word highlighted
|
207 |
+
yield {
|
208 |
+
'start': start,
|
209 |
+
'end' : end,
|
210 |
+
'text' : __join_words(
|
211 |
+
[
|
212 |
+
{
|
213 |
+
"word": re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
|
214 |
+
if j == i
|
215 |
+
else word,
|
216 |
+
# The HTML tags <u> and </u> are not displayed,
|
217 |
+
# # so they should not be counted in the word length
|
218 |
+
"length": len(word)
|
219 |
+
} for j, word in enumerate(text_words)
|
220 |
+
], maxLineWidth)
|
221 |
+
}
|
222 |
+
last = end
|
223 |
+
|
224 |
+
if last != subtitle_end:
|
225 |
+
# Display the last part of the text
|
226 |
+
yield {
|
227 |
+
'start': last,
|
228 |
+
'end' : subtitle_end,
|
229 |
+
'text' : subtitle_text
|
230 |
+
}
|
231 |
+
|
232 |
+
# Just return the subtitle text
|
233 |
+
else:
|
234 |
+
result = {
|
235 |
+
'start': subtitle_start,
|
236 |
+
'end' : subtitle_end,
|
237 |
+
'text' : subtitle_text
|
238 |
+
}
|
239 |
+
if original_text is not None and len(original_text) > 0:
|
240 |
+
result.update({'original': original_text})
|
241 |
+
yield result
|
242 |
+
|
243 |
+
def __join_words(words: Iterator[Union[str, dict]], maxLineWidth: int = None):
|
244 |
+
if maxLineWidth is None or maxLineWidth < 0:
|
245 |
+
return " ".join(words)
|
246 |
+
|
247 |
+
lines = []
|
248 |
+
current_line = ""
|
249 |
+
current_length = 0
|
250 |
+
|
251 |
+
for entry in words:
|
252 |
+
# Either accept a string or a dict with a 'word' and 'length' field
|
253 |
+
if isinstance(entry, dict):
|
254 |
+
word = entry['word']
|
255 |
+
word_length = entry['length']
|
256 |
+
else:
|
257 |
+
word = entry
|
258 |
+
word_length = len(word)
|
259 |
+
|
260 |
+
if current_length > 0 and current_length + word_length > maxLineWidth:
|
261 |
+
lines.append(current_line)
|
262 |
+
current_line = ""
|
263 |
+
current_length = 0
|
264 |
+
|
265 |
+
current_length += word_length
|
266 |
+
# The word will be prefixed with a space by Whisper, so we don't need to add one here
|
267 |
+
current_line += word
|
268 |
+
|
269 |
+
if len(current_line) > 0:
|
270 |
+
lines.append(current_line)
|
271 |
+
|
272 |
+
return "\n".join(lines)
|
273 |
+
|
274 |
+
def process_text(text: str, maxLineWidth=None):
|
275 |
+
if (maxLineWidth is None or maxLineWidth < 0):
|
276 |
+
return text
|
277 |
+
|
278 |
+
lines = textwrap.wrap(text, width=maxLineWidth, tabsize=4)
|
279 |
+
return '\n'.join(lines)
|
280 |
+
|
281 |
+
def slugify(value, allow_unicode=False, is_lower=False):
|
282 |
+
"""
|
283 |
+
Taken from https://github.com/django/django/blob/master/django/utils/text.py
|
284 |
+
Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
|
285 |
+
dashes to single dashes. Remove characters that aren't alphanumerics,
|
286 |
+
underscores, or hyphens. Convert to lowercase. Also strip leading and
|
287 |
+
trailing whitespace, dashes, and underscores.
|
288 |
+
"""
|
289 |
+
value = str(value)
|
290 |
+
if allow_unicode:
|
291 |
+
value = unicodedata.normalize('NFKC', value)
|
292 |
+
else:
|
293 |
+
value = unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii')
|
294 |
+
if is_lower:
|
295 |
+
value = value.lower()
|
296 |
+
value = re.sub(r'[^\w\s-]', '', value.replace("/","_").replace("⧸","_"))
|
297 |
+
return re.sub(r'[-\s]+', '-', value).strip('-_')
|
298 |
+
|
299 |
+
def download_file(url: str, destination: str):
|
300 |
+
with urllib3.request.urlopen(url) as source, open(destination, "wb") as output:
|
301 |
+
with tqdm(
|
302 |
+
total=int(source.info().get("Content-Length")),
|
303 |
+
ncols=80,
|
304 |
+
unit="iB",
|
305 |
+
unit_scale=True,
|
306 |
+
unit_divisor=1024,
|
307 |
+
) as loop:
|
308 |
+
while True:
|
309 |
+
buffer = source.read(8192)
|
310 |
+
if not buffer:
|
311 |
+
break
|
312 |
+
|
313 |
+
output.write(buffer)
|
314 |
+
loop.update(len(buffer))
|