shivangibithel commited on
Commit
2d30c9d
·
1 Parent(s): 9217b31

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -22
app.py CHANGED
@@ -15,28 +15,24 @@ from sklearn.preprocessing import normalize, OneHotEncoder
15
  # loading the train dataset
16
  with open('clip_train.pkl', 'rb') as f:
17
  temp_d = pickle.load(f)
18
- # train_xv = temp_d['image'].astype(np.float64) # Array of image features : np ndarray
19
- # train_xt = temp_d['text'].astype(np.float64) # Array of text features : np ndarray
20
- # train_yv = temp_d['label'] # Array of labels
21
  train_yt = temp_d['label'] # Array of labels
22
- # ids = list(temp_d['ids']) # image names == len(images)
23
-
24
- # train_yt = np.load("train_yt.npy")
25
 
26
  # loading the test dataset
27
  with open('clip_test.pkl', 'rb') as f:
28
  temp_d = pickle.load(f)
29
- # test_xv = temp_d['image'].astype(np.float64)
30
- # test_xt = temp_d['text'].astype(np.float64)
31
- # test_yv = temp_d['label']
32
  test_yt = temp_d['label']
33
 
34
- # test_xt = np.load("test_xt.npy")
35
-
36
  enc = OneHotEncoder(sparse=False)
37
  enc.fit(np.concatenate((train_yt, test_yt)).reshape((-1, 1)))
38
- # train_yv = enc.transform(self.train_yv.reshape((-1, 1))).astype(np.float64)
39
- # test_yv = enc.transform(self.test_yv.reshape((-1, 1))).astype(np.float64)
40
  train_yt = enc.transform(train_yt.reshape((-1, 1))).astype(np.float64)
41
  test_yt = enc.transform(test_yt.reshape((-1, 1))).astype(np.float64)
42
 
@@ -55,18 +51,15 @@ d = 32
55
  text_index = faiss.index_factory(d, "Flat", faiss.METRIC_INNER_PRODUCT)
56
  text_index = faiss.read_index("text_index.index")
57
 
58
- def T2Isearch(query,focussed_word, k=50):
59
  # Encode the text query
60
- inputs = text_tokenizer([query,focussed_word], padding=True, return_tensors="pt")
61
  outputs = text_model(**inputs)
62
  query_embedding = outputs.text_embeds
63
  query_vector = query_embedding.detach().numpy()
64
- query_vector = np.concatenate((query_vector[0], query_vector[1]), dtype=np.float32)
65
- query_vector = query_vector.reshape(1,1024)
66
  faiss.normalize_L2(query_vector)
67
- text_index.nprobe = text_index.ntotal
68
-
69
- # text_index.nprobe = 100
70
 
71
  # Search for the nearest neighbors in the FAISS text index
72
  D, I = text_index.search(query_vector, k)
@@ -104,7 +97,7 @@ def T2Isearch(query,focussed_word, k=50):
104
  if count == 5: break
105
 
106
  query = st.text_input("Enter your search query here:")
107
- focussed_word = st.text_input("Enter focussed word here")
108
  if st.button("Search"):
109
  if query:
110
- T2Isearch(query, focussed_word)
 
15
  # loading the train dataset
16
  with open('clip_train.pkl', 'rb') as f:
17
  temp_d = pickle.load(f)
18
+ train_xv = temp_d['image'].astype(np.float64) # Array of image features : np ndarray
19
+ train_xt = temp_d['text'].astype(np.float64) # Array of text features : np ndarray
20
+ train_yv = temp_d['label'] # Array of labels
21
  train_yt = temp_d['label'] # Array of labels
22
+ ids = list(temp_d['ids']) # image names == len(images)
 
 
23
 
24
  # loading the test dataset
25
  with open('clip_test.pkl', 'rb') as f:
26
  temp_d = pickle.load(f)
27
+ test_xv = temp_d['image'].astype(np.float64)
28
+ test_xt = temp_d['text'].astype(np.float64)
29
+ test_yv = temp_d['label']
30
  test_yt = temp_d['label']
31
 
 
 
32
  enc = OneHotEncoder(sparse=False)
33
  enc.fit(np.concatenate((train_yt, test_yt)).reshape((-1, 1)))
34
+ train_yv = enc.transform(train_yv.reshape((-1, 1))).astype(np.float64)
35
+ test_yv = enc.transform(test_yv.reshape((-1, 1))).astype(np.float64)
36
  train_yt = enc.transform(train_yt.reshape((-1, 1))).astype(np.float64)
37
  test_yt = enc.transform(test_yt.reshape((-1, 1))).astype(np.float64)
38
 
 
51
  text_index = faiss.index_factory(d, "Flat", faiss.METRIC_INNER_PRODUCT)
52
  text_index = faiss.read_index("text_index.index")
53
 
54
+ def T2Isearch(query, k=50):
55
  # Encode the text query
56
+ inputs = text_tokenizer([query], padding=True, return_tensors="pt")
57
  outputs = text_model(**inputs)
58
  query_embedding = outputs.text_embeds
59
  query_vector = query_embedding.detach().numpy()
60
+ query_vector = query_vector.reshape(1,512)
 
61
  faiss.normalize_L2(query_vector)
62
+ index.nprobe = index.ntotal
 
 
63
 
64
  # Search for the nearest neighbors in the FAISS text index
65
  D, I = text_index.search(query_vector, k)
 
97
  if count == 5: break
98
 
99
  query = st.text_input("Enter your search query here:")
100
+
101
  if st.button("Search"):
102
  if query:
103
+ T2Isearch(query)