SoybeanMilk commited on
Commit
0052b96
1 Parent(s): a87bad5

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +71 -37
  2. config.json5 +50 -1
app.py CHANGED
@@ -40,7 +40,7 @@ from src.whisper.whisperFactory import create_whisper_container
40
  from src.translation.translationModel import TranslationModel
41
  from src.translation.translationLangs import (TranslationLang,
42
  _TO_LANG_CODE_WHISPER, get_lang_whisper_names, get_lang_from_whisper_name, get_lang_from_whisper_code,
43
- get_lang_nllb_names, get_lang_from_nllb_name, get_lang_m2m100_names, get_lang_from_m2m100_name)
44
  import shutil
45
  import zhconv
46
  import tqdm
@@ -233,6 +233,8 @@ class WhisperTranscriber:
233
  mt5LangName: str = decodeOptions.pop("mt5LangName")
234
  ALMAModelName: str = decodeOptions.pop("ALMAModelName")
235
  ALMALangName: str = decodeOptions.pop("ALMALangName")
 
 
236
 
237
  translationBatchSize: int = decodeOptions.pop("translationBatchSize")
238
  translationNoRepeatNgramSize: int = decodeOptions.pop("translationNoRepeatNgramSize")
@@ -250,6 +252,7 @@ class WhisperTranscriber:
250
  vadPadding: float = decodeOptions.pop("vadPadding", self.app_config.vad_padding)
251
  vadPromptWindow: float = decodeOptions.pop("vadPromptWindow", self.app_config.vad_prompt_window)
252
  vadInitialPromptMode: str = decodeOptions.pop("vadInitialPromptMode", self.app_config.vad_initial_prompt_mode)
 
253
 
254
  diarization: bool = decodeOptions.pop("diarization", False)
255
  diarization_speakers: int = decodeOptions.pop("diarization_speakers", 2)
@@ -267,19 +270,22 @@ class WhisperTranscriber:
267
  if whisperNoRepeatNgramSize is not None and whisperNoRepeatNgramSize <= 1:
268
  decodeOptions.pop("no_repeat_ngram_size")
269
 
270
- # word_timestamps = options.get("word_timestamps", False)
271
- # condition_on_previous_text = options.get("condition_on_previous_text", False)
272
-
273
- # prepend_punctuations = options.get("prepend_punctuations", None)
274
- # append_punctuations = options.get("append_punctuations", None)
275
- # initial_prompt = options.get("initial_prompt", None)
276
- # best_of = options.get("best_of", None)
277
- # beam_size = options.get("beam_size", None)
278
- # patience = options.get("patience", None)
279
- # length_penalty = options.get("length_penalty", None)
280
- # suppress_tokens = options.get("suppress_tokens", None)
281
- # compression_ratio_threshold = options.get("compression_ratio_threshold", None)
282
- # logprob_threshold = options.get("logprob_threshold", None)
 
 
 
283
 
284
  vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode)
285
 
@@ -340,6 +346,10 @@ class WhisperTranscriber:
340
  selectedModelName = ALMAModelName if ALMAModelName is not None and len(ALMAModelName) > 0 else "ALMA-13B-GPTQ/TheBloke"
341
  selectedModel = next((modelConfig for modelConfig in self.app_config.models["ALMA"] if modelConfig.name == selectedModelName), None)
342
  translationLang = get_lang_from_m2m100_name(ALMALangName)
 
 
 
 
343
 
344
  if translationLang is not None:
345
  translationModel = TranslationModel(modelConfig=selectedModel, whisperLang=whisperLang, translationLang=translationLang, batchSize=translationBatchSize, noRepeatNgramSize=translationNoRepeatNgramSize, numBeams=translationNumBeams)
@@ -384,7 +394,7 @@ class WhisperTranscriber:
384
 
385
  # Transcribe
386
  result = self.transcribe_file(model, source.source_path, whisperLangCode, task, vadOptions, scaled_progress_listener, **decodeOptions)
387
- if whisperLang is None and result["language"] is not None and len(result["language"]) > 0:
388
  whisperLang = get_lang_from_whisper_code(result["language"])
389
  translationModel.whisperLang = whisperLang
390
 
