Instantaneous1 commited on
Commit
415d5ea
1 Parent(s): e392687

batch process, faiss, gpu support, optimise

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +86 -39
  3. requirements.txt +4 -2
.gitignore CHANGED
@@ -2,5 +2,6 @@ env/
2
  images/
3
  __pycache__/
4
  *.tree
 
5
  secrets.toml
6
  kaggle.json
 
2
  images/
3
  __pycache__/
4
  *.tree
5
+ *.index
6
  secrets.toml
7
  kaggle.json
app.py CHANGED
@@ -2,7 +2,7 @@ import streamlit as st
2
  import torch
3
  import os
4
  import torchvision
5
- from annoy import AnnoyIndex
6
  from PIL import Image
7
  import traceback
8
  from tqdm import tqdm
@@ -11,27 +11,41 @@ from slugify import slugify
11
  import opendatasets as od
12
  import json
13
  import argparse
 
 
 
 
 
 
14
 
15
-
 
16
  ImageFile.LOAD_TRUNCATED_IMAGES = True
17
  FOLDER = "images/"
18
  NUM_TREES = 100
19
  FEATURES = 1000
20
  FILETYPES = [".png", ".jpg", ".jpeg", ".tiff", ".bmp"]
21
-
22
- from azure.storage.blob import BlobServiceClient
 
 
 
 
23
 
24
 
25
  @st.cache_resource
26
  def dl_embeddings():
27
  """dl pretrained embeddings in production environment instead of creating"""
28
  # Connect to your Blob Storage account
 
 
 
29
  connect_str = st.secrets["connectionstring"]
30
  blob_service_client = BlobServiceClient.from_connection_string(connect_str)
31
 
32
  # Specify container and blob names
33
  container_name = "imagessearch"
34
- blob_name = f"{slugify(FOLDER)}.tree"
35
 
36
  # Get a reference to the blob
37
  blob_client = blob_service_client.get_blob_client(
@@ -39,7 +53,7 @@ def dl_embeddings():
39
  )
40
 
41
  # Download the binary data
42
- download_file_path = f"{slugify(FOLDER)}.tree" # Path to save the downloaded file
43
  with open(download_file_path, "wb") as download_file:
44
  download_file.write(blob_client.download_blob().readall())
45
 
@@ -56,16 +70,18 @@ def load_dataset():
56
  },
57
  f,
58
  )
59
- od.download(
60
- "https://www.kaggle.com/datasets/kkhandekar/image-dataset",
61
- "images/",
62
- )
 
63
 
64
 
65
  # Load a pre-trained image feature extractor model
66
  @st.cache_resource
67
  def load_model():
68
  """Loads a pre-trained image feature extractor model."""
 
69
  model = torch.hub.load(
70
  "NVIDIA/DeepLearningExamples:torchhub",
71
  "nvidia_efficientnet_b0",
@@ -104,9 +120,19 @@ def load_images(file_paths):
104
  return images
105
 
106
 
 
 
 
 
 
 
 
 
 
107
  # Function to preprocess images
108
  def preprocess_image(image):
109
  """Preprocesses an image for feature extraction."""
 
110
  if image.mode == "RGB": # Already has 3 channels
111
  pass # No need to modify
112
  elif image.mode == "L": # Grayscale image
@@ -128,57 +154,77 @@ def preprocess_image(image):
128
  return preprocess(image)
129
 
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  # Extract features from a list of images
132
- def extract_features(images, model):
133
  """Extracts features from a list of images."""
134
  print("Extracting features:")
 
 
 
 
135
  features = []
136
- for image in images:
137
- with torch.no_grad():
138
- feature = model(preprocess_image(image).unsqueeze(0)).squeeze(0)
139
- features.append(feature.numpy())
140
- return features
 
141
 
142
 
143
  # Build an Annoy index for efficient similarity search
144
  def build_annoy_index(features):
145
  """Builds an Annoy index for efficient similarity search."""
146
- print("Building annoy index:")
147
  f = features[0].shape[0] # Feature dimensionality
148
- t = AnnoyIndex(f, "angular") # Use angular distance for image features
149
- for i, feature in tqdm(enumerate(features)):
150
- t.add_item(i, feature)
151
- t.build(NUM_TREES) # Adjust num_trees for accuracy vs. speed trade-off
152
- return t
 
153
 
154
 
155
  # Perform reverse image search
156
- def search_similar_images(uploaded_file, f=FEATURES, num_results=5):
157
  """Finds similar images based on a query image feature."""
158
- index = AnnoyIndex(f, "angular")
159
- index.load(f"{slugify(FOLDER)}.tree")
160
- query_image = Image.open(uploaded_file)
161
- model = load_model()
162
  # Extract features and search
163
- query_feature = (
164
- model(preprocess_image(query_image).unsqueeze(0)).squeeze(0).detach().numpy()
165
- )
166
- nearest_neighbors, distances = index.get_nns_by_vector(
167
- query_feature, num_results, include_distances=True
 
168
  )
169
- return query_image, nearest_neighbors, distances
170
 
171
 
172
  @st.cache_data
173
  def save_embedding(folder=FOLDER):
174
- if os.path.isfile(f"{slugify(FOLDER)}.tree"):
 
175
  return
 
176
  model = load_model() # Load the model once
177
  file_paths = get_all_file_paths(folder_path=folder)
178
- images = load_images(file_paths)
179
- features = extract_features(images, model)
180
  index = build_annoy_index(features)
181
- index.save(f"{slugify(FOLDER)}.tree")
182
 
183
 
184
  def display_image(idx, dist):
@@ -214,11 +260,12 @@ if __name__ == "__main__":
214
  )
