nroggendorff commited on
Commit
d018987
1 Parent(s): e4032ff

I'm so confused

Browse files
Files changed (1) hide show
  1. translation.py +16 -105
translation.py CHANGED
@@ -23,131 +23,46 @@ LANGUAGES = {
23
  "Татар | Tatar | Татарский": "tat_Cyrl",
24
  "Тыва | Тувинский | Tuvan ": "tyv_Cyrl",
25
  }
26
- L1 = "rus_Cyrl"
27
- L2 = "eng_Latn"
28
-
29
 
30
  def get_non_printing_char_replacer(replace_by: str = " ") -> tp.Callable[[str], str]:
31
- non_printable_map = {
32
- ord(c): replace_by
33
- for c in (chr(i) for i in range(sys.maxunicode + 1))
34
- # same as \p{C} in perl
35
- # see https://www.unicode.org/reports/tr44/#General_Category_Values
36
- if unicodedata.category(c) in {"C", "Cc", "Cf", "Cs", "Co", "Cn"}
37
- }
38
-
39
- def replace_non_printing_char(line) -> str:
40
- return line.translate(non_printable_map)
41
-
42
- return replace_non_printing_char
43
-
44
 
45
  class TextPreprocessor:
46
- """
47
- Mimic the text preprocessing made for the NLLB model.
48
- This code is adapted from the Stopes repo of the NLLB team:
49
- https://github.com/facebookresearch/stopes/blob/main/stopes/pipelines/monolingual/monolingual_line_processor.py#L214
50
- """
51
-
52
  def __init__(self, lang="en"):
53
  self.mpn = MosesPunctNormalizer(lang=lang)
54
- self.mpn.substitutions = [
55
- (re.compile(r), sub) for r, sub in self.mpn.substitutions
56
- ]
57
  self.replace_nonprint = get_non_printing_char_replacer(" ")
58
 
59
  def __call__(self, text: str) -> str:
60
- clean = self.mpn.normalize(text)
61
- clean = self.replace_nonprint(clean)
62
- # replace 𝓕𝔯𝔞𝔫𝔠𝔢𝔰𝔠𝔞 by Francesca
63
- clean = unicodedata.normalize("NFKC", clean)
64
- return clean
65
-
66
 
67
  def sentenize_with_fillers(text, splitter, fix_double_space=True, ignore_errors=False):
68
- """Apply a sentence splitter and return the sentences and all separators before and after them"""
69
  if fix_double_space:
70
  text = re.sub(" +", " ", text)
71
  sentences = splitter.split(text)
72
- fillers = []
73
- i = 0
74
- for sentence in sentences:
75
- start_idx = text.find(sentence, i)
76
- if ignore_errors and start_idx == -1:
77
- # print(f"sent not found after {i}: `{sentence}`")
78
- start_idx = i + 1
79
- assert start_idx != -1, f"sent not found after {i}: `{sentence}`"
80
- fillers.append(text[i:start_idx])
81
- i = start_idx + len(sentence)
82
- fillers.append(text[i:])
83
  return sentences, fillers
84
 
85
-
86
  class Translator:
87
  def __init__(self):
88
  self.model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_URL, low_cpu_mem_usage=True)
89
- if torch.cuda.is_available():
90
- self.model.cuda()
91
  self.tokenizer = NllbTokenizer.from_pretrained(MODEL_URL)
92
-
93
  self.splitter = SentenceSplitter("ru")
94
  self.preprocessor = TextPreprocessor()
95
-
96
  self.languages = LANGUAGES
97
 
98
- def translate(
99
- self,
100
- text,
101
- src_lang=L1,
102
- tgt_lang=L2,
103
- max_length="auto",
104
- num_beams=4,
105
- by_sentence=True,
106
- preprocess=True,
107
- **kwargs,
108
- ):
109
- """Translate a text sentence by sentence, preserving the fillers around the sentences."""
110
- if by_sentence:
111
- sents, fillers = sentenize_with_fillers(
112
- text, splitter=self.splitter, ignore_errors=True
113
- )
114
- else:
115
- sents = [text]
116
- fillers = ["", ""]
117
- if preprocess:
118
- sents = [self.preprocessor(sent) for sent in sents]
119
- results = []
120
- for sent, sep in zip(sents, fillers):
121
- results.append(sep)
122
- results.append(
123
- self.translate_single(
124
- sent,
125
- src_lang=src_lang,
126
- tgt_lang=tgt_lang,
127
- max_length=max_length,
128
- num_beams=num_beams,
129
- **kwargs,
130
- )
131
- )
132
- results.append(fillers[-1])
133
- return "".join(results)
134
 
