Spaces:
Sleeping
Sleeping
""" | |
Module for visualizing image evaluation results and creating comparison tables. | |
""" | |
import numpy as np | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from matplotlib.colors import LinearSegmentedColormap | |
import os | |
import io | |
from PIL import Image | |
import base64 | |
class Visualizer: | |
"""Class for visualizing image evaluation results.""" | |
def __init__(self, output_dir='./results'): | |
""" | |
Initialize visualizer with output directory. | |
Args: | |
output_dir: directory to save visualization results | |
""" | |
self.output_dir = output_dir | |
os.makedirs(output_dir, exist_ok=True) | |
# Set up color schemes | |
self.setup_colors() | |
def setup_colors(self): | |
"""Set up color schemes for visualizations.""" | |
# Custom colormap for heatmaps | |
self.cmap = LinearSegmentedColormap.from_list( | |
'custom_cmap', ['#FF5E5B', '#FFED66', '#00CEFF', '#0089BA', '#008F7A'], N=256 | |
) | |
# Color palette for bar charts | |
self.palette = sns.color_palette("viridis", 10) | |
# Set Seaborn style | |
sns.set_style("whitegrid") | |
def create_comparison_table(self, results_dict, metrics_list=None): | |
""" | |
Create a comparison table from evaluation results. | |
Args: | |
results_dict: dictionary with model names as keys and evaluation results as values | |
metrics_list: list of metrics to include in the table (if None, include all) | |
Returns: | |
pandas.DataFrame: comparison table | |
""" | |
# Initialize empty dataframe | |
df = pd.DataFrame() | |
# Process each model's results | |
for model_name, model_results in results_dict.items(): | |
# Create a row for this model | |
model_row = {'Model': model_name} | |
# Add metrics to the row | |
for metric_name, metric_value in model_results.items(): | |
if metrics_list is None or metric_name in metrics_list: | |
# Format numeric values to 2 decimal places | |
if isinstance(metric_value, (int, float)): | |
model_row[metric_name] = round(metric_value, 2) | |
else: | |
model_row[metric_name] = metric_value | |
# Append to dataframe | |
df = pd.concat([df, pd.DataFrame([model_row])], ignore_index=True) | |
# Set Model as index | |
if not df.empty: | |
df.set_index('Model', inplace=True) | |
return df | |
def plot_metric_comparison(self, df, metric_name, title=None, figsize=(10, 6)): | |
""" | |
Create a bar chart comparing models on a specific metric. | |
Args: | |
df: pandas DataFrame with comparison data | |
metric_name: name of the metric to plot | |
title: optional custom title | |
figsize: figure size as (width, height) | |
Returns: | |
str: path to saved figure | |
""" | |
if metric_name not in df.columns: | |
raise ValueError(f"Metric '{metric_name}' not found in dataframe") | |
# Create figure | |
plt.figure(figsize=figsize) | |
# Create bar chart | |
ax = sns.barplot(x=df.index, y=df[metric_name], palette=self.palette) | |
# Set title and labels | |
if title: | |
plt.title(title, fontsize=14) | |
else: | |
plt.title(f"Model Comparison: {metric_name}", fontsize=14) | |
plt.xlabel("Model", fontsize=12) | |
plt.ylabel(metric_name, fontsize=12) | |
# Rotate x-axis labels for better readability | |
plt.xticks(rotation=45, ha='right') | |
# Add value labels on top of bars | |
for i, v in enumerate(df[metric_name]): | |
ax.text(i, v + 0.1, str(round(v, 2)), ha='center') | |
plt.tight_layout() | |
# Save figure | |
output_path = os.path.join(self.output_dir, f"{metric_name}_comparison.png") | |
plt.savefig(output_path, dpi=300, bbox_inches='tight') | |
plt.close() | |
return output_path | |
def plot_radar_chart(self, df, metrics_list, title=None, figsize=(10, 8)): | |
""" | |
Create a radar chart comparing models across multiple metrics. | |
Args: | |
df: pandas DataFrame with comparison data | |
metrics_list: list of metrics to include in the radar chart | |
title: optional custom title | |
figsize: figure size as (width, height) | |
Returns: | |
str: path to saved figure | |
""" | |
# Filter metrics that exist in the dataframe | |
metrics = [m for m in metrics_list if m in df.columns] | |
if not metrics: | |
raise ValueError("None of the specified metrics found in dataframe") | |
# Number of metrics | |
N = len(metrics) | |
# Create figure | |
fig = plt.figure(figsize=figsize) | |
ax = fig.add_subplot(111, polar=True) | |
# Compute angle for each metric | |
angles = [n / float(N) * 2 * np.pi for n in range(N)] | |
angles += angles[:1] # Close the loop | |
# Plot each model | |
for i, model in enumerate(df.index): | |
values = df.loc[model, metrics].values.flatten().tolist() | |
values += values[:1] # Close the loop | |
# Plot values | |
ax.plot(angles, values, linewidth=2, linestyle='solid', label=model, color=self.palette[i % len(self.palette)]) | |
ax.fill(angles, values, alpha=0.1, color=self.palette[i % len(self.palette)]) | |
# Set labels | |
plt.xticks(angles[:-1], metrics, size=12) | |
# Set y-axis limits | |
ax.set_ylim(0, 10) | |
# Add legend | |
plt.legend(loc='upper right', bbox_to_anchor=(0.1, 0.1)) | |
# Set title | |
if title: | |
plt.title(title, size=16, y=1.1) | |
else: | |
plt.title("Model Comparison Across Metrics", size=16, y=1.1) | |
# Save figure | |
output_path = os.path.join(self.output_dir, "radar_comparison.png") | |
plt.savefig(output_path, dpi=300, bbox_inches='tight') | |
plt.close() | |
return output_path | |
def plot_heatmap(self, df, title=None, figsize=(12, 8)): | |
""" | |
Create a heatmap of all metrics across models. | |
Args: | |
df: pandas DataFrame with comparison data | |
title: optional custom title | |
figsize: figure size as (width, height) | |
Returns: | |
str: path to saved figure | |
""" | |
# Create figure | |
plt.figure(figsize=figsize) | |
# Create heatmap | |
ax = sns.heatmap(df, annot=True, cmap=self.cmap, fmt=".2f", linewidths=.5) | |
# Set title | |
if title: | |
plt.title(title, fontsize=16) | |
else: | |
plt.title("Model Comparison Heatmap", fontsize=16) | |
plt.tight_layout() | |
# Save figure | |
output_path = os.path.join(self.output_dir, "comparison_heatmap.png") | |
plt.savefig(output_path, dpi=300, bbox_inches='tight') | |
plt.close() | |
return output_path | |
def plot_prompt_performance(self, prompt_results, metric_name, top_n=5, figsize=(12, 8)): | |
""" | |
Create a grouped bar chart showing model performance on different prompts. | |
Args: | |
prompt_results: dictionary with prompts as keys and model results as values | |
metric_name: name of the metric to plot | |
top_n: number of top prompts to include | |
figsize: figure size as (width, height) | |
Returns: | |
str: path to saved figure | |
""" | |
# Create dataframe from results | |
data = [] | |
for prompt, models_data in prompt_results.items(): | |
for model, metrics in models_data.items(): | |
if metric_name in metrics: | |
data.append({ | |
'Prompt': prompt, | |
'Model': model, | |
metric_name: metrics[metric_name] | |
}) | |
df = pd.DataFrame(data) | |
if df.empty: | |
raise ValueError(f"No data found for metric '{metric_name}'") | |
# Get top N prompts by average metric value | |
top_prompts = df.groupby('Prompt')[metric_name].mean().nlargest(top_n).index.tolist() | |
df_filtered = df[df['Prompt'].isin(top_prompts)] | |
# Create figure | |
plt.figure(figsize=figsize) | |
# Create grouped bar chart | |
ax = sns.barplot(x='Prompt', y=metric_name, hue='Model', data=df_filtered, palette=self.palette) | |
# Set title and labels | |
plt.title(f"Model Performance by Prompt: {metric_name}", fontsize=14) | |
plt.xlabel("Prompt", fontsize=12) | |
plt.ylabel(metric_name, fontsize=12) | |
# Rotate x-axis labels for better readability | |
plt.xticks(rotation=45, ha='right') | |
# Adjust legend | |
plt.legend(title="Model", bbox_to_anchor=(1.05, 1), loc='upper left') | |
plt.tight_layout() | |
# Save figure | |
output_path = os.path.join(self.output_dir, f"prompt_performance_{metric_name}.png") | |
plt.savefig(output_path, dpi=300, bbox_inches='tight') | |
plt.close() | |
return output_path | |
def create_image_grid(self, image_paths, titles=None, cols=3, figsize=(15, 15)): | |
""" | |
Create a grid of images for visual comparison. | |
Args: | |
image_paths: list of paths to images | |
titles: optional list of titles for each image | |
cols: number of columns in the grid | |
figsize: figure size as (width, height) | |
Returns: | |
str: path to saved figure | |
""" | |
# Calculate number of rows needed | |
rows = (len(image_paths) + cols - 1) // cols | |
# Create figure | |
fig, axes = plt.subplots(rows, cols, figsize=figsize) | |
axes = axes.flatten() | |
# Add each image to the grid | |
for i, img_path in enumerate(image_paths): | |
if i < len(axes): | |
try: | |
img = Image.open(img_path) | |
axes[i].imshow(np.array(img)) | |
# Add title if provided | |
if titles and i < len(titles): | |
axes[i].set_title(titles[i]) | |
# Remove axis ticks | |
axes[i].set_xticks([]) | |
axes[i].set_yticks([]) | |
except Exception as e: | |
print(f"Error loading image {img_path}: {e}") | |
axes[i].text(0.5, 0.5, f"Error loading image", ha='center', va='center') | |
axes[i].set_xticks([]) | |
axes[i].set_yticks([]) | |
# Hide unused subplots | |
for j in range(len(image_paths), len(axes)): | |
axes[j].axis('off') | |
plt.tight_layout() | |
# Save figure | |
output_path = os.path.join(self.output_dir, "image_comparison_grid.png") | |
plt.savefig(output_path, dpi=300, bbox_inches='tight') | |
plt.close() | |
return output_path | |
def export_comparison_table(self, df, format='csv'): | |
""" | |
Export comparison table to file. | |
Args: | |
df: pandas DataFrame with comparison data | |
format: export format ('csv', 'excel', or 'html') | |
Returns: | |
str: path to saved file | |
""" | |
if format == 'csv': | |
output_path = os.path.join(self.output_dir, "comparison_table.csv") | |
df.to_csv(output_path) | |
elif format == 'excel': | |
output_path = os.path.join(self.output_dir, "comparison_table.xlsx") | |
df.to_excel(output_path) | |
elif format == 'html': | |
output_path = os.path.join(self.output_dir, "comparison_table.html") | |
df.to_html(output_path) | |
else: | |
raise ValueError(f"Unsupported format: {format}") | |
return output_path | |
def generate_html_report(self, comparison_table, image_paths, metrics_list): | |
""" | |
Generate a comprehensive HTML report with all visualizations. | |
Args: | |
comparison_table: pandas DataFrame with comparison data | |
image_paths: dictionary of generated visualization image paths | |
metrics_list: list of metrics included in the analysis | |
Returns: | |
str: path to saved HTML report | |
""" | |
# Create HTML content | |
html_content = f""" | |
<!DOCTYPE html> | |
<html> | |
<head> | |
<title>Image Model Evaluation Report</title> | |
<style> | |
body {{ | |
font-family: Arial, sans-serif; | |
line-height: 1.6; | |
margin: 0; | |
padding: 20px; | |
color: #333; | |
}} | |
h1, h2, h3 {{ | |
color: #2c3e50; | |
}} | |
.container {{ | |
max-width: 1200px; | |
margin: 0 auto; | |
}} | |
table {{ | |
border-collapse: collapse; | |
width: 100%; | |
margin-bottom: 20px; | |
}} | |
th, td {{ | |
border: 1px solid #ddd; | |
padding: 8px; | |
text-align: left; | |
}} | |
th {{ | |
background-color: #f2f2f2; | |
}} | |
tr:nth-child(even) {{ | |
background-color: #f9f9f9; | |
}} | |
.visualization {{ | |
margin: 20px 0; | |
text-align: center; | |
}} | |
.visualization img {{ | |
max-width: 100%; | |
height: auto; | |
box-shadow: 0 4px 8px rgba(0,0,0,0.1); | |
}} | |
.metrics-list {{ | |
background-color: #f8f9fa; | |
padding: 15px; | |
border-radius: 5px; | |
margin-bottom: 20px; | |
}} | |
</style> | |
</head> | |
<body> | |
<div class="container"> | |
<h1>Image Model Evaluation Report</h1> | |
<h2>Metrics Overview</h2> | |
<div class="metrics-list"> | |
<h3>Metrics included in this analysis:</h3> | |
<ul> | |
""" | |
# Add metrics list | |
for metric in metrics_list: | |
html_content += f" <li><strong>{metric}</strong></li>\n" | |
html_content += """ | |
</ul> | |
</div> | |
<h2>Comparison Table</h2> | |
""" | |
# Add comparison table | |
html_content += comparison_table.to_html(classes="table table-striped") | |
# Add visualizations | |
html_content += """ | |
<h2>Visualizations</h2> | |
""" | |
for title, img_path in image_paths.items(): | |
if os.path.exists(img_path): | |
# Convert image to base64 for embedding | |
with open(img_path, "rb") as img_file: | |
img_data = base64.b64encode(img_file.read()).decode('utf-8') | |
html_content += f""" | |
<div class="visualization"> | |
<h3>{title}</h3> | |
<img src="data:image/png;base64,{img_data}" alt="{title}"> | |
</div> | |
""" | |
# Close HTML | |
html_content += """ | |
</div> | |
</body> | |
</html> | |
""" | |
# Save HTML report | |
output_path = os.path.join(self.output_dir, "evaluation_report.html") | |
with open(output_path, 'w', encoding='utf-8') as f: | |
f.write(html_content) | |
return output_path | |