shivangibithel commited on
Commit
25ae722
1 Parent(s): 86ba518

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -9
app.py CHANGED
@@ -11,6 +11,9 @@ import pickle
11
  import pickletools
12
  from transformers import AutoTokenizer, CLIPTextModelWithProjection
13
  from sklearn.preprocessing import normalize, OneHotEncoder
 
 
 
14
 
15
  # loading the train dataset
16
  with open('clip_train.pkl', 'rb') as f:
@@ -29,6 +32,10 @@ with open('clip_test.pkl', 'rb') as f:
29
  test_yv = temp_d['label']
30
  test_yt = temp_d['label']
31
 
 
 
 
 
32
  enc = OneHotEncoder(sparse=False)
33
  enc.fit(np.concatenate((train_yt, test_yt)).reshape((-1, 1)))
34
  train_yv = enc.transform(train_yv.reshape((-1, 1))).astype(np.float64)
@@ -36,6 +43,45 @@ test_yv = enc.transform(test_yv.reshape((-1, 1))).astype(np.float64)
36
  train_yt = enc.transform(train_yt.reshape((-1, 1))).astype(np.float64)
37
  test_yt = enc.transform(test_yt.reshape((-1, 1))).astype(np.float64)
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  # Map the image ids to the corresponding image URLs
40
  image_map_name = 'pascal_dataset.csv'
41
  df = pd.read_csv(image_map_name)
@@ -51,14 +97,17 @@ d = 32
51
  text_index = faiss.index_factory(d, "Flat", faiss.METRIC_INNER_PRODUCT)
52
  text_index = faiss.read_index("text_index.index")
53
 
54
- def T2Isearch(query, k=50):
55
  # Encode the text query
56
- inputs = text_tokenizer([query], padding=True, return_tensors="pt")
57
- outputs = text_model(**inputs)
58
- query_embedding = outputs.text_embeds
59
- query_vector = query_embedding.detach().numpy()
60
  # query_vector = np.concatenate((query_vector[0], query_vector[1]), dtype=np.float32)
61
- query_vector = query_vector.reshape(1,512)
 
 
 
62
  faiss.normalize_L2(query_vector)
63
  index.nprobe = index.ntotal
64
 
@@ -66,7 +115,7 @@ def T2Isearch(query, k=50):
66
  D, I = text_index.search(query_vector, k)
67
 
68
  # get rank of all classes wrt to query
69
- classes_all = []
70
  Y = train_yt
71
  neighbor_ys = Y[I]
72
  class_freq = np.zeros(Y.shape[1])
@@ -98,7 +147,7 @@ def T2Isearch(query, k=50):
98
  if count == 5: break
99
 
100
  query = st.text_input("Enter your search query here:")
101
-
102
  if st.button("Search"):
103
  if query:
104
- T2Isearch(query)
 
11
  import pickletools
12
  from transformers import AutoTokenizer, CLIPTextModelWithProjection
13
  from sklearn.preprocessing import normalize, OneHotEncoder
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ import torch
17
 
18
  # loading the train dataset
19
  with open('clip_train.pkl', 'rb') as f:
 
32
  test_yv = temp_d['label']
33
  test_yt = temp_d['label']
34
 
35
+ test_xt_proj = np.load("test_text_proj.npy")
36
+ # test_xv_proj = np.load("test_image_proj.npy")
37
+
38
+ # encoding the labels
39
  enc = OneHotEncoder(sparse=False)
40
  enc.fit(np.concatenate((train_yt, test_yt)).reshape((-1, 1)))
41
  train_yv = enc.transform(train_yv.reshape((-1, 1))).astype(np.float64)
 
43
  train_yt = enc.transform(train_yt.reshape((-1, 1))).astype(np.float64)
44
  test_yt = enc.transform(test_yt.reshape((-1, 1))).astype(np.float64)
45
 
46
+ # # Model structure
47
+ # torch.manual_seed(3074)
48
+ # class imgModel(nn.Module):
49
+ # def __init__(self, in_features, out_features):
50
+ # super(imgModel, self).__init__()
51
+ # self.l1 = nn.Linear(in_features=in_features, out_features=256)
52
+ # self.bn1 = nn.BatchNorm1d(256)
53
+ # self.dl1 = nn.Dropout(p=0.2)
54
+ # self.l2 = nn.Linear(in_features=256, out_features=out_features)
55
+
56
+ # def forward(self, x):
57
+ # x = self.l1(x)
58
+ # x = torch.sigmoid(x)
59
+ # x = self.dl1(x)
60
+ # x = self.bn1(x)
61
+
62
+ # x = self.l2(x)
63
+ # x = torch.tanh(x)
64
+ # return x
65
+
66
+ torch.manual_seed(3074)
67
+ class txtModel(nn.Module):
68
+ def __init__(self, in_features, out_features):
69
+ super(txtModel, self).__init__()
70
+ self.l1 = nn.Linear(in_features=in_features, out_features=256)
71
+ self.bn1 = nn.BatchNorm1d(256)
72
+ self.dl2= nn.Dropout(p=0.2)
73
+ self.l2 = nn.Linear(in_features=256, out_features=out_features)
74
+
75
+ def forward(self, x):
76
+ # print(x[0].shape)
77
+ x = self.l1(x)
78
+ x = torch.sigmoid(x)
79
+ x = self.dl2(x)
80
+ x = self.bn1(x)
81
+ x = torch.tanh(self.l2(x))
82
+ # print(x[0].shape)
83
+ return x
84
+
85
  # Map the image ids to the corresponding image URLs
86
  image_map_name = 'pascal_dataset.csv'
87
  df = pd.read_csv(image_map_name)
 
97
  text_index = faiss.index_factory(d, "Flat", faiss.METRIC_INNER_PRODUCT)
98
  text_index = faiss.read_index("text_index.index")
99
 
100
+ def T2Isearch(query, i, k=50):
101
  # Encode the text query
102
+ # inputs = text_tokenizer([query], padding=True, return_tensors="pt")
103
+ # outputs = text_model(**inputs)
104
+ # query_embedding = outputs.text_embeds
105
+ # query_vector = query_embedding.detach().numpy()
106
  # query_vector = np.concatenate((query_vector[0], query_vector[1]), dtype=np.float32)
107
+ # query_vector = query_vector.reshape(1,512)
108
+
109
+ query_vector = test_xt_proj[i-1]
110
+ query_vector = query_vector.reshape(1,32)
111
  faiss.normalize_L2(query_vector)
112
  index.nprobe = index.ntotal
113
 
 
115
  D, I = text_index.search(query_vector, k)
116
 
117
  # get rank of all classes wrt to query
118
+
119
  Y = train_yt
120
  neighbor_ys = Y[I]
121
  class_freq = np.zeros(Y.shape[1])
 
147
  if count == 5: break
148
 
149
  query = st.text_input("Enter your search query here:")
150
+ i = st.text_input("Enter the index of test set from 1 - 200")
151
  if st.button("Search"):
152
  if query:
153
+ T2Isearch(query, i)