File size: 3,954 Bytes
b3195da 22d5448 b3195da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import gradio as gr
from PIL import Image
import numpy as np
from scipy.fftpack import dct
from datasets import load_dataset
from PIL import Image
from multiprocessing import cpu_count
def perceptual_hash_color(image):
image = image.convert("RGB") # Convert to grayscale
image = image.resize((32, 32), Image.ANTIALIAS) # Resize to 32x32
image_array = np.asarray(image) # Convert to numpy array
hashes = []
for i in range(3):
channel = image_array[:, :, i]
dct_coef = dct(dct(channel, axis=0), axis=1) # Compute DCT
dct_reduced_coef = dct_coef[:8, :8] # Retain top-left 8x8 DCT coefficients
# Median of DCT coefficients excluding the DC term (0th term)
median_coef_val = np.median(np.ndarray.flatten(dct_reduced_coef)[1:])
# Mask of all coefficients greater than median of coefficients
hashes.append((dct_reduced_coef >= median_coef_val).flatten() * 1)
return np.concatenate(hashes)
def hamming_distance(array_1, array_2):
return len([1 for el_1, el_2 in zip(array_1, array_2) if el_1 != el_2])
def search_closest_examples(hash_refs, img_dataset):
distances = []
for hash_ref in hash_refs:
distances.extend([hamming_distance(hash_ref, img_dataset[idx]["hash"]) for idx in range(img_dataset.num_rows)])
closests = [i.item() % len(img_dataset) for i in np.argsort(distances)[:9]]
return closests, [distances[c] for c in closests]
def find_closest_images(images, img_dataset):
if not isinstance(images, (list, tuple)):
images = [images]
hashes = [perceptual_hash_color(img) for img in images]
closest_idx, distances = search_closest_examples(hashes, img_dataset)
return closest_idx, distances
def compute_hash_from_image(img):
img = img.convert("L") # Convert to grayscale
img = img.resize((32, 32), Image.ANTIALIAS) # Resize to 32x32
img_array = np.asarray(img) # Convert to numpy array
dct_coef = dct(dct(img_array, axis=0), axis=1) # Compute DCT
dct_reduced_coef = dct_coef[:8, :8] # Retain top-left 8x8 DCT coefficients
# Median of DCT coefficients excluding the DC term (0th term)
median_coef_val = np.median(np.ndarray.flatten(dct_reduced_coef)[1:])
# Mask of all coefficients greater than median of coefficients
hash = (dct_reduced_coef >= median_coef_val).flatten() * 1
return hash
def process_dataset(dataset_name, dataset_split, dataset_column_image):
img_dataset = load_dataset(dataset_name)[dataset_split]
def add_hash(example):
example["hash"] = perceptual_hash_color(example[dataset_column_image])
return example
# Compute hash of every image in the dataset
img_dataset = img_dataset.map(add_hash, num_proc=max(cpu_count() // 2, 1))
return img_dataset
def compute(dataset_name, dataset_split, dataset_column_image, img):
img_dataset = process_dataset(dataset_name, dataset_split, dataset_column_image)
closest_idx, distances = find_closest_images(img, img_dataset)
return [img_dataset[i] for i in closest_idx]
with gr.Blocks() as demo:
gr.Markdown("# Find if your images are in a public dataset!")
with gr.Row():
with gr.Column(scale=1, min_width=600):
dataset_name = gr.Textbox(label="Enter the name of a dataset containing images")
dataset_split = gr.Textbox(label="Enter the split of this dataset to consider")
dataset_column_image = gr.Textbox(label="Enter the name of the column of this dataset that contains images")
img = gr.Image(label="Input your image that will be compared against images of the dataset", type="pil")
btn = gr.Button("Find").style(full_width=True)
with gr.Column(scale=2, min_width=600):
gallery_similar = gr.Gallery(label="similar images")
event = btn.click(compute, [dataset_name, dataset_split, dataset_column_image, img], gallery_similar)
demo.launch() |