Spaces:
Sleeping
Sleeping
| # utils/visualization.py (updated for maximum contrast) | |
| import matplotlib.pyplot as plt | |
| import matplotlib.colors as mcolors | |
| import base64 | |
| from io import BytesIO | |
| import numpy as np | |
| def create_visualization(text, explanation, tokenizer, explainer_type): | |
| """Create HTML visualization of token attributions with maximum contrast""" | |
| try: | |
| # Tokenize the text | |
| tokens = tokenizer.tokenize(text) | |
| # Handle different explanation formats | |
| token_values = {} | |
| if explainer_type == "LIME" and explanation: | |
| # LIME returns list of (feature, weight) tuples | |
| for feature, weight in explanation: | |
| # Extract individual tokens from LIME features | |
| feature_tokens = feature.split() | |
| for token in feature_tokens: | |
| # Clean token (remove punctuation, etc.) | |
| clean_token = token.strip('.,!?;:"()[]{}').lower() | |
| if clean_token: | |
| token_values[clean_token] = weight / len(feature_tokens) if feature_tokens else weight | |
| elif explainer_type in ["SHAP", "Captum"] and explanation: | |
| # SHAP and Captum return list of dicts with 'token' and 'value' | |
| for item in explanation: | |
| if isinstance(item, dict) and 'token' in item and 'value' in item: | |
| token = item['token'].lower() | |
| value = item['value'] | |
| token_values[token] = value | |
| # If no explanation data, create a neutral visualization | |
| if not token_values: | |
| html_output = ''' | |
| <div style="font-family: monospace; line-height: 2; padding: 15px; | |
| border-radius: 5px; background-color: #f9f9f9; | |
| border: 1px solid #ddd; margin: 10px 0; color: #666;"> | |
| <i>Explanation data not available. Showing tokenized text.</i><br> | |
| ''' | |
| for token in tokens: | |
| html_output += f'<span style="margin: 2px; padding: 4px 6px; display: inline-block; background-color: #f0f0f0; border: 1px solid #ccc; border-radius: 4px;">{token.replace("##", "")}</span> ' | |
| html_output += '</div>' | |
| return html_output | |
| # Normalize scores for coloring | |
| values = list(token_values.values()) | |
| max_abs_value = max(abs(min(values)), abs(max(values))) if values else 1 | |
| if max_abs_value > 0: | |
| normalized_values = {k: v / max_abs_value for k, v in token_values.items()} | |
| else: | |
| normalized_values = {k: 0 for k in token_values.keys()} | |
| # Create HTML | |
| html_output = ''' | |
| <div style="font-family: monospace; line-height: 2; padding: 15px; | |
| border-radius: 5px; background-color: #f9f9f9; | |
| border: 1px solid #ddd; margin: 10px 0;"> | |
| ''' | |
| # Map tokens to values with high contrast colors | |
| for token in tokens: | |
| clean_token = token.replace('##', '').lower() | |
| if clean_token in normalized_values: | |
| value = token_values[clean_token] | |
| norm_value = normalized_values[clean_token] | |
| # Determine color based on value with maximum contrast | |
| if value < 0: | |
| # Negative values: deep red colors | |
| intensity = min(1.0, 0.6 + 0.4 * abs(norm_value)) | |
| if intensity > 0.8: | |
| color = "#cc0000" # Very dark red | |
| elif intensity > 0.6: | |
| color = "#ff4444" # Dark red | |
| elif intensity > 0.4: | |
| color = "#ff8888" # Medium red | |
| else: | |
| color = "#ffcccc" # Light red | |
| text_color = "white" if intensity > 0.5 else "black" | |
| else: | |
| # Positive values: deep blue colors | |
| intensity = min(1.0, 0.6 + 0.4 * norm_value) | |
| if intensity > 0.8: | |
| color = "#0000cc" # Very dark blue | |
| elif intensity > 0.6: | |
| color = "#4444ff" # Dark blue | |
| elif intensity > 0.4: | |
| color = "#8888ff" # Medium blue | |
| else: | |
| color = "#ccccff" # Light blue | |
| text_color = "white" if intensity > 0.5 else "black" | |
| html_output += f'<span style="background-color: {color}; color: {text_color}; border: 2px solid {color}; margin: 2px; padding: 4px 6px; border-radius: 4px; display: inline-block; font-weight: bold;">{token.replace("##", "")}</span> ' | |
| else: | |
| html_output += f'<span style="margin: 2px; padding: 4px 6px; display: inline-block; background-color: #f0f0f0; border: 1px solid #ccc; border-radius: 4px;">{token.replace("##", "")}</span> ' | |
| html_output += '</div>' | |
| # Add color legend | |
| html_output += ''' | |
| <div style="margin-top: 10px; font-size: 12px; color: #666;"> | |
| <span style="background-color: #cc0000; color: white; padding: 2px 6px; border-radius: 3px; margin-right: 5px;">Strong negative</span> | |
| <span style="background-color: #ff8888; padding: 2px 6px; border-radius: 3px; margin-right: 5px;">Weak negative</span> | |
| <span style="background-color: #0000cc; color: white; padding: 2px 6px; border-radius: 3px; margin-right: 5px;">Strong positive</span> | |
| <span style="background-color: #8888ff; padding: 2px 6px; border-radius: 3px;">Weak positive</span> | |
| </div> | |
| ''' | |
| return html_output | |
| except Exception as e: | |
| print(f"Visualization error: {e}") | |
| return f'<div style="color: red; padding: 10px;">Error creating visualization: {str(e)}</div>' | |
| def create_attribution_plot(explanation, method_name): | |
| """Create matplotlib visualization of token attributions""" | |
| try: | |
| if not explanation: | |
| return "<p>No explanation data available</p>" | |
| # Handle different explanation formats | |
| if method_name == "LIME": | |
| # LIME: list of (feature, weight) tuples | |
| features = [item[0] for item in explanation][:15] # Show top 15 features | |
| scores = [item[1] for item in explanation][:15] | |
| title = f'Top Feature Attributions ({method_name})' | |
| else: | |
| # SHAP/Captum: list of dicts with 'token' and 'value' | |
| tokens = [item['token'] for item in explanation if isinstance(item, dict) and 'token' in item][:15] | |
| scores = [item['value'] for item in explanation if isinstance(item, dict) and 'value' in item][:15] | |
| features = tokens | |
| title = f'Top Token Attributions ({method_name})' | |
| if not features or not scores: | |
| return "<p>No valid explanation data available for plotting</p>" | |
| # Create plot with better colors | |
| fig, ax = plt.subplots(figsize=(12, 6)) | |
| # Create colors based on values - using high contrast colors | |
| colors = ['#ff6b6b' if score < 0 else '#4ecdc4' for score in scores] | |
| # Create horizontal bar chart | |
| y_pos = np.arange(len(features)) | |
| bars = ax.barh(y_pos, scores, color=colors, alpha=0.8, edgecolor='black', linewidth=0.5) | |
| # Customize plot | |
| ax.set_yticks(y_pos) | |
| ax.set_yticklabels(features, fontsize=10) | |
| ax.set_xlabel('Attribution Score', fontsize=12, fontweight='bold') | |
| ax.set_title(title, fontsize=14, fontweight='bold') | |
| ax.axvline(x=0, color='black', linestyle='-', alpha=0.5, linewidth=1) | |
| # Add grid for better readability | |
| ax.grid(True, alpha=0.3, axis='x') | |
| # Add value labels on bars | |
| for i, (bar, score) in enumerate(zip(bars, scores)): | |
| width = bar.get_width() | |
| label_x_pos = width + (0.01 * max(scores) if width >= 0 else 0.01 * min(scores)) | |
| ax.text(label_x_pos, bar.get_y() + bar.get_height()/2, | |
| f'{score:.4f}', ha='left' if width >= 0 else 'right', va='center', | |
| fontsize=9, fontweight='bold') | |
| # Set background color | |
| ax.set_facecolor('#f8f9fa') | |
| fig.patch.set_facecolor('#f8f9fa') | |
| plt.tight_layout() | |
| # Convert to HTML | |
| buf = BytesIO() | |
| plt.savefig(buf, format='png', dpi=100, bbox_inches='tight', facecolor=fig.get_facecolor()) | |
| buf.seek(0) | |
| img_str = base64.b64encode(buf.read()).decode('utf-8') | |
| plt.close(fig) | |
| return f'<img src="data:image/png;base64,{img_str}" style="max-width: 100%; border: 1px solid #ddd; border-radius: 5px;">' | |
| except Exception as e: | |
| print(f"Plot error: {e}") | |
| return f'<div style="color: red; padding: 10px;">Error creating plot: {str(e)}</div>' | |
| # utils/visualization.py (add this function) | |
| def create_confidence_chart(probabilities, class_names=None): | |
| """Create a bar chart showing class probabilities""" | |
| try: | |
| if class_names is None: | |
| class_names = [f"Class {i}" for i in range(len(probabilities))] | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| # Create bar chart | |
| bars = ax.bar(range(len(probabilities)), probabilities, | |
| color=['#ff6b6b', '#4ecdc4', '#45b7af', '#556270'][:len(probabilities)], | |
| alpha=0.8, edgecolor='black', linewidth=1) | |
| # Customize chart | |
| ax.set_xlabel('Classes', fontsize=12, fontweight='bold') | |
| ax.set_ylabel('Probability', fontsize=12, fontweight='bold') | |
| ax.set_title('Class Probability Distribution', fontsize=14, fontweight='bold') | |
| ax.set_xticks(range(len(probabilities))) | |
| ax.set_xticklabels(class_names, rotation=45, ha='right') | |
| ax.set_ylim(0, 1) | |
| # Add value labels on bars | |
| for i, (bar, prob) in enumerate(zip(bars, probabilities)): | |
| ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, | |
| f'{prob:.3f}', ha='center', va='bottom', fontweight='bold') | |
| # Add grid | |
| ax.grid(True, alpha=0.3, axis='y') | |
| # Set background color | |
| ax.set_facecolor('#f8f9fa') | |
| fig.patch.set_facecolor('#f8f9fa') | |
| plt.tight_layout() | |
| # Convert to HTML | |
| buf = BytesIO() | |
| plt.savefig(buf, format='png', dpi=100, bbox_inches='tight', facecolor=fig.get_facecolor()) | |
| buf.seek(0) | |
| img_str = base64.b64encode(buf.read()).decode('utf-8') | |
| plt.close(fig) | |
| return f'<img src="data:image/png;base64,{img_str}" style="max-width: 100%; border: 1px solid #ddd; border-radius: 5px;">' | |
| except Exception as e: | |
| print(f"Confidence chart error: {e}") | |
| return f'<div style="color: red; padding: 10px;">Error creating confidence chart: {str(e)}</div>' |