gopichandra's picture
Update app.py
01dbd9a verified
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import gradio as gr
import os
# Path to the saved model directory inside the Space
model_dir = "./campaign_bert_model/campaign_bert_model/campaign-bert-model"
# Load tokenizer and model
try:
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
model.eval()
except Exception as e:
raise RuntimeError(f"❌ Failed to load model or tokenizer: {e}")
# Map the 5 classes to tones and templates
class_map = {
0: ("Informative", "Here are plan details tailored for your interest."),
1: ("Excited", "Great news! You’re eligible for our premium plans!"),
2: ("Neutral", "Explore various insurance options with us."),
3: ("Persuasive", "Take the first step to secure your future today."),
4: ("Empathetic", "We understand your needs—here’s how we can help."),
}
def predict(text):
try:
if not text.strip():
return "<h3 style='color:red'>⚠️ Please enter a message.</h3>", ""
# Tokenize input
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
# Run model inference
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
if logits is None or logits.shape[1] != len(class_map):
return f"<h3 style='color:red'>❌ Invalid model output shape: {logits.shape}</h3>", ""
probs = torch.softmax(logits, dim=1)
pred_class = torch.argmax(probs, dim=1).item()
confidence = probs[0][pred_class].item()
tone, template = class_map[pred_class]
return (
f"<h3 style='color:green'>Tone: {tone}</h3><p>📨 Suggested Campaign Message:<br><b>{template}</b></p>",
f"<p>Confidence: <b>{confidence:.2%}</b></p>"
)
except Exception as e:
return f"<h3 style='color:red'>Error: {str(e)}</h3>", ""
# Gradio UI
iface = gr.Interface(
fn=predict,
inputs=gr.Textbox(label="Client Message", placeholder="E.g. I want to know about child education plans"),
outputs=[
gr.HTML(label="Prediction"),
gr.HTML(label="Confidence"),
],
title="📢 Campaign Personalizer",
description="Predicts message tone and template using a fine-tuned BERT model with 5 classes.",
allow_flagging="never"
)
iface.launch()