Spaces:
Sleeping
Sleeping
Commit
•
e8299fb
1
Parent(s):
cde3cdd
Update app.py
Browse files
app.py
CHANGED
@@ -8,12 +8,13 @@ import json
|
|
8 |
import zipfile
|
9 |
import pandas as pd
|
10 |
import pickle
|
11 |
-
import pickletools
|
12 |
from transformers import AutoTokenizer, CLIPTextModelWithProjection
|
13 |
from sklearn.preprocessing import normalize, OneHotEncoder
|
14 |
import torch.nn as nn
|
15 |
import torch.nn.functional as F
|
16 |
import torch
|
|
|
|
|
17 |
|
18 |
# loading the train dataset
|
19 |
with open('clip_train.pkl', 'rb') as f:
|
@@ -82,6 +83,16 @@ class txtModel(nn.Module):
|
|
82 |
# print(x[0].shape)
|
83 |
return x
|
84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
# Map the image ids to the corresponding image URLs
|
86 |
image_map_name = 'pascal_dataset.csv'
|
87 |
df = pd.read_csv(image_map_name)
|
@@ -97,55 +108,117 @@ d = 32
|
|
97 |
text_index = faiss.index_factory(d, "Flat", faiss.METRIC_INNER_PRODUCT)
|
98 |
text_index = faiss.read_index("text_index.index")
|
99 |
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
|
147 |
query = st.text_input("Enter your search query here:")
|
148 |
-
|
149 |
if st.button("Search"):
|
|
|
150 |
if query:
|
151 |
-
T2Isearch(query, int(
|
|
|
8 |
import zipfile
|
9 |
import pandas as pd
|
10 |
import pickle
|
|
|
11 |
from transformers import AutoTokenizer, CLIPTextModelWithProjection
|
12 |
from sklearn.preprocessing import normalize, OneHotEncoder
|
13 |
import torch.nn as nn
|
14 |
import torch.nn.functional as F
|
15 |
import torch
|
16 |
+
from torch.utils.data import DataLoader
|
17 |
+
from torch.utils.data import Dataset
|
18 |
|
19 |
# loading the train dataset
|
20 |
with open('clip_train.pkl', 'rb') as f:
|
|
|
83 |
# print(x[0].shape)
|
84 |
return x
|
85 |
|
86 |
+
class customDataset(Dataset):
|
87 |
+
def __init__(self, any_data):
|
88 |
+
self.any_data = any_data
|
89 |
+
|
90 |
+
def __len__(self):
|
91 |
+
return self.any_data.shape[0]
|
92 |
+
|
93 |
+
def __getitem__(self, idx):
|
94 |
+
return self.any_data[idx]
|
95 |
+
|
96 |
# Map the image ids to the corresponding image URLs
|
97 |
image_map_name = 'pascal_dataset.csv'
|
98 |
df = pd.read_csv(image_map_name)
|
|
|
108 |
text_index = faiss.index_factory(d, "Flat", faiss.METRIC_INNER_PRODUCT)
|
109 |
text_index = faiss.read_index("text_index.index")
|
110 |
|
111 |
+
np.random.seed(3074)
|
112 |
+
class model:
|
113 |
+
def __init__(self, L, dataset):
|
114 |
+
self.txt_model_type = 'simple'
|
115 |
+
self.L = 32
|
116 |
+
self.device = 'cpu'
|
117 |
+
self.batch_size = 1
|
118 |
+
self.SIGMA =0.01
|
119 |
+
self.txt_model = txtModel(train_xt.shape[1], L).to(self.device)
|
120 |
+
self.mse_criterion = nn.MSELoss(reduction='mean')
|
121 |
+
# image_state_dict = torch.load(dir_path +'/image_checkpoint.pth')
|
122 |
+
self.text_state_dict = torch.load('text_checkpoint.pth')
|
123 |
+
|
124 |
+
# img_model.load_state_dict(image_state_dict)
|
125 |
+
self.txt_model.load_state_dict(self.text_state_dict)
|
126 |
+
|
127 |
+
def ffModelLoss(self, data, output, true_output, criterion, model_type):
|
128 |
+
if model_type == 'simple':
|
129 |
+
return criterion(output, true_output)
|
130 |
+
elif model_type == 'ae_middle':
|
131 |
+
emb, reconstruction = output
|
132 |
+
return self.SIGMA*criterion(reconstruction, data) + criterion(emb, true_output)
|
133 |
+
|
134 |
+
def ffModelPred(self, output, model_type):
|
135 |
+
if model_type == 'simple':
|
136 |
+
return output.tolist()
|
137 |
+
elif model_type == 'ae_middle':
|
138 |
+
emb, reconstruction = output
|
139 |
+
return emb.tolist()
|
140 |
+
|
141 |
+
def infer(self, model, dataloader, criterion, B, modelLossFxn, model_type, predictionFxn, predictions=False, cal_loss=True):
|
142 |
+
model.eval()
|
143 |
+
running_loss = 0.0
|
144 |
+
preds = []
|
145 |
+
|
146 |
+
with torch.no_grad():
|
147 |
+
for i, data in enumerate(dataloader):
|
148 |
+
data = data.to(self.device)
|
149 |
+
data = data.view(data.size(0), -1)
|
150 |
+
output = model(data)
|
151 |
+
if predictions: preds += predictionFxn(output, model_type)
|
152 |
+
if cal_loss:
|
153 |
+
true_output = torch.tensor(B[i*self.batch_size:(i+1)*self.batch_size, :]).to(self.device)
|
154 |
+
loss = modelLossFxn(data, output, true_output, criterion, model_type)
|
155 |
+
running_loss += loss.item()
|
156 |
+
inference_loss = running_loss/len(dataloader.dataset)
|
157 |
+
|
158 |
+
if predictions: return inference_loss, np.array(preds)
|
159 |
+
else: return inference_loss
|
160 |
+
|
161 |
|
162 |
+
def T2Isearch(self, query, focussed_word, i, k=50):
|
163 |
+
# Encode the text query
|
164 |
+
inputs = text_tokenizer([query, focussed_word], padding=True, return_tensors="pt")
|
165 |
+
outputs = text_model(**inputs)
|
166 |
+
query_embedding = outputs.text_embeds
|
167 |
+
query_vector = query_embedding.detach().numpy()
|
168 |
+
query_vector = np.concatenate((query_vector[0], query_vector[1]), dtype=np.float32)
|
169 |
+
query_vector = query_vector.reshape(1,1024)
|
170 |
+
query_vector = customDataset(query_vector)
|
171 |
+
self.test_xt_loader = DataLoader(query_vector, batch_size=1, shuffle=False)
|
172 |
+
_, query_vector = self.infer(self.txt_model, self.test_xt_loader, self.mse_criterion, \
|
173 |
+
None, None, self.txt_model_type, self.ffModelPred, True, False)
|
174 |
+
|
175 |
+
query_vector = query_vector.astype(np.float32)
|
176 |
+
|
177 |
+
# give this input to learned encoder
|
178 |
+
|
179 |
+
# query_vector = test_xt_proj[i-1].astype(np.float32)
|
180 |
+
# query_vector = query_vector.reshape(1,32)
|
181 |
+
|
182 |
+
faiss.normalize_L2(query_vector)
|
183 |
+
text_index.nprobe = text_index.ntotal
|
184 |
+
|
185 |
+
# Search for the nearest neighbors in the FAISS text index
|
186 |
+
D, I = text_index.search(query_vector, k)
|
187 |
+
|
188 |
+
# get rank of all classes wrt to query
|
189 |
+
Y = train_yt
|
190 |
+
neighbor_ys = Y[I[0]]
|
191 |
+
class_freq = np.zeros(Y.shape[1])
|
192 |
+
for neighbor_y in neighbor_ys:
|
193 |
+
classes = np.where(neighbor_y > 0.5)[0]
|
194 |
+
for _class in classes:
|
195 |
+
class_freq[_class] += 1
|
196 |
+
|
197 |
+
count = 0
|
198 |
+
for i in range(len(class_freq)):
|
199 |
+
if class_freq[i]>0:
|
200 |
+
count +=1
|
201 |
+
ranked_classes = np.argsort(-class_freq) # chosen order of pivots -- predicted sequence of all labels for the query
|
202 |
+
ranked_classes_after_knn = ranked_classes[:count] # predicted sequence of top labels after knn search
|
203 |
+
|
204 |
+
lis = ['aeroplane', 'bicycle','bird','boat','bottle','bus','car','cat','chair','cow','diningtable','dog','horse','motorbike','person','pottedplant','sheep','sofa','train','tvmonitor']
|
205 |
+
class_ = lis[ranked_classes_after_knn[0]]
|
206 |
+
print(class_)
|
207 |
+
|
208 |
+
# Map the image ids to the corresponding image URLs
|
209 |
+
count = 0
|
210 |
+
for i in range(len(image_list)):
|
211 |
+
if class_list[i] == class_ :
|
212 |
+
count+=1
|
213 |
+
image_name = image_list[i]
|
214 |
+
image_data = zip_file.open("pascal_raw/images/dataset/"+ image_name)
|
215 |
+
image = Image.open(image_data)
|
216 |
+
st.image(image, width=600)
|
217 |
+
if count == 5: break
|
218 |
|
219 |
query = st.text_input("Enter your search query here:")
|
220 |
+
Focussed_word = st.text_input("Enter your focussed word here:")
|
221 |
if st.button("Search"):
|
222 |
+
LCM = model(d, "pascal")
|
223 |
if query:
|
224 |
+
LCM.T2Isearch(query, Focussed_word, int(ind))
|