shivangibithel commited on
Commit
e8299fb
1 Parent(s): cde3cdd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -48
app.py CHANGED
@@ -8,12 +8,13 @@ import json
8
  import zipfile
9
  import pandas as pd
10
  import pickle
11
- import pickletools
12
  from transformers import AutoTokenizer, CLIPTextModelWithProjection
13
  from sklearn.preprocessing import normalize, OneHotEncoder
14
  import torch.nn as nn
15
  import torch.nn.functional as F
16
  import torch
 
 
17
 
18
  # loading the train dataset
19
  with open('clip_train.pkl', 'rb') as f:
@@ -82,6 +83,16 @@ class txtModel(nn.Module):
82
  # print(x[0].shape)
83
  return x
84
 
 
 
 
 
 
 
 
 
 
 
85
  # Map the image ids to the corresponding image URLs
86
  image_map_name = 'pascal_dataset.csv'
87
  df = pd.read_csv(image_map_name)
@@ -97,55 +108,117 @@ d = 32
97
  text_index = faiss.index_factory(d, "Flat", faiss.METRIC_INNER_PRODUCT)
98
  text_index = faiss.read_index("text_index.index")
99
 
100
- def T2Isearch(query, i, k=50):
101
- # Encode the text query
102
- # inputs = text_tokenizer([query], padding=True, return_tensors="pt")
103
- # outputs = text_model(**inputs)
104
- # query_embedding = outputs.text_embeds
105
- # query_vector = query_embedding.detach().numpy()
106
- # query_vector = np.concatenate((query_vector[0], query_vector[1]), dtype=np.float32)
107
- # query_vector = query_vector.reshape(1,512)
108
-
109
- query_vector = test_xt_proj[i-1].astype(np.float32)
110
- query_vector = query_vector.reshape(1,32)
111
- faiss.normalize_L2(query_vector)
112
- text_index.nprobe = text_index.ntotal
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
- # Search for the nearest neighbors in the FAISS text index
115
- D, I = text_index.search(query_vector, k)
116
-
117
- # get rank of all classes wrt to query
118
- Y = train_yt
119
- neighbor_ys = Y[I[0]]
120
- class_freq = np.zeros(Y.shape[1])
121
- for neighbor_y in neighbor_ys:
122
- classes = np.where(neighbor_y > 0.5)[0]
123
- for _class in classes:
124
- class_freq[_class] += 1
125
-
126
- count = 0
127
- for i in range(len(class_freq)):
128
- if class_freq[i]>0:
129
- count +=1
130
- ranked_classes = np.argsort(-class_freq) # chosen order of pivots -- predicted sequence of all labels for the query
131
- ranked_classes_after_knn = ranked_classes[:count] # predicted sequence of top labels after knn search
132
-
133
- lis = ['aeroplane', 'bicycle','bird','boat','bottle','bus','car','cat','chair','cow','diningtable','dog','horse','motorbike','person','pottedplant','sheep','sofa','train','tvmonitor']
134
- class_ = lis[ranked_classes_after_knn[0]]
135
-
136
- # Map the image ids to the corresponding image URLs
137
- count = 0
138
- for i in range(len(image_list)):
139
- if class_list[i] == class_ :
140
- count+=1
141
- image_name = image_list[i]
142
- image_data = zip_file.open("pascal_raw/images/dataset/"+ image_name)
143
- image = Image.open(image_data)
144
- st.image(image, width=600)
145
- if count == 5: break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
  query = st.text_input("Enter your search query here:")
148
- i = st.text_input("Enter the index of test set from 1 - 200")
149
  if st.button("Search"):
 
150
  if query:
151
- T2Isearch(query, int(i))
 
8
  import zipfile
9
  import pandas as pd
10
  import pickle
 
11
  from transformers import AutoTokenizer, CLIPTextModelWithProjection
12
  from sklearn.preprocessing import normalize, OneHotEncoder
13
  import torch.nn as nn
14
  import torch.nn.functional as F
15
  import torch
16
+ from torch.utils.data import DataLoader
17
+ from torch.utils.data import Dataset
18
 
19
  # loading the train dataset
20
  with open('clip_train.pkl', 'rb') as f:
 
83
  # print(x[0].shape)
84
  return x
85
 
86
+ class customDataset(Dataset):
87
+ def __init__(self, any_data):
88
+ self.any_data = any_data
89
+
90
+ def __len__(self):
91
+ return self.any_data.shape[0]
92
+
93
+ def __getitem__(self, idx):
94
+ return self.any_data[idx]
95
+
96
  # Map the image ids to the corresponding image URLs
97
  image_map_name = 'pascal_dataset.csv'
98
  df = pd.read_csv(image_map_name)
 
108
  text_index = faiss.index_factory(d, "Flat", faiss.METRIC_INNER_PRODUCT)
109
  text_index = faiss.read_index("text_index.index")
110
 
