browser-backend / app.py
atwang's picture
add code files
db66d62
raw
history blame
4.02 kB
import gradio as gr
import torch
import numpy as np
import h5py
import faiss
from PIL import Image
import io
import pickle
import random
def getRandID():
indx = random.randrange(0, 396503)
return indx_to_id_dict[indx], indx
def chooseImageIndex(indexType):
if (indexType == "FlatIP(default)"):
return image_index_IP
elif (indexType == "FlatL2"):
return image_index_L2
elif (indexType == "HNSWFlat"):
return image_index_HNSW
elif (indexType == "IVFFlat"):
return image_index_IVF
elif (indexType == "LSH"):
return image_index_LSH
def chooseDNAIndex(indexType):
if (indexType == "FlatIP(default)"):
return dna_index_IP
elif (indexType == "FlatL2"):
return dna_index_L2
elif (indexType == "HNSWFlat"):
return dna_index_HNSW
elif (indexType == "IVFFlat"):
return dna_index_IVF
elif (indexType == "LSH"):
return dna_index_LSH
def searchEmbeddings(id, mod1, mod2, indexType):
# variable and index initialization
dim = 768
count = 0
num_neighbors = 10
index = faiss.IndexFlatIP(dim)
# get index
if (mod2 == "Image"):
index = chooseImageIndex(indexType)
elif (mod2 == "DNA"):
index = chooseDNAIndex(indexType)
# search for query
if (mod1 == "Image"):
query = id_to_image_emb_dict[id]
elif (mod1 == "DNA"):
query = id_to_dna_emb_dict[id]
query = query.astype(np.float32)
D, I = index.search(query, num_neighbors)
id_list = []
i = 1
for indx in I[0]:
id = indx_to_id_dict[indx]
id_list.append(id)
return id_list
with gr.Blocks() as demo:
# for hf: change all file paths, indx_to_id_dict as well
# load indexes
image_index_IP = faiss.read_index("big_image_index_FlatIP.index")
image_index_L2 = faiss.read_index("big_image_index_FlatL2.index")
image_index_HNSW = faiss.read_index("big_image_index_HNSWFlat.index")
image_index_IVF = faiss.read_index("big_image_index_IVFFlat.index")
image_index_LSH = faiss.read_index("big_image_index_LSH.index")
dna_index_IP = faiss.read_index("big_dna_index_FlatIP.index")
dna_index_L2 = faiss.read_index("big_dna_index_FlatL2.index")
dna_index_HNSW = faiss.read_index("big_dna_index_HNSWFlat.index")
dna_index_IVF = faiss.read_index("big_dna_index_IVFFlat.index")
dna_index_LSH = faiss.read_index("big_dna_index_LSH.index")
with open("dataset_processid_list.pickle", "rb") as f:
dataset_processid_list = pickle.load(f)
with open("processid_to_index.pickle", "rb") as f:
processid_to_index = pickle.load(f)
with open("big_indx_to_id_dict.pickle", "rb") as f:
indx_to_id_dict = pickle.load(f)
# initialize both possible dicts
with open("big_id_to_image_emb_dict.pickle", "rb") as f:
id_to_image_emb_dict = pickle.load(f)
with open("big_id_to_dna_emb_dict.pickle", "rb") as f:
id_to_dna_emb_dict = pickle.load(f)
with gr.Column():
with gr.Row():
with gr.Column():
rand_id = gr.Textbox(label="Random ID:")
rand_id_indx = gr.Textbox(label="Index:")
id_btn = gr.Button("Get Random ID")
with gr.Column():
mod1 = gr.Radio(choices=["DNA", "Image"], label="Search From:")
mod2 = gr.Radio(choices=["DNA", "Image"], label="Search To:")
indexType = gr.Radio(choices=["FlatIP(default)", "FlatL2", "HNSWFlat", "IVFFlat", "LSH"], label="Index:", value="FlatIP(default)")
process_id = gr.Textbox(label="ID:", info="Enter a sample ID to search for")
process_id_list = gr.Textbox(label="Closest 10 matches:" )
search_btn = gr.Button("Search")
id_btn.click(fn=getRandID, inputs=[], outputs=[rand_id, rand_id_indx])
search_btn.click(fn=searchEmbeddings, inputs=[process_id, mod1, mod2, indexType],
outputs=[process_id_list])
demo.launch()