File size: 16,401 Bytes
f89e218
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
"""
Module for visualizing image evaluation results and creating comparison tables.
"""

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import LinearSegmentedColormap
import os
import io
from PIL import Image
import base64


class Visualizer:
    """Class for visualizing image evaluation results."""
    
    def __init__(self, output_dir='./results'):
        """
        Initialize visualizer with output directory.
        
        Args:
            output_dir: directory to save visualization results
        """
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)
        
        # Set up color schemes
        self.setup_colors()
    
    def setup_colors(self):
        """Set up color schemes for visualizations."""
        # Custom colormap for heatmaps
        self.cmap = LinearSegmentedColormap.from_list(
            'custom_cmap', ['#FF5E5B', '#FFED66', '#00CEFF', '#0089BA', '#008F7A'], N=256
        )
        
        # Color palette for bar charts
        self.palette = sns.color_palette("viridis", 10)
        
        # Set Seaborn style
        sns.set_style("whitegrid")
    
    def create_comparison_table(self, results_dict, metrics_list=None):
        """
        Create a comparison table from evaluation results.
        
        Args:
            results_dict: dictionary with model names as keys and evaluation results as values
            metrics_list: list of metrics to include in the table (if None, include all)
            
        Returns:
            pandas.DataFrame: comparison table
        """
        # Initialize empty dataframe
        df = pd.DataFrame()
        
        # Process each model's results
        for model_name, model_results in results_dict.items():
            # Create a row for this model
            model_row = {'Model': model_name}
            
            # Add metrics to the row
            for metric_name, metric_value in model_results.items():
                if metrics_list is None or metric_name in metrics_list:
                    # Format numeric values to 2 decimal places
                    if isinstance(metric_value, (int, float)):
                        model_row[metric_name] = round(metric_value, 2)
                    else:
                        model_row[metric_name] = metric_value
            
            # Append to dataframe
            df = pd.concat([df, pd.DataFrame([model_row])], ignore_index=True)
        
        # Set Model as index
        if not df.empty:
            df.set_index('Model', inplace=True)
        
        return df
    
    def plot_metric_comparison(self, df, metric_name, title=None, figsize=(10, 6)):
        """
        Create a bar chart comparing models on a specific metric.
        
        Args:
            df: pandas DataFrame with comparison data
            metric_name: name of the metric to plot
            title: optional custom title
            figsize: figure size as (width, height)
            
        Returns:
            str: path to saved figure
        """
        if metric_name not in df.columns:
            raise ValueError(f"Metric '{metric_name}' not found in dataframe")
        
        # Create figure
        plt.figure(figsize=figsize)
        
        # Create bar chart
        ax = sns.barplot(x=df.index, y=df[metric_name], palette=self.palette)
        
        # Set title and labels
        if title:
            plt.title(title, fontsize=14)
        else:
            plt.title(f"Model Comparison: {metric_name}", fontsize=14)
        
        plt.xlabel("Model", fontsize=12)
        plt.ylabel(metric_name, fontsize=12)
        
        # Rotate x-axis labels for better readability
        plt.xticks(rotation=45, ha='right')
        
        # Add value labels on top of bars
        for i, v in enumerate(df[metric_name]):
            ax.text(i, v + 0.1, str(round(v, 2)), ha='center')
        
        plt.tight_layout()
        
        # Save figure
        output_path = os.path.join(self.output_dir, f"{metric_name}_comparison.png")
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        return output_path
    
    def plot_radar_chart(self, df, metrics_list, title=None, figsize=(10, 8)):
        """
        Create a radar chart comparing models across multiple metrics.
        
        Args:
            df: pandas DataFrame with comparison data
            metrics_list: list of metrics to include in the radar chart
            title: optional custom title
            figsize: figure size as (width, height)
            
        Returns:
            str: path to saved figure
        """
        # Filter metrics that exist in the dataframe
        metrics = [m for m in metrics_list if m in df.columns]
        
        if not metrics:
            raise ValueError("None of the specified metrics found in dataframe")
        
        # Number of metrics
        N = len(metrics)
        
        # Create figure
        fig = plt.figure(figsize=figsize)
        ax = fig.add_subplot(111, polar=True)
        
        # Compute angle for each metric
        angles = [n / float(N) * 2 * np.pi for n in range(N)]
        angles += angles[:1]  # Close the loop
        
        # Plot each model
        for i, model in enumerate(df.index):
            values = df.loc[model, metrics].values.flatten().tolist()
            values += values[:1]  # Close the loop
            
            # Plot values
            ax.plot(angles, values, linewidth=2, linestyle='solid', label=model, color=self.palette[i % len(self.palette)])
            ax.fill(angles, values, alpha=0.1, color=self.palette[i % len(self.palette)])
        
        # Set labels
        plt.xticks(angles[:-1], metrics, size=12)
        
        # Set y-axis limits
        ax.set_ylim(0, 10)
        
        # Add legend
        plt.legend(loc='upper right', bbox_to_anchor=(0.1, 0.1))
        
        # Set title
        if title:
            plt.title(title, size=16, y=1.1)
        else:
            plt.title("Model Comparison Across Metrics", size=16, y=1.1)
        
        # Save figure
        output_path = os.path.join(self.output_dir, "radar_comparison.png")
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        return output_path
    
    def plot_heatmap(self, df, title=None, figsize=(12, 8)):
        """
        Create a heatmap of all metrics across models.
        
        Args:
            df: pandas DataFrame with comparison data
            title: optional custom title
            figsize: figure size as (width, height)
            
        Returns:
            str: path to saved figure
        """
        # Create figure
        plt.figure(figsize=figsize)
        
        # Create heatmap
        ax = sns.heatmap(df, annot=True, cmap=self.cmap, fmt=".2f", linewidths=.5)
        
        # Set title
        if title:
            plt.title(title, fontsize=16)
        else:
            plt.title("Model Comparison Heatmap", fontsize=16)
        
        plt.tight_layout()
        
        # Save figure
        output_path = os.path.join(self.output_dir, "comparison_heatmap.png")
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        return output_path
    
    def plot_prompt_performance(self, prompt_results, metric_name, top_n=5, figsize=(12, 8)):
        """
        Create a grouped bar chart showing model performance on different prompts.
        
        Args:
            prompt_results: dictionary with prompts as keys and model results as values
            metric_name: name of the metric to plot
            top_n: number of top prompts to include
            figsize: figure size as (width, height)
            
        Returns:
            str: path to saved figure
        """
        # Create dataframe from results
        data = []
        for prompt, models_data in prompt_results.items():
            for model, metrics in models_data.items():
                if metric_name in metrics:
                    data.append({
                        'Prompt': prompt,
                        'Model': model,
                        metric_name: metrics[metric_name]
                    })
        
        df = pd.DataFrame(data)
        
        if df.empty:
            raise ValueError(f"No data found for metric '{metric_name}'")
        
        # Get top N prompts by average metric value
        top_prompts = df.groupby('Prompt')[metric_name].mean().nlargest(top_n).index.tolist()
        df_filtered = df[df['Prompt'].isin(top_prompts)]
        
        # Create figure
        plt.figure(figsize=figsize)
        
        # Create grouped bar chart
        ax = sns.barplot(x='Prompt', y=metric_name, hue='Model', data=df_filtered, palette=self.palette)
        
        # Set title and labels
        plt.title(f"Model Performance by Prompt: {metric_name}", fontsize=14)
        plt.xlabel("Prompt", fontsize=12)
        plt.ylabel(metric_name, fontsize=12)
        
        # Rotate x-axis labels for better readability
        plt.xticks(rotation=45, ha='right')
        
        # Adjust legend
        plt.legend(title="Model", bbox_to_anchor=(1.05, 1), loc='upper left')
        
        plt.tight_layout()
        
        # Save figure
        output_path = os.path.join(self.output_dir, f"prompt_performance_{metric_name}.png")
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        return output_path
    
    def create_image_grid(self, image_paths, titles=None, cols=3, figsize=(15, 15)):
        """
        Create a grid of images for visual comparison.
        
        Args:
            image_paths: list of paths to images
            titles: optional list of titles for each image
            cols: number of columns in the grid
            figsize: figure size as (width, height)
            
        Returns:
            str: path to saved figure
        """
        # Calculate number of rows needed
        rows = (len(image_paths) + cols - 1) // cols
        
        # Create figure
        fig, axes = plt.subplots(rows, cols, figsize=figsize)
        axes = axes.flatten()
        
        # Add each image to the grid
        for i, img_path in enumerate(image_paths):
            if i < len(axes):
                try:
                    img = Image.open(img_path)
                    axes[i].imshow(np.array(img))
                    
                    # Add title if provided
                    if titles and i < len(titles):
                        axes[i].set_title(titles[i])
                    
                    # Remove axis ticks
                    axes[i].set_xticks([])
                    axes[i].set_yticks([])
                except Exception as e:
                    print(f"Error loading image {img_path}: {e}")
                    axes[i].text(0.5, 0.5, f"Error loading image", ha='center', va='center')
                    axes[i].set_xticks([])
                    axes[i].set_yticks([])
        
        # Hide unused subplots
        for j in range(len(image_paths), len(axes)):
            axes[j].axis('off')
        
        plt.tight_layout()
        
        # Save figure
        output_path = os.path.join(self.output_dir, "image_comparison_grid.png")
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        return output_path
    
    def export_comparison_table(self, df, format='csv'):
        """
        Export comparison table to file.
        
        Args:
            df: pandas DataFrame with comparison data
            format: export format ('csv', 'excel', or 'html')
            
        Returns:
            str: path to saved file
        """
        if format == 'csv':
            output_path = os.path.join(self.output_dir, "comparison_table.csv")
            df.to_csv(output_path)
        elif format == 'excel':
            output_path = os.path.join(self.output_dir, "comparison_table.xlsx")
            df.to_excel(output_path)
        elif format == 'html':
            output_path = os.path.join(self.output_dir, "comparison_table.html")
            df.to_html(output_path)
        else:
            raise ValueError(f"Unsupported format: {format}")
        
        return output_path
    
    def generate_html_report(self, comparison_table, image_paths, metrics_list):
        """
        Generate a comprehensive HTML report with all visualizations.
        
        Args:
            comparison_table: pandas DataFrame with comparison data
            image_paths: dictionary of generated visualization image paths
            metrics_list: list of metrics included in the analysis
            
        Returns:
            str: path to saved HTML report
        """
        # Create HTML content
        html_content = f"""
        <!DOCTYPE html>
        <html>
        <head>
            <title>Image Model Evaluation Report</title>
            <style>
                body {{
                    font-family: Arial, sans-serif;
                    line-height: 1.6;
                    margin: 0;
                    padding: 20px;
                    color: #333;
                }}
                h1, h2, h3 {{
                    color: #2c3e50;
                }}
                .container {{
                    max-width: 1200px;
                    margin: 0 auto;
                }}
                table {{
                    border-collapse: collapse;
                    width: 100%;
                    margin-bottom: 20px;
                }}
                th, td {{
                    border: 1px solid #ddd;
                    padding: 8px;
                    text-align: left;
                }}
                th {{
                    background-color: #f2f2f2;
                }}
                tr:nth-child(even) {{
                    background-color: #f9f9f9;
                }}
                .visualization {{
                    margin: 20px 0;
                    text-align: center;
                }}
                .visualization img {{
                    max-width: 100%;
                    height: auto;
                    box-shadow: 0 4px 8px rgba(0,0,0,0.1);
                }}
                .metrics-list {{
                    background-color: #f8f9fa;
                    padding: 15px;
                    border-radius: 5px;
                    margin-bottom: 20px;
                }}
            </style>
        </head>
        <body>
            <div class="container">
                <h1>Image Model Evaluation Report</h1>
                
                <h2>Metrics Overview</h2>
                <div class="metrics-list">
                    <h3>Metrics included in this analysis:</h3>
                    <ul>
        """
        
        # Add metrics list
        for metric in metrics_list:
            html_content += f"            <li><strong>{metric}</strong></li>\n"
        
        html_content += """
                    </ul>
                </div>
                
                <h2>Comparison Table</h2>
        """
        
        # Add comparison table
        html_content += comparison_table.to_html(classes="table table-striped")
        
        # Add visualizations
        html_content += """
                <h2>Visualizations</h2>
        """
        
        for title, img_path in image_paths.items():
            if os.path.exists(img_path):
                # Convert image to base64 for embedding
                with open(img_path, "rb") as img_file:
                    img_data = base64.b64encode(img_file.read()).decode('utf-8')
                
                html_content += f"""
                <div class="visualization">
                    <h3>{title}</h3>
                    <img src="data:image/png;base64,{img_data}" alt="{title}">
                </div>
                """
        
        # Close HTML
        html_content += """
            </div>
        </body>
        </html>
        """
        
        # Save HTML report
        output_path = os.path.join(self.output_dir, "evaluation_report.html")
        with open(output_path, 'w', encoding='utf-8') as f:
            f.write(html_content)
        
        return output_path