|
import torch |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import gradio as gr |
|
import os |
|
|
|
|
|
model_dir = "./campaign_bert_model/campaign_bert_model/campaign-bert-model" |
|
|
|
|
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained(model_dir) |
|
model = AutoModelForSequenceClassification.from_pretrained(model_dir) |
|
model.eval() |
|
except Exception as e: |
|
raise RuntimeError(f"❌ Failed to load model or tokenizer: {e}") |
|
|
|
|
|
class_map = { |
|
0: ("Informative", "Here are plan details tailored for your interest."), |
|
1: ("Excited", "Great news! You’re eligible for our premium plans!"), |
|
2: ("Neutral", "Explore various insurance options with us."), |
|
3: ("Persuasive", "Take the first step to secure your future today."), |
|
4: ("Empathetic", "We understand your needs—here’s how we can help."), |
|
} |
|
|
|
def predict(text): |
|
try: |
|
if not text.strip(): |
|
return "<h3 style='color:red'>⚠️ Please enter a message.</h3>", "" |
|
|
|
|
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
logits = outputs.logits |
|
|
|
if logits is None or logits.shape[1] != len(class_map): |
|
return f"<h3 style='color:red'>❌ Invalid model output shape: {logits.shape}</h3>", "" |
|
|
|
probs = torch.softmax(logits, dim=1) |
|
pred_class = torch.argmax(probs, dim=1).item() |
|
confidence = probs[0][pred_class].item() |
|
|
|
tone, template = class_map[pred_class] |
|
|
|
return ( |
|
f"<h3 style='color:green'>Tone: {tone}</h3><p>📨 Suggested Campaign Message:<br><b>{template}</b></p>", |
|
f"<p>Confidence: <b>{confidence:.2%}</b></p>" |
|
) |
|
|
|
except Exception as e: |
|
return f"<h3 style='color:red'>Error: {str(e)}</h3>", "" |
|
|
|
|
|
iface = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Textbox(label="Client Message", placeholder="E.g. I want to know about child education plans"), |
|
outputs=[ |
|
gr.HTML(label="Prediction"), |
|
gr.HTML(label="Confidence"), |
|
], |
|
title="📢 Campaign Personalizer", |
|
description="Predicts message tone and template using a fine-tuned BERT model with 5 classes.", |
|
allow_flagging="never" |
|
) |
|
|
|
iface.launch() |
|
|