haor's picture
Update app.py
3e4553e verified
raw
history blame
No virus
4.76 kB
import gradio as gr
import numpy as np
import tensorflow as tf
import logging
from PIL import Image
from tensorflow.keras.preprocessing import image as keras_image
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input as resnet_preprocess
from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input as vgg_preprocess
import scipy.fftpack
import time
import clip
import torch
# Set up logging
logging.basicConfig(level=logging.INFO)
# Load models
resnet_model = ResNet50(weights='imagenet', include_top=False, pooling='avg')
vgg_model = VGG16(weights='imagenet', include_top=False, pooling='avg')
clip_model, preprocess_clip = clip.load("ViT-B/32", device="cpu")
# Preprocess function
def preprocess_img(img_path, target_size=(224, 224), preprocess_func=resnet_preprocess):
start_time = time.time()
img = keras_image.load_img(img_path, target_size=target_size)
img_array = keras_image.img_to_array(img)
img_array = np.expand_dims(img_array, axis=0)
img_array = preprocess_func(img_array)
logging.info(f"Image preprocessed in {time.time() - start_time:.4f} seconds")
return img_array
# Feature extraction function
def extract_features(img_path, model, preprocess_func):
img_array = preprocess_img(img_path, preprocess_func=preprocess_func)
start_time = time.time()
features = model.predict(img_array)
logging.info(f"Features extracted in {time.time() - start_time:.4f} seconds")
return features.flatten()
# Calculate cosine similarity
def cosine_similarity(vec1, vec2):
return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
# pHash related functions
def phashstr(image, hash_size=8, highfreq_factor=4):
img_size = hash_size * highfreq_factor
image = image.convert('L').resize((img_size, img_size), Image.Resampling.LANCZOS)
pixels = np.asarray(image)
dct = scipy.fftpack.dct(scipy.fftpack.dct(pixels, axis=0), axis=1)
dctlowfreq = dct[:hash_size, :hash_size]
med = np.median(dctlowfreq)
diff = dctlowfreq > med
return _binary_array_to_hex(diff.flatten())
def _binary_array_to_hex(arr):
h = 0
s = []
for i, v in enumerate(arr):
if v:
h += 2**(i % 8)
if (i % 8) == 7:
s.append(hex(h)[2:].rjust(2, '0'))
h = 0
return ''.join(s)
def hamming_distance(hash1, hash2):
if len(hash1) != len(hash2):
raise ValueError("Hashes must be of the same length")
return sum(c1 != c2 for c1, c2 in zip(hash1, hash2))
def hamming_to_similarity(distance, hash_length):
return (1 - distance / hash_length) * 100
# CLIP related functions
def extract_clip_features(image_path, model, preprocess):
image = preprocess(Image.open(image_path)).unsqueeze(0).to("cpu")
with torch.no_grad():
features = model.encode_image(image)
return features.cpu().numpy().flatten()
# Main function
def compare_images(image1, image2, method):
similarity = None
start_time = time.time()
if method == 'pHash':
img1 = Image.open(image1)
img2 = Image.open(image2)
hash1 = phashstr(img1)
hash2 = phashstr(img2)
distance = hamming_distance(hash1, hash2)
similarity = hamming_to_similarity(distance, len(hash1) * 4)
elif method == 'ResNet50':
features1 = extract_features(image1, resnet_model, resnet_preprocess)
features2 = extract_features(image2, resnet_model, resnet_preprocess)
similarity = cosine_similarity(features1, features2)
elif method == 'VGG16':
features1 = extract_features(image1, vgg_model, vgg_preprocess)
features2 = extract_features(image2, vgg_model, vgg_preprocess)
similarity = cosine_similarity(features1, features2)
elif method == 'CLIP':
features1 = extract_clip_features(image1, clip_model, preprocess_clip)
features2 = extract_clip_features(image2, clip_model, preprocess_clip)
similarity = cosine_similarity(features1, features2)
logging.info(f"Image comparison using {method} completed in {time.time() - start_time:.4f} seconds")
return similarity
# Gradio interface
demo = gr.Interface(
fn=compare_images,
inputs=[
gr.Image(type="filepath", label="Upload First Image"),
gr.Image(type="filepath", label="Upload Second Image"),
gr.Radio(["pHash", "ResNet50", "VGG16", "CLIP"], label="Select Comparison Method")
],
outputs=gr.Textbox(label="Similarity"),
title="Image Similarity Comparison",
description="Upload two images and select the comparison method.",
examples=[
["Snipaste_2024-05-31_16-18-31.jpg", "Snipaste_2024-05-31_16-18-52.jpg"],
["example1.png", "example2.png"]
]
)
demo.launch()