Spaces:
Runtime error
Runtime error
refine key and query types
Browse files
app.py
CHANGED
@@ -49,7 +49,7 @@ def chooseDNAIndex(indexType):
|
|
49 |
return dna_index_LSH
|
50 |
|
51 |
|
52 |
-
def searchEmbeddings(id,
|
53 |
# variable and index initialization
|
54 |
dim = 768
|
55 |
count = 0
|
@@ -58,15 +58,15 @@ def searchEmbeddings(id, mod1, mod2, indexType):
|
|
58 |
index = faiss.IndexFlatIP(dim)
|
59 |
|
60 |
# get index
|
61 |
-
if
|
62 |
-
index = chooseImageIndex(
|
63 |
-
elif
|
64 |
-
index = chooseDNAIndex(
|
65 |
|
66 |
# search for query
|
67 |
-
if
|
68 |
query = id_to_image_emb_dict[id]
|
69 |
-
elif
|
70 |
query = id_to_dna_emb_dict[id]
|
71 |
query = query.astype(np.float32)
|
72 |
D, I = index.search(query, num_neighbors)
|
@@ -118,8 +118,8 @@ with gr.Blocks() as demo:
|
|
118 |
rand_id_indx = gr.Textbox(label="Index:")
|
119 |
id_btn = gr.Button("Get Random ID")
|
120 |
with gr.Column():
|
121 |
-
key_type = gr.Radio(choices=["
|
122 |
-
query_type = gr.Radio(choices=["
|
123 |
|
124 |
index_type = gr.Radio(
|
125 |
choices=["FlatIP(default)", "FlatL2", "HNSWFlat", "IVFFlat", "LSH"], label="Index:", value="FlatIP(default)"
|
|
|
49 |
return dna_index_LSH
|
50 |
|
51 |
|
52 |
+
def searchEmbeddings(id, key_type, query_type, index_type):
|
53 |
# variable and index initialization
|
54 |
dim = 768
|
55 |
count = 0
|
|
|
58 |
index = faiss.IndexFlatIP(dim)
|
59 |
|
60 |
# get index
|
61 |
+
if query_type == "Image":
|
62 |
+
index = chooseImageIndex(index_type)
|
63 |
+
elif query_type == "DNA":
|
64 |
+
index = chooseDNAIndex(index_type)
|
65 |
|
66 |
# search for query
|
67 |
+
if key_type == "Image":
|
68 |
query = id_to_image_emb_dict[id]
|
69 |
+
elif key_type == "DNA":
|
70 |
query = id_to_dna_emb_dict[id]
|
71 |
query = query.astype(np.float32)
|
72 |
D, I = index.search(query, num_neighbors)
|
|
|
118 |
rand_id_indx = gr.Textbox(label="Index:")
|
119 |
id_btn = gr.Button("Get Random ID")
|
120 |
with gr.Column():
|
121 |
+
key_type = gr.Radio(choices=["Image", "DNA"], label="Search From:", value="Image")
|
122 |
+
query_type = gr.Radio(choices=["Image", "DNA"], label="Search To:", value="Image")
|
123 |
|
124 |
index_type = gr.Radio(
|
125 |
choices=["FlatIP(default)", "FlatL2", "HNSWFlat", "IVFFlat", "LSH"], label="Index:", value="FlatIP(default)"
|