Spaces:
Running
Running
Instantaneous1
commited on
Commit
•
415d5ea
1
Parent(s):
e392687
batch process, faiss, gpu support, optimise
Browse files- .gitignore +1 -0
- app.py +86 -39
- requirements.txt +4 -2
.gitignore
CHANGED
@@ -2,5 +2,6 @@ env/
|
|
2 |
images/
|
3 |
__pycache__/
|
4 |
*.tree
|
|
|
5 |
secrets.toml
|
6 |
kaggle.json
|
|
|
2 |
images/
|
3 |
__pycache__/
|
4 |
*.tree
|
5 |
+
*.index
|
6 |
secrets.toml
|
7 |
kaggle.json
|
app.py
CHANGED
@@ -2,7 +2,7 @@ import streamlit as st
|
|
2 |
import torch
|
3 |
import os
|
4 |
import torchvision
|
5 |
-
|
6 |
from PIL import Image
|
7 |
import traceback
|
8 |
from tqdm import tqdm
|
@@ -11,27 +11,41 @@ from slugify import slugify
|
|
11 |
import opendatasets as od
|
12 |
import json
|
13 |
import argparse
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
-
|
|
|
16 |
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
17 |
FOLDER = "images/"
|
18 |
NUM_TREES = 100
|
19 |
FEATURES = 1000
|
20 |
FILETYPES = [".png", ".jpg", ".jpeg", ".tiff", ".bmp"]
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
|
|
23 |
|
24 |
|
25 |
@st.cache_resource
|
26 |
def dl_embeddings():
|
27 |
"""dl pretrained embeddings in production environment instead of creating"""
|
28 |
# Connect to your Blob Storage account
|
|
|
|
|
|
|
29 |
connect_str = st.secrets["connectionstring"]
|
30 |
blob_service_client = BlobServiceClient.from_connection_string(connect_str)
|
31 |
|
32 |
# Specify container and blob names
|
33 |
container_name = "imagessearch"
|
34 |
-
blob_name = f"{slugify(FOLDER)}.
|
35 |
|
36 |
# Get a reference to the blob
|
37 |
blob_client = blob_service_client.get_blob_client(
|
@@ -39,7 +53,7 @@ def dl_embeddings():
|
|
39 |
)
|
40 |
|
41 |
# Download the binary data
|
42 |
-
download_file_path = f"{slugify(FOLDER)}.
|
43 |
with open(download_file_path, "wb") as download_file:
|
44 |
download_file.write(blob_client.download_blob().readall())
|
45 |
|
@@ -56,16 +70,18 @@ def load_dataset():
|
|
56 |
},
|
57 |
f,
|
58 |
)
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
63 |
|
64 |
|
65 |
# Load a pre-trained image feature extractor model
|
66 |
@st.cache_resource
|
67 |
def load_model():
|
68 |
"""Loads a pre-trained image feature extractor model."""
|
|
|
69 |
model = torch.hub.load(
|
70 |
"NVIDIA/DeepLearningExamples:torchhub",
|
71 |
"nvidia_efficientnet_b0",
|
@@ -104,9 +120,19 @@ def load_images(file_paths):
|
|
104 |
return images
|
105 |
|
106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
# Function to preprocess images
|
108 |
def preprocess_image(image):
|
109 |
"""Preprocesses an image for feature extraction."""
|
|
|
110 |
if image.mode == "RGB": # Already has 3 channels
|
111 |
pass # No need to modify
|
112 |
elif image.mode == "L": # Grayscale image
|
@@ -128,57 +154,77 @@ def preprocess_image(image):
|
|
128 |
return preprocess(image)
|
129 |
|
130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
# Extract features from a list of images
|
132 |
-
def extract_features(
|
133 |
"""Extracts features from a list of images."""
|
134 |
print("Extracting features:")
|
|
|
|
|
|
|
|
|
135 |
features = []
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
|
|
141 |
|
142 |
|
143 |
# Build an Annoy index for efficient similarity search
|
144 |
def build_annoy_index(features):
|
145 |
"""Builds an Annoy index for efficient similarity search."""
|
146 |
-
print("Building
|
147 |
f = features[0].shape[0] # Feature dimensionality
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
|
|
153 |
|
154 |
|
155 |
# Perform reverse image search
|
156 |
-
def search_similar_images(
|
157 |
"""Finds similar images based on a query image feature."""
|
158 |
-
index =
|
159 |
-
|
160 |
-
query_image = Image.open(uploaded_file)
|
161 |
-
model = load_model()
|
162 |
# Extract features and search
|
163 |
-
|
164 |
-
|
165 |
-
)
|
166 |
-
|
167 |
-
query_feature,
|
|
|
168 |
)
|
169 |
-
return query_image, nearest_neighbors, distances
|
170 |
|
171 |
|
172 |
@st.cache_data
|
173 |
def save_embedding(folder=FOLDER):
|
174 |
-
if os.path.isfile(f"{slugify(FOLDER)}.
|
|
|
175 |
return
|
|
|
176 |
model = load_model() # Load the model once
|
177 |
file_paths = get_all_file_paths(folder_path=folder)
|
178 |
-
images = load_images(file_paths)
|
179 |
-
features = extract_features(
|
180 |
index = build_annoy_index(features)
|
181 |
-
|
182 |
|
183 |
|
184 |
def display_image(idx, dist):
|
@@ -214,11 +260,12 @@ if __name__ == "__main__":
|
|
214 |
)
|
215 |
|
216 |
if uploaded_file is not None:
|
|
|
|
|
217 |
query_image, nearest_neighbors, distances = search_similar_images(
|
218 |
-
|
219 |
)
|
220 |
|
221 |
-
st.image(query_image.resize([256, 256]), caption="Query Image", width=200)
|
222 |
st.subheader("Similar Images:")
|
223 |
cols = st.columns([1] * 5)
|
224 |
for i, (idx, dist) in enumerate(
|
|
|
2 |
import torch
|
3 |
import os
|
4 |
import torchvision
|
5 |
+
import faiss
|
6 |
from PIL import Image
|
7 |
import traceback
|
8 |
from tqdm import tqdm
|
|
|
11 |
import opendatasets as od
|
12 |
import json
|
13 |
import argparse
|
14 |
+
from streamlit_cropper import st_cropper
|
15 |
+
from azure.storage.blob import BlobServiceClient
|
16 |
+
from torch.utils.data import Dataset, DataLoader
|
17 |
+
import torchvision.transforms
|
18 |
+
import numpy as np
|
19 |
+
import faiss.contrib.torch_utils
|
20 |
|
21 |
+
BATCH_SIZE = 200
|
22 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
23 |
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
24 |
FOLDER = "images/"
|
25 |
NUM_TREES = 100
|
26 |
FEATURES = 1000
|
27 |
FILETYPES = [".png", ".jpg", ".jpeg", ".tiff", ".bmp"]
|
28 |
+
LIBRARIES = [
|
29 |
+
"https://www.kaggle.com/datasets/athota1/caltech101",
|
30 |
+
"https://www.kaggle.com/datasets/gpiosenka/sports-classification",
|
31 |
+
"https://www.kaggle.com/datasets/puneet6060/intel-image-classification",
|
32 |
+
"https://www.kaggle.com/datasets/kkhandekar/image-dataset",
|
33 |
+
]
|
34 |
|
35 |
|
36 |
@st.cache_resource
|
37 |
def dl_embeddings():
|
38 |
"""dl pretrained embeddings in production environment instead of creating"""
|
39 |
# Connect to your Blob Storage account
|
40 |
+
if os.path.isfile(f"{slugify(FOLDER)}.index"):
|
41 |
+
print("Embeddings files already exists, skip download")
|
42 |
+
return
|
43 |
connect_str = st.secrets["connectionstring"]
|
44 |
blob_service_client = BlobServiceClient.from_connection_string(connect_str)
|
45 |
|
46 |
# Specify container and blob names
|
47 |
container_name = "imagessearch"
|
48 |
+
blob_name = f"{slugify(FOLDER)}.index"
|
49 |
|
50 |
# Get a reference to the blob
|
51 |
blob_client = blob_service_client.get_blob_client(
|
|
|
53 |
)
|
54 |
|
55 |
# Download the binary data
|
56 |
+
download_file_path = f"{slugify(FOLDER)}.index" # Path to save the downloaded file
|
57 |
with open(download_file_path, "wb") as download_file:
|
58 |
download_file.write(blob_client.download_blob().readall())
|
59 |
|
|
|
70 |
},
|
71 |
f,
|
72 |
)
|
73 |
+
for lib in LIBRARIES:
|
74 |
+
od.download(
|
75 |
+
lib,
|
76 |
+
"images/",
|
77 |
+
)
|
78 |
|
79 |
|
80 |
# Load a pre-trained image feature extractor model
|
81 |
@st.cache_resource
|
82 |
def load_model():
|
83 |
"""Loads a pre-trained image feature extractor model."""
|
84 |
+
print("Loading pretrained model...")
|
85 |
model = torch.hub.load(
|
86 |
"NVIDIA/DeepLearningExamples:torchhub",
|
87 |
"nvidia_efficientnet_b0",
|
|
|
120 |
return images
|
121 |
|
122 |
|
123 |
+
def load_image(file_path):
|
124 |
+
"""Load all the images from file paths."""
|
125 |
+
try:
|
126 |
+
image = Image.open(file_path).resize([224, 224])
|
127 |
+
return image
|
128 |
+
except BaseException as e:
|
129 |
+
print("Error loading ", file_path, e)
|
130 |
+
|
131 |
+
|
132 |
# Function to preprocess images
|
133 |
def preprocess_image(image):
|
134 |
"""Preprocesses an image for feature extraction."""
|
135 |
+
|
136 |
if image.mode == "RGB": # Already has 3 channels
|
137 |
pass # No need to modify
|
138 |
elif image.mode == "L": # Grayscale image
|
|
|
154 |
return preprocess(image)
|
155 |
|
156 |
|
157 |
+
class ImageLoader(Dataset):
|
158 |
+
def __init__(self, image_files, transform, load_image):
|
159 |
+
self.transform = transform
|
160 |
+
self.load_image = load_image
|
161 |
+
self.image_files = image_files
|
162 |
+
|
163 |
+
def __len__(self):
|
164 |
+
return len(self.image_files)
|
165 |
+
|
166 |
+
def __getitem__(self, index):
|
167 |
+
return self.transform(self.load_image(self.image_files[index]))
|
168 |
+
|
169 |
+
|
170 |
# Extract features from a list of images
|
171 |
+
def extract_features(file_paths, model):
|
172 |
"""Extracts features from a list of images."""
|
173 |
print("Extracting features:")
|
174 |
+
loader = DataLoader(
|
175 |
+
ImageLoader(file_paths, transform=preprocess_image, load_image=load_image),
|
176 |
+
batch_size=BATCH_SIZE,
|
177 |
+
)
|
178 |
features = []
|
179 |
+
model = model.to(DEVICE)
|
180 |
+
with torch.no_grad():
|
181 |
+
for batch_idx, images in enumerate(tqdm(loader)):
|
182 |
+
images = images.to(DEVICE)
|
183 |
+
features.append(model(images))
|
184 |
+
return torch.cat(features)
|
185 |
|
186 |
|
187 |
# Build an Annoy index for efficient similarity search
|
188 |
def build_annoy_index(features):
|
189 |
"""Builds an Annoy index for efficient similarity search."""
|
190 |
+
print("Building faiss index:")
|
191 |
f = features[0].shape[0] # Feature dimensionality
|
192 |
+
index = faiss.IndexIDMap(faiss.IndexFlatIP(f))
|
193 |
+
index.add_with_ids(
|
194 |
+
features.cpu().detach().numpy(), np.array(range(len(features)))
|
195 |
+
) # Adjust num_trees for accuracy vs. speed trade-off
|
196 |
+
print("built faiss index:")
|
197 |
+
return index
|
198 |
|
199 |
|
200 |
# Perform reverse image search
|
201 |
+
def search_similar_images(query_image, num_results, f=FEATURES):
|
202 |
"""Finds similar images based on a query image feature."""
|
203 |
+
index = faiss.read_index(f"{slugify(FOLDER)}.index")
|
204 |
+
model = load_model().to(DEVICE)
|
|
|
|
|
205 |
# Extract features and search
|
206 |
+
proc_image = preprocess_image(query_image).unsqueeze(0).to(DEVICE)
|
207 |
+
query_feature = model(proc_image)
|
208 |
+
query_feature = query_feature.cpu().detach().numpy()
|
209 |
+
distances, nearest_neighbors = index.search(
|
210 |
+
query_feature,
|
211 |
+
num_results,
|
212 |
)
|
213 |
+
return query_image, nearest_neighbors[0], distances[0]
|
214 |
|
215 |
|
216 |
@st.cache_data
|
217 |
def save_embedding(folder=FOLDER):
|
218 |
+
if os.path.isfile(f"{slugify(FOLDER)}.index"):
|
219 |
+
print("skipping recreating image embeddings")
|
220 |
return
|
221 |
+
print("Performing image embeddings")
|
222 |
model = load_model() # Load the model once
|
223 |
file_paths = get_all_file_paths(folder_path=folder)
|
224 |
+
# images = load_images(file_paths)
|
225 |
+
features = extract_features(file_paths, model)
|
226 |
index = build_annoy_index(features)
|
227 |
+
faiss.write_index(index, f"{slugify(FOLDER)}.index")
|
228 |
|
229 |
|
230 |
def display_image(idx, dist):
|
|
|
260 |
)
|
261 |
|
262 |
if uploaded_file is not None:
|
263 |
+
query_image = Image.open(uploaded_file)
|
264 |
+
cropped = st_cropper(query_image)
|
265 |
query_image, nearest_neighbors, distances = search_similar_images(
|
266 |
+
cropped.resize([256, 256]), n_matches
|
267 |
)
|
268 |
|
|
|
269 |
st.subheader("Similar Images:")
|
270 |
cols = st.columns([1] * 5)
|
271 |
for i, (idx, dist) in enumerate(
|
requirements.txt
CHANGED
@@ -1,8 +1,10 @@
|
|
1 |
-
|
|
|
2 |
torch
|
3 |
torchvision
|
4 |
streamlit
|
5 |
tqdm
|
6 |
python-slugify
|
7 |
opendatasets
|
8 |
-
azure-storage-blob
|
|
|
|
1 |
+
faiss-cpu
|
2 |
+
faiss-gpu
|
3 |
torch
|
4 |
torchvision
|
5 |
streamlit
|
6 |
tqdm
|
7 |
python-slugify
|
8 |
opendatasets
|
9 |
+
azure-storage-blob
|
10 |
+
streamlit-cropper
|