Spaces:
Sleeping
Sleeping
""" | |
Aesthetic metrics for image quality assessment using AI models. | |
These metrics evaluate subjective aspects of images like aesthetic appeal, composition, etc. | |
""" | |
import torch | |
import numpy as np | |
from PIL import Image | |
from transformers import AutoFeatureExtractor, AutoModelForImageClassification, CLIPProcessor, CLIPModel | |
import torchvision.transforms as transforms | |
class AestheticMetrics: | |
"""Class for computing aesthetic image quality metrics using AI models.""" | |
def __init__(self): | |
"""Initialize models for aesthetic evaluation.""" | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self._initialize_models() | |
def _initialize_models(self): | |
"""Initialize all required models.""" | |
# Initialize CLIP model for text-image similarity using transformers | |
try: | |
self.clip_model_name = "openai/clip-vit-base-patch32" | |
self.clip_processor = CLIPProcessor.from_pretrained(self.clip_model_name) | |
self.clip_model = CLIPModel.from_pretrained(self.clip_model_name) | |
self.clip_model.to(self.device) | |
self.clip_loaded = True | |
except Exception as e: | |
print(f"Warning: Could not load CLIP model: {e}") | |
self.clip_loaded = False | |
# Initialize aesthetic predictor model (LAION Aesthetic Predictor v2) | |
try: | |
self.aesthetic_model_name = "cafeai/cafe_aesthetic" | |
self.aesthetic_extractor = AutoFeatureExtractor.from_pretrained(self.aesthetic_model_name) | |
self.aesthetic_model = AutoModelForImageClassification.from_pretrained(self.aesthetic_model_name) | |
self.aesthetic_model.to(self.device) | |
self.aesthetic_loaded = True | |
except Exception as e: | |
print(f"Warning: Could not load aesthetic model: {e}") | |
self.aesthetic_loaded = False | |
# Initialize transforms for preprocessing | |
self.transform = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
def calculate_aesthetic_score(self, image_path): | |
""" | |
Calculate aesthetic score using a pre-trained model. | |
Args: | |
image_path: path to the image file | |
Returns: | |
float: aesthetic score between 0 and 10 | |
""" | |
if not self.aesthetic_loaded: | |
return 5.0 # Default middle score if model not loaded | |
try: | |
image = Image.open(image_path).convert('RGB') | |
inputs = self.aesthetic_extractor(images=image, return_tensors="pt").to(self.device) | |
with torch.no_grad(): | |
outputs = self.aesthetic_model(**inputs) | |
# Get predicted class probabilities | |
probs = torch.nn.functional.softmax(outputs.logits, dim=1) | |
# Calculate weighted score (0-10 scale) | |
score_weights = torch.tensor([i for i in range(10)]).to(self.device).float() | |
aesthetic_score = torch.sum(probs * score_weights).item() | |
return aesthetic_score | |
except Exception as e: | |
print(f"Error calculating aesthetic score: {e}") | |
return 5.0 | |
def calculate_composition_score(self, image_path): | |
""" | |
Estimate composition quality using rule of thirds and symmetry analysis. | |
Args: | |
image_path: path to the image file | |
Returns: | |
float: composition score between 0 and 10 | |
""" | |
try: | |
# Load image | |
image = Image.open(image_path).convert('RGB') | |
img_array = np.array(image) | |
# Calculate rule of thirds score | |
h, w = img_array.shape[:2] | |
third_h, third_w = h // 3, w // 3 | |
# Define rule of thirds points | |
thirds_points = [ | |
(third_w, third_h), (2*third_w, third_h), | |
(third_w, 2*third_h), (2*third_w, 2*third_h) | |
] | |
# Calculate edge detection to find important elements | |
gray = np.mean(img_array, axis=2).astype(np.uint8) | |
edges = np.abs(np.diff(gray, axis=0, append=0)) + np.abs(np.diff(gray, axis=1, append=0)) | |
# Calculate score based on edge concentration near thirds points | |
thirds_score = 0 | |
for px, py in thirds_points: | |
# Get region around thirds point | |
region = edges[max(0, py-50):min(h, py+50), max(0, px-50):min(w, px+50)] | |
thirds_score += np.mean(region) | |
# Normalize score | |
thirds_score = min(10, thirds_score / 100) | |
# Calculate symmetry score | |
flipped = np.fliplr(img_array) | |
symmetry_diff = np.mean(np.abs(img_array.astype(float) - flipped.astype(float))) | |
symmetry_score = 10 * (1 - symmetry_diff / 255) | |
# Combine scores (weighted average) | |
composition_score = 0.7 * thirds_score + 0.3 * symmetry_score | |
return min(10, max(0, composition_score)) | |
except Exception as e: | |
print(f"Error calculating composition score: {e}") | |
return 5.0 | |
def calculate_color_harmony(self, image_path): | |
""" | |
Calculate color harmony score based on color theory. | |
Args: | |
image_path: path to the image file | |
Returns: | |
float: color harmony score between 0 and 10 | |
""" | |
try: | |
# Load image | |
image = Image.open(image_path).convert('RGB') | |
img_array = np.array(image) | |
# Convert to HSV for better color analysis | |
hsv = np.array(image.convert('HSV')) | |
# Extract hue channel and create histogram | |
hue = hsv[:,:,0].flatten() | |
hist, _ = np.histogram(hue, bins=36, range=(0, 255)) | |
hist = hist / np.sum(hist) | |
# Calculate entropy of hue distribution | |
entropy = -np.sum(hist * np.log2(hist + 1e-10)) | |
# Calculate complementary color usage | |
complementary_score = 0 | |
for i in range(18): | |
complementary_i = (i + 18) % 36 | |
complementary_score += min(hist[i], hist[complementary_i]) | |
# Calculate analogous color usage | |
analogous_score = 0 | |
for i in range(36): | |
analogous_i1 = (i + 1) % 36 | |
analogous_i2 = (i + 35) % 36 | |
analogous_score += min(hist[i], max(hist[analogous_i1], hist[analogous_i2])) | |
# Calculate saturation variance as a measure of color interest | |
saturation = hsv[:,:,1].flatten() | |
saturation_variance = np.var(saturation) | |
# Combine metrics into final score | |
harmony_score = ( | |
3 * (1 - min(1, entropy/5)) + # Lower entropy is better for harmony | |
3 * complementary_score + # Complementary colors | |
2 * analogous_score + # Analogous colors | |
2 * min(1, saturation_variance/2000) # Saturation variance | |
) | |
return min(10, max(0, harmony_score)) | |
except Exception as e: | |
print(f"Error calculating color harmony: {e}") | |
return 5.0 | |
def calculate_prompt_similarity(self, image_path, prompt): | |
""" | |
Calculate similarity between image and text prompt using CLIP. | |
Args: | |
image_path: path to the image file | |
prompt: text prompt used to generate the image | |
Returns: | |
float: similarity score between 0 and 10 | |
""" | |
if not self.clip_loaded or not prompt: | |
return 5.0 # Default middle score if model not loaded or no prompt | |
try: | |
# Load image | |
image = Image.open(image_path).convert('RGB') | |
# Process inputs with CLIP processor | |
inputs = self.clip_processor( | |
text=[prompt], | |
images=image, | |
return_tensors="pt", | |
padding=True | |
).to(self.device) | |
# Calculate similarity | |
with torch.no_grad(): | |
outputs = self.clip_model(**inputs) | |
logits_per_image = outputs.logits_per_image | |
similarity = logits_per_image.item() | |
# Convert to 0-10 scale (CLIP similarity is typically in 0-100 range) | |
return min(10, max(0, similarity / 10)) | |
except Exception as e: | |
print(f"Error calculating prompt similarity: {e}") | |
return 5.0 | |
def calculate_all_metrics(self, image_path, prompt=None): | |
""" | |
Calculate all aesthetic metrics for an image. | |
Args: | |
image_path: path to the image file | |
prompt: optional text prompt used to generate the image | |
Returns: | |
dict: dictionary with all metric scores | |
""" | |
metrics = { | |
'aesthetic_score': self.calculate_aesthetic_score(image_path), | |
'composition_score': self.calculate_composition_score(image_path), | |
'color_harmony': self.calculate_color_harmony(image_path), | |
} | |
# Add prompt similarity if prompt is provided | |
if prompt: | |
metrics['prompt_similarity'] = self.calculate_prompt_similarity(image_path, prompt) | |
return metrics | |