flutter-painter commited on
Commit
e30bf3c
β€’
1 Parent(s): 442feca

Create translation.py

Browse files
Files changed (1) hide show
  1. translation.py +182 -0
translation.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import sys
3
+ import typing as tp
4
+ import unicodedata
5
+
6
+ import torch
7
+ from sacremoses import MosesPunctNormalizer
8
+ from sentence_splitter import SentenceSplitter
9
+ from transformers import AutoModelForSeq2SeqLM, NllbTokenizer
10
+
11
+ MODEL_URL = "flutter-painter/nllb-fra-fuf-v2"
12
+ LANGUAGES = {
13
+ "French": "fra_Latn",
14
+ "Fula": "fuf_Latn",
15
+ }
16
+
17
+
18
+ def get_non_printing_char_replacer(replace_by: str = " ") -> tp.Callable[[str], str]:
19
+ non_printable_map = {
20
+ ord(c): replace_by
21
+ for c in (chr(i) for i in range(sys.maxunicode + 1))
22
+ # same as \p{C} in perl
23
+ # see https://www.unicode.org/reports/tr44/#General_Category_Values
24
+ if unicodedata.category(c) in {"C", "Cc", "Cf", "Cs", "Co", "Cn"}
25
+ }
26
+
27
+ def replace_non_printing_char(line) -> str:
28
+ return line.translate(non_printable_map)
29
+
30
+ return replace_non_printing_char
31
+
32
+
33
+ class TextPreprocessor:
34
+ """
35
+ Mimic the text preprocessing made for the NLLB model.
36
+ This code is adapted from the Stopes repo of the NLLB team:
37
+ https://github.com/facebookresearch/stopes/blob/main/stopes/pipelines/monolingual/monolingual_line_processor.py#L214
38
+ """
39
+
40
+ def __init__(self, lang="en"):
41
+ self.mpn = MosesPunctNormalizer(lang=lang)
42
+ self.mpn.substitutions = [
43
+ (re.compile(r), sub) for r, sub in self.mpn.substitutions
44
+ ]
45
+ self.replace_nonprint = get_non_printing_char_replacer(" ")
46
+
47
+ def __call__(self, text: str) -> str:
48
+ clean = self.mpn.normalize(text)
49
+ clean = self.replace_nonprint(clean)
50
+ # replace π“•π”―π”žπ”«π” π”’π”°π” π”ž by Francesca
51
+ clean = unicodedata.normalize("NFKC", clean)
52
+ return clean
53
+
54
+
55
+ def fix_tokenizer(tokenizer, new_lang="tyv_Cyrl"):
56
+ """Add a new language token to the tokenizer vocabulary
57
+ (this should be done each time after its initialization)
58
+ """
59
+ old_len = len(tokenizer) - int(new_lang in tokenizer.added_tokens_encoder)
60
+ tokenizer.lang_code_to_id[new_lang] = old_len - 1
61
+ tokenizer.id_to_lang_code[old_len - 1] = new_lang
62
+ # always move "mask" to the last position
63
+ tokenizer.fairseq_tokens_to_ids["<mask>"] = (
64
+ len(tokenizer.sp_model)
65
+ + len(tokenizer.lang_code_to_id)
66
+ + tokenizer.fairseq_offset
67
+ )
68
+
69
+ tokenizer.fairseq_tokens_to_ids.update(tokenizer.lang_code_to_id)
70
+ tokenizer.fairseq_ids_to_tokens = {
71
+ v: k for k, v in tokenizer.fairseq_tokens_to_ids.items()
72
+ }
73
+ if new_lang not in tokenizer._additional_special_tokens:
74
+ tokenizer._additional_special_tokens.append(new_lang)
75
+ # clear the added token encoder; otherwise a new token may end up there by mistake
76
+ tokenizer.added_tokens_encoder = {}
77
+ tokenizer.added_tokens_decoder = {}
78
+
79
+
80
+ def sentenize_with_fillers(text, splitter, fix_double_space=True, ignore_errors=False):
81
+ """Apply a sentence splitter and return the sentences and all separators before and after them"""
82
+ if fix_double_space:
83
+ text = re.sub(" +", " ", text)
84
+ sentences = splitter.split(text)
85
+ fillers = []
86
+ i = 0
87
+ for sentence in sentences:
88
+ start_idx = text.find(sentence, i)
89
+ if ignore_errors and start_idx == -1:
90
+ # print(f"sent not found after {i}: `{sentence}`")
91
+ start_idx = i + 1
92
+ assert start_idx != -1, f"sent not found after {i}: `{sentence}`"
93
+ fillers.append(text[i:start_idx])
94
+ i = start_idx + len(sentence)
95
+ fillers.append(text[i:])
96
+ return sentences, fillers
97
+
98
+
99
+ class Translator:
100
+ def __init__(self):
101
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_URL)
102
+ if torch.cuda.is_available():
103
+ self.model.cuda()
104
+ self.tokenizer = NllbTokenizer.from_pretrained(MODEL_URL)
105
+ fix_tokenizer(self.tokenizer)
106
+
107
+ self.splitter = SentenceSplitter("ru")
108
+ self.preprocessor = TextPreprocessor()
109
+
110
+ self.languages = LANGUAGES
111
+
112
+ def translate(
113
+ self,
114
+ text,
115
+ src_lang="rus_Cyrl",
116
+ tgt_lang="tyv_Cyrl",
117
+ max_length="auto",
118
+ num_beams=4,
119
+ by_sentence=True,
120
+ preprocess=True,
121
+ **kwargs,
122
+ ):
123
+ """Translate a text sentence by sentence, preserving the fillers around the sentences."""
124
+ if by_sentence:
125
+ sents, fillers = sentenize_with_fillers(
126
+ text, splitter=self.splitter, ignore_errors=True
127
+ )
128
+ else:
129
+ sents = [text]
130
+ fillers = ["", ""]
131
+ if preprocess:
132
+ sents = [self.preprocessor(sent) for sent in sents]
133
+ results = []
134
+ for sent, sep in zip(sents, fillers):
135
+ results.append(sep)
136
+ results.append(
137
+ self.translate_single(
138
+ sent,
139
+ src_lang=src_lang,
140
+ tgt_lang=tgt_lang,
141
+ max_length=max_length,
142
+ num_beams=num_beams,
143
+ **kwargs,
144
+ )
145
+ )
146
+ results.append(fillers[-1])
147
+ return "".join(results)
148
+
149
+ def translate_single(
150
+ self,
151
+ text,
152
+ src_lang="rus_Cyrl",
153
+ tgt_lang="tyv_Cyrl",
154
+ max_length="auto",
155
+ num_beams=4,
156
+ n_out=None,
157
+ **kwargs,
158
+ ):
159
+ self.tokenizer.src_lang = src_lang
160
+ encoded = self.tokenizer(
161
+ text, return_tensors="pt", truncation=True, max_length=512
162
+ )
163
+ if max_length == "auto":
164
+ max_length = int(32 + 2.0 * encoded.input_ids.shape[1])
165
+ generated_tokens = self.model.generate(
166
+ **encoded.to(self.model.device),
167
+ forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang],
168
+ max_length=max_length,
169
+ num_beams=num_beams,
170
+ num_return_sequences=n_out or 1,
171
+ **kwargs,
172
+ )
173
+ out = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
174
+ if isinstance(text, str) and n_out is None:
175
+ return out[0]
176
+ return out
177
+
178
+
179
+ if __name__ == "__main__":
180
+ print("Initializing a translator to pre-download models...")
181
+ translator = Translator()
182
+ print("Initialization successful!")