FindMyArt / app.py
sgugger's picture
Working version
568dec5 unverified
raw history blame
No virus
4.08 kB
import gradio as gr
from PIL import Image
import numpy as np
from scipy.fftpack import dct
from datasets import load_dataset
from PIL import Image
from multiprocessing import cpu_count
def perceptual_hash_color(image):
image = image.convert("RGB") # Convert to grayscale
image = image.resize((32, 32), Image.ANTIALIAS) # Resize to 32x32
image_array = np.asarray(image) # Convert to numpy array
hashes = []
for i in range(3):
channel = image_array[:, :, i]
dct_coef = dct(dct(channel, axis=0), axis=1) # Compute DCT
dct_reduced_coef = dct_coef[:8, :8] # Retain top-left 8x8 DCT coefficients
# Median of DCT coefficients excluding the DC term (0th term)
median_coef_val = np.median(np.ndarray.flatten(dct_reduced_coef)[1:])
# Mask of all coefficients greater than median of coefficients
hashes.append((dct_reduced_coef >= median_coef_val).flatten() * 1)
return np.concatenate(hashes)
def hamming_distance(array_1, array_2):
return len([1 for el_1, el_2 in zip(array_1, array_2) if el_1 != el_2])
def search_closest_examples(hash_refs, img_dataset):
distances = []
for hash_ref in hash_refs:
distances.extend([hamming_distance(hash_ref, img_dataset[idx]["hash"]) for idx in range(img_dataset.num_rows)])
closests = [i.item() % len(img_dataset) for i in np.argsort(distances)[:9]]
return closests, [distances[c] for c in closests]
def find_closest_images(images, img_dataset):
if not isinstance(images, (list, tuple)):
images = [images]
hashes = [perceptual_hash_color(img) for img in images]
closest_idx, distances = search_closest_examples(hashes, img_dataset)
return closest_idx, distances
def compute_hash_from_image(img):
img = img.convert("L") # Convert to grayscale
img = img.resize((32, 32), Image.ANTIALIAS) # Resize to 32x32
img_array = np.asarray(img) # Convert to numpy array
dct_coef = dct(dct(img_array, axis=0), axis=1) # Compute DCT
dct_reduced_coef = dct_coef[:8, :8] # Retain top-left 8x8 DCT coefficients
# Median of DCT coefficients excluding the DC term (0th term)
median_coef_val = np.median(np.ndarray.flatten(dct_reduced_coef)[1:])
# Mask of all coefficients greater than median of coefficients
hash = (dct_reduced_coef >= median_coef_val).flatten() * 1
return hash
def process_dataset(dataset_name, dataset_split, dataset_column_image):
img_dataset = load_dataset(dataset_name)[dataset_split]
def add_hash(example):
example["hash"] = perceptual_hash_color(example[dataset_column_image])
return example
# Compute hash of every image in the dataset
img_dataset = img_dataset.map(add_hash, num_proc=max(cpu_count() // 2, 1))
return img_dataset
def compute(dataset_name, dataset_split, dataset_column_image, img):
img_dataset = process_dataset(dataset_name, dataset_split, dataset_column_image)
closest_idx, distances = find_closest_images(img, img_dataset)
return [img_dataset[i][dataset_column_image] for i in closest_idx]
with gr.Blocks() as demo:
gr.Markdown("# Find if your images are in a public dataset!")
with gr.Row():
with gr.Column(scale=1, min_width=600):
dataset_name = gr.Textbox(label="Enter the name of a dataset containing images", value="huggan/few-shot-pokemon")
dataset_split = gr.Textbox(label="Enter the split of this dataset to consider", value="train")
dataset_column_image = gr.Textbox(label="Enter the name of the column of this dataset that contains images", value="image")
img = gr.Image(label="Input your image that will be compared against images of the dataset", type="pil")
btn = gr.Button("Find").style(full_width=True)
with gr.Column(scale=2, min_width=600):
gallery_similar = gr.Gallery(label="similar images")
gallery_similar.style(grid=3)
event = btn.click(compute, [dataset_name, dataset_split, dataset_column_image, img], gallery_similar)
demo.launch()