Spaces:
Runtime error
Runtime error
File size: 3,277 Bytes
db66d62 07356cd db66d62 b383c02 db66d62 1f788b3 db66d62 b383c02 1f788b3 07356cd db66d62 b383c02 1f788b3 07356cd db66d62 1f788b3 db66d62 51e3825 1f788b3 51e3825 1f788b3 db66d62 51e3825 1f788b3 51e3825 1f788b3 db66d62 1f788b3 db66d62 1f788b3 b383c02 1f788b3 db66d62 b383c02 db66d62 07356cd 1f788b3 db66d62 07356cd db66d62 73ff1bb b383c02 1f788b3 db66d62 07356cd b383c02 db66d62 1f788b3 b383c02 db66d62 b383c02 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import pickle
import random
import gradio as gr
import numpy as np
from data import load_indexes_local, load_indexes_hf, load_index_pickle
def getRandID():
indx = random.randrange(0, len(index_to_id_dict))
return index_to_id_dict[indx], indx
def get_image_index(indexType):
try:
return image_indexes[indexType]
except KeyError:
raise KeyError(f"Tried to load an image index that is not supported: {indexType}")
def get_dna_index(indexType):
try:
return dna_indexes[indexType]
except KeyError:
raise KeyError(f"Tried to load a DNA index that is not supported: {indexType}")
def searchEmbeddings(id, key_type, query_type, index_type, num_results: int = 10):
image_index = get_image_index(index_type)
dna_index = get_dna_index(index_type)
# get index
if query_type == "Image":
query = image_index.reconstruct(id_to_index_dict[id])
elif query_type == "DNA":
query = dna_index.reconstruct(id_to_index_dict[id])
else:
raise ValueError(f"Invalid query type: {query_type}")
query = query.astype(np.float32)
query = np.expand_dims(query, axis=0)
# search for query
if key_type == "Image":
index = image_index
elif key_type == "DNA":
index = dna_index
else:
raise ValueError(f"Invalid key type: {key_type}")
_, I = index.search(query, num_results)
closest_ids = []
for indx in I[0]:
id = index_to_id_dict[indx]
closest_ids.append(id)
return closest_ids
with gr.Blocks() as demo:
# for hf: change all file paths, indx_to_id_dict as well
# load indexes
image_indexes = load_indexes_hf(
{"FlatIP(default)": "bioscan_5m_image_IndexFlatIP.index"}, repo_name="bioscan-ml/bioscan-clibd"
)
dna_indexes = load_indexes_hf(
{"FlatIP(default)": "bioscan_5m_dna_IndexFlatIP.index"}, repo_name="bioscan-ml/bioscan-clibd"
)
index_to_id_dict = load_index_pickle("big_indx_to_id_dict.pickle", repo_name="bioscan-ml/bioscan-clibd")
id_to_index_dict = {v: k for k, v in index_to_id_dict.items()}
with gr.Column():
with gr.Row():
with gr.Column():
rand_id = gr.Textbox(label="Random ID:")
rand_id_indx = gr.Textbox(label="Index:")
id_btn = gr.Button("Get Random ID")
with gr.Column():
query_type = gr.Radio(choices=["Image", "DNA"], label="Query:", value="Image")
key_type = gr.Radio(choices=["Image", "DNA"], label="Key:", value="Image")
index_type = gr.Radio(
choices=["FlatIP(default)", "FlatL2", "HNSWFlat", "IVFFlat", "LSH"], label="Index:", value="FlatIP(default)"
)
num_results = gr.Number(label="Number of Results:", value=10, precision=0)
process_id = gr.Textbox(label="ID:", info="Enter a sample ID to search for")
process_id_list = gr.Textbox(label="Closest matches:")
search_btn = gr.Button("Search")
id_btn.click(fn=getRandID, inputs=[], outputs=[rand_id, rand_id_indx])
search_btn.click(
fn=searchEmbeddings,
inputs=[process_id, key_type, query_type, index_type, num_results],
outputs=[process_id_list],
)
demo.launch()
|