cointegrated commited on
Commit
f44876d
1 Parent(s): 4d80410

the first commit

Browse files
Files changed (5) hide show
  1. .gitignore +1 -0
  2. README.md +5 -5
  3. app.py +55 -0
  4. requirements.txt +7 -0
  5. translation.py +184 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .idea
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: Nllb Rus Myv V2023 Demo
3
- emoji: 📊
4
- colorFrom: green
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 4.1.1
8
  app_file: app.py
9
  pinned: false
10
  ---
 
1
  ---
2
+ title: Nllb Rus Tyv V1 Demo
3
+ emoji: 🚀
4
+ colorFrom: blue
5
+ colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 3.46.1
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+
4
+ from translation import Translator, LANGUAGES
5
+ LANGUAGES_LIST = list(LANGUAGES.keys())
6
+
7
+
8
+ def translate_wrapper(text, src, trg, by_sentence=True, preprocess=True, random=False, num_beams=4):
9
+ src_lang = LANGUAGES.get(src)
10
+ tgt_lang = LANGUAGES.get(trg)
11
+ # if src == trg:
12
+ # return 'Please choose two different languages'
13
+ result = translator.translate(
14
+ text=text,
15
+ src_lang=src_lang,
16
+ tgt_lang=tgt_lang,
17
+ do_sample=random,
18
+ num_beams=int(num_beams),
19
+ by_sentence=by_sentence,
20
+ preprocess=preprocess,
21
+ )
22
+ return result
23
+
24
+
25
+ article = """
26
+ This is a NLLB-200-600M model fine-tuned for translation between Russian and Tyvan (Tuvan) languages,
27
+ using the data from https://tyvan.ru/.
28
+
29
+ This model is described in https://cointegrated.medium.com/a37fc706b865.
30
+
31
+ If you want to host in on your own backend, consider running this dockerized app: https://github.com/slone-nlp/nllb-docker-demo.
32
+ """
33
+
34
+
35
+ interface = gr.Interface(
36
+ translate_wrapper,
37
+ [
38
+ gr.Textbox(label="Text", lines=2, placeholder='text to translate '),
39
+ gr.Dropdown(LANGUAGES_LIST, type="value", label='source language', value=LANGUAGES_LIST[0]),
40
+ gr.Dropdown(LANGUAGES_LIST, type="value", label='target language', value=LANGUAGES_LIST[1]),
41
+ gr.Checkbox(label="by sentence", value=True),
42
+ gr.Checkbox(label="text preprocesing", value=True),
43
+ gr.Checkbox(label="randomize", value=False),
44
+ gr.Dropdown([1, 2, 3, 4, 5], label="number of beams", value=4),
45
+ ],
46
+ "text",
47
+ title='Tyvan-Russian translaton',
48
+ article=article,
49
+ )
50
+
51
+
52
+ if __name__ == '__main__':
53
+ translator = Translator()
54
+
55
+ interface.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ transformers==4.33
2
+ sentencepiece
3
+ gradio>=3.18.0
4
+ torch
5
+ sentence-splitter==1.4
6
+ sacremoses==0.0.45
7
+ accelerate==0.23
translation.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import sys
3
+ import typing as tp
4
+ import unicodedata
5
+
6
+ import torch
7
+ from sacremoses import MosesPunctNormalizer
8
+ from sentence_splitter import SentenceSplitter
9
+ from transformers import AutoModelForSeq2SeqLM, NllbTokenizer
10
+
11
+ MODEL_URL = "slone/nllb-rus-myv-v1-extvoc"
12
+ LANGUAGES = {
13
+ "Рузонь | Русский | Russian": "rus_Cyrl",
14
+ "Эрзянь | Эрзянский | Erzya": "myv_Cyrl",
15
+ }
16
+ L1 = "rus_Cyrl"
17
+ L2 = "myv_Cyrl"
18
+
19
+
20
+ def get_non_printing_char_replacer(replace_by: str = " ") -> tp.Callable[[str], str]:
21
+ non_printable_map = {
22
+ ord(c): replace_by
23
+ for c in (chr(i) for i in range(sys.maxunicode + 1))
24
+ # same as \p{C} in perl
25
+ # see https://www.unicode.org/reports/tr44/#General_Category_Values
26
+ if unicodedata.category(c) in {"C", "Cc", "Cf", "Cs", "Co", "Cn"}
27
+ }
28
+
29
+ def replace_non_printing_char(line) -> str:
30
+ return line.translate(non_printable_map)
31
+
32
+ return replace_non_printing_char
33
+
34
+
35
+ class TextPreprocessor:
36
+ """
37
+ Mimic the text preprocessing made for the NLLB model.
38
+ This code is adapted from the Stopes repo of the NLLB team:
39
+ https://github.com/facebookresearch/stopes/blob/main/stopes/pipelines/monolingual/monolingual_line_processor.py#L214
40
+ """
41
+
42
+ def __init__(self, lang="en"):
43
+ self.mpn = MosesPunctNormalizer(lang=lang)
44
+ self.mpn.substitutions = [
45
+ (re.compile(r), sub) for r, sub in self.mpn.substitutions
46
+ ]
47
+ self.replace_nonprint = get_non_printing_char_replacer(" ")
48
+
49
+ def __call__(self, text: str) -> str:
50
+ clean = self.mpn.normalize(text)
51
+ clean = self.replace_nonprint(clean)
52
+ # replace 𝓕𝔯𝔞𝔫𝔠𝔢𝔰𝔠𝔞 by Francesca
53
+ clean = unicodedata.normalize("NFKC", clean)
54
+ return clean
55
+
56
+
57
+ def fix_tokenizer(tokenizer, new_lang=L2):
58
+ """Add a new language token to the tokenizer vocabulary
59
+ (this should be done each time after its initialization)
60
+ """
61
+ old_len = len(tokenizer) - int(new_lang in tokenizer.added_tokens_encoder)
62
+ tokenizer.lang_code_to_id[new_lang] = old_len - 1
63
+ tokenizer.id_to_lang_code[old_len - 1] = new_lang
64
+ # always move "mask" to the last position
65
+ tokenizer.fairseq_tokens_to_ids["<mask>"] = (
66
+ len(tokenizer.sp_model)
67
+ + len(tokenizer.lang_code_to_id)
68
+ + tokenizer.fairseq_offset
69
+ )
70
+
71
+ tokenizer.fairseq_tokens_to_ids.update(tokenizer.lang_code_to_id)
72
+ tokenizer.fairseq_ids_to_tokens = {
73
+ v: k for k, v in tokenizer.fairseq_tokens_to_ids.items()
74
+ }
75
+ if new_lang not in tokenizer._additional_special_tokens:
76
+ tokenizer._additional_special_tokens.append(new_lang)
77
+ # clear the added token encoder; otherwise a new token may end up there by mistake
78
+ tokenizer.added_tokens_encoder = {}
79
+ tokenizer.added_tokens_decoder = {}
80
+
81
+
82
+ def sentenize_with_fillers(text, splitter, fix_double_space=True, ignore_errors=False):
83
+ """Apply a sentence splitter and return the sentences and all separators before and after them"""
84
+ if fix_double_space:
85
+ text = re.sub(" +", " ", text)
86
+ sentences = splitter.split(text)
87
+ fillers = []
88
+ i = 0
89
+ for sentence in sentences:
90
+ start_idx = text.find(sentence, i)
91
+ if ignore_errors and start_idx == -1:
92
+ # print(f"sent not found after {i}: `{sentence}`")
93
+ start_idx = i + 1
94
+ assert start_idx != -1, f"sent not found after {i}: `{sentence}`"
95
+ fillers.append(text[i:start_idx])
96
+ i = start_idx + len(sentence)
97
+ fillers.append(text[i:])
98
+ return sentences, fillers
99
+
100
+
101
+ class Translator:
102
+ def __init__(self):
103
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_URL, low_cpu_mem_usage=False, load_in_8bit=True)
104
+ if torch.cuda.is_available():
105
+ self.model.cuda()
106
+ self.tokenizer = NllbTokenizer.from_pretrained(MODEL_URL)
107
+ fix_tokenizer(self.tokenizer)
108
+
109
+ self.splitter = SentenceSplitter("ru")
110
+ self.preprocessor = TextPreprocessor()
111
+
112
+ self.languages = LANGUAGES
113
+
114
+ def translate(
115
+ self,
116
+ text,
117
+ src_lang=L1,
118
+ tgt_lang=L2,
119
+ max_length="auto",
120
+ num_beams=4,
121
+ by_sentence=True,
122
+ preprocess=True,
123
+ **kwargs,
124
+ ):
125
+ """Translate a text sentence by sentence, preserving the fillers around the sentences."""
126
+ if by_sentence:
127
+ sents, fillers = sentenize_with_fillers(
128
+ text, splitter=self.splitter, ignore_errors=True
129
+ )
130
+ else:
131
+ sents = [text]
132
+ fillers = ["", ""]
133
+ if preprocess:
134
+ sents = [self.preprocessor(sent) for sent in sents]
135
+ results = []
136
+ for sent, sep in zip(sents, fillers):
137
+ results.append(sep)
138
+ results.append(
139
+ self.translate_single(
140
+ sent,
141
+ src_lang=src_lang,
142
+ tgt_lang=tgt_lang,
143
+ max_length=max_length,
144
+ num_beams=num_beams,
145
+ **kwargs,
146
+ )
147
+ )
148
+ results.append(fillers[-1])
149
+ return "".join(results)
150
+
151
+ def translate_single(
152
+ self,
153
+ text,
154
+ src_lang=L1,
155
+ tgt_lang=L2,
156
+ max_length="auto",
157
+ num_beams=4,
158
+ n_out=None,
159
+ **kwargs,
160
+ ):
161
+ self.tokenizer.src_lang = src_lang
162
+ encoded = self.tokenizer(
163
+ text, return_tensors="pt", truncation=True, max_length=512
164
+ )
165
+ if max_length == "auto":
166
+ max_length = int(32 + 2.0 * encoded.input_ids.shape[1])
167
+ generated_tokens = self.model.generate(
168
+ **encoded.to(self.model.device),
169
+ forced_bos_token_id=self.tokenizer.lang_code_to_id[tgt_lang],
170
+ max_length=max_length,
171
+ num_beams=num_beams,
172
+ num_return_sequences=n_out or 1,
173
+ **kwargs,
174
+ )
175
+ out = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
176
+ if isinstance(text, str) and n_out is None:
177
+ return out[0]
178
+ return out
179
+
180
+
181
+ if __name__ == "__main__":
182
+ print("Initializing a translator to pre-download models...")
183
+ translator = Translator()
184
+ print("Initialization successful!")