demo-banglat5 / app.py
Sanzana Lora
Create app.py
39a2488 verified
raw
history blame
2.82 kB
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import gradio as gr
# Load fine-tuned T5 models for different tasks
translation_model_en_bn = AutoModelForSeq2SeqLM.from_pretrained("csebuetnlp/banglat5_nmt_en_bn")
translation_tokenizer_en_bn = AutoTokenizer.from_pretrained("csebuetnlp/banglat5_nmt_en_bn")
translation_model_bn_en = AutoModelForSeq2SeqLM.from_pretrained("csebuetnlp/banglat5_nmt_bn_en")
translation_tokenizer_bn_en = AutoTokenizer.from_pretrained("csebuetnlp/banglat5_nmt_bn_en")
summarization_model = AutoModelForSeq2SeqLM.from_pretrained("csebuetnlp/mT5_multilingual_XLSum")
summarization_tokenizer = AutoTokenizer.from_pretrained("csebuetnlp/mT5_multilingual_XLSum")
paraphrase_model = AutoModelForSeq2SeqLM.from_pretrained("csebuetnlp/banglat5_banglaparaphrase")
paraphrase_tokenizer = AutoTokenizer.from_pretrained("csebuetnlp/banglat5_banglaparaphrase")
# Function to perform machine translation
def translate_text(input_text):
inputs = translation_tokenizer_en_bn("translate: " + input_text, return_tensors="pt")
outputs = translation_model_en_bn.generate(**inputs)
translated_text = translation_tokenizer_en_bn.decode(outputs[0], skip_special_tokens=True)
return translated_text
# Function to perform summarization
def summarize_text(input_text):
inputs = summarization_tokenizer("summarize: " + input_text, return_tensors="pt")
outputs = summarization_model.generate(**inputs)
summarized_text = summarization_tokenizer.decode(outputs[0], skip_special_tokens=True)
return summarized_text
# Function to perform paraphrasing
def paraphrase_text(input_text):
inputs = paraphrase_tokenizer("paraphrase: " + input_text, return_tensors="pt")
outputs = paraphrase_model.generate(**inputs)
paraphrased_text = paraphrase_tokenizer.decode(outputs[0], skip_special_tokens=True)
return paraphrased_text
# Gradio Interface
iface = gr.Interface(
fn=translate_text, # Placeholder function; will be updated dynamically based on task selection
inputs=gr.Textbox("textarea", label="Input Text"),
outputs=gr.Textbox("auto", label="Output Text"),
live=True
)
# Function to update the Gradio interface based on task selection
def update_interface(change):
selected_task = task_selector.value
if selected_task == 'Translate':
iface.fn = translate_text
elif selected_task == 'Summarize':
iface.fn = summarize_text
elif selected_task == 'Paraphrase':
iface.fn = paraphrase_text
# Dropdown widget to select the task
task_selector = gr.Dropdown(
["Translate", "Summarize", "Paraphrase"],
default="Translate",
label="Select Task"
)
# Attach the update function to the dropdown widget
task_selector.observe(update_interface, names='value')
# Launch the Gradio app
iface.launch()