Spaces:
Sleeping
Sleeping
File size: 4,076 Bytes
a51fba7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import gradio as gr
import torch
import numpy as np
import h5py
import faiss
import json
import hydra
import time
import random
from PIL import Image
import io
import pickle
def get_image(file, dataset_image_mask, processid_to_index, idx):
# idx = processid_to_index[query_id]
image_enc_padded = file["image"][idx].astype(np.uint8)
enc_length = dataset_image_mask[idx]
image_enc = image_enc_padded[:enc_length]
image = Image.open(io.BytesIO(image_enc))
return image
def searchEmbeddings(id):
# get embeddings from file
embeddings_file = h5py.File('5m/extracted_features_of_all_keys.hdf5', 'r')
# variable and index initialization
dim = 768
count = 0
num_neighbors = 10
image_index = faiss.IndexFlatIP(dim)
# load dictionaries
with open("id_emb_dict.pickle", "rb") as f:
id_to_emb_dict = pickle.load(f)
with open("indx_to_id.pickle", "rb") as f:
indx_to_id_dict = pickle.load(f)
# get index
image_index = faiss.read_index("image_index.index")
# search for query
query = id_to_emb_dict[id]
query = query.astype(np.float32)
D, I = image_index.search(query, num_neighbors)
id_list = []
i = 1
for indx in I[0]:
id = indx_to_id_dict[indx]
id_list.append(id)
# get image data
dataset_hdf5_all_key = h5py.File('full5m/BIOSCAN_5M.hdf5', "r", libver="latest")['all_keys']
dataset_processid_list = [item.decode("utf-8") for item in dataset_hdf5_all_key["processid"][:]]
dataset_image_mask = dataset_hdf5_all_key["image_mask"][:]
processid_to_index = {pid: idx for idx, pid in enumerate(dataset_processid_list)}
image1 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][0])
image2 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][1])
image3 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][2])
image4 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][3])
image5 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][4])
image6 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][5])
image7 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][6])
image8 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][7])
image9 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][8])
image10 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][9])
# return id_list, id_list[0], id_list[1], id_list[2], id_list[3], id_list[4], id_list[5], id_list[6], id_list[7], id_list[8], id_list[9], image1, image2, image3, image4, image5, image6, image7, image8, image9, image10
# return id_list, indx_to_id_dict[I[0][0]], indx_to_id_dict[I[0][1]], indx_to_id_dict[I[0][2]], indx_to_id_dict[I[0][3]], indx_to_id_dict[I[0][4]], indx_to_id_dict[I[0][5]], indx_to_id_dict[I[0][6]], indx_to_id_dict[I[0][7]], indx_to_id_dict[I[0][8]], indx_to_id_dict[I[0][9]]
return id_list, image1, image2, image3, image4, image5, image6, image7, image8, image9, image10
with gr.Blocks() as demo:
with gr.Column():
process_id = gr.Textbox(label="ID:", info="Enter a sample ID to search for")
process_id_list = gr.Textbox(label="Closest 10 matches:" )
search_btn = gr.Button("Search")
with gr.Row():
image1 = gr.Image(label=1)
image2 = gr.Image(label=2)
image3 = gr.Image(label=3)
image4 = gr.Image(label=4)
image5 = gr.Image(label=5)
with gr.Row():
image6 = gr.Image(label=6)
image7 = gr.Image(label=7)
image8 = gr.Image(label=8)
image9 = gr.Image(label=9)
image10 = gr.Image(label=10)
search_btn.click(fn=searchEmbeddings, inputs=process_id,
outputs=[process_id_list, image1, image2, image3, image4, image5, image6, image7, image8, image9, image10])
# ARONZ671-20
demo.launch(share=True) |