Spaces:
Runtime error
Runtime error
File size: 5,238 Bytes
00a6112 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import gradio as gr
import lime
from lime.lime_text import LimeTextExplainer
import numpy as np
from datasets import load_dataset
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import make_pipeline
import shap
import matplotlib.pyplot as plt
import io
from PIL import Image
import pandas as pd
# Load the IMDB dataset using Hugging Face datasets
dataset = load_dataset('imdb')
# Extract the training and test sets
text_train = [review['text'] for review in dataset['train']]
y_train = [review['label'] for review in dataset['train']]
text_test = [review['text'] for review in dataset['test']]
y_test = [review['label'] for review in dataset['test']]
# Convert the text data into a TF-IDF representation
vectorizer = TfidfVectorizer(stop_words='english', max_features=5000)
X_train = vectorizer.fit_transform(text_train)
X_test = vectorizer.transform(text_test)
# Split the training data into train and validation sets
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42)
# Train a logistic regression model
model = LogisticRegression(max_iter=1000)
model.fit(X_train, y_train)
# Initialize LIME explainer
lime_explainer = LimeTextExplainer(class_names=['Negative', 'Positive'])
# Create a SHAP explainer object
shap_explainer = shap.LinearExplainer(model, X_train)
def explain_text(input_text):
# Predict label
input_vector = vectorizer.transform([input_text])
predicted_label = model.predict(input_vector)[0]
label_name = 'Positive' if predicted_label == 1 else 'Negative'
# LIME explanation
def predict_proba_for_lime(texts):
return model.predict_proba(vectorizer.transform(texts))
lime_exp = lime_explainer.explain_instance(input_text, predict_proba_for_lime, num_features=10)
lime_fig = lime_exp.as_pyplot_figure()
lime_img = fig_to_nparray(lime_fig)
# Get the complete HTML for LIME explanation
lime_html = lime_exp.as_html()
# SHAP explanation
shap_values = shap_explainer.shap_values(input_vector)[0]
feature_names = vectorizer.get_feature_names_out()
# Create a SHAP explanation object for the selected instance
shap_explanation = shap.Explanation(
values=shap_values,
base_values=shap_explainer.expected_value,
feature_names=feature_names,
data=input_vector.toarray()[0]
)
# Function to highlight text based on SHAP values
def highlight_text_shap(text, word_importances, feature_names, max_num_features):
words = text.split()
word_to_importance = {}
for idx, word in enumerate(feature_names):
if word in text.lower():
word_to_importance[word] = word_importances[idx]
sorted_word_importance = sorted(word_to_importance.items(), key=lambda x: abs(x[1]), reverse=True)[:max_num_features]
top_words = {word: importance for word, importance in sorted_word_importance}
highlighted_text = []
for word in words:
cleaned_word = ''.join(filter(str.isalnum, word)).lower()
if cleaned_word in top_words:
importance = top_words[cleaned_word]
color = 'red' if importance > 0 else 'blue'
highlighted_text.append(f'<span style="color:{color}">{word}</span>')
else:
highlighted_text.append(word)
return ' '.join(highlighted_text)
# Set the maximum number of features to display
max_num_features = 10
# Create a DataFrame for SHAP values
shap_df = pd.DataFrame({
'Feature': shap_explanation.feature_names,
'SHAP Value': shap_explanation.values
}).sort_values(by='SHAP Value', ascending=False).head(max_num_features)
# Plot the SHAP values
plt.figure(figsize=(10, 6))
plt.barh(shap_df['Feature'], shap_df['SHAP Value'], color=['red' if val > 0 else 'blue' for val in shap_df['SHAP Value']])
plt.xlabel('SHAP Value')
plt.title('Top 10 Feature Importance')
plt.tight_layout()
shap_fig = fig_to_nparray(plt.gcf())
# Highlight the text based on SHAP values
shap_highlighted_text = highlight_text_shap(input_text, shap_values, feature_names, max_num_features)
return label_name, lime_img, shap_fig, lime_html, shap_highlighted_text
def fig_to_nparray(fig):
"""Convert a matplotlib figure to a NumPy array."""
buf = io.BytesIO()
fig.savefig(buf, format='png')
buf.seek(0)
img = Image.open(buf)
return np.array(img)
# Create Gradio interface
iface = gr.Interface(
fn=explain_text,
inputs=gr.Textbox(lines=2, placeholder="Enter your text here..."),
outputs=[
gr.Label(label="Predicted Label"),
gr.Image(type="numpy", label="LIME Explanation"),
gr.Image(type="numpy", label="SHAP Explanation"),
gr.HTML(label="LIME Highlighted Text Explanation"),
gr.HTML(label="SHAP Highlighted Text Explanation"),
],
title="LIME and SHAP Explanations",
description="Enter a text sample to see its prediction and explanations using LIME and SHAP."
)
# Launch the interface
iface.launch() |