Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import json | |
import base64 | |
import asyncio | |
import tempfile | |
import re | |
from io import BytesIO | |
from typing import List, Dict, Any, Optional, Tuple | |
import cv2 | |
import numpy as np | |
import torch | |
import gradio as gr | |
from PIL import Image, PngImagePlugin, ExifTags | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
from transformers import pipeline, AutoProcessor, AutoModelForImageClassification | |
from huggingface_hub import hf_hub_download | |
# Create necessary directories | |
os.makedirs('/tmp/image_evaluator_uploads', exist_ok=True) | |
os.makedirs('/tmp/image_evaluator_results', exist_ok=True) | |
##################################### | |
# Model Definitions # | |
##################################### | |
class MLP(torch.nn.Module): | |
"""A multi-layer perceptron for image feature regression.""" | |
def __init__(self, input_size: int, batch_norm: bool = True): | |
super().__init__() | |
self.input_size = input_size | |
self.layers = torch.nn.Sequential( | |
torch.nn.Linear(self.input_size, 2048), | |
torch.nn.ReLU(), | |
torch.nn.BatchNorm1d(2048) if batch_norm else torch.nn.Identity(), | |
torch.nn.Dropout(0.3), | |
torch.nn.Linear(2048, 512), | |
torch.nn.ReLU(), | |
torch.nn.BatchNorm1d(512) if batch_norm else torch.nn.Identity(), | |
torch.nn.Dropout(0.3), | |
torch.nn.Linear(512, 256), | |
torch.nn.ReLU(), | |
torch.nn.BatchNorm1d(256) if batch_norm else torch.nn.Identity(), | |
torch.nn.Dropout(0.2), | |
torch.nn.Linear(256, 128), | |
torch.nn.ReLU(), | |
torch.nn.BatchNorm1d(128) if batch_norm else torch.nn.Identity(), | |
torch.nn.Dropout(0.1), | |
torch.nn.Linear(128, 32), | |
torch.nn.ReLU(), | |
torch.nn.Linear(32, 1) | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return self.layers(x) | |
class WaifuScorer: | |
"""WaifuScorer model that uses CLIP for feature extraction and a custom MLP for scoring.""" | |
def __init__(self, model_path: str = None, device: str = None, cache_dir: str = None, verbose: bool = False): | |
self.verbose = verbose | |
self.device = device if device else ('cuda' if torch.cuda.is_available() else 'cpu') | |
self.dtype = torch.float32 | |
self.available = False | |
try: | |
# Try to import CLIP | |
try: | |
import clip | |
self.clip_available = True | |
except ImportError: | |
print("CLIP not available, using alternative feature extractor") | |
self.clip_available = False | |
# Set default model path if not provided | |
if model_path is None: | |
model_path = "Eugeoter/waifu-scorer-v3/model.pth" | |
if self.verbose: | |
print(f"Model path not provided. Using default: {model_path}") | |
# Download model if not found locally | |
if not os.path.isfile(model_path): | |
try: | |
username, repo_id, model_name = model_path.split("/") | |
model_path = hf_hub_download(f"{username}/{repo_id}", model_name, cache_dir=cache_dir) | |
except Exception as e: | |
print(f"Error downloading model: {e}") | |
# Fallback to local path | |
model_path = os.path.join(os.path.dirname(__file__), "models", "waifu_scorer_v3.pth") | |
if not os.path.exists(model_path): | |
os.makedirs(os.path.dirname(model_path), exist_ok=True) | |
# Create a dummy model for testing | |
self.mlp = MLP(input_size=768) | |
torch.save(self.mlp.state_dict(), model_path) | |
if self.verbose: | |
print(f"Loading WaifuScorer model from: {model_path}") | |
# Initialize MLP model | |
self.mlp = MLP(input_size=768) | |
# Load state dict | |
try: | |
if model_path.endswith(".safetensors"): | |
try: | |
from safetensors.torch import load_file | |
state_dict = load_file(model_path) | |
except ImportError: | |
state_dict = torch.load(model_path, map_location=self.device) | |
else: | |
state_dict = torch.load(model_path, map_location=self.device) | |
self.mlp.load_state_dict(state_dict) | |
except Exception as e: | |
print(f"Error loading model state dict: {e}") | |
# Initialize with random weights for testing | |
pass | |
self.mlp.to(self.device) | |
self.mlp.eval() | |
# Load CLIP model for image preprocessing and feature extraction | |
if self.clip_available: | |
self.clip_model, self.preprocess = clip.load("ViT-L/14", device=self.device) | |
else: | |
# Use alternative feature extractor | |
from transformers import CLIPProcessor, CLIPModel | |
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") | |
self.preprocess = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") | |
self.clip_model.to(self.device) | |
self.available = True | |
except Exception as e: | |
print(f"Unable to initialize WaifuScorer: {e}") | |
self.available = False | |
def __call__(self, images): | |
if not self.available: | |
return [5.0] * (len(images) if isinstance(images, list) else 1) # Default score instead of None | |
if isinstance(images, Image.Image): | |
images = [images] | |
n = len(images) | |
# Ensure at least two images for CLIP model compatibility | |
if n == 1: | |
images = images * 2 | |
try: | |
if self.clip_available: | |
# Original CLIP processing | |
image_tensors = [self.preprocess(img).unsqueeze(0) for img in images] | |
image_batch = torch.cat(image_tensors).to(self.device) | |
image_features = self.clip_model.encode_image(image_batch) | |
else: | |
# Alternative processing with Transformers CLIP | |
inputs = self.preprocess(images=images, return_tensors="pt").to(self.device) | |
image_features = self.clip_model.get_image_features(**inputs) | |
# Normalize features | |
norm = image_features.norm(2, dim=-1, keepdim=True) | |
norm[norm == 0] = 1 | |
im_emb = (image_features / norm).to(device=self.device, dtype=self.dtype) | |
predictions = self.mlp(im_emb) | |
scores = predictions.clamp(0, 10).cpu().numpy().reshape(-1).tolist() | |
return scores[:n] | |
except Exception as e: | |
print(f"Error in WaifuScorer inference: {e}") | |
return [5.0] * n # Default score instead of None | |
class AestheticPredictor: | |
"""Aesthetic Predictor using SiGLIP or other models.""" | |
def __init__(self, model_name="SmilingWolf/aesthetic-predictor-v2-5", device=None): | |
self.device = device if device else ('cuda' if torch.cuda.is_available() else 'cpu') | |
self.model_name = model_name | |
self.available = False | |
try: | |
print(f"Loading Aesthetic Predictor: {model_name}") | |
self.processor = AutoProcessor.from_pretrained(model_name) | |
self.model = AutoModelForImageClassification.from_pretrained(model_name) | |
if torch.cuda.is_available() and self.device == 'cuda': | |
self.model = self.model.to(torch.bfloat16).cuda() | |
else: | |
self.model = self.model.to(self.device) | |
self.model.eval() | |
self.available = True | |
except Exception as e: | |
print(f"Error loading Aesthetic Predictor: {e}") | |
self.available = False | |
def inference(self, images): | |
if not self.available: | |
return [5.0] * (len(images) if isinstance(images, list) else 1) # Default score instead of None | |
try: | |
if isinstance(images, list): | |
images_rgb = [img.convert("RGB") for img in images] | |
pixel_values = self.processor(images=images_rgb, return_tensors="pt").pixel_values | |
if torch.cuda.is_available() and self.device == 'cuda': | |
pixel_values = pixel_values.to(torch.bfloat16).cuda() | |
else: | |
pixel_values = pixel_values.to(self.device) | |
with torch.inference_mode(): | |
scores = self.model(pixel_values).logits.squeeze().float().cpu().numpy() | |
if scores.ndim == 0: | |
scores = np.array([scores]) | |
# Scale scores to 0-10 range | |
scores = scores * 10.0 | |
return scores.tolist() | |
else: | |
return self.inference([images])[0] | |
except Exception as e: | |
print(f"Error in Aesthetic Predictor inference: {e}") | |
if isinstance(images, list): | |
return [5.0] * len(images) # Default score instead of None | |
else: | |
return 5.0 # Default score instead of None | |
class AnimeAestheticEvaluator: | |
"""Anime Aesthetic Evaluator using ONNX model.""" | |
def __init__(self, model_path=None, device=None): | |
self.device = device if device else ('cuda' if torch.cuda.is_available() else 'cpu') | |
self.available = False | |
try: | |
import onnxruntime as rt | |
# Set default model path if not provided | |
if model_path is None: | |
try: | |
model_path = hf_hub_download(repo_id="skytnt/anime-aesthetic", filename="model.onnx") | |
except Exception as e: | |
print(f"Error downloading anime aesthetic model: {e}") | |
# Fallback to local path | |
model_path = os.path.join(os.path.dirname(__file__), "models", "anime_aesthetic.onnx") | |
if not os.path.exists(model_path): | |
print("Model not found and couldn't be downloaded") | |
self.available = False | |
return | |
# Select provider based on device | |
if self.device == 'cuda' and 'CUDAExecutionProvider' in rt.get_available_providers(): | |
providers = ['CUDAExecutionProvider'] | |
else: | |
providers = ['CPUExecutionProvider'] | |
self.model = rt.InferenceSession(model_path, providers=providers) | |
self.available = True | |
except Exception as e: | |
print(f"Error initializing Anime Aesthetic Evaluator: {e}") | |
self.available = False | |
def predict(self, images): | |
if not self.available: | |
return [5.0] * (len(images) if isinstance(images, list) else 1) # Default score instead of None | |
if isinstance(images, Image.Image): | |
images = [images] | |
try: | |
results = [] | |
for img in images: | |
img_np = np.array(img).astype(np.float32) / 255.0 | |
s = 768 | |
h, w = img_np.shape[:2] | |
if h > w: | |
new_h, new_w = s, int(s * w / h) | |
else: | |
new_h, new_w = int(s * h / w), s | |
resized = cv2.resize(img_np, (new_w, new_h)) | |
# Center the resized image in a square canvas | |
canvas = np.zeros((s, s, 3), dtype=np.float32) | |
pad_h = (s - new_h) // 2 | |
pad_w = (s - new_w) // 2 | |
canvas[pad_h:pad_h+new_h, pad_w:pad_w+new_w] = resized | |
# Prepare input for model | |
input_tensor = np.transpose(canvas, (2, 0, 1))[np.newaxis, :] | |
# Run inference | |
pred = self.model.run(None, {"img": input_tensor})[0].item() | |
# Scale to 0-10 | |
pred = pred * 10.0 | |
results.append(pred) | |
return results | |
except Exception as e: | |
print(f"Error in Anime Aesthetic prediction: {e}") | |
return [5.0] * len(images) # Default score instead of None | |
##################################### | |
# Technical Evaluator Class # | |
##################################### | |
class TechnicalEvaluator: | |
""" | |
Evaluator for basic technical image quality metrics. | |
Measures sharpness, noise, artifacts, and other technical aspects. | |
""" | |
def __init__(self, config=None): | |
self.config = config or {} | |
self.config.setdefault('laplacian_ksize', 3) | |
self.config.setdefault('blur_threshold', 100) | |
self.config.setdefault('noise_threshold', 0.05) | |
def evaluate(self, image_path_or_pil): | |
""" | |
Evaluate technical aspects of an image. | |
Args: | |
image_path_or_pil: Path to the image file or PIL Image. | |
Returns: | |
dict: Dictionary containing technical evaluation scores. | |
""" | |
try: | |
# Load image | |
if isinstance(image_path_or_pil, str): | |
img = cv2.imread(image_path_or_pil) | |
if img is None: | |
return { | |
'error': 'Failed to load image', | |
'overall_technical': 0.0 | |
} | |
else: | |
# Convert PIL Image to OpenCV format | |
img = cv2.cvtColor(np.array(image_path_or_pil), cv2.COLOR_RGB2BGR) | |
# Convert to grayscale for some calculations | |
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |
# Calculate sharpness using Laplacian variance | |
laplacian = cv2.Laplacian(gray, cv2.CV_64F, ksize=self.config['laplacian_ksize']) | |
sharpness_score = np.var(laplacian) / 10000 # Normalize | |
sharpness_score = min(1.0, sharpness_score) # Cap at 1.0 | |
# Calculate noise level | |
# Using a simple method based on standard deviation in smooth areas | |
blur = cv2.GaussianBlur(gray, (11, 11), 0) | |
diff = cv2.absdiff(gray, blur) | |
noise_level = np.std(diff) / 255.0 | |
noise_score = 1.0 - min(1.0, noise_level / self.config['noise_threshold']) | |
# Check for compression artifacts | |
edges = cv2.Canny(gray, 100, 200) | |
artifact_score = 1.0 - (np.count_nonzero(edges) / (gray.shape[0] * gray.shape[1])) | |
artifact_score = max(0.0, min(1.0, artifact_score * 2)) # Adjust range | |
# Calculate color range and saturation | |
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) | |
saturation = hsv[:, :, 1] | |
saturation_score = np.mean(saturation) / 255.0 | |
# Calculate contrast | |
min_val, max_val, _, _ = cv2.minMaxLoc(gray) | |
contrast_score = (max_val - min_val) / 255.0 | |
# Calculate overall technical score (weighted average) | |
overall_technical = ( | |
0.3 * sharpness_score + | |
0.2 * noise_score + | |
0.2 * artifact_score + | |
0.15 * saturation_score + | |
0.15 * contrast_score | |
) | |
# Scale to 0-10 range for consistency with other metrics | |
return { | |
'sharpness': float(sharpness_score * 10), | |
'noise': float(noise_score * 10), | |
'artifacts': float(artifact_score * 10), | |
'saturation': float(saturation_score * 10), | |
'contrast': float(contrast_score * 10), | |
'overall_technical': float(overall_technical * 10) | |
} | |
except Exception as e: | |
print(f"Error in technical evaluation: {e}") | |
return { | |
'error': str(e), | |
'overall_technical': 5.0 # Default score instead of 0 | |
} | |
def get_metadata(self): | |
""" | |
Return metadata about this evaluator. | |
Returns: | |
dict: Dictionary containing metadata about the evaluator. | |
""" | |
return { | |
'id': 'technical', | |
'name': 'Technical Metrics', | |
'description': 'Evaluates basic technical aspects of image quality including sharpness, noise, artifacts, saturation, and contrast.', | |
'version': '1.0', | |
'metrics': [ | |
{'id': 'sharpness', 'name': 'Sharpness', 'description': 'Measures image clarity and detail'}, | |
{'id': 'noise', 'name': 'Noise', 'description': 'Measures absence of unwanted variations'}, | |
{'id': 'artifacts', 'name': 'Artifacts', 'description': 'Measures absence of compression artifacts'}, | |
{'id': 'saturation', 'name': 'Saturation', 'description': 'Measures color intensity'}, | |
{'id': 'contrast', 'name': 'Contrast', 'description': 'Measures difference between light and dark areas'}, | |
{'id': 'overall_technical', 'name': 'Overall Technical', 'description': 'Combined technical quality score'} | |
] | |
} | |
##################################### | |
# Aesthetic Evaluator Class # | |
##################################### | |
class AestheticEvaluator: | |
""" | |
Evaluator for aesthetic image quality. | |
Uses a combination of rule-based metrics and ML models. | |
""" | |
def __init__(self, config=None): | |
self.config = config or {} | |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
# Initialize aesthetic predictor | |
try: | |
self.aesthetic_predictor = AestheticPredictor(device=self.device) | |
except Exception as e: | |
print(f"Error initializing Aesthetic Predictor: {e}") | |
self.aesthetic_predictor = None | |
# Initialize aesthetic shadow model | |
try: | |
self.aesthetic_shadow = pipeline( | |
"image-classification", | |
model="NeoChen1024/aesthetic-shadow-v2-backup", | |
device=self.device | |
) | |
except Exception as e: | |
print(f"Error initializing Aesthetic Shadow: {e}") | |
self.aesthetic_shadow = None | |
def evaluate(self, image_path_or_pil): | |
""" | |
Evaluate aesthetic aspects of an image. | |
Args: | |
image_path_or_pil: Path to the image file or PIL Image. | |
Returns: | |
dict: Dictionary containing aesthetic evaluation scores. | |
""" | |
try: | |
# Load image | |
if isinstance(image_path_or_pil, str): | |
img = Image.open(image_path_or_pil).convert("RGB") | |
else: | |
img = image_path_or_pil.convert("RGB") | |
# Convert to numpy array for calculations | |
img_np = np.array(img) | |
# Calculate color harmony using standard deviation of colors | |
r, g, b = img_np[:,:,0], img_np[:,:,1], img_np[:,:,2] | |
color_std = (np.std(r) + np.std(g) + np.std(b)) / 3 | |
color_harmony = min(1.0, color_std / 80.0) # Normalize | |
# Calculate composition score using rule of thirds | |
h, w = img_np.shape[:2] | |
third_h, third_w = h // 3, w // 3 | |
# Create a rule of thirds grid mask | |
grid_mask = np.zeros((h, w)) | |
for i in range(1, 3): | |
grid_mask[third_h * i - 5:third_h * i + 5, :] = 1 | |
grid_mask[:, third_w * i - 5:third_w * i + 5] = 1 | |
# Convert to grayscale for edge detection | |
gray = np.mean(img_np, axis=2).astype(np.uint8) | |
# Simple edge detection | |
edges = np.abs(np.diff(gray, axis=0, prepend=0)) + np.abs(np.diff(gray, axis=1, prepend=0)) | |
edges = edges > 30 # Threshold | |
# Calculate how many edges fall on the rule of thirds lines | |
thirds_alignment = np.sum(edges * grid_mask) / max(1, np.sum(edges)) | |
composition_score = min(1.0, thirds_alignment * 3) # Scale up for better distribution | |
# Calculate visual interest using entropy | |
hist_r = np.histogram(r, bins=256, range=(0, 256))[0] / (h * w) | |
hist_g = np.histogram(g, bins=256, range=(0, 256))[0] / (h * w) | |
hist_b = np.histogram(b, bins=256, range=(0, 256))[0] / (h * w) | |
entropy_r = -np.sum(hist_r[hist_r > 0] * np.log2(hist_r[hist_r > 0])) | |
entropy_g = -np.sum(hist_g[hist_g > 0] * np.log2(hist_g[hist_g > 0])) | |
entropy_b = -np.sum(hist_b[hist_b > 0] * np.log2(hist_b[hist_b > 0])) | |
entropy = (entropy_r + entropy_g + entropy_b) / 3 | |
visual_interest = min(1.0, entropy / 7.5) # Normalize | |
# Get ML model predictions | |
aesthetic_predictor_score = 0.5 # Default value | |
aesthetic_shadow_score = 0.5 # Default value | |
if self.aesthetic_predictor and self.aesthetic_predictor.available: | |
try: | |
aesthetic_predictor_score = self.aesthetic_predictor.inference(img) / 10.0 # Scale to 0-1 | |
except Exception as e: | |
print(f"Error in Aesthetic Predictor: {e}") | |
if self.aesthetic_shadow: | |
try: | |
shadow_result = self.aesthetic_shadow(img) | |
# Extract score from result | |
if isinstance(shadow_result, list) and len(shadow_result) > 0: | |
shadow_score = shadow_result[0]['score'] | |
aesthetic_shadow_score = shadow_score | |
except Exception as e: | |
print(f"Error in Aesthetic Shadow: {e}") | |
# Calculate overall aesthetic score (weighted average) | |
overall_aesthetic = ( | |
0.2 * color_harmony + | |
0.15 * composition_score + | |
0.15 * visual_interest + | |
0.25 * aesthetic_predictor_score + | |
0.25 * aesthetic_shadow_score | |
) | |
# Scale to 0-10 range for consistency with other metrics | |
return { | |
'color_harmony': float(color_harmony * 10), | |
'composition': float(composition_score * 10), | |
'visual_interest': float(visual_interest * 10), | |
'aesthetic_predictor': float(aesthetic_predictor_score * 10), | |
'aesthetic_shadow': float(aesthetic_shadow_score * 10), | |
'overall_aesthetic': float(overall_aesthetic * 10) | |
} | |
except Exception as e: | |
print(f"Error in aesthetic evaluation: {e}") | |
return { | |
'error': str(e), | |
'overall_aesthetic': 5.0 # Default score instead of 0 | |
} | |
def get_metadata(self): | |
""" | |
Return metadata about this evaluator. | |
Returns: | |
dict: Dictionary containing metadata about the evaluator. | |
""" | |
return { | |
'id': 'aesthetic', | |
'name': 'Aesthetic Assessment', | |
'description': 'Evaluates aesthetic qualities of images including color harmony, composition, and visual interest.', | |
'version': '1.0', | |
'metrics': [ | |
{'id': 'color_harmony', 'name': 'Color Harmony', 'description': 'Measures how well colors work together'}, | |
{'id': 'composition', 'name': 'Composition', 'description': 'Measures adherence to compositional principles like rule of thirds'}, | |
{'id': 'visual_interest', 'name': 'Visual Interest', 'description': 'Measures how visually engaging the image is'}, | |
{'id': 'aesthetic_predictor', 'name': 'Aesthetic Predictor', 'description': 'Score from Aesthetic Predictor V2.5 model'}, | |
{'id': 'aesthetic_shadow', 'name': 'Aesthetic Shadow', 'description': 'Score from Aesthetic Shadow model'}, | |
{'id': 'overall_aesthetic', 'name': 'Overall Aesthetic', 'description': 'Combined aesthetic quality score'} | |
] | |
} | |
##################################### | |
# Anime Evaluator Class # | |
##################################### | |
class AnimeEvaluator: | |
""" | |
Specialized evaluator for anime-style images. | |
Focuses on line quality, character design, style consistency, and other anime-specific attributes. | |
""" | |
def __init__(self, config=None): | |
self.config = config or {} | |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
# Initialize anime aesthetic model | |
try: | |
self.anime_aesthetic = AnimeAestheticEvaluator(device=self.device) | |
except Exception as e: | |
print(f"Error initializing Anime Aesthetic: {e}") | |
self.anime_aesthetic = None | |
# Initialize waifu scorer | |
try: | |
self.waifu_scorer = WaifuScorer(device=self.device, verbose=True) | |
except Exception as e: | |
print(f"Error initializing Waifu Scorer: {e}") | |
self.waifu_scorer = None | |
def evaluate(self, image_path_or_pil): | |
""" | |
Evaluate anime-specific aspects of an image. | |
Args: | |
image_path_or_pil: Path to the image file or PIL Image. | |
Returns: | |
dict: Dictionary containing anime-style evaluation scores. | |
""" | |
try: | |
# Load image | |
if isinstance(image_path_or_pil, str): | |
img = Image.open(image_path_or_pil).convert("RGB") | |
else: | |
img = image_path_or_pil.convert("RGB") | |
img_np = np.array(img) | |
# Line quality assessment | |
gray = np.mean(img_np, axis=2).astype(np.uint8) | |
# Calculate gradients for edge detection | |
gx = np.abs(np.diff(gray, axis=1, prepend=0)) | |
gy = np.abs(np.diff(gray, axis=0, prepend=0)) | |
# Combine gradients | |
edges = np.maximum(gx, gy) | |
# Strong edges are characteristic of anime | |
strong_edges = edges > 50 | |
edge_ratio = np.sum(strong_edges) / (gray.shape[0] * gray.shape[1]) | |
# Line quality score - anime typically has a higher proportion of strong edges | |
line_quality = min(1.0, edge_ratio * 20) # Scale appropriately | |
# Color palette assessment | |
pixels = img_np.reshape(-1, 3) | |
sample_size = min(10000, pixels.shape[0]) | |
indices = np.random.choice(pixels.shape[0], sample_size, replace=False) | |
sampled_pixels = pixels[indices] | |
# Calculate color diversity (simplified) | |
color_std = np.std(sampled_pixels, axis=0) | |
color_diversity = np.mean(color_std) / 128.0 # Normalize | |
# Anime often has a good balance of diversity but not excessive | |
color_score = 1.0 - abs(color_diversity - 0.5) * 2 # Penalize too high or too low | |
# Get ML model predictions | |
anime_aesthetic_score = 0.5 # Default value | |
waifu_score = 0.5 # Default value | |
if self.anime_aesthetic and self.anime_aesthetic.available: | |
try: | |
anime_scores = self.anime_aesthetic.predict([img]) | |
anime_aesthetic_score = anime_scores[0] / 10.0 # Scale to 0-1 | |
except Exception as e: | |
print(f"Error in Anime Aesthetic: {e}") | |
if self.waifu_scorer and self.waifu_scorer.available: | |
try: | |
waifu_scores = self.waifu_scorer([img]) | |
waifu_score = waifu_scores[0] / 10.0 # Scale to 0-1 | |
except Exception as e: | |
print(f"Error in Waifu Scorer: {e}") | |
# Style consistency assessment | |
hsv = np.array(img.convert('HSV')) | |
saturation = hsv[:,:,1] | |
value = hsv[:,:,2] | |
# Calculate statistics | |
sat_mean = np.mean(saturation) / 255.0 | |
val_mean = np.mean(value) / 255.0 | |
# Anime often has higher saturation and controlled brightness | |
sat_score = 1.0 - abs(sat_mean - 0.7) * 2 # Ideal around 0.7 | |
val_score = 1.0 - abs(val_mean - 0.6) * 2 # Ideal around 0.6 | |
style_consistency = (sat_score + val_score) / 2 | |
# Overall anime score (weighted average) | |
overall_anime = ( | |
0.2 * line_quality + | |
0.15 * color_score + | |
0.3 * waifu_score + | |
0.2 * anime_aesthetic_score + | |
0.15 * style_consistency | |
) | |
# Scale to 0-10 range for consistency with other metrics | |
return { | |
'line_quality': float(line_quality * 10), | |
'color_palette': float(color_score * 10), | |
'character_quality': float(waifu_score * 10), | |
'anime_aesthetic': float(anime_aesthetic_score * 10), | |
'style_consistency': float(style_consistency * 10), | |
'overall_anime': float(overall_anime * 10) | |
} | |
except Exception as e: | |
print(f"Error in anime evaluation: {e}") | |
return { | |
'error': str(e), | |
'overall_anime': 5.0 # Default score instead of 0 | |
} | |
def get_metadata(self): | |
""" | |
Return metadata about this evaluator. | |
Returns: | |
dict: Dictionary containing metadata about the evaluator. | |
""" | |
return { | |
'id': 'anime_specialized', | |
'name': 'Anime Style Evaluator', | |
'description': 'Specialized evaluator for anime-style images, focusing on line quality, color palette, character design, and style consistency.', | |
'version': '1.0', | |
'metrics': [ | |
{'id': 'line_quality', 'name': 'Line Quality', 'description': 'Measures clarity and quality of line work'}, | |
{'id': 'color_palette', 'name': 'Color Palette', 'description': 'Evaluates color choices and harmony for anime style'}, | |
{'id': 'character_quality', 'name': 'Character Quality', 'description': 'Assesses character design and rendering using Waifu Scorer'}, | |
{'id': 'anime_aesthetic', 'name': 'Anime Aesthetic', 'description': 'Score from specialized anime aesthetic model'}, | |
{'id': 'style_consistency', 'name': 'Style Consistency', 'description': 'Measures adherence to anime style conventions'}, | |
{'id': 'overall_anime', 'name': 'Overall Anime Quality', 'description': 'Combined anime-specific quality score'} | |
] | |
} | |
##################################### | |
# Metadata Manager Class # | |
##################################### | |
class MetadataManager: | |
""" | |
Manager for extracting and parsing image metadata. | |
""" | |
def __init__(self): | |
pass | |
def extract_metadata(self, image_path_or_pil): | |
""" | |
Extract metadata from an image. | |
Args: | |
image_path_or_pil: Path to the image file or PIL Image. | |
Returns: | |
dict: Dictionary containing extracted metadata. | |
""" | |
try: | |
# Load image if path is provided | |
if isinstance(image_path_or_pil, str): | |
img = Image.open(image_path_or_pil) | |
else: | |
img = image_path_or_pil | |
# Initialize metadata dictionary | |
metadata = { | |
'has_metadata': False, | |
'prompt': None, | |
'negative_prompt': None, | |
'steps': None, | |
'sampler': None, | |
'cfg_scale': None, | |
'seed': None, | |
'size': None, | |
'model': None, | |
'raw_metadata': None | |
} | |
# Check for PNG info metadata (Stable Diffusion WebUI) | |
if 'parameters' in img.info: | |
metadata['has_metadata'] = True | |
metadata['raw_metadata'] = img.info['parameters'] | |
# Parse parameters | |
params = img.info['parameters'] | |
# Extract prompt and negative prompt | |
neg_prompt_prefix = "Negative prompt:" | |
if neg_prompt_prefix in params: | |
parts = params.split(neg_prompt_prefix, 1) | |
metadata['prompt'] = parts[0].strip() | |
rest = parts[1].strip() | |
# Find the next parameter after negative prompt | |
next_param_match = re.search(r'\n(Steps: |Sampler: |CFG scale: |Seed: |Size: |Model: )', rest) | |
if next_param_match: | |
neg_end = next_param_match.start() | |
metadata['negative_prompt'] = rest[:neg_end].strip() | |
rest = rest[neg_end:].strip() | |
else: | |
metadata['negative_prompt'] = rest | |
else: | |
metadata['prompt'] = params | |
# Extract other parameters | |
for param in ['Steps', 'Sampler', 'CFG scale', 'Seed', 'Size', 'Model']: | |
param_match = re.search(rf'{param}: ([^,\n]+)', params) | |
if param_match: | |
param_key = param.lower().replace(' ', '_') | |
metadata[param_key] = param_match.group(1).strip() | |
# Check for EXIF metadata | |
elif hasattr(img, '_getexif') and img._getexif(): | |
exif = { | |
ExifTags.TAGS[k]: v | |
for k, v in img._getexif().items() | |
if k in ExifTags.TAGS | |
} | |
if 'ImageDescription' in exif and exif['ImageDescription']: | |
metadata['has_metadata'] = True | |
metadata['raw_metadata'] = exif['ImageDescription'] | |
# Try to parse as JSON | |
try: | |
json_data = json.loads(exif['ImageDescription']) | |
if 'prompt' in json_data: | |
metadata['prompt'] = json_data['prompt'] | |
if 'negative_prompt' in json_data: | |
metadata['negative_prompt'] = json_data['negative_prompt'] | |
# Map other parameters | |
param_mapping = { | |
'steps': 'steps', | |
'sampler': 'sampler', | |
'cfg_scale': 'cfg_scale', | |
'seed': 'seed', | |
'width': 'width', | |
'height': 'height', | |
'model': 'model' | |
} | |
for json_key, meta_key in param_mapping.items(): | |
if json_key in json_data: | |
metadata[meta_key] = json_data[json_key] | |
# Combine width and height for size | |
if 'width' in json_data and 'height' in json_data: | |
metadata['size'] = f"{json_data['width']}x{json_data['height']}" | |
except json.JSONDecodeError: | |
# Not JSON, try to parse as text | |
desc = exif['ImageDescription'] | |
metadata['prompt'] = desc | |
# If no metadata found but image has dimensions, add them | |
if not metadata['size'] and hasattr(img, 'width') and hasattr(img, 'height'): | |
metadata['size'] = f"{img.width}x{img.height}" | |
return metadata | |
except Exception as e: | |
print(f"Error extracting metadata: {e}") | |
return { | |
'has_metadata': False, | |
'error': str(e) | |
} | |
def update_metadata(self, image, new_metadata): | |
""" | |
Update the metadata in an image. | |
Args: | |
image: PIL Image. | |
new_metadata: New metadata string. | |
Returns: | |
PIL Image: Image with updated metadata. | |
""" | |
if image: | |
try: | |
# Create a PngInfo object to store metadata | |
pnginfo = PngImagePlugin.PngInfo() | |
pnginfo.add_text("parameters", new_metadata) | |
# Save the image to a BytesIO object with the updated metadata | |
output_bytes = BytesIO() | |
image.save(output_bytes, format="PNG", pnginfo=pnginfo) | |
output_bytes.seek(0) | |
# Re-open the image from the BytesIO object | |
updated_image = Image.open(output_bytes) | |
return updated_image | |
except Exception as e: | |
print(f"Error updating metadata: {e}") | |
return image | |
else: | |
return None | |
##################################### | |
# Evaluator Manager Class # | |
##################################### | |
class EvaluatorManager: | |
""" | |
Manager class for handling multiple evaluators. | |
Provides a unified interface for evaluating images with different metrics. | |
""" | |
def __init__(self): | |
"""Initialize the evaluator manager with available evaluators.""" | |
self.evaluators = {} | |
self.metadata_manager = MetadataManager() | |
self._register_default_evaluators() | |
def _register_default_evaluators(self): | |
"""Register the default set of evaluators.""" | |
self.register_evaluator(TechnicalEvaluator()) | |
self.register_evaluator(AestheticEvaluator()) | |
self.register_evaluator(AnimeEvaluator()) | |
def register_evaluator(self, evaluator): | |
""" | |
Register a new evaluator. | |
Args: | |
evaluator: The evaluator to register. | |
""" | |
metadata = evaluator.get_metadata() | |
self.evaluators[metadata['id']] = evaluator | |
def get_available_evaluators(self): | |
""" | |
Get a list of available evaluators. | |
Returns: | |
list: List of evaluator metadata. | |
""" | |
return [evaluator.get_metadata() for evaluator in self.evaluators.values()] | |
def evaluate_image(self, image_path_or_pil, evaluator_ids=None): | |
""" | |
Evaluate an image using specified evaluators. | |
Args: | |
image_path_or_pil: Path to the image file or PIL Image. | |
evaluator_ids: List of evaluator IDs to use. | |
If None, all available evaluators will be used. | |
Returns: | |
dict: Dictionary containing evaluation results from each evaluator. | |
""" | |
# Check if image exists | |
if isinstance(image_path_or_pil, str) and not os.path.exists(image_path_or_pil): | |
return {'error': f'Image file not found: {image_path_or_pil}'} | |
if evaluator_ids is None: | |
evaluator_ids = list(self.evaluators.keys()) | |
results = {} | |
# Extract metadata | |
metadata = self.metadata_manager.extract_metadata(image_path_or_pil) | |
results['metadata'] = metadata | |
# Evaluate with each evaluator | |
for evaluator_id in evaluator_ids: | |
if evaluator_id in self.evaluators: | |
results[evaluator_id] = self.evaluators[evaluator_id].evaluate(image_path_or_pil) | |
else: | |
results[evaluator_id] = {'error': f'Evaluator not found: {evaluator_id}'} | |
return results | |
def batch_evaluate_images(self, image_paths_or_pils, evaluator_ids=None): | |
""" | |
Evaluate multiple images using specified evaluators. | |
Args: | |
image_paths_or_pils: List of paths to image files or PIL Images. | |
evaluator_ids: List of evaluator IDs to use. | |
If None, all available evaluators will be used. | |
Returns: | |
list: List of dictionaries containing evaluation results for each image. | |
""" | |
return [self.evaluate_image(path_or_pil, evaluator_ids) for path_or_pil in image_paths_or_pils] | |
def compare_models(self, model_results): | |
""" | |
Compare different models based on evaluation results. | |
Args: | |
model_results: Dictionary mapping model names to their evaluation results. | |
Returns: | |
dict: Comparison results including rankings and best model. | |
""" | |
if not model_results: | |
return {'error': 'No model results provided for comparison'} | |
# Calculate average scores for each model across all images and evaluators | |
model_scores = {} | |
for model_name, image_results in model_results.items(): | |
model_scores[model_name] = { | |
'technical': 0.0, | |
'aesthetic': 0.0, | |
'anime_specialized': 0.0, | |
'overall': 0.0 | |
} | |
image_count = len(image_results) | |
if image_count == 0: | |
continue | |
# Sum up scores across all images | |
for image_id, evaluations in image_results.items(): | |
if 'technical' in evaluations and 'overall_technical' in evaluations['technical']: | |
model_scores[model_name]['technical'] += evaluations['technical']['overall_technical'] | |
if 'aesthetic' in evaluations and 'overall_aesthetic' in evaluations['aesthetic']: | |
model_scores[model_name]['aesthetic'] += evaluations['aesthetic']['overall_aesthetic'] | |
if 'anime_specialized' in evaluations and 'overall_anime' in evaluations['anime_specialized']: | |
model_scores[model_name]['anime_specialized'] += evaluations['anime_specialized']['overall_anime'] | |
# Calculate averages | |
model_scores[model_name]['technical'] /= image_count | |
model_scores[model_name]['aesthetic'] /= image_count | |
model_scores[model_name]['anime_specialized'] /= image_count | |
# Calculate overall score (weighted average of all metrics) | |
model_scores[model_name]['overall'] = ( | |
0.3 * model_scores[model_name]['technical'] + | |
0.4 * model_scores[model_name]['aesthetic'] + | |
0.3 * model_scores[model_name]['anime_specialized'] | |
) | |
# Rank models by overall score | |
rankings = sorted( | |
[(model, scores['overall']) for model, scores in model_scores.items()], | |
key=lambda x: x[1], | |
reverse=True | |
) | |
# Format rankings | |
formatted_rankings = [ | |
{'rank': i+1, 'model': model, 'score': score} | |
for i, (model, score) in enumerate(rankings) | |
] | |
# Determine best model | |
best_model = rankings[0][0] if rankings else None | |
# Format comparison metrics | |
comparison_metrics = { | |
'technical': {model: scores['technical'] for model, scores in model_scores.items()}, | |
'aesthetic': {model: scores['aesthetic'] for model, scores in model_scores.items()}, | |
'anime_specialized': {model: scores['anime_specialized'] for model, scores in model_scores.items()}, | |
'overall': {model: scores['overall'] for model, scores in model_scores.items()} | |
} | |
return { | |
'best_model': best_model, | |
'rankings': formatted_rankings, | |
'comparison_metrics': comparison_metrics | |
} | |
##################################### | |
# Model Manager Class # | |
##################################### | |
class ModelManager: | |
""" | |
Manages model loading and processing requests using a queue. | |
""" | |
def __init__(self): | |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
print(f"Using device: {self.device}") | |
# Initialize evaluator manager | |
self.evaluator_manager = EvaluatorManager() | |
# Initialize processing queue | |
self.processing_queue = asyncio.Queue() | |
self.worker_task = None | |
# Create temp directory | |
self.temp_dir = tempfile.mkdtemp() | |
async def start_worker(self): | |
"""Start the background worker task.""" | |
if self.worker_task is None: | |
self.worker_task = asyncio.create_task(self._worker()) | |
async def _worker(self): | |
"""Background worker to process image evaluation requests from the queue.""" | |
while True: | |
request = await self.processing_queue.get() | |
if request is None: # Shutdown signal | |
self.processing_queue.task_done() | |
break | |
try: | |
results = await self._process_request(request) | |
request['results_future'].set_result(results) # Fulfill the future with results | |
except Exception as e: | |
request['results_future'].set_exception(e) # Set exception if processing fails | |
finally: | |
self.processing_queue.task_done() | |
async def submit_request(self, request_data): | |
"""Submit a new image processing request to the queue.""" | |
results_future = asyncio.Future() # Future to hold the results | |
request = {**request_data, 'results_future': results_future} | |
await self.processing_queue.put(request) | |
return await results_future # Wait for and return results | |
async def _process_request(self, request): | |
"""Process a single image evaluation request.""" | |
file_paths = request['file_paths'] | |
auto_batch = request['auto_batch'] | |
manual_batch_size = request['manual_batch_size'] | |
selected_evaluators = request['selected_evaluators'] | |
log_events = [] | |
images = [] | |
file_names = [] | |
final_results = [] | |
# Prepare images and file names | |
total_files = len(file_paths) | |
log_events.append(f"Starting to load {total_files} images...") | |
for f in file_paths: | |
try: | |
img = Image.open(f).convert("RGB") | |
images.append(img) | |
file_names.append(os.path.basename(f)) | |
except Exception as e: | |
log_events.append(f"Error opening {f}: {e}") | |
if not images: | |
log_events.append("No valid images loaded.") | |
return [], log_events, 0, manual_batch_size | |
log_events.append("Images loaded. Determining batch size...") | |
try: | |
manual_batch_size = int(manual_batch_size) if manual_batch_size is not None else 1 | |
except ValueError: | |
manual_batch_size = 1 | |
log_events.append("Invalid manual batch size. Defaulting to 1.") | |
optimal_batch = self.auto_tune_batch_size(images) if auto_batch else manual_batch_size | |
log_events.append(f"Using batch size: {optimal_batch}") | |
total_images = len(images) | |
for i in range(0, total_images, optimal_batch): | |
batch_images = images[i:i+optimal_batch] | |
batch_file_paths = file_paths[i:i+optimal_batch] | |
batch_file_names = file_names[i:i+optimal_batch] | |
batch_index = i // optimal_batch + 1 | |
log_events.append(f"Processing batch {batch_index}: images {i+1} to {min(i+optimal_batch, total_images)}") | |
# Process each image in the batch | |
for j, (img, img_path, img_name) in enumerate(zip(batch_images, batch_file_paths, batch_file_names)): | |
# Evaluate image with selected evaluators | |
evaluation_results = self.evaluator_manager.evaluate_image(img_path, selected_evaluators) | |
# Extract metadata | |
metadata = evaluation_results.get('metadata', {}) | |
# Calculate final score | |
scores_to_average = [] | |
for evaluator_id in selected_evaluators: | |
if evaluator_id in evaluation_results: | |
if evaluator_id == 'technical' and 'overall_technical' in evaluation_results[evaluator_id]: | |
scores_to_average.append(evaluation_results[evaluator_id]['overall_technical']) | |
elif evaluator_id == 'aesthetic' and 'overall_aesthetic' in evaluation_results[evaluator_id]: | |
scores_to_average.append(evaluation_results[evaluator_id]['overall_aesthetic']) | |
elif evaluator_id == 'anime_specialized' and 'overall_anime' in evaluation_results[evaluator_id]: | |
scores_to_average.append(evaluation_results[evaluator_id]['overall_anime']) | |
final_score = float(np.clip(np.mean(scores_to_average), 0.0, 10.0)) if scores_to_average else 5.0 | |
# Create thumbnail | |
thumbnail = img.copy() | |
thumbnail.thumbnail((200, 200)) | |
# Create result | |
result = { | |
'file_name': img_name, | |
'file_path': img_path, | |
'img_data': self.image_to_base64(thumbnail), | |
'final_score': final_score, | |
'metadata': metadata, | |
} | |
# Add evaluator results | |
for evaluator_id in selected_evaluators: | |
if evaluator_id in evaluation_results: | |
result[evaluator_id] = evaluation_results[evaluator_id] | |
final_results.append(result) | |
log_events.append("All images processed.") | |
return final_results, log_events, 100, optimal_batch | |
def image_to_base64(self, image: Image.Image) -> str: | |
"""Convert PIL Image to base64 encoded JPEG string.""" | |
buffered = BytesIO() | |
image.save(buffered, format="JPEG") | |
return base64.b64encode(buffered.getvalue()).decode('utf-8') | |
def auto_tune_batch_size(self, images: list) -> int: | |
"""Automatically determine the optimal batch size for processing.""" | |
# For simplicity, use a fixed batch size | |
# In a real implementation, this would test different batch sizes | |
return min(4, len(images)) | |
##################################### | |
# Gradio Interface # | |
##################################### | |
# Initialize evaluator manager and model manager | |
evaluator_manager = EvaluatorManager() | |
model_manager = ModelManager() | |
# Global variables to store uploaded images and results | |
uploaded_images = {} | |
evaluation_results = {} | |
def extract_metadata_from_image(image): | |
""" | |
Extract metadata from an uploaded image. | |
Args: | |
image: Uploaded image. | |
Returns: | |
tuple: (image, metadata) | |
""" | |
if image is None: | |
return None, "" | |
metadata_manager = MetadataManager() | |
metadata = metadata_manager.extract_metadata(image) | |
if metadata['has_metadata']: | |
return image, metadata['raw_metadata'] or "" | |
else: | |
return image, "No metadata found in image." | |
def update_image_metadata(image, new_metadata): | |
""" | |
Update metadata in an image. | |
Args: | |
image: Image to update. | |
new_metadata: New metadata string. | |
Returns: | |
tuple: (updated_image, metadata) | |
""" | |
if image is None: | |
return None, "" | |
metadata_manager = MetadataManager() | |
updated_image = metadata_manager.update_metadata(image, new_metadata) | |
return updated_image, new_metadata | |
def evaluate_images(images, model_name, selected_evaluators): | |
""" | |
Evaluate uploaded images using selected evaluators. | |
Args: | |
images: List of uploaded image files. | |
model_name: Name of the model that generated these images. | |
selected_evaluators: List of evaluator IDs to use. | |
Returns: | |
str: Status message. | |
""" | |
global uploaded_images, evaluation_results | |
if not images: | |
return "No images uploaded." | |
if not model_name: | |
model_name = "unknown_model" | |
# Save uploaded images | |
if model_name not in uploaded_images: | |
uploaded_images[model_name] = [] | |
image_paths = [] | |
for img in images: | |
# Save image to temporary file | |
img_path = f"/tmp/image_evaluator_uploads/{model_name}_{len(uploaded_images[model_name])}.png" | |
os.makedirs(os.path.dirname(img_path), exist_ok=True) | |
Image.open(img).save(img_path) | |
# Add to uploaded images | |
uploaded_images[model_name].append({ | |
'path': img_path, | |
'id': f"{model_name}_{len(uploaded_images[model_name])}" | |
}) | |
image_paths.append(img_path) | |
# Evaluate images | |
if not selected_evaluators: | |
selected_evaluators = ['technical', 'aesthetic', 'anime_specialized'] | |
results = {} | |
for i, img_path in enumerate(image_paths): | |
img_id = uploaded_images[model_name][i]['id'] | |
results[img_id] = evaluator_manager.evaluate_image(img_path, selected_evaluators) | |
# Store results | |
if model_name not in evaluation_results: | |
evaluation_results[model_name] = {} | |
evaluation_results[model_name].update(results) | |
return f"Evaluated {len(images)} images for model '{model_name}'." | |
async def evaluate_images_async(images, model_name, selected_evaluators, auto_batch=True, batch_size=4): | |
""" | |
Asynchronously evaluate uploaded images using selected evaluators. | |
Args: | |
images: List of uploaded image files. | |
model_name: Name of the model that generated these images. | |
selected_evaluators: List of evaluator IDs to use. | |
auto_batch: Whether to automatically determine batch size. | |
batch_size: Manual batch size if auto_batch is False. | |
Returns: | |
tuple: (results, log, progress, batch_size) | |
""" | |
if not images: | |
return [], ["No images uploaded."], 0, batch_size | |
if not model_name: | |
model_name = "unknown_model" | |
# Start worker if not already running | |
await model_manager.start_worker() | |
# Prepare request | |
request_data = { | |
'file_paths': images, | |
'auto_batch': auto_batch, | |
'manual_batch_size': batch_size, | |
'selected_evaluators': selected_evaluators | |
} | |
# Submit request and wait for results | |
results, log_events, progress, actual_batch_size = await model_manager.submit_request(request_data) | |
# Store results in global variable | |
if results: | |
global evaluation_results | |
if model_name not in evaluation_results: | |
evaluation_results[model_name] = {} | |
for result in results: | |
img_id = f"{model_name}_{os.path.basename(result['file_path'])}" | |
evaluation_data = { | |
'metadata': result.get('metadata', {}), | |
'technical': result.get('technical', {}), | |
'aesthetic': result.get('aesthetic', {}), | |
'anime_specialized': result.get('anime_specialized', {}) | |
} | |
evaluation_results[model_name][img_id] = evaluation_data | |
# Create results table HTML | |
results_table_html = create_results_table(results) | |
return results_table_html, log_events, progress, actual_batch_size | |
def compare_models(): | |
""" | |
Compare models based on evaluation results. | |
Returns: | |
tuple: (comparison table HTML, overall chart, radar chart) | |
""" | |
global evaluation_results | |
if not evaluation_results or len(evaluation_results) < 2: | |
return "Need at least two models with evaluated images for comparison.", None, None | |
# Compare models | |
comparison = evaluator_manager.compare_models(evaluation_results) | |
# Create comparison table | |
models = list(evaluation_results.keys()) | |
metrics = ['technical', 'aesthetic', 'anime_specialized', 'overall'] | |
data = [] | |
for model in models: | |
row = {'Model': model} | |
for metric in metrics: | |
if metric in comparison['comparison_metrics'] and model in comparison['comparison_metrics'][metric]: | |
row[metric.capitalize()] = comparison['comparison_metrics'][metric][model] | |
else: | |
row[metric.capitalize()] = 0.0 | |
data.append(row) | |
df = pd.DataFrame(data) | |
# Add ranking information | |
for rank_info in comparison['rankings']: | |
if rank_info['model'] in df['Model'].values: | |
df.loc[df['Model'] == rank_info['model'], 'Rank'] = rank_info['rank'] | |
# Sort by rank | |
df = df.sort_values('Rank') | |
# Create overall comparison chart | |
plt.figure(figsize=(10, 6)) | |
overall_scores = [comparison['comparison_metrics']['overall'].get(model, 0) for model in models] | |
bars = plt.bar(models, overall_scores, color='skyblue') | |
# Add value labels on top of bars | |
for bar in bars: | |
height = bar.get_height() | |
plt.text(bar.get_x() + bar.get_width()/2., height + 0.01, | |
f'{height:.2f}', ha='center', va='bottom') | |
plt.title('Overall Quality Scores by Model') | |
plt.xlabel('Model') | |
plt.ylabel('Score') | |
plt.ylim(0, 10.5) | |
plt.grid(axis='y', linestyle='--', alpha=0.7) | |
# Save the chart | |
overall_chart_path = "/tmp/image_evaluator_results/overall_comparison.png" | |
os.makedirs(os.path.dirname(overall_chart_path), exist_ok=True) | |
plt.savefig(overall_chart_path) | |
plt.close() | |
# Create radar chart | |
categories = [m.capitalize() for m in metrics[:-1]] # Exclude 'overall' | |
N = len(categories) | |
# Create angles for each metric | |
angles = [n / float(N) * 2 * np.pi for n in range(N)] | |
angles += angles[:1] # Close the loop | |
# Create radar chart | |
plt.figure(figsize=(10, 10)) | |
ax = plt.subplot(111, polar=True) | |
# Add lines for each model | |
colors = plt.cm.tab10(np.linspace(0, 1, len(models))) | |
for i, model in enumerate(models): | |
values = [comparison['comparison_metrics'][metric].get(model, 0) for metric in metrics[:-1]] | |
values += values[:1] # Close the loop | |
ax.plot(angles, values, linewidth=2, linestyle='solid', label=model, color=colors[i]) | |
ax.fill(angles, values, alpha=0.1, color=colors[i]) | |
# Set category labels | |
plt.xticks(angles[:-1], categories) | |
# Set y-axis limits | |
ax.set_ylim(0, 10) | |
# Add legend | |
plt.legend(loc='upper right', bbox_to_anchor=(0.1, 0.1)) | |
plt.title('Detailed Metrics Comparison by Model') | |
# Save the chart | |
radar_chart_path = "/tmp/image_evaluator_results/radar_comparison.png" | |
plt.savefig(radar_chart_path) | |
plt.close() | |
# Create result message | |
result_message = f"Best model: {comparison['best_model']}\n\nModel rankings:\n" | |
for rank in comparison['rankings']: | |
result_message += f"{rank['rank']}. {rank['model']} (score: {rank['score']:.2f})\n" | |
return result_message, overall_chart_path, radar_chart_path | |
def create_results_table(results): | |
""" | |
Create an HTML table with results and image previews. | |
Args: | |
results: List of evaluation results. | |
Returns: | |
str: HTML table. | |
""" | |
if not results: | |
return "No results to display." | |
# Sort results by final score (descending) | |
sorted_results = sorted(results, key=lambda x: x.get('final_score', 0), reverse=True) | |
# Create HTML table | |
html = """ | |
<style> | |
.results-table { | |
width: 100%; | |
border-collapse: collapse; | |
font-family: Arial, sans-serif; | |
} | |
.results-table th, .results-table td { | |
border: 1px solid #ddd; | |
padding: 8px; | |
text-align: left; | |
} | |
.results-table th { | |
background-color: #f2f2f2; | |
position: sticky; | |
top: 0; | |
} | |
.results-table tr:nth-child(even) { | |
background-color: #f9f9f9; | |
} | |
.results-table tr:hover { | |
background-color: #f1f1f1; | |
} | |
.image-preview { | |
max-width: 150px; | |
max-height: 150px; | |
} | |
.score { | |
font-weight: bold; | |
} | |
.high-score { | |
color: green; | |
} | |
.medium-score { | |
color: orange; | |
} | |
.low-score { | |
color: red; | |
} | |
.metadata-cell { | |
max-width: 300px; | |
overflow: hidden; | |
text-overflow: ellipsis; | |
white-space: nowrap; | |
} | |
.metadata-cell:hover { | |
white-space: normal; | |
overflow: visible; | |
} | |
</style> | |
<table class="results-table"> | |
<thead> | |
<tr> | |
<th>Preview</th> | |
<th>File Name</th> | |
<th>Final Score</th> | |
<th>Technical</th> | |
<th>Aesthetic</th> | |
<th>Anime</th> | |
<th>Prompt</th> | |
</tr> | |
</thead> | |
<tbody> | |
""" | |
for result in sorted_results: | |
# Determine score class | |
score = result.get('final_score', 0) | |
if score >= 7.5: | |
score_class = "high-score" | |
elif score >= 5: | |
score_class = "medium-score" | |
else: | |
score_class = "low-score" | |
# Get technical score | |
technical_score = "N/A" | |
if 'technical' in result and 'overall_technical' in result['technical']: | |
technical_score = f"{result['technical']['overall_technical']:.2f}" | |
# Get aesthetic score | |
aesthetic_score = "N/A" | |
if 'aesthetic' in result and 'overall_aesthetic' in result['aesthetic']: | |
aesthetic_score = f"{result['aesthetic']['overall_aesthetic']:.2f}" | |
# Get anime score | |
anime_score = "N/A" | |
if 'anime_specialized' in result and 'overall_anime' in result['anime_specialized']: | |
anime_score = f"{result['anime_specialized']['overall_anime']:.2f}" | |
# Get prompt from metadata | |
prompt = "N/A" | |
if 'metadata' in result and result['metadata'].get('prompt'): | |
prompt = result['metadata']['prompt'] | |
# Add row to table | |
html += f""" | |
<tr> | |
<td><img src="data:image/jpeg;base64,{result['img_data']}" class="image-preview"></td> | |
<td>{result['file_name']}</td> | |
<td class="score {score_class}">{score:.2f}</td> | |
<td>{technical_score}</td> | |
<td>{aesthetic_score}</td> | |
<td>{anime_score}</td> | |
<td class="metadata-cell">{prompt}</td> | |
</tr> | |
""" | |
html += """ | |
</tbody> | |
</table> | |
""" | |
return html | |
def export_results(format_type): | |
""" | |
Export evaluation results to file. | |
Args: | |
format_type: Export format ('csv', 'json', 'html', or 'markdown'). | |
Returns: | |
str: Path to exported file. | |
""" | |
global evaluation_results | |
if not evaluation_results: | |
return "No evaluation results to export." | |
# Create output directory | |
output_dir = "/tmp/image_evaluator_results" | |
os.makedirs(output_dir, exist_ok=True) | |
# Compare models if multiple models are available | |
if len(evaluation_results) >= 2: | |
comparison = evaluator_manager.compare_models(evaluation_results) | |
else: | |
comparison = None | |
# Create DataFrame for the results | |
models = list(evaluation_results.keys()) | |
metrics = ['technical', 'aesthetic', 'anime_specialized', 'overall'] | |
if comparison: | |
data = [] | |
for model in models: | |
row = {'Model': model} | |
for metric in metrics: | |
if metric in comparison['comparison_metrics'] and model in comparison['comparison_metrics'][metric]: | |
row[metric.capitalize()] = comparison['comparison_metrics'][metric][model] | |
else: | |
row[metric.capitalize()] = 0.0 | |
data.append(row) | |
df = pd.DataFrame(data) | |
# Add ranking information | |
for rank_info in comparison['rankings']: | |
if rank_info['model'] in df['Model'].values: | |
df.loc[df['Model'] == rank_info['model'], 'Rank'] = rank_info['rank'] | |
# Sort by rank | |
df = df.sort_values('Rank') | |
else: | |
# Single model, create detailed results | |
model = models[0] | |
data = [] | |
for img_id, results in evaluation_results[model].items(): | |
row = {'Image': img_id} | |
# Add metadata if available | |
if 'metadata' in results and results['metadata'].get('prompt'): | |
row['Prompt'] = results['metadata']['prompt'] | |
# Add evaluator results | |
for evaluator_id in ['technical', 'aesthetic', 'anime_specialized']: | |
if evaluator_id in results: | |
for metric, value in results[evaluator_id].items(): | |
if isinstance(value, (int, float)): | |
row[f"{evaluator_id}_{metric}"] = value | |
data.append(row) | |
df = pd.DataFrame(data) | |
# Export based on format | |
if format_type == 'csv': | |
output_path = os.path.join(output_dir, 'evaluation_results.csv') | |
df.to_csv(output_path, index=False) | |
elif format_type == 'json': | |
output_path = os.path.join(output_dir, 'evaluation_results.json') | |
if comparison: | |
export_data = { | |
'comparison': comparison, | |
'results': evaluation_results | |
} | |
else: | |
export_data = evaluation_results | |
with open(output_path, 'w') as f: | |
json.dump(export_data, f, indent=2) | |
elif format_type == 'html': | |
output_path = os.path.join(output_dir, 'evaluation_results.html') | |
# Create HTML with both table and visualizations | |
html_content = """ | |
<!DOCTYPE html> | |
<html> | |
<head> | |
<title>Image Evaluation Results</title> | |
<style> | |
body { font-family: Arial, sans-serif; margin: 20px; } | |
h1, h2 { color: #333; } | |
.container { margin-bottom: 30px; } | |
table { border-collapse: collapse; width: 100%; } | |
th, td { border: 1px solid #ddd; padding: 8px; text-align: left; } | |
th { background-color: #f2f2f2; } | |
tr:nth-child(even) { background-color: #f9f9f9; } | |
.chart { margin: 20px 0; max-width: 800px; } | |
.best-model { font-weight: bold; color: green; } | |
</style> | |
</head> | |
<body> | |
<h1>Image Evaluation Results</h1> | |
""" | |
if comparison: | |
html_content += f""" | |
<div class="container"> | |
<h2>Model Comparison</h2> | |
<p class="best-model">Best model: {comparison['best_model']}</p> | |
<table> | |
<tr> | |
<th>Rank</th> | |
<th>Model</th> | |
<th>Overall Score</th> | |
<th>Technical</th> | |
<th>Aesthetic</th> | |
<th>Anime</th> | |
</tr> | |
""" | |
for rank in comparison['rankings']: | |
model = rank['model'] | |
html_content += f""" | |
<tr> | |
<td>{rank['rank']}</td> | |
<td>{model}</td> | |
<td>{rank['score']:.2f}</td> | |
<td>{comparison['comparison_metrics']['technical'].get(model, 0):.2f}</td> | |
<td>{comparison['comparison_metrics']['aesthetic'].get(model, 0):.2f}</td> | |
<td>{comparison['comparison_metrics']['anime_specialized'].get(model, 0):.2f}</td> | |
</tr> | |
""" | |
html_content += """ | |
</table> | |
</div> | |
""" | |
# Add charts | |
html_content += """ | |
<div class="container"> | |
<h2>Visualizations</h2> | |
<div class="chart"> | |
<h3>Overall Scores</h3> | |
<img src="overall_comparison.png" alt="Overall Scores Chart"> | |
</div> | |
<div class="chart"> | |
<h3>Detailed Metrics</h3> | |
<img src="radar_comparison.png" alt="Radar Chart"> | |
</div> | |
</div> | |
""" | |
# Save charts | |
plt.figure(figsize=(10, 6)) | |
overall_scores = [comparison['comparison_metrics']['overall'].get(model, 0) for model in models] | |
bars = plt.bar(models, overall_scores, color='skyblue') | |
for bar in bars: | |
height = bar.get_height() | |
plt.text(bar.get_x() + bar.get_width()/2., height + 0.01, f'{height:.2f}', ha='center', va='bottom') | |
plt.title('Overall Quality Scores by Model') | |
plt.xlabel('Model') | |
plt.ylabel('Score') | |
plt.ylim(0, 10.5) | |
plt.grid(axis='y', linestyle='--', alpha=0.7) | |
plt.savefig(os.path.join(output_dir, 'overall_comparison.png')) | |
plt.close() | |
# Create radar chart | |
categories = [m.capitalize() for m in metrics[:-1]] | |
N = len(categories) | |
angles = [n / float(N) * 2 * np.pi for n in range(N)] | |
angles += angles[:1] | |
plt.figure(figsize=(10, 10)) | |
ax = plt.subplot(111, polar=True) | |
colors = plt.cm.tab10(np.linspace(0, 1, len(models))) | |
for i, model in enumerate(models): | |
values = [comparison['comparison_metrics'][metric].get(model, 0) for metric in metrics[:-1]] | |
values += values[:1] | |
ax.plot(angles, values, linewidth=2, linestyle='solid', label=model, color=colors[i]) | |
ax.fill(angles, values, alpha=0.1, color=colors[i]) | |
plt.xticks(angles[:-1], categories) | |
ax.set_ylim(0, 10) | |
plt.legend(loc='upper right', bbox_to_anchor=(0.1, 0.1)) | |
plt.title('Detailed Metrics Comparison by Model') | |
plt.savefig(os.path.join(output_dir, 'radar_comparison.png')) | |
plt.close() | |
# Add detailed results for each model | |
for model in models: | |
html_content += f""" | |
<div class="container"> | |
<h2>Detailed Results: {model}</h2> | |
<table> | |
<tr> | |
<th>Image</th> | |
<th>Technical</th> | |
<th>Aesthetic</th> | |
<th>Anime</th> | |
<th>Prompt</th> | |
</tr> | |
""" | |
for img_id, results in evaluation_results[model].items(): | |
technical = results.get('technical', {}).get('overall_technical', 'N/A') | |
aesthetic = results.get('aesthetic', {}).get('overall_aesthetic', 'N/A') | |
anime = results.get('anime_specialized', {}).get('overall_anime', 'N/A') | |
prompt = results.get('metadata', {}).get('prompt', 'N/A') | |
if isinstance(technical, (int, float)): | |
technical = f"{technical:.2f}" | |
if isinstance(aesthetic, (int, float)): | |
aesthetic = f"{aesthetic:.2f}" | |
if isinstance(anime, (int, float)): | |
anime = f"{anime:.2f}" | |
html_content += f""" | |
<tr> | |
<td>{img_id}</td> | |
<td>{technical}</td> | |
<td>{aesthetic}</td> | |
<td>{anime}</td> | |
<td>{prompt}</td> | |
</tr> | |
""" | |
html_content += """ | |
</table> | |
</div> | |
""" | |
html_content += """ | |
</body> | |
</html> | |
""" | |
with open(output_path, 'w') as f: | |
f.write(html_content) | |
elif format_type == 'markdown': | |
output_path = os.path.join(output_dir, 'evaluation_results.md') | |
md_content = "# Image Evaluation Results\n\n" | |
if comparison: | |
md_content += f"## Model Comparison\n\n**Best model: {comparison['best_model']}**\n\n" | |
md_content += "| Rank | Model | Overall Score | Technical | Aesthetic | Anime |\n" | |
md_content += "|------|-------|--------------|-----------|-----------|-------|\n" | |
for rank in comparison['rankings']: | |
model = rank['model'] | |
md_content += f"| {rank['rank']} | {model} | {rank['score']:.2f} | " | |
md_content += f"{comparison['comparison_metrics']['technical'].get(model, 0):.2f} | " | |
md_content += f"{comparison['comparison_metrics']['aesthetic'].get(model, 0):.2f} | " | |
md_content += f"{comparison['comparison_metrics']['anime_specialized'].get(model, 0):.2f} |\n" | |
md_content += "\n" | |
# Add detailed results for each model | |
for model in models: | |
md_content += f"## Detailed Results: {model}\n\n" | |
md_content += "| Image | Technical | Aesthetic | Anime | Prompt |\n" | |
md_content += "|-------|-----------|-----------|-------|--------|\n" | |
for img_id, results in evaluation_results[model].items(): | |
technical = results.get('technical', {}).get('overall_technical', 'N/A') | |
aesthetic = results.get('aesthetic', {}).get('overall_aesthetic', 'N/A') | |
anime = results.get('anime_specialized', {}).get('overall_anime', 'N/A') | |
prompt = results.get('metadata', {}).get('prompt', 'N/A') | |
if isinstance(technical, (int, float)): | |
technical = f"{technical:.2f}" | |
if isinstance(aesthetic, (int, float)): | |
aesthetic = f"{aesthetic:.2f}" | |
if isinstance(anime, (int, float)): | |
anime = f"{anime:.2f}" | |
# Truncate prompt if too long | |
if len(str(prompt)) > 50: | |
prompt = str(prompt)[:47] + "..." | |
md_content += f"| {img_id} | {technical} | {aesthetic} | {anime} | {prompt} |\n" | |
md_content += "\n" | |
with open(output_path, 'w') as f: | |
f.write(md_content) | |
else: | |
return f"Unsupported format: {format_type}" | |
return output_path | |
def reset_data(): | |
"""Reset all uploaded images and evaluation results.""" | |
global uploaded_images, evaluation_results | |
uploaded_images = {} | |
evaluation_results = {} | |
return "All data has been reset." | |
def create_interface(): | |
"""Create Gradio interface.""" | |
# Get available evaluators | |
available_evaluators = evaluator_manager.get_available_evaluators() | |
evaluator_choices = [e['id'] for e in available_evaluators] | |
with gr.Blocks(title="Image Evaluator") as interface: | |
gr.Markdown("# Image Evaluator") | |
gr.Markdown("Tool for evaluating and comparing images generated by different AI models") | |
with gr.Tab("Upload & Evaluate"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
images_input = gr.File(file_count="multiple", label="Upload Images") | |
model_name_input = gr.Textbox(label="Model Name", placeholder="Enter model name") | |
evaluator_select = gr.CheckboxGroup(choices=evaluator_choices, label="Select Evaluators", value=evaluator_choices) | |
auto_batch = gr.Checkbox(label="Auto Batch Size", value=True) | |
batch_size = gr.Number(label="Batch Size (if Auto is off)", value=4, precision=0) | |
evaluate_button = gr.Button("Evaluate Images") | |
with gr.Column(scale=2): | |
with gr.Row(): | |
evaluation_output = gr.Textbox(label="Evaluation Status") | |
progress = gr.Number(label="Progress (%)", value=0, precision=0) | |
log_output = gr.Textbox(label="Processing Log", lines=10) | |
results_table = gr.HTML(label="Results Table") | |
with gr.Tab("Compare Models"): | |
with gr.Row(): | |
compare_button = gr.Button("Compare Models") | |
with gr.Row(): | |
with gr.Column(): | |
comparison_output = gr.Textbox(label="Comparison Results") | |
with gr.Column(): | |
overall_chart = gr.Image(label="Overall Scores") | |
radar_chart = gr.Image(label="Detailed Metrics") | |
with gr.Tab("Metadata Viewer"): | |
with gr.Row(): | |
with gr.Column(): | |
metadata_image_input = gr.Image(type="pil", label="Upload Image for Metadata") | |
with gr.Column(): | |
metadata_output = gr.Textbox(label="Image Metadata", lines=10) | |
with gr.Row(): | |
copy_metadata_button = gr.Button("Copy Metadata") | |
update_metadata_button = gr.Button("Update Metadata") | |
with gr.Tab("Export Results"): | |
with gr.Row(): | |
format_select = gr.Radio(choices=["csv", "json", "html", "markdown"], label="Export Format", value="html") | |
export_button = gr.Button("Export Results") | |
with gr.Row(): | |
export_output = gr.Textbox(label="Export Status") | |
with gr.Tab("Help"): | |
gr.Markdown(""" | |
## How to Use Image Evaluator | |
### Step 1: Upload Images | |
- Go to the "Upload & Evaluate" tab | |
- Upload images for a specific model | |
- Enter the model name | |
- Select which evaluators to use | |
- Click "Evaluate Images" | |
- Repeat for each model you want to compare | |
### Step 2: Compare Models | |
- Go to the "Compare Models" tab | |
- Click "Compare Models" to see results | |
- The best model will be highlighted | |
- View charts for visual comparison | |
### Step 3: View Metadata | |
- Go to the "Metadata Viewer" tab | |
- Upload an image to view its metadata | |
- Edit metadata if needed | |
### Step 4: Export Results | |
- Go to the "Export Results" tab | |
- Select export format (CSV, JSON, HTML, or Markdown) | |
- Click "Export Results" | |
- Download the exported file | |
### Available Metrics | |
#### Technical Metrics | |
- Sharpness: Measures image clarity and detail | |
- Noise: Measures absence of unwanted variations | |
- Artifacts: Measures absence of compression artifacts | |
- Saturation: Measures color intensity | |
- Contrast: Measures difference between light and dark areas | |
#### Aesthetic Metrics | |
- Color Harmony: Measures how well colors work together | |
- Composition: Measures adherence to compositional principles | |
- Visual Interest: Measures how visually engaging the image is | |
- Aesthetic Predictor: Score from Aesthetic Predictor V2.5 model | |
- Aesthetic Shadow: Score from Aesthetic Shadow model | |
#### Anime-Specific Metrics | |
- Line Quality: Measures clarity and quality of line work | |
- Color Palette: Evaluates color choices for anime style | |
- Character Quality: Assesses character design and rendering using Waifu Scorer | |
- Anime Aesthetic: Score from specialized anime aesthetic model | |
- Style Consistency: Measures adherence to anime style conventions | |
""") | |
with gr.Row(): | |
reset_button = gr.Button("Reset All Data") | |
reset_output = gr.Textbox(label="Reset Status") | |
# Event handlers | |
evaluate_button.click( | |
fn=lambda *args: asyncio.create_task(evaluate_images_async(*args)), | |
inputs=[images_input, model_name_input, evaluator_select, auto_batch, batch_size], | |
outputs=[results_table, log_output, progress, batch_size] | |
) | |
compare_button.click( | |
compare_models, | |
inputs=[], | |
outputs=[comparison_output, overall_chart, radar_chart] | |
) | |
metadata_image_input.change( | |
extract_metadata_from_image, | |
inputs=[metadata_image_input], | |
outputs=[metadata_image_input, metadata_output] | |
) | |
update_metadata_button.click( | |
update_image_metadata, | |
inputs=[metadata_image_input, metadata_output], | |
outputs=[metadata_image_input, metadata_output] | |
) | |
copy_metadata_button.click( | |
lambda x: x, | |
inputs=[metadata_output], | |
outputs=[metadata_output] | |
) | |
export_button.click( | |
export_results, | |
inputs=[format_select], | |
outputs=[export_output] | |
) | |
reset_button.click( | |
reset_data, | |
inputs=[], | |
outputs=[reset_output] | |
) | |
return interface | |
# Create and launch the interface | |
interface = create_interface() | |
if __name__ == "__main__": | |
# Import re here to avoid circular import | |
interface.launch(server_name="0.0.0.0") | |