LCMI_T2I / app.py
shivangibithel's picture
Update app.py
c45b099
import streamlit as st
st.set_page_config(page_title='ITR', page_icon="🧊", layout='centered')
st.title("LCM-Independent for Pascal Dataset")
import faiss
import numpy as np
from PIL import Image
import json
import zipfile
import pickle
from transformers import AutoTokenizer, CLIPTextModelWithProjection
# loading the train dataset
with open('clip_train.pkl', 'rb') as f:
temp_d = pickle.load(f)
# train_xv = temp_d['image'].astype(np.float64) # Array of image features : np ndarray
# train_xt = temp_d['text'].astype(np.float64) # Array of text features : np ndarray
# train_yv = temp_d['label'] # Array of labels
train_yt = temp_d['label'] # Array of labels
# ids = list(temp_d['ids']) # image names == len(images)
# loading the test dataset
with open('clip_test.pkl', 'rb') as f:
temp_d = pickle.load(f)
# test_xv = temp_d['image'].astype(np.float64)
test_xt = temp_d['text'].astype(np.float64)
# test_yv = temp_d['label']
# test_yt = temp_d['label']
# Map the image ids to the corresponding image URLs
image_map_name = 'pascal_dataset.csv'
df = pd.read_csv(image_map_name)
image_list = list(df['image'])
class_list = list(df['class'])
zip_path = "pascal_raw.zip"
zip_file = zipfile.ZipFile(zip_path)
# text_model = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
# text_tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
text_index = faiss.read_index("text_index.index")
def T2Isearch(query, k=50):
# Encode the text query
# inputs = text_tokenizer([query], padding=True, return_tensors="pt")
# outputs = text_model(**inputs)
# query_embedding = outputs.text_embeds
query_embedding = test_xt[0]
query_vector = np.array([query_embedding])
faiss.normalize_L2(query_vector)
# text_index.nprobe = index.ntotal
text_index.nprobe = 100
# Search for the nearest neighbors in the FAISS text index
D, I = text_index.search(query_vector, k)
# get rank of all classes wrt to query
classes_all = []
Y = train_yt
neighbor_ys = Y[I]
class_freq = np.zeros(Y.shape[1])
for neighbor_y in neighbor_ys:
classes = np.where(neighbor_y > 0.5)[0]
for _class in classes:
class_freq[_class] += 1
count = 0
for i in range(len(class_freq)):
if class_freq[i]>0:
count +=1
ranked_classes = np.argsort(-class_freq) # chosen order of pivots -- predicted sequence of all labels for the query
ranked_classes_after_knn = ranked_classes[:count] # predicted sequence of top labels after knn search
lis = ['aeroplane', 'bicycle','bird','boat','bottle','bus','car','cat','chair','cow','diningtable','dog','horse','motorbike','person','pottedplant','sheep','sofa','train','tvmonitor']
class_ = lis[ranked_classes_after_knn[0]-1]
# Map the image ids to the corresponding image URLs
for i in range(len(image_list)):
if class_list[i] == class_ :
image_name = image_list[i]
image_data = zip_file.open("pascal_raw/images/dataset/"+ image_name)
image = Image.open(image_data)
st.image(image, width=600)
query = st.text_input("Enter your search query here:")
if st.button("Search"):
if query:
T2Isearch(query)