@@ -413,7 +423,7 @@ class WhisperTranscriber:
413
  out = ffmpeg.output(input_file, input_srt, output_with_srt, vcodec='copy', acodec='copy', scodec='mov_text')
414
  outRsult = out.run(overwrite_output=True)
415
  except Exception as e:
416
- # Ignore error - it's just a cleanup
417
  print("Error merge subtitle with source file: \n" + source.source_path + ", \n" + str(e), outRsult)
418
  elif self.app_config.save_downloaded_files and self.app_config.output_dir is not None and urlData:
419
  print("Saving downloaded file [" + source.source_name + "]")
@@ -421,7 +431,7 @@ class WhisperTranscriber:
421
  save_path = os.path.join(self.app_config.output_dir, filePrefix)
422
  shutil.copy(source.source_path, save_path + suffix)
423
  except Exception as e:
424
- # Ignore error - it's just a cleanup
425
  print("Error saving downloaded file: \n" + source.source_path + ", \n" + str(e))
426
 
427
  if len(sources) > 1:
@@ -473,7 +483,7 @@ class WhisperTranscriber:
473
  try:
474
  os.remove(source.source_path)
475
  except Exception as e:
476
- # Ignore error - it's just a cleanup
477
  print("Error deleting temporary source file: \n" + source.source_path + ", \n" + str(e))
478
 
479
  except ExceededMaximumDuration as e:
@@ -619,7 +629,7 @@ class WhisperTranscriber:
619
  def _create_silero_config(self, non_speech_strategy: NonSpeechStrategy, vadOptions: VadOptions):
620
  # Use Silero VAD
621
  if (self.vad_model is None):
622
- self.vad_model = VadSileroTranscription()
623
 
624
  config = TranscriptionConfig(non_speech_strategy = non_speech_strategy,
625
  max_silent_period=vadOptions.vadMergeWindow, max_merge_size=vadOptions.vadMaxMergeSize,
@@ -661,7 +671,6 @@ class WhisperTranscriber:
661
 
662
  print("\n\nprocess segments took {} seconds.\n\n".format(perf_end_time - perf_start_time))
663
  except Exception as e:
664
- # Ignore error - it's just a cleanup
665
  print(traceback.format_exc())
666
  print("Error process segments: " + str(e))
667
 
@@ -771,8 +780,15 @@ class WhisperTranscriber:
771
  self.diarization = None
772
 
773
  def create_ui(app_config: ApplicationConfig):
 
774
  optionsMd: str = None
775
  readmeMd: str = None
 
 
 
 
 
 
776
  try:
777
  optionsPath = pathlib.Path("docs/options.md")
778
  with open(optionsPath, "r", encoding="utf-8") as optionsFile:
@@ -817,23 +833,16 @@ def create_ui(app_config: ApplicationConfig):
817
  uiDescription += "\n\n" + "Max audio file length: " + str(app_config.input_audio_max_duration) + " s"
818
 
819
  uiArticle = "Read the [documentation here](https://gitlab.com/aadnk/whisper-webui/-/blob/main/docs/options.md)."
820
- uiArticle += "\n\nWhisper's Task 'translate' only implements the functionality of translating other languages into English. "
821
- uiArticle += "OpenAI does not guarantee translations between arbitrary languages. In such cases, you can choose to use the NLLB Model to implement the translation task. "
822
- uiArticle += "However, it's important to note that the NLLB Model runs slowly, and the completion time may be twice as long as usual. "
823
- uiArticle += "\n\nThe larger the parameters of the NLLB model, the better its performance is expected to be. "
824
- uiArticle += "However, it also requires higher computational resources, making it slower to operate. "
825
- uiArticle += "On the other hand, the version converted from ct2 (CTranslate2) requires lower resources and operates at a faster speed."
826
- uiArticle += "\n\nCurrently, enabling word-level timestamps cannot be used in conjunction with NLLB Model translation "
827
- uiArticle += "because Word Timestamps will split the source text, and after translation, it becomes a non-word-level string. "
828
- uiArticle += "\n\nThe 'mt5-zh-ja-en-trimmed' model is finetuned from Google's 'mt5-base' model. "
829
- uiArticle += "This model has a relatively good translation speed, but it only supports three languages: Chinese, Japanese, and English. "
830
 
831
  whisper_models = app_config.get_model_names("whisper")
832
  nllb_models = app_config.get_model_names("nllb")
833
  m2m100_models = app_config.get_model_names("m2m100")
834
  mt5_models = app_config.get_model_names("mt5")
835
  ALMA_models = app_config.get_model_names("ALMA")
836
-
 
 
 
837
  common_whisper_inputs = lambda : {
838
  gr.Dropdown(label="Whisper - Model (for audio)", choices=whisper_models, value=app_config.default_model_name, elem_id="whisperModelName"),
839
  gr.Dropdown(label="Whisper - Language", choices=sorted(get_lang_whisper_names()), value=app_config.language, elem_id="whisperLangName"),
@@ -852,7 +861,11 @@ def create_ui(app_config: ApplicationConfig):
852
  }
853
  common_ALMA_inputs = lambda : {
854
  gr.Dropdown(label="ALMA - Model (for translate)", choices=ALMA_models, elem_id="ALMAModelName"),
855
- gr.Dropdown(label="ALMA - Language", choices=sorted(get_lang_m2m100_names(["en", "ja", "zh"])), elem_id="ALMALangName"),
 
 
 
 
856
  }
857
 
858
  common_translation_inputs = lambda : {
@@ -865,6 +878,7 @@ def create_ui(app_config: ApplicationConfig):
865
  gr.Dropdown(choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], value=app_config.default_vad, label="VAD", elem_id="vad"),
866
  gr.Number(label="VAD - Merge Window (s)", precision=0, value=app_config.vad_merge_window, elem_id="vadMergeWindow"),
867
  gr.Number(label="VAD - Max Merge Size (s)", precision=0, value=app_config.vad_max_merge_size, elem_id="vadMaxMergeSize"),
 
868
  }