135
- def translate_single(
136
- self,
137
- text,
138
- src_lang=L1,
139
- tgt_lang=L2,
140
- max_length="auto",
141
- num_beams=4,
142
- n_out=None,
143
- **kwargs,
144
- ):
145
  self.tokenizer.src_lang = src_lang
146
- encoded = self.tokenizer(
147
- text, return_tensors="pt", truncation=True, max_length=512
148
- )
149
- if max_length == "auto":
150
- max_length = int(32 + 2.0 * encoded.input_ids.shape[1])
151
  generated_tokens = self.model.generate(
152
  **encoded.to(self.model.device),
153
  forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang],
@@ -156,13 +71,9 @@ class Translator:
156
  num_return_sequences=n_out or 1,
157
  **kwargs,
158
  )
159
- out = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
160
- if isinstance(text, str) and n_out is None:
161
- return out[0]
162
- return out
163
-
164
 
165
  if __name__ == "__main__":
166
  print("Initializing a translator to pre-download models...")
167
  translator = Translator()
168
- print("Initialization successful!")
 
23
  "Татар | Tatar | Татарский": "tat_Cyrl",
24
  "Тыва | Тувинский | Tuvan ": "tyv_Cyrl",
25
  }
26
+ L1, L2 = "rus_Cyrl", "eng_Latn"
 
 
27
 
28
  def get_non_printing_char_replacer(replace_by: str = " ") -> tp.Callable[[str], str]:
29
+ return lambda line: line.translate({ord(c): replace_by for c in (chr(i) for i in range(sys.maxunicode + 1)) if unicodedata.category(c) in {"C", "Cc", "Cf", "Cs", "Co", "Cn"}})
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  class TextPreprocessor:
 
 
 
 
 
 
32
  def __init__(self, lang="en"):
33
  self.mpn = MosesPunctNormalizer(lang=lang)
34
+ self.mpn.substitutions = [(re.compile(r), sub) for r, sub in self.mpn.substitutions]
 
 
35
  self.replace_nonprint = get_non_printing_char_replacer(" ")
36
 
37
  def __call__(self, text: str) -> str:
38
+ return unicodedata.normalize("NFKC", self.replace_nonprint(self.mpn.normalize(text)))
 
 
 
 
 
39
 
40
  def sentenize_with_fillers(text, splitter, fix_double_space=True, ignore_errors=False):
 
41
  if fix_double_space:
42
  text = re.sub(" +", " ", text)
43
  sentences = splitter.split(text)
44
+ fillers = [text[i:text.find(s, i)] for i, s in enumerate(sentences)]
45
+ fillers.append(text[text.find(sentences[-1], 0) + len(sentences[-1]):])
 
 
 
 
 
 
 
 
 
46
  return sentences, fillers
47
 
 
48
  class Translator:
49
  def __init__(self):
50
  self.model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_URL, low_cpu_mem_usage=True)
 
 
51
  self.tokenizer = NllbTokenizer.from_pretrained(MODEL_URL)
 
52
  self.splitter = SentenceSplitter("ru")
53
  self.preprocessor = TextPreprocessor()
 
54
  self.languages = LANGUAGES
55
 
56
+ def translate(self, text, src_lang=L1, tgt_lang=L2, max_length="auto", num_beams=4, by_sentence=True, preprocess=True, **kwargs):
57
+ sents, fillers = (sentenize_with_fillers(text, self.splitter, ignore_errors=True) if by_sentence else ([text], ["", ""]))
58
+ sents = [self.preprocessor(sent) for sent in sents] if preprocess else sents
59
+ results = [sep + self.translate_single(sent, src_lang, tgt_lang, max_length, num_beams, **kwargs) for sent, sep in zip(sents, fillers)]
60
+ return "".join(results + [fillers[-1]])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ def translate_single(self, text, src_lang=L1, tgt_lang=L2, max_length="auto", num_beams=4, n_out=None, **kwargs):
 
 
 
 
 
 
 
 
 
63
  self.tokenizer.src_lang = src_lang
64
+ encoded = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
65
+ max_length = int(32 + 2.0 * encoded.input_ids.shape[1]) if max_length == "auto" else max_length
 
 
 
66
  generated_tokens = self.model.generate(
67
  **encoded.to(self.model.device),
68
  forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang],
 
71
  num_return_sequences=n_out or 1,
72
  **kwargs,
73
  )
74
+ return self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] if n_out is None else self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
 
 
 
 
75
 
76
  if __name__ == "__main__":
77
  print("Initializing a translator to pre-download models...")
78
  translator = Translator()
79
+ print("Initialization successful!")