Spaces:
Sleeping
Sleeping
import torch | |
from transformers import MBart50TokenizerFast, MBartForConditionalGeneration | |
import gradio as gr | |
class MBartTranslator: | |
""" | |
MBartTranslator provides a simple user interface for text translation using the MBart Langugae model. | |
This class can translate one language to 50 other languages and is based on the model "facebook/mbart-large-50-one-to-many-mmt", a pretrained model from facebook. | |
Attributes: | |
model (MBartForConditionalGeneration): The MBart language model. | |
tokenizer (MBart50TokenizerFast): The MBart tokenizer. | |
""" | |
def __init__(self, model_name="facebook/mbart-large-50-many-to-many-mmt", src_lang=None, tgt_lang=None): | |
self.supported_languages = [ | |
"en_XX", | |
"hi_IN", | |
"ml_IN", | |
"ta_IN", | |
"te_IN" | |
] | |
print("Building translator") | |
print("Loading tokenizer") | |
self.tokenizer = MBart50TokenizerFast.from_pretrained(model_name) | |
print("Laoding generator") | |
self.model = MBartForConditionalGeneration.from_pretrained(model_name) | |
def translate(self, text: str, input_language: str, output_language: str) -> str: | |
""" | |
Translates the given text from input to output langauge. | |
Args: | |
text (str): The text to translate. | |
input_language (str): The input language code (e.g. "en_XX" for English). | |
output_language (str): The output language code (e.g. "ta_IN" for Tamil). | |
Returns: | |
str: The translated text. | |
""" | |
if input_language not in self.supported_languages: | |
raise ValueError(f"Input language not supported. Supported languages: {self.supported_languages}") | |
if output_language not in self.supported_languages: | |
raise ValueError(f"Output language not supported. Supported languages: {self.supported_languages}") | |
self.tokenizer.src_lang = input_language | |
encoded_input = self.tokenizer(text, return_tensors='pt') | |
generated_tokens = self.model.generate( | |
**encoded_input, forced_bos_token_id=self.tokenizer.lang_code_to_id[output_language] | |
) | |
translated_text = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) | |
return translated_text[0] | |
translator = MBartTranslator() | |
def translate_text(source_lang, target_lang, text): | |
return translator.translate(text, source_lang, target_lang) | |
translation_interface = gr.Interface( | |
fn=translate_text, | |
inputs=[ | |
gr.Dropdown(choices=["en_XX", "hi_IN", "ml_IN", "ta_IN", "te_IN"], label = "Source Language"), | |
gr.Dropdown(choices=['en_XX', "hi_IN", "ml_IN", "ta_IN", "te_IN"], label = "Target Language"), | |
gr.Textbox(label = "Text to translate") | |
], | |
outputs=gr.Textbox(label = "Translated text") | |
) | |
translation_interface.launch() |