Spaces:
Sleeping
Sleeping
Commit
·
e8762f9
1
Parent(s):
50167d4
Add support for the ALMA model.
Browse files- app.py +19 -0
- config.json5 +7 -0
- src/config.py +2 -2
- src/translation/translationModel.py +18 -1
- src/utils.py +1 -1
app.py
CHANGED
|
@@ -231,6 +231,8 @@ class WhisperTranscriber:
|
|
| 231 |
nllbLangName: str = decodeOptions.pop("nllbLangName")
|
| 232 |
mt5ModelName: str = decodeOptions.pop("mt5ModelName")
|
| 233 |
mt5LangName: str = decodeOptions.pop("mt5LangName")
|
|
|
|
|
|
|
| 234 |
|
| 235 |
translationBatchSize: int = decodeOptions.pop("translationBatchSize")
|
| 236 |
translationNoRepeatNgramSize: int = decodeOptions.pop("translationNoRepeatNgramSize")
|
|
@@ -337,6 +339,10 @@ class WhisperTranscriber:
|
|
| 337 |
selectedModelName = mt5ModelName if mt5ModelName is not None and len(mt5ModelName) > 0 else "mt5-zh-ja-en-trimmed/K024"
|
| 338 |
selectedModel = next((modelConfig for modelConfig in self.app_config.models["mt5"] if modelConfig.name == selectedModelName), None)
|
| 339 |
translationLang = get_lang_from_m2m100_name(mt5LangName)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
|
| 341 |
if translationLang is not None:
|
| 342 |
translationModel = TranslationModel(modelConfig=selectedModel, whisperLang=whisperLang, translationLang=translationLang, batchSize=translationBatchSize, noRepeatNgramSize=translationNoRepeatNgramSize, numBeams=translationNumBeams)
|
|
@@ -828,6 +834,7 @@ def create_ui(app_config: ApplicationConfig):
|
|
| 828 |
nllb_models = app_config.get_model_names("nllb")
|
| 829 |
m2m100_models = app_config.get_model_names("m2m100")
|
| 830 |
mt5_models = app_config.get_model_names("mt5")
|
|
|
|
| 831 |
|
| 832 |
common_whisper_inputs = lambda : {
|
| 833 |
gr.Dropdown(label="Whisper - Model (for audio)", choices=whisper_models, value=app_config.default_model_name, elem_id="whisperModelName"),
|
|
@@ -845,6 +852,10 @@ def create_ui(app_config: ApplicationConfig):
|
|
| 845 |
gr.Dropdown(label="MT5 - Model (for translate)", choices=mt5_models, elem_id="mt5ModelName"),
|
| 846 |
gr.Dropdown(label="MT5 - Language", choices=sorted(get_lang_m2m100_names(["en", "ja", "zh"])), elem_id="mt5LangName"),
|
| 847 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 848 |
|
| 849 |
common_translation_inputs = lambda : {
|
| 850 |
gr.Number(label="Translation - Batch Size", precision=0, value=app_config.translation_batch_size, elem_id="translationBatchSize"),
|
|
@@ -905,9 +916,13 @@ def create_ui(app_config: ApplicationConfig):
|
|
| 905 |
with gr.Tab(label="MT5") as simpleMT5Tab:
|
| 906 |
with gr.Row():
|
| 907 |
simpleInputDict.update(common_mt5_inputs())
|
|
|
|
|
|
|
|
|
|
| 908 |
simpleM2M100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [simpleTranslateInput] )
|
| 909 |
simpleNllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [simpleTranslateInput] )
|
| 910 |
simpleMT5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [simpleTranslateInput] )
|
|
|
|
| 911 |
with gr.Column():
|
| 912 |
with gr.Tab(label="URL") as simpleUrlTab:
|
| 913 |
simpleInputDict.update({gr.Text(label="URL (YouTube, etc.)", elem_id = "urlData")})
|
|
@@ -964,9 +979,13 @@ def create_ui(app_config: ApplicationConfig):
|
|
| 964 |
with gr.Tab(label="MT5") as fullMT5Tab:
|
| 965 |
with gr.Row():
|
| 966 |
fullInputDict.update(common_mt5_inputs())
|
|
|
|
|
|
|
|
|
|
| 967 |
fullM2M100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [fullTranslateInput] )
|
| 968 |
fullNllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [fullTranslateInput] )
|
| 969 |
fullMT5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [fullTranslateInput] )
|
|
|
|
| 970 |
with gr.Column():
|
| 971 |
with gr.Tab(label="URL") as fullUrlTab:
|
| 972 |
fullInputDict.update({gr.Text(label="URL (YouTube, etc.)", elem_id = "urlData")})
|
|
|
|
| 231 |
nllbLangName: str = decodeOptions.pop("nllbLangName")
|
| 232 |
mt5ModelName: str = decodeOptions.pop("mt5ModelName")
|
| 233 |
mt5LangName: str = decodeOptions.pop("mt5LangName")
|
| 234 |
+
ALMAModelName: str = decodeOptions.pop("ALMAModelName")
|
| 235 |
+
ALMALangName: str = decodeOptions.pop("ALMALangName")
|
| 236 |
|
| 237 |
translationBatchSize: int = decodeOptions.pop("translationBatchSize")
|
| 238 |
translationNoRepeatNgramSize: int = decodeOptions.pop("translationNoRepeatNgramSize")
|
|
|
|
| 339 |
selectedModelName = mt5ModelName if mt5ModelName is not None and len(mt5ModelName) > 0 else "mt5-zh-ja-en-trimmed/K024"
|
| 340 |
selectedModel = next((modelConfig for modelConfig in self.app_config.models["mt5"] if modelConfig.name == selectedModelName), None)
|
| 341 |
translationLang = get_lang_from_m2m100_name(mt5LangName)
|
| 342 |
+
elif translateInput == "ALMA" and ALMALangName is not None and len(ALMALangName) > 0:
|
| 343 |
+
selectedModelName = ALMAModelName if ALMAModelName is not None and len(ALMAModelName) > 0 else "ALMA-13B-GPTQ/TheBloke"
|
| 344 |
+
selectedModel = next((modelConfig for modelConfig in self.app_config.models["ALMA"] if modelConfig.name == selectedModelName), None)
|
| 345 |
+
translationLang = get_lang_from_m2m100_name(ALMALangName)
|
| 346 |
|
| 347 |
if translationLang is not None:
|
| 348 |
translationModel = TranslationModel(modelConfig=selectedModel, whisperLang=whisperLang, translationLang=translationLang, batchSize=translationBatchSize, noRepeatNgramSize=translationNoRepeatNgramSize, numBeams=translationNumBeams)
|
|
|
|
| 834 |
nllb_models = app_config.get_model_names("nllb")
|
| 835 |
m2m100_models = app_config.get_model_names("m2m100")
|
| 836 |
mt5_models = app_config.get_model_names("mt5")
|
| 837 |
+
ALMA_models = app_config.get_model_names("ALMA")
|
| 838 |
|
| 839 |
common_whisper_inputs = lambda : {
|
| 840 |
gr.Dropdown(label="Whisper - Model (for audio)", choices=whisper_models, value=app_config.default_model_name, elem_id="whisperModelName"),
|
|
|
|
| 852 |
gr.Dropdown(label="MT5 - Model (for translate)", choices=mt5_models, elem_id="mt5ModelName"),
|
| 853 |
gr.Dropdown(label="MT5 - Language", choices=sorted(get_lang_m2m100_names(["en", "ja", "zh"])), elem_id="mt5LangName"),
|
| 854 |
}
|
| 855 |
+
common_ALMA_inputs = lambda : {
|
| 856 |
+
gr.Dropdown(label="ALMA - Model (for translate)", choices=ALMA_models, elem_id="ALMAModelName"),
|
| 857 |
+
gr.Dropdown(label="ALMA - Language", choices=sorted(get_lang_m2m100_names(["en", "ja", "zh"])), elem_id="ALMALangName"),
|
| 858 |
+
}
|
| 859 |
|
| 860 |
common_translation_inputs = lambda : {
|
| 861 |
gr.Number(label="Translation - Batch Size", precision=0, value=app_config.translation_batch_size, elem_id="translationBatchSize"),
|
|
|
|
| 916 |
with gr.Tab(label="MT5") as simpleMT5Tab:
|
| 917 |
with gr.Row():
|
| 918 |
simpleInputDict.update(common_mt5_inputs())
|
| 919 |
+
with gr.Tab(label="ALMA") as simpleALMATab:
|
| 920 |
+
with gr.Row():
|
| 921 |
+
simpleInputDict.update(common_ALMA_inputs())
|
| 922 |
simpleM2M100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [simpleTranslateInput] )
|
| 923 |
simpleNllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [simpleTranslateInput] )
|
| 924 |
simpleMT5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [simpleTranslateInput] )
|
| 925 |
+
simpleALMATab.select(fn=lambda: "ALMA", inputs = [], outputs= [simpleTranslateInput] )
|
| 926 |
with gr.Column():
|
| 927 |
with gr.Tab(label="URL") as simpleUrlTab:
|
| 928 |
simpleInputDict.update({gr.Text(label="URL (YouTube, etc.)", elem_id = "urlData")})
|
|
|
|
| 979 |
with gr.Tab(label="MT5") as fullMT5Tab:
|
| 980 |
with gr.Row():
|
| 981 |
fullInputDict.update(common_mt5_inputs())
|
| 982 |
+
with gr.Tab(label="ALMA") as fullALMATab:
|
| 983 |
+
with gr.Row():
|
| 984 |
+
fullInputDict.update(common_ALMA_inputs())
|
| 985 |
fullM2M100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [fullTranslateInput] )
|
| 986 |
fullNllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [fullTranslateInput] )
|
| 987 |
fullMT5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [fullTranslateInput] )
|
| 988 |
+
fullALMATab.select(fn=lambda: "ALMA", inputs = [], outputs= [fullTranslateInput] )
|
| 989 |
with gr.Column():
|
| 990 |
with gr.Tab(label="URL") as fullUrlTab:
|
| 991 |
fullInputDict.update({gr.Text(label="URL (YouTube, etc.)", elem_id = "urlData")})
|
config.json5
CHANGED
|
@@ -191,6 +191,13 @@
|
|
| 191 |
"url": "engmatic-earth/mt5-zh-ja-en-trimmed-fine-tuned-v1",
|
| 192 |
"type": "huggingface"
|
| 193 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
]
|
| 195 |
},
|
| 196 |
// Configuration options that will be used if they are not specified in the command line arguments.
|
|
|
|
| 191 |
"url": "engmatic-earth/mt5-zh-ja-en-trimmed-fine-tuned-v1",
|
| 192 |
"type": "huggingface"
|
| 193 |
}
|
| 194 |
+
],
|
| 195 |
+
"ALMA": [
|
| 196 |
+
{
|
| 197 |
+
"name": "ALMA-13B-GPTQ/TheBloke",
|
| 198 |
+
"url": "TheBloke/ALMA-13B-GPTQ",
|
| 199 |
+
"type": "huggingface",
|
| 200 |
+
},
|
| 201 |
]
|
| 202 |
},
|
| 203 |
// Configuration options that will be used if they are not specified in the command line arguments.
|
src/config.py
CHANGED
|
@@ -43,7 +43,7 @@ class VadInitialPromptMode(Enum):
|
|
| 43 |
return None
|
| 44 |
|
| 45 |
class ApplicationConfig:
|
| 46 |
-
def __init__(self, models: Dict[Literal["whisper", "m2m100", "nllb", "mt5"], List[ModelConfig]],
|
| 47 |
input_audio_max_duration: int = 600, share: bool = False, server_name: str = None, server_port: int = 7860,
|
| 48 |
queue_concurrency_count: int = 1, delete_uploaded_files: bool = True,
|
| 49 |
whisper_implementation: str = "whisper", default_model_name: str = "medium",
|
|
@@ -169,7 +169,7 @@ class ApplicationConfig:
|
|
| 169 |
# Load using json5
|
| 170 |
data = json5.load(f)
|
| 171 |
data_models = data.pop("models", [])
|
| 172 |
-
models: Dict[Literal["whisper", "m2m100", "nllb", "mt5"], List[ModelConfig]] = {
|
| 173 |
key: [ModelConfig(**item) for item in value]
|
| 174 |
for key, value in data_models.items()
|
| 175 |
}
|
|
|
|
| 43 |
return None
|
| 44 |
|
| 45 |
class ApplicationConfig:
|
| 46 |
+
def __init__(self, models: Dict[Literal["whisper", "m2m100", "nllb", "mt5", "ALMA"], List[ModelConfig]],
|
| 47 |
input_audio_max_duration: int = 600, share: bool = False, server_name: str = None, server_port: int = 7860,
|
| 48 |
queue_concurrency_count: int = 1, delete_uploaded_files: bool = True,
|
| 49 |
whisper_implementation: str = "whisper", default_model_name: str = "medium",
|
|
|
|
| 169 |
# Load using json5
|
| 170 |
data = json5.load(f)
|
| 171 |
data_models = data.pop("models", [])
|
| 172 |
+
models: Dict[Literal["whisper", "m2m100", "nllb", "mt5", "ALMA"], List[ModelConfig]] = {
|
| 173 |
key: [ModelConfig(**item) for item in value]
|
| 174 |
for key, value in data_models.items()
|
| 175 |
}
|
src/translation/translationModel.py
CHANGED
|
@@ -7,6 +7,8 @@ import torch
|
|
| 7 |
import ctranslate2
|
| 8 |
import transformers
|
| 9 |
|
|
|
|
|
|
|
| 10 |
from typing import Optional
|
| 11 |
from src.config import ModelConfig
|
| 12 |
from src.translation.translationLangs import TranslationLang, get_lang_from_whisper_code
|
|
@@ -97,6 +99,11 @@ class TranslationModel:
|
|
| 97 |
self.transTokenizer = transformers.T5Tokenizer.from_pretrained(self.modelPath, legacy=False) #requires spiece.model
|
| 98 |
self.transModel = transformers.MT5ForConditionalGeneration.from_pretrained(self.modelPath)
|
| 99 |
self.transTranslator = transformers.pipeline('text2text-generation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
else:
|
| 101 |
self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelPath)
|
| 102 |
self.transModel = transformers.AutoModelForSeq2SeqLM.from_pretrained(self.modelPath)
|
|
@@ -130,6 +137,12 @@ class TranslationModel:
|
|
| 130 |
elif "mt5" in self.modelPath:
|
| 131 |
output = self.transTranslator(self.mt5Prefix + text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams) #, num_return_sequences=2
|
| 132 |
result = output[0]['generated_text']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
else: #M2M100 & NLLB
|
| 134 |
output = self.transTranslator(text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams)
|
| 135 |
result = output[0]['translation_text']
|
|
@@ -148,7 +161,8 @@ _MODELS = ["distilled-600M", "distilled-1.3B", "1.3B", "3.3B",
|
|
| 148 |
"m2m100_1.2B-ct2", "m2m100_418M-ct2", "m2m100-12B-ct2",
|
| 149 |
"m2m100_1.2B", "m2m100_418M",
|
| 150 |
"mt5-zh-ja-en-trimmed",
|
| 151 |
-
"mt5-zh-ja-en-trimmed-fine-tuned-v1"
|
|
|
|
| 152 |
|
| 153 |
def check_model_name(name):
|
| 154 |
return any(allowed_name in name for allowed_name in _MODELS)
|
|
@@ -206,6 +220,9 @@ def download_model(
|
|
| 206 |
"special_tokens_map.json",
|
| 207 |
"spiece.model",
|
| 208 |
"vocab.json", #m2m100
|
|
|
|
|
|
|
|
|
|
| 209 |
]
|
| 210 |
|
| 211 |
kwargs = {
|
|
|
|
| 7 |
import ctranslate2
|
| 8 |
import transformers
|
| 9 |
|
| 10 |
+
import re
|
| 11 |
+
|
| 12 |
from typing import Optional
|
| 13 |
from src.config import ModelConfig
|
| 14 |
from src.translation.translationLangs import TranslationLang, get_lang_from_whisper_code
|
|
|
|
| 99 |
self.transTokenizer = transformers.T5Tokenizer.from_pretrained(self.modelPath, legacy=False) #requires spiece.model
|
| 100 |
self.transModel = transformers.MT5ForConditionalGeneration.from_pretrained(self.modelPath)
|
| 101 |
self.transTranslator = transformers.pipeline('text2text-generation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer)
|
| 102 |
+
elif "ALMA" in self.modelPath:
|
| 103 |
+
self.ALMAPrefix = "Translate this from " + self.whisperLang.whisper.code + " to " + self.translationLang.whisper.code + ":" + self.whisperLang.whisper.code + ":"
|
| 104 |
+
self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelPath, use_fast=True)
|
| 105 |
+
self.transModel = transformers.AutoModelForCausalLM.from_pretrained(self.modelPath, device_map="auto", trust_remote_code=False, revision="main")
|
| 106 |
+
self.transTranslator = transformers.pipeline("text-generation", model=self.transModel, tokenizer=self.transTokenizer, batch_size=2, do_sample=True, temperature=0.7, top_p=0.95, top_k=40, repetition_penalty=1.1)
|
| 107 |
else:
|
| 108 |
self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelPath)
|
| 109 |
self.transModel = transformers.AutoModelForSeq2SeqLM.from_pretrained(self.modelPath)
|
|
|
|
| 137 |
elif "mt5" in self.modelPath:
|
| 138 |
output = self.transTranslator(self.mt5Prefix + text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams) #, num_return_sequences=2
|
| 139 |
result = output[0]['generated_text']
|
| 140 |
+
elif "ALMA" in self.modelPath:
|
| 141 |
+
output = self.transTranslator(self.ALMAPrefix + text + self.translationLang.whisper.code + ":", max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams)
|
| 142 |
+
result = output[0]['generated_text']
|
| 143 |
+
result = re.sub(rf'^(.*{self.translationLang.whisper.code}: )', '', result) # Remove the prompt from the result
|
| 144 |
+
result = re.sub(rf'^(Translate this from .* to .*:)', '', result) # Remove the translation instruction
|
| 145 |
+
return result.strip()
|
| 146 |
else: #M2M100 & NLLB
|
| 147 |
output = self.transTranslator(text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams)
|
| 148 |
result = output[0]['translation_text']
|
|
|
|
| 161 |
"m2m100_1.2B-ct2", "m2m100_418M-ct2", "m2m100-12B-ct2",
|
| 162 |
"m2m100_1.2B", "m2m100_418M",
|
| 163 |
"mt5-zh-ja-en-trimmed",
|
| 164 |
+
"mt5-zh-ja-en-trimmed-fine-tuned-v1",
|
| 165 |
+
"ALMA-13B-GPTQ"]
|
| 166 |
|
| 167 |
def check_model_name(name):
|
| 168 |
return any(allowed_name in name for allowed_name in _MODELS)
|
|
|
|
| 220 |
"special_tokens_map.json",
|
| 221 |
"spiece.model",
|
| 222 |
"vocab.json", #m2m100
|
| 223 |
+
"model.safetensors",
|
| 224 |
+
"quantize_config.json",
|
| 225 |
+
"tokenizer.model"
|
| 226 |
]
|
| 227 |
|
| 228 |
kwargs = {
|
src/utils.py
CHANGED
|
@@ -130,7 +130,7 @@ def write_srt_original(transcript: Iterator[dict], file: TextIO,
|
|
| 130 |
flush=True,
|
| 131 |
)
|
| 132 |
|
| 133 |
-
if original is not None: print(f"{original}",
|
| 134 |
file=file,
|
| 135 |
flush=True)
|
| 136 |
|
|
|
|
| 130 |
flush=True,
|
| 131 |
)
|
| 132 |
|
| 133 |
+
if original is not None: print(f"{original}\n",
|
| 134 |
file=file,
|
| 135 |
flush=True)
|
| 136 |
|