avans06 commited on
Commit
8077be2
1 Parent(s): 7bdbe2a

Added support for translation models (NLLB, NLLB-CT2, MT5)

Browse files

to provide full translation capabilities for Whisper.

The interface now includes optional selection of NLLB Model (for translate) and NLLB Language. If not selected, the translation feature will not be activated.
__________________

Whisper’s Task ‘translate’ only implements the functionality of translating other languages into English. OpenAI does not guarantee translations between arbitrary languages. In such cases, you can choose to use the NLLB Model to implement the translation task. However, it’s important to note that the NLLB Model runs slowly, and the completion time may be twice as long as usual.

The larger the parameters of the NLLB model, the better its performance is expected to be. However, it also requires higher computational resources, making it slower to operate. On the other hand, the version converted from ct2 (CTranslate2) requires lower resources and operates at a faster speed.

Currently, enabling word-level timestamps cannot be used in conjunction with NLLB Model translation because Word Timestamps will split the source text, and after translation, it becomes a non-word-level string.

The ‘mt5-zh-ja-en-trimmed’ model is finetuned from Google’s ‘mt5-base’ model. This model has a relatively good translation speed, but it only supports three languages: Chinese, Japanese, and English.

README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Faster Whisper Webui
3
  emoji: ✨
4
  colorFrom: blue
5
  colorTo: purple
 
1
  ---
2
+ title: Faster Whisper Webui with translate
3
  emoji: ✨
4
  colorFrom: blue
5
  colorTo: purple
app.py CHANGED
@@ -5,8 +5,8 @@ from typing import Iterator, Union
5
  import argparse
6
 
7
  from io import StringIO
 
8
  import os
9
- import pathlib
10
  import tempfile
11
  import zipfile
12
  import numpy as np
@@ -37,9 +37,14 @@ from src.utils import optional_int, slugify, write_srt, write_vtt
37
  from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
38
  from src.whisper.abstractWhisperContainer import AbstractWhisperContainer
39
  from src.whisper.whisperFactory import create_whisper_container
 
 
 
 
40
 
41
  import shutil
42
  import zhconv
 
43
 
44
  # Configure more application defaults in config.json5
45
 
@@ -92,26 +97,26 @@ class WhisperTranscriber:
92
  print("[Auto parallel] Using GPU devices " + str(self.parallel_device_list) + " and " + str(self.vad_cpu_cores) + " CPU cores for VAD/transcription.")
93
 
94
  # Entry function for the simple tab
95
- def transcribe_webui_simple(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
96
  vad, vadMergeWindow, vadMaxMergeSize,
97
  word_timestamps: bool = False, highlight_words: bool = False):
98
- return self.transcribe_webui_simple_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task,
99
  vad, vadMergeWindow, vadMaxMergeSize,
100
  word_timestamps, highlight_words)
101
 
102
  # Entry function for the simple tab progress
103
- def transcribe_webui_simple_progress(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
104
  vad, vadMergeWindow, vadMaxMergeSize,
105
  word_timestamps: bool = False, highlight_words: bool = False,
106
  progress=gr.Progress()):
107
 
108
  vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, self.app_config.vad_padding, self.app_config.vad_prompt_window, self.app_config.vad_initial_prompt_mode)
109
 
110
- return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vadOptions,
111
  word_timestamps=word_timestamps, highlight_words=highlight_words, progress=progress)
112
 
113
  # Entry function for the full tab
114
- def transcribe_webui_full(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
115
  vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
116
  # Word timestamps
117
  word_timestamps: bool, highlight_words: bool, prepend_punctuations: str, append_punctuations: str,
@@ -119,7 +124,7 @@ class WhisperTranscriber:
119
  condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
120
  compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float):
121
 
122
- return self.transcribe_webui_full_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task,
123
  vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
124
  word_timestamps, highlight_words, prepend_punctuations, append_punctuations,
125
  initial_prompt, temperature, best_of, beam_size, patience, length_penalty, suppress_tokens,
@@ -127,7 +132,7 @@ class WhisperTranscriber:
127
  compression_ratio_threshold, logprob_threshold, no_speech_threshold)
128
 
129
  # Entry function for the full tab with progress
130
- def transcribe_webui_full_progress(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
131
  vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
132
  # Word timestamps
133
  word_timestamps: bool, highlight_words: bool, prepend_punctuations: str, append_punctuations: str,
@@ -144,21 +149,21 @@ class WhisperTranscriber:
144
 
145
  vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode)
146
 
147
- return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vadOptions,
148
  initial_prompt=initial_prompt, temperature=temperature, best_of=best_of, beam_size=beam_size, patience=patience, length_penalty=length_penalty, suppress_tokens=suppress_tokens,
149
  condition_on_previous_text=condition_on_previous_text, fp16=fp16,
150
  compression_ratio_threshold=compression_ratio_threshold, logprob_threshold=logprob_threshold, no_speech_threshold=no_speech_threshold,
151
  word_timestamps=word_timestamps, prepend_punctuations=prepend_punctuations, append_punctuations=append_punctuations, highlight_words=highlight_words,
152
  progress=progress)
153
 
154
- def transcribe_webui(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
155
  vadOptions: VadOptions, progress: gr.Progress = None, highlight_words: bool = False,
156
  **decodeOptions: dict):
157
  try:
158
  sources = self.__get_source(urlData, multipleFiles, microphoneData)
159
 
160
  try:
161
- langObj = get_language_from_name(languageName)
162
  selectedLanguage = languageName.lower() if languageName is not None and len(languageName) > 0 else None
163
  selectedModel = modelName if modelName is not None else "base"
164
 
@@ -166,6 +171,12 @@ class WhisperTranscriber:
166
  model_name=selectedModel, compute_type=self.app_config.compute_type,
167
  cache=self.model_cache, models=self.app_config.models)
168
 
 
 
 
 
 
 
169
  # Result
170
  download = []
171
  zip_file_lookup = {}
@@ -208,7 +219,7 @@ class WhisperTranscriber:
208
  # Update progress
209
  current_progress += source_audio_duration
210
 
211
- source_download, source_text, source_vtt = self.write_result(result, filePrefix, outputDirectory, highlight_words)
212
 
213
  if len(sources) > 1:
214
  # Add new line separators
@@ -252,30 +263,19 @@ class WhisperTranscriber:
252
  return download, text, vtt
253
 
254
  finally:
255
- if languageName == "Chinese":
256
- for file_path in source_download:
257
- try:
258
- with open(file_path, "r+", encoding="utf-8") as source:
259
- content = source.read()
260
- content = zhconv.convert(content, "zh-tw")
261
- source.seek(0)
262
- source.write(content)
263
- except Exception as e:
264
- # Ignore error - it's just a cleanup
265
- print("Error converting Traditional Chinese with download source file: \n" + file_path + ", \n" + str(e))
266
-
267
  # Cleanup source
268
  if self.deleteUploadedFiles:
269
  for source in sources:
270
  if self.app_config.merge_subtitle_with_sources and self.app_config.output_dir is not None and len(source_download) > 0:
271
- print("merge subtitle(srt) with source file [" + source.source_name + "]")
272
  outRsult = ""
273
  try:
274
  srt_path = source_download[0]
275
  save_path = os.path.join(self.app_config.output_dir, source.source_name)
276
  save_without_ext, ext = os.path.splitext(save_path)
277
- lang_ext = "." + langObj.code if langObj is not None else ""
278
- output_with_srt = save_without_ext + lang_ext + ext
 
279
 
280
  #ffmpeg -i "input.mp4" -i "input.srt" -c copy -c:s mov_text output.mp4
281
  input_file = ffmpeg.input(source.source_path)
@@ -435,20 +435,41 @@ class WhisperTranscriber:
435
 
436
  return config
437
 
438
- def write_result(self, result: dict, source_name: str, output_dir: str, highlight_words: bool = False):
439
  if not os.path.exists(output_dir):
440
  os.makedirs(output_dir)
441
 
442
  text = result["text"]
 
443
  language = result["language"]
444
  languageMaxLineWidth = self.__get_max_line_width(language)
445
 
446
- print("Max line width " + str(languageMaxLineWidth))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
447
  vtt = self.__get_subs(result["segments"], "vtt", languageMaxLineWidth, highlight_words=highlight_words)
448
  srt = self.__get_subs(result["segments"], "srt", languageMaxLineWidth, highlight_words=highlight_words)
449
  json_result = json.dumps(result, indent=4, ensure_ascii=False)
450
 
451
- if language == "zh":
452
  vtt = zhconv.convert(vtt, "zh-tw")
453
  srt = zhconv.convert(srt, "zh-tw")
454
  text = zhconv.convert(text, "zh-tw")