869
 
870
  common_word_timestamps_inputs = lambda : {
@@ -917,10 +931,14 @@ def create_ui(app_config: ApplicationConfig):
917
  with gr.Tab(label="ALMA") as simpleALMATab:
918
  with gr.Row():
919
  simpleInputDict.update(common_ALMA_inputs())
 
 
 
920
  simpleM2M100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [simpleTranslateInput] )
921
  simpleNllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [simpleTranslateInput] )
922
  simpleMT5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [simpleTranslateInput] )
923
  simpleALMATab.select(fn=lambda: "ALMA", inputs = [], outputs= [simpleTranslateInput] )
 
924
  with gr.Column():
925
  with gr.Tab(label="URL") as simpleUrlTab:
926
  simpleInputDict.update({gr.Text(label="URL (YouTube, etc.)", elem_id = "urlData")})
@@ -942,8 +960,10 @@ def create_ui(app_config: ApplicationConfig):
942
  simpleInputDict.update(common_translation_inputs())
943
  with gr.Column():
944
  simpleOutput = common_output()
945
- with gr.Accordion("Article"):
946
- gr.Markdown(uiArticle)
 
 
947
  if optionsMd is not None:
948
  with gr.Accordion("docs/options.md", open=False):
949
  gr.Markdown(optionsMd)
@@ -957,7 +977,7 @@ def create_ui(app_config: ApplicationConfig):
957
 
958
  fullInputDict = {}
959
  fullDescription = uiDescription + "\n\n\n\n" + "Be careful when changing some of the options in the full interface - this can cause the model to crash."
960
-
961
  with gr.Blocks() as fullTranscribe:
962
  fullTranslateInput = gr.State(value="m2m100", elem_id = "translateInput")
963
  fullSourceInput = gr.State(value="urlData", elem_id = "sourceInput")
@@ -980,10 +1000,14 @@ def create_ui(app_config: ApplicationConfig):
980
  with gr.Tab(label="ALMA") as fullALMATab:
981
  with gr.Row():
982
  fullInputDict.update(common_ALMA_inputs())
 
 
 
983
  fullM2M100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [fullTranslateInput] )
984
  fullNllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [fullTranslateInput] )
985
  fullMT5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [fullTranslateInput] )
