atwang commited on
Commit
51e3825
1 Parent(s): 73ff1bb

refine key and query types

Browse files
Files changed (1) hide show
  1. app.py +9 -9
app.py CHANGED
@@ -49,7 +49,7 @@ def chooseDNAIndex(indexType):
49
  return dna_index_LSH
50
 
51
 
52
- def searchEmbeddings(id, mod1, mod2, indexType):
53
  # variable and index initialization
54
  dim = 768
55
  count = 0
@@ -58,15 +58,15 @@ def searchEmbeddings(id, mod1, mod2, indexType):
58
  index = faiss.IndexFlatIP(dim)
59
 
60
  # get index
61
- if mod2 == "Image":
62
- index = chooseImageIndex(indexType)
63
- elif mod2 == "DNA":
64
- index = chooseDNAIndex(indexType)
65
 
66
  # search for query
67
- if mod1 == "Image":
68
  query = id_to_image_emb_dict[id]
69
- elif mod1 == "DNA":
70
  query = id_to_dna_emb_dict[id]
71
  query = query.astype(np.float32)
72
  D, I = index.search(query, num_neighbors)
@@ -118,8 +118,8 @@ with gr.Blocks() as demo:
118
  rand_id_indx = gr.Textbox(label="Index:")
119
  id_btn = gr.Button("Get Random ID")
120
  with gr.Column():
121
- key_type = gr.Radio(choices=["DNA", "Image"], label="Search From:")
122
- query_type = gr.Radio(choices=["DNA", "Image"], label="Search To:")
123
 
124
  index_type = gr.Radio(
125
  choices=["FlatIP(default)", "FlatL2", "HNSWFlat", "IVFFlat", "LSH"], label="Index:", value="FlatIP(default)"
 
49
  return dna_index_LSH
50
 
51
 
52
+ def searchEmbeddings(id, key_type, query_type, index_type):
53
  # variable and index initialization
54
  dim = 768
55
  count = 0
 
58
  index = faiss.IndexFlatIP(dim)
59
 
60
  # get index
61
+ if query_type == "Image":
62
+ index = chooseImageIndex(index_type)
63
+ elif query_type == "DNA":
64
+ index = chooseDNAIndex(index_type)
65
 
66
  # search for query
67
+ if key_type == "Image":
68
  query = id_to_image_emb_dict[id]
69
+ elif key_type == "DNA":
70
  query = id_to_dna_emb_dict[id]
71
  query = query.astype(np.float32)
72
  D, I = index.search(query, num_neighbors)
 
118
  rand_id_indx = gr.Textbox(label="Index:")
119
  id_btn = gr.Button("Get Random ID")
120
  with gr.Column():
121
+ key_type = gr.Radio(choices=["Image", "DNA"], label="Search From:", value="Image")
122
+ query_type = gr.Radio(choices=["Image", "DNA"], label="Search To:", value="Image")
123
 
124
  index_type = gr.Radio(
125
  choices=["FlatIP(default)", "FlatL2", "HNSWFlat", "IVFFlat", "LSH"], label="Index:", value="FlatIP(default)"