Instantaneous1 commited on
Commit
c456f09
0 Parent(s):

first commit

Browse files
Files changed (4) hide show
  1. .github/workflows/main.yaml +20 -0
  2. .gitignore +6 -0
  3. app.py +207 -0
  4. requirements.txt +7 -0
.github/workflows/main.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Sync to Hugging Face hub
2
+ on:
3
+ push:
4
+ branches: [main]
5
+
6
+ # to run this workflow manually from the Actions tab
7
+ workflow_dispatch:
8
+
9
+ jobs:
10
+ sync-to-hub:
11
+ runs-on: ubuntu-latest
12
+ steps:
13
+ - uses: actions/checkout@v3
14
+ with:
15
+ fetch-depth: 0
16
+ lfs: true
17
+ - name: Push to hub
18
+ env:
19
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
20
+ run: git push --force https://Instantaneous1:$HF_TOKEN@huggingface.co/spaces/Instantaneous1/search-by-image main
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ env/
2
+ images/
3
+ __pycache__/
4
+ *.tree
5
+ secrets.toml
6
+ kaggle.json
app.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import os
4
+ import torchvision
5
+ from annoy import AnnoyIndex
6
+ from PIL import Image
7
+ import traceback
8
+ from tqdm import tqdm
9
+ from PIL import ImageFile
10
+ from slugify import slugify
11
+ import opendatasets as od
12
+ import json
13
+
14
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
15
+ FOLDER = "images/"
16
+ NUM_TREES = 100
17
+ FEATURES = 1000
18
+
19
+
20
+ @st.cache_resource
21
+ def load_dataset():
22
+ with open("kaggle.json", "w+") as f:
23
+ json.dump(
24
+ {
25
+ "username": st.secrets["username"],
26
+ "key": st.secrets["key"],
27
+ },
28
+ f,
29
+ )
30
+ od.download(
31
+ "https://www.kaggle.com/datasets/kkhandekar/image-dataset",
32
+ "images/",
33
+ )
34
+
35
+
36
+ # Load a pre-trained image feature extractor model
37
+ @st.cache_resource
38
+ def load_model():
39
+ """Loads a pre-trained image feature extractor model."""
40
+ model = torch.hub.load(
41
+ "NVIDIA/DeepLearningExamples:torchhub",
42
+ "nvidia_efficientnet_b0",
43
+ pretrained=True,
44
+ )
45
+ model.eval() # Set model to evaluation mode
46
+ return model
47
+
48
+
49
+ # Get all file paths within a folder and its subfolders
50
+ @st.cache_data
51
+ def get_all_file_paths(folder_path):
52
+ """Returns a list of all file paths within a folder and its subfolders."""
53
+ file_paths = []
54
+ for root, _, files in os.walk(folder_path):
55
+ for file in files:
56
+ if not file.lower().endswith(
57
+ (".png", ".jpg", ".jpeg", ".tiff", ".bmp", ".gif")
58
+ ):
59
+ continue
60
+ file_path = os.path.join(root, file)
61
+ file_paths.append(file_path)
62
+ return file_paths
63
+
64
+
65
+ # Load all the images from file paths
66
+ @st.cache_data
67
+ def load_images(file_paths):
68
+ """Load all the images from file paths."""
69
+ print("Loading images: ")
70
+ images = list()
71
+ for path in tqdm(file_paths):
72
+ try:
73
+ images.append(Image.open(path).resize([224, 224]))
74
+ except BaseException as e:
75
+ print("error loading ", path, e)
76
+ return images
77
+
78
+
79
+ # Function to preprocess images
80
+ def preprocess_image(image):
81
+ """Preprocesses an image for feature extraction."""
82
+ if image.mode == "RGB": # Already has 3 channels
83
+ pass # No need to modify
84
+ elif image.mode == "L": # Grayscale image
85
+ image = image.convert("RGB") # Convert to 3-channel RGB
86
+ else: # Image has more than 3 channels
87
+ image = image.convert(
88
+ "RGB"
89
+ ) # Convert to 3-channel RGB, discarding extra channels
90
+ preprocess = torchvision.transforms.Compose(
91
+ [
92
+ # torchvision.transforms.Resize(224), # Adjust for EfficientNet input size
93
+ torchvision.transforms.CenterCrop(224),
94
+ torchvision.transforms.ToTensor(),
95
+ torchvision.transforms.Normalize(
96
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
97
+ ),
98
+ ]
99
+ )
100
+ return preprocess(image)
101
+
102
+
103
+ # Extract features from a list of images
104
+ def extract_features(images, model):
105
+ """Extracts features from a list of images."""
106
+ print("Extracting features:")
107
+ features = []
108
+ for image in images:
109
+ with torch.no_grad():
110
+ feature = model(preprocess_image(image).unsqueeze(0)).squeeze(0)
111
+ features.append(feature.numpy())
112
+ return features
113
+
114
+
115
+ # Build an Annoy index for efficient similarity search
116
+ def build_annoy_index(features):
117
+ """Builds an Annoy index for efficient similarity search."""
118
+ print("Building annoy index:")
119
+ f = features[0].shape[0] # Feature dimensionality
120
+ t = AnnoyIndex(f, "angular") # Use angular distance for image features
121
+ for i, feature in tqdm(enumerate(features)):
122
+ t.add_item(i, feature)
123
+ t.build(NUM_TREES) # Adjust num_trees for accuracy vs. speed trade-off
124
+ return t
125
+
126
+
127
+ # Perform reverse image search
128
+ def search_similar_images(uploaded_file, f=FEATURES, num_results=5):
129
+ """Finds similar images based on a query image feature."""
130
+ index = AnnoyIndex(f, "angular")
131
+ index.load(f"{slugify(FOLDER)}.tree")
132
+ query_image = Image.open(uploaded_file)
133
+ model = load_model()
134
+ # Extract features and search
135
+ query_feature = (
136
+ model(preprocess_image(query_image).unsqueeze(0)).squeeze(0).detach().numpy()
137
+ )
138
+ nearest_neighbors, distances = index.get_nns_by_vector(
139
+ query_feature, num_results, include_distances=True
140
+ )
141
+ return query_image, nearest_neighbors, distances
142
+
143
+
144
+ @st.cache_data
145
+ def save_embedding(folder=FOLDER):
146
+ if os.path.isfile(f"{slugify(FOLDER)}.tree"):
147
+ return
148
+ model = load_model() # Load the model once
149
+ file_paths = get_all_file_paths(folder_path=folder)
150
+ images = load_images(file_paths)
151
+ features = extract_features(images, model)
152
+ index = build_annoy_index(features)
153
+ index.save(f"{slugify(FOLDER)}.tree")
154
+
155
+
156
+ def display_image(idx, dist):
157
+ file_paths = get_all_file_paths(folder_path=FOLDER)
158
+ image = Image.open(file_paths[idx])
159
+ st.image(image.resize([256, 256]))
160
+ st.markdown("SimScore: -" + str(round(dist, 2)))
161
+ # st.markdown(file_paths[idx])
162
+
163
+
164
+ if __name__ == "__main__":
165
+ # Main app logic
166
+ st.set_page_config(layout="wide")
167
+ st.title("Reverse Image Search App")
168
+
169
+ try:
170
+ load_dataset()
171
+ save_embedding(FOLDER)
172
+
173
+ # File uploader
174
+ uploaded_file = st.file_uploader(
175
+ "Choose an image like a car, cat, dog, flower, fruits, bike, aeroplane, person"
176
+ )
177
+
178
+ n_matches = st.slider(
179
+ "Num of matches to be displayed", min_value=3, max_value=100, value=5
180
+ )
181
+
182
+ if uploaded_file is not None:
183
+ query_image, nearest_neighbors, distances = search_similar_images(
184
+ uploaded_file, num_results=n_matches
185
+ )
186
+
187
+ st.image(query_image.resize([256, 256]), caption="Query Image", width=200)
188
+ st.subheader("Similar Images:")
189
+ cols = st.columns([1] * 5)
190
+ for i, (idx, dist) in enumerate(
191
+ zip(
192
+ *[
193
+ nearest_neighbors,
194
+ distances,
195
+ ]
196
+ )
197
+ ):
198
+ with cols[i % 5]:
199
+ # Display results
200
+ display_image(idx, dist)
201
+ else:
202
+ st.write("Please upload an image to start searching.")
203
+
204
+ except Exception as e:
205
+ traceback.print_exc()
206
+ print(e)
207
+ st.error("An error occurred: {}".format(e))
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ annoy
2
+ torch
3
+ torchvision
4
+ streamlit
5
+ tqdm
6
+ python-slugify
7
+ opendatasets