@@ -541,12 +562,29 @@ def create_ui(app_config: ApplicationConfig):
541
  ui_description += "\n\n" + "Max audio file length: " + str(app_config.input_audio_max_duration) + " s"
542
 
543
  ui_article = "Read the [documentation here](https://gitlab.com/aadnk/whisper-webui/-/blob/main/docs/options.md)."
 
 
 
 
 
 
 
 
 
 
544
 
545
  whisper_models = app_config.get_model_names()
546
-
547
- common_inputs = lambda : [
548
- gr.Dropdown(choices=whisper_models, value=app_config.default_model_name, label="Model"),
549
- gr.Dropdown(choices=sorted(get_language_names()), label="Language", value=app_config.language),
 
 
 
 
 
 
 
550
  gr.Text(label="URL (YouTube, etc.)"),
551
  gr.File(label="Upload Files", file_count="multiple"),
552
  gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
@@ -579,7 +617,13 @@ def create_ui(app_config: ApplicationConfig):
579
  with gr.Row():
580
  with gr.Column():
581
  simple_submit = gr.Button("Submit", variant="primary")
582
- simple_input = common_inputs() + common_vad_inputs() + common_word_timestamps_inputs()
 
 
 
 
 
 
583
  with gr.Column():
584
  simple_output = common_output()
585
  simple_flag = gr.Button("Flag")
@@ -602,27 +646,33 @@ def create_ui(app_config: ApplicationConfig):
602
  with gr.Row():
603
  with gr.Column():
604
  full_submit = gr.Button("Submit", variant="primary")
605
- full_input1 = common_inputs() + common_vad_inputs() + [
606
- gr.Number(label="VAD - Padding (s)", precision=None, value=app_config.vad_padding),
607
- gr.Number(label="VAD - Prompt Window (s)", precision=None, value=app_config.vad_prompt_window),
608
- gr.Dropdown(choices=VAD_INITIAL_PROMPT_MODE_VALUES, label="VAD - Initial Prompt Mode")]
609
-
610
- full_input2 = common_word_timestamps_inputs() + [
611
- gr.Text(label="Word Timestamps - Prepend Punctuations", value=app_config.prepend_punctuations),
612
- gr.Text(label="Word Timestamps - Append Punctuations", value=app_config.append_punctuations),
613
- gr.TextArea(label="Initial Prompt"),
614
- gr.Number(label="Temperature", value=app_config.temperature),
615
- gr.Number(label="Best Of - Non-zero temperature", value=app_config.best_of, precision=0),
616
- gr.Number(label="Beam Size - Zero temperature", value=app_config.beam_size, precision=0),
617
- gr.Number(label="Patience - Zero temperature", value=app_config.patience),
618
- gr.Number(label="Length Penalty - Any temperature", value=app_config.length_penalty),
619
- gr.Text(label="Suppress Tokens - Comma-separated list of token IDs", value=app_config.suppress_tokens),
620
- gr.Checkbox(label="Condition on previous text", value=app_config.condition_on_previous_text),
621
- gr.Checkbox(label="FP16", value=app_config.fp16),
622
- gr.Number(label="Temperature increment on fallback", value=app_config.temperature_increment_on_fallback),
623
- gr.Number(label="Compression ratio threshold", value=app_config.compression_ratio_threshold),
624
- gr.Number(label="Logprob threshold", value=app_config.logprob_threshold),
625
- gr.Number(label="No speech threshold", value=app_config.no_speech_threshold)]
 
 
 
 
 
 
626
 
627
  with gr.Column():
628
  full_output = common_output()
@@ -654,6 +704,7 @@ def create_ui(app_config: ApplicationConfig):
654
  if __name__ == '__main__':
655
  default_app_config = ApplicationConfig.create_default()
656
  whisper_models = default_app_config.get_model_names()
 
657
 
658
  # Environment variable overrides
659
  default_whisper_implementation = os.environ.get("WHISPER_IMPLEMENTATION", default_app_config.whisper_implementation)
@@ -707,6 +758,14 @@ if __name__ == '__main__':
707
 
708
  updated_config = default_app_config.update(**args)
709
 
 
 
 
 
 
 
 
 
710
  if (threads := args.pop("threads")) > 0:
711
  torch.set_num_threads(threads)
712
 
 
5
  import argparse
6
 
7
  from io import StringIO
8
+ import time
9
  import os
 
10
  import tempfile
11
  import zipfile
12
  import numpy as np
 
37
  from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
38
  from src.whisper.abstractWhisperContainer import AbstractWhisperContainer
39
  from src.whisper.whisperFactory import create_whisper_container
40
+ from src.nllb.nllbModel import NllbModel
41
+ from src.nllb.nllbLangs import _TO_NLLB_LANG_CODE
42
+ from src.nllb.nllbLangs import get_nllb_lang_names
43
+ from src.nllb.nllbLangs import get_nllb_lang_from_name
44
 
45
  import shutil
46
  import zhconv
47
+ import tqdm
48
 
49
  # Configure more application defaults in config.json5
50
 
 
97
  print("[Auto parallel] Using GPU devices " + str(self.parallel_device_list) + " and " + str(self.vad_cpu_cores) + " CPU cores for VAD/transcription.")
98
 
99
  # Entry function for the simple tab
100
+ def transcribe_webui_simple(self, modelName, languageName, nllbModelName, nllbLangName, urlData, multipleFiles, microphoneData, task,
101
  vad, vadMergeWindow, vadMaxMergeSize,
102
  word_timestamps: bool = False, highlight_words: bool = False):
103
+ return self.transcribe_webui_simple_progress(modelName, languageName, nllbModelName, nllbLangName, urlData, multipleFiles, microphoneData, task,
104
  vad, vadMergeWindow, vadMaxMergeSize,
105
  word_timestamps, highlight_words)
106
 
107
  # Entry function for the simple tab progress
108
+ def transcribe_webui_simple_progress(self, modelName, languageName, nllbModelName, nllbLangName, urlData, multipleFiles, microphoneData, task,
109
  vad, vadMergeWindow, vadMaxMergeSize,
110
  word_timestamps: bool = False, highlight_words: bool = False,
111
  progress=gr.Progress()):
112
 
113
  vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, self.app_config.vad_padding, self.app_config.vad_prompt_window, self.app_config.vad_initial_prompt_mode)
114
 
115
+ return self.transcribe_webui(modelName, languageName, nllbModelName, nllbLangName, urlData, multipleFiles, microphoneData, task, vadOptions,
116
  word_timestamps=word_timestamps, highlight_words=highlight_words, progress=progress)
117
 
118
  # Entry function for the full tab
119
+ def transcribe_webui_full(self, modelName, languageName, nllbModelName, nllbLangName, urlData, multipleFiles, microphoneData, task,
120
  vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
121
  # Word timestamps
122
  word_timestamps: bool, highlight_words: bool, prepend_punctuations: str, append_punctuations: str,
 
124
  condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
125
  compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float):
126
 
127
+ return self.transcribe_webui_full_progress(modelName, languageName, nllbModelName, nllbLangName, urlData, multipleFiles, microphoneData, task,
128
  vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
129
  word_timestamps, highlight_words, prepend_punctuations, append_punctuations,
130
  initial_prompt, temperature, best_of, beam_size, patience, length_penalty, suppress_tokens,
 
132
  compression_ratio_threshold, logprob_threshold, no_speech_threshold)
133
 
134
  # Entry function for the full tab with progress
135
+ def transcribe_webui_full_progress(self, modelName, languageName, nllbModelName, nllbLangName, urlData, multipleFiles, microphoneData, task,
136
  vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
137
  # Word timestamps
138
  word_timestamps: bool, highlight_words: bool, prepend_punctuations: str, append_punctuations: str,
 
149
 
150
  vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode)
151
 
152
+ return self.transcribe_webui(modelName, languageName, nllbModelName, nllbLangName, urlData, multipleFiles, microphoneData, task, vadOptions,
153
  initial_prompt=initial_prompt, temperature=temperature, best_of=best_of, beam_size=beam_size, patience=patience, length_penalty=length_penalty, suppress_tokens=suppress_tokens,
154
  condition_on_previous_text=condition_on_previous_text, fp16=fp16,
155
  compression_ratio_threshold=compression_ratio_threshold, logprob_threshold=logprob_threshold, no_speech_threshold=no_speech_threshold,
156
  word_timestamps=word_timestamps, prepend_punctuations=prepend_punctuations, append_punctuations=append_punctuations, highlight_words=highlight_words,
157
  progress=progress)
158
 
159
+ def transcribe_webui(self, modelName, languageName, nllbModelName, nllbLangName, urlData, multipleFiles, microphoneData, task,
160
  vadOptions: VadOptions, progress: gr.Progress = None, highlight_words: bool = False,
161
  **decodeOptions: dict):
