Didier commited on
Commit
0c7be31
1 Parent(s): b929bff

Restructuring. Remove examples.

Browse files
Files changed (2) hide show
  1. app.py +51 -178
  2. model_translation.py +263 -1
app.py CHANGED
@@ -6,70 +6,17 @@ Description: Translate text...
6
  Author: Didier Guillevic
7
  Date: 2024-09-07
8
  """
9
- import spaces
10
- import gradio as gr
11
- import langdetect
12
- from typing import List
13
-
14
  import logging
15
  logger = logging.getLogger(__name__)
16
  logging.basicConfig(level=logging.INFO)
17
 
 
 
 
18
  from deep_translator import GoogleTranslator
19
  from model_spacy import nlp_xx
20
 
21
- import model_translation as translation
22
-
23
-
24
- def detect_language(text):
25
- lang = langdetect.detect(text)
26
- return lang
27
-
28
-
29
- def build_text_chunks(text, src_lang, sents_per_chunk):
30
- """
31
- Given a text:
32
- - Split the text into sentences.
33
- - Build text chunks:
34
- - Consider up to sents_per_chunk
35
- - Ensure that we do not exceed translation.max_words_per_chunk
36
- """
37
- # Split text into sentences...
38
- sentences = [
39
- sent.text.strip() for sent in nlp_xx(text).sents if sent.text.strip()]
40
- logger.info(f"LANG: {src_lang}, TEXT: {text[:20]}, NB_SENTS: {len(sentences)}")
41
-
42
- # Create text chunks of N sentences
43
- chunks = []
44
- chunk = ''
45
- chunk_nb_sentences = 0
46
- chunk_nb_words = 0
47
-
48
- for i in range(0, len(sentences)):
49
- # Get sentence
50
- sent = sentences[i]
51
- sent_nb_words = len(sent.split())
52
-
53
- # If chunk already 'full', save chunk, start new chunk
54
- if (
55
- (chunk_nb_words + sent_nb_words > translation.max_words_per_chunk) or
56
- (chunk_nb_sentences + 1 > sents_per_chunk)
57
- ):
58
- chunks.append(chunk)
59
- chunk = ''
60
- chunk_nb_sentences = 0
61
- chunk_nb_words = 0
62
-
63
- # Append sentence to current chunk. One sentence per line.
64
- chunk = (chunk + '\n' + sent) if chunk else sent
65
- chunk_nb_sentences += 1
66
- chunk_nb_words += sent_nb_words
67
-
68
- # Append last chunk
69
- if chunk:
70
- chunks.append(chunk)
71
-
72
- return chunks
73
 
74
 
75
  def translate_with_Helsinki(
@@ -106,82 +53,31 @@ def translate_with_Helsinki(
106
  return '\n'.join(translated_chunks)
107
 
108
 
109
- @spaces.GPU
110
- def translate_with_m2m100(
111
- chunks: List[str],
112
- src_lang: str,
113
- tgt_lang: str) -> str:
114
- """Translate with the m2m100 model
115
- """
116
- m2m100 = translation.ModelM2M100()
117
- m2m100.tokenizer.src_lang = src_lang
118
-
119
- translated_chunks = []
120
- for chunk in chunks:
121
- input_ids = m2m100.tokenizer(
122
- chunk, return_tensors="pt").input_ids.to(m2m100.model.device)
123
- outputs = m2m100.model.generate(
124
- input_ids=input_ids,
125
- forced_bos_token_id=m2m100.tokenizer.get_lang_id(tgt_lang))
126
- translated_chunk = m2m100.tokenizer.batch_decode(
127
- outputs, skip_special_tokens=True)[0]
128
- translated_chunks.append(translated_chunk)
129
-
130
- return '\n'.join(translated_chunks)
131
-
132
-
133
- @spaces.GPU
134
- def translate_with_MADLAD(
135
- chunks: List[str],
136
- tgt_lang: str,
137
- input_max_length: int=512,
138
- output_max_length: int=512) -> str:
139
- """Translate with Google MADLAD model
140
- """
141
- madlad = translation.ModelMADLAD()
142
-
143
- translated_chunks = []
144
- for chunk in chunks:
145
- input_text = f"<2{tgt_lang}> {chunk}"
146
- #logger.info(f" Translating: {input_text[:30]}")
147
- input_ids = madlad.tokenizer(
148
- input_text, return_tensors="pt", max_length=input_max_length,
149
- truncation=True, padding="longest").input_ids.to(madlad.model.device)
150
- outputs = madlad.model.generate(
151
- input_ids=input_ids, max_length=output_max_length)
152
- translated_chunk = madlad.tokenizer.decode(
153
- outputs[0], skip_special_tokens=True)
154
- translated_chunks.append(translated_chunk)
155
-
156
- return '\n'.join(translated_chunks)
157
-
158
-
159
  def translate_text(
160
  text: str,
161
- src_lang: str=None,
162
- sents_per_chunk: int=5,
163
- input_max_length: int=512,
164
- output_max_length: int=512):
165
- """
166
- Translate the given text into English (default "easy" language)
167
  """
168
- src_lang = src_lang if (src_lang and src_lang != "auto") else detect_language(text)
169
- tgt_lang = 'en' # Default "easy" language
170
-
171
- chunks = build_text_chunks(text, src_lang, sents_per_chunk)
172
- chunks = [text, ]
 
 
 
 
173
 
174
- #translated_text_Helsinki = translate_with_Helsinki(
175
- # chunks, src_lang, tgt_lang, input_max_length, output_max_length)
176
- #translated_text_m2m100 = translate_with_m2m100(chunks, src_lang, tgt_lang)
177
- translated_text_MADLAD = translate_with_MADLAD(chunks, tgt_lang)
178
  translated_text_google_translate = GoogleTranslator(
179
  source='auto', target='en').translate(text=text)
180
 
181
  return (
182
- #translated_text_Helsinki,
183
- #translated_text_m2m100,
184
- translated_text_MADLAD,
185
  translated_text_google_translate
186
  )
187
 
@@ -192,77 +88,54 @@ def translate_text(
192
  with gr.Blocks() as demo:
193
 
194
  gr.Markdown("""
