Spaces:
Sleeping
Sleeping
import streamlit as st | |
st.set_page_config(page_title='ITR', page_icon="🧊", layout='centered') | |
st.title("LCM-Independent for Pascal Dataset") | |
import faiss | |
import numpy as np | |
from PIL import Image | |
import json | |
import zipfile | |
import pandas as pd | |
import pickle | |
from transformers import AutoTokenizer, CLIPTextModelWithProjection | |
from sklearn.preprocessing import normalize, OneHotEncoder | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch | |
from torch.utils.data import DataLoader | |
from torch.utils.data import Dataset | |
# loading the train dataset | |
with open('clip_train.pkl', 'rb') as f: | |
temp_d = pickle.load(f) | |
train_xv = temp_d['image'].astype(np.float64) # Array of image features : np ndarray | |
train_xt = temp_d['text'].astype(np.float64) # Array of text features : np ndarray | |
train_yv = temp_d['label'] # Array of labels | |
train_yt = temp_d['label'] # Array of labels | |
ids = list(temp_d['ids']) # image names == len(images) | |
# loading the test dataset | |
with open('clip_test.pkl', 'rb') as f: | |
temp_d = pickle.load(f) | |
test_xv = temp_d['image'].astype(np.float64) | |
test_xt = temp_d['text'].astype(np.float64) | |
test_yv = temp_d['label'] | |
test_yt = temp_d['label'] | |
test_xt_proj = np.load("test_text_proj.npy") | |
# test_xv_proj = np.load("test_image_proj.npy") | |
# encoding the labels | |
enc = OneHotEncoder(sparse=False) | |
enc.fit(np.concatenate((train_yt, test_yt)).reshape((-1, 1))) | |
train_yv = enc.transform(train_yv.reshape((-1, 1))).astype(np.float64) | |
test_yv = enc.transform(test_yv.reshape((-1, 1))).astype(np.float64) | |
train_yt = enc.transform(train_yt.reshape((-1, 1))).astype(np.float64) | |
test_yt = enc.transform(test_yt.reshape((-1, 1))).astype(np.float64) | |
# # Model structure | |
# torch.manual_seed(3074) | |
# class imgModel(nn.Module): | |
# def __init__(self, in_features, out_features): | |
# super(imgModel, self).__init__() | |
# self.l1 = nn.Linear(in_features=in_features, out_features=256) | |
# self.bn1 = nn.BatchNorm1d(256) | |
# self.dl1 = nn.Dropout(p=0.2) | |
# self.l2 = nn.Linear(in_features=256, out_features=out_features) | |
# def forward(self, x): | |
# x = self.l1(x) | |
# x = torch.sigmoid(x) | |
# x = self.dl1(x) | |
# x = self.bn1(x) | |
# x = self.l2(x) | |
# x = torch.tanh(x) | |
# return x | |
torch.manual_seed(3074) | |
class txtModel(nn.Module): | |
def __init__(self, in_features, out_features): | |
super(txtModel, self).__init__() | |
self.l1 = nn.Linear(in_features=in_features, out_features=256) | |
self.bn1 = nn.BatchNorm1d(256) | |
self.dl2= nn.Dropout(p=0.2) | |
self.l2 = nn.Linear(in_features=256, out_features=out_features) | |
def forward(self, x): | |
# print(x[0].shape) | |
x = self.l1(x) | |
x = torch.sigmoid(x) | |
x = self.dl2(x) | |
x = self.bn1(x) | |
x = torch.tanh(self.l2(x)) | |
# print(x[0].shape) | |
return x | |
class customDataset(Dataset): | |
def __init__(self, any_data): | |
self.any_data = any_data | |
def __len__(self): | |
return self.any_data.shape[0] | |
def __getitem__(self, idx): | |
return self.any_data[idx] | |
# Map the image ids to the corresponding image URLs | |
image_map_name = 'pascal_dataset.csv' | |
df = pd.read_csv(image_map_name) | |
image_list = list(df['image']) | |
class_list = list(df['class']) | |
zip_path = "pascal_raw.zip" | |
zip_file = zipfile.ZipFile(zip_path) | |
text_model = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-base-patch32") | |
text_tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") | |
d = 32 | |
text_index = faiss.index_factory(d, "Flat", faiss.METRIC_INNER_PRODUCT) | |
text_index = faiss.read_index("text_index.index") | |
np.random.seed(3074) | |
class model: | |
def __init__(self, L, dataset): | |
self.txt_model_type = 'simple' | |
self.L = 32 | |
self.device = 'cpu' | |
self.batch_size = 1 | |
self.SIGMA =0.01 | |
self.txt_model = txtModel(train_xt.shape[1], L).to(self.device) | |
self.mse_criterion = nn.MSELoss(reduction='mean') | |
# image_state_dict = torch.load(dir_path +'/image_checkpoint.pth') | |
self.text_state_dict = torch.load('text_checkpoint.pth') | |
# img_model.load_state_dict(image_state_dict) | |
self.txt_model.load_state_dict(self.text_state_dict) | |
def ffModelLoss(self, data, output, true_output, criterion, model_type): | |
if model_type == 'simple': | |
return criterion(output, true_output) | |
elif model_type == 'ae_middle': | |
emb, reconstruction = output | |
return self.SIGMA*criterion(reconstruction, data) + criterion(emb, true_output) | |
def ffModelPred(self, output, model_type): | |
if model_type == 'simple': | |
return output.tolist() | |
elif model_type == 'ae_middle': | |
emb, reconstruction = output | |
return emb.tolist() | |
def infer(self, model, dataloader, criterion, B, modelLossFxn, model_type, predictionFxn, predictions=False, cal_loss=True): | |
model.eval() | |
running_loss = 0.0 | |
preds = [] | |
with torch.no_grad(): | |
for i, data in enumerate(dataloader): | |
data = data.to(self.device) | |
data = data.view(data.size(0), -1) | |
output = model(data) | |
if predictions: preds += predictionFxn(output, model_type) | |
if cal_loss: | |
true_output = torch.tensor(B[i*self.batch_size:(i+1)*self.batch_size, :]).to(self.device) | |
loss = modelLossFxn(data, output, true_output, criterion, model_type) | |
running_loss += loss.item() | |
inference_loss = running_loss/len(dataloader.dataset) | |
if predictions: return inference_loss, np.array(preds) | |
else: return inference_loss | |
def T2Isearch(self, query, focussed_word, i, k=50): | |
# Encode the text query | |
inputs = text_tokenizer([query, focussed_word], padding=True, return_tensors="pt") | |
outputs = text_model(**inputs) | |
query_embedding = outputs.text_embeds | |
query_vector = query_embedding.detach().numpy() | |
query_vector = np.concatenate((query_vector[0], query_vector[1]), dtype=np.float32) | |
query_vector = query_vector.reshape(1,1024) | |
query_vector = customDataset(query_vector) | |
self.test_xt_loader = DataLoader(query_vector, batch_size=1, shuffle=False) | |
_, query_vector = self.infer(self.txt_model, self.test_xt_loader, self.mse_criterion, \ | |
None, None, self.txt_model_type, self.ffModelPred, True, False) | |
query_vector = query_vector.astype(np.float32) | |
# give this input to learned encoder | |
# query_vector = test_xt_proj[i-1].astype(np.float32) | |
# query_vector = query_vector.reshape(1,32) | |
faiss.normalize_L2(query_vector) | |
text_index.nprobe = text_index.ntotal | |
# Search for the nearest neighbors in the FAISS text index | |
D, I = text_index.search(query_vector, k) | |
# get rank of all classes wrt to query | |
Y = train_yt | |
neighbor_ys = Y[I[0]] | |
class_freq = np.zeros(Y.shape[1]) | |
for neighbor_y in neighbor_ys: | |
classes = np.where(neighbor_y > 0.5)[0] | |
for _class in classes: | |
class_freq[_class] += 1 | |
count = 0 | |
for i in range(len(class_freq)): | |
if class_freq[i]>0: | |
count +=1 | |
ranked_classes = np.argsort(-class_freq) # chosen order of pivots -- predicted sequence of all labels for the query | |
ranked_classes_after_knn = ranked_classes[:count] # predicted sequence of top labels after knn search | |
lis = ['aeroplane', 'bicycle','bird','boat','bottle','bus','car','cat','chair','cow','diningtable','dog','horse','motorbike','person','pottedplant','sheep','sofa','train','tvmonitor'] | |
class_ = lis[ranked_classes_after_knn[0]] | |
print(class_) | |
# Map the image ids to the corresponding image URLs | |
count = 0 | |
for i in range(len(image_list)): | |
if class_list[i] == class_ : | |
count+=1 | |
image_name = image_list[i] | |
image_data = zip_file.open("pascal_raw/images/dataset/"+ image_name) | |
image = Image.open(image_data) | |
st.image(image, width=600) | |
if count == 5: break | |
query = st.text_input("Enter your search query here:") | |
Focussed_word = st.text_input("Enter your focussed word here:") | |
if st.button("Search"): | |
LCM = model(d, "pascal") | |
if query: | |
LCM.T2Isearch(query, Focussed_word, int(ind)) |