Explainability_Sandbox / utils /visualization.py
dyra1222's picture
new added changes
fb007f1
# utils/visualization.py (updated for maximum contrast)
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import base64
from io import BytesIO
import numpy as np
def create_visualization(text, explanation, tokenizer, explainer_type):
"""Create HTML visualization of token attributions with maximum contrast"""
try:
# Tokenize the text
tokens = tokenizer.tokenize(text)
# Handle different explanation formats
token_values = {}
if explainer_type == "LIME" and explanation:
# LIME returns list of (feature, weight) tuples
for feature, weight in explanation:
# Extract individual tokens from LIME features
feature_tokens = feature.split()
for token in feature_tokens:
# Clean token (remove punctuation, etc.)
clean_token = token.strip('.,!?;:"()[]{}').lower()
if clean_token:
token_values[clean_token] = weight / len(feature_tokens) if feature_tokens else weight
elif explainer_type in ["SHAP", "Captum"] and explanation:
# SHAP and Captum return list of dicts with 'token' and 'value'
for item in explanation:
if isinstance(item, dict) and 'token' in item and 'value' in item:
token = item['token'].lower()
value = item['value']
token_values[token] = value
# If no explanation data, create a neutral visualization
if not token_values:
html_output = '''
<div style="font-family: monospace; line-height: 2; padding: 15px;
border-radius: 5px; background-color: #f9f9f9;
border: 1px solid #ddd; margin: 10px 0; color: #666;">
<i>Explanation data not available. Showing tokenized text.</i><br>
'''
for token in tokens:
html_output += f'<span style="margin: 2px; padding: 4px 6px; display: inline-block; background-color: #f0f0f0; border: 1px solid #ccc; border-radius: 4px;">{token.replace("##", "")}</span> '
html_output += '</div>'
return html_output
# Normalize scores for coloring
values = list(token_values.values())
max_abs_value = max(abs(min(values)), abs(max(values))) if values else 1
if max_abs_value > 0:
normalized_values = {k: v / max_abs_value for k, v in token_values.items()}
else:
normalized_values = {k: 0 for k in token_values.keys()}
# Create HTML
html_output = '''
<div style="font-family: monospace; line-height: 2; padding: 15px;
border-radius: 5px; background-color: #f9f9f9;
border: 1px solid #ddd; margin: 10px 0;">
'''
# Map tokens to values with high contrast colors
for token in tokens:
clean_token = token.replace('##', '').lower()
if clean_token in normalized_values:
value = token_values[clean_token]
norm_value = normalized_values[clean_token]
# Determine color based on value with maximum contrast
if value < 0:
# Negative values: deep red colors
intensity = min(1.0, 0.6 + 0.4 * abs(norm_value))
if intensity > 0.8:
color = "#cc0000" # Very dark red
elif intensity > 0.6:
color = "#ff4444" # Dark red
elif intensity > 0.4:
color = "#ff8888" # Medium red
else:
color = "#ffcccc" # Light red
text_color = "white" if intensity > 0.5 else "black"
else:
# Positive values: deep blue colors
intensity = min(1.0, 0.6 + 0.4 * norm_value)
if intensity > 0.8:
color = "#0000cc" # Very dark blue
elif intensity > 0.6:
color = "#4444ff" # Dark blue
elif intensity > 0.4:
color = "#8888ff" # Medium blue
else:
color = "#ccccff" # Light blue
text_color = "white" if intensity > 0.5 else "black"
html_output += f'<span style="background-color: {color}; color: {text_color}; border: 2px solid {color}; margin: 2px; padding: 4px 6px; border-radius: 4px; display: inline-block; font-weight: bold;">{token.replace("##", "")}</span> '
else:
html_output += f'<span style="margin: 2px; padding: 4px 6px; display: inline-block; background-color: #f0f0f0; border: 1px solid #ccc; border-radius: 4px;">{token.replace("##", "")}</span> '
html_output += '</div>'
# Add color legend
html_output += '''
<div style="margin-top: 10px; font-size: 12px; color: #666;">
<span style="background-color: #cc0000; color: white; padding: 2px 6px; border-radius: 3px; margin-right: 5px;">Strong negative</span>
<span style="background-color: #ff8888; padding: 2px 6px; border-radius: 3px; margin-right: 5px;">Weak negative</span>
<span style="background-color: #0000cc; color: white; padding: 2px 6px; border-radius: 3px; margin-right: 5px;">Strong positive</span>
<span style="background-color: #8888ff; padding: 2px 6px; border-radius: 3px;">Weak positive</span>
</div>
'''
return html_output
except Exception as e:
print(f"Visualization error: {e}")
return f'<div style="color: red; padding: 10px;">Error creating visualization: {str(e)}</div>'
def create_attribution_plot(explanation, method_name):
"""Create matplotlib visualization of token attributions"""
try:
if not explanation:
return "<p>No explanation data available</p>"
# Handle different explanation formats
if method_name == "LIME":
# LIME: list of (feature, weight) tuples
features = [item[0] for item in explanation][:15] # Show top 15 features
scores = [item[1] for item in explanation][:15]
title = f'Top Feature Attributions ({method_name})'
else:
# SHAP/Captum: list of dicts with 'token' and 'value'
tokens = [item['token'] for item in explanation if isinstance(item, dict) and 'token' in item][:15]
scores = [item['value'] for item in explanation if isinstance(item, dict) and 'value' in item][:15]
features = tokens
title = f'Top Token Attributions ({method_name})'
if not features or not scores:
return "<p>No valid explanation data available for plotting</p>"
# Create plot with better colors
fig, ax = plt.subplots(figsize=(12, 6))
# Create colors based on values - using high contrast colors
colors = ['#ff6b6b' if score < 0 else '#4ecdc4' for score in scores]
# Create horizontal bar chart
y_pos = np.arange(len(features))
bars = ax.barh(y_pos, scores, color=colors, alpha=0.8, edgecolor='black', linewidth=0.5)
# Customize plot
ax.set_yticks(y_pos)
ax.set_yticklabels(features, fontsize=10)
ax.set_xlabel('Attribution Score', fontsize=12, fontweight='bold')
ax.set_title(title, fontsize=14, fontweight='bold')
ax.axvline(x=0, color='black', linestyle='-', alpha=0.5, linewidth=1)
# Add grid for better readability
ax.grid(True, alpha=0.3, axis='x')
# Add value labels on bars
for i, (bar, score) in enumerate(zip(bars, scores)):
width = bar.get_width()
label_x_pos = width + (0.01 * max(scores) if width >= 0 else 0.01 * min(scores))
ax.text(label_x_pos, bar.get_y() + bar.get_height()/2,
f'{score:.4f}', ha='left' if width >= 0 else 'right', va='center',
fontsize=9, fontweight='bold')
# Set background color
ax.set_facecolor('#f8f9fa')
fig.patch.set_facecolor('#f8f9fa')
plt.tight_layout()
# Convert to HTML
buf = BytesIO()
plt.savefig(buf, format='png', dpi=100, bbox_inches='tight', facecolor=fig.get_facecolor())
buf.seek(0)
img_str = base64.b64encode(buf.read()).decode('utf-8')
plt.close(fig)
return f'<img src="data:image/png;base64,{img_str}" style="max-width: 100%; border: 1px solid #ddd; border-radius: 5px;">'
except Exception as e:
print(f"Plot error: {e}")
return f'<div style="color: red; padding: 10px;">Error creating plot: {str(e)}</div>'
# utils/visualization.py (add this function)
def create_confidence_chart(probabilities, class_names=None):
"""Create a bar chart showing class probabilities"""
try:
if class_names is None:
class_names = [f"Class {i}" for i in range(len(probabilities))]
fig, ax = plt.subplots(figsize=(10, 6))
# Create bar chart
bars = ax.bar(range(len(probabilities)), probabilities,
color=['#ff6b6b', '#4ecdc4', '#45b7af', '#556270'][:len(probabilities)],
alpha=0.8, edgecolor='black', linewidth=1)
# Customize chart
ax.set_xlabel('Classes', fontsize=12, fontweight='bold')
ax.set_ylabel('Probability', fontsize=12, fontweight='bold')
ax.set_title('Class Probability Distribution', fontsize=14, fontweight='bold')
ax.set_xticks(range(len(probabilities)))
ax.set_xticklabels(class_names, rotation=45, ha='right')
ax.set_ylim(0, 1)
# Add value labels on bars
for i, (bar, prob) in enumerate(zip(bars, probabilities)):
ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
f'{prob:.3f}', ha='center', va='bottom', fontweight='bold')
# Add grid
ax.grid(True, alpha=0.3, axis='y')
# Set background color
ax.set_facecolor('#f8f9fa')
fig.patch.set_facecolor('#f8f9fa')
plt.tight_layout()
# Convert to HTML
buf = BytesIO()
plt.savefig(buf, format='png', dpi=100, bbox_inches='tight', facecolor=fig.get_facecolor())
buf.seek(0)
img_str = base64.b64encode(buf.read()).decode('utf-8')
plt.close(fig)
return f'<img src="data:image/png;base64,{img_str}" style="max-width: 100%; border: 1px solid #ddd; border-radius: 5px;">'
except Exception as e:
print(f"Confidence chart error: {e}")
return f'<div style="color: red; padding: 10px;">Error creating confidence chart: {str(e)}</div>'