atwang commited on
Commit
1f788b3
1 Parent(s): 51e3825

update app to use indexes directly to get embeddings

Browse files
app.py CHANGED
@@ -11,11 +11,11 @@ import click
11
 
12
 
13
  def getRandID():
14
- indx = random.randrange(0, 396503)
15
- return indx_to_id_dict[indx], indx
16
 
17
 
18
- def chooseImageIndex(indexType):
19
  if indexType == "FlatIP(default)":
20
  return image_index_IP
21
  elif indexType == "FlatL2":
@@ -32,7 +32,7 @@ def chooseImageIndex(indexType):
32
  return image_index_LSH
33
 
34
 
35
- def chooseDNAIndex(indexType):
36
  if indexType == "FlatIP(default)":
37
  return dna_index_IP
38
  elif indexType == "FlatL2":
@@ -49,35 +49,36 @@ def chooseDNAIndex(indexType):
49
  return dna_index_LSH
50
 
51
 
52
- def searchEmbeddings(id, key_type, query_type, index_type):
53
- # variable and index initialization
54
- dim = 768
55
- count = 0
56
- num_neighbors = 10
57
-
58
- index = faiss.IndexFlatIP(dim)
59
 
60
  # get index
61
  if query_type == "Image":
62
- index = chooseImageIndex(index_type)
63
  elif query_type == "DNA":
64
- index = chooseDNAIndex(index_type)
 
 
 
 
65
 
66
  # search for query
67
  if key_type == "Image":
68
- query = id_to_image_emb_dict[id]
69
  elif key_type == "DNA":
70
- query = id_to_dna_emb_dict[id]
71
- query = query.astype(np.float32)
72
- D, I = index.search(query, num_neighbors)
73
 
74
- id_list = []
75
- i = 1
 
76
  for indx in I[0]:
77
- id = indx_to_id_dict[indx]
78
- id_list.append(id)
79
 
80
- return id_list
81
 
82
 
83
  with gr.Blocks() as demo:
@@ -102,14 +103,8 @@ with gr.Blocks() as demo:
102
  # with open("processid_to_index.pickle", "rb") as f:
103
  # processid_to_index = pickle.load(f)
104
  with open("big_indx_to_id_dict.pickle", "rb") as f:
105
- indx_to_id_dict = pickle.load(f)
106
-
107
- # initialize both possible dicts
108
- with open("big_id_to_image_emb_dict.pickle", "rb") as f:
109
- id_to_image_emb_dict = pickle.load(f)
110
- # with open("big_id_to_dna_emb_dict.pickle", "rb") as f:
111
- # id_to_dna_emb_dict = pickle.load(f)
112
- id_to_dna_emb_dict = None
113
 
114
  with gr.Column():
115
  with gr.Row():
@@ -124,12 +119,18 @@ with gr.Blocks() as demo:
124
  index_type = gr.Radio(
125
  choices=["FlatIP(default)", "FlatL2", "HNSWFlat", "IVFFlat", "LSH"], label="Index:", value="FlatIP(default)"
126
  )
 
 
127
  process_id = gr.Textbox(label="ID:", info="Enter a sample ID to search for")
128
  process_id_list = gr.Textbox(label="Closest 10 matches:")
129
  search_btn = gr.Button("Search")
130
  id_btn.click(fn=getRandID, inputs=[], outputs=[rand_id, rand_id_indx])
131
 
132
- search_btn.click(fn=searchEmbeddings, inputs=[process_id, key_type, query_type, index_type], outputs=[process_id_list])
 
 
 
 
133
 
134
 
135
  demo.launch()
 
11
 
12
 
13
  def getRandID():
14
+ indx = random.randrange(0, len(index_to_id_dict))
15
+ return index_to_id_dict[indx], indx
16
 
17
 
18
+ def get_image_index(indexType):
19
  if indexType == "FlatIP(default)":
20
  return image_index_IP
21
  elif indexType == "FlatL2":
 
32
  return image_index_LSH
33
 
34
 
35
+ def get_dna_index(indexType):
36
  if indexType == "FlatIP(default)":
37
  return dna_index_IP
38
  elif indexType == "FlatL2":
 
49
  return dna_index_LSH
50
 
51
 
52
+ def searchEmbeddings(id, key_type, query_type, index_type, num_results: int = 10):
53
+ image_index = get_image_index(index_type)
54
+ dna_index = get_dna_index(index_type)
 
 
 
 
55
 
56
  # get index
57
  if query_type == "Image":
58
+ query = image_index.reconstruct(id_to_index_dict[id])
59
  elif query_type == "DNA":
60
+ query = dna_index.reconstruct(id_to_index_dict[id])
61
+ else:
62
+ raise ValueError(f"Invalid query type: {query_type}")
63
+ query = query.astype(np.float32)
64
+ query = np.expand_dims(query, axis=0)
65
 
66
  # search for query
67
  if key_type == "Image":
68
+ index = image_index
69
  elif key_type == "DNA":
70
+ index = dna_index
71
+ else:
72
+ raise ValueError(f"Invalid key type: {key_type}")
73
 
74
+ _, I = index.search(query, num_results)
75
+
76
+ closest_ids = []
77
  for indx in I[0]:
78
+ id = index_to_id_dict[indx]
79
+ closest_ids.append(id)
80
 
81
+ return closest_ids
82
 
83
 
84
  with gr.Blocks() as demo:
 
103
  # with open("processid_to_index.pickle", "rb") as f:
104
  # processid_to_index = pickle.load(f)
105
  with open("big_indx_to_id_dict.pickle", "rb") as f:
106
+ index_to_id_dict = pickle.load(f)
107
+ id_to_index_dict = {v: k for k, v in index_to_id_dict.items()}
 
 
 
 
 
 
108
 
109
  with gr.Column():
110
  with gr.Row():
 
119
  index_type = gr.Radio(
120
  choices=["FlatIP(default)", "FlatL2", "HNSWFlat", "IVFFlat", "LSH"], label="Index:", value="FlatIP(default)"
121
  )
122
+ num_results = gr.Number(label="Number of Results:", value=10, precision=0)
123
+
124
  process_id = gr.Textbox(label="ID:", info="Enter a sample ID to search for")
125
  process_id_list = gr.Textbox(label="Closest 10 matches:")
126
  search_btn = gr.Button("Search")
127
  id_btn.click(fn=getRandID, inputs=[], outputs=[rand_id, rand_id_indx])
128
 
129
+ search_btn.click(
130
+ fn=searchEmbeddings,
131
+ inputs=[process_id, key_type, query_type, index_type, num_results],
132
+ outputs=[process_id_list],
133
+ )
134
 
135
 
136
  demo.launch()
big_id_to_image_emb_dict.pickle DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:4fb3f21f2d38a91cb2cad8f40449f31c12d481944d93e9c61def2d3e8e6b78eb
3
- size 274402415
 
 
 
 
big_indx_to_id_dict.pickle CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a192bfb968d669f59ad0c5438751f1585094a67d2130b80c56db9731d4406e10
3
- size 7861755
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee0a9044e054f640b704247a2fa2e74219180b78ded6ba07f551bfc222657fc5
3
+ size 885457