Didier commited on
Commit
706408b
1 Parent(s): b10cb1c

Add the M2M100 model from Facebook

Browse files
Files changed (2) hide show
  1. app.py +69 -120
  2. model_translation.py +63 -28
app.py CHANGED
@@ -7,25 +7,25 @@ Author: Didier Guillevic
7
  Date: 2024-09-07
8
  """
9
  import spaces
10
- import torch
11
  import gradio as gr
12
  import langdetect
 
13
 
14
  import logging
15
  logger = logging.getLogger(__name__)
16
  logging.basicConfig(level=logging.INFO)
17
 
 
 
 
18
  import model_translation as translation
19
- from model_translation import tokenizer_multilingual, model_multilingual
20
- from model_translation import tokenizer_m2m100, model_m2m100
21
 
22
- from deep_translator import GoogleTranslator
23
 
24
- from model_spacy import nlp_xx
 
 
 
25
 
26
- #
27
- # Translate given input text
28
- #
29
  def build_text_chunks(text, src_lang, sents_per_chunk):
30
  """
31
  Given a text:
@@ -71,18 +71,21 @@ def build_text_chunks(text, src_lang, sents_per_chunk):
71
 
72
  return chunks
73
 
74
- def translate_with_model(
75
- text, tokenizer, model, src_lang, sents_per_chunk,
76
- input_max_length=512, output_max_length=512):
77
-
78
- # Build text chunks (using sents_per_chunk and translation.max_words_per_chunk)
79
- chunks = build_text_chunks(text, src_lang, sents_per_chunk)
80
- logger.info(f"LANG: {src_lang}, TEXT: {text[:20]}, NB_CHUNKS: {len(chunks)}")
81
 
82
- # Translate chunks
 
 
 
 
 
 
 
 
 
 
 
83
  translated_chunks = []
84
  for chunk in chunks:
85
-
86
  # NOTE: The 'fa' (Persian) model has multiple target languages to choose from.
87
  # We need to specifiy the desired languages among: fra ita por ron spa
88
  # https://huggingface.co/Helsinki-NLP/opus-mt-tc-big-fa-itc