162
  try:
163
  sources = self.__get_source(urlData, multipleFiles, microphoneData)
164
 
165
  try:
166
+ whisper_lang = get_language_from_name(languageName)
167
  selectedLanguage = languageName.lower() if languageName is not None and len(languageName) > 0 else None
168
  selectedModel = modelName if modelName is not None else "base"
169
 
 
171
  model_name=selectedModel, compute_type=self.app_config.compute_type,
172
  cache=self.model_cache, models=self.app_config.models)
173
 
174
+ nllb_lang = get_nllb_lang_from_name(nllbLangName)
175
+ selectedNllbModelName = nllbModelName if nllbModelName is not None and len(nllbModelName) > 0 else "nllb-200-distilled-600M/facebook"
176
+ selectedNllbModel = next((modelConfig for modelConfig in self.app_config.nllb_models if modelConfig.name == selectedNllbModelName), None)
177
+
178
+ nllb_model = NllbModel(model_config=selectedNllbModel, whisper_lang=whisper_lang, nllb_lang=nllb_lang) # load_model=True
179
+
180
  # Result
181
  download = []
182
  zip_file_lookup = {}
 
219
  # Update progress
220
  current_progress += source_audio_duration
221
 
222
+ source_download, source_text, source_vtt = self.write_result(result, nllb_model, filePrefix, outputDirectory, highlight_words)
223
 
224
  if len(sources) > 1:
225
  # Add new line separators
 
263
  return download, text, vtt
264
 
265
  finally:
 
 
 
 
 
 
 
 
 
 
 
 
266
  # Cleanup source
267
  if self.deleteUploadedFiles:
268
  for source in sources:
269
  if self.app_config.merge_subtitle_with_sources and self.app_config.output_dir is not None and len(source_download) > 0:
270
+ print("\nmerge subtitle(srt) with source file [" + source.source_name + "]\n")
271
  outRsult = ""
272
  try:
273
  srt_path = source_download[0]
274
  save_path = os.path.join(self.app_config.output_dir, source.source_name)
275
  save_without_ext, ext = os.path.splitext(save_path)
276
+ source_lang = "." + whisper_lang.code if whisper_lang is not None else ""
277
+ translate_lang = "." + nllb_lang.code if nllb_lang is not None else ""
278
+ output_with_srt = save_without_ext + source_lang + translate_lang + ext
279
 
280
  #ffmpeg -i "input.mp4" -i "input.srt" -c copy -c:s mov_text output.mp4
281
  input_file = ffmpeg.input(source.source_path)
 
435
 
436
  return config
437
 
438
+ def write_result(self, result: dict, nllb_model: NllbModel, source_name: str, output_dir: str, highlight_words: bool = False):
439
  if not os.path.exists(output_dir):
440
  os.makedirs(output_dir)
441
 
442
  text = result["text"]
443
+ segments = result["segments"]
444
  language = result["language"]
445
  languageMaxLineWidth = self.__get_max_line_width(language)
446
 
447
+ if nllb_model.nllb_lang is not None:
448
+ try:
449
+ pbar = tqdm.tqdm(total=len(segments))
450
+ perf_start_time = time.perf_counter()
451
+ nllb_model.load_model()
452
+ for idx, segment in enumerate(segments):
453
+ seg_text = segment["text"]
454
+ if language == "zh":
455
+ segment["text"] = zhconv.convert(seg_text, "zh-tw")
456
+ if nllb_model.nllb_lang is not None:
457
+ segment["text"] = nllb_model.translation(seg_text)
458
+ pbar.update(1)
459
+
460
+ nllb_model.release_vram()
461
+ perf_end_time = time.perf_counter()
462
+ print("\n\nprocess segments took {} seconds.\n\n".format(perf_end_time - perf_start_time))
463
+ except Exception as e:
464
+ # Ignore error - it's just a cleanup
465
+ print("Error process segments: " + str(e))
466
+
467
+ print("Max line width " + str(languageMaxLineWidth) + " for language:" + language)
468
  vtt = self.__get_subs(result["segments"], "vtt", languageMaxLineWidth, highlight_words=highlight_words)
469
  srt = self.__get_subs(result["segments"], "srt", languageMaxLineWidth, highlight_words=highlight_words)
470
  json_result = json.dumps(result, indent=4, ensure_ascii=False)
471
 
472
+ if language == "zh" or (nllb_model.nllb_lang is not None and nllb_model.nllb_lang.code == "zho_Hant"):
473
  vtt = zhconv.convert(vtt, "zh-tw")
474
  srt = zhconv.convert(srt, "zh-tw")
475
  text = zhconv.convert(text, "zh-tw")
 
562
  ui_description += "\n\n" + "Max audio file length: " + str(app_config.input_audio_max_duration) + " s"
563
 
564
  ui_article = "Read the [documentation here](https://gitlab.com/aadnk/whisper-webui/-/blob/main/docs/options.md)."
565
+ ui_article += "\n\nWhisper's Task 'translate' only implements the functionality of translating other languages into English. "
566
+ ui_article += "OpenAI does not guarantee translations between arbitrary languages. In such cases, you can choose to use the NLLB Model to implement the translation task. "
567
+ ui_article += "However, it's important to note that the NLLB Model runs slowly, and the completion time may be twice as long as usual. "
568
+ ui_article += "\n\nThe larger the parameters of the NLLB model, the better its performance is expected to be. "
569
+ ui_article += "However, it also requires higher computational resources, making it slower to operate. "
570
+ ui_article += "On the other hand, the version converted from ct2 (CTranslate2) requires lower resources and operates at a faster speed."
571
+ ui_article += "\n\nCurrently, enabling word-level timestamps cannot be used in conjunction with NLLB Model translation "
572
+ ui_article += "because Word Timestamps will split the source text, and after translation, it becomes a non-word-level string. "
573
+ ui_article += "\n\nThe 'mt5-zh-ja-en-trimmed' model is finetuned from Google's 'mt5-base' model. "
574
+ ui_article += "This model has a relatively good translation speed, but it only supports three languages: Chinese, Japanese, and English. "
575
 
576
  whisper_models = app_config.get_model_names()
