MBartTranslator / app.py
wall-e-zz's picture
Update app.py
58f896d
raw
history blame contribute delete
No virus
4.83 kB
import re
import os
import sys
import torch
import gradio as gr
from transformers import MBart50TokenizerFast, MBartForConditionalGeneration
language_options = {
'中文': 'zh_CN',
'英语': 'en_XX',
'越南语': 'vi_VN',
'泰语': 'th_TH',
'日语': 'ja_XX',
'韩语': 'ko_KR',
}
languages = list(language_options.keys())
class MBartTranslator:
"""MBartTranslator class provides a simple interface for translating text using the MBart language model.
The class can translate between 50 languages and is based on the "facebook/mbart-large-50-many-to-many-mmt"
pre-trained MBart model. However, it is possible to use a different MBart model by specifying its name.
Attributes:
model (MBartForConditionalGeneration): The MBart language model.
tokenizer (MBart50TokenizerFast): The MBart tokenizer.
"""
def __init__(self, model_name="facebook/mbart-large-50-many-to-many-mmt", src_lang=None, tgt_lang=None):
self.supported_languages = [
"ar_AR",
"cs_CZ",
"de_DE",
"en_XX",
"es_XX",
"et_EE",
"fi_FI",
"fr_XX",
"gu_IN",
"hi_IN",
"it_IT",
"ja_XX",
"kk_KZ",
"ko_KR",
"lt_LT",
"lv_LV",
"my_MM",
"ne_NP",
"nl_XX",
"ro_RO",
"ru_RU",
"si_LK",
"tr_TR",
"vi_VN",
"zh_CN",
"af_ZA",
"az_AZ",
"bn_IN",
"fa_IR",
"he_IL",
"hr_HR",
"id_ID",
"ka_GE",
"km_KH",
"mk_MK",
"ml_IN",
"mn_MN",
"mr_IN",
"pl_PL",
"ps_AF",
"pt_XX",
"sv_SE",
"sw_KE",
"ta_IN",
"te_IN",
"th_TH",
"tl_XX",
"uk_UA",
"ur_PK",
"xh_ZA",
"gl_ES",
"sl_SI",
]
print("Building translator")
print("Loading generator (this may take few minutes the first time as I need to download the model)")
self.model = MBartForConditionalGeneration.from_pretrained(model_name).to(device)
print("Loading tokenizer")
self.tokenizer = MBart50TokenizerFast.from_pretrained(model_name, src_lang=src_lang, tgt_lang=tgt_lang)
print("Translator is ready")
def translate(self, text: str, input_language: str, output_language: str) -> str:
"""Translate the given text from the input language to the output language.
Args:
text (str): The text to translate.
input_language (str): The input language code (e.g. "hi_IN" for Hindi).
output_language (str): The output language code (e.g. "en_US" for English).
Returns:
str: The translated text.
"""
if input_language not in self.supported_languages:
raise ValueError(f"Input language not supported. Supported languages: {self.supported_languages}")
if output_language not in self.supported_languages:
raise ValueError(f"Output language not supported. Supported languages: {self.supported_languages}")
self.tokenizer.src_lang = input_language
encoded_input = self.tokenizer(text, return_tensors="pt").to(device)
generated_tokens = self.model.generate(
**encoded_input, forced_bos_token_id=self.tokenizer.lang_code_to_id[output_language]
)
translated_text = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
return translated_text[0]
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
translator = MBartTranslator()
def translate(src, dst, content):
output = translator.translate(content, language_options[src], language_options[dst])
# output = translator.translate(content, "zh_CN", "en_XX")
return output
examples=[
['中文', '英语', '今天天气真不错!'],
['英语', '中文', "Life was a box of chocolates, you never know what you're gonna get."],
['中文', '泰语', '别放弃你的梦想,迟早有一天它会在你手里发光。'],
]
demo = gr.Interface(
fn=translate,
inputs=[
gr.Dropdown(
languages, label="源语言", value=languages[0], show_label=True
),
gr.Dropdown(
languages, label="目标语言", value=languages[1], show_label=True
),
gr.Textbox(label='内容', placeholder='这里输入要翻译的内容', lines=5)
],
outputs=[
gr.Textbox(label='结果', lines=5)
],
examples=examples
)
demo.launch(enable_queue=True)