shivangibithel commited on
Commit
90c5c7f
1 Parent(s): 67cbd57

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -5
app.py CHANGED
@@ -55,7 +55,7 @@ d = 1024
55
  text_index = faiss.index_factory(d, "Flat", faiss.METRIC_INNER_PRODUCT)
56
  faiss.read_index("text_index.index")
57
 
58
- def T2Isearch(query,focussed_word, k=50):
59
  # Encode the text query
60
  inputs = text_tokenizer([query,focussed_word], padding=True, return_tensors="pt")
61
  outputs = text_model(**inputs)
@@ -63,22 +63,25 @@ def T2Isearch(query,focussed_word, k=50):
63
  query_vector = query_embedding.detach().numpy()
64
  query_vector = np.concatenate((query_vector[0], query_vector[1]), dtype=np.float32)
65
  query_vector = query_vector.reshape(1,1024)
66
- # query_vector = test_xt[0]
67
- # query_vector = np.array([query_embedding])
68
  faiss.normalize_L2(query_vector)
69
- # text_index.nprobe = index.ntotal
70
- text_index.nprobe = 100
 
71
 
72
  # Search for the nearest neighbors in the FAISS text index
73
  D, I = text_index.search(query_vector, k)
 
 
74
 
75
  # get rank of all classes wrt to query
76
  classes_all = []
77
  Y = train_yt
78
  neighbor_ys = Y[I]
79
  class_freq = np.zeros(Y.shape[1])
 
80
  for neighbor_y in neighbor_ys:
81
  classes = np.where(neighbor_y > 0.5)[0]
 
82
  for _class in classes:
83
  class_freq[_class] += 1
84
 
@@ -91,6 +94,7 @@ def T2Isearch(query,focussed_word, k=50):
91
 
92
  lis = ['aeroplane', 'bicycle','bird','boat','bottle','bus','car','cat','chair','cow','diningtable','dog','horse','motorbike','person','pottedplant','sheep','sofa','train','tvmonitor']
93
  class_ = lis[ranked_classes_after_knn[0]-1]
 
94
 
95
  # Map the image ids to the corresponding image URLs
96
  for i in range(len(image_list)):
 
55
  text_index = faiss.index_factory(d, "Flat", faiss.METRIC_INNER_PRODUCT)
56
  faiss.read_index("text_index.index")
57
 
58
+ def T2Isearch(query,focussed_word, k=5):
59
  # Encode the text query
60
  inputs = text_tokenizer([query,focussed_word], padding=True, return_tensors="pt")
61
  outputs = text_model(**inputs)
 
63
  query_vector = query_embedding.detach().numpy()
64
  query_vector = np.concatenate((query_vector[0], query_vector[1]), dtype=np.float32)
65
  query_vector = query_vector.reshape(1,1024)
 
 
66
  faiss.normalize_L2(query_vector)
67
+ text_index.nprobe = text_index.ntotal
68
+
69
+ # text_index.nprobe = 100
70
 
71
  # Search for the nearest neighbors in the FAISS text index
72
  D, I = text_index.search(query_vector, k)
73
+ print(D)
74
+ print(I)
75
 
76
  # get rank of all classes wrt to query
77
  classes_all = []
78
  Y = train_yt
79
  neighbor_ys = Y[I]
80
  class_freq = np.zeros(Y.shape[1])
81
+
82
  for neighbor_y in neighbor_ys:
83
  classes = np.where(neighbor_y > 0.5)[0]
84
+ print(classes)
85
  for _class in classes:
86
  class_freq[_class] += 1
87
 
 
94
 
95
  lis = ['aeroplane', 'bicycle','bird','boat','bottle','bus','car','cat','chair','cow','diningtable','dog','horse','motorbike','person','pottedplant','sheep','sofa','train','tvmonitor']
96
  class_ = lis[ranked_classes_after_knn[0]-1]
97
+ print(class_)
98
 
99
  # Map the image ids to the corresponding image URLs
100
  for i in range(len(image_list)):