215
 
216
  if uploaded_file is not None:
 
 
217
  query_image, nearest_neighbors, distances = search_similar_images(
218
- uploaded_file, num_results=n_matches
219
  )
220
 
221
- st.image(query_image.resize([256, 256]), caption="Query Image", width=200)
222
  st.subheader("Similar Images:")
223
  cols = st.columns([1] * 5)
224
  for i, (idx, dist) in enumerate(
 
2
  import torch
3
  import os
4
  import torchvision
5
+ import faiss
6
  from PIL import Image
7
  import traceback
8
  from tqdm import tqdm
 
11
  import opendatasets as od
12
  import json
13
  import argparse
14
+ from streamlit_cropper import st_cropper
15
+ from azure.storage.blob import BlobServiceClient
16
+ from torch.utils.data import Dataset, DataLoader
17
+ import torchvision.transforms
18
+ import numpy as np
19
+ import faiss.contrib.torch_utils
20
 
21
+ BATCH_SIZE = 200
22
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
  ImageFile.LOAD_TRUNCATED_IMAGES = True
24
  FOLDER = "images/"
25
  NUM_TREES = 100
26
  FEATURES = 1000
27
  FILETYPES = [".png", ".jpg", ".jpeg", ".tiff", ".bmp"]
28
+ LIBRARIES = [
29
+ "https://www.kaggle.com/datasets/athota1/caltech101",
30
+ "https://www.kaggle.com/datasets/gpiosenka/sports-classification",
31
+ "https://www.kaggle.com/datasets/puneet6060/intel-image-classification",
32
+ "https://www.kaggle.com/datasets/kkhandekar/image-dataset",
33
+ ]
34
 
35
 
36
  @st.cache_resource
37
  def dl_embeddings():
38
  """dl pretrained embeddings in production environment instead of creating"""
39
  # Connect to your Blob Storage account
40
+ if os.path.isfile(f"{slugify(FOLDER)}.index"):
41
+ print("Embeddings files already exists, skip download")
42
+ return
43
  connect_str = st.secrets["connectionstring"]
44
  blob_service_client = BlobServiceClient.from_connection_string(connect_str)
45
 
46
  # Specify container and blob names
47
  container_name = "imagessearch"
48
+ blob_name = f"{slugify(FOLDER)}.index"
49
 
50
  # Get a reference to the blob
51
  blob_client = blob_service_client.get_blob_client(
 
53
  )
54
 
55
  # Download the binary data
56
+ download_file_path = f"{slugify(FOLDER)}.index" # Path to save the downloaded file
57
  with open(download_file_path, "wb") as download_file:
58
  download_file.write(blob_client.download_blob().readall())
59
 
 
70
  },
71
  f,
72
  )
73
+ for lib in LIBRARIES:
74
+ od.download(
75
+ lib,
76
+ "images/",
77
+ )
78
 
79
 
80
  # Load a pre-trained image feature extractor model
81
  @st.cache_resource
82
  def load_model():
83
  """Loads a pre-trained image feature extractor model."""
