File size: 4,081 Bytes
b3195da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22d5448
b3195da
 
 
 
 
 
568dec5
b3195da
 
 
 
 
 
568dec5
 
 
b3195da
 
 
 
 
568dec5
b3195da
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import gradio as gr
from PIL import Image
import numpy as np
from scipy.fftpack import dct
from datasets import load_dataset
from PIL import Image
from multiprocessing import cpu_count


def perceptual_hash_color(image):
    image = image.convert("RGB")  # Convert to grayscale
    image = image.resize((32, 32), Image.ANTIALIAS)  # Resize to 32x32
    image_array = np.asarray(image)  # Convert to numpy array
    hashes = []
    for i in range(3):
        channel = image_array[:, :, i]
        dct_coef = dct(dct(channel, axis=0), axis=1)  # Compute DCT
        dct_reduced_coef = dct_coef[:8, :8]  # Retain top-left 8x8 DCT coefficients
        # Median of DCT coefficients excluding the DC term (0th term)
        median_coef_val = np.median(np.ndarray.flatten(dct_reduced_coef)[1:])
        # Mask of all coefficients greater than median of coefficients
        hashes.append((dct_reduced_coef >= median_coef_val).flatten() * 1)
    return np.concatenate(hashes)

def hamming_distance(array_1, array_2):
    return len([1 for el_1, el_2 in zip(array_1, array_2) if el_1 != el_2])

def search_closest_examples(hash_refs, img_dataset):
    distances = []
    for hash_ref in hash_refs:
        distances.extend([hamming_distance(hash_ref, img_dataset[idx]["hash"]) for idx in range(img_dataset.num_rows)])
    closests = [i.item() % len(img_dataset) for i in np.argsort(distances)[:9]]
    return closests, [distances[c] for c in closests]

def find_closest_images(images, img_dataset):
    if not isinstance(images, (list, tuple)):
        images = [images]
    hashes =  [perceptual_hash_color(img) for img in images]
    closest_idx, distances = search_closest_examples(hashes, img_dataset)
    return closest_idx, distances

def compute_hash_from_image(img):
    img = img.convert("L")  # Convert to grayscale
    img = img.resize((32, 32), Image.ANTIALIAS)  # Resize to 32x32
    img_array = np.asarray(img)  # Convert to numpy array
    dct_coef = dct(dct(img_array, axis=0), axis=1)  # Compute DCT
    dct_reduced_coef = dct_coef[:8, :8]  # Retain top-left 8x8 DCT coefficients
    # Median of DCT coefficients excluding the DC term (0th term)
    median_coef_val = np.median(np.ndarray.flatten(dct_reduced_coef)[1:])
    # Mask of all coefficients greater than median of coefficients
    hash = (dct_reduced_coef >= median_coef_val).flatten() * 1
    return hash


def process_dataset(dataset_name, dataset_split, dataset_column_image):
    img_dataset = load_dataset(dataset_name)[dataset_split]

    def add_hash(example):
        example["hash"] = perceptual_hash_color(example[dataset_column_image])
        return example

    # Compute hash of every image in the dataset
    img_dataset = img_dataset.map(add_hash, num_proc=max(cpu_count() // 2, 1))
    return img_dataset
    

def compute(dataset_name, dataset_split, dataset_column_image, img):
    img_dataset = process_dataset(dataset_name, dataset_split, dataset_column_image)
    closest_idx, distances = find_closest_images(img, img_dataset)
    return [img_dataset[i][dataset_column_image] for i in closest_idx]


with gr.Blocks() as demo:
    gr.Markdown("# Find if your images are in a public dataset!")
    with gr.Row():
        with gr.Column(scale=1, min_width=600):
            dataset_name = gr.Textbox(label="Enter the name of a dataset containing images", value="huggan/few-shot-pokemon")
            dataset_split = gr.Textbox(label="Enter the split of this dataset to consider", value="train")
            dataset_column_image = gr.Textbox(label="Enter the name of the column of this dataset that contains images", value="image")
            img = gr.Image(label="Input your image that will be compared against images of the dataset", type="pil")
            btn = gr.Button("Find").style(full_width=True)

        with gr.Column(scale=2, min_width=600):
            gallery_similar = gr.Gallery(label="similar images")
            gallery_similar.style(grid=3)
        
    event = btn.click(compute, [dataset_name, dataset_split, dataset_column_image, img], gallery_similar)


demo.launch()