Spaces:
Sleeping
Sleeping
Restructuring. Remove examples.
Browse files- app.py +51 -178
- model_translation.py +263 -1
app.py
CHANGED
@@ -6,70 +6,17 @@ Description: Translate text...
|
|
6 |
Author: Didier Guillevic
|
7 |
Date: 2024-09-07
|
8 |
"""
|
9 |
-
import spaces
|
10 |
-
import gradio as gr
|
11 |
-
import langdetect
|
12 |
-
from typing import List
|
13 |
-
|
14 |
import logging
|
15 |
logger = logging.getLogger(__name__)
|
16 |
logging.basicConfig(level=logging.INFO)
|
17 |
|
|
|
|
|
|
|
18 |
from deep_translator import GoogleTranslator
|
19 |
from model_spacy import nlp_xx
|
20 |
|
21 |
-
import model_translation
|
22 |
-
|
23 |
-
|
24 |
-
def detect_language(text):
|
25 |
-
lang = langdetect.detect(text)
|
26 |
-
return lang
|
27 |
-
|
28 |
-
|
29 |
-
def build_text_chunks(text, src_lang, sents_per_chunk):
|
30 |
-
"""
|
31 |
-
Given a text:
|
32 |
-
- Split the text into sentences.
|
33 |
-
- Build text chunks:
|
34 |
-
- Consider up to sents_per_chunk
|
35 |
-
- Ensure that we do not exceed translation.max_words_per_chunk
|
36 |
-
"""
|
37 |
-
# Split text into sentences...
|
38 |
-
sentences = [
|
39 |
-
sent.text.strip() for sent in nlp_xx(text).sents if sent.text.strip()]
|
40 |
-
logger.info(f"LANG: {src_lang}, TEXT: {text[:20]}, NB_SENTS: {len(sentences)}")
|
41 |
-
|
42 |
-
# Create text chunks of N sentences
|
43 |
-
chunks = []
|
44 |
-
chunk = ''
|
45 |
-
chunk_nb_sentences = 0
|
46 |
-
chunk_nb_words = 0
|
47 |
-
|
48 |
-
for i in range(0, len(sentences)):
|
49 |
-
# Get sentence
|
50 |
-
sent = sentences[i]
|
51 |
-
sent_nb_words = len(sent.split())
|
52 |
-
|
53 |
-
# If chunk already 'full', save chunk, start new chunk
|
54 |
-
if (
|
55 |
-
(chunk_nb_words + sent_nb_words > translation.max_words_per_chunk) or
|
56 |
-
(chunk_nb_sentences + 1 > sents_per_chunk)
|
57 |
-
):
|
58 |
-
chunks.append(chunk)
|
59 |
-
chunk = ''
|
60 |
-
chunk_nb_sentences = 0
|
61 |
-
chunk_nb_words = 0
|
62 |
-
|
63 |
-
# Append sentence to current chunk. One sentence per line.
|
64 |
-
chunk = (chunk + '\n' + sent) if chunk else sent
|
65 |
-
chunk_nb_sentences += 1
|
66 |
-
chunk_nb_words += sent_nb_words
|
67 |
-
|
68 |
-
# Append last chunk
|
69 |
-
if chunk:
|
70 |
-
chunks.append(chunk)
|
71 |
-
|
72 |
-
return chunks
|
73 |
|
74 |
|
75 |
def translate_with_Helsinki(
|
@@ -106,82 +53,31 @@ def translate_with_Helsinki(
|
|
106 |
return '\n'.join(translated_chunks)
|
107 |
|
108 |
|
109 |
-
@spaces.GPU
|
110 |
-
def translate_with_m2m100(
|
111 |
-
chunks: List[str],
|
112 |
-
src_lang: str,
|
113 |
-
tgt_lang: str) -> str:
|
114 |
-
"""Translate with the m2m100 model
|
115 |
-
"""
|
116 |
-
m2m100 = translation.ModelM2M100()
|
117 |
-
m2m100.tokenizer.src_lang = src_lang
|
118 |
-
|
119 |
-
translated_chunks = []
|
120 |
-
for chunk in chunks:
|
121 |
-
input_ids = m2m100.tokenizer(
|
122 |
-
chunk, return_tensors="pt").input_ids.to(m2m100.model.device)
|
123 |
-
outputs = m2m100.model.generate(
|
124 |
-
input_ids=input_ids,
|
125 |
-
forced_bos_token_id=m2m100.tokenizer.get_lang_id(tgt_lang))
|
126 |
-
translated_chunk = m2m100.tokenizer.batch_decode(
|
127 |
-
outputs, skip_special_tokens=True)[0]
|
128 |
-
translated_chunks.append(translated_chunk)
|
129 |
-
|
130 |
-
return '\n'.join(translated_chunks)
|
131 |
-
|
132 |
-
|
133 |
-
@spaces.GPU
|
134 |
-
def translate_with_MADLAD(
|
135 |
-
chunks: List[str],
|
136 |
-
tgt_lang: str,
|
137 |
-
input_max_length: int=512,
|
138 |
-
output_max_length: int=512) -> str:
|
139 |
-
"""Translate with Google MADLAD model
|
140 |
-
"""
|
141 |
-
madlad = translation.ModelMADLAD()
|
142 |
-
|
143 |
-
translated_chunks = []
|
144 |
-
for chunk in chunks:
|
145 |
-
input_text = f"<2{tgt_lang}> {chunk}"
|
146 |
-
#logger.info(f" Translating: {input_text[:30]}")
|
147 |
-
input_ids = madlad.tokenizer(
|
148 |
-
input_text, return_tensors="pt", max_length=input_max_length,
|
149 |
-
truncation=True, padding="longest").input_ids.to(madlad.model.device)
|
150 |
-
outputs = madlad.model.generate(
|
151 |
-
input_ids=input_ids, max_length=output_max_length)
|
152 |
-
translated_chunk = madlad.tokenizer.decode(
|
153 |
-
outputs[0], skip_special_tokens=True)
|
154 |
-
translated_chunks.append(translated_chunk)
|
155 |
-
|
156 |
-
return '\n'.join(translated_chunks)
|
157 |
-
|
158 |
-
|
159 |
def translate_text(
|
160 |
text: str,
|
161 |
-
src_lang: str
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
"""
|
166 |
-
Translate the given text into English (default "easy" language)
|
167 |
"""
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
|
|
|
|
|
|
|
|
173 |
|
174 |
-
#
|
175 |
-
|
176 |
-
|
177 |
-
translated_text_MADLAD = translate_with_MADLAD(chunks, tgt_lang)
|
178 |
translated_text_google_translate = GoogleTranslator(
|
179 |
source='auto', target='en').translate(text=text)
|
180 |
|
181 |
return (
|
182 |
-
|
183 |
-
#translated_text_m2m100,
|
184 |
-
translated_text_MADLAD,
|
185 |
translated_text_google_translate
|
186 |
)
|
187 |
|
@@ -192,77 +88,54 @@ def translate_text(
|
|
192 |
with gr.Blocks() as demo:
|
193 |
|
194 |
gr.Markdown("""
|
195 |
-
## Text translation v0.0.
|
196 |
""")
|
|
|
197 |
input_text = gr.Textbox(
|
198 |
-
lines=
|
199 |
placeholder="Enter text to translate",
|
200 |
label="Text to translate",
|
201 |
-
render=
|
202 |
)
|
203 |
-
|
204 |
-
#
|
205 |
-
|
206 |
-
# render=False
|
207 |
-
#)
|
208 |
-
#output_text_m2m100 = gr.Textbox(
|
209 |
-
# lines=6,
|
210 |
-
# label="Facebook m2m100 (1.2B)",
|
211 |
-
# render=False
|
212 |
-
#)
|
213 |
-
output_text_MADLAD = gr.Textbox(
|
214 |
lines=6,
|
215 |
-
label="
|
216 |
-
render=
|
217 |
)
|
218 |
output_text_google_translate = gr.Textbox(
|
219 |
lines=6,
|
220 |
label="Google Translate",
|
221 |
-
render=
|
222 |
)
|
223 |
|
224 |
-
#
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
|
|
|
|
|
|
235 |
|
236 |
-
#
|
237 |
-
|
238 |
-
|
239 |
-
#["Clément Delangue est, avec Julien Chaumond et Thomas Wolf, l’un des trois Français cofondateurs de Hugging Face, une start-up d’intelligence artificielle (IA) de premier plan. Valorisée à 4,2 milliards d’euros après avoir levé près de 450 millions d’euros depuis sa création en 2016, cette société de droit américain est connue comme la plate-forme de référence où développeurs et entreprises publient des outils et des modèles pour faire de l’IA en open source, c’est-à-dire accessible gratuitement et modifiable.", "fr"],
|
240 |
-
["يُعد تفشي مرض جدري القردة قضية صحية عالمية خطيرة، ومن المهم محاولة منع انتشاره للحفاظ على سلامة الناس وتجنب العدوى. د. صموئيل بولاند، مدير الحوادث الخاصة بمرض الجدري في المكتب الإقليمي لمنظمة الصحة العالمية في أفريقيا، يتحدث من كينشاسا في جمهورية الكونغو الديمقراطية، ولديه بعض النصائح البسيطة التي يمكن للناس اتباعها لتقليل خطر انتشار المرض.", "ar"],
|
241 |
-
["張先生稱,奇瑞已經凖備在西班牙生產汽車,並決心採取「本地化」的方式進入歐洲市場。此外,他也否認該公司的出口受益於不公平補貼。奇瑞成立於1997年,是中國最大的汽車公司之一。它已經是中國最大的汽車出口商,並且制定了進一步擴張的野心勃勃的計劃。", "zh"],
|
242 |
-
#["ברוכה הבאה, קיטי: בית הקפה החדש בלוס אנג'לס החתולה האהובה והחברים שלה מקבלים בית קפה משלהם בשדרות יוניברסל סיטי, שם תוכלו למצוא מגוון של פינוקים מתוקים – החל ממשקאות ועד עוגות", "he"],
|
243 |
-
]
|
244 |
-
|
245 |
-
gr.Interface(
|
246 |
fn=translate_text,
|
247 |
-
inputs=[input_text, src_lang],
|
248 |
-
outputs=[
|
249 |
-
#output_text_Helsinki,
|
250 |
-
#output_text_m2m100,
|
251 |
-
output_text_MADLAD,
|
252 |
-
output_text_google_translate,
|
253 |
-
],
|
254 |
-
additional_inputs=[sentences_per_chunk,],
|
255 |
-
#clear_btn=None, # Unfortunately, clear_btn also reset the additional inputs. Hence disabling for now.
|
256 |
-
allow_flagging="never",
|
257 |
-
examples=examples,
|
258 |
-
cache_examples=False
|
259 |
)
|
260 |
|
261 |
with gr.Accordion("Documentation", open=False):
|
262 |
gr.Markdown("""
|
263 |
-
- Models: serving Facebook M2M100 and Google
|
264 |
-
- Basic: processing of long paragraph / document to be enhanced.
|
265 |
-
- Most examples are copy/pasted from BBC news international web sites.
|
266 |
""")
|
267 |
|
268 |
if __name__ == "__main__":
|
|
|
6 |
Author: Didier Guillevic
|
7 |
Date: 2024-09-07
|
8 |
"""
|
|
|
|
|
|
|
|
|
|
|
9 |
import logging
|
10 |
logger = logging.getLogger(__name__)
|
11 |
logging.basicConfig(level=logging.INFO)
|
12 |
|
13 |
+
import gradio as gr
|
14 |
+
import langdetect
|
15 |
+
|
16 |
from deep_translator import GoogleTranslator
|
17 |
from model_spacy import nlp_xx
|
18 |
|
19 |
+
import model_translation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
|
22 |
def translate_with_Helsinki(
|
|
|
53 |
return '\n'.join(translated_chunks)
|
54 |
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
def translate_text(
|
57 |
text: str,
|
58 |
+
src_lang: str,
|
59 |
+
tgt_lang: str
|
60 |
+
) -> str:
|
61 |
+
"""Translate the given text into English or French
|
|
|
|
|
62 |
"""
|
63 |
+
# src_lang among the supported languages?
|
64 |
+
# - make sure src_lang is not None
|
65 |
+
src_lang = src_lang if (src_lang and src_lang != "auto") else langdetect.detect(text)
|
66 |
+
if src_lang not in model_translation.language_codes.values():
|
67 |
+
logging.error(f"Language detected {src_lang} not among supported language")
|
68 |
+
|
69 |
+
# tgt_lang: make sure it is not None. Default to 'en' if not set.
|
70 |
+
if tgt_lang not in model_translation.tgt_language_codes.values():
|
71 |
+
tgt_lang = 'en'
|
72 |
|
73 |
+
# translate
|
74 |
+
m2m100 = model_translation.ModelM2M100()
|
75 |
+
translated_text_m2m100 = m2m100.translate(text, src_lang, tgt_lang)
|
|
|
76 |
translated_text_google_translate = GoogleTranslator(
|
77 |
source='auto', target='en').translate(text=text)
|
78 |
|
79 |
return (
|
80 |
+
translated_text_m2m100,
|
|
|
|
|
81 |
translated_text_google_translate
|
82 |
)
|
83 |
|
|
|
88 |
with gr.Blocks() as demo:
|
89 |
|
90 |
gr.Markdown("""
|
91 |
+
## Text translation v0.0.3
|
92 |
""")
|
93 |
+
# Input
|
94 |
input_text = gr.Textbox(
|
95 |
+
lines=6,
|
96 |
placeholder="Enter text to translate",
|
97 |
label="Text to translate",
|
98 |
+
render=True
|
99 |
)
|
100 |
+
|
101 |
+
# Output
|
102 |
+
output_text_m2m100 = gr.Textbox(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
lines=6,
|
104 |
+
label="Facebook m2m100 (1.2B)",
|
105 |
+
render=True
|
106 |
)
|
107 |
output_text_google_translate = gr.Textbox(
|
108 |
lines=6,
|
109 |
label="Google Translate",
|
110 |
+
render=True
|
111 |
)
|
112 |
|
113 |
+
# Source and target languages
|
114 |
+
with gr.Row():
|
115 |
+
src_lang = gr.Radio(
|
116 |
+
choices=model_translation.language_codes.items(),
|
117 |
+
value="auto",
|
118 |
+
label="Source language",
|
119 |
+
render=True
|
120 |
+
)
|
121 |
+
tgt_lang = gr.Radio(
|
122 |
+
choices=model_translation.tgt_language_codes.items(),
|
123 |
+
value="en",
|
124 |
+
label="Target language",
|
125 |
+
render=True
|
126 |
+
)
|
127 |
|
128 |
+
# Submit button
|
129 |
+
translate_btn = gr.Button("Translate")
|
130 |
+
translate_btn.click(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
fn=translate_text,
|
132 |
+
inputs=[input_text, src_lang, tgt_lang],
|
133 |
+
outputs=[output_text_m2m100, output_text_google_translate]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
)
|
135 |
|
136 |
with gr.Accordion("Documentation", open=False):
|
137 |
gr.Markdown("""
|
138 |
+
- Models: serving Facebook M2M100 and Google Translate.
|
|
|
|
|
139 |
""")
|
140 |
|
141 |
if __name__ == "__main__":
|
model_translation.py
CHANGED
@@ -7,17 +7,188 @@ Description:
|
|
7 |
Author: Didier Guillevic
|
8 |
Date: 2024-03-16
|
9 |
"""
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
import torch
|
12 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
13 |
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
|
14 |
from transformers import BitsAndBytesConfig
|
15 |
|
|
|
|
|
16 |
quantization_config = BitsAndBytesConfig(
|
17 |
load_in_8bit=True,
|
18 |
llm_int8_threshold=200.0 # https://discuss.huggingface.co/t/correct-usage-of-bitsandbytesconfig/33809/5
|
19 |
)
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
class Singleton(type):
|
22 |
_instances = {}
|
23 |
def __call__(cls, *args, **kwargs):
|
@@ -25,8 +196,11 @@ class Singleton(type):
|
|
25 |
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
|
26 |
return cls._instances[cls]
|
27 |
|
|
|
28 |
class ModelM2M100(metaclass=Singleton):
|
29 |
"""Loads an instance of the M2M100 model.
|
|
|
|
|
30 |
"""
|
31 |
def __init__(self):
|
32 |
self._model_name = "facebook/m2m100_1.2B"
|
@@ -35,10 +209,47 @@ class ModelM2M100(metaclass=Singleton):
|
|
35 |
self._model_name,
|
36 |
device_map="auto",
|
37 |
torch_dtype=torch.float16,
|
38 |
-
low_cpu_mem_usage=True
|
|
|
39 |
)
|
40 |
self._model = torch.compile(self._model)
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
@property
|
43 |
def model_name(self):
|
44 |
return self._model_name
|
@@ -51,11 +262,20 @@ class ModelM2M100(metaclass=Singleton):
|
|
51 |
def model(self):
|
52 |
return self._model
|
53 |
|
|
|
|
|
|
|
|
|
|
|
54 |
class ModelMADLAD(metaclass=Singleton):
|
55 |
"""Loads an instance of the Google MADLAD model (3B).
|
|
|
|
|
56 |
"""
|
57 |
def __init__(self):
|
58 |
self._model_name = "google/madlad400-3b-mt"
|
|
|
|
|
59 |
self._tokenizer = AutoTokenizer.from_pretrained(
|
60 |
self.model_name, use_fast=True
|
61 |
)
|
@@ -68,6 +288,44 @@ class ModelMADLAD(metaclass=Singleton):
|
|
68 |
)
|
69 |
self._model = torch.compile(self._model)
|
70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
@property
|
72 |
def model_name(self):
|
73 |
return self._model_name
|
@@ -79,6 +337,10 @@ class ModelMADLAD(metaclass=Singleton):
|
|
79 |
@property
|
80 |
def model(self):
|
81 |
return self._model
|
|
|
|
|
|
|
|
|
82 |
|
83 |
|
84 |
# Bi-lingual individual models
|
|
|
7 |
Author: Didier Guillevic
|
8 |
Date: 2024-03-16
|
9 |
"""
|
10 |
+
import spaces
|
11 |
+
|
12 |
+
import logging
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
logging.basicConfig(level=logging.INFO)
|
15 |
|
16 |
import torch
|
17 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
18 |
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
|
19 |
from transformers import BitsAndBytesConfig
|
20 |
|
21 |
+
from model_spacy import nlp_xx as model_spacy
|
22 |
+
|
23 |
quantization_config = BitsAndBytesConfig(
|
24 |
load_in_8bit=True,
|
25 |
llm_int8_threshold=200.0 # https://discuss.huggingface.co/t/correct-usage-of-bitsandbytesconfig/33809/5
|
26 |
)
|
27 |
|
28 |
+
# The 100 languages supported by the facebook/m2m100_418M model
|
29 |
+
# https://huggingface.co/facebook/m2m100_418M
|
30 |
+
# plus the 'AUTOMATIC' option where we will use a language detector.
|
31 |
+
language_codes = {
|
32 |
+
'AUTOMATIC': 'auto',
|
33 |
+
'Afrikaans (af)': 'af',
|
34 |
+
'Albanian (sq)': 'sq',
|
35 |
+
'Amharic (am)': 'am',
|
36 |
+
'Arabic (ar)': 'ar',
|
37 |
+
'Armenian (hy)': 'hy',
|
38 |
+
'Asturian (ast)': 'ast',
|
39 |
+
'Azerbaijani (az)': 'az',
|
40 |
+
'Bashkir (ba)': 'ba',
|
41 |
+
'Belarusian (be)': 'be',
|
42 |
+
'Bengali (bn)': 'bn',
|
43 |
+
'Bosnian (bs)': 'bs',
|
44 |
+
'Breton (br)': 'br',
|
45 |
+
'Bulgarian (bg)': 'bg',
|
46 |
+
'Burmese (my)': 'my',
|
47 |
+
'Catalan; Valencian (ca)': 'ca',
|
48 |
+
'Cebuano (ceb)': 'ceb',
|
49 |
+
'Central Khmer (km)': 'km',
|
50 |
+
'Chinese (zh)': 'zh',
|
51 |
+
'Croatian (hr)': 'hr',
|
52 |
+
'Czech (cs)': 'cs',
|
53 |
+
'Danish (da)': 'da',
|
54 |
+
'Dutch; Flemish (nl)': 'nl',
|
55 |
+
'English (en)': 'en',
|
56 |
+
'Estonian (et)': 'et',
|
57 |
+
'Finnish (fi)': 'fi',
|
58 |
+
'French (fr)': 'fr',
|
59 |
+
'Fulah (ff)': 'ff',
|
60 |
+
'Gaelic; Scottish Gaelic (gd)': 'gd',
|
61 |
+
'Galician (gl)': 'gl',
|
62 |
+
'Ganda (lg)': 'lg',
|
63 |
+
'Georgian (ka)': 'ka',
|
64 |
+
'German (de)': 'de',
|
65 |
+
'Greeek (el)': 'el',
|
66 |
+
'Gujarati (gu)': 'gu',
|
67 |
+
'Haitian; Haitian Creole (ht)': 'ht',
|
68 |
+
'Hausa (ha)': 'ha',
|
69 |
+
'Hebrew (he)': 'he',
|
70 |
+
'Hindi (hi)': 'hi',
|
71 |
+
'Hungarian (hu)': 'hu',
|
72 |
+
'Icelandic (is)': 'is',
|
73 |
+
'Igbo (ig)': 'ig',
|
74 |
+
'Iloko (ilo)': 'ilo',
|
75 |
+
'Indonesian (id)': 'id',
|
76 |
+
'Irish (ga)': 'ga',
|
77 |
+
'Italian (it)': 'it',
|
78 |
+
'Japanese (ja)': 'ja',
|
79 |
+
'Javanese (jv)': 'jv',
|
80 |
+
'Kannada (kn)': 'kn',
|
81 |
+
'Kazakh (kk)': 'kk',
|
82 |
+
'Korean (ko)': 'ko',
|
83 |
+
'Lao (lo)': 'lo',
|
84 |
+
'Latvian (lv)': 'lv',
|
85 |
+
'Lingala (ln)': 'ln',
|
86 |
+
'Lithuanian (lt)': 'lt',
|
87 |
+
'Luxembourgish; Letzeburgesch (lb)': 'lb',
|
88 |
+
'Macedonian (mk)': 'mk',
|
89 |
+
'Malagasy (mg)': 'mg',
|
90 |
+
'Malay (ms)': 'ms',
|
91 |
+
'Malayalam (ml)': 'ml',
|
92 |
+
'Marathi (mr)': 'mr',
|
93 |
+
'Mongolian (mn)': 'mn',
|
94 |
+
'Nepali (ne)': 'ne',
|
95 |
+
'Northern Sotho (ns)': 'ns',
|
96 |
+
'Norwegian (no)': 'no',
|
97 |
+
'Occitan (post 1500) (oc)': 'oc',
|
98 |
+
'Oriya (or)': 'or',
|
99 |
+
'Panjabi; Punjabi (pa)': 'pa',
|
100 |
+
'Persian (fa)': 'fa',
|
101 |
+
'Polish (pl)': 'pl',
|
102 |
+
'Portuguese (pt)': 'pt',
|
103 |
+
'Pushto; Pashto (ps)': 'ps',
|
104 |
+
'Romanian; Moldavian; Moldovan (ro)': 'ro',
|
105 |
+
'Russian (ru)': 'ru',
|
106 |
+
'Serbian (sr)': 'sr',
|
107 |
+
'Sindhi (sd)': 'sd',
|
108 |
+
'Sinhala; Sinhalese (si)': 'si',
|
109 |
+
'Slovak (sk)': 'sk',
|
110 |
+
'Slovenian (sl)': 'sl',
|
111 |
+
'Somali (so)': 'so',
|
112 |
+
'Spanish (es)': 'es',
|
113 |
+
'Sundanese (su)': 'su',
|
114 |
+
'Swahili (sw)': 'sw',
|
115 |
+
'Swati (ss)': 'ss',
|
116 |
+
'Swedish (sv)': 'sv',
|
117 |
+
'Tagalog (tl)': 'tl',
|
118 |
+
'Tamil (ta)': 'ta',
|
119 |
+
'Thai (th)': 'th',
|
120 |
+
'Tswana (tn)': 'tn',
|
121 |
+
'Turkish (tr)': 'tr',
|
122 |
+
'Ukrainian (uk)': 'uk',
|
123 |
+
'Urdu (ur)': 'ur',
|
124 |
+
'Uzbek (uz)': 'uz',
|
125 |
+
'Vietnamese (vi)': 'vi',
|
126 |
+
'Welsh (cy)': 'cy',
|
127 |
+
'Western Frisian (fy)': 'fy',
|
128 |
+
'Wolof (wo)': 'wo',
|
129 |
+
'Xhosa (xh)': 'xh',
|
130 |
+
'Yiddish (yi)': 'yi',
|
131 |
+
'Yoruba (yo)': 'yo',
|
132 |
+
'Zulu (zu)': 'zu'
|
133 |
+
}
|
134 |
+
|
135 |
+
tgt_language_codes = {
|
136 |
+
'English (en)': 'en',
|
137 |
+
'French (fr)': 'fr'
|
138 |
+
}
|
139 |
+
|
140 |
+
|
141 |
+
def build_text_chunks(
|
142 |
+
text: str,
|
143 |
+
sents_per_chunk: int=5,
|
144 |
+
words_per_chunk=200) -> list[str]:
|
145 |
+
"""Split a given text into chunks with at most sents_per_chnks and words_per_chunk
|
146 |
+
|
147 |
+
Given a text:
|
148 |
+
- Split the text into sentences.
|
149 |
+
- Build text chunks:
|
150 |
+
- Consider up to sents_per_chunk
|
151 |
+
- Ensure that we do not exceed words_per_chunk
|
152 |
+
"""
|
153 |
+
# Split text into sentences...
|
154 |
+
sentences = [
|
155 |
+
sent.text.strip() for sent in model_spacy(text).sents if sent.text.strip()
|
156 |
+
]
|
157 |
+
logger.info(f"TEXT: {text[:25]}, NB_SENTS: {len(sentences)}")
|
158 |
+
|
159 |
+
# Create text chunks of N sentences
|
160 |
+
chunks = []
|
161 |
+
chunk = ''
|
162 |
+
chunk_nb_sentences = 0
|
163 |
+
chunk_nb_words = 0
|
164 |
+
|
165 |
+
for i in range(0, len(sentences)):
|
166 |
+
# Get sentence
|
167 |
+
sent = sentences[i]
|
168 |
+
sent_nb_words = len(sent.split())
|
169 |
+
|
170 |
+
# If chunk already 'full', save chunk, start new chunk
|
171 |
+
if (
|
172 |
+
(chunk_nb_words + sent_nb_words > words_per_chunk) or
|
173 |
+
(chunk_nb_sentences + 1 > sents_per_chunk)
|
174 |
+
):
|
175 |
+
chunks.append(chunk)
|
176 |
+
chunk = ''
|
177 |
+
chunk_nb_sentences = 0
|
178 |
+
chunk_nb_words = 0
|
179 |
+
|
180 |
+
# Append sentence to current chunk. One sentence per line.
|
181 |
+
chunk = (chunk + '\n' + sent) if chunk else sent
|
182 |
+
chunk_nb_sentences += 1
|
183 |
+
chunk_nb_words += sent_nb_words
|
184 |
+
|
185 |
+
# Append last chunk
|
186 |
+
if chunk:
|
187 |
+
chunks.append(chunk)
|
188 |
+
|
189 |
+
return chunks
|
190 |
+
|
191 |
+
|
192 |
class Singleton(type):
|
193 |
_instances = {}
|
194 |
def __call__(cls, *args, **kwargs):
|
|
|
196 |
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
|
197 |
return cls._instances[cls]
|
198 |
|
199 |
+
|
200 |
class ModelM2M100(metaclass=Singleton):
|
201 |
"""Loads an instance of the M2M100 model.
|
202 |
+
|
203 |
+
Model: https://huggingface.co/facebook/m2m100_1.2B
|
204 |
"""
|
205 |
def __init__(self):
|
206 |
self._model_name = "facebook/m2m100_1.2B"
|
|
|
209 |
self._model_name,
|
210 |
device_map="auto",
|
211 |
torch_dtype=torch.float16,
|
212 |
+
low_cpu_mem_usage=True,
|
213 |
+
quantization_config=quantization_config
|
214 |
)
|
215 |
self._model = torch.compile(self._model)
|
216 |
|
217 |
+
@spaces.GPU
|
218 |
+
def translate(
|
219 |
+
self,
|
220 |
+
text: str,
|
221 |
+
src_lang: str,
|
222 |
+
tgt_lang: str,
|
223 |
+
chunk_text: bool=True,
|
224 |
+
sents_per_chunk: int=5,
|
225 |
+
words_per_chunk: int=200
|
226 |
+
) -> str:
|
227 |
+
"""Translate the given text from src_lang to tgt_lang.
|
228 |
+
|
229 |
+
The text will be split into chunks to ensure the chunks fit into the
|
230 |
+
model input_max_length (usually 512 tokens).
|
231 |
+
"""
|
232 |
+
chunks = [text,]
|
233 |
+
if chunk_text:
|
234 |
+
chunks = build_text_chunks(text, sents_per_chunk, words_per_chunk)
|
235 |
+
|
236 |
+
self._tokenizer.src_lang = src_lang
|
237 |
+
|
238 |
+
translated_chunks = []
|
239 |
+
for chunk in chunks:
|
240 |
+
input_ids = self._tokenizer(
|
241 |
+
chunk,
|
242 |
+
return_tensors="pt").input_ids.to(self._model.device)
|
243 |
+
outputs = self._model.generate(
|
244 |
+
input_ids=input_ids,
|
245 |
+
forced_bos_token_id=self._tokenizer.get_lang_id(tgt_lang))
|
246 |
+
translated_chunk = self._tokenizer.batch_decode(
|
247 |
+
outputs,
|
248 |
+
skip_special_tokens=True)[0]
|
249 |
+
translated_chunks.append(translated_chunk)
|
250 |
+
|
251 |
+
return '\n'.join(translated_chunks)
|
252 |
+
|
253 |
@property
|
254 |
def model_name(self):
|
255 |
return self._model_name
|
|
|
262 |
def model(self):
|
263 |
return self._model
|
264 |
|
265 |
+
@property
|
266 |
+
def device(self):
|
267 |
+
return self._model.device
|
268 |
+
|
269 |
+
|
270 |
class ModelMADLAD(metaclass=Singleton):
|
271 |
"""Loads an instance of the Google MADLAD model (3B).
|
272 |
+
|
273 |
+
Model: https://huggingface.co/google/madlad400-3b-mt
|
274 |
"""
|
275 |
def __init__(self):
|
276 |
self._model_name = "google/madlad400-3b-mt"
|
277 |
+
self._input_max_length = 512 # config.json n_positions
|
278 |
+
self._output_max_length = 512 # config.json n_positions
|
279 |
self._tokenizer = AutoTokenizer.from_pretrained(
|
280 |
self.model_name, use_fast=True
|
281 |
)
|
|
|
288 |
)
|
289 |
self._model = torch.compile(self._model)
|
290 |
|
291 |
+
@spaces.GPU
|
292 |
+
def translate(
|
293 |
+
self,
|
294 |
+
text: str,
|
295 |
+
tgt_lang: str,
|
296 |
+
chunk_text: True,
|
297 |
+
sents_per_chunk: int=5,
|
298 |
+
words_per_chunk: int=5
|
299 |
+
) -> str:
|
300 |
+
"""Translate given text into the target language.
|
301 |
+
|
302 |
+
The text will be split into chunks to ensure the chunks fit into the
|
303 |
+
model input_max_length (usually 512 tokens).
|
304 |
+
"""
|
305 |
+
chunks = [text,]
|
306 |
+
if chunk_text:
|
307 |
+
chunks = build_text_chunks(text, sents_per_chunk, words_per_chunk)
|
308 |
+
|
309 |
+
translated_chunks = []
|
310 |
+
for chunk in chunks:
|
311 |
+
input_text = f"<2{tgt_lang}> {chunk}"
|
312 |
+
logger.info(f" Translating: {input_text[:50]}")
|
313 |
+
input_ids = self._tokenizer(
|
314 |
+
input_text,
|
315 |
+
return_tensors="pt",
|
316 |
+
max_length=self._input_max_length,
|
317 |
+
truncation=True,
|
318 |
+
padding="longest").input_ids.to(self._model.device)
|
319 |
+
outputs = self._model.generate(
|
320 |
+
input_ids=input_ids,
|
321 |
+
max_length=self._output_max_length)
|
322 |
+
translated_chunk = self._tokenizer.decode(
|
323 |
+
outputs[0],
|
324 |
+
skip_special_tokens=True)
|
325 |
+
translated_chunks.append(translated_chunk)
|
326 |
+
|
327 |
+
return '\n'.join(translated_chunks)
|
328 |
+
|
329 |
@property
|
330 |
def model_name(self):
|
331 |
return self._model_name
|
|
|
337 |
@property
|
338 |
def model(self):
|
339 |
return self._model
|
340 |
+
|
341 |
+
@property
|
342 |
+
def device(self):
|
343 |
+
return self._model.device
|
344 |
|
345 |
|
346 |
# Bi-lingual individual models
|