SoybeanMilk commited on
Commit
4dcbce8
1 Parent(s): 7282235

Add madlad400 support.

Browse files
Files changed (1) hide show
  1. src/translation/translationModel.py +204 -55
src/translation/translationModel.py CHANGED
@@ -3,17 +3,14 @@ import warnings
3
  import huggingface_hub
4
  import requests
5
  import torch
6
-
7
  import ctranslate2
8
  import transformers
9
-
10
- import re
11
 
12
  from typing import Optional
13
  from src.config import ModelConfig
14
  from src.translation.translationLangs import TranslationLang, get_lang_from_whisper_code
15
 
16
-
17
  class TranslationModel:
18
  def __init__(
19
  self,
@@ -68,7 +65,7 @@ class TranslationModel:
68
  if os.path.isdir(modelConfig.url):
69
  self.modelPath = modelConfig.url
70
  else:
71
- self.modelPath = download_model(
72
  modelConfig,
73
  localFilesOnly=localFilesOnly,
74
  cacheDir=downloadRoot,
@@ -86,85 +83,233 @@ class TranslationModel:
86
  self.load_model()
87
 
88
  def load_model(self):
89
- print('\n\nLoading model: %s\n\n' % self.modelPath)
90
- if "ct2" in self.modelPath:
91
- if "nllb" in self.modelPath:
92
- self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelConfig.tokenizer_url if self.modelConfig.tokenizer_url is not None and len(self.modelConfig.tokenizer_url) > 0 else self.modelPath, src_lang=self.whisperLang.nllb.code)
93
- self.targetPrefix = [self.translationLang.nllb.code]
94
- elif "m2m100" in self.modelPath:
95
- self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelConfig.tokenizer_url if self.modelConfig.tokenizer_url is not None and len(self.modelConfig.tokenizer_url) > 0 else self.modelPath, src_lang=self.whisperLang.m2m100.code)
96
- self.targetPrefix = [self.transTokenizer.lang_code_to_token[self.translationLang.m2m100.code]]
97
- self.transModel = ctranslate2.Translator(self.modelPath, compute_type="auto", device=self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
- elif "mt5" in self.modelPath:
100
- self.mt5Prefix = self.whisperLang.whisper.code + "2" + self.translationLang.whisper.code + ": "
101
- self.transTokenizer = transformers.T5Tokenizer.from_pretrained(self.modelPath, legacy=False) #requires spiece.model
102
- self.transModel = transformers.MT5ForConditionalGeneration.from_pretrained(self.modelPath)
103
- self.transTranslator = transformers.pipeline('text2text-generation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer)
104
- elif "ALMA" in self.modelPath:
105
- self.ALMAPrefix = "Translate this from " + self.whisperLang.whisper.code + " to " + self.translationLang.whisper.code + ":" + self.whisperLang.whisper.code + ":"
106
- self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelPath, use_fast=True)
107
- self.transModel = transformers.AutoModelForCausalLM.from_pretrained(self.modelPath, device_map="auto", trust_remote_code=False, revision="main")
108
- self.transTranslator = transformers.pipeline("text-generation", model=self.transModel, tokenizer=self.transTokenizer, batch_size=2, do_sample=True, temperature=0.7, top_p=0.95, top_k=40, repetition_penalty=1.1)
109
- else:
110
- self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelPath)
111
- self.transModel = transformers.AutoModelForSeq2SeqLM.from_pretrained(self.modelPath)
112
- if "m2m100" in self.modelPath:
113
- self.transTranslator = transformers.pipeline('translation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer, src_lang=self.whisperLang.m2m100.code, tgt_lang=self.translationLang.m2m100.code)
114
- else: #NLLB
115
- self.transTranslator = transformers.pipeline('translation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer, src_lang=self.whisperLang.nllb.code, tgt_lang=self.translationLang.nllb.code)
116
 
117
  def release_vram(self):
118
  try:
119
  if torch.cuda.is_available():
120
  if "ct2" not in self.modelPath:
121
- device = torch.device("cpu")
122
- self.transModel.to(device)
123
- del self.transModel
124
- torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  print("release vram end.")
126
  except Exception as e:
 
127
  print("Error release vram: " + str(e))
128
 
129
 
130
  def translation(self, text: str, max_length: int = 400):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  output = None
132
  result = None
133
  try:
134
  if "ct2" in self.modelPath:
135
- source = self.transTokenizer.convert_ids_to_tokens(self.transTokenizer.encode(text))
136
- output = self.transModel.translate_batch([source], target_prefix=[self.targetPrefix], max_batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, beam_size=self.numBeams)
137
- target = output[0].hypotheses[0][1:]
138
- result = self.transTokenizer.decode(self.transTokenizer.convert_tokens_to_ids(target))
 
 
 
 
 
 
 
 
 
 
 
139
  elif "mt5" in self.modelPath:
140
  output = self.transTranslator(self.mt5Prefix + text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams) #, num_return_sequences=2
141
  result = output[0]['generated_text']
142
  elif "ALMA" in self.modelPath:
143
- output = self.transTranslator(self.ALMAPrefix + text + self.translationLang.whisper.code + ":", max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams)
 
 
 
144
  result = output[0]['generated_text']
145
- result = re.sub(rf'^(.*{self.translationLang.whisper.code}: )', '', result) # Remove the prompt from the result
146
- result = re.sub(rf'^(Translate this from .* to .*:)', '', result) # Remove the translation instruction
147
- return result.strip()
148
  else: #M2M100 & NLLB
149
  output = self.transTranslator(text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams)
150
  result = output[0]['translation_text']
151
  except Exception as e:
 
152
  print("Error translation text: " + str(e))
153
 
154
  return result
155
 
156
 
157
- _MODELS = ["distilled-600M", "distilled-1.3B", "1.3B", "3.3B",
158
- "ct2fast-nllb-200-distilled-1.3B-int8_float16",
159
- "ct2fast-nllb-200-3.3B-int8_float16",
160
- "nllb-200-3.3B-ct2-float16", "nllb-200-1.3B-ct2", "nllb-200-1.3B-ct2-int8", "nllb-200-1.3B-ct2-float16",
161
- "nllb-200-distilled-1.3B-ct2", "nllb-200-distilled-1.3B-ct2-int8", "nllb-200-distilled-1.3B-ct2-float16",
162
- "nllb-200-distilled-600M-ct2", "nllb-200-distilled-600M-ct2-int8", "nllb-200-distilled-600M-ct2-float16",
163
- "m2m100_1.2B-ct2", "m2m100_418M-ct2", "m2m100-12B-ct2",
164
- "m2m100_1.2B", "m2m100_418M",
165
- "mt5-zh-ja-en-trimmed",
166
- "mt5-zh-ja-en-trimmed-fine-tuned-v1",
167
- "ALMA-13B-GPTQ"]
168
 
169
  def check_model_name(name):
170
  return any(allowed_name in name for allowed_name in _MODELS)
@@ -224,7 +369,8 @@ def download_model(
224
  "vocab.json", #m2m100
225
  "model.safetensors",
226
  "quantize_config.json",
227
- "tokenizer.model"
 
228
  ]
229
 
230
  kwargs = {
@@ -232,6 +378,9 @@ def download_model(
232
  "allow_patterns": allowPatterns,
233
  #"tqdm_class": disabled_tqdm,
234
  }
 
 
 
235
 
236
  if outputDir is not None:
237
  kwargs["local_dir"] = outputDir
 
3
  import huggingface_hub
4
  import requests
5
  import torch
 
6
  import ctranslate2
7
  import transformers
8
+ import traceback
 
9
 
10
  from typing import Optional
11
  from src.config import ModelConfig
12
  from src.translation.translationLangs import TranslationLang, get_lang_from_whisper_code
13
 
 
14
  class TranslationModel:
15
  def __init__(
16
  self,
 
65
  if os.path.isdir(modelConfig.url):
66
  self.modelPath = modelConfig.url
67
  else:
68
+ self.modelPath = modelConfig.url if getattr(modelConfig, "model_file", None) is not None else download_model(
69
  modelConfig,
70
  localFilesOnly=localFilesOnly,
71
  cacheDir=downloadRoot,
 
83
  self.load_model()
84
 
85
  def load_model(self):
86
+ """
87
+ [from_pretrained]
88
+ low_cpu_mem_usage(bool, optional)
89
+ Tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. This is an experimental feature and a subject to change at any moment.
90
+
91
+ [transformers.AutoTokenizer.from_pretrained]
92
+ use_fast (bool, optional, defaults to True):
93
+ Use a fast Rust-based tokenizer if it is supported for a given model.
94
+ If a fast tokenizer is not available for a given model, a normal Python-based tokenizer is returned instead.
95
+
96
+ [transformers.AutoModelForCausalLM.from_pretrained]
97
+ device_map (str or Dict[str, Union[int, str, torch.device], optional):
98
+ Sent directly as model_kwargs (just a simpler shortcut). When accelerate library is present,
99
+ set device_map="auto" to compute the most optimized device_map automatically.
100
+ revision (str, optional, defaults to "main"):
101
+ The specific model version to use. It can be a branch name, a tag name, or a commit id,
102
+ since we use a git-based system for storing models and other artifacts on huggingface.co,
103
+ so revision can be any identifier allowed by git.
104
+ code_revision (str, optional, defaults to "main")
105
+ The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model.
106
+ It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co,
107
+ so revision can be any identifier allowed by git.
108
+ trust_remote_code (bool, optional, defaults to False):
109
+ Whether or not to allow for custom models defined on the Hub in their own modeling files.
110
+ This option should only be set to True for repositories you trust and in which you have read the code,
111
+ as it will execute code present on the Hub on your local machine.
112
+
113
+ [transformers.pipeline "text-generation"]
114
+ do_sample:
115
+ if set to True, this parameter enables decoding strategies such as multinomial sampling,
116
+ beam-search multinomial sampling, Top-K sampling and Top-p sampling.
117
+ All these strategies select the next token from the probability distribution
118
+ over the entire vocabulary with various strategy-specific adjustments.
119
+ temperature (float, optional, defaults to 1.0):
120
+ The value used to modulate the next token probabilities.
121
+ top_k (int, optional, defaults to 50):
122
+ The number of highest probability vocabulary tokens to keep for top-k-filtering.
123
+ top_p (float, optional, defaults to 1.0):
124
+ If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
125
+ repetition_penalty (float, optional, defaults to 1.0)
126
+ The parameter for repetition penalty. 1.0 means no penalty. See this paper for more details.
127
+
128
+ [transformers.GPTQConfig]
129
+ use_exllama (bool, optional):
130
+ Whether to use exllama backend. Defaults to True if unset. Only works with bits = 4.
131
+
132
+ [ExLlama]
133
+ ExLlama is a Python/C++/CUDA implementation of the Llama model that is designed for faster inference with 4-bit GPTQ weights (check out these benchmarks).
134
+ The ExLlama kernel is activated by default when you create a [GPTQConfig] object.
135
+ To boost inference speed even further, use the ExLlamaV2 kernels by configuring the exllama_config parameter.
136
+ The ExLlama kernels are only supported when the entire model is on the GPU.
137
+ If you're doing inference on a CPU with AutoGPTQ (version > 0.4.2), then you'll need to disable the ExLlama kernel.
138
+ This overwrites the attributes related to the ExLlama kernels in the quantization config of the config.json file.
139
+ https://github.com/huggingface/transformers/blob/main/docs/source/en/quantization.md#exllama
140
+
141
+ [ctransformers]
142
+ gpu_layers
143
+ means number of layers to run on GPU. Depending on how much GPU memory is available you can increase gpu_layers. Start with a larger value gpu_layers=100 and if it runs out of memory, try smaller values.
144
+ To run some of the model layers on GPU, set the `gpu_layers` parameter
145
+ https://github.com/marella/ctransformers/issues/68
146
+ """
147
+ try:
148
+ print('\n\nLoading model: %s\n\n' % self.modelPath)
149
+ if "ct2" in self.modelPath:
150
+ if any(name in self.modelPath for name in ["nllb", "m2m100"]):
151
+ if "nllb" in self.modelPath:
152
+ self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelConfig.tokenizer_url if self.modelConfig.tokenizer_url is not None and len(self.modelConfig.tokenizer_url) > 0 else self.modelPath, src_lang=self.whisperLang.nllb.code)
153
+ self.targetPrefix = [self.translationLang.nllb.code]
154
+ elif "m2m100" in self.modelPath:
155
+ self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelConfig.tokenizer_url if self.modelConfig.tokenizer_url is not None and len(self.modelConfig.tokenizer_url) > 0 else self.modelPath, src_lang=self.whisperLang.m2m100.code)
156
+ self.targetPrefix = [self.transTokenizer.lang_code_to_token[self.translationLang.m2m100.code]]
157
+ self.transModel = ctranslate2.Translator(self.modelPath, compute_type="auto", device=self.device)
158
+ elif "ALMA" in self.modelPath:
159
+ self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelConfig.tokenizer_url if self.modelConfig.tokenizer_url is not None and len(self.modelConfig.tokenizer_url) > 0 else self.modelPath)
160
+ self.ALMAPrefix = "Translate this from " + self.whisperLang.whisper.names[0] + " to " + self.translationLang.whisper.names[0] + ":\n" + self.whisperLang.whisper.names[0] + ": "
161
+ self.transModel = ctranslate2.Generator(self.modelPath, compute_type="auto", device=self.device)
162
+ elif "madlad400" in self.modelPath:
163
+ self.madlad400Prefix = "<2" + self.translationLang.whisper.code + "> "
164
+ self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelConfig.tokenizer_url if self.modelConfig.tokenizer_url is not None and len(self.modelConfig.tokenizer_url) > 0 else self.modelPath, src_lang=self.whisperLang.m2m100.code)
165
+ self.transModel = ctranslate2.Translator(self.modelPath, compute_type="auto", device=self.device)
166
+ elif "mt5" in self.modelPath:
167
+ self.mt5Prefix = self.whisperLang.whisper.code + "2" + self.translationLang.whisper.code + ": "
168
+ self.transTokenizer = transformers.T5Tokenizer.from_pretrained(self.modelPath, legacy=False) #requires spiece.model
169
+ self.transModel = transformers.MT5ForConditionalGeneration.from_pretrained(self.modelPath, low_cpu_mem_usage=True)
170
+ self.transTranslator = transformers.pipeline('text2text-generation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer)
171
+ elif "ALMA" in self.modelPath:
172
+ self.ALMAPrefix = "Translate this from " + self.whisperLang.whisper.names[0] + " to " + self.translationLang.whisper.names[0] + ":\n" + self.whisperLang.whisper.names[0] + ": "
173
+ if "GPTQ" in self.modelPath:
174
+ self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelPath, use_fast=True)
175
+ if self.device == "cpu":
176
+ # Due to the poor support of GPTQ for CPUs, Therefore, it is strongly discouraged to operate it on CPU.
177
+ # set torch_dtype=torch.float32 to prevent the occurrence of the exception "addmm_impl_cpu_ not implemented for 'Half'."
178
+ transModelConfig = transformers.AutoConfig.from_pretrained(self.modelPath)
179
+ transModelConfig.quantization_config["use_exllama"] = False
180
+ self.transModel = transformers.AutoModelForCausalLM.from_pretrained(self.modelPath, device_map="auto", low_cpu_mem_usage=True, trust_remote_code=False, revision=self.modelConfig.revision, config=transModelConfig, torch_dtype=torch.float32)
181
+ else:
182
+ # transModelConfig.quantization_config["exllama_config"] = {"version":2} # After configuring to use ExLlamaV2, VRAM cannot be effectively released, which may be an issue. Temporarily not adopting the V2 version.
183
+ self.transModel = transformers.AutoModelForCausalLM.from_pretrained(self.modelPath, device_map="auto", low_cpu_mem_usage=True, trust_remote_code=False, revision=self.modelConfig.revision)
184
+ elif "GGUF" in self.modelPath:
185
+ import ctransformers
186
+ self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelConfig.tokenizer_url)
187
+ if self.device == "cpu":
188
+ self.transModel = ctransformers.AutoModelForCausalLM.from_pretrained(self.modelPath, hf=True, model_file=self.modelConfig.model_file)
189
+ else:
190
+ self.transModel = ctransformers.AutoModelForCausalLM.from_pretrained(self.modelPath, hf=True, model_file=self.modelConfig.model_file, gpu_layers=50)
191
+ self.transTranslator = transformers.pipeline("text-generation", model=self.transModel, tokenizer=self.transTokenizer, do_sample=True, temperature=0.7, top_k=40, top_p=0.95, repetition_penalty=1.1)
192
+ else:
193
+ self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelPath)
194
+ self.transModel = transformers.AutoModelForSeq2SeqLM.from_pretrained(self.modelPath)
195
+ if "m2m100" in self.modelPath:
196
+ self.transTranslator = transformers.pipeline('translation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer, src_lang=self.whisperLang.m2m100.code, tgt_lang=self.translationLang.m2m100.code)
197
+ else: #NLLB
198
+ self.transTranslator = transformers.pipeline('translation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer, src_lang=self.whisperLang.nllb.code, tgt_lang=self.translationLang.nllb.code)
199
 
200
+ except Exception as e:
201
+ self.release_vram()
202
+ raise e
203
+
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
  def release_vram(self):
206
  try:
207
  if torch.cuda.is_available():
208
  if "ct2" not in self.modelPath:
209
+ try:
210
+ if getattr(self, "transModel", None) is not None:
211
+ device = torch.device("cpu")
212
+ self.transModel.to(device)
213
+ except Exception as e:
214
+ print(traceback.format_exc())
215
+ print("\tself.transModel.to cpu, error: " + str(e))
216
+ if getattr(self, "transTranslator", None) is not None:
217
+ del self.transTranslator
218
+ if "ct2" in self.modelPath:
219
+ if getattr(self, "transModel", None) is not None and getattr(self.transModel, "unload_model", None) is not None:
220
+ self.transModel.unload_model()
221
+
222
+ if getattr(self, "transTokenizer", None) is not None:
223
+ del self.transTokenizer
224
+ if getattr(self, "transModel", None) is not None:
225
+ del self.transModel
226
+ try:
227
+ torch.cuda.empty_cache()
228
+ except Exception as e:
229
+ print(traceback.format_exc())
230
+ print("\tcuda empty cache, error: " + str(e))
231
+ import gc
232
+ gc.collect()
233
  print("release vram end.")
234
  except Exception as e:
235
+ print(traceback.format_exc())
236
  print("Error release vram: " + str(e))
237
 
238
 
239
  def translation(self, text: str, max_length: int = 400):
240
+ """
241
+ [ctranslate2]
242
+ max_batch_size:
243
+ The maximum batch size. If the number of inputs is greater than max_batch_size,
244
+ the inputs are sorted by length and split by chunks of max_batch_size examples
245
+ so that the number of padding positions is minimized.
246
+ no_repeat_ngram_size:
247
+ Prevent repetitions of ngrams with this size (set 0 to disable).
248
+ beam_size:
249
+ Beam size (1 for greedy search).
250
+
251
+ [ctranslate2.Generator.generate_batch]
252
+ sampling_temperature:
253
+ Sampling temperature to generate more random samples.
254
+ sampling_topk:
255
+ Randomly sample predictions from the top K candidates.
256
+ sampling_topp:
257
+ Keep the most probable tokens whose cumulative probability exceeds this value.
258
+ repetition_penalty:
259
+ Penalty applied to the score of previously generated tokens (set > 1 to penalize).
260
+ include_prompt_in_result:
261
+ Include the start_tokens in the result.
262
+ If include_prompt_in_result is True (the default), the decoding loop is constrained to generate the start tokens that are then included in the result.
263
+ If include_prompt_in_result is False, the start tokens are forwarded in the decoder at once to initialize its state (i.e. the KV cache for Transformer models).
264
+ For variable-length inputs, only the tokens up to the minimum length in the batch are forwarded at once. The remaining tokens are generated in the decoding loop with constrained decoding.
265
+
266
+ [transformers.TextGenerationPipeline.__call__]
267
+ return_full_text (bool, optional, defaults to True):
268
+ If set to False only added text is returned, otherwise the full text is returned. Only meaningful if return_text is set to True.
269
+ """
270
  output = None
271
  result = None
272
  try:
273
  if "ct2" in self.modelPath:
274
+ if any(name in self.modelPath for name in ["nllb", "m2m100"]):
275
+ source = self.transTokenizer.convert_ids_to_tokens(self.transTokenizer.encode(text))
276
+ output = self.transModel.translate_batch([source], target_prefix=[self.targetPrefix], max_batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, beam_size=self.numBeams)
277
+ target = output[0].hypotheses[0][1:]
278
+ result = self.transTokenizer.decode(self.transTokenizer.convert_tokens_to_ids(target))
279
+ elif "ALMA" in self.modelPath:
280
+ source = self.transTokenizer.convert_ids_to_tokens(self.transTokenizer.encode(self.ALMAPrefix + text + "\n" + self.translationLang.whisper.names[0] + ": "))
281
+ output = self.transModel.generate_batch([source], max_length=max_length, max_batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, beam_size=self.numBeams, sampling_temperature=0.7, sampling_topp=0.9, repetition_penalty=1.1, include_prompt_in_result=False) #, sampling_topk=40
282
+ target = output[0]
283
+ result = self.transTokenizer.decode(target.sequences_ids[0])
284
+ elif "madlad400" in self.modelPath:
285
+ source = self.transTokenizer.convert_ids_to_tokens(self.transTokenizer.encode(self.madlad400Prefix + text))
286
+ output = self.transModel.translate_batch([source], max_batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, beam_size=self.numBeams)
287
+ target = output[0].hypotheses[0]
288
+ result = self.transTokenizer.decode(self.transTokenizer.convert_tokens_to_ids(target))
289
  elif "mt5" in self.modelPath:
290
  output = self.transTranslator(self.mt5Prefix + text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams) #, num_return_sequences=2
291
  result = output[0]['generated_text']
292
  elif "ALMA" in self.modelPath:
293
+ if "GPTQ" in self.modelPath:
294
+ output = self.transTranslator(self.ALMAPrefix + text + "\n" + self.translationLang.whisper.names[0] + ": ", max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams, return_full_text=False)
295
+ elif "GGUF" in self.modelPath:
296
+ output = self.transTranslator(self.ALMAPrefix + text + "\n" + self.translationLang.whisper.names[0] + ": ", max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams, return_full_text=False)
297
  result = output[0]['generated_text']
 
 
 
298
  else: #M2M100 & NLLB
299
  output = self.transTranslator(text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams)
300
  result = output[0]['translation_text']
301
  except Exception as e:
302
+ print(traceback.format_exc())
303
  print("Error translation text: " + str(e))
304
 
305
  return result
306
 
307
 
308
+ _MODELS = ["nllb-200",
309
+ "m2m100",
310
+ "mt5",
311
+ "ALMA",
312
+ "madlad400"]
 
 
 
 
 
 
313
 
314
  def check_model_name(name):
315
  return any(allowed_name in name for allowed_name in _MODELS)
 
369
  "vocab.json", #m2m100
370
  "model.safetensors",
371
  "quantize_config.json",
372
+ "tokenizer.model",
373
+ "vocabulary.json"
374
  ]
375
 
376
  kwargs = {
 
378
  "allow_patterns": allowPatterns,
379
  #"tqdm_class": disabled_tqdm,
380
  }
381
+
382
+ if modelConfig.revision is not None:
383
+ kwargs["revision"] = modelConfig.revision
384
 
385
  if outputDir is not None:
386
  kwargs["local_dir"] = outputDir