jeevitha-app commited on
Commit
6e629f4
·
verified ·
1 Parent(s): 0426dfa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +204 -91
app.py CHANGED
@@ -1,108 +1,221 @@
 
 
1
 
2
-
3
- import gradio as gr
4
  import joblib
5
- import shap
6
  import numpy as np
 
 
 
 
 
 
 
 
7
 
8
- # ------------------------------------------------------------
9
- # 1️⃣ Load pre-trained models and vectorizers
10
- # ------------------------------------------------------------
11
- eng_model = joblib.load("best_model.pkl")
12
- eng_vectorizer = joblib.load("tfidf_vectorizer.pkl")
13
-
14
- per_model = joblib.load("logistic_regression.pkl")
15
- per_vectorizer = joblib.load("tfidf_vectorizer_persian.pkl")
16
-
17
- # ------------------------------------------------------------
18
- # 2️⃣ Define class labels
19
- # ------------------------------------------------------------
20
- class_names = ["Negative", "Neutral", "Positive"]
21
-
22
- # ------------------------------------------------------------
23
- # 3️⃣ Prediction Function
24
- # ------------------------------------------------------------
25
- def predict_sentiment(text, language):
26
- if not text.strip():
27
- return "⚠️ Please enter some text!", None, None
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  if language == "English":
30
  model = eng_model
31
  vectorizer = eng_vectorizer
 
32
  else:
33
  model = per_model
34
  vectorizer = per_vectorizer
 
 
 
 
35
 
36
- # Vectorize input
37
  vec = vectorizer.transform([text])
38
  probs = model.predict_proba(vec)[0]
39
- pred_class = np.argmax(probs)
40
  label = class_names[pred_class]