195
- ## Text translation v0.0.2 (small paragraph, multilingual)
196
  """)
 
197
  input_text = gr.Textbox(
198
- lines=15,
199
  placeholder="Enter text to translate",
200
  label="Text to translate",
201
- render=False
202
  )
203
- #output_text_Helsinki = gr.Textbox(
204
- # lines=6,
205
- # label="Bilingual translation model (Helsinki NLP)",
206
- # render=False
207
- #)
208
- #output_text_m2m100 = gr.Textbox(
209
- # lines=6,
210
- # label="Facebook m2m100 (1.2B)",
211
- # render=False
212
- #)
213
- output_text_MADLAD = gr.Textbox(
214
  lines=6,
215
- label="Google MADLAD400 (3B)",
216
- render=False
217
  )
218
  output_text_google_translate = gr.Textbox(
219
  lines=6,
220
  label="Google Translate",
221
- render=False
222
  )
223
 
224
- # Extra (additional) input parameters
225
- sentences_per_chunk = gr.Slider(
226
- minimum=1, maximum=10, value=5, step=1,
227
- label="nb sentences per context",
228
- render=False
229
- )
230
- src_lang = gr.Radio(
231
- choices=["auto", "ar", "en", "fa", "fr", "he", "zh"], value="auto",
232
- label="Source language",
233
- render=False
234
- )
 
 
 
235
 
236
- # Examples
237
- examples = [
238
- ["ریچارد مور، رئیس سازمان مخفی اطلاعاتی بریتانیا (ام‌آی‌۶) در دیدار ویلیام برنز، رئیس سازمان اطلاعات مرکزی آمریکا (سیا) گفت همچنان احتمال اقدام ایران علیه اسرائیل در واکنش به ترور اسماعیل هنیه، رهبر حماس وجود دارد. آقای برنز نیز در این دیدار فاش کرد که در سال اول جنگ اوکراین، «خطر واقعی» وجود داشت که روسیه به استفاده از «تسلیحات هسته‌ای تاکتیکی» متوسل شود. این دو مقام امنیتی هشدار دادند که «نظم جهانی» از زمان جنگ سرد تا کنون تا این حد «در معرض تهدید» نبوده است.", "fa"],
239
- #["Clément Delangue est, avec Julien Chaumond et Thomas Wolf, l’un des trois Français cofondateurs de Hugging Face, une start-up d’intelligence artificielle (IA) de premier plan. Valorisée à 4,2 milliards d’euros après avoir levé près de 450 millions d’euros depuis sa création en 2016, cette société de droit américain est connue comme la plate-forme de référence où développeurs et entreprises publient des outils et des modèles pour faire de l’IA en open source, c’est-à-dire accessible gratuitement et modifiable.", "fr"],
240
- ["يُعد تفشي مرض جدري القردة قضية صحية عالمية خطيرة، ومن المهم محاولة منع انتشاره للحفاظ على سلامة الناس وتجنب العدوى. د. صموئيل بولاند، مدير الحوادث الخاصة بمرض الجدري في المكتب الإقليمي لمنظمة الصحة العالمية في أفريقيا، يتحدث من كينشاسا في جمهورية الكونغو الديمقراطية، ولديه بعض النصائح البسيطة التي يمكن للناس اتباعها لتقليل خطر انتشار المرض.", "ar"],
241
- ["張先生稱,奇瑞已經凖備在西班牙生產汽車,並決心採取「本地化」的方式進入歐洲市場。此外,他也否認該公司的出口受益於不公平補貼。奇瑞成立於1997年,是中國最大的汽車公司之一。它已經是中國最大的汽車出口商,並且制定了進一步擴張的野心勃勃的計劃。", "zh"],
242
- #["ברוכה הבאה, קיטי: בית הקפה החדש בלוס אנג'לס החתולה האהובה והחברים שלה מקבלים בית קפה משלהם בשדרות יוניברסל סיטי, שם תוכלו למצוא מגוון של פינוקים מתוקים – החל ממשקאות ועד עוגות", "he"],
243
- ]
244
-
245
- gr.Interface(
246
  fn=translate_text,
247
- inputs=[input_text, src_lang],
248
- outputs=[
249
- #output_text_Helsinki,
250
- #output_text_m2m100,
251
- output_text_MADLAD,
252
- output_text_google_translate,
253
- ],
254
- additional_inputs=[sentences_per_chunk,],
255
- #clear_btn=None, # Unfortunately, clear_btn also reset the additional inputs. Hence disabling for now.
256
- allow_flagging="never",
257
- examples=examples,
258
- cache_examples=False
259
  )
260
 
261
  with gr.Accordion("Documentation", open=False):
262
  gr.Markdown("""