@@ -91,92 +94,61 @@ def translate_with_model(
91
  chunk = ">>fra<< " + chunk
92
 
93
  inputs = tokenizer(
94
- chunk, return_tensors="pt",
95
- max_length=input_max_length,
96
  truncation=True, padding="longest").to(model.device)
97
-
98
- outputs = model.generate(
99
- **inputs,
100
- max_length=output_max_length)
101
-
102
  translated_chunk = tokenizer.batch_decode(
103
  outputs, skip_special_tokens=True)[0]
104
-
105
  #logger.info(f"Text: {chunk}")
106
  #logger.info(f"Translation: {translated_chunk}")
107
-
108
  translated_chunks.append(translated_chunk)
109
 
110
  return '\n'.join(translated_chunks)
111
 
112
 
113
- def detect_language(text):
114
- lang = langdetect.detect(text)
115
- return lang
116
-
117
-
118
- def translate_with_bilingual_model(
119
- text, src_lang, tgt_lang, sents_per_chunk
120
- ):
121
- """
122
- Translate with Helsinki bilingual models
123
- """
124
- if src_lang not in translation.src_langs:
125
- return (
126
- f"ISSUE: currently no model for language '{src_lang}'. "
127
- "If wrong language, please specify language."
128
- )
129
- logger.info(f"LANG: {src_lang}, TEXT: {text[:50]}...")
130
- tokenizer, model = translation.get_tokenizer_model_for_src_lang(src_lang)
131
- translated_text_bilingual_model = translate_with_model(
132
- text, tokenizer, model, src_lang, sents_per_chunk)
133
- return translated_text_bilingual_model
134
-
135
-
136
  @spaces.GPU
137
- def translate_with_m2m100_model(
138
- text: str,
139
  src_lang: str,
140
- tgt_lang: str,
141
- sents_per_chunk: int=5):
142
- """
143
- Translate with the m2m100 model
144
  """
145
- tokenizer_m2m100.src_lang = src_lang
146
- input_ids = tokenizer_m2m100(
147
- text, return_tensors="pt").input_ids.to(model_m2m100.device)
148
- outputs = model_m2m100.generate(
149
- input_ids=input_ids,
150
- forced_bos_token_id=tokenizer_m2m100.get_lang_id(tgt_lang))
151
- translated_text = tokenizer_m2m100.batch_decode(
152
- outputs[0], skip_special_tokens=True)
153
- return translated_text
 
 
 
 
 
 
154
 
155
 
156
  @spaces.GPU
157
- def translate_with_multilingual_model(
158
- text: str,
159
  tgt_lang: str,
160
- sents_per_chunk: int=5,
161
  input_max_length: int=512,
162
- output_max_length: int=512):
 
163
  """
164
- Translate the given text into English (default "easy" language)
165
- """
166
- chunks = build_text_chunks(text, None, sents_per_chunk)
167
- translated_chunks = []
168
 
169
  for chunk in chunks:
170
  input_text = f"<2{tgt_lang}> {text}"
171
- logger.info(f" Translating: {input_text[:30]}")
172
- input_ids = tokenizer_multilingual(
173
- input_text, return_tensors="pt",
174
- max_length=input_max_length,
175
- truncation=True, padding="longest").input_ids.to(
176
- model_multilingual.device)
177
- outputs = model_multilingual.generate(
178
  input_ids=input_ids, max_length=output_max_length)
179
- translated_chunk = tokenizer_multilingual.decode(
180
  outputs[0], skip_special_tokens=True)
181
  translated_chunks.append(translated_chunk)
182
 
@@ -195,37 +167,19 @@ def translate_text(
195
  src_lang = src_lang if (src_lang and src_lang != "auto") else detect_language(text)
196
  tgt_lang = 'en' # Default "easy" language
197
 
198
- #
199
- # Bilingual (Helsinki model)
200
- #
201
- translated_text_bilingual_model = translate_with_bilingual_model(
202
- text, src_lang, tgt_lang, sents_per_chunk
203
- )
204
-
205
- #
206
- # m2m100 model
207
- #
208
- translated_text_m2m100_model = translate_with_m2m100_model(
209
- text, src_lang, tgt_lang, sents_per_chunk
210
- )
211
 
212
- #
213
- # Multilingual model (Google MADLAD)
214
- #
215
-
216
- translated_text_multilingual_model = translate_with_multilingual_model(
217
- text, tgt_lang, sents_per_chunk, input_max_length, output_max_length)
218
-
219
- #
220
- # Google Translate
221
- #
222
  translated_text_google_translate = GoogleTranslator(
223
  source='auto', target='en').translate(text=text)
224
 
225
  return (
226
- translated_text_bilingual_model,
227
-
228
- translated_text_multilingual_model,
229
  translated_text_google_translate
230
  )
231
 
@@ -244,19 +198,19 @@ with gr.Blocks() as demo:
244
  label="Text to translate",
245
  render=False
246
  )
247
- output_text_bilingual_model = gr.Textbox(
248
  lines=6,
249
  label="Bilingual translation model (Helsinki NLP)",
250
  render=False
251
  )
252
- output_text_m2m100_model = gr.Textbox(
253
  lines=6,
254
  label="Facebook m2m100 translation model (**small**)",
255
  render=False
256
  )
257
- output_text_multilingual_model = gr.Textbox(
258
  lines=6,
259
- label="Multilingual translation model (**small** Google MADLAD)",
260
  render=False
261
  )
262
  output_text_google_translate = gr.Textbox(
@@ -272,7 +226,7 @@ with gr.Blocks() as demo:
272
  render=False
273
  )
274
  src_lang = gr.Radio(
275
- choices=["auto", "ar", "en", "fa", "fr", "he", "ja", "zh"], value="auto",
276
  label="Source language",
277
  render=False
278
  )
@@ -282,22 +236,17 @@ with gr.Blocks() as demo:
282
  ["ریچارد مور، رئیس سازمان مخفی اطلاعاتی بریتانیا (ام‌آی‌۶) در دیدار ویلیام برنز، رئیس سازمان اطلاعات مرکزی آمریکا (سیا) گفت همچنان احتمال اقدام ایران علیه اسرائیل در واکنش به ترور اسماعیل هنیه، رهبر حماس وجود دارد. آقای برنز نیز در این دیدار فاش کرد که در سال اول جنگ اوکراین، «خطر واقعی» وجود داشت که روسیه به استفاده از «تسلیحات هسته‌ای تاکتیکی» متوسل شود. این دو مقام امنیتی هشدار دادند که «نظم جهانی» از زمان جنگ سرد تا کنون تا این حد «در معرض تهدید» نبوده است.", "fa"],
283
  ["Clément Delangue est, avec Julien Chaumond et Thomas Wolf, l’un des trois Français cofondateurs de Hugging Face, une start-up d’intelligence artificielle (IA) de premier plan. Valorisée à 4,2 milliards d’euros après avoir levé près de 450 millions d’euros depuis sa création en 2016, cette société de droit américain est connue comme la plate-forme de référence où développeurs et entreprises publient des outils et des modèles pour faire de l’IA en open source, c’est-à-dire accessible gratuitement et modifiable.", "fr"],
284
  ["يُعد تفشي مرض جدري القردة قضية صحية عالمية خطيرة، ومن المهم محاولة منع انتشاره للحفاظ على سلامة الناس وتجنب العدوى. د. صموئيل بولاند، مدير الحوادث الخاصة بمرض الجدري في المكتب الإقليمي لمنظمة الصحة العالمية في أفريقيا، يتحدث من كينشاسا في جمهورية الكونغو الديمقراطية، ولديه بعض النصائح البسيطة التي يمكن للناس اتباعها لتقليل خطر انتشار المرض.", "ar"],
285
- ["【ワシントン=冨山優介】米ボーイングの新型宇宙船「スターライナー」は7日午前0時(日本時間7日午後1時)過ぎ、米ニューメキシコ州のホワイトサンズ宇宙港に着地し、地球に帰還した。スターライナーは米宇宙飛行士2人を乗せて6月に打ち上げられ、国際宇宙ステーション(ISS)に接続したが、機体のトラブルが解決できず、無人でISSから離脱した。", "ja"],
286
  ["張先生稱,奇瑞已經凖備在西班牙生產汽車,並決心採取「本地化」的方式進入歐洲市場。此外,他也否認該公司的出口受益於不公平補貼。奇瑞成立於1997年,是中國最大的汽車公司之一。它已經是中國最大的汽車出口商,並且制定了進一步擴張的野心勃勃的計劃。", "zh"],
287
  ["ברוכה הבאה, קיטי: בית הקפה החדש בלוס אנג'לס החתולה האהובה והחברים שלה מקבלים בית קפה משלהם בשדרות יוניברסל סיטי, שם תוכלו למצוא מגוון של פינוקים מתוקים – החל ממשקאות ועד עוגות", "he"],
288
  ]
