File size: 1,739 Bytes
0d59440
cf9d9e0
 
d6c88ae
 
 
 
0d59440
d6c88ae
42f52e5
b54da20
0d59440
d6c88ae
e3bc95e
 
d6c88ae
e3bc95e
 
 
 
dec5315
 
 
 
 
 
b047033
 
dec5315
7318e38
dec5315
 
0d59440
dec5315
d6c88ae
489b7f2
 
 
 
d6c88ae
 
489b7f2
d6c88ae
 
 
 
e3bc95e
 
0a6c9ba
482cb2b
82496ef
d6c88ae
 
 
 
f352d7e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import streamlit as st
st.set_page_config(page_title='T2I', page_icon="🧊", layout='centered')
st.title("Text To Image Retrieval for KaggleX BPIOC Mentorship Program")
import torch
from transformers import AutoTokenizer, AutoModel
import faiss
import numpy as np
from PIL import Image
from sentence_transformers import SentenceTransformer
import json
import zipfile

# Map the image ids to the corresponding image URLs
image_map_name = 'captions.json'

with open(image_map_name, 'r') as f:
    caption_dict = json.load(f)

image_list = list(caption_dict.keys())
caption_list = list(caption_dict.values())
zip_path = "Images.zip"
zip_file = zipfile.ZipFile(zip_path)

model_name = "sentence-transformers/all-distilroberta-v1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = SentenceTransformer(model_name)
# vectors = model.encode(caption_list)
vectors = np.load("./sbert_text_features.npy")
vector_dimension = vectors.shape[1]
index = faiss.IndexFlatIP(vector_dimension)
faiss.normalize_L2(vectors)
index.add(vectors)

def search(query, k=4):
    # Encode the query
    query_embedding = model.encode(query)
    query_vector = np.array([query_embedding])
    faiss.normalize_L2(query_vector)
    index.nprobe = index.ntotal

    # Search for the nearest neighbors in the FAISS index
    D, I = index.search(query_vector, k)

    # Map the image ids to the corresponding image URLs
    image_urls = []
    for i in I[0]:
        text_id = i
        image_id = str(image_list[i])
        image_data = zip_file.open("Images/" +image_id)
        image = Image.open(image_data)
        st.image(image, width=600)

query = st.text_input("Enter your search query here:")
if st.button("Search"):
    if query:
        search(query)