Spaces:
Sleeping
Sleeping
import streamlit as st | |
st.set_page_config(page_title='ITR', page_icon="π§", layout='centered') | |
st.title("LCM-Independent for Pascal Dataset") | |
import faiss | |
import numpy as np | |
from PIL import Image | |
import json | |
import zipfile | |
import pandas as pd | |
import pickle | |
import pickletools | |
from transformers import AutoTokenizer, CLIPTextModelWithProjection | |
# loading the train dataset | |
# with open('clip_train.pkl', 'rb') as f: | |
# temp_d = pickletools.dis(f) | |
# train_xv = temp_d['image'].astype(np.float64) # Array of image features : np ndarray | |
# train_xt = temp_d['text'].astype(np.float64) # Array of text features : np ndarray | |
# train_yv = temp_d['label'] # Array of labels | |
# train_yt = temp_d['label'] # Array of labels | |
# ids = list(temp_d['ids']) # image names == len(images) | |
train_yt = np.load("train_yt.npy") | |
# loading the test dataset | |
# with open('clip_test.pkl', 'rb') as f: | |
# temp_d = pickletools.dis(f) | |
# test_xv = temp_d['image'].astype(np.float64) | |
# test_xt = temp_d['text'].astype(np.float64) | |
# test_yv = temp_d['label'] | |
# test_yt = temp_d['label'] | |
test_xt = np.load("test_xt.npy") | |
# Map the image ids to the corresponding image URLs | |
image_map_name = 'pascal_dataset.csv' | |
df = pd.read_csv(image_map_name) | |
image_list = list(df['image']) | |
class_list = list(df['class']) | |
zip_path = "pascal_raw.zip" | |
zip_file = zipfile.ZipFile(zip_path) | |
text_model = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-base-patch32") | |
text_tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") | |
d = 1024 | |
text_index = faiss.index_factory(d, "Flat", faiss.METRIC_INNER_PRODUCT) | |
faiss.read_index("text_index.index") | |
def T2Isearch(query, k=50): | |
# Encode the text query | |
inputs = text_tokenizer([query], padding=True, return_tensors="pt") | |
outputs = text_model(**inputs) | |
query_embedding = outputs.text_embeds | |
# query_vector = test_xt[0] | |
query_vector = np.array([query_embedding]) | |
faiss.normalize_L2(query_vector) | |
# text_index.nprobe = index.ntotal | |
text_index.nprobe = 100 | |
# Search for the nearest neighbors in the FAISS text index | |
D, I = text_index.search(query_vector, k) | |
# get rank of all classes wrt to query | |
classes_all = [] | |
Y = train_yt | |
neighbor_ys = Y[I] | |
class_freq = np.zeros(Y.shape[1]) | |
for neighbor_y in neighbor_ys: | |
classes = np.where(neighbor_y > 0.5)[0] | |
for _class in classes: | |
class_freq[_class] += 1 | |
count = 0 | |
for i in range(len(class_freq)): | |
if class_freq[i]>0: | |
count +=1 | |
ranked_classes = np.argsort(-class_freq) # chosen order of pivots -- predicted sequence of all labels for the query | |
ranked_classes_after_knn = ranked_classes[:count] # predicted sequence of top labels after knn search | |
lis = ['aeroplane', 'bicycle','bird','boat','bottle','bus','car','cat','chair','cow','diningtable','dog','horse','motorbike','person','pottedplant','sheep','sofa','train','tvmonitor'] | |
class_ = lis[ranked_classes_after_knn[0]-1] | |
# Map the image ids to the corresponding image URLs | |
for i in range(len(image_list)): | |
if class_list[i] == class_ : | |
image_name = image_list[i] | |
image_data = zip_file.open("pascal_raw/images/dataset/"+ image_name) | |
image = Image.open(image_data) | |
st.image(image, width=600) | |
query = st.text_input("Enter your search query here:") | |
if st.button("Search"): | |
if query: | |
T2Isearch(query) |