986
  fullALMATab.select(fn=lambda: "ALMA", inputs = [], outputs= [fullTranslateInput] )
 
987
  with gr.Column():
988
  with gr.Tab(label="URL") as fullUrlTab:
989
  fullInputDict.update({gr.Text(label="URL (YouTube, etc.)", elem_id = "urlData")})
@@ -1013,7 +1037,7 @@ def create_ui(app_config: ApplicationConfig):
1013
  gr.Number(label="Best Of - Non-zero temperature", value=app_config.best_of, precision=0, elem_id = "best_of"),
1014
  gr.Number(label="Beam Size - Zero temperature", value=app_config.beam_size, precision=0, elem_id = "beam_size"),
1015
  gr.Number(label="Patience - Zero temperature", value=app_config.patience, elem_id = "patience"),
1016
- gr.Number(label="Length Penalty - Any temperature", value=app_config.length_penalty, elem_id = "length_penalty"),
1017
  gr.Text(label="Suppress Tokens - Comma-separated list of token IDs", value=app_config.suppress_tokens, elem_id = "suppress_tokens"),
1018
  gr.Checkbox(label="Condition on previous text", value=app_config.condition_on_previous_text, elem_id = "condition_on_previous_text"),
1019
  gr.Checkbox(label="FP16", value=app_config.fp16, elem_id = "fp16"),
@@ -1054,7 +1078,7 @@ def create_ui(app_config: ApplicationConfig):
1054
  print("Queue mode enabled (concurrency count: " + str(app_config.queue_concurrency_count) + ")")
1055
  else:
1056
  print("Queue mode disabled - progress bars will not be shown.")
1057
-
1058
  demo.launch(inbrowser=app_config.autolaunch, share=app_config.share, server_name=app_config.server_name, server_port=app_config.server_port)
1059
 
1060
  # Clean up
@@ -1136,6 +1160,16 @@ if __name__ == '__main__':
1136
  # updated_config.autolaunch = True
1137
  # updated_config.auto_parallel = False
1138
  # updated_config.save_downloaded_files = True
 
 
 
 
 
 
 
 
 
 
1139
 
1140
  if (threads := args.pop("threads")) > 0:
1141
  torch.set_num_threads(threads)
 
40
  from src.translation.translationModel import TranslationModel
41
  from src.translation.translationLangs import (TranslationLang,
42
  _TO_LANG_CODE_WHISPER, get_lang_whisper_names, get_lang_from_whisper_name, get_lang_from_whisper_code,
43
+ get_lang_nllb_names, get_lang_from_nllb_name, get_lang_m2m100_names, get_lang_from_m2m100_name, sort_lang_by_whisper_codes)
44
  import shutil
45
  import zhconv
46
  import tqdm
 
233
  mt5LangName: str = decodeOptions.pop("mt5LangName")
234
  ALMAModelName: str = decodeOptions.pop("ALMAModelName")
235
  ALMALangName: str = decodeOptions.pop("ALMALangName")
236
+ madlad400ModelName: str = decodeOptions.pop("madlad400ModelName")
237
+ madlad400LangName: str = decodeOptions.pop("madlad400LangName")
238
 
239
  translationBatchSize: int = decodeOptions.pop("translationBatchSize")
240
  translationNoRepeatNgramSize: int = decodeOptions.pop("translationNoRepeatNgramSize")
 
252
  vadPadding: float = decodeOptions.pop("vadPadding", self.app_config.vad_padding)
253
  vadPromptWindow: float = decodeOptions.pop("vadPromptWindow", self.app_config.vad_prompt_window)
254
  vadInitialPromptMode: str = decodeOptions.pop("vadInitialPromptMode", self.app_config.vad_initial_prompt_mode)
255
+ self.vad_process_timeout: float = decodeOptions.pop("vadPocessTimeout", self.vad_process_timeout)
256
 
257
  diarization: bool = decodeOptions.pop("diarization", False)
258
  diarization_speakers: int = decodeOptions.pop("diarization_speakers", 2)
 
270
  if whisperNoRepeatNgramSize is not None and whisperNoRepeatNgramSize <= 1:
271
  decodeOptions.pop("no_repeat_ngram_size")
272
 
