Spaces:
Sleeping
Sleeping
File size: 8,709 Bytes
7239b15 c45b099 cfcd5ba c45b099 eccac7a 25ae722 e8299fb 7239b15 c45b099 eccac7a 2d30c9d eccac7a 2d30c9d a7358e7 c45b099 eccac7a 2d30c9d eccac7a 7239b15 25ae722 eccac7a 2d30c9d 67cbd57 a7358e7 25ae722 e8299fb c45b099 7239b15 c45b099 7239b15 cbb6586 7320ee9 aa90ca5 b68c4cb 7239b15 e8299fb 7239b15 d541a70 e8299fb 7239b15 e8299fb c45b099 e8299fb c45b099 d541a70 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 |
import streamlit as st
st.set_page_config(page_title='ITR', page_icon="🧊", layout='centered')
st.title("LCM-Independent for Pascal Dataset")
import faiss
import numpy as np
from PIL import Image
import json
import zipfile
import pandas as pd
import pickle
from transformers import AutoTokenizer, CLIPTextModelWithProjection
from sklearn.preprocessing import normalize, OneHotEncoder
import torch.nn as nn
import torch.nn.functional as F
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
# loading the train dataset
with open('clip_train.pkl', 'rb') as f:
temp_d = pickle.load(f)
train_xv = temp_d['image'].astype(np.float64) # Array of image features : np ndarray
train_xt = temp_d['text'].astype(np.float64) # Array of text features : np ndarray
train_yv = temp_d['label'] # Array of labels
train_yt = temp_d['label'] # Array of labels
ids = list(temp_d['ids']) # image names == len(images)
# loading the test dataset
with open('clip_test.pkl', 'rb') as f:
temp_d = pickle.load(f)
test_xv = temp_d['image'].astype(np.float64)
test_xt = temp_d['text'].astype(np.float64)
test_yv = temp_d['label']
test_yt = temp_d['label']
test_xt_proj = np.load("test_text_proj.npy")
# test_xv_proj = np.load("test_image_proj.npy")
# encoding the labels
enc = OneHotEncoder(sparse=False)
enc.fit(np.concatenate((train_yt, test_yt)).reshape((-1, 1)))
train_yv = enc.transform(train_yv.reshape((-1, 1))).astype(np.float64)
test_yv = enc.transform(test_yv.reshape((-1, 1))).astype(np.float64)
train_yt = enc.transform(train_yt.reshape((-1, 1))).astype(np.float64)
test_yt = enc.transform(test_yt.reshape((-1, 1))).astype(np.float64)
# # Model structure
# torch.manual_seed(3074)
# class imgModel(nn.Module):
# def __init__(self, in_features, out_features):
# super(imgModel, self).__init__()
# self.l1 = nn.Linear(in_features=in_features, out_features=256)
# self.bn1 = nn.BatchNorm1d(256)
# self.dl1 = nn.Dropout(p=0.2)
# self.l2 = nn.Linear(in_features=256, out_features=out_features)
# def forward(self, x):
# x = self.l1(x)
# x = torch.sigmoid(x)
# x = self.dl1(x)
# x = self.bn1(x)
# x = self.l2(x)
# x = torch.tanh(x)
# return x
torch.manual_seed(3074)
class txtModel(nn.Module):
def __init__(self, in_features, out_features):
super(txtModel, self).__init__()
self.l1 = nn.Linear(in_features=in_features, out_features=256)
self.bn1 = nn.BatchNorm1d(256)
self.dl2= nn.Dropout(p=0.2)
self.l2 = nn.Linear(in_features=256, out_features=out_features)
def forward(self, x):
# print(x[0].shape)
x = self.l1(x)
x = torch.sigmoid(x)
x = self.dl2(x)
x = self.bn1(x)
x = torch.tanh(self.l2(x))
# print(x[0].shape)
return x
class customDataset(Dataset):
def __init__(self, any_data):
self.any_data = any_data
def __len__(self):
return self.any_data.shape[0]
def __getitem__(self, idx):
return self.any_data[idx]
# Map the image ids to the corresponding image URLs
image_map_name = 'pascal_dataset.csv'
df = pd.read_csv(image_map_name)
image_list = list(df['image'])
class_list = list(df['class'])
zip_path = "pascal_raw.zip"
zip_file = zipfile.ZipFile(zip_path)
text_model = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
text_tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
d = 32
text_index = faiss.index_factory(d, "Flat", faiss.METRIC_INNER_PRODUCT)
text_index = faiss.read_index("text_index.index")
np.random.seed(3074)
class model:
def __init__(self, L, dataset):
self.txt_model_type = 'simple'
self.L = 32
self.device = 'cpu'
self.batch_size = 1
self.SIGMA =0.01
self.txt_model = txtModel(train_xt.shape[1], L).to(self.device)
self.mse_criterion = nn.MSELoss(reduction='mean')
# image_state_dict = torch.load(dir_path +'/image_checkpoint.pth')
self.text_state_dict = torch.load('text_checkpoint.pth')
# img_model.load_state_dict(image_state_dict)
self.txt_model.load_state_dict(self.text_state_dict)
def ffModelLoss(self, data, output, true_output, criterion, model_type):
if model_type == 'simple':
return criterion(output, true_output)
elif model_type == 'ae_middle':
emb, reconstruction = output
return self.SIGMA*criterion(reconstruction, data) + criterion(emb, true_output)
def ffModelPred(self, output, model_type):
if model_type == 'simple':
return output.tolist()
elif model_type == 'ae_middle':
emb, reconstruction = output
return emb.tolist()
def infer(self, model, dataloader, criterion, B, modelLossFxn, model_type, predictionFxn, predictions=False, cal_loss=True):
model.eval()
running_loss = 0.0
preds = []
with torch.no_grad():
for i, data in enumerate(dataloader):
data = data.to(self.device)
data = data.view(data.size(0), -1)
output = model(data)
if predictions: preds += predictionFxn(output, model_type)
if cal_loss:
true_output = torch.tensor(B[i*self.batch_size:(i+1)*self.batch_size, :]).to(self.device)
loss = modelLossFxn(data, output, true_output, criterion, model_type)
running_loss += loss.item()
inference_loss = running_loss/len(dataloader.dataset)
if predictions: return inference_loss, np.array(preds)
else: return inference_loss
def T2Isearch(self, query, focussed_word, k=50):
# Encode the text query
inputs = text_tokenizer([query, focussed_word], padding=True, return_tensors="pt")
outputs = text_model(**inputs)
query_embedding = outputs.text_embeds
query_vector = query_embedding.detach().numpy()
query_vector = np.concatenate((query_vector[0], query_vector[1]), dtype=np.float32)
query_vector = query_vector.reshape(1,1024)
query_vector = customDataset(query_vector)
self.test_xt_loader = DataLoader(query_vector, batch_size=1, shuffle=False)
_, query_vector = self.infer(self.txt_model, self.test_xt_loader, self.mse_criterion, \
None, None, self.txt_model_type, self.ffModelPred, True, False)
query_vector = query_vector.astype(np.float32)
# give this input to learned encoder
# query_vector = test_xt_proj[i-1].astype(np.float32)
# query_vector = query_vector.reshape(1,32)
faiss.normalize_L2(query_vector)
text_index.nprobe = text_index.ntotal
# Search for the nearest neighbors in the FAISS text index
D, I = text_index.search(query_vector, k)
# get rank of all classes wrt to query
Y = train_yt
neighbor_ys = Y[I[0]]
class_freq = np.zeros(Y.shape[1])
for neighbor_y in neighbor_ys:
classes = np.where(neighbor_y > 0.5)[0]
for _class in classes:
class_freq[_class] += 1
count = 0
for i in range(len(class_freq)):
if class_freq[i]>0:
count +=1
ranked_classes = np.argsort(-class_freq) # chosen order of pivots -- predicted sequence of all labels for the query
ranked_classes_after_knn = ranked_classes[:count] # predicted sequence of top labels after knn search
lis = ['aeroplane', 'bicycle','bird','boat','bottle','bus','car','cat','chair','cow','diningtable','dog','horse','motorbike','person','pottedplant','sheep','sofa','train','tvmonitor']
class_ = lis[ranked_classes_after_knn[0]]
print(class_)
# Map the image ids to the corresponding image URLs
count = 0
for i in range(len(image_list)):
if class_list[i] == class_ :
count+=1
image_name = image_list[i]
image_data = zip_file.open("pascal_raw/images/dataset/"+ image_name)
image = Image.open(image_data)
st.image(image, width=600)
if count == 5: break
query = st.text_input("Enter your search query here:")
Focussed_word = st.text_input("Enter your focussed word here:")
if st.button("Search"):
LCM = model(d, "pascal")
if query:
LCM.T2Isearch(query, Focussed_word) |