289
-
290
- outputs = gr.Row(
291
-
292
- )
293
 
294
  gr.Interface(
295
  fn=translate_text,
296
  inputs=[input_text, src_lang,],
297
  outputs=[
298
- output_text_bilingual_model,
299
- output_text_multilingual_model,
300
- output_text_m2m100_model,
301
  output_text_google_translate,
302
  ],
303
  additional_inputs=[sentences_per_chunk,],
@@ -309,7 +258,7 @@ with gr.Blocks() as demo:
309
 
310
  with gr.Accordion("Documentation", open=False):
311
  gr.Markdown("""
312
- - Models: serving bilingual models from Helsinki NLP and multilingual model from Google MADLAD.
313
  - Basic: processing of long paragraph / document to be enhanced.
314
  - Most examples are copy/pasted from BBC news international web sites.
315
  """)
 
7
  Date: 2024-09-07
8
  """
9
  import spaces
 
10
  import gradio as gr
11
  import langdetect
12
+ from typing import List
13
 
14
  import logging
15
  logger = logging.getLogger(__name__)
16
  logging.basicConfig(level=logging.INFO)
17
 
18
+ from deep_translator import GoogleTranslator
19
+ from model_spacy import nlp_xx
20
+
21
  import model_translation as translation
 
 
22
 
 
23
 
24
+ def detect_language(text):
25
+ lang = langdetect.detect(text)
26
+ return lang
27
+
28
 
 
 
 
29
  def build_text_chunks(text, src_lang, sents_per_chunk):
30
  """
31
  Given a text:
 
71
 
72
  return chunks
73
 
 
 
 
 
 
 
 
74
 
75
+ def translate_with_Helsinki(
76
+ chunks, src_lang, tgt_lang, input_max_length, output_max_length) -> str:
77
+ """Translate the chunks with the Helsinki model
78
+ """
79
+ if src_lang not in translation.src_langs:
80
+ return (
81
+ f"ISSUE: currently no model for language '{src_lang}'. "
82
+ "If wrong language, please specify language."
83
+ )
84
+ logger.info(f"LANG: {src_lang}, TEXT: {chunks[0][:50]}...")
85
+ tokenizer, model = translation.get_tokenizer_model_for_src_lang(src_lang)
86
+
87
  translated_chunks = []
88
  for chunk in chunks:
 
89
  # NOTE: The 'fa' (Persian) model has multiple target languages to choose from.
90
  # We need to specifiy the desired languages among: fra ita por ron spa
91
  # https://huggingface.co/Helsinki-NLP/opus-mt-tc-big-fa-itc
 
94
  chunk = ">>fra<< " + chunk
95
 
96
  inputs = tokenizer(
97
+ chunk, return_tensors="pt", max_length=input_max_length,
 
98
  truncation=True, padding="longest").to(model.device)
99
+ outputs = model.generate(**inputs, max_length=output_max_length)
 
 
 
 
100
  translated_chunk = tokenizer.batch_decode(
101
  outputs, skip_special_tokens=True)[0]
 
102
  #logger.info(f"Text: {chunk}")
103
  #logger.info(f"Translation: {translated_chunk}")
 
104
  translated_chunks.append(translated_chunk)
105
 
106
  return '\n'.join(translated_chunks)
107
 
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  @spaces.GPU
110
+ def translate_with_m2m100(
111
+ chunks: List[str],
112
  src_lang: str,
113
+ tgt_lang: str) -> str:
114
+ """Translate with the m2m100 model
 
 
115
  """
116
+ m2m100 = translation.ModelM2M100()
117
+ m2m100.tokenizer.src_lang = src_lang
118
+
119
+ translated_chunks = []
120
+ for chunk in chunks:
121
+ input_ids = m2m100.tokenizer(
122
+ chunk, return_tensors="pt").input_ids.to(m2m100.model.device)
123
+ outputs = m2m100.model.generate(
124
+ input_ids=input_ids,
125
+ forced_bos_token_id=m2m100.tokenizer.get_lang_id(tgt_lang))
126
+ translated_chunk = m2m100.tokenizer.batch_decode(
127
+ outputs, skip_special_tokens=True)[0]
128
+ translated_chunks.append(translated_chunk)
129
+
130
+ return '\n'.join(translated_chunks)
131
 
132
 
133
  @spaces.GPU
134
+ def translate_with_MADLAD(
135
+ chunks: List[str],
136
  tgt_lang: str,
 
137
  input_max_length: int=512,
138
+ output_max_length: int=512) -> str:
139
+ """Translate with Google MADLAD model
140
  """
141
+ madlad = translation.ModelMADLAD()
 
 
 
142
 
143
  for chunk in chunks:
144
  input_text = f"<2{tgt_lang}> {text}"
145
+ #logger.info(f" Translating: {input_text[:30]}")
146
+ input_ids = madlad.tokenizer(
147
+ input_text, return_tensors="pt", max_length=input_max_length,
148
+ truncation=True, padding="longest").input_ids.to(madlad.model.device)
149
+ outputs = madlad.model.generate(
 
 
150
  input_ids=input_ids, max_length=output_max_length)
151
+ translated_chunk = madlad.tokenizer.decode(
152
  outputs[0], skip_special_tokens=True)
153
  translated_chunks.append(translated_chunk)
154
 
 
167
  src_lang = src_lang if (src_lang and src_lang != "auto") else detect_language(text)
168
  tgt_lang = 'en' # Default "easy" language
169
 
170
+ chunks = build_text_chunks(text, src_lang, sents_per_chunk)
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
+ translated_text_Helsinki = translate_with_Helsinki(
173
+ chunks, src_lang, tgt_lang, input_max_length, output_max_length)
174
+ translated_text_m2m100 = translate_with_m2m100(chunks, src_lang, tgt_lang)
175
+ translated_text_MADLAD = translate_with_MADLAD(chunks, tgt_lang)
 
 
 
 
 
 
176
  translated_text_google_translate = GoogleTranslator(
177
  source='auto', target='en').translate(text=text)
178
 
179
  return (
180
+ translated_text_Helsinki,
181
+ translated_text_m2m100,
182
+ translated_text_MADLAD,
183
  translated_text_google_translate
184
  )
185
 
 
198
  label="Text to translate",
199
  render=False
200
  )