273
+ for key, value in list(decodeOptions.items()):
274
+ if value == "":
275
+ del decodeOptions[key]
276
+
277
+ # word_timestamps = decodeOptions.get("word_timestamps", False)
278
+ # condition_on_previous_text = decodeOptions.get("condition_on_previous_text", False)
279
+ # prepend_punctuations = decodeOptions.get("prepend_punctuations", None)
280
+ # append_punctuations = decodeOptions.get("append_punctuations", None)
281
+ # initial_prompt = decodeOptions.get("initial_prompt", None)
282
+ # best_of = decodeOptions.get("best_of", None)
283
+ # beam_size = decodeOptions.get("beam_size", None)
284
+ # patience = decodeOptions.get("patience", None)
285
+ # length_penalty = decodeOptions.get("length_penalty", None)
286
+ # suppress_tokens = decodeOptions.get("suppress_tokens", None)
287
+ # compression_ratio_threshold = decodeOptions.get("compression_ratio_threshold", None)
288
+ # logprob_threshold = decodeOptions.get("logprob_threshold", None)
289
 
290
  vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode)
291
 
 
346
  selectedModelName = ALMAModelName if ALMAModelName is not None and len(ALMAModelName) > 0 else "ALMA-13B-GPTQ/TheBloke"
347
  selectedModel = next((modelConfig for modelConfig in self.app_config.models["ALMA"] if modelConfig.name == selectedModelName), None)
348
  translationLang = get_lang_from_m2m100_name(ALMALangName)
349
+ elif translateInput == "madlad400" and madlad400LangName is not None and len(madlad400LangName) > 0:
350
+ selectedModelName = madlad400ModelName if madlad400ModelName is not None and len(madlad400ModelName) > 0 else "madlad400-10b-mt-ct2-int8_float16"
351
+ selectedModel = next((modelConfig for modelConfig in self.app_config.models["madlad400"] if modelConfig.name == selectedModelName), None)
352
+ translationLang = get_lang_from_m2m100_name(madlad400LangName)
353
 
354
  if translationLang is not None:
355
  translationModel = TranslationModel(modelConfig=selectedModel, whisperLang=whisperLang, translationLang=translationLang, batchSize=translationBatchSize, noRepeatNgramSize=translationNoRepeatNgramSize, numBeams=translationNumBeams)
 
394
 
395
  # Transcribe
396
  result = self.transcribe_file(model, source.source_path, whisperLangCode, task, vadOptions, scaled_progress_listener, **decodeOptions)
397
+ if translationModel is not None and whisperLang is None and result["language"] is not None and len(result["language"]) > 0:
398
  whisperLang = get_lang_from_whisper_code(result["language"])
399
  translationModel.whisperLang = whisperLang
400
 
 
423
  out = ffmpeg.output(input_file, input_srt, output_with_srt, vcodec='copy', acodec='copy', scodec='mov_text')
424
  outRsult = out.run(overwrite_output=True)
425
  except Exception as e:
426
+ print(traceback.format_exc())
427
  print("Error merge subtitle with source file: \n" + source.source_path + ", \n" + str(e), outRsult)
428
  elif self.app_config.save_downloaded_files and self.app_config.output_dir is not None and urlData:
429
  print("Saving downloaded file [" + source.source_name + "]")
 
431
  save_path = os.path.join(self.app_config.output_dir, filePrefix)
432
  shutil.copy(source.source_path, save_path + suffix)
433
  except Exception as e:
434
+ print(traceback.format_exc())
435
  print("Error saving downloaded file: \n" + source.source_path + ", \n" + str(e))
436
 
437
  if len(sources) > 1:
 
483
  try:
484
  os.remove(source.source_path)
485
  except Exception as e:
486
+ print(traceback.format_exc())
487
  print("Error deleting temporary source file: \n" + source.source_path + ", \n" + str(e))
488
 
489
  except ExceededMaximumDuration as e:
 
629
  def _create_silero_config(self, non_speech_strategy: NonSpeechStrategy, vadOptions: VadOptions):
630
  # Use Silero VAD
631
  if (self.vad_model is None):
632
+ self.vad_model = VadSileroTranscription() #vad_model is snakers4/silero-vad
633
 
