AriNubar commited on
Commit
e8c3b4c
1 Parent(s): 650e5db

improve translation speed

Browse files
Files changed (1) hide show
  1. translation.py +37 -16
translation.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import os
2
  import re
3
  import sys
@@ -7,6 +9,9 @@ import torch
7
  import pysbd
8
  from transformers import NllbTokenizer, AutoModelForSeq2SeqLM
9
  import unicodedata
 
 
 
10
 
11
  #hy_segmenter = pysbd.Segmenter(language="hy", clean=False) not needed
12
 
@@ -117,8 +122,8 @@ class Translator:
117
  self.tokenizer = NllbTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN)
118
  init_tokenizer(self.tokenizer)
119
 
120
- self.hyw_splitter = pysbd.Segmenter(language="hy", clean=False)
121
- self.eng_splitter = pysbd.Segmenter(language="en", clean=False)
122
  self.languages = LANGUAGES
123
 
124
 
@@ -138,6 +143,7 @@ class Translator:
138
  )
139
  if max_length == "auto":
140
  max_length = int(32 + 2.0 * encoded.input_ids.shape[1])
 
141
  generated_tokens = self.model.generate(
142
  **encoded.to(self.model.device),
143
  forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang],
@@ -155,36 +161,51 @@ class Translator:
155
  def translate(self, text: str,
156
  src_lang: str,
157
  tgt_lang: str,
158
- max_length="auto",
159
  num_beams=4,
160
  by_sentence=True,
161
  clean=True,
162
  **kwargs):
163
 
164
  if by_sentence:
165
- if src_lang =="eng_Latn":
166
- sents, fillers = sentenize_with_fillers(text, self.eng_splitter, ignore_errors=True)
167
  elif src_lang == "hyw_Armn":
168
- sents, fillers = sentenize_with_fillers(text, self.hyw_splitter, ignore_errors=True)
169
 
170
- else:
171
- sents = [text]
172
- fillers = ["", ""]
173
 
174
  if clean:
175
  sents = [clean_text(sent, src_lang) for sent in sents]
176
 
177
- results = []
178
- for sent, sep in zip(sents, fillers):
179
- results.append(sep)
180
- results.append(self.translate_single(sent, src_lang, tgt_lang, max_length, num_beams, **kwargs))
181
-
182
- results.append(fillers[-1])
183
 
184
  return " ".join(results)
185
 
 
 
 
 
 
 
 
 
 
 
 
186
  if __name__ == "__main__":
 
187
  print("Initializing translator...")
188
  translator = Translator()
189
  print("Translator initialized.")
190
- print(translator.translate("Hello, world!", "eng_Latn", "hyw_Armn"))
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
  import os
4
  import re
5
  import sys
 
9
  import pysbd
10
  from transformers import NllbTokenizer, AutoModelForSeq2SeqLM
11
  import unicodedata
12
+ import time
13
+
14
+
15
 
16
  #hy_segmenter = pysbd.Segmenter(language="hy", clean=False) not needed
17
 
 
122
  self.tokenizer = NllbTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN)
123
  init_tokenizer(self.tokenizer)
124
 
125
+ self.hyw_splitter = pysbd.Segmenter(language="hy", clean=True)
126
+ self.eng_splitter = pysbd.Segmenter(language="en", clean=True)
127
  self.languages = LANGUAGES
128
 
129
 
 
143
  )
144
  if max_length == "auto":
145
  max_length = int(32 + 2.0 * encoded.input_ids.shape[1])
146
+
147
  generated_tokens = self.model.generate(
148
  **encoded.to(self.model.device),
149
  forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang],
 
161
  def translate(self, text: str,
162
  src_lang: str,
163
  tgt_lang: str,
164
+ max_length=256,
165
  num_beams=4,
166
  by_sentence=True,
167
  clean=True,
168
  **kwargs):
169
 
170
  if by_sentence:
171
+ if src_lang == "eng_Latn":
172
+ sents = self.eng_splitter.segment(text)
173
  elif src_lang == "hyw_Armn":
174
+ sents = self.hyw_splitter.segment(text)
175
 
 
 
 
176
 
177
  if clean:
178
  sents = [clean_text(sent, src_lang) for sent in sents]
179
 
180
+
181
+ if len(sents) > 1:
182
+ results = self.translate_batch(sents, src_lang, tgt_lang, num_beams=num_beams, max_length=max_length, **kwargs)
183
+ else:
184
+ results = self.translate_single(sents, src_lang, tgt_lang, max_length=max_length, num_beams=num_beams, **kwargs)
 
185
 
186
  return " ".join(results)
187
 
188
+ def translate_batch(self, texts, src_lang, tgt_lang, num_beams=4, max_length=256, **kwargs):
189
+ self.tokenizer.src_lang = src_lang
190
+
191
+ if torch.cuda.is_available():
192
+ inputs = self.tokenizer(texts, return_tensors="pt", max_length=max_length, padding=True, truncation=True).input_ids.to("cuda")
193
+ translated_tokens = self.model.generate(inputs, num_beams=num_beams, forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang])
194
+ else:
195
+ inputs = self.tokenizer(texts, return_tensors="pt", max_length=max_length, padding=True, truncation=True)
196
+ translated_tokens = self.model.generate(**inputs, num_beams=num_beams, forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang])
197
+ return self.tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)
198
+
199
  if __name__ == "__main__":
200
+
201
  print("Initializing translator...")
202
  translator = Translator()
203
  print("Translator initialized.")
204
+
205
+ start_time = time.time()
206
+ print(translator.translate("Hello world!", "eng_Latn", "hyw_Armn"))
207
+ print("Time elapsed: ", time.time() - start_time)
208
+
209
+ start_time = time.time()
210
+ print(translator.translate("I am the greatest translator! Do not fuck with me!", "eng_Latn", "hyw_Armn"))
211
+ print("Time elapsed: ", time.time() - start_time)