SathishSKN's picture
Update app.py
3b9fdf4 verified
import torch
from transformers import MBart50TokenizerFast, MBartForConditionalGeneration
import gradio as gr
class MBartTranslator:
"""
MBartTranslator provides a simple user interface for text translation using the MBart Langugae model.
This class can translate one language to 50 other languages and is based on the model "facebook/mbart-large-50-one-to-many-mmt", a pretrained model from facebook.
Attributes:
model (MBartForConditionalGeneration): The MBart language model.
tokenizer (MBart50TokenizerFast): The MBart tokenizer.
"""
def __init__(self, model_name="facebook/mbart-large-50-many-to-many-mmt", src_lang=None, tgt_lang=None):
self.supported_languages = [
"en_XX",
"hi_IN",
"ml_IN",
"ta_IN",
"te_IN"
]
print("Building translator")
print("Loading tokenizer")
self.tokenizer = MBart50TokenizerFast.from_pretrained(model_name)
print("Laoding generator")
self.model = MBartForConditionalGeneration.from_pretrained(model_name)
def translate(self, text: str, input_language: str, output_language: str) -> str:
"""
Translates the given text from input to output langauge.
Args:
text (str): The text to translate.
input_language (str): The input language code (e.g. "en_XX" for English).
output_language (str): The output language code (e.g. "ta_IN" for Tamil).
Returns:
str: The translated text.
"""
if input_language not in self.supported_languages:
raise ValueError(f"Input language not supported. Supported languages: {self.supported_languages}")
if output_language not in self.supported_languages:
raise ValueError(f"Output language not supported. Supported languages: {self.supported_languages}")
self.tokenizer.src_lang = input_language
encoded_input = self.tokenizer(text, return_tensors='pt')
generated_tokens = self.model.generate(
**encoded_input, forced_bos_token_id=self.tokenizer.lang_code_to_id[output_language]
)
translated_text = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
return translated_text[0]
translator = MBartTranslator()
def translate_text(source_lang, target_lang, text):
return translator.translate(text, source_lang, target_lang)
translation_interface = gr.Interface(
fn=translate_text,
inputs=[
gr.Dropdown(choices=["en_XX", "hi_IN", "ml_IN", "ta_IN", "te_IN"], label = "Source Language"),
gr.Dropdown(choices=['en_XX', "hi_IN", "ml_IN", "ta_IN", "te_IN"], label = "Target Language"),
gr.Textbox(label = "Text to translate")
],
outputs=gr.Textbox(label = "Translated text")
)
translation_interface.launch()