Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import clip | |
| import pandas as pd | |
| import hashlib | |
| import numpy as np | |
| import cv2 | |
| import time | |
| from PIL import Image | |
| # MLP model definition | |
| class MLP(nn.Module): | |
| def __init__(self, input_size): | |
| super().__init__() | |
| self.layers = nn.Sequential( | |
| nn.Linear(input_size, 1024), | |
| nn.Dropout(0.2), | |
| nn.Linear(1024, 128), | |
| nn.Dropout(0.2), | |
| nn.Linear(128, 64), | |
| nn.Dropout(0.1), | |
| nn.Linear(64, 16), | |
| nn.Linear(16, 1), | |
| ) | |
| def forward(self, x): | |
| return self.layers(x) | |
| # Convert binary array to hexadecimal string | |
| def binary_array_to_hex(arr): | |
| bit_string = ''.join(str(b) for b in 1 * arr.flatten()) | |
| width = int(np.ceil(len(bit_string) / 4)) | |
| return '{:0>{width}x}'.format(int(bit_string, 2), width=width) | |
| # Calculate perceptual hash of an image | |
| def phash(image, hash_size=8, highfreq_factor=4): | |
| if hash_size < 2: | |
| raise ValueError('Hash size must be greater than or equal to 2') | |
| import scipy.fftpack | |
| img_size = hash_size * highfreq_factor | |
| image = image.convert('L').resize((img_size, img_size), Image.Resampling.LANCZOS) | |
| pixels = np.asarray(image) | |
| dct = scipy.fftpack.dct(scipy.fftpack.dct(pixels, axis=0), axis=1) | |
| dctlowfreq = dct[:hash_size, :hash_size] | |
| med = np.median(dctlowfreq) | |
| diff = dctlowfreq > med | |
| return binary_array_to_hex(diff) | |
| # Convert NumPy types to Python built-in types | |
| def convert_numpy_types(data): | |
| if isinstance(data, dict): | |
| return {key: convert_numpy_types(value) for key, value in data.items()} | |
| elif isinstance(data, list): | |
| return [convert_numpy_types(item) for item in data] | |
| elif isinstance(data, np.float64): | |
| return float(data) | |
| elif isinstance(data, np.int64): | |
| return int(data) | |
| else: | |
| return data | |
| # Normalize tensor | |
| def normalize(a, axis=-1, order=2): | |
| l2 = torch.linalg.norm(a, dim=axis, ord=order, keepdim=True) | |
| l2[l2 == 0] = 1 | |
| return a / l2 | |
| # Load pre-trained MLP model and CLIP model | |
| model = MLP(768) # CLIP embedding dim is 768 for CLIP ViT L 14 | |
| pthpath = "https://huggingface.co/haor/aesthetics/resolve/main/sac%2Blogos%2Bava1-l14-linearMSE.pth" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.load_state_dict(torch.hub.load_state_dict_from_url(pthpath, map_location=device)) | |
| model.to(device).eval() | |
| model2, preprocess = clip.load("ViT-L/14", device=device) | |
| # Predict aesthetic score and other metrics of an image | |
| def predict(image): | |
| # Preprocess image | |
| image = Image.fromarray(image) | |
| image_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY) | |
| laplacian_variance = cv2.Laplacian(image_np, cv2.CV_64F).var() | |
| phash_value = phash(image) | |
| md5 = hashlib.md5(image.tobytes()).hexdigest() | |
| sha1 = hashlib.sha1(image.tobytes()).hexdigest() | |
| inputs = preprocess(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| # Extract image features using CLIP model | |
| start_time = time.time() | |
| img_emb = model2.encode_image(inputs) | |
| end_time = time.time() | |
| print(f"Encoding image took {end_time - start_time} seconds") | |
| # Normalize image features | |
| start_time = time.time() | |
| img_emb = normalize(img_emb).float() | |
| end_time = time.time() | |
| print(f"Normalizing image took {end_time - start_time} seconds") | |
| # Predict aesthetic score using MLP model | |
| start_time = time.time() | |
| prediction = model(img_emb).item() | |
| end_time = time.time() | |
| print(f"Making prediction took {end_time - start_time} seconds") | |
| # Return prediction results | |
| result = { | |
| "clip_aesthetic": prediction, | |
| "phash": phash_value, | |
| "md5": md5, | |
| "sha1": sha1, | |
| "laplacian_variance": laplacian_variance | |
| } | |
| return convert_numpy_types(result) | |
| # Create web interface using Gradio | |
| title = "CLIP Aesthetic Score" | |
| description = "Upload an image to predict its aesthetic score using the CLIP model and calculate other image metrics." | |
| gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="numpy"), | |
| outputs=gr.JSON(label="Result"), | |
| title=title, | |
| description=description, | |
| examples=[["example1.jpg"], ["example2.jpg"]] | |
| ).launch(server_name='0.0.0.0') |