201
+ output_text_Helsinki = gr.Textbox(
202
  lines=6,
203
  label="Bilingual translation model (Helsinki NLP)",
204
  render=False
205
  )
206
+ output_text_m2m100 = gr.Textbox(
207
  lines=6,
208
  label="Facebook m2m100 translation model (**small**)",
209
  render=False
210
  )
211
+ output_text_MADLAD = gr.Textbox(
212
  lines=6,
213
+ label="Google MADLAD translation model (**small**)",
214
  render=False
215
  )
216
  output_text_google_translate = gr.Textbox(
 
226
  render=False
227
  )
228
  src_lang = gr.Radio(
229
+ choices=["auto", "ar", "en", "fa", "fr", "he", "zh"], value="auto",
230
  label="Source language",
231
  render=False
232
  )
 
236
  ["ریچارد مور، رئیس سازمان مخفی اطلاعاتی بریتانیا (ام‌آی‌۶) در دیدار ویلیام برنز، رئیس سازمان اطلاعات مرکزی آمریکا (سیا) گفت همچنان احتمال اقدام ایران علیه اسرائیل در واکنش به ترور اسماعیل هنیه، رهبر حماس وجود دارد. آقای برنز نیز در این دیدار فاش کرد که در سال اول جنگ اوکراین، «خطر واقعی» وجود داشت که روسیه به استفاده از «تسلیحات هسته‌ای تاکتیکی» متوسل شود. این دو مقام امنیتی هشدار دادند که «نظم جهانی» از زمان جنگ سرد تا کنون تا این حد «در معرض تهدید» نبوده است.", "fa"],
