import re import os import sys import torch import gradio as gr from transformers import MBart50TokenizerFast, MBartForConditionalGeneration language_options = { '中文': 'zh_CN', '英语': 'en_XX', '越南语': 'vi_VN', '泰语': 'th_TH', '日语': 'ja_XX', '韩语': 'ko_KR', } languages = list(language_options.keys()) class MBartTranslator: """MBartTranslator class provides a simple interface for translating text using the MBart language model. The class can translate between 50 languages and is based on the "facebook/mbart-large-50-many-to-many-mmt" pre-trained MBart model. However, it is possible to use a different MBart model by specifying its name. Attributes: model (MBartForConditionalGeneration): The MBart language model. tokenizer (MBart50TokenizerFast): The MBart tokenizer. """ def __init__(self, model_name="facebook/mbart-large-50-many-to-many-mmt", src_lang=None, tgt_lang=None): self.supported_languages = [ "ar_AR", "cs_CZ", "de_DE", "en_XX", "es_XX", "et_EE", "fi_FI", "fr_XX", "gu_IN", "hi_IN", "it_IT", "ja_XX", "kk_KZ", "ko_KR", "lt_LT", "lv_LV", "my_MM", "ne_NP", "nl_XX", "ro_RO", "ru_RU", "si_LK", "tr_TR", "vi_VN", "zh_CN", "af_ZA", "az_AZ", "bn_IN", "fa_IR", "he_IL", "hr_HR", "id_ID", "ka_GE", "km_KH", "mk_MK", "ml_IN", "mn_MN", "mr_IN", "pl_PL", "ps_AF", "pt_XX", "sv_SE", "sw_KE", "ta_IN", "te_IN", "th_TH", "tl_XX", "uk_UA", "ur_PK", "xh_ZA", "gl_ES", "sl_SI", ] print("Building translator") print("Loading generator (this may take few minutes the first time as I need to download the model)") self.model = MBartForConditionalGeneration.from_pretrained(model_name).to(device) print("Loading tokenizer") self.tokenizer = MBart50TokenizerFast.from_pretrained(model_name, src_lang=src_lang, tgt_lang=tgt_lang) print("Translator is ready") def translate(self, text: str, input_language: str, output_language: str) -> str: """Translate the given text from the input language to the output language. Args: text (str): The text to translate. input_language (str): The input language code (e.g. "hi_IN" for Hindi). output_language (str): The output language code (e.g. "en_US" for English). Returns: str: The translated text. """ if input_language not in self.supported_languages: raise ValueError(f"Input language not supported. Supported languages: {self.supported_languages}") if output_language not in self.supported_languages: raise ValueError(f"Output language not supported. Supported languages: {self.supported_languages}") self.tokenizer.src_lang = input_language encoded_input = self.tokenizer(text, return_tensors="pt").to(device) generated_tokens = self.model.generate( **encoded_input, forced_bos_token_id=self.tokenizer.lang_code_to_id[output_language] ) translated_text = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) return translated_text[0] device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") translator = MBartTranslator() def translate(src, dst, content): output = translator.translate(content, language_options[src], language_options[dst]) # output = translator.translate(content, "zh_CN", "en_XX") return output examples=[ ['中文', '英语', '今天天气真不错!'], ['英语', '中文', "Life was a box of chocolates, you never know what you're gonna get."], ['中文', '泰语', '别放弃你的梦想,迟早有一天它会在你手里发光。'], ] demo = gr.Interface( fn=translate, inputs=[ gr.Dropdown( languages, label="源语言", value=languages[0], show_label=True ), gr.Dropdown( languages, label="目标语言", value=languages[1], show_label=True ), gr.Textbox(label='内容', placeholder='这里输入要翻译的内容', lines=5) ], outputs=[ gr.Textbox(label='结果', lines=5) ], examples=examples ) demo.launch(enable_queue=True)