""" Module for visualizing image evaluation results and creating comparison tables. """ import numpy as np import pandas as pd import matplotlib.pyplot as plt import seaborn as sns from matplotlib.colors import LinearSegmentedColormap import os import io from PIL import Image import base64 class Visualizer: """Class for visualizing image evaluation results.""" def __init__(self, output_dir='./results'): """ Initialize visualizer with output directory. Args: output_dir: directory to save visualization results """ self.output_dir = output_dir os.makedirs(output_dir, exist_ok=True) # Set up color schemes self.setup_colors() def setup_colors(self): """Set up color schemes for visualizations.""" # Custom colormap for heatmaps self.cmap = LinearSegmentedColormap.from_list( 'custom_cmap', ['#FF5E5B', '#FFED66', '#00CEFF', '#0089BA', '#008F7A'], N=256 ) # Color palette for bar charts self.palette = sns.color_palette("viridis", 10) # Set Seaborn style sns.set_style("whitegrid") def create_comparison_table(self, results_dict, metrics_list=None): """ Create a comparison table from evaluation results. Args: results_dict: dictionary with model names as keys and evaluation results as values metrics_list: list of metrics to include in the table (if None, include all) Returns: pandas.DataFrame: comparison table """ # Initialize empty dataframe df = pd.DataFrame() # Process each model's results for model_name, model_results in results_dict.items(): # Create a row for this model model_row = {'Model': model_name} # Add metrics to the row for metric_name, metric_value in model_results.items(): if metrics_list is None or metric_name in metrics_list: # Format numeric values to 2 decimal places if isinstance(metric_value, (int, float)): model_row[metric_name] = round(metric_value, 2) else: model_row[metric_name] = metric_value # Append to dataframe df = pd.concat([df, pd.DataFrame([model_row])], ignore_index=True) # Set Model as index if not df.empty: df.set_index('Model', inplace=True) return df def plot_metric_comparison(self, df, metric_name, title=None, figsize=(10, 6)): """ Create a bar chart comparing models on a specific metric. Args: df: pandas DataFrame with comparison data metric_name: name of the metric to plot title: optional custom title figsize: figure size as (width, height) Returns: str: path to saved figure """ if metric_name not in df.columns: raise ValueError(f"Metric '{metric_name}' not found in dataframe") # Create figure plt.figure(figsize=figsize) # Create bar chart ax = sns.barplot(x=df.index, y=df[metric_name], palette=self.palette) # Set title and labels if title: plt.title(title, fontsize=14) else: plt.title(f"Model Comparison: {metric_name}", fontsize=14) plt.xlabel("Model", fontsize=12) plt.ylabel(metric_name, fontsize=12) # Rotate x-axis labels for better readability plt.xticks(rotation=45, ha='right') # Add value labels on top of bars for i, v in enumerate(df[metric_name]): ax.text(i, v + 0.1, str(round(v, 2)), ha='center') plt.tight_layout() # Save figure output_path = os.path.join(self.output_dir, f"{metric_name}_comparison.png") plt.savefig(output_path, dpi=300, bbox_inches='tight') plt.close() return output_path def plot_radar_chart(self, df, metrics_list, title=None, figsize=(10, 8)): """ Create a radar chart comparing models across multiple metrics. Args: df: pandas DataFrame with comparison data metrics_list: list of metrics to include in the radar chart title: optional custom title figsize: figure size as (width, height) Returns: str: path to saved figure """ # Filter metrics that exist in the dataframe metrics = [m for m in metrics_list if m in df.columns] if not metrics: raise ValueError("None of the specified metrics found in dataframe") # Number of metrics N = len(metrics) # Create figure fig = plt.figure(figsize=figsize) ax = fig.add_subplot(111, polar=True) # Compute angle for each metric angles = [n / float(N) * 2 * np.pi for n in range(N)] angles += angles[:1] # Close the loop # Plot each model for i, model in enumerate(df.index): values = df.loc[model, metrics].values.flatten().tolist() values += values[:1] # Close the loop # Plot values ax.plot(angles, values, linewidth=2, linestyle='solid', label=model, color=self.palette[i % len(self.palette)]) ax.fill(angles, values, alpha=0.1, color=self.palette[i % len(self.palette)]) # Set labels plt.xticks(angles[:-1], metrics, size=12) # Set y-axis limits ax.set_ylim(0, 10) # Add legend plt.legend(loc='upper right', bbox_to_anchor=(0.1, 0.1)) # Set title if title: plt.title(title, size=16, y=1.1) else: plt.title("Model Comparison Across Metrics", size=16, y=1.1) # Save figure output_path = os.path.join(self.output_dir, "radar_comparison.png") plt.savefig(output_path, dpi=300, bbox_inches='tight') plt.close() return output_path def plot_heatmap(self, df, title=None, figsize=(12, 8)): """ Create a heatmap of all metrics across models. Args: df: pandas DataFrame with comparison data title: optional custom title figsize: figure size as (width, height) Returns: str: path to saved figure """ # Create figure plt.figure(figsize=figsize) # Create heatmap ax = sns.heatmap(df, annot=True, cmap=self.cmap, fmt=".2f", linewidths=.5) # Set title if title: plt.title(title, fontsize=16) else: plt.title("Model Comparison Heatmap", fontsize=16) plt.tight_layout() # Save figure output_path = os.path.join(self.output_dir, "comparison_heatmap.png") plt.savefig(output_path, dpi=300, bbox_inches='tight') plt.close() return output_path def plot_prompt_performance(self, prompt_results, metric_name, top_n=5, figsize=(12, 8)): """ Create a grouped bar chart showing model performance on different prompts. Args: prompt_results: dictionary with prompts as keys and model results as values metric_name: name of the metric to plot top_n: number of top prompts to include figsize: figure size as (width, height) Returns: str: path to saved figure """ # Create dataframe from results data = [] for prompt, models_data in prompt_results.items(): for model, metrics in models_data.items(): if metric_name in metrics: data.append({ 'Prompt': prompt, 'Model': model, metric_name: metrics[metric_name] }) df = pd.DataFrame(data) if df.empty: raise ValueError(f"No data found for metric '{metric_name}'") # Get top N prompts by average metric value top_prompts = df.groupby('Prompt')[metric_name].mean().nlargest(top_n).index.tolist() df_filtered = df[df['Prompt'].isin(top_prompts)] # Create figure plt.figure(figsize=figsize) # Create grouped bar chart ax = sns.barplot(x='Prompt', y=metric_name, hue='Model', data=df_filtered, palette=self.palette) # Set title and labels plt.title(f"Model Performance by Prompt: {metric_name}", fontsize=14) plt.xlabel("Prompt", fontsize=12) plt.ylabel(metric_name, fontsize=12) # Rotate x-axis labels for better readability plt.xticks(rotation=45, ha='right') # Adjust legend plt.legend(title="Model", bbox_to_anchor=(1.05, 1), loc='upper left') plt.tight_layout() # Save figure output_path = os.path.join(self.output_dir, f"prompt_performance_{metric_name}.png") plt.savefig(output_path, dpi=300, bbox_inches='tight') plt.close() return output_path def create_image_grid(self, image_paths, titles=None, cols=3, figsize=(15, 15)): """ Create a grid of images for visual comparison. Args: image_paths: list of paths to images titles: optional list of titles for each image cols: number of columns in the grid figsize: figure size as (width, height) Returns: str: path to saved figure """ # Calculate number of rows needed rows = (len(image_paths) + cols - 1) // cols # Create figure fig, axes = plt.subplots(rows, cols, figsize=figsize) axes = axes.flatten() # Add each image to the grid for i, img_path in enumerate(image_paths): if i < len(axes): try: img = Image.open(img_path) axes[i].imshow(np.array(img)) # Add title if provided if titles and i < len(titles): axes[i].set_title(titles[i]) # Remove axis ticks axes[i].set_xticks([]) axes[i].set_yticks([]) except Exception as e: print(f"Error loading image {img_path}: {e}") axes[i].text(0.5, 0.5, f"Error loading image", ha='center', va='center') axes[i].set_xticks([]) axes[i].set_yticks([]) # Hide unused subplots for j in range(len(image_paths), len(axes)): axes[j].axis('off') plt.tight_layout() # Save figure output_path = os.path.join(self.output_dir, "image_comparison_grid.png") plt.savefig(output_path, dpi=300, bbox_inches='tight') plt.close() return output_path def export_comparison_table(self, df, format='csv'): """ Export comparison table to file. Args: df: pandas DataFrame with comparison data format: export format ('csv', 'excel', or 'html') Returns: str: path to saved file """ if format == 'csv': output_path = os.path.join(self.output_dir, "comparison_table.csv") df.to_csv(output_path) elif format == 'excel': output_path = os.path.join(self.output_dir, "comparison_table.xlsx") df.to_excel(output_path) elif format == 'html': output_path = os.path.join(self.output_dir, "comparison_table.html") df.to_html(output_path) else: raise ValueError(f"Unsupported format: {format}") return output_path def generate_html_report(self, comparison_table, image_paths, metrics_list): """ Generate a comprehensive HTML report with all visualizations. Args: comparison_table: pandas DataFrame with comparison data image_paths: dictionary of generated visualization image paths metrics_list: list of metrics included in the analysis Returns: str: path to saved HTML report """ # Create HTML content html_content = f""" Image Model Evaluation Report

Image Model Evaluation Report

Metrics Overview

Metrics included in this analysis:

Comparison Table

""" # Add comparison table html_content += comparison_table.to_html(classes="table table-striped") # Add visualizations html_content += """

Visualizations

""" for title, img_path in image_paths.items(): if os.path.exists(img_path): # Convert image to base64 for embedding with open(img_path, "rb") as img_file: img_data = base64.b64encode(img_file.read()).decode('utf-8') html_content += f"""

{title}

{title}
""" # Close HTML html_content += """
""" # Save HTML report output_path = os.path.join(self.output_dir, "evaluation_report.html") with open(output_path, 'w', encoding='utf-8') as f: f.write(html_content) return output_path