SoybeanMilk commited on
Commit
c27a2cb
1 Parent(s): 90d4f1f

Delete translationModel.py

Browse files
Files changed (1) hide show
  1. translationModel.py +0 -259
translationModel.py DELETED
@@ -1,259 +0,0 @@
1
- import os
2
- import warnings
3
- import huggingface_hub
4
- import requests
5
- import torch
6
-
7
- import ctranslate2
8
- import transformers
9
-
10
- import re
11
-
12
- from typing import Optional
13
- from src.config import ModelConfig
14
- from src.translation.translationLangs import TranslationLang, get_lang_from_whisper_code
15
- from peft import PeftModel
16
-
17
- class TranslationModel:
18
- def __init__(
19
- self,
20
- modelConfig: ModelConfig,
21
- device: str = None,
22
- whisperLang: TranslationLang = None,
23
- translationLang: TranslationLang = None,
24
- batchSize: int = 2,
25
- noRepeatNgramSize: int = 3,
26
- numBeams: int = 2,
27
- downloadRoot: Optional[str] = None,
28
- localFilesOnly: bool = False,
29
- loadModel: bool = False,
30
- ):
31
- """Initializes the M2M100 / Nllb-200 / mt5 model.
32
-
33
- Args:
34
- modelConfig: Config of the model to use (distilled-600M, distilled-1.3B,
35
- 1.3B, 3.3B...) or a path to a converted
36
- model directory. When a size is configured, the converted model is downloaded
37
- from the Hugging Face Hub.
38
- device: Device to use for computation (cpu, cuda, ipu, xpu, mkldnn, opengl, opencl,
39
- ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia).
40
- device_index: Device ID to use.
41
- The model can also be loaded on multiple GPUs by passing a list of IDs
42
- (e.g. [0, 1, 2, 3]). In that case, multiple transcriptions can run in parallel
43
- when transcribe() is called from multiple Python threads (see also num_workers).
44
- compute_type: Type to use for computation.
45
- See https://opennmt.net/CTranslate2/quantization.html.
46
- cpu_threads: Number of threads to use when running on CPU (4 by default).
47
- A non zero value overrides the OMP_NUM_THREADS environment variable.
48
- num_workers: When transcribe() is called from multiple Python threads,
49
- having multiple workers enables true parallelism when running the model
50
- (concurrent calls to self.model.generate() will run in parallel).
51
- This can improve the global throughput at the cost of increased memory usage.
52
- downloadRoot: Directory where the models should be saved. If not set, the models
53
- are saved in the standard Hugging Face cache directory.
54
- localFilesOnly: If True, avoid downloading the file and return the path to the
55
- local cached file if it exists.
56
- """
57
- self.modelConfig = modelConfig
58
- self.whisperLang = whisperLang # self.translationLangWhisper = get_lang_from_whisper_code(whisperLang.code.lower() if whisperLang is not None else "en")
59
- self.translationLang = translationLang
60
-
61
- if translationLang is None:
62
- return
63
-
64
- self.batchSize = batchSize
65
- self.noRepeatNgramSize = noRepeatNgramSize
66
- self.numBeams = numBeams
67
-
68
- if os.path.isdir(modelConfig.url):
69
- self.modelPath = modelConfig.url
70
- else:
71
- self.modelPath = download_model(
72
- modelConfig,
73
- localFilesOnly=localFilesOnly,
74
- cacheDir=downloadRoot,
75
- )
76
-
77
- if device is None:
78
- if torch.cuda.is_available():
79
- device = "cuda" if "ct2" in self.modelPath else "cuda:0"
80
- else:
81
- device = "cpu"
82
-
83
- self.device = device
84
-
85
- if loadModel:
86
- self.load_model()
87
-
88
- def load_model(self):
89
- print('\n\nLoading model: %s\n\n' % self.modelPath)
90
- if "ct2" in self.modelPath:
91
- if "nllb" in self.modelPath:
92
- self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelConfig.tokenizer_url if self.modelConfig.tokenizer_url is not None and len(self.modelConfig.tokenizer_url) > 0 else self.modelPath, src_lang=self.whisperLang.nllb.code)
93
- self.targetPrefix = [self.translationLang.nllb.code]
94
- elif "m2m100" in self.modelPath:
95
- self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelConfig.tokenizer_url if self.modelConfig.tokenizer_url is not None and len(self.modelConfig.tokenizer_url) > 0 else self.modelPath, src_lang=self.whisperLang.m2m100.code)
96
- self.targetPrefix = [self.transTokenizer.lang_code_to_token[self.translationLang.m2m100.code]]
97
- self.transModel = ctranslate2.Translator(self.modelPath, compute_type="auto", device=self.device)
98
-
99
- elif "mt5" in self.modelPath:
100
- self.mt5Prefix = self.whisperLang.whisper.code + "2" + self.translationLang.whisper.code + ": "
101
- self.transTokenizer = transformers.T5Tokenizer.from_pretrained(self.modelPath, legacy=False) #requires spiece.model
102
- self.transModel = transformers.MT5ForConditionalGeneration.from_pretrained(self.modelPath)
103
- self.transTranslator = transformers.pipeline('text2text-generation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer)
104
- elif "ALMA" in self.modelPath:
105
- self.ALMAPrefix = "Translate this from " + self.whisperLang.whisper.code + " to " + self.translationLang.whisper.code + ":" + self.whisperLang.whisper.code + ":"
106
- self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelPath, use_fast=True)
107
- self.transModel = transformers.AutoModelForCausalLM.from_pretrained(self.modelPath, device_map="auto", trust_remote_code=False, revision="main")
108
- self.transTranslator = transformers.pipeline("text-generation", model=self.transModel, tokenizer=self.transTokenizer, batch_size=2, do_sample=True, temperature=0.7, top_p=0.95, top_k=40, repetition_penalty=1.1)
109
- else:
110
- self.transTokenizer = transformers.AutoTokenizer.from_pretrained(self.modelPath)
111
- self.transModel = transformers.AutoModelForSeq2SeqLM.from_pretrained(self.modelPath)
112
- if "m2m100" in self.modelPath:
113
- self.transTranslator = transformers.pipeline('translation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer, src_lang=self.whisperLang.m2m100.code, tgt_lang=self.translationLang.m2m100.code)
114
- else: #NLLB
115
- self.transTranslator = transformers.pipeline('translation', model=self.transModel, device=self.device, tokenizer=self.transTokenizer, src_lang=self.whisperLang.nllb.code, tgt_lang=self.translationLang.nllb.code)
116
-
117
- def release_vram(self):
118
- try:
119
- if torch.cuda.is_available():
120
- if "ct2" not in self.modelPath:
121
- device = torch.device("cpu")
122
- self.transModel.to(device)
123
- del self.transModel
124
- torch.cuda.empty_cache()
125
- print("release vram end.")
126
- except Exception as e:
127
- print("Error release vram: " + str(e))
128
-
129
-
130
- def translation(self, text: str, max_length: int = 400):
131
- output = None
132
- result = None
133
- try:
134
- if "ct2" in self.modelPath:
135
- source = self.transTokenizer.convert_ids_to_tokens(self.transTokenizer.encode(text))
136
- output = self.transModel.translate_batch([source], target_prefix=[self.targetPrefix], max_batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, beam_size=self.numBeams)
137
- target = output[0].hypotheses[0][1:]
138
- result = self.transTokenizer.decode(self.transTokenizer.convert_tokens_to_ids(target))
139
- elif "mt5" in self.modelPath:
140
- output = self.transTranslator(self.mt5Prefix + text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams) #, num_return_sequences=2
141
- result = output[0]['generated_text']
142
- elif "ALMA" in self.modelPath:
143
- output = self.transTranslator(self.ALMAPrefix + text + self.translationLang.whisper.code + ":", max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams)
144
- result = output[0]['generated_text']
145
- result = re.sub(rf'^(.*{self.translationLang.whisper.code}: )', '', result) # Remove the prompt from the result
146
- result = re.sub(rf'^(Translate this from .* to .*:)', '', result) # Remove the translation instruction
147
- return result.strip()
148
- else: #M2M100 & NLLB
149
- output = self.transTranslator(text, max_length=max_length, batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, num_beams=self.numBeams)
150
- result = output[0]['translation_text']
151
- except Exception as e:
152
- print("Error translation text: " + str(e))
153
-
154
- return result
155
-
156
-
157
- _MODELS = ["distilled-600M", "distilled-1.3B", "1.3B", "3.3B",
158
- "ct2fast-nllb-200-distilled-1.3B-int8_float16",
159
- "ct2fast-nllb-200-3.3B-int8_float16",
160
- "nllb-200-3.3B-ct2-float16", "nllb-200-1.3B-ct2", "nllb-200-1.3B-ct2-int8", "nllb-200-1.3B-ct2-float16",
161
- "nllb-200-distilled-1.3B-ct2", "nllb-200-distilled-1.3B-ct2-int8", "nllb-200-distilled-1.3B-ct2-float16",
162
- "nllb-200-distilled-600M-ct2", "nllb-200-distilled-600M-ct2-int8", "nllb-200-distilled-600M-ct2-float16",
163
- "m2m100_1.2B-ct2", "m2m100_418M-ct2", "m2m100-12B-ct2",
164
- "m2m100_1.2B", "m2m100_418M",
165
- "mt5-zh-ja-en-trimmed",
166
- "mt5-zh-ja-en-trimmed-fine-tuned-v1",
167
- "ALMA-13B-GPTQ"]
168
-
169
- def check_model_name(name):
170
- return any(allowed_name in name for allowed_name in _MODELS)
171
-
172
- def download_model(
173
- modelConfig: ModelConfig,
174
- outputDir: Optional[str] = None,
175
- localFilesOnly: bool = False,
176
- cacheDir: Optional[str] = None,
177
- ):
178
- """"download_model" is referenced from the "utils.py" script
179
- of the "faster_whisper" project, authored by guillaumekln.
180
-
181
- Downloads a nllb-200 model from the Hugging Face Hub.
182
-
183
- The model is downloaded from https://huggingface.co/facebook.
184
-
185
- Args:
186
- modelConfig: config of the model to download (facebook/nllb-distilled-600M,
187
- facebook/nllb-distilled-1.3B, facebook/nllb-1.3B, facebook/nllb-3.3B...).
188
- outputDir: Directory where the model should be saved. If not set, the model is saved in
189
- the cache directory.
190
- localFilesOnly: If True, avoid downloading the file and return the path to the local
191
- cached file if it exists.
192
- cacheDir: Path to the folder where cached files are stored.
193
-
194
- Returns:
195
- The path to the downloaded model.
196
-
197
- Raises:
198
- ValueError: if the model size is invalid.
199
- """
200
- if not check_model_name(modelConfig.name):
201
- raise ValueError(
202
- "Invalid model name '%s', expected one of: %s" % (modelConfig.name, ", ".join(_MODELS))
203
- )
204
-
205
- repoId = modelConfig.url #"facebook/nllb-200-%s" %
206
-
207
- allowPatterns = [
208
- "config.json",
209
- "generation_config.json",
210
- "model.bin",
211
- "pytorch_model.bin",
212
- "pytorch_model.bin.index.json",
213
- "pytorch_model-*.bin",
214
- "pytorch_model-00001-of-00003.bin",
215
- "pytorch_model-00002-of-00003.bin",
216
- "pytorch_model-00003-of-00003.bin",
217
- "sentencepiece.bpe.model",
218
- "tokenizer.json",
219
- "tokenizer_config.json",
220
- "shared_vocabulary.txt",
221
- "shared_vocabulary.json",
222
- "special_tokens_map.json",
223
- "spiece.model",
224
- "vocab.json", #m2m100
225
- "model.safetensors",
226
- "quantize_config.json",
227
- "tokenizer.model"
228
- ]
229
-
230
- kwargs = {
231
- "local_files_only": localFilesOnly,
232
- "allow_patterns": allowPatterns,
233
- #"tqdm_class": disabled_tqdm,
234
- }
235
-
236
- if outputDir is not None:
237
- kwargs["local_dir"] = outputDir
238
- kwargs["local_dir_use_symlinks"] = False
239
-
240
- if cacheDir is not None:
241
- kwargs["cache_dir"] = cacheDir
242
-
243
- try:
244
- return huggingface_hub.snapshot_download(repoId, **kwargs)
245
- except (
246
- huggingface_hub.utils.HfHubHTTPError,
247
- requests.exceptions.ConnectionError,
248
- ) as exception:
249
- warnings.warn(
250
- "An error occured while synchronizing the model %s from the Hugging Face Hub:\n%s",
251
- repoId,
252
- exception,
253
- )
254
- warnings.warn(
255
- "Trying to load the model directly from the local cache, if it exists."
256
- )
257
-
258
- kwargs["local_files_only"] = True
259
- return huggingface_hub.snapshot_download(repoId, **kwargs)