634
  config = TranscriptionConfig(non_speech_strategy = non_speech_strategy,
635
  max_silent_period=vadOptions.vadMergeWindow, max_merge_size=vadOptions.vadMaxMergeSize,
 
671
 
672
  print("\n\nprocess segments took {} seconds.\n\n".format(perf_end_time - perf_start_time))
673
  except Exception as e:
 
674
  print(traceback.format_exc())
675
  print("Error process segments: " + str(e))
676
 
 
780
  self.diarization = None
781
 
782
  def create_ui(app_config: ApplicationConfig):
783
+ translateModelMd: str = None
784
  optionsMd: str = None
785
  readmeMd: str = None
786
+ try:
787
+ translateModelPath = pathlib.Path("docs/translateModel.md")
788
+ with open(translateModelPath, "r", encoding="utf-8") as translateModelFile:
789
+ translateModelMd = translateModelFile.read()
790
+ except Exception as e:
791
+ print("Error occurred during read translateModel.md file: ", str(e))
792
  try:
793
  optionsPath = pathlib.Path("docs/options.md")
794
  with open(optionsPath, "r", encoding="utf-8") as optionsFile:
 
833
  uiDescription += "\n\n" + "Max audio file length: " + str(app_config.input_audio_max_duration) + " s"
834
 
835
  uiArticle = "Read the [documentation here](https://gitlab.com/aadnk/whisper-webui/-/blob/main/docs/options.md)."
 
 
 
 
 
 
 
 
 
 
836
 
837
  whisper_models = app_config.get_model_names("whisper")
838
  nllb_models = app_config.get_model_names("nllb")
839
  m2m100_models = app_config.get_model_names("m2m100")
840
  mt5_models = app_config.get_model_names("mt5")
841
  ALMA_models = app_config.get_model_names("ALMA")
842
+ madlad400_models = app_config.get_model_names("madlad400")
843
+ if not torch.cuda.is_available(): #Due to the poor support of GPTQ for CPUs, the execution time per iteration exceeds a thousand seconds when operating on a CPU. Therefore, when the system does not support a GPU, the GPTQ model is removed from the list.
844
+ ALMA_models = list(filter(lambda alma: "GPTQ" not in alma, ALMA_models))
845
+
846
  common_whisper_inputs = lambda : {
847
  gr.Dropdown(label="Whisper - Model (for audio)", choices=whisper_models, value=app_config.default_model_name, elem_id="whisperModelName"),
848
  gr.Dropdown(label="Whisper - Language", choices=sorted(get_lang_whisper_names()), value=app_config.language, elem_id="whisperLangName"),
 
861
  }
862
  common_ALMA_inputs = lambda : {
863
  gr.Dropdown(label="ALMA - Model (for translate)", choices=ALMA_models, elem_id="ALMAModelName"),
864
+ gr.Dropdown(label="ALMA - Language", choices=sort_lang_by_whisper_codes(["en", "de", "cs", "is", "ru", "zh", "ja"]), elem_id="ALMALangName"),
865
+ }
866
+ common_madlad400_inputs = lambda : {
867
+ gr.Dropdown(label="madlad400 - Model (for translate)", choices=madlad400_models, elem_id="madlad400ModelName"),
868
+ gr.Dropdown(label="madlad400 - Language", choices=sorted(get_lang_m2m100_names()), elem_id="madlad400LangName"),
869
  }
870
 
871
  common_translation_inputs = lambda : {
 
878
  gr.Dropdown(choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], value=app_config.default_vad, label="VAD", elem_id="vad"),
879
  gr.Number(label="VAD - Merge Window (s)", precision=0, value=app_config.vad_merge_window, elem_id="vadMergeWindow"),
880
  gr.Number(label="VAD - Max Merge Size (s)", precision=0, value=app_config.vad_max_merge_size, elem_id="vadMaxMergeSize"),
881
+ gr.Number(label="VAD - Process Timeout (s)", precision=0, value=app_config.vad_process_timeout, elem_id="vadPocessTimeout"),
882
  }
883
 
