browser-backend / app.py
atwang's picture
update code to download dataset files from separate repo
07356cd
raw
history blame
3.28 kB
import pickle
import random
import gradio as gr
import numpy as np
from data import load_indexes_local, load_indexes_hf, load_index_pickle
def getRandID():
indx = random.randrange(0, len(index_to_id_dict))
return index_to_id_dict[indx], indx
def get_image_index(indexType):
try:
return image_indexes[indexType]
except KeyError:
raise KeyError(f"Tried to load an image index that is not supported: {indexType}")
def get_dna_index(indexType):
try:
return dna_indexes[indexType]
except KeyError:
raise KeyError(f"Tried to load a DNA index that is not supported: {indexType}")
def searchEmbeddings(id, key_type, query_type, index_type, num_results: int = 10):
image_index = get_image_index(index_type)
dna_index = get_dna_index(index_type)
# get index
if query_type == "Image":
query = image_index.reconstruct(id_to_index_dict[id])
elif query_type == "DNA":
query = dna_index.reconstruct(id_to_index_dict[id])
else:
raise ValueError(f"Invalid query type: {query_type}")
query = query.astype(np.float32)
query = np.expand_dims(query, axis=0)
# search for query
if key_type == "Image":
index = image_index
elif key_type == "DNA":
index = dna_index
else:
raise ValueError(f"Invalid key type: {key_type}")
_, I = index.search(query, num_results)
closest_ids = []
for indx in I[0]:
id = index_to_id_dict[indx]
closest_ids.append(id)
return closest_ids
with gr.Blocks() as demo:
# for hf: change all file paths, indx_to_id_dict as well
# load indexes
image_indexes = load_indexes_hf(
{"FlatIP(default)": "bioscan_5m_image_IndexFlatIP.index"}, repo_name="bioscan-ml/bioscan-clibd"
)
dna_indexes = load_indexes_hf(
{"FlatIP(default)": "bioscan_5m_dna_IndexFlatIP.index"}, repo_name="bioscan-ml/bioscan-clibd"
)
index_to_id_dict = load_index_pickle("big_indx_to_id_dict.pickle", repo_name="bioscan-ml/bioscan-clibd")
id_to_index_dict = {v: k for k, v in index_to_id_dict.items()}
with gr.Column():
with gr.Row():
with gr.Column():
rand_id = gr.Textbox(label="Random ID:")
rand_id_indx = gr.Textbox(label="Index:")
id_btn = gr.Button("Get Random ID")
with gr.Column():
query_type = gr.Radio(choices=["Image", "DNA"], label="Query:", value="Image")
key_type = gr.Radio(choices=["Image", "DNA"], label="Key:", value="Image")
index_type = gr.Radio(
choices=["FlatIP(default)", "FlatL2", "HNSWFlat", "IVFFlat", "LSH"], label="Index:", value="FlatIP(default)"
)
num_results = gr.Number(label="Number of Results:", value=10, precision=0)
process_id = gr.Textbox(label="ID:", info="Enter a sample ID to search for")
process_id_list = gr.Textbox(label="Closest matches:")
search_btn = gr.Button("Search")
id_btn.click(fn=getRandID, inputs=[], outputs=[rand_id, rand_id_indx])
search_btn.click(
fn=searchEmbeddings,
inputs=[process_id, key_type, query_type, index_type, num_results],
outputs=[process_id_list],
)
demo.launch()