msabia's picture
Update app.py
4aa13f2 verified
raw
history blame
4.03 kB
import gradio as gr
import torch
import numpy as np
import h5py
import faiss
from PIL import Image
import io
import pickle
def get_image(file, dataset_image_mask, processid_to_index, idx):
# idx = processid_to_index[query_id]
image_enc_padded = file["image"][idx].astype(np.uint8)
enc_length = dataset_image_mask[idx]
image_enc = image_enc_padded[:enc_length]
image = Image.open(io.BytesIO(image_enc))
return image
def searchEmbeddings(id):
# get embeddings from file
embeddings_file = h5py.File('5m/extracted_features_of_all_keys.hdf5', 'r')
# variable and index initialization
dim = 768
count = 0
num_neighbors = 10
image_index = faiss.IndexFlatIP(dim)
# load dictionaries
with open("id_emb_dict.pickle", "rb") as f:
id_to_emb_dict = pickle.load(f)
with open("indx_to_id.pickle", "rb") as f:
indx_to_id_dict = pickle.load(f)
# get index
image_index = faiss.read_index("image_index.index")
# search for query
query = id_to_emb_dict[id]
query = query.astype(np.float32)
D, I = image_index.search(query, num_neighbors)
id_list = []
i = 1
for indx in I[0]:
id = indx_to_id_dict[indx]
id_list.append(id)
# get image data
dataset_hdf5_all_key = h5py.File('full5m/BIOSCAN_5M.hdf5', "r", libver="latest")['all_keys']
dataset_processid_list = [item.decode("utf-8") for item in dataset_hdf5_all_key["processid"][:]]
dataset_image_mask = dataset_hdf5_all_key["image_mask"][:]
processid_to_index = {pid: idx for idx, pid in enumerate(dataset_processid_list)}
image1 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][0])
image2 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][1])
image3 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][2])
image4 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][3])
image5 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][4])
image6 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][5])
image7 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][6])
image8 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][7])
image9 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][8])
image10 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][9])
# return id_list, id_list[0], id_list[1], id_list[2], id_list[3], id_list[4], id_list[5], id_list[6], id_list[7], id_list[8], id_list[9], image1, image2, image3, image4, image5, image6, image7, image8, image9, image10
# return id_list, indx_to_id_dict[I[0][0]], indx_to_id_dict[I[0][1]], indx_to_id_dict[I[0][2]], indx_to_id_dict[I[0][3]], indx_to_id_dict[I[0][4]], indx_to_id_dict[I[0][5]], indx_to_id_dict[I[0][6]], indx_to_id_dict[I[0][7]], indx_to_id_dict[I[0][8]], indx_to_id_dict[I[0][9]]
return id_list, image1, image2, image3, image4, image5, image6, image7, image8, image9, image10
with gr.Blocks() as demo:
with gr.Column():
process_id = gr.Textbox(label="ID:", info="Enter a sample ID to search for")
process_id_list = gr.Textbox(label="Closest 10 matches:" )
search_btn = gr.Button("Search")
with gr.Row():
image1 = gr.Image(label=1)
image2 = gr.Image(label=2)
image3 = gr.Image(label=3)
image4 = gr.Image(label=4)
image5 = gr.Image(label=5)
with gr.Row():
image6 = gr.Image(label=6)
image7 = gr.Image(label=7)
image8 = gr.Image(label=8)
image9 = gr.Image(label=9)
image10 = gr.Image(label=10)
search_btn.click(fn=searchEmbeddings, inputs=process_id,
outputs=[process_id_list, image1, image2, image3, image4, image5, image6, image7, image8, image9, image10])
# ARONZ671-20
demo.launch(share=True)