File size: 1,739 Bytes
0d59440
cf9d9e0
 
d6c88ae
 
 
 
0d59440
d6c88ae
42f52e5
b54da20
0d59440
d6c88ae
e3bc95e
 
d6c88ae
e3bc95e
 
 
 
dec5315
 
 
 
 
 
b047033
 
dec5315
 
 
 
0d59440
dec5315
d6c88ae
489b7f2
 
 
 
d6c88ae
 
489b7f2
d6c88ae
 
 
 
e3bc95e
 
0a6c9ba
482cb2b
a848b0f
d6c88ae
 
 
 
f352d7e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import streamlit as st
st.set_page_config(page_title='T2I', page_icon="🧊", layout='centered')
st.title("Text To Image Retrieval for KaggleX BPIOC Mentorship Program")
import torch
from transformers import AutoTokenizer, AutoModel
import faiss
import numpy as np
from PIL import Image
from sentence_transformers import SentenceTransformer
import json
import zipfile

# Map the image ids to the corresponding image URLs
image_map_name = 'captions.json'

with open(image_map_name, 'r') as f:
    caption_dict = json.load(f)

image_list = list(caption_dict.keys())
caption_list = list(caption_dict.values())
zip_path = "Images.zip"
zip_file = zipfile.ZipFile(zip_path)

model_name = "sentence-transformers/all-distilroberta-v1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = SentenceTransformer(model_name)
# vectors = model.encode(caption_list)
vectors = np.load("./sbert_text_features.npy")
vector_dimension = vectors.shape[1]
index = faiss.IndexFlatL2(vector_dimension)
faiss.normalize_L2(vectors)
index.add(vectors)

def search(query, k=4):
    # Encode the query
    query_embedding = model.encode(query)
    query_vector = np.array([query_embedding])
    faiss.normalize_L2(query_vector)
    index.nprobe = index.ntotal

    # Search for the nearest neighbors in the FAISS index
    D, I = index.search(query_vector, k)

    # Map the image ids to the corresponding image URLs
    image_urls = []
    for i in I[0]:
        text_id = i
        image_id = str(image_list[i])
        image_data = zip_file.open("Images/" +image_id)
        image = Image.open(image_data)
        st.image(image, width=400)

query = st.text_input("Enter your search query here:")
if st.button("Search"):
    if query:
        search(query)