884
  common_word_timestamps_inputs = lambda : {
 
931
  with gr.Tab(label="ALMA") as simpleALMATab:
932
  with gr.Row():
933
  simpleInputDict.update(common_ALMA_inputs())
934
+ with gr.Tab(label="madlad400") as simplemadlad400Tab:
935
+ with gr.Row():
936
+ simpleInputDict.update(common_madlad400_inputs())
937
  simpleM2M100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [simpleTranslateInput] )
938
  simpleNllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [simpleTranslateInput] )
939
  simpleMT5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [simpleTranslateInput] )
940
  simpleALMATab.select(fn=lambda: "ALMA", inputs = [], outputs= [simpleTranslateInput] )
941
+ simplemadlad400Tab.select(fn=lambda: "madlad400", inputs = [], outputs= [simpleTranslateInput] )
942
  with gr.Column():
943
  with gr.Tab(label="URL") as simpleUrlTab:
944
  simpleInputDict.update({gr.Text(label="URL (YouTube, etc.)", elem_id = "urlData")})
 
960
  simpleInputDict.update(common_translation_inputs())
961
  with gr.Column():
962
  simpleOutput = common_output()
963
+ gr.Markdown(uiArticle)
964
+ if translateModelMd is not None:
965
+ with gr.Accordion("docs/translateModel.md", open=False):
966
+ gr.Markdown(translateModelMd)
967
  if optionsMd is not None:
968
  with gr.Accordion("docs/options.md", open=False):
969
  gr.Markdown(optionsMd)
 
977
 
978
  fullInputDict = {}
979
  fullDescription = uiDescription + "\n\n\n\n" + "Be careful when changing some of the options in the full interface - this can cause the model to crash."
980
+
981
  with gr.Blocks() as fullTranscribe:
982
  fullTranslateInput = gr.State(value="m2m100", elem_id = "translateInput")
983
  fullSourceInput = gr.State(value="urlData", elem_id = "sourceInput")
 
1000
  with gr.Tab(label="ALMA") as fullALMATab:
1001
  with gr.Row():
1002
  fullInputDict.update(common_ALMA_inputs())
1003
+ with gr.Tab(label="madlad400") as fullmadlad400Tab:
1004
+ with gr.Row():
1005
+ fullInputDict.update(common_madlad400_inputs())
1006
  fullM2M100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [fullTranslateInput] )
1007
  fullNllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [fullTranslateInput] )
1008
  fullMT5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [fullTranslateInput] )
1009
  fullALMATab.select(fn=lambda: "ALMA", inputs = [], outputs= [fullTranslateInput] )
1010
+ fullmadlad400Tab.select(fn=lambda: "madlad400", inputs = [], outputs= [fullTranslateInput] )
1011
  with gr.Column():
1012
  with gr.Tab(label="URL") as fullUrlTab:
1013
  fullInputDict.update({gr.Text(label="URL (YouTube, etc.)", elem_id = "urlData")})
 
1037
  gr.Number(label="Best Of - Non-zero temperature", value=app_config.best_of, precision=0, elem_id = "best_of"),
1038
  gr.Number(label="Beam Size - Zero temperature", value=app_config.beam_size, precision=0, elem_id = "beam_size"),
1039
  gr.Number(label="Patience - Zero temperature", value=app_config.patience, elem_id = "patience"),
1040
+ gr.Number(label="Length Penalty - Any temperature", value=lambda : None if app_config.length_penalty is None else app_config.length_penalty, elem_id = "length_penalty"),
1041
  gr.Text(label="Suppress Tokens - Comma-separated list of token IDs", value=app_config.suppress_tokens, elem_id = "suppress_tokens"),
1042
  gr.Checkbox(label="Condition on previous text", value=app_config.condition_on_previous_text, elem_id = "condition_on_previous_text"),
1043
  gr.Checkbox(label="FP16", value=app_config.fp16, elem_id = "fp16"),
 
1078
  print("Queue mode enabled (concurrency count: " + str(app_config.queue_concurrency_count) + ")")
1079
  else:
1080
  print("Queue mode disabled - progress bars will not be shown.")
1081
+
1082
  demo.launch(inbrowser=app_config.autolaunch, share=app_config.share, server_name=app_config.server_name, server_port=app_config.server_port)
1083
 
