Spaces:
Runtime error
Runtime error
SoybeanMilk
commited on
Commit
·
73779c4
1
Parent(s):
e57c738
Upload translationModel.py
Browse files
src/translation/translationModel.py
CHANGED
@@ -7,9 +7,12 @@ import torch
|
|
7 |
import ctranslate2
|
8 |
import transformers
|
9 |
|
|
|
|
|
10 |
from typing import Optional
|
11 |
from src.config import ModelConfig
|
12 |
from src.translation.translationLangs import TranslationLang, get_lang_from_whisper_code
|
|
|
13 |
|
14 |
class TranslationModel:
|
15 |
def __init__(
|
@@ -92,11 +95,17 @@ class TranslationModel:
|
|
92 |
self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelConfig.tokenizer_url if self.modelConfig.tokenizer_url is not None and len(self.modelConfig.tokenizer_url) > 0 else self.modelPath, src_lang=self.whisperLang.m2m100.code)
|
93 |
self.targetPrefix = [self.transTokenizer.lang_code_to_token[self.translationLang.m2m100.code]]
|
94 |
self.transModel = ctranslate2.Translator(self.modelPath, compute_type="auto", device=self.device)
|
|
|
95 |
elif "mt5" in self.modelPath:
|
96 |
self.mt5Prefix = self.whisperLang.whisper.code + "2" + self.translationLang.whisper.code + ": "
|
97 |
self.transTokenizer = transformers.T5Tokenizer.from_pretrained(self.modelPath, legacy=False) #requires spiece.model
|
98 |
self.transModel = transformers.MT5ForConditionalGeneration.from_pretrained(self.modelPath)
|
99 |
self.transTranslator = transformers.pipeline('text2text-generation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer)
|
|
|
|
|
|
|
|
|
|
|
100 |
else:
|
101 |
self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelPath)
|
102 |
self.transModel = transformers.AutoModelForSeq2SeqLM.from_pretrained(self.modelPath)
|
@@ -130,6 +139,12 @@ class TranslationModel:
|
|
130 |
elif "mt5" in self.modelPath:
|
131 |
output = self.transTranslator(self.mt5Prefix + text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams) #, num_return_sequences=2
|
132 |
result = output[0]['generated_text']
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
else: #M2M100 & NLLB
|
134 |
output = self.transTranslator(text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams)
|
135 |
result = output[0]['translation_text']
|
@@ -148,7 +163,8 @@ _MODELS = ["distilled-600M", "distilled-1.3B", "1.3B", "3.3B",
|
|
148 |
"m2m100_1.2B-ct2", "m2m100_418M-ct2", "m2m100-12B-ct2",
|
149 |
"m2m100_1.2B", "m2m100_418M",
|
150 |
"mt5-zh-ja-en-trimmed",
|
151 |
-
"mt5-zh-ja-en-trimmed-fine-tuned-v1"
|
|
|
152 |
|
153 |
def check_model_name(name):
|
154 |
return any(allowed_name in name for allowed_name in _MODELS)
|
@@ -206,6 +222,9 @@ def download_model(
|
|
206 |
"special_tokens_map.json",
|
207 |
"spiece.model",
|
208 |
"vocab.json", #m2m100
|
|
|
|
|
|
|
209 |
]
|
210 |
|
211 |
kwargs = {
|
|
|
7 |
import ctranslate2
|
8 |
import transformers
|
9 |
|
10 |
+
import re
|
11 |
+
|
12 |
from typing import Optional
|
13 |
from src.config import ModelConfig
|
14 |
from src.translation.translationLangs import TranslationLang, get_lang_from_whisper_code
|
15 |
+
from peft import PeftModel
|
16 |
|
17 |
class TranslationModel:
|
18 |
def __init__(
|
|
|
95 |
self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelConfig.tokenizer_url if self.modelConfig.tokenizer_url is not None and len(self.modelConfig.tokenizer_url) > 0 else self.modelPath, src_lang=self.whisperLang.m2m100.code)
|
96 |
self.targetPrefix = [self.transTokenizer.lang_code_to_token[self.translationLang.m2m100.code]]
|
97 |
self.transModel = ctranslate2.Translator(self.modelPath, compute_type="auto", device=self.device)
|
98 |
+
|
99 |
elif "mt5" in self.modelPath:
|
100 |
self.mt5Prefix = self.whisperLang.whisper.code + "2" + self.translationLang.whisper.code + ": "
|
101 |
self.transTokenizer = transformers.T5Tokenizer.from_pretrained(self.modelPath, legacy=False) #requires spiece.model
|
102 |
self.transModel = transformers.MT5ForConditionalGeneration.from_pretrained(self.modelPath)
|
103 |
self.transTranslator = transformers.pipeline('text2text-generation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer)
|
104 |
+
elif "ALMA" in self.modelPath:
|
105 |
+
self.ALMAPrefix = "Translate this from " + self.whisperLang.whisper.code + " to " + self.translationLang.whisper.code + ":" + self.whisperLang.whisper.code + ":"
|
106 |
+
self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelPath, use_fast=True)
|
107 |
+
self.transModel = transformers.AutoModelForCausalLM.from_pretrained(self.modelPath, device_map="auto", trust_remote_code=False, revision="main")
|
108 |
+
self.transTranslator = transformers.pipeline("text-generation", model=self.transModel, tokenizer=self.transTokenizer, batch_size=2, do_sample=True, temperature=0.7, top_p=0.95, top_k=40, repetition_penalty=1.1)
|
109 |
else:
|
110 |
self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelPath)
|
111 |
self.transModel = transformers.AutoModelForSeq2SeqLM.from_pretrained(self.modelPath)
|
|
|
139 |
elif "mt5" in self.modelPath:
|
140 |
output = self.transTranslator(self.mt5Prefix + text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams) #, num_return_sequences=2
|
141 |
result = output[0]['generated_text']
|
142 |
+
elif "ALMA" in self.modelPath:
|
143 |
+
output = self.transTranslator(self.ALMAPrefix + text + self.translationLang.whisper.code + ":", max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams)
|
144 |
+
result = output[0]['generated_text']
|
145 |
+
result = re.sub(rf'^(.*{self.translationLang.whisper.code}: )', '', result) # Remove the prompt from the result
|
146 |
+
result = re.sub(rf'^(Translate this from .* to .*:)', '', result) # Remove the translation instruction
|
147 |
+
return result.strip()
|
148 |
else: #M2M100 & NLLB
|
149 |
output = self.transTranslator(text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams)
|
150 |
result = output[0]['translation_text']
|
|
|
163 |
"m2m100_1.2B-ct2", "m2m100_418M-ct2", "m2m100-12B-ct2",
|
164 |
"m2m100_1.2B", "m2m100_418M",
|
165 |
"mt5-zh-ja-en-trimmed",
|
166 |
+
"mt5-zh-ja-en-trimmed-fine-tuned-v1",
|
167 |
+
"ALMA-13B-GPTQ"]
|
168 |
|
169 |
def check_model_name(name):
|
170 |
return any(allowed_name in name for allowed_name in _MODELS)
|
|
|
222 |
"special_tokens_map.json",
|
223 |
"spiece.model",
|
224 |
"vocab.json", #m2m100
|
225 |
+
"model.safetensors",
|
226 |
+
"quantize_config.json",
|
227 |
+
"tokenizer.model"
|
228 |
]
|
229 |
|
230 |
kwargs = {
|