Spaces:
Runtime error
Runtime error
SoybeanMilk
commited on
Commit
•
0052b96
1
Parent(s):
a87bad5
Upload 2 files
Browse files- app.py +71 -37
- config.json5 +50 -1
app.py
CHANGED
@@ -40,7 +40,7 @@ from src.whisper.whisperFactory import create_whisper_container
|
|
40 |
from src.translation.translationModel import TranslationModel
|
41 |
from src.translation.translationLangs import (TranslationLang,
|
42 |
_TO_LANG_CODE_WHISPER, get_lang_whisper_names, get_lang_from_whisper_name, get_lang_from_whisper_code,
|
43 |
-
get_lang_nllb_names, get_lang_from_nllb_name, get_lang_m2m100_names, get_lang_from_m2m100_name)
|
44 |
import shutil
|
45 |
import zhconv
|
46 |
import tqdm
|
@@ -233,6 +233,8 @@ class WhisperTranscriber:
|
|
233 |
mt5LangName: str = decodeOptions.pop("mt5LangName")
|
234 |
ALMAModelName: str = decodeOptions.pop("ALMAModelName")
|
235 |
ALMALangName: str = decodeOptions.pop("ALMALangName")
|
|
|
|
|
236 |
|
237 |
translationBatchSize: int = decodeOptions.pop("translationBatchSize")
|
238 |
translationNoRepeatNgramSize: int = decodeOptions.pop("translationNoRepeatNgramSize")
|
@@ -250,6 +252,7 @@ class WhisperTranscriber:
|
|
250 |
vadPadding: float = decodeOptions.pop("vadPadding", self.app_config.vad_padding)
|
251 |
vadPromptWindow: float = decodeOptions.pop("vadPromptWindow", self.app_config.vad_prompt_window)
|
252 |
vadInitialPromptMode: str = decodeOptions.pop("vadInitialPromptMode", self.app_config.vad_initial_prompt_mode)
|
|
|
253 |
|
254 |
diarization: bool = decodeOptions.pop("diarization", False)
|
255 |
diarization_speakers: int = decodeOptions.pop("diarization_speakers", 2)
|
@@ -267,19 +270,22 @@ class WhisperTranscriber:
|
|
267 |
if whisperNoRepeatNgramSize is not None and whisperNoRepeatNgramSize <= 1:
|
268 |
decodeOptions.pop("no_repeat_ngram_size")
|
269 |
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
#
|
275 |
-
#
|
276 |
-
#
|
277 |
-
#
|
278 |
-
#
|
279 |
-
#
|
280 |
-
#
|
281 |
-
#
|
282 |
-
#
|
|
|
|
|
|
|
283 |
|
284 |
vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode)
|
285 |
|
@@ -340,6 +346,10 @@ class WhisperTranscriber:
|
|
340 |
selectedModelName = ALMAModelName if ALMAModelName is not None and len(ALMAModelName) > 0 else "ALMA-13B-GPTQ/TheBloke"
|
341 |
selectedModel = next((modelConfig for modelConfig in self.app_config.models["ALMA"] if modelConfig.name == selectedModelName), None)
|
342 |
translationLang = get_lang_from_m2m100_name(ALMALangName)
|
|
|
|
|
|
|
|
|
343 |
|
344 |
if translationLang is not None:
|
345 |
translationModel = TranslationModel(modelConfig=selectedModel, whisperLang=whisperLang, translationLang=translationLang, batchSize=translationBatchSize, noRepeatNgramSize=translationNoRepeatNgramSize, numBeams=translationNumBeams)
|
@@ -384,7 +394,7 @@ class WhisperTranscriber:
|
|
384 |
|
385 |
# Transcribe
|
386 |
result = self.transcribe_file(model, source.source_path, whisperLangCode, task, vadOptions, scaled_progress_listener, **decodeOptions)
|
387 |
-
if whisperLang is None and result["language"] is not None and len(result["language"]) > 0:
|
388 |
whisperLang = get_lang_from_whisper_code(result["language"])
|
389 |
translationModel.whisperLang = whisperLang
|
390 |
|
@@ -413,7 +423,7 @@ class WhisperTranscriber:
|
|
413 |
out = ffmpeg.output(input_file, input_srt, output_with_srt, vcodec='copy', acodec='copy', scodec='mov_text')
|
414 |
outRsult = out.run(overwrite_output=True)
|
415 |
except Exception as e:
|
416 |
-
|
417 |
print("Error merge subtitle with source file: \n" + source.source_path + ", \n" + str(e), outRsult)
|
418 |
elif self.app_config.save_downloaded_files and self.app_config.output_dir is not None and urlData:
|
419 |
print("Saving downloaded file [" + source.source_name + "]")
|
@@ -421,7 +431,7 @@ class WhisperTranscriber:
|
|
421 |
save_path = os.path.join(self.app_config.output_dir, filePrefix)
|
422 |
shutil.copy(source.source_path, save_path + suffix)
|
423 |
except Exception as e:
|
424 |
-
|
425 |
print("Error saving downloaded file: \n" + source.source_path + ", \n" + str(e))
|
426 |
|
427 |
if len(sources) > 1:
|
@@ -473,7 +483,7 @@ class WhisperTranscriber:
|
|
473 |
try:
|
474 |
os.remove(source.source_path)
|
475 |
except Exception as e:
|
476 |
-
|
477 |
print("Error deleting temporary source file: \n" + source.source_path + ", \n" + str(e))
|
478 |
|
479 |
except ExceededMaximumDuration as e:
|
@@ -619,7 +629,7 @@ class WhisperTranscriber:
|
|
619 |
def _create_silero_config(self, non_speech_strategy: NonSpeechStrategy, vadOptions: VadOptions):
|
620 |
# Use Silero VAD
|
621 |
if (self.vad_model is None):
|
622 |
-
self.vad_model = VadSileroTranscription()
|
623 |
|
624 |
config = TranscriptionConfig(non_speech_strategy = non_speech_strategy,
|
625 |
max_silent_period=vadOptions.vadMergeWindow, max_merge_size=vadOptions.vadMaxMergeSize,
|
@@ -661,7 +671,6 @@ class WhisperTranscriber:
|
|
661 |
|
662 |
print("\n\nprocess segments took {} seconds.\n\n".format(perf_end_time - perf_start_time))
|
663 |
except Exception as e:
|
664 |
-
# Ignore error - it's just a cleanup
|
665 |
print(traceback.format_exc())
|
666 |
print("Error process segments: " + str(e))
|
667 |
|
@@ -771,8 +780,15 @@ class WhisperTranscriber:
|
|
771 |
self.diarization = None
|
772 |
|
773 |
def create_ui(app_config: ApplicationConfig):
|
|
|
774 |
optionsMd: str = None
|
775 |
readmeMd: str = None
|
|
|
|
|
|
|
|
|
|
|
|
|
776 |
try:
|
777 |
optionsPath = pathlib.Path("docs/options.md")
|
778 |
with open(optionsPath, "r", encoding="utf-8") as optionsFile:
|
@@ -817,23 +833,16 @@ def create_ui(app_config: ApplicationConfig):
|
|
817 |
uiDescription += "\n\n" + "Max audio file length: " + str(app_config.input_audio_max_duration) + " s"
|
818 |
|
819 |
uiArticle = "Read the [documentation here](https://gitlab.com/aadnk/whisper-webui/-/blob/main/docs/options.md)."
|
820 |
-
uiArticle += "\n\nWhisper's Task 'translate' only implements the functionality of translating other languages into English. "
|
821 |
-
uiArticle += "OpenAI does not guarantee translations between arbitrary languages. In such cases, you can choose to use the NLLB Model to implement the translation task. "
|
822 |
-
uiArticle += "However, it's important to note that the NLLB Model runs slowly, and the completion time may be twice as long as usual. "
|
823 |
-
uiArticle += "\n\nThe larger the parameters of the NLLB model, the better its performance is expected to be. "
|
824 |
-
uiArticle += "However, it also requires higher computational resources, making it slower to operate. "
|
825 |
-
uiArticle += "On the other hand, the version converted from ct2 (CTranslate2) requires lower resources and operates at a faster speed."
|
826 |
-
uiArticle += "\n\nCurrently, enabling word-level timestamps cannot be used in conjunction with NLLB Model translation "
|
827 |
-
uiArticle += "because Word Timestamps will split the source text, and after translation, it becomes a non-word-level string. "
|
828 |
-
uiArticle += "\n\nThe 'mt5-zh-ja-en-trimmed' model is finetuned from Google's 'mt5-base' model. "
|
829 |
-
uiArticle += "This model has a relatively good translation speed, but it only supports three languages: Chinese, Japanese, and English. "
|
830 |
|
831 |
whisper_models = app_config.get_model_names("whisper")
|
832 |
nllb_models = app_config.get_model_names("nllb")
|
833 |
m2m100_models = app_config.get_model_names("m2m100")
|
834 |
mt5_models = app_config.get_model_names("mt5")
|
835 |
ALMA_models = app_config.get_model_names("ALMA")
|
836 |
-
|
|
|
|
|
|
|
837 |
common_whisper_inputs = lambda : {
|
838 |
gr.Dropdown(label="Whisper - Model (for audio)", choices=whisper_models, value=app_config.default_model_name, elem_id="whisperModelName"),
|
839 |
gr.Dropdown(label="Whisper - Language", choices=sorted(get_lang_whisper_names()), value=app_config.language, elem_id="whisperLangName"),
|
@@ -852,7 +861,11 @@ def create_ui(app_config: ApplicationConfig):
|
|
852 |
}
|
853 |
common_ALMA_inputs = lambda : {
|
854 |
gr.Dropdown(label="ALMA - Model (for translate)", choices=ALMA_models, elem_id="ALMAModelName"),
|
855 |
-
gr.Dropdown(label="ALMA - Language", choices=
|
|
|
|
|
|
|
|
|
856 |
}
|
857 |
|
858 |
common_translation_inputs = lambda : {
|
@@ -865,6 +878,7 @@ def create_ui(app_config: ApplicationConfig):
|
|
865 |
gr.Dropdown(choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], value=app_config.default_vad, label="VAD", elem_id="vad"),
|
866 |
gr.Number(label="VAD - Merge Window (s)", precision=0, value=app_config.vad_merge_window, elem_id="vadMergeWindow"),
|
867 |
gr.Number(label="VAD - Max Merge Size (s)", precision=0, value=app_config.vad_max_merge_size, elem_id="vadMaxMergeSize"),
|
|
|
868 |
}
|
869 |
|
870 |
common_word_timestamps_inputs = lambda : {
|
@@ -917,10 +931,14 @@ def create_ui(app_config: ApplicationConfig):
|
|
917 |
with gr.Tab(label="ALMA") as simpleALMATab:
|
918 |
with gr.Row():
|
919 |
simpleInputDict.update(common_ALMA_inputs())
|
|
|
|
|
|
|
920 |
simpleM2M100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [simpleTranslateInput] )
|
921 |
simpleNllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [simpleTranslateInput] )
|
922 |
simpleMT5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [simpleTranslateInput] )
|
923 |
simpleALMATab.select(fn=lambda: "ALMA", inputs = [], outputs= [simpleTranslateInput] )
|
|
|
924 |
with gr.Column():
|
925 |
with gr.Tab(label="URL") as simpleUrlTab:
|
926 |
simpleInputDict.update({gr.Text(label="URL (YouTube, etc.)", elem_id = "urlData")})
|
@@ -942,8 +960,10 @@ def create_ui(app_config: ApplicationConfig):
|
|
942 |
simpleInputDict.update(common_translation_inputs())
|
943 |
with gr.Column():
|
944 |
simpleOutput = common_output()
|
945 |
-
|
946 |
-
|
|
|
|
|
947 |
if optionsMd is not None:
|
948 |
with gr.Accordion("docs/options.md", open=False):
|
949 |
gr.Markdown(optionsMd)
|
@@ -957,7 +977,7 @@ def create_ui(app_config: ApplicationConfig):
|
|
957 |
|
958 |
fullInputDict = {}
|
959 |
fullDescription = uiDescription + "\n\n\n\n" + "Be careful when changing some of the options in the full interface - this can cause the model to crash."
|
960 |
-
|
961 |
with gr.Blocks() as fullTranscribe:
|
962 |
fullTranslateInput = gr.State(value="m2m100", elem_id = "translateInput")
|
963 |
fullSourceInput = gr.State(value="urlData", elem_id = "sourceInput")
|
@@ -980,10 +1000,14 @@ def create_ui(app_config: ApplicationConfig):
|
|
980 |
with gr.Tab(label="ALMA") as fullALMATab:
|
981 |
with gr.Row():
|
982 |
fullInputDict.update(common_ALMA_inputs())
|
|
|
|
|
|
|
983 |
fullM2M100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [fullTranslateInput] )
|
984 |
fullNllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [fullTranslateInput] )
|
985 |
fullMT5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [fullTranslateInput] )
|
986 |
fullALMATab.select(fn=lambda: "ALMA", inputs = [], outputs= [fullTranslateInput] )
|
|
|
987 |
with gr.Column():
|
988 |
with gr.Tab(label="URL") as fullUrlTab:
|
989 |
fullInputDict.update({gr.Text(label="URL (YouTube, etc.)", elem_id = "urlData")})
|
@@ -1013,7 +1037,7 @@ def create_ui(app_config: ApplicationConfig):
|
|
1013 |
gr.Number(label="Best Of - Non-zero temperature", value=app_config.best_of, precision=0, elem_id = "best_of"),
|
1014 |
gr.Number(label="Beam Size - Zero temperature", value=app_config.beam_size, precision=0, elem_id = "beam_size"),
|
1015 |
gr.Number(label="Patience - Zero temperature", value=app_config.patience, elem_id = "patience"),
|
1016 |
-
gr.Number(label="Length Penalty - Any temperature", value=app_config.length_penalty, elem_id = "length_penalty"),
|
1017 |
gr.Text(label="Suppress Tokens - Comma-separated list of token IDs", value=app_config.suppress_tokens, elem_id = "suppress_tokens"),
|
1018 |
gr.Checkbox(label="Condition on previous text", value=app_config.condition_on_previous_text, elem_id = "condition_on_previous_text"),
|
1019 |
gr.Checkbox(label="FP16", value=app_config.fp16, elem_id = "fp16"),
|
@@ -1054,7 +1078,7 @@ def create_ui(app_config: ApplicationConfig):
|
|
1054 |
print("Queue mode enabled (concurrency count: " + str(app_config.queue_concurrency_count) + ")")
|
1055 |
else:
|
1056 |
print("Queue mode disabled - progress bars will not be shown.")
|
1057 |
-
|
1058 |
demo.launch(inbrowser=app_config.autolaunch, share=app_config.share, server_name=app_config.server_name, server_port=app_config.server_port)
|
1059 |
|
1060 |
# Clean up
|
@@ -1136,6 +1160,16 @@ if __name__ == '__main__':
|
|
1136 |
# updated_config.autolaunch = True
|
1137 |
# updated_config.auto_parallel = False
|
1138 |
# updated_config.save_downloaded_files = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1139 |
|
1140 |
if (threads := args.pop("threads")) > 0:
|
1141 |
torch.set_num_threads(threads)
|
|
|
40 |
from src.translation.translationModel import TranslationModel
|
41 |
from src.translation.translationLangs import (TranslationLang,
|
42 |
_TO_LANG_CODE_WHISPER, get_lang_whisper_names, get_lang_from_whisper_name, get_lang_from_whisper_code,
|
43 |
+
get_lang_nllb_names, get_lang_from_nllb_name, get_lang_m2m100_names, get_lang_from_m2m100_name, sort_lang_by_whisper_codes)
|
44 |
import shutil
|
45 |
import zhconv
|
46 |
import tqdm
|
|
|
233 |
mt5LangName: str = decodeOptions.pop("mt5LangName")
|
234 |
ALMAModelName: str = decodeOptions.pop("ALMAModelName")
|
235 |
ALMALangName: str = decodeOptions.pop("ALMALangName")
|
236 |
+
madlad400ModelName: str = decodeOptions.pop("madlad400ModelName")
|
237 |
+
madlad400LangName: str = decodeOptions.pop("madlad400LangName")
|
238 |
|
239 |
translationBatchSize: int = decodeOptions.pop("translationBatchSize")
|
240 |
translationNoRepeatNgramSize: int = decodeOptions.pop("translationNoRepeatNgramSize")
|
|
|
252 |
vadPadding: float = decodeOptions.pop("vadPadding", self.app_config.vad_padding)
|
253 |
vadPromptWindow: float = decodeOptions.pop("vadPromptWindow", self.app_config.vad_prompt_window)
|
254 |
vadInitialPromptMode: str = decodeOptions.pop("vadInitialPromptMode", self.app_config.vad_initial_prompt_mode)
|
255 |
+
self.vad_process_timeout: float = decodeOptions.pop("vadPocessTimeout", self.vad_process_timeout)
|
256 |
|
257 |
diarization: bool = decodeOptions.pop("diarization", False)
|
258 |
diarization_speakers: int = decodeOptions.pop("diarization_speakers", 2)
|
|
|
270 |
if whisperNoRepeatNgramSize is not None and whisperNoRepeatNgramSize <= 1:
|
271 |
decodeOptions.pop("no_repeat_ngram_size")
|
272 |
|
273 |
+
for key, value in list(decodeOptions.items()):
|
274 |
+
if value == "":
|
275 |
+
del decodeOptions[key]
|
276 |
+
|
277 |
+
# word_timestamps = decodeOptions.get("word_timestamps", False)
|
278 |
+
# condition_on_previous_text = decodeOptions.get("condition_on_previous_text", False)
|
279 |
+
# prepend_punctuations = decodeOptions.get("prepend_punctuations", None)
|
280 |
+
# append_punctuations = decodeOptions.get("append_punctuations", None)
|
281 |
+
# initial_prompt = decodeOptions.get("initial_prompt", None)
|
282 |
+
# best_of = decodeOptions.get("best_of", None)
|
283 |
+
# beam_size = decodeOptions.get("beam_size", None)
|
284 |
+
# patience = decodeOptions.get("patience", None)
|
285 |
+
# length_penalty = decodeOptions.get("length_penalty", None)
|
286 |
+
# suppress_tokens = decodeOptions.get("suppress_tokens", None)
|
287 |
+
# compression_ratio_threshold = decodeOptions.get("compression_ratio_threshold", None)
|
288 |
+
# logprob_threshold = decodeOptions.get("logprob_threshold", None)
|
289 |
|
290 |
vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode)
|
291 |
|
|
|
346 |
selectedModelName = ALMAModelName if ALMAModelName is not None and len(ALMAModelName) > 0 else "ALMA-13B-GPTQ/TheBloke"
|
347 |
selectedModel = next((modelConfig for modelConfig in self.app_config.models["ALMA"] if modelConfig.name == selectedModelName), None)
|
348 |
translationLang = get_lang_from_m2m100_name(ALMALangName)
|
349 |
+
elif translateInput == "madlad400" and madlad400LangName is not None and len(madlad400LangName) > 0:
|
350 |
+
selectedModelName = madlad400ModelName if madlad400ModelName is not None and len(madlad400ModelName) > 0 else "madlad400-10b-mt-ct2-int8_float16"
|
351 |
+
selectedModel = next((modelConfig for modelConfig in self.app_config.models["madlad400"] if modelConfig.name == selectedModelName), None)
|
352 |
+
translationLang = get_lang_from_m2m100_name(madlad400LangName)
|
353 |
|
354 |
if translationLang is not None:
|
355 |
translationModel = TranslationModel(modelConfig=selectedModel, whisperLang=whisperLang, translationLang=translationLang, batchSize=translationBatchSize, noRepeatNgramSize=translationNoRepeatNgramSize, numBeams=translationNumBeams)
|
|
|
394 |
|
395 |
# Transcribe
|
396 |
result = self.transcribe_file(model, source.source_path, whisperLangCode, task, vadOptions, scaled_progress_listener, **decodeOptions)
|
397 |
+
if translationModel is not None and whisperLang is None and result["language"] is not None and len(result["language"]) > 0:
|
398 |
whisperLang = get_lang_from_whisper_code(result["language"])
|
399 |
translationModel.whisperLang = whisperLang
|
400 |
|
|
|
423 |
out = ffmpeg.output(input_file, input_srt, output_with_srt, vcodec='copy', acodec='copy', scodec='mov_text')
|
424 |
outRsult = out.run(overwrite_output=True)
|
425 |
except Exception as e:
|
426 |
+
print(traceback.format_exc())
|
427 |
print("Error merge subtitle with source file: \n" + source.source_path + ", \n" + str(e), outRsult)
|
428 |
elif self.app_config.save_downloaded_files and self.app_config.output_dir is not None and urlData:
|
429 |
print("Saving downloaded file [" + source.source_name + "]")
|
|
|
431 |
save_path = os.path.join(self.app_config.output_dir, filePrefix)
|
432 |
shutil.copy(source.source_path, save_path + suffix)
|
433 |
except Exception as e:
|
434 |
+
print(traceback.format_exc())
|
435 |
print("Error saving downloaded file: \n" + source.source_path + ", \n" + str(e))
|
436 |
|
437 |
if len(sources) > 1:
|
|
|
483 |
try:
|
484 |
os.remove(source.source_path)
|
485 |
except Exception as e:
|
486 |
+
print(traceback.format_exc())
|
487 |
print("Error deleting temporary source file: \n" + source.source_path + ", \n" + str(e))
|
488 |
|
489 |
except ExceededMaximumDuration as e:
|
|
|
629 |
def _create_silero_config(self, non_speech_strategy: NonSpeechStrategy, vadOptions: VadOptions):
|
630 |
# Use Silero VAD
|
631 |
if (self.vad_model is None):
|
632 |
+
self.vad_model = VadSileroTranscription() #vad_model is snakers4/silero-vad
|
633 |
|
634 |
config = TranscriptionConfig(non_speech_strategy = non_speech_strategy,
|
635 |
max_silent_period=vadOptions.vadMergeWindow, max_merge_size=vadOptions.vadMaxMergeSize,
|
|
|
671 |
|
672 |
print("\n\nprocess segments took {} seconds.\n\n".format(perf_end_time - perf_start_time))
|
673 |
except Exception as e:
|
|
|
674 |
print(traceback.format_exc())
|
675 |
print("Error process segments: " + str(e))
|
676 |
|
|
|
780 |
self.diarization = None
|
781 |
|
782 |
def create_ui(app_config: ApplicationConfig):
|
783 |
+
translateModelMd: str = None
|
784 |
optionsMd: str = None
|
785 |
readmeMd: str = None
|
786 |
+
try:
|
787 |
+
translateModelPath = pathlib.Path("docs/translateModel.md")
|
788 |
+
with open(translateModelPath, "r", encoding="utf-8") as translateModelFile:
|
789 |
+
translateModelMd = translateModelFile.read()
|
790 |
+
except Exception as e:
|
791 |
+
print("Error occurred during read translateModel.md file: ", str(e))
|
792 |
try:
|
793 |
optionsPath = pathlib.Path("docs/options.md")
|
794 |
with open(optionsPath, "r", encoding="utf-8") as optionsFile:
|
|
|
833 |
uiDescription += "\n\n" + "Max audio file length: " + str(app_config.input_audio_max_duration) + " s"
|
834 |
|
835 |
uiArticle = "Read the [documentation here](https://gitlab.com/aadnk/whisper-webui/-/blob/main/docs/options.md)."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
836 |
|
837 |
whisper_models = app_config.get_model_names("whisper")
|
838 |
nllb_models = app_config.get_model_names("nllb")
|
839 |
m2m100_models = app_config.get_model_names("m2m100")
|
840 |
mt5_models = app_config.get_model_names("mt5")
|
841 |
ALMA_models = app_config.get_model_names("ALMA")
|
842 |
+
madlad400_models = app_config.get_model_names("madlad400")
|
843 |
+
if not torch.cuda.is_available(): #Due to the poor support of GPTQ for CPUs, the execution time per iteration exceeds a thousand seconds when operating on a CPU. Therefore, when the system does not support a GPU, the GPTQ model is removed from the list.
|
844 |
+
ALMA_models = list(filter(lambda alma: "GPTQ" not in alma, ALMA_models))
|
845 |
+
|
846 |
common_whisper_inputs = lambda : {
|
847 |
gr.Dropdown(label="Whisper - Model (for audio)", choices=whisper_models, value=app_config.default_model_name, elem_id="whisperModelName"),
|
848 |
gr.Dropdown(label="Whisper - Language", choices=sorted(get_lang_whisper_names()), value=app_config.language, elem_id="whisperLangName"),
|
|
|
861 |
}
|
862 |
common_ALMA_inputs = lambda : {
|
863 |
gr.Dropdown(label="ALMA - Model (for translate)", choices=ALMA_models, elem_id="ALMAModelName"),
|
864 |
+
gr.Dropdown(label="ALMA - Language", choices=sort_lang_by_whisper_codes(["en", "de", "cs", "is", "ru", "zh", "ja"]), elem_id="ALMALangName"),
|
865 |
+
}
|
866 |
+
common_madlad400_inputs = lambda : {
|
867 |
+
gr.Dropdown(label="madlad400 - Model (for translate)", choices=madlad400_models, elem_id="madlad400ModelName"),
|
868 |
+
gr.Dropdown(label="madlad400 - Language", choices=sorted(get_lang_m2m100_names()), elem_id="madlad400LangName"),
|
869 |
}
|
870 |
|
871 |
common_translation_inputs = lambda : {
|
|
|
878 |
gr.Dropdown(choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], value=app_config.default_vad, label="VAD", elem_id="vad"),
|
879 |
gr.Number(label="VAD - Merge Window (s)", precision=0, value=app_config.vad_merge_window, elem_id="vadMergeWindow"),
|
880 |
gr.Number(label="VAD - Max Merge Size (s)", precision=0, value=app_config.vad_max_merge_size, elem_id="vadMaxMergeSize"),
|
881 |
+
gr.Number(label="VAD - Process Timeout (s)", precision=0, value=app_config.vad_process_timeout, elem_id="vadPocessTimeout"),
|
882 |
}
|
883 |
|
884 |
common_word_timestamps_inputs = lambda : {
|
|
|
931 |
with gr.Tab(label="ALMA") as simpleALMATab:
|
932 |
with gr.Row():
|
933 |
simpleInputDict.update(common_ALMA_inputs())
|
934 |
+
with gr.Tab(label="madlad400") as simplemadlad400Tab:
|
935 |
+
with gr.Row():
|
936 |
+
simpleInputDict.update(common_madlad400_inputs())
|
937 |
simpleM2M100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [simpleTranslateInput] )
|
938 |
simpleNllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [simpleTranslateInput] )
|
939 |
simpleMT5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [simpleTranslateInput] )
|
940 |
simpleALMATab.select(fn=lambda: "ALMA", inputs = [], outputs= [simpleTranslateInput] )
|
941 |
+
simplemadlad400Tab.select(fn=lambda: "madlad400", inputs = [], outputs= [simpleTranslateInput] )
|
942 |
with gr.Column():
|
943 |
with gr.Tab(label="URL") as simpleUrlTab:
|
944 |
simpleInputDict.update({gr.Text(label="URL (YouTube, etc.)", elem_id = "urlData")})
|
|
|
960 |
simpleInputDict.update(common_translation_inputs())
|
961 |
with gr.Column():
|
962 |
simpleOutput = common_output()
|
963 |
+
gr.Markdown(uiArticle)
|
964 |
+
if translateModelMd is not None:
|
965 |
+
with gr.Accordion("docs/translateModel.md", open=False):
|
966 |
+
gr.Markdown(translateModelMd)
|
967 |
if optionsMd is not None:
|
968 |
with gr.Accordion("docs/options.md", open=False):
|
969 |
gr.Markdown(optionsMd)
|
|
|
977 |
|
978 |
fullInputDict = {}
|
979 |
fullDescription = uiDescription + "\n\n\n\n" + "Be careful when changing some of the options in the full interface - this can cause the model to crash."
|
980 |
+
|
981 |
with gr.Blocks() as fullTranscribe:
|
982 |
fullTranslateInput = gr.State(value="m2m100", elem_id = "translateInput")
|
983 |
fullSourceInput = gr.State(value="urlData", elem_id = "sourceInput")
|
|
|
1000 |
with gr.Tab(label="ALMA") as fullALMATab:
|
1001 |
with gr.Row():
|
1002 |
fullInputDict.update(common_ALMA_inputs())
|
1003 |
+
with gr.Tab(label="madlad400") as fullmadlad400Tab:
|
1004 |
+
with gr.Row():
|
1005 |
+
fullInputDict.update(common_madlad400_inputs())
|
1006 |
fullM2M100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [fullTranslateInput] )
|
1007 |
fullNllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [fullTranslateInput] )
|
1008 |
fullMT5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [fullTranslateInput] )
|
1009 |
fullALMATab.select(fn=lambda: "ALMA", inputs = [], outputs= [fullTranslateInput] )
|
1010 |
+
fullmadlad400Tab.select(fn=lambda: "madlad400", inputs = [], outputs= [fullTranslateInput] )
|
1011 |
with gr.Column():
|
1012 |
with gr.Tab(label="URL") as fullUrlTab:
|
1013 |
fullInputDict.update({gr.Text(label="URL (YouTube, etc.)", elem_id = "urlData")})
|
|
|
1037 |
gr.Number(label="Best Of - Non-zero temperature", value=app_config.best_of, precision=0, elem_id = "best_of"),
|
1038 |
gr.Number(label="Beam Size - Zero temperature", value=app_config.beam_size, precision=0, elem_id = "beam_size"),
|
1039 |
gr.Number(label="Patience - Zero temperature", value=app_config.patience, elem_id = "patience"),
|
1040 |
+
gr.Number(label="Length Penalty - Any temperature", value=lambda : None if app_config.length_penalty is None else app_config.length_penalty, elem_id = "length_penalty"),
|
1041 |
gr.Text(label="Suppress Tokens - Comma-separated list of token IDs", value=app_config.suppress_tokens, elem_id = "suppress_tokens"),
|
1042 |
gr.Checkbox(label="Condition on previous text", value=app_config.condition_on_previous_text, elem_id = "condition_on_previous_text"),
|
1043 |
gr.Checkbox(label="FP16", value=app_config.fp16, elem_id = "fp16"),
|
|
|
1078 |
print("Queue mode enabled (concurrency count: " + str(app_config.queue_concurrency_count) + ")")
|
1079 |
else:
|
1080 |
print("Queue mode disabled - progress bars will not be shown.")
|
1081 |
+
|
1082 |
demo.launch(inbrowser=app_config.autolaunch, share=app_config.share, server_name=app_config.server_name, server_port=app_config.server_port)
|
1083 |
|
1084 |
# Clean up
|
|
|
1160 |
# updated_config.autolaunch = True
|
1161 |
# updated_config.auto_parallel = False
|
1162 |
# updated_config.save_downloaded_files = True
|
1163 |
+
|
1164 |
+
try:
|
1165 |
+
if torch.cuda.is_available():
|
1166 |
+
deviceId = torch.cuda.current_device()
|
1167 |
+
totalVram = torch.cuda.get_device_properties(deviceId).total_memory
|
1168 |
+
if totalVram/(1024*1024*1024) <= 4: #VRAM <= 4 GB
|
1169 |
+
updated_config.vad_process_timeout = 0
|
1170 |
+
except Exception as e:
|
1171 |
+
print(traceback.format_exc())
|
1172 |
+
print("Error detect vram: " + str(e))
|
1173 |
|
1174 |
if (threads := args.pop("threads")) > 0:
|
1175 |
torch.set_num_threads(threads)
|
config.json5
CHANGED
@@ -23,6 +23,10 @@
|
|
23 |
"name": "large",
|
24 |
"url": "large"
|
25 |
},
|
|
|
|
|
|
|
|
|
26 |
{
|
27 |
"name": "large-v2",
|
28 |
"url": "large-v2"
|
@@ -189,10 +193,55 @@
|
|
189 |
}
|
190 |
],
|
191 |
"ALMA": [
|
|
|
|
|
|
|
|
|
|
|
192 |
{
|
193 |
"name": "ALMA-13B-GPTQ/TheBloke",
|
194 |
"url": "TheBloke/ALMA-13B-GPTQ",
|
195 |
-
"type": "huggingface"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
},
|
197 |
]
|
198 |
},
|
|
|
23 |
"name": "large",
|
24 |
"url": "large"
|
25 |
},
|
26 |
+
{
|
27 |
+
"name": "large-v1",
|
28 |
+
"url": "large-v1"
|
29 |
+
},
|
30 |
{
|
31 |
"name": "large-v2",
|
32 |
"url": "large-v2"
|
|
|
193 |
}
|
194 |
],
|
195 |
"ALMA": [
|
196 |
+
{
|
197 |
+
"name": "ALMA-7B-GPTQ/TheBloke",
|
198 |
+
"url": "TheBloke/ALMA-7B-GPTQ",
|
199 |
+
"type": "huggingface"
|
200 |
+
},
|
201 |
{
|
202 |
"name": "ALMA-13B-GPTQ/TheBloke",
|
203 |
"url": "TheBloke/ALMA-13B-GPTQ",
|
204 |
+
"type": "huggingface"
|
205 |
+
},
|
206 |
+
{
|
207 |
+
"name": "ALMA-7B-GGUF-Q4_K_M/TheBloke",
|
208 |
+
"url": "TheBloke/ALMA-7B-GGUF",
|
209 |
+
"type": "huggingface",
|
210 |
+
"model_file": "alma-7b.Q4_K_M.gguf",
|
211 |
+
"tokenizer_url": "haoranxu/ALMA-7B"
|
212 |
+
},
|
213 |
+
{
|
214 |
+
"name": "ALMA-13B-GGUF-Q4_K_M/TheBloke",
|
215 |
+
"url": "TheBloke/ALMA-13B-GGUF",
|
216 |
+
"type": "huggingface",
|
217 |
+
"model_file": "alma-13b.Q4_K_M.gguf",
|
218 |
+
"tokenizer_url": "haoranxu/ALMA-13B"
|
219 |
+
},
|
220 |
+
{
|
221 |
+
"name": "ALMA-7B-ct2:int8_float16/avan",
|
222 |
+
"url": "avans06/ALMA-7B-ct2-int8_float16",
|
223 |
+
"type": "huggingface",
|
224 |
+
"tokenizer_url": "haoranxu/ALMA-7B"
|
225 |
+
},
|
226 |
+
{
|
227 |
+
"name": "ALMA-13B-ct2:int8_float16/avan",
|
228 |
+
"url": "avans06/ALMA-13B-ct2-int8_float16",
|
229 |
+
"type": "huggingface",
|
230 |
+
"tokenizer_url": "haoranxu/ALMA-13B"
|
231 |
+
},
|
232 |
+
],
|
233 |
+
"madlad400": [
|
234 |
+
{
|
235 |
+
"name": "madlad400-3b-mt-ct2-int8_float16/SoybeanMilk",
|
236 |
+
"url": "SoybeanMilk/madlad400-3b-mt-ct2-int8_float16",
|
237 |
+
"type": "huggingface",
|
238 |
+
"tokenizer_url": "jbochi/madlad400-3b-mt"
|
239 |
+
},
|
240 |
+
{
|
241 |
+
"name": "madlad400-10b-mt-ct2-int8_float16/SoybeanMilk",
|
242 |
+
"url": "SoybeanMilk/madlad400-10b-mt-ct2-int8_float16",
|
243 |
+
"type": "huggingface",
|
244 |
+
"tokenizer_url": "jbochi/madlad400-10b-mt"
|
245 |
},
|
246 |
]
|
247 |
},
|