111
+ np.random.seed(3074)
112
+ class model:
113
+ def __init__(self, L, dataset):
114
+ self.txt_model_type = 'simple'
115
+ self.L = 32
116
+ self.device = 'cpu'
117
+ self.batch_size = 1
118
+ self.SIGMA =0.01
119
+ self.txt_model = txtModel(train_xt.shape[1], L).to(self.device)
120
+ self.mse_criterion = nn.MSELoss(reduction='mean')
121
+ # image_state_dict = torch.load(dir_path +'/image_checkpoint.pth')
122
+ self.text_state_dict = torch.load('text_checkpoint.pth')
123
+
124
+ # img_model.load_state_dict(image_state_dict)
125
+ self.txt_model.load_state_dict(self.text_state_dict)
126
+
127
+ def ffModelLoss(self, data, output, true_output, criterion, model_type):
128
+ if model_type == 'simple':
129
+ return criterion(output, true_output)
130
+ elif model_type == 'ae_middle':
131
+ emb, reconstruction = output
132
+ return self.SIGMA*criterion(reconstruction, data) + criterion(emb, true_output)
133
+
134
+ def ffModelPred(self, output, model_type):
135
+ if model_type == 'simple':
136
+ return output.tolist()
137
+ elif model_type == 'ae_middle':
138
+ emb, reconstruction = output
139
+ return emb.tolist()
140
+
141
+ def infer(self, model, dataloader, criterion, B, modelLossFxn, model_type, predictionFxn, predictions=False, cal_loss=True):
142
+ model.eval()
143
+ running_loss = 0.0
144
+ preds = []
145
+
146
+ with torch.no_grad():
147
+ for i, data in enumerate(dataloader):
148
+ data = data.to(self.device)
149
+ data = data.view(data.size(0), -1)
150
+ output = model(data)
151
+ if predictions: preds += predictionFxn(output, model_type)
152
+ if cal_loss:
153
+ true_output = torch.tensor(B[i*self.batch_size:(i+1)*self.batch_size, :]).to(self.device)
154
+ loss = modelLossFxn(data, output, true_output, criterion, model_type)
155
+ running_loss += loss.item()
156
+ inference_loss = running_loss/len(dataloader.dataset)
157
+
158
+ if predictions: return inference_loss, np.array(preds)
159
+ else: return inference_loss
160
+
161
 
162
+ def T2Isearch(self, query, focussed_word, i, k=50):
163
+ # Encode the text query
164
+ inputs = text_tokenizer([query, focussed_word], padding=True, return_tensors="pt")
165
+ outputs = text_model(**inputs)
166
+ query_embedding = outputs.text_embeds
167
+ query_vector = query_embedding.detach().numpy()
168
+ query_vector = np.concatenate((query_vector[0], query_vector[1]), dtype=np.float32)
169
+ query_vector = query_vector.reshape(1,1024)
170
+ query_vector = customDataset(query_vector)
171
+ self.test_xt_loader = DataLoader(query_vector, batch_size=1, shuffle=False)
172
+ _, query_vector = self.infer(self.txt_model, self.test_xt_loader, self.mse_criterion, \
173
+ None, None, self.txt_model_type, self.ffModelPred, True, False)
174
+
175
+ query_vector = query_vector.astype(np.float32)
176
+
177
+ # give this input to learned encoder
178
+
179
+ # query_vector = test_xt_proj[i-1].astype(np.float32)
180
+ # query_vector = query_vector.reshape(1,32)
181
+
182
+ faiss.normalize_L2(query_vector)
183
+ text_index.nprobe = text_index.ntotal
184
+
185
+ # Search for the nearest neighbors in the FAISS text index
186
+ D, I = text_index.search(query_vector, k)
187
+
188
+ # get rank of all classes wrt to query
189
+ Y = train_yt
190
+ neighbor_ys = Y[I[0]]
191
+ class_freq = np.zeros(Y.shape[1])
192
+ for neighbor_y in neighbor_ys:
193
+ classes = np.where(neighbor_y > 0.5)[0]
194
+ for _class in classes:
195
+ class_freq[_class] += 1
196
+
197
+ count = 0
198
+ for i in range(len(class_freq)):
199
+ if class_freq[i]>0:
200
+ count +=1
201
+ ranked_classes = np.argsort(-class_freq) # chosen order of pivots -- predicted sequence of all labels for the query
202
+ ranked_classes_after_knn = ranked_classes[:count] # predicted sequence of top labels after knn search
203
+
204
+ lis = ['aeroplane', 'bicycle','bird','boat','bottle','bus','car','cat','chair','cow','diningtable','dog','horse','motorbike','person','pottedplant','sheep','sofa','train','tvmonitor']
205
+ class_ = lis[ranked_classes_after_knn[0]]
206
+ print(class_)
207
+
208
+ # Map the image ids to the corresponding image URLs
209
+ count = 0
210
+ for i in range(len(image_list)):
211
+ if class_list[i] == class_ :
212
+ count+=1
213
+ image_name = image_list[i]
214
+ image_data = zip_file.open("pascal_raw/images/dataset/"+ image_name)
215
+ image = Image.open(image_data)
216
+ st.image(image, width=600)
217
+ if count == 5: break
218
 
219
  query = st.text_input("Enter your search query here:")
220
+ Focussed_word = st.text_input("Enter your focussed word here:")
221
  if st.button("Search"):
222
+ LCM = model(d, "pascal")
223
  if query:
224
+ LCM.T2Isearch(query, Focussed_word, int(ind))