jusancp99 commited on
Commit
b34681a
1 Parent(s): a86cb64

Create similarity_utils.py

Browse files
Files changed (1) hide show
  1. similarity_utils.py +175 -0
similarity_utils.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+
3
+ import datasets
4
+ import numpy as np
5
+ import torch
6
+ import torchvision.transforms as T
7
+ from PIL import Image
8
+ from tqdm.auto import tqdm
9
+ from transformers import AutoFeatureExtractor, AutoModel
10
+
11
+ seed = 42
12
+ hash_size = 8
13
+ hidden_dim = 768 # ViT-base
14
+ np.random.seed(seed)
15
+
16
+
17
+ # Device.
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+
20
+ # Load model for computing embeddings..
21
+ model_ckpt = "nateraw/vit-base-beans"
22
+ extractor = AutoFeatureExtractor.from_pretrained(model_ckpt)
23
+
24
+ # Data transformation chain.
25
+ transformation_chain = T.Compose(
26
+ [
27
+ # We first resize the input image to 256x256 and then we take center crop.
28
+ T.Resize(int((256 / 224) * extractor.size["height"])),
29
+ T.CenterCrop(extractor.size["height"]),
30
+ T.ToTensor(),
31
+ T.Normalize(mean=extractor.image_mean, std=extractor.image_std),
32
+ ]
33
+ )
34
+
35
+
36
+ # Define random vectors to project with.
37
+ random_vectors = np.random.randn(hash_size, hidden_dim).T
38
+
39
+
40
+ def hash_func(embedding, random_vectors=random_vectors):
41
+ """Randomly projects the embeddings and then computes bit-wise hashes."""
42
+ if not isinstance(embedding, np.ndarray):
43
+ embedding = np.array(embedding)
44
+ if len(embedding.shape) < 2:
45
+ embedding = np.expand_dims(embedding, 0)
46
+
47
+ # Random projection.
48
+ bools = np.dot(embedding, random_vectors) > 0
49
+ return [bool2int(bool_vec) for bool_vec in bools]
50
+
51
+
52
+ def bool2int(x):
53
+ y = 0
54
+ for i, j in enumerate(x):
55
+ if j:
56
+ y += 1 << i
57
+ return y
58
+
59
+
60
+ def compute_hash(model: Union[torch.nn.Module, str]):
61
+ """Computes hash on a given dataset."""
62
+ device = model.device
63
+
64
+ def pp(example_batch):
65
+ # Prepare the input images for the model.
66
+ image_batch = example_batch["image"]
67
+ image_batch_transformed = torch.stack(
68
+ [transformation_chain(image) for image in image_batch]
69
+ )
70
+ new_batch = {"pixel_values": image_batch_transformed.to(device)}
71
+
72
+ # Compute embeddings and pool them i.e., take the representations from the [CLS]
73
+ # token.
74
+ with torch.no_grad():
75
+ embeddings = model(**new_batch).last_hidden_state[:, 0].cpu().numpy()
76
+
77
+ # Compute hashes for the batch of images.
78
+ hashes = [hash_func(embeddings[i]) for i in range(len(embeddings))]
79
+ example_batch["hashes"] = hashes
80
+ return example_batch
81
+
82
+ return pp
83
+
84
+
85
+ class Table:
86
+ def __init__(self, hash_size: int):
87
+ self.table = {}
88
+ self.hash_size = hash_size
89
+
90
+ def add(self, id: int, hashes: List[int], label: int):
91
+ # Create a unique indentifier.
92
+ entry = {"id_label": str(id) + "_" + str(label)}
93
+
94
+ # Add the hash values to the current table.
95
+ for h in hashes:
96
+ if h in self.table:
97
+ self.table[h].append(entry)
98
+ else:
99
+ self.table[h] = [entry]
100
+
101
+ def query(self, hashes: List[int]):
102
+ results = []
103
+
104
+ # Loop over the query hashes and determine if they exist in
105
+ # the current table.
106
+ for h in hashes:
107
+ if h in self.table:
108
+ results.extend(self.table[h])
109
+ return results
110
+
111
+
112
+ class LSH:
113
+ def __init__(self, hash_size, num_tables):
114
+ self.num_tables = num_tables
115
+ self.tables = []
116
+ for i in range(self.num_tables):
117
+ self.tables.append(Table(hash_size))
118
+
119
+ def add(self, id: int, hash: List[int], label: int):
120
+ for table in self.tables:
121
+ table.add(id, hash, label)
122
+
123
+ def query(self, hashes: List[int]):
124
+ results = []
125
+ for table in self.tables:
126
+ results.extend(table.query(hashes))
127
+ return results
128
+
129
+
130
+ class BuildLSHTable:
131
+ def __init__(
132
+ self,
133
+ model: Union[torch.nn.Module, None],
134
+ batch_size: int = 48,
135
+ hash_size: int = hash_size,
136
+ dim: int = hidden_dim,
137
+ num_tables: int = 10,
138
+ ):
139
+ self.hash_size = hash_size
140
+ self.dim = dim
141
+ self.num_tables = num_tables
142
+ self.lsh = LSH(self.hash_size, self.num_tables)
143
+
144
+ self.batch_size = batch_size
145
+ self.hash_fn = compute_hash(model.to(device))
146
+
147
+ def build(self, ds: datasets.DatasetDict):
148
+ dataset_hashed = ds.map(self.hash_fn, batched=True, batch_size=self.batch_size)
149
+
150
+ for id in tqdm(range(len(dataset_hashed))):
151
+ hash, label = dataset_hashed[id]["hashes"], dataset_hashed[id]["labels"]
152
+ self.lsh.add(id, hash, label)
153
+
154
+ def query(self, image, verbose=True):
155
+ if isinstance(image, str):
156
+ image = Image.open(image).convert("RGB")
157
+
158
+ # Compute the hashes of the query image and fetch the results.
159
+ example_batch = dict(image=[image])
160
+ hashes = self.hash_fn(example_batch)["hashes"][0]
161
+
162
+ results = self.lsh.query(hashes)
163
+ if verbose:
164
+ print("Matches:", len(results))
165
+
166
+ # Calculate Jaccard index to quantify the similarity.
167
+ counts = {}
168
+ for r in results:
169
+ if r["id_label"] in counts:
170
+ counts[r["id_label"]] += 1
171
+ else:
172
+ counts[r["id_label"]] = 1
173
+ for k in counts:
174
+ counts[k] = float(counts[k]) / self.dim
175
+ return counts