Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,108 +1,221 @@
|
|
|
|
|
|
|
|
| 1 |
|
| 2 |
-
|
| 3 |
-
import gradio as gr
|
| 4 |
import joblib
|
| 5 |
-
import shap
|
| 6 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
-
|
| 9 |
-
# 1️⃣ Load pre-trained models and vectorizers
|
| 10 |
-
# ------------------------------------------------------------
|
| 11 |
-
eng_model = joblib.load("best_model.pkl")
|
| 12 |
-
eng_vectorizer = joblib.load("tfidf_vectorizer.pkl")
|
| 13 |
-
|
| 14 |
-
per_model = joblib.load("logistic_regression.pkl")
|
| 15 |
-
per_vectorizer = joblib.load("tfidf_vectorizer_persian.pkl")
|
| 16 |
-
|
| 17 |
-
# ------------------------------------------------------------
|
| 18 |
-
# 2️⃣ Define class labels
|
| 19 |
-
# ------------------------------------------------------------
|
| 20 |
-
class_names = ["Negative", "Neutral", "Positive"]
|
| 21 |
-
|
| 22 |
-
# ------------------------------------------------------------
|
| 23 |
-
# 3️⃣ Prediction Function
|
| 24 |
-
# ------------------------------------------------------------
|
| 25 |
-
def predict_sentiment(text, language):
|
| 26 |
-
if not text.strip():
|
| 27 |
-
return "⚠️ Please enter some text!", None, None
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
if language == "English":
|
| 30 |
model = eng_model
|
| 31 |
vectorizer = eng_vectorizer
|
|
|
|
| 32 |
else:
|
| 33 |
model = per_model
|
| 34 |
vectorizer = per_vectorizer
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
-
#
|
| 37 |
vec = vectorizer.transform([text])
|
| 38 |
probs = model.predict_proba(vec)[0]
|
| 39 |
-
pred_class = np.argmax(probs)
|
| 40 |
label = class_names[pred_class]
|
| 41 |
-
confidence = probs[pred_class]
|
| 42 |
-
|
| 43 |
-
#
|
| 44 |
-
explainer
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
-
|
| 62 |
-
# ------------------------------------------------------------
|
| 63 |
-
# 4️⃣ Gradio Interface
|
| 64 |
-
# ------------------------------------------------------------
|
| 65 |
-
def gradio_ui(text, language):
|
| 66 |
-
pred, probs, interp = predict_sentiment(text, language)
|
| 67 |
-
|
| 68 |
-
if not probs:
|
| 69 |
-
return pred, None, None
|
| 70 |
-
|
| 71 |
-
# Confidence Bar Plot
|
| 72 |
-
bar_plot = {cls: float(p) for cls, p in zip(class_names, probs)}
|
| 73 |
-
|
| 74 |
-
# Word Contribution Table
|
| 75 |
-
word_table = None
|
| 76 |
-
if interp:
|
| 77 |
-
word_table = {
|
| 78 |
-
"Word": interp["words"],
|
| 79 |
-
"SHAP Impact": interp["contributions"]
|
| 80 |
-
}
|
| 81 |
-
|
| 82 |
-
return pred, bar_plot, word_table
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
# ------------------------------------------------------------
|
| 86 |
-
# 5️⃣ Gradio Layout
|
| 87 |
-
# ------------------------------------------------------------
|
| 88 |
-
with gr.Blocks(theme=gr.themes.Soft()) as interface:
|
| 89 |
-
gr.Markdown("
|
| 90 |
-
Select language, enter text, and view predictions with interpretable SHAP insights!")
|
| 91 |
-
|
| 92 |
-
lang = gr.Radio(["English", "Persian"], label="Select Dataset", value="English")
|
| 93 |
-
text = gr.Textbox(label="Enter your text here", placeholder="Type an English or Persian comment...")
|
| 94 |
-
|
| 95 |
-
output_pred = gr.Markdown(label="Prediction")
|
| 96 |
-
output_bar = gr.BarPlot(label="Confidence per Class")
|
| 97 |
-
output_table = gr.Dataframe(label="Top Influential Words", headers=["Word", "SHAP Impact"])
|
| 98 |
-
|
| 99 |
-
btn = gr.Button("🔍 Analyze Sentiment")
|
| 100 |
-
|
| 101 |
-
btn.click(fn=gradio_ui, inputs=[text, lang], outputs=[output_pred, output_bar, output_table])
|
| 102 |
-
|
| 103 |
-
# ------------------------------------------------------------
|
| 104 |
-
# 6️⃣ Launch App
|
| 105 |
-
# ------------------------------------------------------------
|
| 106 |
if __name__ == "__main__":
|
| 107 |
-
|
| 108 |
-
|
|
|
|
| 1 |
+
# app.py
|
| 2 |
+
# Gradio app: English + Persian sentiment with SHAP-based interpretability and word highlighting
|
| 3 |
|
|
|
|
|
|
|
| 4 |
import joblib
|
|
|
|
| 5 |
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import shap
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
import io
|
| 10 |
+
import base64
|
| 11 |
+
import html
|
| 12 |
+
from typing import Tuple, Dict, List
|
| 13 |
+
import math
|
| 14 |
|
| 15 |
+
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
+
# --------- Load models (replace filenames if you used different names) ----------
|
| 18 |
+
ENG_MODEL_PATH = "models/english_lr_model.pkl"
|
| 19 |
+
ENG_VEC_PATH = "models/english_vectorizer.pkl"
|
| 20 |
+
PER_MODEL_PATH = "models/persian_lr_model.pkl"
|
| 21 |
+
PER_VEC_PATH = "models/persian_vectorizer.pkl"
|
| 22 |
+
|
| 23 |
+
eng_model = joblib.load(ENG_MODEL_PATH)
|
| 24 |
+
eng_vectorizer = joblib.load(ENG_VEC_PATH)
|
| 25 |
+
|
| 26 |
+
per_model = joblib.load(PER_MODEL_PATH)
|
| 27 |
+
per_vectorizer = joblib.load(PER_VEC_PATH)
|
| 28 |
+
|
| 29 |
+
CLASS_NAMES_EN = ["Negative", "Neutral", "Positive"]
|
| 30 |
+
CLASS_NAMES_PER = ["منفی", "خنثی", "مثبت"]
|
| 31 |
+
|
| 32 |
+
# --------- Utility: create bar data for gradio BarPlot ----------
|
| 33 |
+
def probs_to_bar(probs: List[float], lang: str):
|
| 34 |
+
names = CLASS_NAMES_EN if lang == "English" else CLASS_NAMES_PER
|
| 35 |
+
return {names[i]: float(probs[i]) for i in range(len(probs))}
|
| 36 |
+
|
| 37 |
+
# --------- Utility: create HTML highlight from SHAP values ----------
|
| 38 |
+
def make_html_highlight(original_text: str,
|
| 39 |
+
feature_names: np.ndarray,
|
| 40 |
+
shap_values_feature: np.ndarray,
|
| 41 |
+
vectorizer_vocab: dict,
|
| 42 |
+
max_display: int = 30) -> str:
|
| 43 |
+
"""
|
| 44 |
+
Simple token-level highlighting:
|
| 45 |
+
- Tokenize by whitespace (preserves original punctuation).
|
| 46 |
+
- For each token, attempt to map token.lower() to the vectorizer vocab;
|
| 47 |
+
if found, get SHAP impact for that feature name.
|
| 48 |
+
- Color red for positive contribution, blue for negative.
|
| 49 |
+
Returns an HTML-safe string.
|
| 50 |
+
"""
|
| 51 |
+
# Build mapping word -> shap value if present in vocabulary
|
| 52 |
+
# vectorizer_vocab maps token -> idx in feature_names
|
| 53 |
+
token_to_shap = {}
|
| 54 |
+
for idx, fname in enumerate(feature_names):
|
| 55 |
+
# Often fname is the token/ngram itself
|
| 56 |
+
token_to_shap[fname] = shap_values_feature[idx]
|
| 57 |
+
|
| 58 |
+
# Tokenize (simple)
|
| 59 |
+
tokens = original_text.split()
|
| 60 |
+
# Compute max magnitude for scaling opacity
|
| 61 |
+
mags = []
|
| 62 |
+
for t in tokens:
|
| 63 |
+
key = t.lower()
|
| 64 |
+
val = None
|
| 65 |
+
# Try several common variants: exact, lower, strip punctuation from ends
|
| 66 |
+
if key in vectorizer_vocab:
|
| 67 |
+
val = shap_values_feature[vectorizer_vocab[key]]
|
| 68 |
+
else:
|
| 69 |
+
key2 = ''.join(ch for ch in key if ch.isalnum())
|
| 70 |
+
if key2 in vectorizer_vocab:
|
| 71 |
+
val = shap_values_feature[vectorizer_vocab[key2]]
|
| 72 |
+
mags.append(abs(val) if val is not None else 0.0)
|
| 73 |
+
max_mag = max(mags) if mags else 1.0
|
| 74 |
+
if max_mag == 0:
|
| 75 |
+
max_mag = 1.0
|
| 76 |
+
|
| 77 |
+
# Build HTML with span coloring
|
| 78 |
+
html_tokens = []
|
| 79 |
+
for t in tokens:
|
| 80 |
+
display = html.escape(t)
|
| 81 |
+
key = t.lower()
|
| 82 |
+
val = None
|
| 83 |
+
if key in vectorizer_vocab:
|
| 84 |
+
val = shap_values_feature[vectorizer_vocab[key]]
|
| 85 |
+
else:
|
| 86 |
+
key2 = ''.join(ch for ch in key if ch.isalnum())
|
| 87 |
+
if key2 in vectorizer_vocab:
|
| 88 |
+
val = shap_values_feature[vectorizer_vocab[key2]]
|
| 89 |
+
if val is None or abs(val) < 1e-6:
|
| 90 |
+
html_tokens.append(f"<span style='padding:2px'>{display}</span>")
|
| 91 |
+
else:
|
| 92 |
+
sign = "pos" if val > 0 else "neg"
|
| 93 |
+
mag = min(1.0, abs(val) / max_mag) # scale 0..1
|
| 94 |
+
opacity = 0.15 + 0.85 * mag # avoid fully transparent
|
| 95 |
+
color = f"rgba(220,20,60,{opacity})" if sign == "pos" else f"rgba(30,144,255,{opacity})"
|
| 96 |
+
border = "1px solid rgba(0,0,0,0.04)"
|
| 97 |
+
html_tokens.append(
|
| 98 |
+
f"<span style='background:{color};padding:2px;margin:1px;border-radius:4px;display:inline-block;{border}'>"
|
| 99 |
+
f"{display}</span>"
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
highlighted_html = "<div style='line-height:1.6;font-size:16px'>" + " ".join(html_tokens) + "</div>"
|
| 103 |
+
return highlighted_html
|
| 104 |
+
|
| 105 |
+
# --------- Core function: predict + interpret ----------
|
| 106 |
+
def explain_and_predict(text: str, language: str):
|
| 107 |
+
text = text or ""
|
| 108 |
if language == "English":
|
| 109 |
model = eng_model
|
| 110 |
vectorizer = eng_vectorizer
|
| 111 |
+
class_names = CLASS_NAMES_EN
|
| 112 |
else:
|
| 113 |
model = per_model
|
| 114 |
vectorizer = per_vectorizer
|
| 115 |
+
class_names = CLASS_NAMES_PER
|
| 116 |
+
|
| 117 |
+
if text.strip() == "":
|
| 118 |
+
return "⚠️ Please enter text.", {}, {"Word": [], "SHAP Impact": []}, "<i>No input</i>"
|
| 119 |
|
| 120 |
+
# vectorize
|
| 121 |
vec = vectorizer.transform([text])
|
| 122 |
probs = model.predict_proba(vec)[0]
|
| 123 |
+
pred_class = int(np.argmax(probs))
|
| 124 |
label = class_names[pred_class]
|
| 125 |
+
confidence = float(probs[pred_class])
|
| 126 |
+
|
| 127 |
+
# Build SHAP explainer on a small background (use small subset via dummy background)
|
| 128 |
+
# NOTE: building explainer can be slow; in Spaces you can build once at import
|
| 129 |
+
# For robustness we build a simple LinearExplainer on vector space
|
| 130 |
+
# Use small dense sample from training if available - here use vectorizer vocabulary size fallback
|
| 131 |
+
# Convert to dense for LinearExplainer
|
| 132 |
+
try:
|
| 133 |
+
# Use a small background of zeros (cheap) — LinearExplainer can accept arrays
|
| 134 |
+
background = np.zeros((1, vec.shape[1]))
|
| 135 |
+
explainer = shap.LinearExplainer(model, background, feature_names=vectorizer.get_feature_names_out())
|
| 136 |
+
# compute shap on the numeric vector
|
| 137 |
+
vec_dense = vec.toarray()
|
| 138 |
+
shap_vals = explainer(vec_dense) # returns shap.Explanation
|
| 139 |
+
except Exception:
|
| 140 |
+
# fallback: use PermutationExplainer on numeric input (slower)
|
| 141 |
+
explainer = shap.Explainer(model.predict_proba, vec)
|
| 142 |
+
shap_vals = explainer(vec)
|
| 143 |
+
|
| 144 |
+
# shap_vals.values shape: (n_outputs, n_features) OR Explanation with values (n_features, n_classes)
|
| 145 |
+
# Normalize to feature vector for chosen class
|
| 146 |
+
# shap_vals may be multi-output: shap_vals.values => (n_samples, n_features, n_classes) or similar
|
| 147 |
+
try:
|
| 148 |
+
# preferred shape: shap_vals.values -> (1, n_features, n_classes)
|
| 149 |
+
values = shap_vals.values # ND array
|
| 150 |
+
if values.ndim == 3:
|
| 151 |
+
# pick sample 0, class pred_class
|
| 152 |
+
shap_per_feature = values[0, :, pred_class]
|
| 153 |
+
elif values.ndim == 2:
|
| 154 |
+
# shape (n_samples, n_features) for single class models — take sample 0
|
| 155 |
+
shap_per_feature = values[0, :]
|
| 156 |
+
else:
|
| 157 |
+
# try to flatten
|
| 158 |
+
shap_per_feature = np.ravel(values)[0:vec.shape[1]]
|
| 159 |
+
except Exception:
|
| 160 |
+
# Last resort: try shap_vals[0].values
|
| 161 |
+
try:
|
| 162 |
+
shap_per_feature = shap_vals[0].values[:, pred_class]
|
| 163 |
+
except Exception:
|
| 164 |
+
shap_per_feature = np.zeros(vec.shape[1])
|
| 165 |
+
|
| 166 |
+
# Feature names & vocab
|
| 167 |
+
feature_names = np.array(vectorizer.get_feature_names_out())
|
| 168 |
+
vocab = {k: v for k, v in (getattr(vectorizer, "vocabulary_", {})).items()}
|
| 169 |
+
|
| 170 |
+
# Build top contributing words list (pairs)
|
| 171 |
+
# shap_per_feature length must match len(feature_names)
|
| 172 |
+
if len(shap_per_feature) != len(feature_names):
|
| 173 |
+
# try to align by vectorizer.vocabulary_
|
| 174 |
+
full_shap = np.zeros(len(feature_names))
|
| 175 |
+
# if shap_per_feature smaller, attempt to use indices from vocab
|
| 176 |
+
min_len = min(len(shap_per_feature), len(full_shap))
|
| 177 |
+
full_shap[:min_len] = shap_per_feature[:min_len]
|
| 178 |
+
shap_per_feature = full_shap
|
| 179 |
+
|
| 180 |
+
# Top positive and negative features
|
| 181 |
+
n = 10
|
| 182 |
+
idx_sorted = np.argsort(-np.abs(shap_per_feature))
|
| 183 |
+
top_idx = idx_sorted[:n]
|
| 184 |
+
top_words = feature_names[top_idx].tolist()
|
| 185 |
+
top_contribs = shap_per_feature[top_idx].tolist()
|
| 186 |
+
|
| 187 |
+
# Build word table for display
|
| 188 |
+
word_table = {"Word": top_words, "SHAP Impact": top_contribs}
|
| 189 |
+
|
| 190 |
+
# Build highlight HTML (token-level approx using unigram mapping)
|
| 191 |
+
highlight_html = make_html_highlight(text, feature_names, shap_per_feature, vocab)
|
| 192 |
+
|
| 193 |
+
# Return: label string, probabilities dict, table dict, html highlight
|
| 194 |
+
return f"🎯 **{label}** (confidence: {confidence:.2f})", probs_to_bar(probs.tolist(), language), word_table, highlight_html
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
# --------- Gradio UI build ----------
|
| 198 |
+
with gr.Blocks() as demo:
|
| 199 |
+
gr.Markdown("## 🌍 Multilingual Sentiment Analysis (English 🇬🇧 & Persian 🇮🇷) — Interpretable")
|
| 200 |
+
with gr.Row():
|
| 201 |
+
language = gr.Radio(["English", "Persian"], value="English", label="Choose language")
|
| 202 |
+
text_input = gr.Textbox(lines=4, placeholder="Type comment here...", label="Input text")
|
| 203 |
+
with gr.Row():
|
| 204 |
+
btn = gr.Button("Analyze")
|
| 205 |
+
with gr.Row():
|
| 206 |
+
pred_out = gr.Markdown()
|
| 207 |
+
with gr.Row():
|
| 208 |
+
bar = gr.BarPlot(label="Class probabilities")
|
| 209 |
+
table = gr.Dataframe(headers=["Word", "SHAP Impact"], label="Top contributing words")
|
| 210 |
+
with gr.Row():
|
| 211 |
+
html_out = gr.HTML(label="Word-level Highlight (red = pushes toward prediction, blue = pushes away)")
|
| 212 |
+
|
| 213 |
+
def run(text, lang):
|
| 214 |
+
label, probs, word_table, html_highlight = explain_and_predict(text, lang)
|
| 215 |
+
# format outputs for gradio
|
| 216 |
+
return label, probs, pd.DataFrame(word_table), html_highlight
|
| 217 |
+
|
| 218 |
+
btn.click(fn=run, inputs=[text_input, language], outputs=[pred_out, bar, table, html_out])
|
| 219 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
if __name__ == "__main__":
|
| 221 |
+
demo.launch(server_name="0.0.0.0", share=True)
|
|
|