SumYin nobrowning commited on
Commit
534e8dd
0 Parent(s):

Duplicate from nobrowning/M2M

Browse files

Co-authored-by: Wenshu Geng <nobrowning@users.noreply.huggingface.co>

Files changed (5) hide show
  1. .gitattributes +27 -0
  2. README.md +13 -0
  3. app.py +198 -0
  4. languages.py +47 -0
  5. requirements.txt +4 -0
.gitattributes ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.model filter=lfs diff=lfs merge=lfs -text
11
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
12
+ *.onnx filter=lfs diff=lfs merge=lfs -text
13
+ *.ot filter=lfs diff=lfs merge=lfs -text
14
+ *.parquet filter=lfs diff=lfs merge=lfs -text
15
+ *.pb filter=lfs diff=lfs merge=lfs -text
16
+ *.pt filter=lfs diff=lfs merge=lfs -text
17
+ *.pth filter=lfs diff=lfs merge=lfs -text
18
+ *.rar filter=lfs diff=lfs merge=lfs -text
19
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
20
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
21
+ *.tflite filter=lfs diff=lfs merge=lfs -text
22
+ *.tgz filter=lfs diff=lfs merge=lfs -text
23
+ *.wasm filter=lfs diff=lfs merge=lfs -text
24
+ *.xz filter=lfs diff=lfs merge=lfs -text
25
+ *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: M2M
3
+ emoji: 💻
4
+ colorFrom: green
5
+ colorTo: gray
6
+ sdk: streamlit
7
+ sdk_version: 1.9.0
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: nobrowning/M2M
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
app.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import io
4
+ from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
5
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
+ from languages import LANGUANGE_MAP
7
+ import time
8
+ import json
9
+ from typing import List
10
+ import torch
11
+ import random
12
+ import logging
13
+
14
+ if torch.cuda.is_available():
15
+ device = torch.device("cuda:0")
16
+ else:
17
+ device = torch.device("cpu")
18
+ logging.warning("GPU not found, using CPU, translation will be very slow.")
19
+
20
+ st.cache(suppress_st_warning=True, allow_output_mutation=True)
21
+ st.set_page_config(page_title="M2M100 Translator")
22
+
23
+ lang_id = {
24
+ "Afrikaans": "af",
25
+ "Amharic": "am",
26
+ "Arabic": "ar",
27
+ "Asturian": "ast",
28
+ "Azerbaijani": "az",
29
+ "Bashkir": "ba",
30
+ "Belarusian": "be",
31
+ "Bulgarian": "bg",
32
+ "Bengali": "bn",
33
+ "Breton": "br",
34
+ "Bosnian": "bs",
35
+ "Catalan": "ca",
36
+ "Cebuano": "ceb",
37
+ "Czech": "cs",
38
+ "Welsh": "cy",
39
+ "Danish": "da",
40
+ "German": "de",
41
+ "Greeek": "el",
42
+ "English": "en",
43
+ "Spanish": "es",
44
+ "Estonian": "et",
45
+ "Persian": "fa",
46
+ "Fulah": "ff",
47
+ "Finnish": "fi",
48
+ "French": "fr",
49
+ "Western Frisian": "fy",
50
+ "Irish": "ga",
51
+ "Gaelic": "gd",
52
+ "Galician": "gl",
53
+ "Gujarati": "gu",
54
+ "Hausa": "ha",
55
+ "Hebrew": "he",
56
+ "Hindi": "hi",
57
+ "Croatian": "hr",
58
+ "Haitian": "ht",
59
+ "Hungarian": "hu",
60
+ "Armenian": "hy",
61
+ "Indonesian": "id",
62
+ "Igbo": "ig",
63
+ "Iloko": "ilo",
64
+ "Icelandic": "is",
65
+ "Italian": "it",
66
+ "Japanese": "ja",
67
+ "Javanese": "jv",
68
+ "Georgian": "ka",
69
+ "Kazakh": "kk",
70
+ "Central Khmer": "km",
71
+ "Kannada": "kn",
72
+ "Korean": "ko",
73
+ "Luxembourgish": "lb",
74
+ "Ganda": "lg",
75
+ "Lingala": "ln",
76
+ "Lao": "lo",
77
+ "Lithuanian": "lt",
78
+ "Latvian": "lv",
79
+ "Malagasy": "mg",
80
+ "Macedonian": "mk",
81
+ "Malayalam": "ml",
82
+ "Mongolian": "mn",
83
+ "Marathi": "mr",
84
+ "Malay": "ms",
85
+ "Burmese": "my",
86
+ "Nepali": "ne",
87
+ "Dutch": "nl",
88
+ "Norwegian": "no",
89
+ "Northern Sotho": "ns",
90
+ "Occitan": "oc",
91
+ "Oriya": "or",
92
+ "Panjabi": "pa",
93
+ "Polish": "pl",
94
+ "Pushto": "ps",
95
+ "Portuguese": "pt",
96
+ "Romanian": "ro",
97
+ "Russian": "ru",
98
+ "Sindhi": "sd",
99
+ "Sinhala": "si",
100
+ "Slovak": "sk",
101
+ "Slovenian": "sl",
102
+ "Somali": "so",
103
+ "Albanian": "sq",
104
+ "Serbian": "sr",
105
+ "Swati": "ss",
106
+ "Sundanese": "su",
107
+ "Swedish": "sv",
108
+ "Swahili": "sw",
109
+ "Tamil": "ta",
110
+ "Thai": "th",
111
+ "Tagalog": "tl",
112
+ "Tswana": "tn",
113
+ "Turkish": "tr",
114
+ "Ukrainian": "uk",
115
+ "Urdu": "ur",
116
+ "Uzbek": "uz",
117
+ "Vietnamese": "vi",
118
+ "Wolof": "wo",
119
+ "Xhosa": "xh",
120
+ "Yiddish": "yi",
121
+ "Yoruba": "yo",
122
+ "Chinese": "zh",
123
+ "Zulu": "zu",
124
+ }
125
+
126
+
127
+ @st.cache(suppress_st_warning=True, allow_output_mutation=True)
128
+ def load_model(
129
+ pretrained_model: str = "facebook/m2m100_1.2B",
130
+ cache_dir: str = "models/",
131
+ ):
132
+ tokenizer = M2M100Tokenizer.from_pretrained(pretrained_model, cache_dir=cache_dir)
133
+ model = M2M100ForConditionalGeneration.from_pretrained(
134
+ pretrained_model, cache_dir=cache_dir
135
+ ).to(device)
136
+ model.eval()
137
+ return tokenizer, model
138
+
139
+
140
+ @st.cache(suppress_st_warning=True, allow_output_mutation=True)
141
+ def load_detection_model(
142
+ pretrained_model: str = "ivanlau/language-detection-fine-tuned-on-xlm-roberta-base",
143
+ cache_dir: str = "models/",
144
+ ):
145
+ tokenizer = AutoTokenizer.from_pretrained(pretrained_model, cache_dir=cache_dir)
146
+ model = AutoModelForSequenceClassification.from_pretrained(pretrained_model, cache_dir=cache_dir).to(device)
147
+ model.eval()
148
+ return tokenizer, model
149
+
150
+
151
+ st.title("M2M100 Translator")
152
+ st.write("M2M100 is a multilingual encoder-decoder (seq-to-seq) model trained for Many-to-Many multilingual translation. It was introduced in this paper https://arxiv.org/abs/2010.11125 and first released in https://github.com/pytorch/fairseq/tree/master/examples/m2m_100 repository. The model that can directly translate between the 9,900 directions of 100 languages.\n")
153
+
154
+ st.write(" This demo uses the facebook/m2m100_1.2B model. For local inference see https://github.com/ikergarcia1996/Easy-Translate")
155
+
156
+
157
+ user_input: str = st.text_area(
158
+ "Input text",
159
+ height=200,
160
+ max_chars=5120,
161
+ )
162
+
163
+ target_lang = st.selectbox(label="Target language", options=list(lang_id.keys()))
164
+
165
+ if st.button("Run"):
166
+ time_start = time.time()
167
+ tokenizer, model = load_model()
168
+ de_tokenizer, de_model = load_detection_model()
169
+
170
+ with torch.no_grad():
171
+
172
+ tokenized_sentence = de_tokenizer(user_input, return_tensors='pt')
173
+ output = de_model(**tokenized_sentence)
174
+ de_predictions = torch.nn.functional.softmax(output.logits, dim=-1)
175
+ _, preds = torch.max(de_predictions, dim=-1)
176
+
177
+ lang_type = LANGUANGE_MAP[preds.item()]
178
+
179
+ if lang_type not in lang_id:
180
+ time_end = time.time()
181
+ st.success('Unsupported Language')
182
+ st.write(f"Computation time: {round((time_end-time_start),3)} segs")
183
+ else:
184
+ src_lang = lang_id[lang_type]
185
+ trg_lang = lang_id[target_lang]
186
+ tokenizer.src_lang = src_lang
187
+ encoded_input = tokenizer(user_input, return_tensors="pt").to(device)
188
+ generated_tokens = model.generate(
189
+ **encoded_input, forced_bos_token_id=tokenizer.get_lang_id(trg_lang)
190
+ )
191
+ translated_text = tokenizer.batch_decode(
192
+ generated_tokens, skip_special_tokens=True
193
+ )[0]
194
+
195
+ time_end = time.time()
196
+ st.success(translated_text)
197
+
198
+ st.write(f"Computation time: {round((time_end-time_start),3)} segs")
languages.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LANGUANGE_MAP = {
2
+ 0: 'Arabic',
3
+ 1: 'Basque',
4
+ 2: 'Breton',
5
+ 3: 'Catalan',
6
+ 4: 'Chinese',
7
+ 5: 'Chinese',
8
+ 6: 'Chinese',
9
+ 7: 'Chuvash',
10
+ 8: 'Czech',
11
+ 9: 'Dhivehi',
12
+ 10: 'Dutch',
13
+ 11: 'English',
14
+ 12: 'Esperanto',
15
+ 13: 'Estonian',
16
+ 14: 'French',
17
+ 15: 'Frisian',
18
+ 16: 'Georgian',
19
+ 17: 'German',
20
+ 18: 'Greek',
21
+ 19: 'Hakha_Chin',
22
+ 20: 'Indonesian',
23
+ 21: 'Interlingua',
24
+ 22: 'Italian',
25
+ 23: 'Japanese',
26
+ 24: 'Kabyle',
27
+ 25: 'Kinyarwanda',
28
+ 26: 'Kyrgyz',
29
+ 27: 'Latvian',
30
+ 28: 'Maltese',
31
+ 29: 'Mongolian',
32
+ 30: 'Persian',
33
+ 31: 'Polish',
34
+ 32: 'Portuguese',
35
+ 33: 'Romanian',
36
+ 34: 'Romansh_Sursilvan',
37
+ 35: 'Russian',
38
+ 36: 'Sakha',
39
+ 37: 'Slovenian',
40
+ 38: 'Spanish',
41
+ 39: 'Swedish',
42
+ 40: 'Tamil',
43
+ 41: 'Tatar',
44
+ 42: 'Turkish',
45
+ 43: 'Ukranian',
46
+ 44: 'Welsh'
47
+ }
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ streamlit
2
+ torch
3
+ transformers
4
+ transformers[sentencepiece]