import streamlit as st st.set_page_config(page_title='Image Search App', layout='wide') import torch from transformers import AutoTokenizer, AutoModel import faiss import numpy as np import wget from PIL import Image from io import BytesIO from sentence_transformers import SentenceTransformer import json from zipfile import ZipFile import zipfile from io import BytesIO from PIL import Image # from huggingface_hub import hf_hub_download # hf_hub_download(repo_id="shivangibithel/Flickr8k", filename="Images.zip") # Load the pre-trained sentence encoder model_name = "sentence-transformers/all-distilroberta-v1" tokenizer = AutoTokenizer.from_pretrained(model_name) model = SentenceTransformer(model_name) # # Load the FAISS index # index_name = 'index.faiss' # index_url = './' # wget.download(index_url, index_name) # index = faiss.read_index(faiss_flickr8k.index) # Define the path to the zip folder containing the images zip_path = "Images.zip" # Open the zip folder zip_file = zipfile.ZipFile(zip_path) # Iterate over the images in the zip folder and display them using Streamlit for image_name in zip_file.namelist(): image_data = zip_file.read(image_name) image = Image.open(io.BytesIO(image_data)) st.image(image, caption=image_name) vectors = np.load("./sbert_text_features.npy") vector_dimension = vectors.shape[1] index = faiss.IndexFlatL2(vector_dimension) faiss.normalize_L2(vectors) index.add(vectors) # Map the image ids to the corresponding image URLs image_map_name = 'captions.json' with open(image_map_name, 'r') as f: caption_dict = json.load(f) image_list = list(caption_dict.keys()) caption_list = list(caption_dict.values()) def search(query, k=5): # Encode the query query_embedding = model.encode(query) query_vector = np.array([query_embedding]) faiss.normalize_L2(query_vector) index.nprobe = index.ntotal # Search for the nearest neighbors in the FAISS index D, I = index.search(query_vector, k) # Map the image ids to the corresponding image URLs image_urls = [] for i in I[0]: text_id = i image_id = str(image_list[i]) image_data = zip_file.read(image_id) image = Image.open(io.BytesIO(image_data)) st.image(image, caption=image_name, width=200) # image_url = "./Images/" + image_id # image_urls.append(image_url) # return image_urls st.title("Image Search App") query = st.text_input("Enter your search query here:") if st.button("Search"): if query: search(query) # Display the images # st.image(image_urls, width=200) if __name__ == '__main__': st.cache(allow_output_mutation=True) run_app()