84
+ print("Loading pretrained model...")
85
  model = torch.hub.load(
86
  "NVIDIA/DeepLearningExamples:torchhub",
87
  "nvidia_efficientnet_b0",
 
120
  return images
121
 
122
 
123
+ def load_image(file_path):
124
+ """Load all the images from file paths."""
125
+ try:
126
+ image = Image.open(file_path).resize([224, 224])
127
+ return image
128
+ except BaseException as e:
129
+ print("Error loading ", file_path, e)
130
+
131
+
132
  # Function to preprocess images
133
  def preprocess_image(image):
134
  """Preprocesses an image for feature extraction."""
135
+
136
  if image.mode == "RGB": # Already has 3 channels
137
  pass # No need to modify
138
  elif image.mode == "L": # Grayscale image
 
154
  return preprocess(image)
155
 
156
 
157
+ class ImageLoader(Dataset):
158
+ def __init__(self, image_files, transform, load_image):
159
+ self.transform = transform
160
+ self.load_image = load_image
161
+ self.image_files = image_files
162
+
163
+ def __len__(self):
164
+ return len(self.image_files)
165
+
166
+ def __getitem__(self, index):
167
+ return self.transform(self.load_image(self.image_files[index]))
168
+
169
+
170
  # Extract features from a list of images
171
+ def extract_features(file_paths, model):
172
  """Extracts features from a list of images."""
173
  print("Extracting features:")
174
+ loader = DataLoader(
175
+ ImageLoader(file_paths, transform=preprocess_image, load_image=load_image),
176
+ batch_size=BATCH_SIZE,
177
+ )
178
  features = []
179
+ model = model.to(DEVICE)
180
+ with torch.no_grad():
181
+ for batch_idx, images in enumerate(tqdm(loader)):
182
+ images = images.to(DEVICE)
183
+ features.append(model(images))
184
+ return torch.cat(features)
185
 
186
 
187
  # Build an Annoy index for efficient similarity search
188
  def build_annoy_index(features):
189
  """Builds an Annoy index for efficient similarity search."""
190
+ print("Building faiss index:")
191
  f = features[0].shape[0] # Feature dimensionality
192
+ index = faiss.IndexIDMap(faiss.IndexFlatIP(f))
193
+ index.add_with_ids(
194
+ features.cpu().detach().numpy(), np.array(range(len(features)))
195
+ ) # Adjust num_trees for accuracy vs. speed trade-off
196
+ print("built faiss index:")
197
+ return index
198
 
199
 
200
  # Perform reverse image search
201
+ def search_similar_images(query_image, num_results, f=FEATURES):
202
  """Finds similar images based on a query image feature."""
203
+ index = faiss.read_index(f"{slugify(FOLDER)}.index")
204
+ model = load_model().to(DEVICE)
 
 
205
  # Extract features and search
206
+ proc_image = preprocess_image(query_image).unsqueeze(0).to(DEVICE)
207
+ query_feature = model(proc_image)
208
+ query_feature = query_feature.cpu().detach().numpy()
209
+ distances, nearest_neighbors = index.search(
210
+ query_feature,
211
+ num_results,
212
  )
213
+ return query_image, nearest_neighbors[0], distances[0]
214
 
215
 
216
  @st.cache_data
217
  def save_embedding(folder=FOLDER):
218
+ if os.path.isfile(f"{slugify(FOLDER)}.index"):
219
+ print("skipping recreating image embeddings")
220
  return
221
+ print("Performing image embeddings")
222
  model = load_model() # Load the model once
223
  file_paths = get_all_file_paths(folder_path=folder)
224
+ # images = load_images(file_paths)
225
+ features = extract_features(file_paths, model)
226
  index = build_annoy_index(features)
227
+ faiss.write_index(index, f"{slugify(FOLDER)}.index")
228
 
229
 
230
  def display_image(idx, dist):
 
260
  )
261
 
262
  if uploaded_file is not None:
263
+ query_image = Image.open(uploaded_file)
264
+ cropped = st_cropper(query_image)
265
  query_image, nearest_neighbors, distances = search_similar_images(
266
+ cropped.resize([256, 256]), n_matches
267
  )
268
 
 
269
  st.subheader("Similar Images:")
270
  cols = st.columns([1] * 5)
271
  for i, (idx, dist) in enumerate(
requirements.txt CHANGED
@@ -1,8 +1,10 @@
1
- annoy
 
2
  torch
3
  torchvision
4
  streamlit
5
  tqdm
6
  python-slugify
7
  opendatasets
8
- azure-storage-blob
 
 
1
+ faiss-cpu
2
+ faiss-gpu
3
  torch
4
  torchvision
5
  streamlit
6
  tqdm
7
  python-slugify
8
  opendatasets
9
+ azure-storage-blob
10
+ streamlit-cropper