577
+ nllb_models = app_config.get_nllb_model_names()
578
+
579
+ common_whisper_inputs = lambda : [
580
+ gr.Dropdown(label="Whisper Model (for audio)", choices=whisper_models, value=app_config.default_model_name),
581
+ gr.Dropdown(label="Whisper Language", choices=sorted(get_language_names()), value=app_config.language),
582
+ ]
583
+ common_nllb_inputs = lambda : [
584
+ gr.Dropdown(label="NLLB Model (for translate)", choices=nllb_models),
585
+ gr.Dropdown(label="NLLB Language", choices=sorted(get_nllb_lang_names())),
586
+ ]
587
+ common_audio_inputs = lambda : [
588
  gr.Text(label="URL (YouTube, etc.)"),
589
  gr.File(label="Upload Files", file_count="multiple"),
590
  gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
 
617
  with gr.Row():
618
  with gr.Column():
619
  simple_submit = gr.Button("Submit", variant="primary")
620
+ with gr.Column():
621
+ with gr.Row():
622
+ simple_input = common_whisper_inputs()
623
+ with gr.Row():
624
+ simple_input += common_nllb_inputs()
625
+ with gr.Column():
626
+ simple_input += common_audio_inputs() + common_vad_inputs() + common_word_timestamps_inputs()
627
  with gr.Column():
628
  simple_output = common_output()
629
  simple_flag = gr.Button("Flag")
 
646
  with gr.Row():
647
  with gr.Column():
648
  full_submit = gr.Button("Submit", variant="primary")
649
+ with gr.Column():
650
+ with gr.Row():
651
+ full_input1 = common_whisper_inputs()
652
+ with gr.Row():
653
+ full_input1 += common_nllb_inputs()
654
+ with gr.Column():
655
+ full_input1 += common_audio_inputs() + common_vad_inputs() + [
656
+ gr.Number(label="VAD - Padding (s)", precision=None, value=app_config.vad_padding),
657
+ gr.Number(label="VAD - Prompt Window (s)", precision=None, value=app_config.vad_prompt_window),
658
+ gr.Dropdown(choices=VAD_INITIAL_PROMPT_MODE_VALUES, label="VAD - Initial Prompt Mode")]
659
+
660
+ full_input2 = common_word_timestamps_inputs() + [
661
+ gr.Text(label="Word Timestamps - Prepend Punctuations", value=app_config.prepend_punctuations),
662
+ gr.Text(label="Word Timestamps - Append Punctuations", value=app_config.append_punctuations),
663
+ gr.TextArea(label="Initial Prompt"),
664
+ gr.Number(label="Temperature", value=app_config.temperature),
665
+ gr.Number(label="Best Of - Non-zero temperature", value=app_config.best_of, precision=0),
666
+ gr.Number(label="Beam Size - Zero temperature", value=app_config.beam_size, precision=0),
667
+ gr.Number(label="Patience - Zero temperature", value=app_config.patience),
668
+ gr.Number(label="Length Penalty - Any temperature", value=app_config.length_penalty),
669
+ gr.Text(label="Suppress Tokens - Comma-separated list of token IDs", value=app_config.suppress_tokens),
670
+ gr.Checkbox(label="Condition on previous text", value=app_config.condition_on_previous_text),
671
+ gr.Checkbox(label="FP16", value=app_config.fp16),
672
+ gr.Number(label="Temperature increment on fallback", value=app_config.temperature_increment_on_fallback),
673
+ gr.Number(label="Compression ratio threshold", value=app_config.compression_ratio_threshold),
674
+ gr.Number(label="Logprob threshold", value=app_config.logprob_threshold),
675
+ gr.Number(label="No speech threshold", value=app_config.no_speech_threshold)]
676
 
677
  with gr.Column():
678
  full_output = common_output()
 
704
  if __name__ == '__main__':
705
  default_app_config = ApplicationConfig.create_default()
706
  whisper_models = default_app_config.get_model_names()
707
+ nllb_models = default_app_config.get_nllb_model_names()
708
 
709
  # Environment variable overrides
710
  default_whisper_implementation = os.environ.get("WHISPER_IMPLEMENTATION", default_app_config.whisper_implementation)
 
758
 
759
  updated_config = default_app_config.update(**args)
760
 
761
+ #updated_config.whisper_implementation = "faster-whisper"
762
+ #updated_config.input_audio_max_duration = -1
763
+ #updated_config.default_model_name = "large-v2"
764
+ #updated_config.output_dir = "output"
765
+ #updated_config.vad_max_merge_size = 90
766
+ #updated_config.merge_subtitle_with_sources = True
767
+ #updated_config.autolaunch = True
768
+
769
  if (threads := args.pop("threads")) > 0:
770
  torch.set_num_threads(threads)
771
 
config.json5 CHANGED
@@ -43,6 +43,102 @@
43
  // "url": "https://example.com/path/to/model",
44
  //}
45
  ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  // Configuration options that will be used if they are not specified in the command line arguments.
47
 
48
  // * WEBUI options *
 
43
  // "url": "https://example.com/path/to/model",
44
  //}
45
  ],
46
+ "nllb_models": [
47
+ {
48
+ "name": "nllb-200-distilled-1.3B-ct2fast:int8_float16/michaelfeil",
49
+ "url": "michaelfeil/ct2fast-nllb-200-distilled-1.3B",
50
+ "type": "huggingface"
51
+ },
52
+ {
53
+ "name": "nllb-200-3.3B-ct2fast:int8_float16/michaelfeil",
54
+ "url": "michaelfeil/ct2fast-nllb-200-3.3B",
55
+ "type": "huggingface"
56
+ },
57
+ {
58
+ "name": "nllb-200-1.3B-ct2:float16/JustFrederik",
59
+ "url": "JustFrederik/nllb-200-1.3B-ct2-float16",
60
+ "type": "huggingface"
61
+ },
62
+ {
63
+ "name": "nllb-200-distilled-1.3B-ct2:float16/JustFrederik",
64
+ "url": "JustFrederik/nllb-200-distilled-1.3B-ct2-float16",
65
+ "type": "huggingface"
66
+ },
67
+ {
68
+ "name": "nllb-200-1.3B-ct2:int8/JustFrederik",
69
+ "url": "JustFrederik/nllb-200-1.3B-ct2-int8",
70
+ "type": "huggingface"
71
+ },
72
+ {
73
+ "name": "nllb-200-distilled-1.3B-ct2:int8/JustFrederik",
74
+ "url": "JustFrederik/nllb-200-distilled-1.3B-ct2-int8",
75
+ "type": "huggingface"
76
+ },
77
+ {
78
+ "name": "mt5-zh-ja-en-trimmed/K024",
79
+ "url": "K024/mt5-zh-ja-en-trimmed",
80
+ "type": "huggingface"
81
+ },
82
+ {
83
+ "name": "mt5-zh-ja-en-trimmed-fine-tuned-v1/engmatic-earth",
84
+ "url": "engmatic-earth/mt5-zh-ja-en-trimmed-fine-tuned-v1",
85
+ "type": "huggingface"
86
+ },
87
+ {
88
+ "name": "nllb-200-distilled-600M/facebook",
89
+ "url": "facebook/nllb-200-distilled-600M",
90
+ "type": "huggingface"
91
+ },
92
+ {
93
+ "name": "nllb-200-distilled-600M-ct2/JustFrederik",
94
+ "url": "JustFrederik/nllb-200-distilled-600M-ct2",
95
+ "type": "huggingface"
96
+ },
97
+ {
98
+ "name": "nllb-200-distilled-600M-ct2:float16/JustFrederik",
99
+ "url": "JustFrederik/nllb-200-distilled-600M-ct2-float16",
100
+ "type": "huggingface"
101
+ },
102
+ {
103
+ "name": "nllb-200-distilled-600M-ct2:int8/JustFrederik",
104
+ "url": "JustFrederik/nllb-200-distilled-600M-ct2-int8",
105
+ "type": "huggingface"
106
+ },
107
+ // Uncomment to add official Facebook 1.3B and 3.3B model
108
+ // The official Facebook 1.3B and 3.3B model files are too large,
109
+ // and to avoid occupying too much disk space on Hugging Face's free spaces,
110
+ // these models are not included in the config.
111
+ //{
112
+ // "name": "nllb-200-distilled-1.3B/facebook",
113
+ // "url": "facebook/nllb-200-distilled-1.3B",
114
+ // "type": "huggingface"
115
+ //},
116
+ //{
117
+ // "name": "nllb-200-1.3B/facebook",
118
+ // "url": "facebook/nllb-200-1.3B",
119
+ // "type": "huggingface"
120
+ //},
121
+ //{
122
+ // "name": "nllb-200-3.3B/facebook",
123
+ // "url": "facebook/nllb-200-3.3B",
124
+ // "type": "huggingface"
125
+ //},
126
+ //{
127
+ // "name": "nllb-200-distilled-1.3B-ct2/JustFrederik",
128
+ // "url": "JustFrederik/nllb-200-distilled-1.3B-ct2",
129
+ // "type": "huggingface"
130
+ //},
131
+ //{
132
+ // "name": "nllb-200-1.3B-ct2/JustFrederik",
133
+ // "url": "JustFrederik/nllb-200-1.3B-ct2",
134
+ // "type": "huggingface"
135
+ //},
136
+ //{
137
+ // "name": "nllb-200-3.3B-ct2:float16/JustFrederik",
138
+ // "url": "JustFrederik/nllb-200-3.3B-ct2-float16",
139
+ // "type": "huggingface"
140
+ //},
141
+ ],
142
  // Configuration options that will be used if they are not specified in the command line arguments.
143
 
144
  // * WEBUI options *
requirements-fasterWhisper.txt CHANGED
@@ -1,4 +1,4 @@
1
- ctranslate2
2
  faster-whisper
3
  ffmpeg-python==0.2.0
4
  gradio==3.36.0
@@ -7,4 +7,5 @@ json5
7
  torch
8
  torchaudio
9
  more_itertools
10
- zhconv
 
 
1
+ ctranslate2>=3.16.0
2
  faster-whisper
3
  ffmpeg-python==0.2.0
4
  gradio==3.36.0
 
7
  torch
8
  torchaudio
9
  more_itertools
10
+ zhconv
11
+ sentencepiece
requirements-whisper.txt CHANGED
@@ -7,4 +7,5 @@ yt-dlp
7
  torchaudio
8
  altair
9
  json5
10
- zhconv
 
 
7
  torchaudio
8
  altair
9
  json5
10
+ zhconv
11
+ sentencepiece
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- ctranslate2
2
  faster-whisper
3
  ffmpeg-python==0.2.0
4
  gradio==3.36.0
@@ -7,4 +7,5 @@ json5
7
  torch
8
  torchaudio
9
  more_itertools
10
- zhconv
 
 
1
+ ctranslate2>=3.16.0
2
  faster-whisper
3
  ffmpeg-python==0.2.0
4
  gradio==3.36.0
 
7
  torch
8
  torchaudio
9
  more_itertools
10
+ zhconv
11
+ sentencepiece
src/config.py CHANGED
@@ -47,11 +47,11 @@ class VadInitialPromptMode(Enum):
47
  return None
