Spaces:
Sleeping
Sleeping
import pandas as pd | |
import gradio as gr | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
import torch | |
import os | |
import re | |
# Load the model and tokenizer | |
model_name = "google/flan-t5-base" | |
hf_token = os.environ.get("HF_TOKEN") # Set as a secret in Hugging Face Space settings | |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_token) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, use_auth_token=hf_token) | |
# Move the model to CPU (or GPU if available) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
# Function to generate a clean prompt | |
def generate_prompt(original, translation): | |
return ( | |
f"Rate the quality of this translation from 0 (poor) to 1 (excellent). " | |
f"Only respond with a number.\n\n" | |
f"Source: {original}\n" | |
f"Translation: {translation}\n" | |
f"Score:" | |
) | |
# Main prediction function | |
def predict_scores(file): | |
df = pd.read_csv(file.name, sep="\t") | |
scores = [] | |
for _, row in df.iterrows(): | |
prompt = generate_prompt(row["original"], row["translation"]) | |
# Tokenize and send to model | |
inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
outputs = model.generate(**inputs, max_new_tokens=10) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Debug print (optional) | |
print("Response:", response) | |
# Extract numeric score using regex | |
match = re.search(r"\b([01](?:\.\d+)?)\b", response) | |
if match: | |
score_val = float(match.group(1)) | |
score_val = max(0, min(score_val, 1)) # Clamp between 0 and 1 | |
else: | |
score_val = -1 # fallback if model output is invalid | |
scores.append(score_val) | |
df["predicted_score"] = scores | |
return df | |
# Gradio UI | |
iface = gr.Interface( | |
fn=predict_scores, | |
inputs=gr.File(label="Upload dev.tsv"), | |
outputs=gr.Dataframe(label="QE Output with Predicted Score"), | |
title="MT QE with FLAN-T5-Base", | |
description="Upload a dev.tsv file with columns: 'original' and 'translation'." | |
) | |
# Launch app | |
iface.launch() | |