FakeAgreement / app.py
himanshubeniwal's picture
Update app.py
a4d6934 verified
import gradio as gr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
# Load the model and tokenizer
model_name = "himanshubeniwal/opus-mt-en-ro-finetuned-ro-to-en-agreement"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
def translate_text(
text,
max_length=200,
num_beams=4,
num_return_sequences=1,
temperature=1.0,
repetition_penalty=1.0
):
"""
Translate Romanian text to English using the fine-tuned model
"""
try:
# Encode the input text
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
# Generate translation
outputs = model.generate(
inputs["input_ids"],
max_length=max_length,
num_beams=num_beams,
num_return_sequences=num_return_sequences,
temperature=temperature,
repetition_penalty=repetition_penalty,
early_stopping=True
)
# Decode and format the translations
translations = []
for output in outputs:
translation = tokenizer.decode(output, skip_special_tokens=True)
translations.append(translation)
# Return single string if only one sequence, otherwise return list
if num_return_sequences == 1:
return translations[0]
return "\n\n".join(translations)
except Exception as e:
return f"Error in translation: {str(e)}"
# Create the Gradio interface
iface = gr.Interface(
fn=translate_text,
inputs=[
gr.Textbox(lines=5, label="Enter Romanian text"),
gr.Slider(minimum=10, maximum=400, value=200, step=10, label="Max Length"),
gr.Slider(minimum=1, maximum=8, value=4, step=1, label="Number of Beams"),
gr.Slider(minimum=1, maximum=3, value=1, step=1, label="Number of Translations"),
gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature"),
gr.Slider(minimum=1.0, maximum=2.0, value=1.0, step=0.1, label="Repetition Penalty")
],
outputs=gr.Textbox(lines=5, label="English Translation"),
title="Romanian to English Translation Model",
description="Translate Romanian text to English using a fine-tuned OPUS-MT model. Examples include common phrases like: \n 1. 'Guvernul dumneavoastră are un acord cu Japonia.' (Your Government has an agreement with Japan.) \n 2,'Instanța a spus că au un acord.' (The court said they had an agreement.)",
examples=[
["Guvernul dumneavoastră are un acord cu Japonia.", 200, 4, 1, 1.0, 1.0],
["Instanța a spus că au un acord.", 200, 4, 1, 1.0, 1.0],
]
)
# Launch the interface
if __name__ == "__main__":
iface.launch()