237
  ["Clément Delangue est, avec Julien Chaumond et Thomas Wolf, l’un des trois Français cofondateurs de Hugging Face, une start-up d’intelligence artificielle (IA) de premier plan. Valorisée à 4,2 milliards d’euros après avoir levé près de 450 millions d’euros depuis sa création en 2016, cette société de droit américain est connue comme la plate-forme de référence où développeurs et entreprises publient des outils et des modèles pour faire de l’IA en open source, c’est-à-dire accessible gratuitement et modifiable.", "fr"],
238
  ["يُعد تفشي مرض جدري القردة قضية صحية عالمية خطيرة، ومن المهم محاولة منع انتشاره للحفاظ على سلامة الناس وتجنب العدوى. د. صموئيل بولاند، مدير الحوادث الخاصة بمرض الجدري في المكتب الإقليمي لمنظمة الصحة العالمية في أفريقيا، يتحدث من كينشاسا في جمهورية الكونغو الديمقراطية، ولديه بعض النصائح البسيطة التي يمكن للناس اتباعها لتقليل خطر انتشار المرض.", "ar"],
 
239
  ["張先生稱,奇瑞已經凖備在西班牙生產汽車,並決心採取「本地化」的方式進入歐洲市場。此外,他也否認該公司的出口受益於不公平補貼。奇瑞成立於1997年,是中國最大的汽車公司之一。它已經是中國最大的汽車出口商,並且制定了進一步擴張的野心勃勃的計劃。", "zh"],
240
  ["ברוכה הבאה, קיטי: בית הקפה החדש בלוס אנג'לס החתולה האהובה והחברים שלה מקבלים בית קפה משלהם בשדרות יוניברסל סיטי, שם תוכלו למצוא מגוון של פינוקים מתוקים – החל ממשקאות ועד עוגות", "he"],
241
  ]
 
 
 
 
242
 
