Spaces:
Running
Running
improve translation speed
Browse files- translation.py +37 -16
translation.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
import os
|
2 |
import re
|
3 |
import sys
|
@@ -7,6 +9,9 @@ import torch
|
|
7 |
import pysbd
|
8 |
from transformers import NllbTokenizer, AutoModelForSeq2SeqLM
|
9 |
import unicodedata
|
|
|
|
|
|
|
10 |
|
11 |
#hy_segmenter = pysbd.Segmenter(language="hy", clean=False) not needed
|
12 |
|
@@ -117,8 +122,8 @@ class Translator:
|
|
117 |
self.tokenizer = NllbTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN)
|
118 |
init_tokenizer(self.tokenizer)
|
119 |
|
120 |
-
self.hyw_splitter = pysbd.Segmenter(language="hy", clean=
|
121 |
-
self.eng_splitter = pysbd.Segmenter(language="en", clean=
|
122 |
self.languages = LANGUAGES
|
123 |
|
124 |
|
@@ -138,6 +143,7 @@ class Translator:
|
|
138 |
)
|
139 |
if max_length == "auto":
|
140 |
max_length = int(32 + 2.0 * encoded.input_ids.shape[1])
|
|
|
141 |
generated_tokens = self.model.generate(
|
142 |
**encoded.to(self.model.device),
|
143 |
forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang],
|
@@ -155,36 +161,51 @@ class Translator:
|
|
155 |
def translate(self, text: str,
|
156 |
src_lang: str,
|
157 |
tgt_lang: str,
|
158 |
-
max_length=
|
159 |
num_beams=4,
|
160 |
by_sentence=True,
|
161 |
clean=True,
|
162 |
**kwargs):
|
163 |
|
164 |
if by_sentence:
|
165 |
-
if src_lang =="eng_Latn":
|
166 |
-
sents
|
167 |
elif src_lang == "hyw_Armn":
|
168 |
-
sents
|
169 |
|
170 |
-
else:
|
171 |
-
sents = [text]
|
172 |
-
fillers = ["", ""]
|
173 |
|
174 |
if clean:
|
175 |
sents = [clean_text(sent, src_lang) for sent in sents]
|
176 |
|
177 |
-
|
178 |
-
|
179 |
-
results.
|
180 |
-
|
181 |
-
|
182 |
-
results.append(fillers[-1])
|
183 |
|
184 |
return " ".join(results)
|
185 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
if __name__ == "__main__":
|
|
|
187 |
print("Initializing translator...")
|
188 |
translator = Translator()
|
189 |
print("Translator initialized.")
|
190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
import os
|
4 |
import re
|
5 |
import sys
|
|
|
9 |
import pysbd
|
10 |
from transformers import NllbTokenizer, AutoModelForSeq2SeqLM
|
11 |
import unicodedata
|
12 |
+
import time
|
13 |
+
|
14 |
+
|
15 |
|
16 |
#hy_segmenter = pysbd.Segmenter(language="hy", clean=False) not needed
|
17 |
|
|
|
122 |
self.tokenizer = NllbTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN)
|
123 |
init_tokenizer(self.tokenizer)
|
124 |
|
125 |
+
self.hyw_splitter = pysbd.Segmenter(language="hy", clean=True)
|
126 |
+
self.eng_splitter = pysbd.Segmenter(language="en", clean=True)
|
127 |
self.languages = LANGUAGES
|
128 |
|
129 |
|
|
|
143 |
)
|
144 |
if max_length == "auto":
|
145 |
max_length = int(32 + 2.0 * encoded.input_ids.shape[1])
|
146 |
+
|
147 |
generated_tokens = self.model.generate(
|
148 |
**encoded.to(self.model.device),
|
149 |
forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang],
|
|
|
161 |
def translate(self, text: str,
|
162 |
src_lang: str,
|
163 |
tgt_lang: str,
|
164 |
+
max_length=256,
|
165 |
num_beams=4,
|
166 |
by_sentence=True,
|
167 |
clean=True,
|
168 |
**kwargs):
|
169 |
|
170 |
if by_sentence:
|
171 |
+
if src_lang == "eng_Latn":
|
172 |
+
sents = self.eng_splitter.segment(text)
|
173 |
elif src_lang == "hyw_Armn":
|
174 |
+
sents = self.hyw_splitter.segment(text)
|
175 |
|
|
|
|
|
|
|
176 |
|
177 |
if clean:
|
178 |
sents = [clean_text(sent, src_lang) for sent in sents]
|
179 |
|
180 |
+
|
181 |
+
if len(sents) > 1:
|
182 |
+
results = self.translate_batch(sents, src_lang, tgt_lang, num_beams=num_beams, max_length=max_length, **kwargs)
|
183 |
+
else:
|
184 |
+
results = self.translate_single(sents, src_lang, tgt_lang, max_length=max_length, num_beams=num_beams, **kwargs)
|
|
|
185 |
|
186 |
return " ".join(results)
|
187 |
|
188 |
+
def translate_batch(self, texts, src_lang, tgt_lang, num_beams=4, max_length=256, **kwargs):
|
189 |
+
self.tokenizer.src_lang = src_lang
|
190 |
+
|
191 |
+
if torch.cuda.is_available():
|
192 |
+
inputs = self.tokenizer(texts, return_tensors="pt", max_length=max_length, padding=True, truncation=True).input_ids.to("cuda")
|
193 |
+
translated_tokens = self.model.generate(inputs, num_beams=num_beams, forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang])
|
194 |
+
else:
|
195 |
+
inputs = self.tokenizer(texts, return_tensors="pt", max_length=max_length, padding=True, truncation=True)
|
196 |
+
translated_tokens = self.model.generate(**inputs, num_beams=num_beams, forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang])
|
197 |
+
return self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
|
198 |
+
|
199 |
if __name__ == "__main__":
|
200 |
+
|
201 |
print("Initializing translator...")
|
202 |
translator = Translator()
|
203 |
print("Translator initialized.")
|
204 |
+
|
205 |
+
start_time = time.time()
|
206 |
+
print(translator.translate("Hello world!", "eng_Latn", "hyw_Armn"))
|
207 |
+
print("Time elapsed: ", time.time() - start_time)
|
208 |
+
|
209 |
+
start_time = time.time()
|
210 |
+
print(translator.translate("I am the greatest translator! Do not fuck with me!", "eng_Latn", "hyw_Armn"))
|
211 |
+
print("Time elapsed: ", time.time() - start_time)
|