import streamlit as st st.set_page_config(page_title='ITR', page_icon="🧊", layout='centered') st.title("LCM-Independent for Pascal Dataset") import faiss import numpy as np from PIL import Image import json import zipfile import pandas as pd import pickle from transformers import AutoTokenizer, CLIPTextModelWithProjection from sklearn.preprocessing import normalize, OneHotEncoder import torch.nn as nn import torch.nn.functional as F import torch from torch.utils.data import DataLoader from torch.utils.data import Dataset # loading the train dataset with open('clip_train.pkl', 'rb') as f: temp_d = pickle.load(f) train_xv = temp_d['image'].astype(np.float64) # Array of image features : np ndarray train_xt = temp_d['text'].astype(np.float64) # Array of text features : np ndarray train_yv = temp_d['label'] # Array of labels train_yt = temp_d['label'] # Array of labels ids = list(temp_d['ids']) # image names == len(images) # loading the test dataset with open('clip_test.pkl', 'rb') as f: temp_d = pickle.load(f) test_xv = temp_d['image'].astype(np.float64) test_xt = temp_d['text'].astype(np.float64) test_yv = temp_d['label'] test_yt = temp_d['label'] test_xt_proj = np.load("test_text_proj.npy") # test_xv_proj = np.load("test_image_proj.npy") # encoding the labels enc = OneHotEncoder(sparse=False) enc.fit(np.concatenate((train_yt, test_yt)).reshape((-1, 1))) train_yv = enc.transform(train_yv.reshape((-1, 1))).astype(np.float64) test_yv = enc.transform(test_yv.reshape((-1, 1))).astype(np.float64) train_yt = enc.transform(train_yt.reshape((-1, 1))).astype(np.float64) test_yt = enc.transform(test_yt.reshape((-1, 1))).astype(np.float64) # # Model structure # torch.manual_seed(3074) # class imgModel(nn.Module): # def __init__(self, in_features, out_features): # super(imgModel, self).__init__() # self.l1 = nn.Linear(in_features=in_features, out_features=256) # self.bn1 = nn.BatchNorm1d(256) # self.dl1 = nn.Dropout(p=0.2) # self.l2 = nn.Linear(in_features=256, out_features=out_features) # def forward(self, x): # x = self.l1(x) # x = torch.sigmoid(x) # x = self.dl1(x) # x = self.bn1(x) # x = self.l2(x) # x = torch.tanh(x) # return x torch.manual_seed(3074) class txtModel(nn.Module): def __init__(self, in_features, out_features): super(txtModel, self).__init__() self.l1 = nn.Linear(in_features=in_features, out_features=256) self.bn1 = nn.BatchNorm1d(256) self.dl2= nn.Dropout(p=0.2) self.l2 = nn.Linear(in_features=256, out_features=out_features) def forward(self, x): # print(x[0].shape) x = self.l1(x) x = torch.sigmoid(x) x = self.dl2(x) x = self.bn1(x) x = torch.tanh(self.l2(x)) # print(x[0].shape) return x class customDataset(Dataset): def __init__(self, any_data): self.any_data = any_data def __len__(self): return self.any_data.shape[0] def __getitem__(self, idx): return self.any_data[idx] # Map the image ids to the corresponding image URLs image_map_name = 'pascal_dataset.csv' df = pd.read_csv(image_map_name) image_list = list(df['image']) class_list = list(df['class']) zip_path = "pascal_raw.zip" zip_file = zipfile.ZipFile(zip_path) text_model = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-base-patch32") text_tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") d = 32 text_index = faiss.index_factory(d, "Flat", faiss.METRIC_INNER_PRODUCT) text_index = faiss.read_index("text_index.index") np.random.seed(3074) class model: def __init__(self, L, dataset): self.txt_model_type = 'simple' self.L = 32 self.device = 'cpu' self.batch_size = 1 self.SIGMA =0.01 self.txt_model = txtModel(train_xt.shape[1], L).to(self.device) self.mse_criterion = nn.MSELoss(reduction='mean') # image_state_dict = torch.load(dir_path +'/image_checkpoint.pth') self.text_state_dict = torch.load('text_checkpoint.pth') # img_model.load_state_dict(image_state_dict) self.txt_model.load_state_dict(self.text_state_dict) def ffModelLoss(self, data, output, true_output, criterion, model_type): if model_type == 'simple': return criterion(output, true_output) elif model_type == 'ae_middle': emb, reconstruction = output return self.SIGMA*criterion(reconstruction, data) + criterion(emb, true_output) def ffModelPred(self, output, model_type): if model_type == 'simple': return output.tolist() elif model_type == 'ae_middle': emb, reconstruction = output return emb.tolist() def infer(self, model, dataloader, criterion, B, modelLossFxn, model_type, predictionFxn, predictions=False, cal_loss=True): model.eval() running_loss = 0.0 preds = [] with torch.no_grad(): for i, data in enumerate(dataloader): data = data.to(self.device) data = data.view(data.size(0), -1) output = model(data) if predictions: preds += predictionFxn(output, model_type) if cal_loss: true_output = torch.tensor(B[i*self.batch_size:(i+1)*self.batch_size, :]).to(self.device) loss = modelLossFxn(data, output, true_output, criterion, model_type) running_loss += loss.item() inference_loss = running_loss/len(dataloader.dataset) if predictions: return inference_loss, np.array(preds) else: return inference_loss def T2Isearch(self, query, focussed_word, k=50): # Encode the text query inputs = text_tokenizer([query, focussed_word], padding=True, return_tensors="pt") outputs = text_model(**inputs) query_embedding = outputs.text_embeds query_vector = query_embedding.detach().numpy() query_vector = np.concatenate((query_vector[0], query_vector[1]), dtype=np.float32) query_vector = query_vector.reshape(1,1024) query_vector = customDataset(query_vector) self.test_xt_loader = DataLoader(query_vector, batch_size=1, shuffle=False) _, query_vector = self.infer(self.txt_model, self.test_xt_loader, self.mse_criterion, \ None, None, self.txt_model_type, self.ffModelPred, True, False) query_vector = query_vector.astype(np.float32) # give this input to learned encoder # query_vector = test_xt_proj[i-1].astype(np.float32) # query_vector = query_vector.reshape(1,32) faiss.normalize_L2(query_vector) text_index.nprobe = text_index.ntotal # Search for the nearest neighbors in the FAISS text index D, I = text_index.search(query_vector, k) # get rank of all classes wrt to query Y = train_yt neighbor_ys = Y[I[0]] class_freq = np.zeros(Y.shape[1]) for neighbor_y in neighbor_ys: classes = np.where(neighbor_y > 0.5)[0] for _class in classes: class_freq[_class] += 1 count = 0 for i in range(len(class_freq)): if class_freq[i]>0: count +=1 ranked_classes = np.argsort(-class_freq) # chosen order of pivots -- predicted sequence of all labels for the query ranked_classes_after_knn = ranked_classes[:count] # predicted sequence of top labels after knn search lis = ['aeroplane', 'bicycle','bird','boat','bottle','bus','car','cat','chair','cow','diningtable','dog','horse','motorbike','person','pottedplant','sheep','sofa','train','tvmonitor'] class_ = lis[ranked_classes_after_knn[0]] print(class_) # Map the image ids to the corresponding image URLs count = 0 for i in range(len(image_list)): if class_list[i] == class_ : count+=1 image_name = image_list[i] image_data = zip_file.open("pascal_raw/images/dataset/"+ image_name) image = Image.open(image_data) st.image(image, width=600) if count == 5: break query = st.text_input("Enter your search query here:") Focussed_word = st.text_input("Enter your focussed word here:") if st.button("Search"): LCM = model(d, "pascal") if query: LCM.T2Isearch(query, Focussed_word)