Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import torch.nn as nn | |
import clip | |
import pandas as pd | |
import hashlib | |
import numpy as np | |
import cv2 | |
import time | |
from PIL import Image | |
# MLP model definition | |
class MLP(nn.Module): | |
def __init__(self, input_size): | |
super().__init__() | |
self.layers = nn.Sequential( | |
nn.Linear(input_size, 1024), | |
nn.Dropout(0.2), | |
nn.Linear(1024, 128), | |
nn.Dropout(0.2), | |
nn.Linear(128, 64), | |
nn.Dropout(0.1), | |
nn.Linear(64, 16), | |
nn.Linear(16, 1), | |
) | |
def forward(self, x): | |
return self.layers(x) | |
# Convert binary array to hexadecimal string | |
def binary_array_to_hex(arr): | |
bit_string = ''.join(str(b) for b in 1 * arr.flatten()) | |
width = int(np.ceil(len(bit_string) / 4)) | |
return '{:0>{width}x}'.format(int(bit_string, 2), width=width) | |
# Calculate perceptual hash of an image | |
def phash(image, hash_size=8, highfreq_factor=4): | |
if hash_size < 2: | |
raise ValueError('Hash size must be greater than or equal to 2') | |
import scipy.fftpack | |
img_size = hash_size * highfreq_factor | |
image = image.convert('L').resize((img_size, img_size), Image.Resampling.LANCZOS) | |
pixels = np.asarray(image) | |
dct = scipy.fftpack.dct(scipy.fftpack.dct(pixels, axis=0), axis=1) | |
dctlowfreq = dct[:hash_size, :hash_size] | |
med = np.median(dctlowfreq) | |
diff = dctlowfreq > med | |
return binary_array_to_hex(diff) | |
# Convert NumPy types to Python built-in types | |
def convert_numpy_types(data): | |
if isinstance(data, dict): | |
return {key: convert_numpy_types(value) for key, value in data.items()} | |
elif isinstance(data, list): | |
return [convert_numpy_types(item) for item in data] | |
elif isinstance(data, np.float64): | |
return float(data) | |
elif isinstance(data, np.int64): | |
return int(data) | |
else: | |
return data | |
# Normalize tensor | |
def normalize(a, axis=-1, order=2): | |
l2 = torch.linalg.norm(a, dim=axis, ord=order, keepdim=True) | |
l2[l2 == 0] = 1 | |
return a / l2 | |
# Load pre-trained MLP model and CLIP model | |
model = MLP(768) # CLIP embedding dim is 768 for CLIP ViT L 14 | |
pthpath = "https://huggingface.co/haor/aesthetics/resolve/main/sac%2Blogos%2Bava1-l14-linearMSE.pth" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.load_state_dict(torch.hub.load_state_dict_from_url(pthpath, map_location=device)) | |
model.to(device).eval() | |
model2, preprocess = clip.load("ViT-L/14", device=device) | |
# Predict aesthetic score and other metrics of an image | |
def predict(image): | |
# Preprocess image | |
image = Image.fromarray(image) | |
image_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY) | |
laplacian_variance = cv2.Laplacian(image_np, cv2.CV_64F).var() | |
phash_value = phash(image) | |
md5 = hashlib.md5(image.tobytes()).hexdigest() | |
sha1 = hashlib.sha1(image.tobytes()).hexdigest() | |
inputs = preprocess(image).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
# Extract image features using CLIP model | |
start_time = time.time() | |
img_emb = model2.encode_image(inputs) | |
end_time = time.time() | |
print(f"Encoding image took {end_time - start_time} seconds") | |
# Normalize image features | |
start_time = time.time() | |
img_emb = normalize(img_emb).float() | |
end_time = time.time() | |
print(f"Normalizing image took {end_time - start_time} seconds") | |
# Predict aesthetic score using MLP model | |
start_time = time.time() | |
prediction = model(img_emb).item() | |
end_time = time.time() | |
print(f"Making prediction took {end_time - start_time} seconds") | |
# Return prediction results | |
result = { | |
"clip_aesthetic": prediction, | |
"phash": phash_value, | |
"md5": md5, | |
"sha1": sha1, | |
"laplacian_variance": laplacian_variance | |
} | |
return convert_numpy_types(result) | |
# Create web interface using Gradio | |
title = "CLIP Aesthetic Score" | |
description = "Upload an image to predict its aesthetic score using the CLIP model and calculate other image metrics." | |
gr.Interface( | |
fn=predict, | |
inputs=gr.Image(type="numpy"), | |
outputs=gr.JSON(label="Result"), | |
title=title, | |
description=description, | |
examples=[["example1.jpg"], ["example2.jpg"]] | |
).launch(server_name='0.0.0.0') |