48
 
49
  class ApplicationConfig:
50
- def __init__(self, models: List[ModelConfig] = [], input_audio_max_duration: int = 600,
51
  share: bool = False, server_name: str = None, server_port: int = 7860,
52
  queue_concurrency_count: int = 1, delete_uploaded_files: bool = True,
53
  whisper_implementation: str = "whisper",
54
- default_model_name: str = "medium", default_vad: str = "silero-vad",
55
  vad_parallel_devices: str = "", vad_cpu_cores: int = 1, vad_process_timeout: int = 1800,
56
  auto_parallel: bool = False, output_dir: str = None,
57
  model_dir: str = None, device: str = None,
@@ -72,6 +72,7 @@ class ApplicationConfig:
72
  highlight_words: bool = False):
73
 
74
  self.models = models
 
75
 
76
  # WebUI settings
77
  self.input_audio_max_duration = input_audio_max_duration
@@ -83,6 +84,7 @@ class ApplicationConfig:
83
 
84
  self.whisper_implementation = whisper_implementation
85
  self.default_model_name = default_model_name
 
86
  self.default_vad = default_vad
87
  self.vad_parallel_devices = vad_parallel_devices
88
  self.vad_cpu_cores = vad_cpu_cores
@@ -124,6 +126,9 @@ class ApplicationConfig:
124
  def get_model_names(self):
125
  return [ x.name for x in self.models ]
126
 
 
 
 
127
  def update(self, **new_values):
128
  result = ApplicationConfig(**self.__dict__)
129
 
@@ -148,7 +153,9 @@ class ApplicationConfig:
148
  # Load using json5
149
  data = json5.load(f)
150
  data_models = data.pop("models", [])
151
-
 
152
  models = [ ModelConfig(**x) for x in data_models ]
 
153
 
154
- return ApplicationConfig(models, **data)
 
47
  return None
48
 
49
  class ApplicationConfig:
50
+ def __init__(self, models: List[ModelConfig] = [], nllb_models: List[ModelConfig] = [], input_audio_max_duration: int = 600,
51
  share: bool = False, server_name: str = None, server_port: int = 7860,
52
  queue_concurrency_count: int = 1, delete_uploaded_files: bool = True,
53
  whisper_implementation: str = "whisper",
54
+ default_model_name: str = "medium", default_nllb_model_name: str = "distilled-600M", default_vad: str = "silero-vad",
55
  vad_parallel_devices: str = "", vad_cpu_cores: int = 1, vad_process_timeout: int = 1800,
56
  auto_parallel: bool = False, output_dir: str = None,
57
  model_dir: str = None, device: str = None,
 
72
  highlight_words: bool = False):
73
 
74
  self.models = models
75
+ self.nllb_models = nllb_models
76
 
77
  # WebUI settings
78
  self.input_audio_max_duration = input_audio_max_duration
 
84
 
85
  self.whisper_implementation = whisper_implementation
86
  self.default_model_name = default_model_name
87
+ self.default_nllb_model_name = default_nllb_model_name
88
  self.default_vad = default_vad
89
  self.vad_parallel_devices = vad_parallel_devices
90
  self.vad_cpu_cores = vad_cpu_cores
 
126
  def get_model_names(self):
127
  return [ x.name for x in self.models ]
128
 
129
+ def get_nllb_model_names(self):
130
+ return [ x.name for x in self.nllb_models ]
131
+
132
  def update(self, **new_values):
133
  result = ApplicationConfig(**self.__dict__)
134
 
 
153
  # Load using json5
154
  data = json5.load(f)
155
  data_models = data.pop("models", [])
156
+ data_nllb_models = data.pop("nllb_models", [])
157
+
158
  models = [ ModelConfig(**x) for x in data_models ]
159
+ nllb_models = [ ModelConfig(**x) for x in data_nllb_models ]
160
 
