LCMI_T2I / app.py
shivangibithel's picture
Update app.py
81e59de
import streamlit as st
st.set_page_config(page_title='ITR', page_icon="🧊", layout='centered')
st.title("LCM-Independent for Pascal Dataset")
import faiss
import numpy as np
from PIL import Image
import json
import zipfile
import pandas as pd
import pickle
import pickletools
from transformers import AutoTokenizer, CLIPTextModelWithProjection
# loading the train dataset
# with open('clip_train.pkl', 'rb') as f:
# temp_d = pickletools.dis(f)
# train_xv = temp_d['image'].astype(np.float64) # Array of image features : np ndarray
# train_xt = temp_d['text'].astype(np.float64) # Array of text features : np ndarray
# train_yv = temp_d['label'] # Array of labels
# train_yt = temp_d['label'] # Array of labels
# ids = list(temp_d['ids']) # image names == len(images)
train_yt = np.load("train_yt.npy")
# loading the test dataset
# with open('clip_test.pkl', 'rb') as f:
# temp_d = pickletools.dis(f)
# test_xv = temp_d['image'].astype(np.float64)
# test_xt = temp_d['text'].astype(np.float64)
# test_yv = temp_d['label']
# test_yt = temp_d['label']
test_xt = np.load("test_xt.npy")
# Map the image ids to the corresponding image URLs
image_map_name = 'pascal_dataset.csv'
df = pd.read_csv(image_map_name)
image_list = list(df['image'])
class_list = list(df['class'])
zip_path = "pascal_raw.zip"
zip_file = zipfile.ZipFile(zip_path)
text_model = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
text_tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
d = 1024
text_index = faiss.index_factory(d, "Flat", faiss.METRIC_INNER_PRODUCT)
faiss.read_index("text_index.index")
def T2Isearch(query, k=50):
# Encode the text query
inputs = text_tokenizer([query], padding=True, return_tensors="pt")
outputs = text_model(**inputs)
query_embedding = outputs.text_embeds
query_vector = query_embedding.detach().numpy()
# query_vector = test_xt[0]
# query_vector = np.array([query_embedding])
faiss.normalize_L2(query_vector)
# text_index.nprobe = index.ntotal
text_index.nprobe = 100
# Search for the nearest neighbors in the FAISS text index
D, I = text_index.search(query_vector, k)
# get rank of all classes wrt to query
classes_all = []
Y = train_yt
neighbor_ys = Y[I]
class_freq = np.zeros(Y.shape[1])
for neighbor_y in neighbor_ys:
classes = np.where(neighbor_y > 0.5)[0]
for _class in classes:
class_freq[_class] += 1
count = 0
for i in range(len(class_freq)):
if class_freq[i]>0:
count +=1
ranked_classes = np.argsort(-class_freq) # chosen order of pivots -- predicted sequence of all labels for the query
ranked_classes_after_knn = ranked_classes[:count] # predicted sequence of top labels after knn search
lis = ['aeroplane', 'bicycle','bird','boat','bottle','bus','car','cat','chair','cow','diningtable','dog','horse','motorbike','person','pottedplant','sheep','sofa','train','tvmonitor']
class_ = lis[ranked_classes_after_knn[0]-1]
# Map the image ids to the corresponding image URLs
for i in range(len(image_list)):
if class_list[i] == class_ :
image_name = image_list[i]
image_data = zip_file.open("pascal_raw/images/dataset/"+ image_name)
image = Image.open(image_data)
st.image(image, width=600)
query = st.text_input("Enter your search query here:")
if st.button("Search"):
if query:
T2Isearch(query)