Spaces:
Sleeping
Sleeping
| """ | |
| SHAP Text Explainer — Word-level attribution for text classification | |
| Course: 215 AI Safety ch8 | |
| """ | |
| import numpy as np | |
| import torch | |
| import gradio as gr | |
| from transformers import pipeline | |
| # Load sentiment model | |
| classifier = pipeline( | |
| "sentiment-analysis", | |
| model="distilbert-base-uncased-finetuned-sst-2-english", | |
| return_all_scores=True, | |
| ) | |
| LABEL_NAMES = ["NEGATIVE", "POSITIVE"] | |
| def simple_word_attribution(text: str): | |
| """ | |
| Compute word-level attribution using leave-one-out (LOO) method. | |
| Faster and more reliable than full SHAP on CPU. | |
| """ | |
| if not text.strip(): | |
| return "", "", {} | |
| # Baseline prediction | |
| base_result = classifier(text)[0] | |
| base_scores = {r["label"]: r["score"] for r in base_result} | |
| pred_label = max(base_scores, key=base_scores.get) | |
| pred_score = base_scores[pred_label] | |
| words = text.split() | |
| if len(words) == 0: | |
| return "", "", {} | |
| # LOO attribution | |
| attributions = [] | |
| for i in range(len(words)): | |
| masked = " ".join(words[:i] + words[i + 1 :]) | |
| if not masked.strip(): | |
| attributions.append(0.0) | |
| continue | |
| result = classifier(masked)[0] | |
| masked_scores = {r["label"]: r["score"] for r in result} | |
| # Attribution = how much removing this word changes the predicted class score | |
| diff = base_scores[pred_label] - masked_scores[pred_label] | |
| attributions.append(diff) | |
| # Normalize for display | |
| max_abs = max(abs(a) for a in attributions) if attributions else 1.0 | |
| if max_abs == 0: | |
| max_abs = 1.0 | |
| # Build highlighted HTML | |
| html_parts = [] | |
| for word, attr in zip(words, attributions): | |
| norm_attr = attr / max_abs # -1 to 1 | |
| if norm_attr > 0: | |
| # Pushes toward prediction (red = positive contribution) | |
| intensity = min(int(abs(norm_attr) * 200), 200) | |
| bg = f"rgba(239, 68, 68, {abs(norm_attr) * 0.6})" | |
| else: | |
| # Pushes against prediction (blue = negative contribution) | |
| intensity = min(int(abs(norm_attr) * 200), 200) | |
| bg = f"rgba(59, 130, 246, {abs(norm_attr) * 0.6})" | |
| html_parts.append( | |
| f'<span style="background:{bg};padding:2px 4px;margin:1px;' | |
| f'border-radius:3px;display:inline-block;">{word}</span>' | |
| ) | |
| highlighted_html = ( | |
| '<div style="font-size:16px;line-height:2;padding:10px;">' | |
| + " ".join(html_parts) | |
| + "</div>" | |
| ) | |
| # Legend | |
| legend = ( | |
| '<div style="margin-top:10px;font-size:13px;">' | |
| '<span style="background:rgba(239,68,68,0.5);padding:2px 8px;border-radius:3px;">Red</span>' | |
| f" = pushes toward {pred_label} " | |
| '<span style="background:rgba(59,130,246,0.5);padding:2px 8px;border-radius:3px;">Blue</span>' | |
| f" = pushes against {pred_label} " | |
| "(intensity = strength)" | |
| "</div>" | |
| ) | |
| # Prediction info | |
| pred_info = ( | |
| f"**Prediction: {pred_label}** ({pred_score:.1%})\n\n" | |
| f"| Label | Score |\n|---|---|\n" | |
| ) | |
| for r in base_result: | |
| pred_info += f"| {r['label']} | {r['score']:.1%} |\n" | |
| # Attribution table | |
| pred_info += "\n**Word attributions (leave-one-out):**\n\n" | |
| pred_info += "| Word | Attribution | Effect |\n|---|---|---|\n" | |
| sorted_attr = sorted( | |
| zip(words, attributions), key=lambda x: abs(x[1]), reverse=True | |
| ) | |
| for word, attr in sorted_attr[:15]: | |
| direction = "supports" if attr > 0 else "opposes" | |
| pred_info += f"| {word} | {attr:+.4f} | {direction} |\n" | |
| return highlighted_html + legend, pred_info, base_scores | |
| with gr.Blocks(title="SHAP Text Explainer") as demo: | |
| gr.Markdown( | |
| "# SHAP Text Explainer\n" | |
| "Enter text to see which words contribute most to the sentiment prediction.\n" | |
| "Uses leave-one-out attribution (similar to SHAP) for word-level explanations.\n" | |
| "*Course: 215 AI Safety ch8 — Explainability*" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| text_input = gr.Textbox( | |
| label="Input Text", | |
| placeholder="Enter a sentence to analyze...", | |
| lines=3, | |
| ) | |
| btn = gr.Button("Explain", variant="primary") | |
| gr.Markdown( | |
| "*Note: Each word is removed one at a time to measure its impact. " | |
| "This takes a few seconds for longer texts.*" | |
| ) | |
| with gr.Column(): | |
| highlighted = gr.HTML(label="Word Attribution") | |
| details_md = gr.Markdown() | |
| btn.click( | |
| lambda t: simple_word_attribution(t)[:2], | |
| [text_input], | |
| [highlighted, details_md], | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| "This movie was absolutely fantastic! The acting was superb and the plot kept me engaged throughout.", | |
| "The food was terrible and the service was even worse. I will never go back to this restaurant.", | |
| "The product works okay but nothing special. It does what it says but I expected more for the price.", | |
| "I love how this book combines beautiful writing with deep philosophical insights.", | |
| "The flight was delayed by 3 hours and the airline offered no compensation or explanation.", | |
| ], | |
| inputs=[text_input], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |