Spaces:
Sleeping
Sleeping
""" | |
Main application file for the Image Evaluator tool. | |
This module integrates all components and provides a Gradio interface. | |
""" | |
import os | |
import gradio as gr | |
import numpy as np | |
import pandas as pd | |
import torch | |
import glob | |
from PIL import Image | |
import json | |
import tempfile | |
import shutil | |
from datetime import datetime | |
# Import custom modules | |
from modules.metadata_extractor import MetadataExtractor | |
from modules.technical_metrics import TechnicalMetrics | |
from modules.aesthetic_metrics import AestheticMetrics | |
from modules.aggregator import ResultsAggregator | |
from modules.visualizer import Visualizer | |
class ImageEvaluator: | |
"""Main class for the Image Evaluator application.""" | |
def __init__(self): | |
"""Initialize the Image Evaluator.""" | |
self.results_dir = os.path.join(os.getcwd(), "results") | |
os.makedirs(self.results_dir, exist_ok=True) | |
# Initialize components | |
self.metadata_extractor = MetadataExtractor() | |
self.technical_metrics = TechnicalMetrics() | |
self.aesthetic_metrics = AestheticMetrics() | |
self.aggregator = ResultsAggregator() | |
self.visualizer = Visualizer(self.results_dir) | |
# Storage for results | |
self.evaluation_results = {} | |
self.metadata_cache = {} | |
self.current_comparison = None | |
def process_images(self, image_files, progress=None): | |
""" | |
Process a list of image files and extract metadata. | |
Args: | |
image_files: list of image file paths | |
progress: optional gradio Progress object | |
Returns: | |
tuple: (metadata_by_model, metadata_by_prompt) | |
""" | |
metadata_list = [] | |
total_files = len(image_files) | |
for i, img_path in enumerate(image_files): | |
# Safe progress update without accessing internal attributes | |
if progress is not None: | |
try: | |
progress((i + 1) / total_files, f"Processing image {i+1}/{total_files}") | |
except Exception as e: | |
print(f"Progress update error (non-critical): {e}") | |
# Extract metadata | |
metadata = self.metadata_extractor.extract_metadata(img_path) | |
metadata_list.append((img_path, metadata)) | |
# Cache metadata | |
self.metadata_cache[img_path] = metadata | |
# Group by model and prompt | |
metadata_by_model = self.metadata_extractor.group_images_by_model(metadata_list) | |
metadata_by_prompt = self.metadata_extractor.group_images_by_prompt(metadata_list) | |
return metadata_by_model, metadata_by_prompt | |
def evaluate_images(self, image_files, progress=None): | |
""" | |
Evaluate a list of image files using all metrics. | |
Args: | |
image_files: list of image file paths | |
progress: optional gradio Progress object | |
Returns: | |
dict: evaluation results by image path | |
""" | |
results = {} | |
total_files = len(image_files) | |
for i, img_path in enumerate(image_files): | |
# Safe progress update without accessing internal attributes | |
if progress is not None: | |
try: | |
progress((i + 1) / total_files, f"Evaluating image {i+1}/{total_files}") | |
except Exception as e: | |
print(f"Progress update error (non-critical): {e}") | |
# Get metadata if available | |
metadata = self.metadata_cache.get(img_path, {}) | |
prompt = metadata.get('prompt', '') | |
# Calculate technical metrics | |
tech_metrics = self.technical_metrics.calculate_all_metrics(img_path) | |
# Calculate aesthetic metrics | |
aesthetic_metrics = self.aesthetic_metrics.calculate_all_metrics(img_path, prompt) | |
# Combine results | |
combined_metrics = {**tech_metrics, **aesthetic_metrics} | |
# Store results | |
results[img_path] = combined_metrics | |
return results | |
def compare_models(self, evaluation_results, metadata_by_model): | |
""" | |
Compare different models based on evaluation results. | |
Args: | |
evaluation_results: dictionary with image paths as keys and metrics as values | |
metadata_by_model: dictionary with model names as keys and lists of image paths as values | |
Returns: | |
tuple: (comparison_df, visualizations) | |
""" | |
# Group results by model | |
results_by_model = {} | |
for model, image_paths in metadata_by_model.items(): | |
model_results = [evaluation_results[img] for img in image_paths if img in evaluation_results] | |
results_by_model[model] = model_results | |
# Compare models | |
comparison = self.aggregator.compare_models(results_by_model) | |
# Create comparison dataframe | |
comparison_df = self.aggregator.create_comparison_dataframe(comparison) | |
# Store current comparison | |
self.current_comparison = comparison_df | |
# Create visualizations | |
visualizations = {} | |
# Create heatmap | |
heatmap_path = self.visualizer.plot_heatmap(comparison_df) | |
visualizations['Model Comparison Heatmap'] = heatmap_path | |
# Create radar chart for key metrics | |
key_metrics = ['aesthetic_score', 'sharpness', 'noise', 'contrast', 'color_harmony', 'prompt_similarity'] | |
available_metrics = [m for m in key_metrics if m in comparison_df.columns] | |
if available_metrics: | |
radar_path = self.visualizer.plot_radar_chart(comparison_df, available_metrics) | |
visualizations['Model Comparison Radar Chart'] = radar_path | |
# Create bar charts for important metrics | |
for metric in ['overall_score', 'aesthetic_score', 'prompt_similarity']: | |
if metric in comparison_df.columns: | |
bar_path = self.visualizer.plot_metric_comparison(comparison_df, metric) | |
visualizations[f'{metric} Comparison'] = bar_path | |
return comparison_df, visualizations | |
def export_results(self, format='csv'): | |
""" | |
Export current comparison results. | |
Args: | |
format: export format ('csv', 'excel', or 'html') | |
Returns: | |
str: path to exported file | |
""" | |
if self.current_comparison is not None: | |
return self.visualizer.export_comparison_table(self.current_comparison, format) | |
return None | |
def generate_report(self, comparison_df, visualizations): | |
""" | |
Generate a comprehensive HTML report. | |
Args: | |
comparison_df: pandas DataFrame with comparison data | |
visualizations: dictionary of visualization paths | |
Returns: | |
str: path to HTML report | |
""" | |
metrics_list = comparison_df.columns.tolist() | |
return self.visualizer.generate_html_report(comparison_df, visualizations, metrics_list) | |
# Create Gradio interface | |
def create_interface(): | |
"""Create and configure the Gradio interface.""" | |
# Initialize evaluator | |
evaluator = ImageEvaluator() | |
# Track state | |
state = { | |
'uploaded_images': [], | |
'metadata_by_model': {}, | |
'metadata_by_prompt': {}, | |
'evaluation_results': {}, | |
'comparison_df': None, | |
'visualizations': {}, | |
'report_path': None | |
} | |
def upload_images(files): | |
"""Handle image upload and processing.""" | |
# Reset state | |
state['uploaded_images'] = [] | |
state['metadata_by_model'] = {} | |
state['metadata_by_prompt'] = {} | |
state['evaluation_results'] = {} | |
state['comparison_df'] = None | |
state['visualizations'] = {} | |
state['report_path'] = None | |
# Process uploaded files | |
image_paths = [f.name for f in files] | |
state['uploaded_images'] = image_paths | |
# Extract metadata and group images | |
# Use a simple progress message instead of Gradio Progress object | |
print("Extracting metadata...") | |
metadata_by_model, metadata_by_prompt = evaluator.process_images(image_paths) | |
state['metadata_by_model'] = metadata_by_model | |
state['metadata_by_prompt'] = metadata_by_prompt | |
# Create model summary | |
model_summary = [] | |
for model, images in metadata_by_model.items(): | |
model_summary.append(f"- {model}: {len(images)} images") | |
# Create prompt summary | |
prompt_summary = [] | |
for prompt, images in metadata_by_prompt.items(): | |
prompt_summary.append(f"- {prompt}: {len(images)} images") | |
return ( | |
f"Processed {len(image_paths)} images.\n\n" | |
f"Found {len(metadata_by_model)} models:\n" + "\n".join(model_summary) + "\n\n" | |
f"Found {len(metadata_by_prompt)} unique prompts." | |
) | |
def evaluate_images(): | |
"""Evaluate all uploaded images.""" | |
if not state['uploaded_images']: | |
return "No images uploaded. Please upload images first." | |
# Evaluate images | |
# Use a simple progress message instead of Gradio Progress object | |
print("Evaluating images...") | |
evaluation_results = evaluator.evaluate_images(state['uploaded_images']) | |
state['evaluation_results'] = evaluation_results | |
return f"Evaluated {len(evaluation_results)} images with all metrics." | |
def compare_models(): | |
"""Compare models based on evaluation results.""" | |
if not state['evaluation_results'] or not state['metadata_by_model']: | |
return "No evaluation results available. Please evaluate images first.", None, None | |
# Compare models | |
comparison_df, visualizations = evaluator.compare_models( | |
state['evaluation_results'], state['metadata_by_model'] | |
) | |
state['comparison_df'] = comparison_df | |
state['visualizations'] = visualizations | |
# Generate report | |
report_path = evaluator.generate_report(comparison_df, visualizations) | |
state['report_path'] = report_path | |
# Get visualization paths | |
heatmap_path = visualizations.get('Model Comparison Heatmap') | |
radar_path = visualizations.get('Model Comparison Radar Chart') | |
overall_score_path = visualizations.get('overall_score Comparison') | |
# Convert DataFrame to markdown for display | |
df_markdown = comparison_df.to_markdown() | |
return df_markdown, heatmap_path, radar_path | |
def export_results(format): | |
"""Export results in the specified format.""" | |
if state['comparison_df'] is None: | |
return "No comparison results available. Please compare models first." | |
export_path = evaluator.export_results(format) | |
if export_path: | |
return f"Results exported to {export_path}" | |
else: | |
return "Failed to export results." | |
def view_report(): | |
"""View the generated HTML report.""" | |
if state['report_path'] and os.path.exists(state['report_path']): | |
return state['report_path'] | |
else: | |
return "No report available. Please compare models first." | |
# Create interface | |
with gr.Blocks(title="Image Model Evaluator") as interface: | |
gr.Markdown("# Image Model Evaluator") | |
gr.Markdown("Upload images generated by different AI models to compare their quality and performance.") | |
with gr.Tab("Upload & Process"): | |
with gr.Row(): | |
with gr.Column(): | |
upload_input = gr.File( | |
label="Upload Images (PNG format)", | |
file_count="multiple", | |
type="filepath" # Changed from 'file' to 'filepath' | |
) | |
upload_button = gr.Button("Process Uploaded Images") | |
with gr.Column(): | |
upload_output = gr.Textbox( | |
label="Processing Results", | |
lines=10, | |
interactive=False | |
) | |
evaluate_button = gr.Button("Evaluate Images") | |
evaluate_output = gr.Textbox( | |
label="Evaluation Status", | |
lines=2, | |
interactive=False | |
) | |
with gr.Tab("Compare Models"): | |
compare_button = gr.Button("Compare Models") | |
with gr.Row(): | |
comparison_output = gr.Markdown( | |
label="Comparison Results" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
heatmap_output = gr.Image( | |
label="Model Comparison Heatmap", | |
interactive=False | |
) | |
with gr.Column(): | |
radar_output = gr.Image( | |
label="Model Comparison Radar Chart", | |
interactive=False | |
) | |
with gr.Tab("Export & Report"): | |
with gr.Row(): | |
with gr.Column(): | |
export_format = gr.Radio( | |
label="Export Format", | |
choices=["csv", "excel", "html"], | |
value="csv" | |
) | |
export_button = gr.Button("Export Results") | |
export_output = gr.Textbox( | |
label="Export Status", | |
lines=2, | |
interactive=False | |
) | |
with gr.Column(): | |
report_button = gr.Button("View Full Report") | |
report_output = gr.HTML( | |
label="Full Report" | |
) | |
# Set up event handlers | |
upload_button.click( | |
upload_images, | |
inputs=[upload_input], | |
outputs=[upload_output] | |
) | |
evaluate_button.click( | |
evaluate_images, | |
inputs=[], | |
outputs=[evaluate_output] | |
) | |
compare_button.click( | |
compare_models, | |
inputs=[], | |
outputs=[comparison_output, heatmap_output, radar_output] | |
) | |
export_button.click( | |
export_results, | |
inputs=[export_format], | |
outputs=[export_output] | |
) | |
report_button.click( | |
view_report, | |
inputs=[], | |
outputs=[report_output] | |
) | |
return interface | |
# Launch the application | |
if __name__ == "__main__": | |
interface = create_interface() | |
# Remove share=True for HuggingFace Spaces | |
interface.launch() | |