sayakpaul HF staff commited on
Commit
3579efb
1 Parent(s): 388fe78

add: initial files.

Browse files
Files changed (4) hide show
  1. app.py +69 -0
  2. lsh.pickle +3 -0
  3. requirements.txt +3 -0
  4. similarity_utils.py +175 -0
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Thanks to Freddy Boulton (https://github.com/freddyaboulton) for helping with this.
3
+ """
4
+
5
+
6
+ import pickle
7
+
8
+ import gradio as gr
9
+ from datasets import load_dataset
10
+ from transformers import AutoModel
11
+
12
+ from similarity_utils import BuildLSHTable
13
+
14
+ seed = 42
15
+
16
+ # Only runs once when the script is first run.
17
+ with open("lsh.pickle", "rb") as handle:
18
+ loaded_lsh = pickle.load(handle)
19
+
20
+ # Load model for computing embeddings.
21
+ model_ckpt = "nateraw/vit-base-beans"
22
+ model = AutoModel.from_pretrained(model_ckpt)
23
+ lsh_builder = BuildLSHTable(model)
24
+ lsh_builder.lsh = loaded_lsh
25
+
26
+ # Candidate images.
27
+ dataset = load_dataset("beans")
28
+ candidate_dataset = dataset["train"].shuffle(seed=seed)
29
+
30
+
31
+ def query(image, top_k):
32
+ results = lsh_builder.query(image)
33
+
34
+ # Should be a list of string file paths for gr.Gallery to work
35
+ images = []
36
+ # List of labels for each image in the gallery
37
+ labels = []
38
+
39
+ candidates = []
40
+ overlaps = []
41
+
42
+ for idx, r in enumerate(sorted(results, key=results.get, reverse=True)):
43
+ if idx == top_k:
44
+ break
45
+ image_id, label = r.split("_")[0], r.split("_")[1]
46
+ candidates.append(candidate_dataset[int(image_id)]["image"])
47
+ labels.append(label)
48
+ overlaps.append(results[r])
49
+
50
+ candidates.insert(0, image)
51
+ labels.insert(0, label)
52
+
53
+ for i, candidate in enumerate(candidates):
54
+ filename = f"{i}.png"
55
+ candidate.save(filename)
56
+ images.append(filename)
57
+
58
+ # The gallery component can be a list of tuples, where the first element is a path to a file
59
+ # and the second element is an optional caption for that image
60
+ return list(zip(images, labels))
61
+
62
+
63
+ # You can set the type of gr.Image to be PIL, numpy or str (filepath)
64
+ # Not sure what the best for this demo is.
65
+ gr.Interface(
66
+ query,
67
+ inputs=[gr.Image(), gr.Slider(value=5, minimum=1, maximum=10, step=1)],
68
+ outputs=gr.Gallery(),
69
+ ).launch()
lsh.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:caa1727832f2279a4026b03b9f17638ff4a4deffa0a28586e74db59332dce732
3
+ size 136667
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ transformers==4.25.1
2
+ datasets==2.7.1
3
+ numpy==1.21.6
similarity_utils.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+
3
+ import datasets
4
+ import numpy as np
5
+ import torch
6
+ import torchvision.transforms as T
7
+ from PIL import Image
8
+ from tqdm.auto import tqdm
9
+ from transformers import AutoFeatureExtractor, AutoModel
10
+
11
+ seed = 42
12
+ hash_size = 8
13
+ hidden_dim = 768 # ViT-base
14
+ np.random.seed(seed)
15
+
16
+
17
+ # Device.
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+
20
+ # Load model for computing embeddings..
21
+ model_ckpt = "nateraw/vit-base-beans"
22
+ extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
23
+
24
+ # Data transformation chain.
25
+ transformation_chain = T.Compose(
26
+ [
27
+ # We first resize the input image to 256x256 and then we take center crop.
28
+ T.Resize(int((256 / 224) * extractor.size["height"])),
29
+ T.CenterCrop(extractor.size["height"]),
30
+ T.ToTensor(),
31
+ T.Normalize(mean=extractor.image_mean, std=extractor.image_std),
32
+ ]
33
+ )
34
+
35
+
36
+ # Define random vectors to project with.
37
+ random_vectors = np.random.randn(hash_size, hidden_dim).T
38
+
39
+
40
+ def hash_func(embedding, random_vectors=random_vectors):
41
+ """Randomly projects the embeddings and then computes bit-wise hashes."""
42
+ if not isinstance(embedding, np.ndarray):
43
+ embedding = np.array(embedding)
44
+ if len(embedding.shape) < 2:
45
+ embedding = np.expand_dims(embedding, 0)
46
+
47
+ # Random projection.
48
+ bools = np.dot(embedding, random_vectors) > 0
49
+ return [bool2int(bool_vec) for bool_vec in bools]
50
+
51
+
52
+ def bool2int(x):
53
+ y = 0
54
+ for i, j in enumerate(x):
55
+ if j:
56
+ y += 1 << i
57
+ return y
58
+
59
+
60
+ def compute_hash(model: Union[torch.nn.Module, str]):
61
+ """Computes hash on a given dataset."""
62
+ device = model.device
63
+
64
+ def pp(example_batch):
65
+ # Prepare the input images for the model.
66
+ image_batch = example_batch["image"]
67
+ image_batch_transformed = torch.stack(
68
+ [transformation_chain(image) for image in image_batch]
69
+ )
70
+ new_batch = {"pixel_values": image_batch_transformed.to(device)}
71
+
72
+ # Compute embeddings and pool them i.e., take the representations from the [CLS]
73
+ # token.
74
+ with torch.no_grad():
75
+ embeddings = model(**new_batch).last_hidden_state[:, 0].cpu().numpy()
76
+
77
+ # Compute hashes for the batch of images.
78
+ hashes = [hash_func(embeddings[i]) for i in range(len(embeddings))]
79
+ example_batch["hashes"] = hashes
80
+ return example_batch
81
+
82
+ return pp
83
+
84
+
85
+ class Table:
86
+ def __init__(self, hash_size: int):
87
+ self.table = {}
88
+ self.hash_size = hash_size
89
+
90
+ def add(self, id: int, hashes: List[int], label: int):
91
+ # Create a unique indentifier.
92
+ entry = {"id_label": str(id) + "_" + str(label)}
93
+
94
+ # Add the hash values to the current table.
95
+ for h in hashes:
96
+ if h in self.table:
97
+ self.table[h].append(entry)
98
+ else:
99
+ self.table[h] = [entry]
100
+
101
+ def query(self, hashes: List[int]):
102
+ results = []
103
+
104
+ # Loop over the query hashes and determine if they exist in
105
+ # the current table.
106
+ for h in hashes:
107
+ if h in self.table:
108
+ results.extend(self.table[h])
109
+ return results
110
+
111
+
112
+ class LSH:
113
+ def __init__(self, hash_size, num_tables):
114
+ self.num_tables = num_tables
115
+ self.tables = []
116
+ for i in range(self.num_tables):
117
+ self.tables.append(Table(hash_size))
118
+
119
+ def add(self, id: int, hash: List[int], label: int):
120
+ for table in self.tables:
121
+ table.add(id, hash, label)
122
+
123
+ def query(self, hashes: List[int]):
124
+ results = []
125
+ for table in self.tables:
126
+ results.extend(table.query(hashes))
127
+ return results
128
+
129
+
130
+ class BuildLSHTable:
131
+ def __init__(
132
+ self,
133
+ model: Union[torch.nn.Module, None],
134
+ batch_size: int = 48,
135
+ hash_size: int = hash_size,
136
+ dim: int = hidden_dim,
137
+ num_tables: int = 10,
138
+ ):
139
+ self.hash_size = hash_size
140
+ self.dim = dim
141
+ self.num_tables = num_tables
142
+ self.lsh = LSH(self.hash_size, self.num_tables)
143
+
144
+ self.batch_size = batch_size
145
+ self.hash_fn = compute_hash(model.to(device))
146
+
147
+ def build(self, ds: datasets.DatasetDict):
148
+ dataset_hashed = ds.map(self.hash_fn, batched=True, batch_size=self.batch_size)
149
+
150
+ for id in tqdm(range(len(dataset_hashed))):
151
+ hash, label = dataset_hashed[id]["hashes"], dataset_hashed[id]["labels"]
152
+ self.lsh.add(id, hash, label)
153
+
154
+ def query(self, image, verbose=True):
155
+ if isinstance(image, str):
156
+ image = Image.open(image).convert("RGB")
157
+
158
+ # Compute the hashes of the query image and fetch the results.
159
+ example_batch = dict(image=[image])
160
+ hashes = self.hash_fn(example_batch)["hashes"][0]
161
+
162
+ results = self.lsh.query(hashes)
163
+ if verbose:
164
+ print("Matches:", len(results))
165
+
166
+ # Calculate Jaccard index to quantify the similarity.
167
+ counts = {}
168
+ for r in results:
169
+ if r["id_label"] in counts:
170
+ counts[r["id_label"]] += 1
171
+ else:
172
+ counts[r["id_label"]] = 1
173
+ for k in counts:
174
+ counts[k] = float(counts[k]) / self.dim
175
+ return counts