161
+ return ApplicationConfig(models, nllb_models, **data)
src/nllb/nllbLangs.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class NllbLang():
2
+ def __init__(self, code, name, code_whisper=None, name_whisper=None):
3
+ self.code = code
4
+ self.name = name
5
+ self.code_whisper = code_whisper
6
+ self.name_whisper = name_whisper
7
+
8
+ def __str__(self):
9
+ return "Language(code={}, name={})".format(self.code, self.name)
10
+
11
+ NLLB_LANGS = [
12
+ NllbLang('ace_Arab', 'Acehnese (Arabic script)'),
13
+ NllbLang('ace_Latn', 'Acehnese (Latin script)'),
14
+ NllbLang('acm_Arab', 'Mesopotamian Arabic', 'ar', 'Arabic'),
15
+ NllbLang('acq_Arab', 'Ta’izzi-Adeni Arabic', 'ar', 'Arabic'),
16
+ NllbLang('aeb_Arab', 'Tunisian Arabic'),
17
+ NllbLang('afr_Latn', 'Afrikaans', 'am', 'Amharic'),
18
+ NllbLang('ajp_Arab', 'South Levantine Arabic', 'ar', 'Arabic'),
19
+ NllbLang('aka_Latn', 'Akan'),
20
+ NllbLang('amh_Ethi', 'Amharic'),
21
+ NllbLang('apc_Arab', 'North Levantine Arabic', 'ar', 'Arabic'),
22
+ NllbLang('arb_Arab', 'Modern Standard Arabic', 'ar', 'Arabic'),
23
+ NllbLang('arb_Latn', 'Modern Standard Arabic (Romanized)'),
24
+ NllbLang('ars_Arab', 'Najdi Arabic', 'ar', 'Arabic'),
25
+ NllbLang('ary_Arab', 'Moroccan Arabic', 'ar', 'Arabic'),
26
+ NllbLang('arz_Arab', 'Egyptian Arabic', 'ar', 'Arabic'),
27
+ NllbLang('asm_Beng', 'Assamese', 'as', 'Assamese'),
28
+ NllbLang('ast_Latn', 'Asturian'),
29
+ NllbLang('awa_Deva', 'Awadhi'),
30
+ NllbLang('ayr_Latn', 'Central Aymara'),
31
+ NllbLang('azb_Arab', 'South Azerbaijani', 'az', 'Azerbaijani'),
32
+ NllbLang('azj_Latn', 'North Azerbaijani', 'az', 'Azerbaijani'),
33
+ NllbLang('bak_Cyrl', 'Bashkir', 'ba', 'Bashkir'),
34
+ NllbLang('bam_Latn', 'Bambara'),
35
+ NllbLang('ban_Latn', 'Balinese'),
36
+ NllbLang('bel_Cyrl', 'Belarusian', 'be', 'Belarusian'),
37
+ NllbLang('bem_Latn', 'Bemba'),
38
+ NllbLang('ben_Beng', 'Bengali', 'bn', 'Bengali'),
39
+ NllbLang('bho_Deva', 'Bhojpuri'),
40
+ NllbLang('bjn_Arab', 'Banjar (Arabic script)'),
41
+ NllbLang('bjn_Latn', 'Banjar (Latin script)'),
42
+ NllbLang('bod_Tibt', 'Standard Tibetan', 'bo', 'Tibetan'),
43
+ NllbLang('bos_Latn', 'Bosnian', 'bs', 'Bosnian'),
44
+ NllbLang('bug_Latn', 'Buginese'),
45
+ NllbLang('bul_Cyrl', 'Bulgarian', 'bg', 'Bulgarian'),
46
+ NllbLang('cat_Latn', 'Catalan', 'ca', 'Catalan'),
47
+ NllbLang('ceb_Latn', 'Cebuano'),
48
+ NllbLang('ces_Latn', 'Czech', 'cs', 'Czech'),
49
+ NllbLang('cjk_Latn', 'Chokwe'),
50
+ NllbLang('ckb_Arab', 'Central Kurdish'),
51
+ NllbLang('crh_Latn', 'Crimean Tatar'),
52
+ NllbLang('cym_Latn', 'Welsh', 'cy', 'Welsh'),
53
+ NllbLang('dan_Latn', 'Danish', 'da', 'Danish'),
54
+ NllbLang('deu_Latn', 'German', 'de', 'German'),
55
+ NllbLang('dik_Latn', 'Southwestern Dinka'),
56
+ NllbLang('dyu_Latn', 'Dyula'),
57
+ NllbLang('dzo_Tibt', 'Dzongkha'),
58
+ NllbLang('ell_Grek', 'Greek', 'el', 'Greek'),
59
+ NllbLang('eng_Latn', 'English', 'en', 'English'),
60
+ NllbLang('epo_Latn', 'Esperanto'),
61
+ NllbLang('est_Latn', 'Estonian', 'et', 'Estonian'),
62
+ NllbLang('eus_Latn', 'Basque', 'eu', 'Basque'),
63
+ NllbLang('ewe_Latn', 'Ewe'),
64
+ NllbLang('fao_Latn', 'Faroese', 'fo', 'Faroese'),
65
+ NllbLang('fij_Latn', 'Fijian'),
66
+ NllbLang('fin_Latn', 'Finnish', 'fi', 'Finnish'),
67
+ NllbLang('fon_Latn', 'Fon'),
68
+ NllbLang('fra_Latn', 'French', 'fr', 'French'),
69
+ NllbLang('fur_Latn', 'Friulian'),
70
+ NllbLang('fuv_Latn', 'Nigerian Fulfulde'),
71
+ NllbLang('gla_Latn', 'Scottish Gaelic'),
72
+ NllbLang('gle_Latn', 'Irish'),
73
+ NllbLang('glg_Latn', 'Galician', 'gl', 'Galician'),
74
+ NllbLang('grn_Latn', 'Guarani'),
75
+ NllbLang('guj_Gujr', 'Gujarati', 'gu', 'Gujarati'),
76
+ NllbLang('hat_Latn', 'Haitian Creole', 'ht', 'Haitian creole'),
77
+ NllbLang('hau_Latn', 'Hausa', 'ha', 'Hausa'),
78
+ NllbLang('heb_Hebr', 'Hebrew', 'he', 'Hebrew'),
79
+ NllbLang('hin_Deva', 'Hindi', 'hi', 'Hindi'),
80
+ NllbLang('hne_Deva', 'Chhattisgarhi'),
81
+ NllbLang('hrv_Latn', 'Croatian', 'hr', 'Croatian'),
82
+ NllbLang('hun_Latn', 'Hungarian', 'hu', 'Hungarian'),
83
+ NllbLang('hye_Armn', 'Armenian', 'hy', 'Armenian'),
84
+ NllbLang('ibo_Latn', 'Igbo'),
85
+ NllbLang('ilo_Latn', 'Ilocano'),
86
+ NllbLang('ind_Latn', 'Indonesian', 'id', 'Indonesian'),
87
+ NllbLang('isl_Latn', 'Icelandic', 'is', 'Icelandic'),
88
+ NllbLang('ita_Latn', 'Italian', 'it', 'Italian'),
89
+ NllbLang('jav_Latn', 'Javanese', 'jw', 'Javanese'),
90
+ NllbLang('jpn_Jpan', 'Japanese', 'ja', 'Japanese'),
91
+ NllbLang('kab_Latn', 'Kabyle'),
92
+ NllbLang('kac_Latn', 'Jingpho'),
93
+ NllbLang('kam_Latn', 'Kamba'),
94
+ NllbLang('kan_Knda', 'Kannada', 'kn', 'Kannada'),
95
+ NllbLang('kas_Arab', 'Kashmiri (Arabic script)'),
96
+ NllbLang('kas_Deva', 'Kashmiri (Devanagari script)'),
97
+ NllbLang('kat_Geor', 'Georgian', 'ka', 'Georgian'),
98
+ NllbLang('knc_Arab', 'Central Kanuri (Arabic script)'),
99
+ NllbLang('knc_Latn', 'Central Kanuri (Latin script)'),
100
+ NllbLang('kaz_Cyrl', 'Kazakh', 'kk', 'Kazakh'),
101
+ NllbLang('kbp_Latn', 'Kabiyè'),
102
+ NllbLang('kea_Latn', 'Kabuverdianu'),
103
+ NllbLang('khm_Khmr', 'Khmer', 'km', 'Khmer'),
104
+ NllbLang('kik_Latn', 'Kikuyu'),
105
+ NllbLang('kin_Latn', 'Kinyarwanda'),
106
+ NllbLang('kir_Cyrl', 'Kyrgyz'),
107
+ NllbLang('kmb_Latn', 'Kimbundu'),
108
+ NllbLang('kmr_Latn', 'Northern Kurdish'),
109
+ NllbLang('kon_Latn', 'Kikongo'),
110
+ NllbLang('kor_Hang', 'Korean', 'ko', 'Korean'),
111
+ NllbLang('lao_Laoo', 'Lao', 'lo', 'Lao'),
112
+ NllbLang('lij_Latn', 'Ligurian'),
113
+ NllbLang('lim_Latn', 'Limburgish'),
114
+ NllbLang('lin_Latn', 'Lingala', 'ln', 'Lingala'),
115
+ NllbLang('lit_Latn', 'Lithuanian', 'lt', 'Lithuanian'),
116
+ NllbLang('lmo_Latn', 'Lombard'),
117
+ NllbLang('ltg_Latn', 'Latgalian'),
118
+ NllbLang('ltz_Latn', 'Luxembourgish', 'lb', 'Luxembourgish'),
119
+ NllbLang('lua_Latn', 'Luba-Kasai'),
120
+ NllbLang('lug_Latn', 'Ganda'),
121
+ NllbLang('luo_Latn', 'Luo'),
122
+ NllbLang('lus_Latn', 'Mizo'),
123
+ NllbLang('lvs_Latn', 'Standard Latvian', 'lv', 'Latvian'),
124
+ NllbLang('mag_Deva', 'Magahi'),
125
+ NllbLang('mai_Deva', 'Maithili'),
126
+ NllbLang('mal_Mlym', 'Malayalam', 'ml', 'Malayalam'),
127
+ NllbLang('mar_Deva', 'Marathi', 'mr', 'Marathi'),
128
+ NllbLang('min_Arab', 'Minangkabau (Arabic script)'),
129
+ NllbLang('min_Latn', 'Minangkabau (Latin script)'),
130
+ NllbLang('mkd_Cyrl', 'Macedonian', 'mk', 'Macedonian'),
131
+ NllbLang('plt_Latn', 'Plateau Malagasy', 'mg', 'Malagasy'),
132
+ NllbLang('mlt_Latn', 'Maltese', 'mt', 'Maltese'),
133
+ NllbLang('mni_Beng', 'Meitei (Bengali script)'),
134
+ NllbLang('khk_Cyrl', 'Halh Mongolian', 'mn', 'Mongolian'),
135
+ NllbLang('mos_Latn', 'Mossi'),
136
+ NllbLang('mri_Latn', 'Maori', 'mi', 'Maori'),
137
+ NllbLang('mya_Mymr', 'Burmese', 'my', 'Myanmar'),
138
+ NllbLang('nld_Latn', 'Dutch', 'nl', 'Dutch'),
139
+ NllbLang('nno_Latn', 'Norwegian Nynorsk', 'nn', 'Nynorsk'),
140
+ NllbLang('nob_Latn', 'Norwegian Bokmål', 'no', 'Norwegian'),
141
+ NllbLang('npi_Deva', 'Nepali', 'ne', 'Nepali'),
142
+ NllbLang('nso_Latn', 'Northern Sotho'),
143
+ NllbLang('nus_Latn', 'Nuer'),
144
+ NllbLang('nya_Latn', 'Nyanja'),
145
+ NllbLang('oci_Latn', 'Occitan', 'oc', 'Occitan'),
146
+ NllbLang('gaz_Latn', 'West Central Oromo'),
147
+ NllbLang('ory_Orya', 'Odia'),
148
+ NllbLang('pag_Latn', 'Pangasinan'),
149
+ NllbLang('pan_Guru', 'Eastern Panjabi', 'pa', 'Punjabi'),
150
+ NllbLang('pap_Latn', 'Papiamento'),
151
+ NllbLang('pes_Arab', 'Western Persian', 'fa', 'Persian'),
152
+ NllbLang('pol_Latn', 'Polish', 'pl', 'Polish'),
153
+ NllbLang('por_Latn', 'Portuguese', 'pt', 'Portuguese'),
154
+ NllbLang('prs_Arab', 'Dari'),
155
+ NllbLang('pbt_Arab', 'Southern Pashto', 'ps', 'Pashto'),
156
+ NllbLang('quy_Latn', 'Ayacucho Quechua'),
157
+ NllbLang('ron_Latn', 'Romanian', 'ro', 'Romanian'),
158
+ NllbLang('run_Latn', 'Rundi'),
159
+ NllbLang('rus_Cyrl', 'Russian', 'ru', 'Russian'),
160
+ NllbLang('sag_Latn', 'Sango'),
161
+ NllbLang('san_Deva', 'Sanskrit', 'sa', 'Sanskrit'),
162
+ NllbLang('sat_Olck', 'Santali'),
163
+ NllbLang('scn_Latn', 'Sicilian'),
164
+ NllbLang('shn_Mymr', 'Shan'),
165
+ NllbLang('sin_Sinh', 'Sinhala', 'si', 'Sinhala'),
166
+ NllbLang('slk_Latn', 'Slovak', 'sk', 'Slovak'),
167
+ NllbLang('slv_Latn', 'Slovenian', 'sl', 'Slovenian'),
168
+ NllbLang('smo_Latn', 'Samoan'),
169
+ NllbLang('sna_Latn', 'Shona', 'sn', 'Shona'),
170
+ NllbLang('snd_Arab', 'Sindhi', 'sd', 'Sindhi'),
171
+ NllbLang('som_Latn', 'Somali', 'so', 'Somali'),
172
+ NllbLang('sot_Latn', 'Southern Sotho'),
173
+ NllbLang('spa_Latn', 'Spanish', 'es', 'Spanish'),
174
+ NllbLang('als_Latn', 'Tosk Albanian', 'sq', 'Albanian'),
175
+ NllbLang('srd_Latn', 'Sardinian'),
176
+ NllbLang('srp_Cyrl', 'Serbian', 'sr', 'Serbian'),
177
+ NllbLang('ssw_Latn', 'Swati'),
178
+ NllbLang('sun_Latn', 'Sundanese', 'su', 'Sundanese'),
179
+ NllbLang('swe_Latn', 'Swedish', 'sv', 'Swedish'),
180
+ NllbLang('swh_Latn', 'Swahili', 'sw', 'Swahili'),
181
+ NllbLang('szl_Latn', 'Silesian'),
182
+ NllbLang('tam_Taml', 'Tamil', 'ta', 'Tamil'),
183
+ NllbLang('tat_Cyrl', 'Tatar', 'tt', 'Tatar'),
184
+ NllbLang('tel_Telu', 'Telugu', 'te', 'Telugu'),
185
+ NllbLang('tgk_Cyrl', 'Tajik', 'tg', 'Tajik'),
186
+ NllbLang('tgl_Latn', 'Tagalog', 'tl', 'Tagalog'),
187
+ NllbLang('tha_Thai', 'Thai', 'th', 'Thai'),
188
+ NllbLang('tir_Ethi', 'Tigrinya'),
189
+ NllbLang('taq_Latn', 'Tamasheq (Latin script)'),
190
+ NllbLang('taq_Tfng', 'Tamasheq (Tifinagh script)'),
191
+ NllbLang('tpi_Latn', 'Tok Pisin'),
192
+ NllbLang('tsn_Latn', 'Tswana'),
193
+ NllbLang('tso_Latn', 'Tsonga'),
194
+ NllbLang('tuk_Latn', 'Turkmen', 'tk', 'Turkmen'),
195
+ NllbLang('tum_Latn', 'Tumbuka'),
196
+ NllbLang('tur_Latn', 'Turkish', 'tr', 'Turkish'),
197
+ NllbLang('twi_Latn', 'Twi'),
198
+ NllbLang('tzm_Tfng', 'Central Atlas Tamazight'),
199
+ NllbLang('uig_Arab', 'Uyghur'),
200
+ NllbLang('ukr_Cyrl', 'Ukrainian', 'uk', 'Ukrainian'),
201
+ NllbLang('umb_Latn', 'Umbundu'),
202
+ NllbLang('urd_Arab', 'Urdu', 'ur', 'Urdu'),
203
+ NllbLang('uzn_Latn', 'Northern Uzbek', 'uz', 'Uzbek'),
204
+ NllbLang('vec_Latn', 'Venetian'),
205
+ NllbLang('vie_Latn', 'Vietnamese', 'vi', 'Vietnamese'),
206
+ NllbLang('war_Latn', 'Waray'),
207
+ NllbLang('wol_Latn', 'Wolof'),
208
+ NllbLang('xho_Latn', 'Xhosa'),
209
+ NllbLang('ydd_Hebr', 'Eastern Yiddish', 'yi', 'Yiddish'),
210
+ NllbLang('yor_Latn', 'Yoruba', 'yo', 'Yoruba'),
211
+ NllbLang('yue_Hant', 'Yue Chinese', 'zh', 'Chinese'),
212
+ NllbLang('zho_Hans', 'Chinese (Simplified)', 'zh', 'Chinese'),
213
+ NllbLang('zho_Hant', 'Chinese (Traditional)', 'zh', 'Chinese'),
214
+ NllbLang('zsm_Latn', 'Standard Malay', 'ms', 'Malay'),
215
+ NllbLang('zul_Latn', 'Zulu'),
216
+ ]
217
+
218
+ _TO_NLLB_LANG_CODE = {language.code.lower(): language for language in NLLB_LANGS if language.code is not None}
219
+
220
+ _TO_NLLB_LANG_NAME = {language.name.lower(): language for language in NLLB_LANGS if language.name is not None}
221
+
222
+ _TO_NLLB_LANG_WHISPER_CODE = {language.code_whisper.lower(): language for language in NLLB_LANGS if language.code_whisper is not None}
223
+
224
+ _TO_NLLB_LANG_WHISPER_NAME = {language.name_whisper.lower(): language for language in NLLB_LANGS if language.name_whisper is not None}
225
+
226
+ def get_nllb_lang_from_code(lang_code, default=None) -> NllbLang:
227
+ """Return the language from the language code."""
228
+ return _TO_NLLB_LANG_CODE.get(lang_code, default)
229
+
230
+ def get_nllb_lang_from_name(lang_name, default=None) -> NllbLang:
231
+ """Return the language from the language name."""
232
+ return _TO_NLLB_LANG_NAME.get(lang_name.lower() if lang_name else None, default)
233
+
234
+ def get_nllb_lang_from_code_whisper(lang_code_whisper, default=None) -> NllbLang:
235
+ """Return the language from the language code."""
236
+ return _TO_NLLB_LANG_WHISPER_CODE.get(lang_code_whisper, default)
237
+
238
+ def get_nllb_lang_from_name_whisper(lang_name_whisper, default=None) -> NllbLang:
239
+ """Return the language from the language name."""
240
+ return _TO_NLLB_LANG_WHISPER_NAME.get(lang_name_whisper.lower() if lang_name_whisper else None, default)
241
+
242
+ def get_nllb_lang_names():
243
+ """Return a list of language names."""
244
+ return [language.name for language in NLLB_LANGS]
245
+
246
+ if __name__ == "__main__":
247
+ # Test lookup
248
+ print(get_nllb_lang_from_code('eng_Latn'))
249
+ print(get_nllb_lang_from_name('English'))
250
+
251
+ print(get_nllb_lang_names())
src/nllb/nllbModel.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+ import huggingface_hub
4
+ import requests
5
+ import torch
6
+
7
+ import ctranslate2
8
+ import transformers
9
+
10
+ from typing import Optional
11
+ from src.config import ModelConfig
12
+ from src.languages import Language
13
+ from src.nllb.nllbLangs import NllbLang, get_nllb_lang_from_code_whisper
14
+
15
+ class NllbModel:
16
+ def __init__(
17
+ self,
18
+ model_config: ModelConfig,
19
+ device: str = None,
20
+ whisper_lang: Language = None,
21
+ nllb_lang: NllbLang = None,
22
+ download_root: Optional[str] = None,
23
+ local_files_only: bool = False,
24
+ load_model: bool = False,
25
+ ):
26
+ """Initializes the Nllb-200 model.
27
+
28
+ Args:
29
+ model_config: Config of the model to use (distilled-600M, distilled-1.3B,
30
+ 1.3B, 3.3B...) or a path to a converted
31
+ model directory. When a size is configured, the converted model is downloaded
32
+ from the Hugging Face Hub.
33
+ device: Device to use for computation (cpu, cuda, ipu, xpu, mkldnn, opengl, opencl,
34
+ ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia).
35
+ device_index: Device ID to use.
36
+ The model can also be loaded on multiple GPUs by passing a list of IDs
37
+ (e.g. [0, 1, 2, 3]). In that case, multiple transcriptions can run in parallel
38
+ when transcribe() is called from multiple Python threads (see also num_workers).
39
+ compute_type: Type to use for computation.
40
+ See https://opennmt.net/CTranslate2/quantization.html.
41
+ cpu_threads: Number of threads to use when running on CPU (4 by default).
42
+ A non zero value overrides the OMP_NUM_THREADS environment variable.
43
+ num_workers: When transcribe() is called from multiple Python threads,
44
+ having multiple workers enables true parallelism when running the model
45
+ (concurrent calls to self.model.generate() will run in parallel).
46
+ This can improve the global throughput at the cost of increased memory usage.
47
+ download_root: Directory where the models should be saved. If not set, the models
48
+ are saved in the standard Hugging Face cache directory.
49
+ local_files_only: If True, avoid downloading the file and return the path to the
50
+ local cached file if it exists.
51
+ """
52
+ self.whisper_lang = whisper_lang
53
+ self.nllb_whisper_lang = get_nllb_lang_from_code_whisper(whisper_lang.code.lower() if whisper_lang is not None else "en")
54
+ self.nllb_lang = nllb_lang
55
+ self.model_config = model_config
56
+
57
+ if os.path.isdir(model_config.url):
58
+ self.model_path = model_config.url
59
+ else:
60
+ self.model_path = download_model(
61
+ model_config,
62
+ local_files_only=local_files_only,
63
+ cache_dir=download_root,
64
+ )
65
+
66
+ if device is None:
67
+ if torch.cuda.is_available():
68
+ device = "cuda" if "ct2" in self.model_path else "cuda:0"
69
+ else:
70
+ device = "cpu"
71
+
72
+ self.device = device
73
+
74
+ if load_model:
75
+ self.load_model()
76
+
77
+ def load_model(self):
78
+ print('\n\nLoading model: %s\n\n' % self.model_path)
79
+ if "ct2" in self.model_path:
80
+ self.target_prefix = [self.nllb_lang.code]
81
+ self.trans_tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_path, src_lang=self.nllb_whisper_lang.code)
82
+ self.trans_model = ctranslate2.Translator(self.model_path, compute_type="auto", device=self.device)
83
+ elif "mt5" in self.model_path:
84
+ self.mt5_prefix = self.whisper_lang.code + "2" + self.nllb_lang.code_whisper + ": "
85
+ self.trans_tokenizer = transformers.T5Tokenizer.from_pretrained(self.model_path) #requires spiece.model
86
+ self.trans_model = transformers.MT5ForConditionalGeneration.from_pretrained(self.model_path)
87
+ self.trans_translator = transformers.pipeline('text2text-generation', model=self.trans_model, device=self.device, tokenizer=self.trans_tokenizer)
88
+ else: #NLLB
89
+ self.trans_tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_path)
90
+ self.trans_model = transformers.AutoModelForSeq2SeqLM.from_pretrained(self.model_path)
91
+ self.trans_translator = transformers.pipeline('translation', model=self.trans_model, device=self.device, tokenizer=self.trans_tokenizer, src_lang=self.nllb_whisper_lang.code, tgt_lang=self.nllb_lang.code)
92
+
93
+ def release_vram(self):
94
+ try:
95
+ if torch.cuda.is_available():
96
+ if "ct2" not in self.model_path:
97
+ device = torch.device("cpu")
98
+ self.trans_model.to(device)
99
+ del self.trans_model
100
+ torch.cuda.empty_cache()
101
+ print("release vram end.")
102
+ except Exception as e:
103
+ print("Error release vram: " + str(e))
104
+
105
+
106
+ def translation(self, text: str, max_length: int = 400):
107
+ output = None
108
+ result = None
109
+ try:
110
+ if "ct2" in self.model_path:
111
+ source = self.trans_tokenizer.convert_ids_to_tokens(self.trans_tokenizer.encode(text))
112
+ output = self.trans_model.translate_batch([source], target_prefix=[self.target_prefix])
113
+ target = output[0].hypotheses[0][1:]
114
+ result = self.trans_tokenizer.decode(self.trans_tokenizer.convert_tokens_to_ids(target))
115
+ elif "mt5" in self.model_path:
116
+ output = self.trans_translator(self.mt5_prefix + text, max_length=max_length, num_beams=4)
117
+ result = output[0]['generated_text']
118
+ else: #NLLB
119
+ output = self.trans_translator(text, max_length=max_length)
120
+ result = output[0]['translation_text']
121
+ except Exception as e:
122
+ print("Error translation text: " + str(e))
123
+
124
+ return result
125
+
126
+
127
+ _MODELS = ["distilled-600M", "distilled-1.3B", "1.3B", "3.3B",
128
+ "ct2fast-nllb-200-distilled-1.3B-int8_float16",
129
+ "ct2fast-nllb-200-3.3B-int8_float16",
130
+ "nllb-200-3.3B-ct2-float16", "nllb-200-1.3B-ct2", "nllb-200-1.3B-ct2-int8", "nllb-200-1.3B-ct2-float16",
131
+ "nllb-200-distilled-1.3B-ct2", "nllb-200-distilled-1.3B-ct2-int8", "nllb-200-distilled-1.3B-ct2-float16",
132
+ "nllb-200-distilled-600M-ct2", "nllb-200-distilled-600M-ct2-int8", "nllb-200-distilled-600M-ct2-float16",
133
+ "mt5-zh-ja-en-trimmed",
134
+ "mt5-zh-ja-en-trimmed-fine-tuned-v1"]
135
+
136
+ def check_model_name(name):
137
+ return any(allowed_name in name for allowed_name in _MODELS)
138
+
139
+ def download_model(
140
+ model_config: ModelConfig,
141
+ output_dir: Optional[str] = None,
142
+ local_files_only: bool = False,
143
+ cache_dir: Optional[str] = None,
144
+ ):
145
+ """"download_model" is referenced from the "utils.py" script
146
+ of the "faster_whisper" project, authored by guillaumekln.
147
+
148
+ Downloads a nllb-200 model from the Hugging Face Hub.
149
+
150
+ The model is downloaded from https://huggingface.co/facebook.
151
+
152
+ Args:
153
+ model_config: config of the model to download (facebook/nllb-distilled-600M,
154
+ facebook/nllb-distilled-1.3B, facebook/nllb-1.3B, facebook/nllb-3.3B...).
155
+ output_dir: Directory where the model should be saved. If not set, the model is saved in
156
+ the cache directory.
157
+ local_files_only: If True, avoid downloading the file and return the path to the local
158
+ cached file if it exists.
159
+ cache_dir: Path to the folder where cached files are stored.
160
+
161
+ Returns:
162
+ The path to the downloaded model.
163
+
164
+ Raises:
165
+ ValueError: if the model size is invalid.
166
+ """
167
+ if not check_model_name(model_config.name):
168
+ raise ValueError(
169
+ "Invalid model name '%s', expected one of: %s" % (model_config.name, ", ".join(_MODELS))
170
+ )
171
+
172
+ repo_id = model_config.url #"facebook/nllb-200-%s" %
173
+
174
+ allow_patterns = [
175
+ "config.json",
176
+ "generation_config.json",
177
+ "model.bin",
178
+ "pytorch_model.bin",
179
+ "pytorch_model.bin.index.json",
180
+ "pytorch_model-00001-of-00003.bin",
181
+ "pytorch_model-00002-of-00003.bin",
182
+ "pytorch_model-00003-of-00003.bin",
183
+ "sentencepiece.bpe.model",
184
+ "tokenizer.json",
185
+ "tokenizer_config.json",
186
+ "shared_vocabulary.txt",
187
+ "shared_vocabulary.json",
188
+ "special_tokens_map.json",
189
+ "spiece.model",
190
+ ]
191
+
192
+ kwargs = {
193
+ "local_files_only": local_files_only,
194
+ "allow_patterns": allow_patterns,
195
+ #"tqdm_class": disabled_tqdm,
196
+ }
197
+
198
+ if output_dir is not None:
199
+ kwargs["local_dir"] = output_dir
200
+ kwargs["local_dir_use_symlinks"] = False
201
+
202
+ if cache_dir is not None:
203
+ kwargs["cache_dir"] = cache_dir
204
+
205
+ try:
206
+ return huggingface_hub.snapshot_download(repo_id, **kwargs)
207
+ except (
208
+ huggingface_hub.utils.HfHubHTTPError,
209
+ requests.exceptions.ConnectionError,
210
+ ) as exception:
211
+ warnings.warn(
212
+ "An error occured while synchronizing the model %s from the Hugging Face Hub:\n%s",
213
+ repo_id,
214
+ exception,
215
+ )
216
+ warnings.warn(
217
+ "Trying to load the model directly from the local cache, if it exists."
218
+ )
219
+
220
+ kwargs["local_files_only"] = True
221
+ return huggingface_hub.snapshot_download(repo_id, **kwargs)
src/vadParallel.py CHANGED
@@ -204,7 +204,7 @@ class ParallelTranscription(AbstractTranscription):
204
  gpu_parallel_context.close()
205
 
206
  perf_end_gpu = time.perf_counter()
207
- print("Parallel transcription took " + str(perf_end_gpu - perf_start_gpu) + " seconds")
208
 
209
  return merged
210
 
 
204
  gpu_parallel_context.close()
205
 
206
  perf_end_gpu = time.perf_counter()
207
+ print("\nParallel transcription took " + str(perf_end_gpu - perf_start_gpu) + " seconds")
208
 
209
  return merged
210