263
- - Models: serving Facebook M2M100 and Google MADLAD models.
264
- - Basic: processing of long paragraph / document to be enhanced.
265
- - Most examples are copy/pasted from BBC news international web sites.
266
  """)
267
 
268
  if __name__ == "__main__":
 
6
  Author: Didier Guillevic
7
  Date: 2024-09-07
8
  """
 
 
 
 
 
9
  import logging
10
  logger = logging.getLogger(__name__)
11
  logging.basicConfig(level=logging.INFO)
12
 
13
+ import gradio as gr
14
+ import langdetect
15
+
16
  from deep_translator import GoogleTranslator
17
  from model_spacy import nlp_xx
18
 
19
+ import model_translation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
 
22
  def translate_with_Helsinki(
 
53
  return '\n'.join(translated_chunks)
54
 
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  def translate_text(
57
  text: str,
58
+ src_lang: str,
59
+ tgt_lang: str
60
+ ) -> str:
61
+ """Translate the given text into English or French
 
 
62
  """
63
+ # src_lang among the supported languages?
64
+ # - make sure src_lang is not None
65
+ src_lang = src_lang if (src_lang and src_lang != "auto") else langdetect.detect(text)
66
+ if src_lang not in model_translation.language_codes.values():
67
+ logging.error(f"Language detected {src_lang} not among supported language")
68
+
69
+ # tgt_lang: make sure it is not None. Default to 'en' if not set.
70
+ if tgt_lang not in model_translation.tgt_language_codes.values():
71
+ tgt_lang = 'en'
72
 
73
+ # translate
74
+ m2m100 = model_translation.ModelM2M100()
75
+ translated_text_m2m100 = m2m100.translate(text, src_lang, tgt_lang)
 
76
  translated_text_google_translate = GoogleTranslator(
77
  source='auto', target='en').translate(text=text)
78
 
79
  return (
80
+ translated_text_m2m100,
 
 
81
  translated_text_google_translate
82
  )
83
 
 
88
  with gr.Blocks() as demo:
89
 
90
  gr.Markdown("""
91
+ ## Text translation v0.0.3
92
  """)
93
+ # Input
94
  input_text = gr.Textbox(
95
+ lines=6,
96
  placeholder="Enter text to translate",
97
  label="Text to translate",
98
+ render=True
99
  )
100
+
101
+ # Output
102
+ output_text_m2m100 = gr.Textbox(
 
 
 
 
 
 
 
 
103
  lines=6,
104
+ label="Facebook m2m100 (1.2B)",
105
+ render=True
106
  )
107
  output_text_google_translate = gr.Textbox(
108
  lines=6,
109
  label="Google Translate",
110
+ render=True
111
  )
112
 
113
+ # Source and target languages
114
+ with gr.Row():
115
+ src_lang = gr.Radio(
116
+ choices=model_translation.language_codes.items(),
117
+ value="auto",
118
+ label="Source language",
119
+ render=True
120
+ )
121
+ tgt_lang = gr.Radio(
122
+ choices=model_translation.tgt_language_codes.items(),
123
+ value="en",
124
+ label="Target language",
125
+ render=True
126
+ )
127
 
128
+ # Submit button
129
+ translate_btn = gr.Button("Translate")
130
+ translate_btn.click(
 
 
 
 
 
 
 
131
  fn=translate_text,
132
+ inputs=[input_text, src_lang, tgt_lang],
133
+ outputs=[output_text_m2m100, output_text_google_translate]
 
 
 
 
 
 
 
 
 
 
134
  )
135
 
136
  with gr.Accordion("Documentation", open=False):
137
  gr.Markdown("""
138
+ - Models: serving Facebook M2M100 and Google Translate.
 
 
139
  """)
140
 
141
  if __name__ == "__main__":
model_translation.py CHANGED
@@ -7,17 +7,188 @@ Description:
7
  Author: Didier Guillevic
8
  Date: 2024-03-16
9
  """
 
 
 
 
 
10
 
11
  import torch
12
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
13
  from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
14
  from transformers import BitsAndBytesConfig
15
 
 
 
16
  quantization_config = BitsAndBytesConfig(
17
  load_in_8bit=True,
18
  llm_int8_threshold=200.0 # https://discuss.huggingface.co/t/correct-usage-of-bitsandbytesconfig/33809/5
19
  )
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  class Singleton(type):
22
  _instances = {}
23
  def __call__(cls, *args, **kwargs):
@@ -25,8 +196,11 @@ class Singleton(type):
25
  cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
26
  return cls._instances[cls]
27
 
 
28
  class ModelM2M100(metaclass=Singleton):
29
  """Loads an instance of the M2M100 model.
 
 
30
  """
31
  def __init__(self):
32
  self._model_name = "facebook/m2m100_1.2B"
@@ -35,10 +209,47 @@ class ModelM2M100(metaclass=Singleton):
35
  self._model_name,
36
  device_map="auto",
37
  torch_dtype=torch.float16,
38
- low_cpu_mem_usage=True
 
39
  )
40
  self._model = torch.compile(self._model)
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  @property
43
  def model_name(self):
44
  return self._model_name
@@ -51,11 +262,20 @@ class ModelM2M100(metaclass=Singleton):
51
  def model(self):
52
  return self._model
53
 
 
 
 
 
 
54
  class ModelMADLAD(metaclass=Singleton):
55
  """Loads an instance of the Google MADLAD model (3B).
 
 
56
  """
57
  def __init__(self):
58
  self._model_name = "google/madlad400-3b-mt"
 
 
59
  self._tokenizer = AutoTokenizer.from_pretrained(
60
  self.model_name, use_fast=True
61
  )
@@ -68,6 +288,44 @@ class ModelMADLAD(metaclass=Singleton):
68
  )
69
  self._model = torch.compile(self._model)
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  @property
72
  def model_name(self):
73
  return self._model_name
@@ -79,6 +337,10 @@ class ModelMADLAD(metaclass=Singleton):
79
  @property
80
  def model(self):
81
  return self._model
 
 
 
 
82
 
83
 
84
  # Bi-lingual individual models
 
7
  Author: Didier Guillevic
8
  Date: 2024-03-16
9
  """
10
+ import spaces
11
+
12
+ import logging
13
+ logger = logging.getLogger(__name__)
14
+ logging.basicConfig(level=logging.INFO)
15
 
16
  import torch
17
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
18
  from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
19
  from transformers import BitsAndBytesConfig
20
 
21
+ from model_spacy import nlp_xx as model_spacy
22
+
23
  quantization_config = BitsAndBytesConfig(
24
  load_in_8bit=True,
25
  llm_int8_threshold=200.0 # https://discuss.huggingface.co/t/correct-usage-of-bitsandbytesconfig/33809/5
26
  )
27
 
28
+ # The 100 languages supported by the facebook/m2m100_418M model
29
+ # https://huggingface.co/facebook/m2m100_418M
30
+ # plus the 'AUTOMATIC' option where we will use a language detector.
31
+ language_codes = {
32
+ 'AUTOMATIC': 'auto',
33
+ 'Afrikaans (af)': 'af',
34
+ 'Albanian (sq)': 'sq',
35
+ 'Amharic (am)': 'am',
36
+ 'Arabic (ar)': 'ar',
37
+ 'Armenian (hy)': 'hy',
38
+ 'Asturian (ast)': 'ast',
39
+ 'Azerbaijani (az)': 'az',
40
+ 'Bashkir (ba)': 'ba',
41
+ 'Belarusian (be)': 'be',
42
+ 'Bengali (bn)': 'bn',
43
+ 'Bosnian (bs)': 'bs',
44
+ 'Breton (br)': 'br',
45
+ 'Bulgarian (bg)': 'bg',
46
+ 'Burmese (my)': 'my',
47
+ 'Catalan; Valencian (ca)': 'ca',
48
+ 'Cebuano (ceb)': 'ceb',
49
+ 'Central Khmer (km)': 'km',
50
+ 'Chinese (zh)': 'zh',
51
+ 'Croatian (hr)': 'hr',
52
+ 'Czech (cs)': 'cs',
53
+ 'Danish (da)': 'da',
54
+ 'Dutch; Flemish (nl)': 'nl',
55
+ 'English (en)': 'en',
56
+ 'Estonian (et)': 'et',
57
+ 'Finnish (fi)': 'fi',
58
+ 'French (fr)': 'fr',
59
+ 'Fulah (ff)': 'ff',
60
+ 'Gaelic; Scottish Gaelic (gd)': 'gd',
61
+ 'Galician (gl)': 'gl',
62
+ 'Ganda (lg)': 'lg',
63
+ 'Georgian (ka)': 'ka',
64
+ 'German (de)': 'de',
65
+ 'Greeek (el)': 'el',
66
+ 'Gujarati (gu)': 'gu',
67
+ 'Haitian; Haitian Creole (ht)': 'ht',
68
+ 'Hausa (ha)': 'ha',
69
+ 'Hebrew (he)': 'he',
70
+ 'Hindi (hi)': 'hi',
71
+ 'Hungarian (hu)': 'hu',
72
+ 'Icelandic (is)': 'is',
73
+ 'Igbo (ig)': 'ig',
74
+ 'Iloko (ilo)': 'ilo',
75
+ 'Indonesian (id)': 'id',
76
+ 'Irish (ga)': 'ga',
77
+ 'Italian (it)': 'it',
78
+ 'Japanese (ja)': 'ja',
79
+ 'Javanese (jv)': 'jv',
80
+ 'Kannada (kn)': 'kn',
81
+ 'Kazakh (kk)': 'kk',
82
+ 'Korean (ko)': 'ko',
83
+ 'Lao (lo)': 'lo',
84
+ 'Latvian (lv)': 'lv',
85
+ 'Lingala (ln)': 'ln',
86
+ 'Lithuanian (lt)': 'lt',
87
+ 'Luxembourgish; Letzeburgesch (lb)': 'lb',
88
+ 'Macedonian (mk)': 'mk',
89
+ 'Malagasy (mg)': 'mg',
90
+ 'Malay (ms)': 'ms',
91
+ 'Malayalam (ml)': 'ml',
92
+ 'Marathi (mr)': 'mr',
93
+ 'Mongolian (mn)': 'mn',
94
+ 'Nepali (ne)': 'ne',
95
+ 'Northern Sotho (ns)': 'ns',
96
+ 'Norwegian (no)': 'no',
97
+ 'Occitan (post 1500) (oc)': 'oc',
98
+ 'Oriya (or)': 'or',
99
+ 'Panjabi; Punjabi (pa)': 'pa',
100
+ 'Persian (fa)': 'fa',
101
+ 'Polish (pl)': 'pl',
102
+ 'Portuguese (pt)': 'pt',
103
+ 'Pushto; Pashto (ps)': 'ps',
104
+ 'Romanian; Moldavian; Moldovan (ro)': 'ro',
105
+ 'Russian (ru)': 'ru',
106
+ 'Serbian (sr)': 'sr',
107
+ 'Sindhi (sd)': 'sd',
108
+ 'Sinhala; Sinhalese (si)': 'si',
109
+ 'Slovak (sk)': 'sk',
110
+ 'Slovenian (sl)': 'sl',
111
+ 'Somali (so)': 'so',
112
+ 'Spanish (es)': 'es',
113
+ 'Sundanese (su)': 'su',
114
+ 'Swahili (sw)': 'sw',
115
+ 'Swati (ss)': 'ss',
116
+ 'Swedish (sv)': 'sv',
117
+ 'Tagalog (tl)': 'tl',
118
+ 'Tamil (ta)': 'ta',
119
+ 'Thai (th)': 'th',
120
+ 'Tswana (tn)': 'tn',
121
+ 'Turkish (tr)': 'tr',
122
+ 'Ukrainian (uk)': 'uk',
123
+ 'Urdu (ur)': 'ur',
124
+ 'Uzbek (uz)': 'uz',
125
+ 'Vietnamese (vi)': 'vi',
126
+ 'Welsh (cy)': 'cy',
127
+ 'Western Frisian (fy)': 'fy',
128
+ 'Wolof (wo)': 'wo',
129
+ 'Xhosa (xh)': 'xh',
130
+ 'Yiddish (yi)': 'yi',
131
+ 'Yoruba (yo)': 'yo',
132
+ 'Zulu (zu)': 'zu'
133
+ }
134
+
135
+ tgt_language_codes = {
136
+ 'English (en)': 'en',
137
+ 'French (fr)': 'fr'
138
+ }
139
+
140
+
141
+ def build_text_chunks(
142
+ text: str,
143
+ sents_per_chunk: int=5,
144
+ words_per_chunk=200) -> list[str]:
145
+ """Split a given text into chunks with at most sents_per_chnks and words_per_chunk
146
+
147
+ Given a text:
148
+ - Split the text into sentences.
149
+ - Build text chunks:
150
+ - Consider up to sents_per_chunk
151
+ - Ensure that we do not exceed words_per_chunk
152
+ """
153
+ # Split text into sentences...
154
+ sentences = [
155
+ sent.text.strip() for sent in model_spacy(text).sents if sent.text.strip()
156
+ ]
157
+ logger.info(f"TEXT: {text[:25]}, NB_SENTS: {len(sentences)}")
158
+
159
+ # Create text chunks of N sentences
160
+ chunks = []
161
+ chunk = ''
162
+ chunk_nb_sentences = 0
163
+ chunk_nb_words = 0
164
+
165
+ for i in range(0, len(sentences)):
166
+ # Get sentence
167
+ sent = sentences[i]
168
+ sent_nb_words = len(sent.split())
169
+
170
+ # If chunk already 'full', save chunk, start new chunk
171
+ if (
172
+ (chunk_nb_words + sent_nb_words > words_per_chunk) or
173
+ (chunk_nb_sentences + 1 > sents_per_chunk)
174
+ ):
175
+ chunks.append(chunk)
176
+ chunk = ''
177
+ chunk_nb_sentences = 0
178
+ chunk_nb_words = 0
179
+
180
+ # Append sentence to current chunk. One sentence per line.
181
+ chunk = (chunk + '\n' + sent) if chunk else sent
182
+ chunk_nb_sentences += 1
183
+ chunk_nb_words += sent_nb_words
184
+
185
+ # Append last chunk
186
+ if chunk:
187
+ chunks.append(chunk)
188
+
189
+ return chunks
190
+
191
+
192
  class Singleton(type):
193
  _instances = {}
194
  def __call__(cls, *args, **kwargs):
 
196
  cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
197
  return cls._instances[cls]
198
 
199
+
200
  class ModelM2M100(metaclass=Singleton):
201
  """Loads an instance of the M2M100 model.
202
+
203
+ Model: https://huggingface.co/facebook/m2m100_1.2B
204
  """
205
  def __init__(self):
206
  self._model_name = "facebook/m2m100_1.2B"
 
209
  self._model_name,
210
  device_map="auto",
211
  torch_dtype=torch.float16,
212
+ low_cpu_mem_usage=True,
213
+ quantization_config=quantization_config
214
  )