243
  gr.Interface(
244
  fn=translate_text,
245
  inputs=[input_text, src_lang,],
246
  outputs=[
247
+ output_text_Helsinki,
248
+ output_text_m2m100,
249
+ output_text_MADLAD,
250
  output_text_google_translate,
251
  ],
252
  additional_inputs=[sentences_per_chunk,],
 
258
 
259
  with gr.Accordion("Documentation", open=False):
260
  gr.Markdown("""
261
+ - Models: serving Helsinki NLP, Facebook M2M100 and Google MADLAD models.
262
  - Basic: processing of long paragraph / document to be enhanced.
263
  - Most examples are copy/pasted from BBC news international web sites.
264
  """)
model_translation.py CHANGED
@@ -10,7 +10,70 @@ Date: 2024-03-16
10
 
11
  import torch
12
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
 
 
 
 
 
 
 
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  src_langs = set(["ar", "en", "fa", "fr", "he", "ja", "zh"])
15
  model_names = {
16
  "ar": "Helsinki-NLP/opus-mt-ar-en",
@@ -18,7 +81,6 @@ model_names = {
18
  "fa": "Helsinki-NLP/opus-mt-tc-big-fa-itc",
19
  "fr": "Helsinki-NLP/opus-mt-fr-en",
20
  "he": "Helsinki-NLP/opus-mt-tc-big-he-en",
21
- "ja": "Helsinki-NLP/opus-mt-jap-en",
22
  "zh": "Helsinki-NLP/opus-mt-zh-en",
23
  }
24
 
@@ -57,30 +119,3 @@ def get_tokenizer_model_for_src_lang(src_lang: str) -> (AutoTokenizer, AutoModel
57
  # - Let's set to some number of words somewhat lower than that threshold
58
  # - e.g. 200 words
59
  max_words_per_chunk = 200
60
-
61
- #
62
- # Multilingual language pairs
63
- #
64
- from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
65
-
66
- model_name_m2m100 = "facebook/m2m100_418M"
67
- tokenizer_m2m100 = M2M100Tokenizer.from_pretrained(model_name_m2m100)
68
- model_m2m100 = M2M100ForConditionalGeneration.from_pretrained(
69
- model_name_m2m100,
70
- device_map="auto",
71
- torch_dtype=torch.float16,
72
- low_cpu_mem_usage=True
73
- )
74
-
75
- #
76
- # Multilingual translation model
77
- #
78
- model_MADLAD_name = "google/madlad400-3b-mt"
79
- #model_MADLAD_name = "google/madlad400-7b-mt-bt"
80
- tokenizer_multilingual = AutoTokenizer.from_pretrained(model_MADLAD_name, use_fast=True)
81
- model_multilingual = AutoModelForSeq2SeqLM.from_pretrained(
82
- model_MADLAD_name,
83
- device_map="auto",
84
- torch_dtype=torch.float16,
85
- low_cpu_mem_usage=True
86
- )
 
10
 
11
  import torch
12
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
13
+ from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
14
+
15
+
16
+ class Singleton(type):
17
+ _instances = {}
18
+ def __call__(cls, *args, **kwargs):
19
+ if cls not in cls._instances:
20
+ cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
21
+ return cls._instances[cls]
22
 
23
+ class ModelM2M100(metaclass=Singleton):
24
+ """Loads an instance of the M2M100 model (418M).
25
+ """
26
+ def __init__(self):
27
+ self._model_name = "facebook/m2m100_418M"
28
+ self._tokenizer = M2M100Tokenizer.from_pretrained(self._model_name)
29
+ self._model = M2M100ForConditionalGeneration.from_pretrained(
30
+ self._model_name,
31
+ device_map="auto",
32
+ torch_dtype=torch.float16,
33
+ low_cpu_mem_usage=True
34
+ )
35
+
36
+ @property
37
+ def model_name(self):
38
+ return self._model_name
39
+
40
+ @property
41
+ def tokenizer(self):
42
+ return self._tokenizer
43
+
44
+ @property
45
+ def model(self):
46
+ return self._model
47
+
48
+ class ModelMADLAD(metaclass=Singleton):
49
+ """Loads an instance of the Google MADLAD model (3B).
50
+ """
51
+ def __init__(self, model_name):
52
+ self._model_name = "google/madlad400-3b-mt"
53
+ self._tokenizer = AutoTokenizer.from_pretrained(
54
+ self.model_name, use_fast=True
55
+ )
56
+ self._model = AutoModelForSeq2SeqLM.from_pretrained(
57
+ self._model_name,
58
+ device_map="auto",
59
+ torch_dtype=torch.float16,
60
+ low_cpu_mem_usage=True
61
+ )
62
+
63
+ @property
64
+ def model_name(self):
65
+ return self._model_name
66
+
67
+ @property
68
+ def tokenizer(self):
69
+ return self._tokenizer
70
+
71
+ @property
72
+ def model(self):
73
+ return self._model
74
+
75
+
76
+ # Bi-lingual individual models
77
  src_langs = set(["ar", "en", "fa", "fr", "he", "ja", "zh"])
78
  model_names = {
79
  "ar": "Helsinki-NLP/opus-mt-ar-en",
 
81
  "fa": "Helsinki-NLP/opus-mt-tc-big-fa-itc",
82
  "fr": "Helsinki-NLP/opus-mt-fr-en",
83
  "he": "Helsinki-NLP/opus-mt-tc-big-he-en",
 
84
  "zh": "Helsinki-NLP/opus-mt-zh-en",
85
  }
86
 
 
119
  # - Let's set to some number of words somewhat lower than that threshold
120
  # - e.g. 200 words
121
  max_words_per_chunk = 200