SoybeanMilk commited on
Commit
73779c4
1 Parent(s): e57c738

Upload translationModel.py

Browse files
Files changed (1) hide show
  1. src/translation/translationModel.py +20 -1
src/translation/translationModel.py CHANGED
@@ -7,9 +7,12 @@ import torch
7
  import ctranslate2
8
  import transformers
9
 
 
 
10
  from typing import Optional
11
  from src.config import ModelConfig
12
  from src.translation.translationLangs import TranslationLang, get_lang_from_whisper_code
 
13
 
14
  class TranslationModel:
15
  def __init__(
@@ -92,11 +95,17 @@ class TranslationModel:
92
  self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelConfig.tokenizer_url if self.modelConfig.tokenizer_url is not None and len(self.modelConfig.tokenizer_url) > 0 else self.modelPath, src_lang=self.whisperLang.m2m100.code)
93
  self.targetPrefix = [self.transTokenizer.lang_code_to_token[self.translationLang.m2m100.code]]
94
  self.transModel = ctranslate2.Translator(self.modelPath, compute_type="auto", device=self.device)
 
95
  elif "mt5" in self.modelPath:
96
  self.mt5Prefix = self.whisperLang.whisper.code + "2" + self.translationLang.whisper.code + ": "
97
  self.transTokenizer = transformers.T5Tokenizer.from_pretrained(self.modelPath, legacy=False) #requires spiece.model
98
  self.transModel = transformers.MT5ForConditionalGeneration.from_pretrained(self.modelPath)
99
  self.transTranslator = transformers.pipeline('text2text-generation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer)
 
 
 
 
 
100
  else:
101
  self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelPath)
102
  self.transModel = transformers.AutoModelForSeq2SeqLM.from_pretrained(self.modelPath)
@@ -130,6 +139,12 @@ class TranslationModel:
130
  elif "mt5" in self.modelPath:
131
  output = self.transTranslator(self.mt5Prefix + text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams) #, num_return_sequences=2
132
  result = output[0]['generated_text']
 
 
 
 
 
 
133
  else: #M2M100 & NLLB
134
  output = self.transTranslator(text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams)
135
  result = output[0]['translation_text']
@@ -148,7 +163,8 @@ _MODELS = ["distilled-600M", "distilled-1.3B", "1.3B", "3.3B",
148
  "m2m100_1.2B-ct2", "m2m100_418M-ct2", "m2m100-12B-ct2",
149
  "m2m100_1.2B", "m2m100_418M",
150
  "mt5-zh-ja-en-trimmed",
151
- "mt5-zh-ja-en-trimmed-fine-tuned-v1"]
 
152
 
153
  def check_model_name(name):
154
  return any(allowed_name in name for allowed_name in _MODELS)
@@ -206,6 +222,9 @@ def download_model(
206
  "special_tokens_map.json",
207
  "spiece.model",
208
  "vocab.json", #m2m100
 
 
 
209
  ]
210
 
211
  kwargs = {
 
7
  import ctranslate2
8
  import transformers
9
 
10
+ import re
11
+
12
  from typing import Optional
13
  from src.config import ModelConfig
14
  from src.translation.translationLangs import TranslationLang, get_lang_from_whisper_code
15
+ from peft import PeftModel
16
 
17
  class TranslationModel:
18
  def __init__(
 
95
  self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelConfig.tokenizer_url if self.modelConfig.tokenizer_url is not None and len(self.modelConfig.tokenizer_url) > 0 else self.modelPath, src_lang=self.whisperLang.m2m100.code)
96
  self.targetPrefix = [self.transTokenizer.lang_code_to_token[self.translationLang.m2m100.code]]
97
  self.transModel = ctranslate2.Translator(self.modelPath, compute_type="auto", device=self.device)
98
+
99
  elif "mt5" in self.modelPath:
100
  self.mt5Prefix = self.whisperLang.whisper.code + "2" + self.translationLang.whisper.code + ": "
101
  self.transTokenizer = transformers.T5Tokenizer.from_pretrained(self.modelPath, legacy=False) #requires spiece.model
102
  self.transModel = transformers.MT5ForConditionalGeneration.from_pretrained(self.modelPath)
103
  self.transTranslator = transformers.pipeline('text2text-generation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer)
104
+ elif "ALMA" in self.modelPath:
105
+ self.ALMAPrefix = "Translate this from " + self.whisperLang.whisper.code + " to " + self.translationLang.whisper.code + ":" + self.whisperLang.whisper.code + ":"
106
+ self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelPath, use_fast=True)
107
+ self.transModel = transformers.AutoModelForCausalLM.from_pretrained(self.modelPath, device_map="auto", trust_remote_code=False, revision="main")
108
+ self.transTranslator = transformers.pipeline("text-generation", model=self.transModel, tokenizer=self.transTokenizer, batch_size=2, do_sample=True, temperature=0.7, top_p=0.95, top_k=40, repetition_penalty=1.1)
109
  else:
110
  self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelPath)
111
  self.transModel = transformers.AutoModelForSeq2SeqLM.from_pretrained(self.modelPath)
 
139
  elif "mt5" in self.modelPath:
140
  output = self.transTranslator(self.mt5Prefix + text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams) #, num_return_sequences=2
141
  result = output[0]['generated_text']
142
+ elif "ALMA" in self.modelPath:
143
+ output = self.transTranslator(self.ALMAPrefix + text + self.translationLang.whisper.code + ":", max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams)
144
+ result = output[0]['generated_text']
145
+ result = re.sub(rf'^(.*{self.translationLang.whisper.code}: )', '', result) # Remove the prompt from the result
146
+ result = re.sub(rf'^(Translate this from .* to .*:)', '', result) # Remove the translation instruction
147
+ return result.strip()
148
  else: #M2M100 & NLLB
149
  output = self.transTranslator(text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams)
150
  result = output[0]['translation_text']
 
163
  "m2m100_1.2B-ct2", "m2m100_418M-ct2", "m2m100-12B-ct2",
164
  "m2m100_1.2B", "m2m100_418M",
165
  "mt5-zh-ja-en-trimmed",
166
+ "mt5-zh-ja-en-trimmed-fine-tuned-v1",
167
+ "ALMA-13B-GPTQ"]
168
 
169
  def check_model_name(name):
170
  return any(allowed_name in name for allowed_name in _MODELS)
 
222
  "special_tokens_map.json",
223
  "spiece.model",
224
  "vocab.json", #m2m100
225
+ "model.safetensors",
226
+ "quantize_config.json",
227
+ "tokenizer.model"
228
  ]
229
 
230
  kwargs = {