215
  self._model = torch.compile(self._model)
216
 
217
+ @spaces.GPU
218
+ def translate(
219
+ self,
220
+ text: str,
221
+ src_lang: str,
222
+ tgt_lang: str,
223
+ chunk_text: bool=True,
224
+ sents_per_chunk: int=5,
225
+ words_per_chunk: int=200
226
+ ) -> str:
227
+ """Translate the given text from src_lang to tgt_lang.
228
+
229
+ The text will be split into chunks to ensure the chunks fit into the
230
+ model input_max_length (usually 512 tokens).
231
+ """
232
+ chunks = [text,]
233
+ if chunk_text:
234
+ chunks = build_text_chunks(text, sents_per_chunk, words_per_chunk)
235
+
236
+ self._tokenizer.src_lang = src_lang
237
+
238
+ translated_chunks = []
239
+ for chunk in chunks:
240
+ input_ids = self._tokenizer(
241
+ chunk,
242
+ return_tensors="pt").input_ids.to(self._model.device)
243
+ outputs = self._model.generate(
244
+ input_ids=input_ids,
245
+ forced_bos_token_id=self._tokenizer.get_lang_id(tgt_lang))
246
+ translated_chunk = self._tokenizer.batch_decode(
247
+ outputs,
248
+ skip_special_tokens=True)[0]
249
+ translated_chunks.append(translated_chunk)
250
+
251
+ return '\n'.join(translated_chunks)
252
+
253
  @property