1084
  # Clean up
 
1160
  # updated_config.autolaunch = True
1161
  # updated_config.auto_parallel = False
1162
  # updated_config.save_downloaded_files = True
1163
+
1164
+ try:
1165
+ if torch.cuda.is_available():
1166
+ deviceId = torch.cuda.current_device()
1167
+ totalVram = torch.cuda.get_device_properties(deviceId).total_memory
1168
+ if totalVram/(1024*1024*1024) <= 4: #VRAM <= 4 GB
1169
+ updated_config.vad_process_timeout = 0
1170
+ except Exception as e:
1171
+ print(traceback.format_exc())
1172
+ print("Error detect vram: " + str(e))
1173
 
1174
  if (threads := args.pop("threads")) > 0:
1175
  torch.set_num_threads(threads)
config.json5 CHANGED
@@ -23,6 +23,10 @@
23
  "name": "large",
24
  "url": "large"
25
  },
 
 
 
 
26
  {
27
  "name": "large-v2",
28
  "url": "large-v2"
@@ -189,10 +193,55 @@
189
  }
190
  ],
191
  "ALMA": [
 
 
 
 
 
192
  {
193
  "name": "ALMA-13B-GPTQ/TheBloke",
194
  "url": "TheBloke/ALMA-13B-GPTQ",
195
- "type": "huggingface",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  },
197
  ]
198
  },
 
23
  "name": "large",
24
  "url": "large"
25
  },
26
+ {
27
+ "name": "large-v1",
28
+ "url": "large-v1"
29
+ },
30
  {
31
  "name": "large-v2",
32
  "url": "large-v2"
 
193
  }
194
  ],
195
  "ALMA": [
196
+ {
197
+ "name": "ALMA-7B-GPTQ/TheBloke",
198
+ "url": "TheBloke/ALMA-7B-GPTQ",
199
+ "type": "huggingface"
200
+ },
201
  {
202
  "name": "ALMA-13B-GPTQ/TheBloke",
203
  "url": "TheBloke/ALMA-13B-GPTQ",
204
+ "type": "huggingface"
205
+ },
206
+ {
207
+ "name": "ALMA-7B-GGUF-Q4_K_M/TheBloke",
208
+ "url": "TheBloke/ALMA-7B-GGUF",
209
+ "type": "huggingface",
210
+ "model_file": "alma-7b.Q4_K_M.gguf",
211
+ "tokenizer_url": "haoranxu/ALMA-7B"
212
+ },
213
+ {
214
+ "name": "ALMA-13B-GGUF-Q4_K_M/TheBloke",
215
+ "url": "TheBloke/ALMA-13B-GGUF",
216
+ "type": "huggingface",
217
+ "model_file": "alma-13b.Q4_K_M.gguf",
218
+ "tokenizer_url": "haoranxu/ALMA-13B"
219
+ },
220
+ {
221
+ "name": "ALMA-7B-ct2:int8_float16/avan",
222
+ "url": "avans06/ALMA-7B-ct2-int8_float16",
223
+ "type": "huggingface",
224
+ "tokenizer_url": "haoranxu/ALMA-7B"
225
+ },
226
+ {
227
+ "name": "ALMA-13B-ct2:int8_float16/avan",
228
+ "url": "avans06/ALMA-13B-ct2-int8_float16",
229
+ "type": "huggingface",
230
+ "tokenizer_url": "haoranxu/ALMA-13B"
231
+ },
232
+ ],
233
+ "madlad400": [
234
+ {
235
+ "name": "madlad400-3b-mt-ct2-int8_float16/SoybeanMilk",
236
+ "url": "SoybeanMilk/madlad400-3b-mt-ct2-int8_float16",
237
+ "type": "huggingface",
238
+ "tokenizer_url": "jbochi/madlad400-3b-mt"
239
+ },
240
+ {
241
+ "name": "madlad400-10b-mt-ct2-int8_float16/SoybeanMilk",
242
+ "url": "SoybeanMilk/madlad400-10b-mt-ct2-int8_float16",
243
+ "type": "huggingface",
244
+ "tokenizer_url": "jbochi/madlad400-10b-mt"
245
  },
246
  ]
247
  },