Spaces:
Running
Running
import streamlit as st | |
import torch | |
import os | |
import torchvision | |
import faiss | |
from PIL import Image | |
import traceback | |
from tqdm import tqdm | |
from PIL import ImageFile | |
from slugify import slugify | |
import opendatasets as od | |
import json | |
import argparse | |
from streamlit_cropper import st_cropper | |
from azure.storage.blob import BlobServiceClient | |
from torch.utils.data import Dataset, DataLoader | |
import torchvision.transforms | |
import numpy as np | |
import faiss.contrib.torch_utils | |
from efficientnet_pytorch import EfficientNet | |
BATCH_SIZE = 200 | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
ImageFile.LOAD_TRUNCATED_IMAGES = True | |
FOLDER = "images/" | |
NUM_TREES = 100 | |
FEATURES = 1000 | |
FILETYPES = [".png", ".jpg", ".jpeg", ".tiff", ".bmp", ".webp"] | |
LIBRARIES = [ | |
"https://www.kaggle.com/datasets/athota1/caltech101", | |
"https://www.kaggle.com/datasets/gpiosenka/sports-classification", | |
"https://www.kaggle.com/datasets/puneet6060/intel-image-classification", | |
"https://www.kaggle.com/datasets/kkhandekar/image-dataset", | |
] | |
def dl_embeddings(): | |
"""dl pretrained embeddings in production environment instead of creating""" | |
# Connect to your Blob Storage account | |
if os.path.isfile(f"{slugify(FOLDER)}.index"): | |
print("Embeddings files already exists, skip download") | |
return | |
connect_str = st.secrets["connectionstring"] | |
blob_service_client = BlobServiceClient.from_connection_string(connect_str) | |
# Specify container and blob names | |
container_name = "imagessearch" | |
blob_name = f"{slugify(FOLDER)}.index" | |
# Get a reference to the blob | |
blob_client = blob_service_client.get_blob_client( | |
container=container_name, blob=blob_name | |
) | |
# Download the binary data | |
download_file_path = f"{slugify(FOLDER)}.index" # Path to save the downloaded file | |
with open(download_file_path, "wb") as download_file: | |
download_file.write(blob_client.download_blob().readall()) | |
print(f"File downloaded to: {download_file_path}") | |
def load_dataset(): | |
with open("kaggle.json", "w+") as f: | |
json.dump( | |
{ | |
"username": st.secrets["username"], | |
"key": st.secrets["key"], | |
}, | |
f, | |
) | |
for lib in LIBRARIES: | |
od.download( | |
lib, | |
"images/", | |
) | |
# Load a pre-trained image feature extractor model | |
def load_model(): | |
"""Loads a pre-trained image feature extractor model.""" | |
print("Loading pretrained model...") | |
model = EfficientNet.from_pretrained('efficientnet-b2') | |
model.eval() # Set model to evaluation mode | |
return model | |
# Get all file paths within a folder and its subfolders | |
def get_all_file_paths(folder_path): | |
"""Returns a list of all file paths within a folder and its subfolders.""" | |
file_paths = [] | |
for root, _, files in os.walk(folder_path): | |
for file in files: | |
if not file.lower().endswith(tuple(FILETYPES)): | |
continue | |
file_path = os.path.join(root, file) | |
file_paths.append(file_path) | |
print(f"Total {len(file_paths)} image files present") | |
return sorted(file_paths) | |
# Load all the images from file paths | |
def load_images(file_paths): | |
"""Load all the images from file paths.""" | |
print("Loading images: ") | |
images = list() | |
for path in tqdm(file_paths): | |
try: | |
images.append(Image.open(path).resize([224, 224])) | |
except BaseException as e: | |
print("error loading ", path, e) | |
return images | |
def load_image(file_path): | |
"""Load all the images from file paths.""" | |
try: | |
image = Image.open(file_path).resize([224, 224]) | |
return image | |
except BaseException as e: | |
print("Error loading ", file_path, e) | |
# Function to preprocess images | |
def preprocess_image(image): | |
"""Preprocesses an image for feature extraction.""" | |
if image.mode == "RGB": # Already has 3 channels | |
pass # No need to modify | |
elif image.mode == "L": # Grayscale image | |
image = image.convert("RGB") # Convert to 3-channel RGB | |
else: # Image has more than 3 channels | |
image = image.convert( | |
"RGB" | |
) # Convert to 3-channel RGB, discarding extra channels | |
preprocess = torchvision.transforms.Compose( | |
[ | |
# torchvision.transforms.Resize(224), # Adjust for EfficientNet input size | |
torchvision.transforms.CenterCrop(224), | |
torchvision.transforms.ToTensor(), | |
torchvision.transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] | |
), | |
] | |
) | |
return preprocess(image) | |
class ImageLoader(Dataset): | |
def __init__(self, image_files, transform, load_image): | |
self.transform = transform | |
self.load_image = load_image | |
self.image_files = image_files | |
def __len__(self): | |
return len(self.image_files) | |
def __getitem__(self, index): | |
return self.transform(self.load_image(self.image_files[index])) | |
# Extract features from a list of images | |
def extract_features(file_paths, model): | |
"""Extracts features from a list of images.""" | |
print("Extracting features:") | |
loader = DataLoader( | |
ImageLoader(file_paths, transform=preprocess_image, load_image=load_image), | |
batch_size=BATCH_SIZE, | |
) | |
features = [] | |
model = model.to(DEVICE) | |
with torch.no_grad(): | |
for batch_idx, images in enumerate(tqdm(loader)): | |
images = images.to(DEVICE) | |
features.append(model(images)) | |
return torch.cat(features) | |
# Build an Annoy index for efficient similarity search | |
def build_annoy_index(features): | |
"""Builds an Annoy index for efficient similarity search.""" | |
print("Building faiss index:") | |
f = features[0].shape[0] # Feature dimensionality | |
index = faiss.IndexIDMap(faiss.IndexFlatIP(f)) | |
features = features.cpu().detach().numpy() | |
faiss.normalize_L2(features) | |
index.add_with_ids( | |
features, np.array(range(len(features))) | |
) # Adjust num_trees for accuracy vs. speed trade-off | |
print("built faiss index:") | |
return index | |
# Perform reverse image search | |
def search_similar_images(query_image, num_results, f=FEATURES): | |
"""Finds similar images based on a query image feature.""" | |
index = faiss.read_index(f"{slugify(FOLDER)}.index") | |
model = load_model().to(DEVICE) | |
# Extract features and search | |
proc_image = preprocess_image(query_image).unsqueeze(0).to(DEVICE) | |
query_feature = model(proc_image) | |
query_feature = query_feature.cpu().detach().numpy() | |
faiss.normalize_L2(query_feature) | |
distances, nearest_neighbors = index.search( | |
query_feature, | |
num_results, | |
) | |
return query_image, nearest_neighbors[0], distances[0] | |
def save_embedding(folder=FOLDER): | |
if os.path.isfile(f"{slugify(FOLDER)}.index"): | |
print("skipping recreating image embeddings") | |
return | |
print("Performing image embeddings") | |
model = load_model() # Load the model once | |
file_paths = get_all_file_paths(folder_path=folder) | |
# images = load_images(file_paths) | |
features = extract_features(file_paths, model) | |
index = build_annoy_index(features) | |
faiss.write_index(index, f"{slugify(FOLDER)}.index") | |
def display_image(idx, dist): | |
file_paths = get_all_file_paths(folder_path=FOLDER) | |
# print(file_paths[idx]) | |
image = Image.open(file_paths[idx]) | |
st.image(image.resize([256, 256])) | |
st.markdown("SimScore: " + str(round(dist, 2))) | |
# st.markdown(file_paths[idx]) | |
if __name__ == "__main__": | |
# Main app logic | |
st.set_page_config(layout="wide") | |
st.title("Reverse Image Search App") | |
try: | |
load_dataset() | |
# download dev embeddings if not developement environment | |
ap = argparse.ArgumentParser() | |
ap.add_argument("--dev", action="store_true") | |
if not ap.parse_args().dev: | |
dl_embeddings() | |
save_embedding(FOLDER) | |
# File uploader | |
uploaded_file = st.file_uploader( | |
"Choose an image like a car, cat, dog, flower, fruits, bike, aeroplane, person", | |
type=FILETYPES, | |
) | |
n_matches = st.slider( | |
"Num of matches to be displayed", min_value=3, max_value=100, value=5 | |
) | |
if uploaded_file is not None: | |
query_image = Image.open(uploaded_file).resize([256, 256]) | |
cropped = st_cropper(query_image, default_coords=[10, 240, 10, 240]) | |
query_image, nearest_neighbors, distances = search_similar_images( | |
cropped.resize([224, 224]), n_matches | |
) | |
st.subheader("Similar Images:") | |
cols = st.columns([1] * 5) | |
for i, (idx, dist) in enumerate( | |
zip( | |
*[ | |
nearest_neighbors, | |
distances, | |
] | |
) | |
): | |
with cols[i % 5]: | |
# Display results | |
display_image(idx, dist) | |
else: | |
st.write("Please upload an image to start searching.") | |
except Exception as e: | |
traceback.print_exc() | |
print(e) | |
st.error("An error occurred: {}".format(e)) | |