254
  def model_name(self):
255
  return self._model_name
 
262
  def model(self):
263
  return self._model
264
 
265
+ @property
266
+ def device(self):
267
+ return self._model.device
268
+
269
+
270
  class ModelMADLAD(metaclass=Singleton):
271
  """Loads an instance of the Google MADLAD model (3B).
272
+
273
+ Model: https://huggingface.co/google/madlad400-3b-mt
274
  """
275
  def __init__(self):
276
  self._model_name = "google/madlad400-3b-mt"
277
+ self._input_max_length = 512 # config.json n_positions
278
+ self._output_max_length = 512 # config.json n_positions
279
  self._tokenizer = AutoTokenizer.from_pretrained(
280
  self.model_name, use_fast=True
281
  )
 
288
  )
289
  self._model = torch.compile(self._model)
290
 
291
+ @spaces.GPU
292
+ def translate(
293
+ self,
294
+ text: str,
295
+ tgt_lang: str,
296
+ chunk_text: True,
297
+ sents_per_chunk: int=5,
298
+ words_per_chunk: int=5
299
+ ) -> str:
300
+ """Translate given text into the target language.
301
+
302
+ The text will be split into chunks to ensure the chunks fit into the
303
+ model input_max_length (usually 512 tokens).
304
+ """
305
+ chunks = [text,]
306
+ if chunk_text:
307
+ chunks = build_text_chunks(text, sents_per_chunk, words_per_chunk)
308
+
309
+ translated_chunks = []
310
+ for chunk in chunks:
311
+ input_text = f"<2{tgt_lang}> {chunk}"
312
+ logger.info(f" Translating: {input_text[:50]}")
313
+ input_ids = self._tokenizer(
314
+ input_text,
315
+ return_tensors="pt",
316
+ max_length=self._input_max_length,
317
+ truncation=True,
318
+ padding="longest").input_ids.to(self._model.device)
319
+ outputs = self._model.generate(
320
+ input_ids=input_ids,
321
+ max_length=self._output_max_length)
322
+ translated_chunk = self._tokenizer.decode(
323
+ outputs[0],
324
+ skip_special_tokens=True)
325
+ translated_chunks.append(translated_chunk)
326
+
327
+ return '\n'.join(translated_chunks)
328
+
329
  @property
330
  def model_name(self):
331
  return self._model_name
 
337
  @property
338
  def model(self):
339
  return self._model
340
+
341
+ @property
342
+ def device(self):
343
+ return self._model.device
344
 
345
 
346
  # Bi-lingual individual models