sgugger commited on
Commit
b3195da
1 Parent(s): 8d8a971

Add first app

Browse files
Files changed (1) hide show
  1. app.py +89 -0
app.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import numpy as np
4
+ from scipy.fftpack import dct
5
+ from datasets import load_dataset
6
+ from PIL import Image
7
+ from multiprocessing import cpu_count
8
+
9
+
10
+ def perceptual_hash_color(image):
11
+ image = image.convert("RGB") # Convert to grayscale
12
+ image = image.resize((32, 32), Image.ANTIALIAS) # Resize to 32x32
13
+ image_array = np.asarray(image) # Convert to numpy array
14
+ hashes = []
15
+ for i in range(3):
16
+ channel = image_array[:, :, i]
17
+ dct_coef = dct(dct(channel, axis=0), axis=1) # Compute DCT
18
+ dct_reduced_coef = dct_coef[:8, :8] # Retain top-left 8x8 DCT coefficients
19
+ # Median of DCT coefficients excluding the DC term (0th term)
20
+ median_coef_val = np.median(np.ndarray.flatten(dct_reduced_coef)[1:])
21
+ # Mask of all coefficients greater than median of coefficients
22
+ hashes.append((dct_reduced_coef >= median_coef_val).flatten() * 1)
23
+ return np.concatenate(hashes)
24
+
25
+ def hamming_distance(array_1, array_2):
26
+ return len([1 for el_1, el_2 in zip(array_1, array_2) if el_1 != el_2])
27
+
28
+ def search_closest_examples(hash_refs, img_dataset):
29
+ distances = []
30
+ for hash_ref in hash_refs:
31
+ distances.extend([hamming_distance(hash_ref, img_dataset[idx]["hash"]) for idx in range(img_dataset.num_rows)])
32
+ closests = [i.item() % len(img_dataset) for i in np.argsort(distances)[:9]]
33
+ return closests, [distances[c] for c in closests]
34
+
35
+ def find_closest_images(images, img_dataset):
36
+ if not isinstance(images, (list, tuple)):
37
+ images = [images]
38
+ hashes = [perceptual_hash_color(img) for img in images]
39
+ closest_idx, distances = search_closest_examples(hashes, img_dataset)
40
+ return closest_idx, distances
41
+
42
+ def compute_hash_from_image(img):
43
+ img = img.convert("L") # Convert to grayscale
44
+ img = img.resize((32, 32), Image.ANTIALIAS) # Resize to 32x32
45
+ img_array = np.asarray(img) # Convert to numpy array
46
+ dct_coef = dct(dct(img_array, axis=0), axis=1) # Compute DCT
47
+ dct_reduced_coef = dct_coef[:8, :8] # Retain top-left 8x8 DCT coefficients
48
+ # Median of DCT coefficients excluding the DC term (0th term)
49
+ median_coef_val = np.median(np.ndarray.flatten(dct_reduced_coef)[1:])
50
+ # Mask of all coefficients greater than median of coefficients
51
+ hash = (dct_reduced_coef >= median_coef_val).flatten() * 1
52
+ return hash
53
+
54
+
55
+ def process_dataset(dataset_name, dataset_split, dataset_column_image):
56
+ img_dataset = load_dataset(dataset_name)[dataset_split]
57
+
58
+ def add_hash(example):
59
+ example["hash"] = perceptual_hash_color(example[dataset_column_image])
60
+ return example
61
+
62
+ # Compute hash of every image in the dataset
63
+ img_dataset = img_dataset.map(add_hash, num_proc=4)
64
+ return img_dataset
65
+
66
+
67
+ def compute(dataset_name, dataset_split, dataset_column_image, img):
68
+ img_dataset = process_dataset(dataset_name, dataset_split, dataset_column_image)
69
+ closest_idx, distances = find_closest_images(img, img_dataset)
70
+ return [img_dataset[i] for i in closest_idx]
71
+
72
+
73
+ with gr.Blocks() as demo:
74
+ gr.Markdown("# Find if your images are in a public dataset!")
75
+ with gr.Row():
76
+ with gr.Column(scale=1, min_width=600):
77
+ dataset_name = gr.Textbox(label="Enter the name of a dataset containing images")
78
+ dataset_split = gr.Textbox(label="Enter the split of this dataset to consider")
79
+ dataset_column_image = gr.Textbox(label="Enter the name of the column of this dataset that contains images")
80
+ img = gr.Image(label="Input your image that will be compared against images of the dataset", type="pil")
81
+ btn = gr.Button("Find").style(full_width=True)
82
+
83
+ with gr.Column(scale=2, min_width=600):
84
+ gallery_similar = gr.Gallery(label="similar images")
85
+
86
+ event = btn.click(compute, [dataset_name, dataset_split, dataset_column_image, img], gallery_similar)
87
+
88
+
89
+ demo.launch()