41
- confidence = probs[pred_class]
42
-
43
- # Interpret top words with SHAP
44
- explainer = shap.Explainer(model, vectorizer.transform(["sample"]))
45
- feature_names = vectorizer.get_feature_names_out()
46
- shap_values = explainer(vec)
47
-
48
- # top contributing words
49
- shap_vals = shap_values[0].values[:, pred_class]
50
- top_indices = np.argsort(-np.abs(shap_vals))[:10]
51
- top_words = [feature_names[i] for i in top_indices]
52
- top_contribs = shap_vals[top_indices]
53
-
54
- interpretation = {
55
- "words": top_words,
56
- "contributions": top_contribs.tolist()
57
- }
58
-
59
- return f"🎯 **{label}** (confidence: {confidence:.2f})", probs.tolist(), interpretation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
-
62
- # ------------------------------------------------------------
63
- # 4️⃣ Gradio Interface
64
- # ------------------------------------------------------------
65
- def gradio_ui(text, language):
66
- pred, probs, interp = predict_sentiment(text, language)
67
-
68
- if not probs:
69
- return pred, None, None
70
-
71
- # Confidence Bar Plot
72
- bar_plot = {cls: float(p) for cls, p in zip(class_names, probs)}
73
-
74
- # Word Contribution Table
75
- word_table = None
76
- if interp:
77
- word_table = {
78
- "Word": interp["words"],
79
- "SHAP Impact": interp["contributions"]
80
- }
81
-
82
- return pred, bar_plot, word_table
83
-
84
-
85
- # ------------------------------------------------------------
86
- # 5️⃣ Gradio Layout
87
- # ------------------------------------------------------------
88
- with gr.Blocks(theme=gr.themes.Soft()) as interface:
89
- gr.Markdown("
90
- Select language, enter text, and view predictions with interpretable SHAP insights!")
91
-
92
- lang = gr.Radio(["English", "Persian"], label="Select Dataset", value="English")
93
- text = gr.Textbox(label="Enter your text here", placeholder="Type an English or Persian comment...")
94
-
95
- output_pred = gr.Markdown(label="Prediction")
96
- output_bar = gr.BarPlot(label="Confidence per Class")
97
- output_table = gr.Dataframe(label="Top Influential Words", headers=["Word", "SHAP Impact"])
98
-
99
- btn = gr.Button("🔍 Analyze Sentiment")
100
-
101
- btn.click(fn=gradio_ui, inputs=[text, lang], outputs=[output_pred, output_bar, output_table])
102
-
103
- # ------------------------------------------------------------
104
- # 6️⃣ Launch App
105
- # ------------------------------------------------------------
106
  if __name__ == "__main__":
107
- interface.launch(share=True)
108
-
 
1
+ # app.py
2
+ # Gradio app: English + Persian sentiment with SHAP-based interpretability and word highlighting
3
 
 
 
4
  import joblib
 
5
  import numpy as np
6
+ import pandas as pd
7
+ import shap
8
+ import matplotlib.pyplot as plt
9
+ import io
10
+ import base64
11
+ import html
12
+ from typing import Tuple, Dict, List
13
+ import math
14
 
15
+ import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ # --------- Load models (replace filenames if you used different names) ----------
18
+ ENG_MODEL_PATH = "models/english_lr_model.pkl"
19
+ ENG_VEC_PATH = "models/english_vectorizer.pkl"
20
+ PER_MODEL_PATH = "models/persian_lr_model.pkl"
21
+ PER_VEC_PATH = "models/persian_vectorizer.pkl"
22
+
23
+ eng_model = joblib.load(ENG_MODEL_PATH)
24
+ eng_vectorizer = joblib.load(ENG_VEC_PATH)
25
+
26
+ per_model = joblib.load(PER_MODEL_PATH)
27
+ per_vectorizer = joblib.load(PER_VEC_PATH)
28
+
29
+ CLASS_NAMES_EN = ["Negative", "Neutral", "Positive"]
30
+ CLASS_NAMES_PER = ["منفی", "خنثی", "مثبت"]
31
+
32
+ # --------- Utility: create bar data for gradio BarPlot ----------
33
+ def probs_to_bar(probs: List[float], lang: str):
34
+ names = CLASS_NAMES_EN if lang == "English" else CLASS_NAMES_PER
35
+ return {names[i]: float(probs[i]) for i in range(len(probs))}
36
+
37
+ # --------- Utility: create HTML highlight from SHAP values ----------
38
+ def make_html_highlight(original_text: str,
39
+ feature_names: np.ndarray,
40
+ shap_values_feature: np.ndarray,
41
+ vectorizer_vocab: dict,
42
+ max_display: int = 30) -> str:
43
+ """
44
+ Simple token-level highlighting:
45
+ - Tokenize by whitespace (preserves original punctuation).
46
+ - For each token, attempt to map token.lower() to the vectorizer vocab;
47
+ if found, get SHAP impact for that feature name.
48
+ - Color red for positive contribution, blue for negative.
49
+ Returns an HTML-safe string.
50
+ """
51
+ # Build mapping word -> shap value if present in vocabulary
52
+ # vectorizer_vocab maps token -> idx in feature_names
53
+ token_to_shap = {}
54
+ for idx, fname in enumerate(feature_names):
55
+ # Often fname is the token/ngram itself
56
+ token_to_shap[fname] = shap_values_feature[idx]
57
+
58
+ # Tokenize (simple)
59
+ tokens = original_text.split()
60
+ # Compute max magnitude for scaling opacity
61
+ mags = []
62
+ for t in tokens:
63
+ key = t.lower()
64
+ val = None
65
+ # Try several common variants: exact, lower, strip punctuation from ends
66
+ if key in vectorizer_vocab:
67
+ val = shap_values_feature[vectorizer_vocab[key]]
68
+ else:
69
+ key2 = ''.join(ch for ch in key if ch.isalnum())
70
+ if key2 in vectorizer_vocab:
71
+ val = shap_values_feature[vectorizer_vocab[key2]]
72
+ mags.append(abs(val) if val is not None else 0.0)
73
+ max_mag = max(mags) if mags else 1.0
74
+ if max_mag == 0:
75
+ max_mag = 1.0
76
+
77
+ # Build HTML with span coloring
78
+ html_tokens = []
79
+ for t in tokens:
80
+ display = html.escape(t)
81
+ key = t.lower()
82
+ val = None
83
+ if key in vectorizer_vocab:
84
+ val = shap_values_feature[vectorizer_vocab[key]]
85
+ else:
86
+ key2 = ''.join(ch for ch in key if ch.isalnum())
87
+ if key2 in vectorizer_vocab:
88
+ val = shap_values_feature[vectorizer_vocab[key2]]
89
+ if val is None or abs(val) < 1e-6:
90
+ html_tokens.append(f"<span style='padding:2px'>{display}</span>")
91
+ else:
92
+ sign = "pos" if val > 0 else "neg"
93
+ mag = min(1.0, abs(val) / max_mag) # scale 0..1
94
+ opacity = 0.15 + 0.85 * mag # avoid fully transparent
95
+ color = f"rgba(220,20,60,{opacity})" if sign == "pos" else f"rgba(30,144,255,{opacity})"
96
+ border = "1px solid rgba(0,0,0,0.04)"
97
+ html_tokens.append(
98
+ f"<span style='background:{color};padding:2px;margin:1px;border-radius:4px;display:inline-block;{border}'>"
99
+ f"{display}</span>"
100
+ )
101
+
102
+ highlighted_html = "<div style='line-height:1.6;font-size:16px'>" + " ".join(html_tokens) + "</div>"
103
+ return highlighted_html
104
+
105
+ # --------- Core function: predict + interpret ----------
106
+ def explain_and_predict(text: str, language: str):
107
+ text = text or ""
108
  if language == "English":
109
  model = eng_model
110
  vectorizer = eng_vectorizer
111
+ class_names = CLASS_NAMES_EN
112
  else:
113
  model = per_model
114
  vectorizer = per_vectorizer
115
+ class_names = CLASS_NAMES_PER
116
+
117
+ if text.strip() == "":
118
+ return "⚠️ Please enter text.", {}, {"Word": [], "SHAP Impact": []}, "<i>No input</i>"
119
 
120
+ # vectorize
121
  vec = vectorizer.transform([text])
122
  probs = model.predict_proba(vec)[0]
123
+ pred_class = int(np.argmax(probs))
124
  label = class_names[pred_class]
125
+ confidence = float(probs[pred_class])
126
+
127
+ # Build SHAP explainer on a small background (use small subset via dummy background)
128
+ # NOTE: building explainer can be slow; in Spaces you can build once at import
129
+ # For robustness we build a simple LinearExplainer on vector space
130
+ # Use small dense sample from training if available - here use vectorizer vocabulary size fallback
131
+ # Convert to dense for LinearExplainer
132
+ try:
133
+ # Use a small background of zeros (cheap) — LinearExplainer can accept arrays
134
+ background = np.zeros((1, vec.shape[1]))
135
+ explainer = shap.LinearExplainer(model, background, feature_names=vectorizer.get_feature_names_out())
136
+ # compute shap on the numeric vector
137
+ vec_dense = vec.toarray()
138
+ shap_vals = explainer(vec_dense) # returns shap.Explanation
139
+ except Exception:
140
+ # fallback: use PermutationExplainer on numeric input (slower)
141
+ explainer = shap.Explainer(model.predict_proba, vec)
142
+ shap_vals = explainer(vec)
143
+
144
+ # shap_vals.values shape: (n_outputs, n_features) OR Explanation with values (n_features, n_classes)
145
+ # Normalize to feature vector for chosen class
146
+ # shap_vals may be multi-output: shap_vals.values => (n_samples, n_features, n_classes) or similar
147
+ try:
148
+ # preferred shape: shap_vals.values -> (1, n_features, n_classes)
149
+ values = shap_vals.values # ND array
150
+ if values.ndim == 3:
151
+ # pick sample 0, class pred_class
152
+ shap_per_feature = values[0, :, pred_class]
153
+ elif values.ndim == 2:
154
+ # shape (n_samples, n_features) for single class models — take sample 0
155
+ shap_per_feature = values[0, :]
156
+ else:
157
+ # try to flatten
158
+ shap_per_feature = np.ravel(values)[0:vec.shape[1]]
159
+ except Exception:
160
+ # Last resort: try shap_vals[0].values
161
+ try:
162
+ shap_per_feature = shap_vals[0].values[:, pred_class]
163
+ except Exception:
164
+ shap_per_feature = np.zeros(vec.shape[1])
165
+
166
+ # Feature names & vocab
167
+ feature_names = np.array(vectorizer.get_feature_names_out())
168
+ vocab = {k: v for k, v in (getattr(vectorizer, "vocabulary_", {})).items()}
169
+
170
+ # Build top contributing words list (pairs)
171
+ # shap_per_feature length must match len(feature_names)
172
+ if len(shap_per_feature) != len(feature_names):
173
+ # try to align by vectorizer.vocabulary_
174
+ full_shap = np.zeros(len(feature_names))
175
+ # if shap_per_feature smaller, attempt to use indices from vocab
176
+ min_len = min(len(shap_per_feature), len(full_shap))
177
+ full_shap[:min_len] = shap_per_feature[:min_len]
178
+ shap_per_feature = full_shap
179
+
180
+ # Top positive and negative features
181
+ n = 10
182
+ idx_sorted = np.argsort(-np.abs(shap_per_feature))
183
+ top_idx = idx_sorted[:n]
184
+ top_words = feature_names[top_idx].tolist()
185
+ top_contribs = shap_per_feature[top_idx].tolist()
186
+
187
+ # Build word table for display
188
+ word_table = {"Word": top_words, "SHAP Impact": top_contribs}
189
+
190
+ # Build highlight HTML (token-level approx using unigram mapping)
191
+ highlight_html = make_html_highlight(text, feature_names, shap_per_feature, vocab)
192
+
193
+ # Return: label string, probabilities dict, table dict, html highlight
194
+ return f"🎯 **{label}** (confidence: {confidence:.2f})", probs_to_bar(probs.tolist(), language), word_table, highlight_html
195
+
196
+
197
+ # --------- Gradio UI build ----------
198
+ with gr.Blocks() as demo:
199
+ gr.Markdown("## 🌍 Multilingual Sentiment Analysis (English 🇬🇧 & Persian 🇮🇷) — Interpretable")
200
+ with gr.Row():
201
+ language = gr.Radio(["English", "Persian"], value="English", label="Choose language")
202
+ text_input = gr.Textbox(lines=4, placeholder="Type comment here...", label="Input text")
203
+ with gr.Row():
204
+ btn = gr.Button("Analyze")
205
+ with gr.Row():
206
+ pred_out = gr.Markdown()
207
+ with gr.Row():
208
+ bar = gr.BarPlot(label="Class probabilities")
209
+ table = gr.Dataframe(headers=["Word", "SHAP Impact"], label="Top contributing words")
210
+ with gr.Row():
211
+ html_out = gr.HTML(label="Word-level Highlight (red = pushes toward prediction, blue = pushes away)")
212
+
213
+ def run(text, lang):
214
+ label, probs, word_table, html_highlight = explain_and_predict(text, lang)
215
+ # format outputs for gradio
216
+ return label, probs, pd.DataFrame(word_table), html_highlight
217
+
218
+ btn.click(fn=run, inputs=[text_input, language], outputs=[pred_out, bar, table, html_out])
219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  if __name__ == "__main__":
221
+ demo.launch(server_name="0.0.0.0", share=True)