File size: 4,076 Bytes
a51fba7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import gradio as gr
import torch
import numpy as np
import h5py
import faiss
import json
import hydra
import time
import random
from PIL import Image
import io 
import pickle

def get_image(file, dataset_image_mask, processid_to_index, idx):
    # idx = processid_to_index[query_id]
    image_enc_padded = file["image"][idx].astype(np.uint8)
    enc_length = dataset_image_mask[idx]
    image_enc = image_enc_padded[:enc_length]
    image = Image.open(io.BytesIO(image_enc))
    return image

def searchEmbeddings(id):
    # get embeddings from file
    embeddings_file = h5py.File('5m/extracted_features_of_all_keys.hdf5', 'r')

    # variable and index initialization
    dim = 768
    count = 0
    num_neighbors = 10

    image_index = faiss.IndexFlatIP(dim)

    # load dictionaries
    with open("id_emb_dict.pickle", "rb") as f:
        id_to_emb_dict = pickle.load(f)
    with open("indx_to_id.pickle", "rb") as f:
        indx_to_id_dict = pickle.load(f)

    # get index
    image_index = faiss.read_index("image_index.index")

    # search for query
    query = id_to_emb_dict[id]
    query = query.astype(np.float32)
    D, I = image_index.search(query, num_neighbors)

    id_list = []
    i = 1
    for indx in I[0]:
        id = indx_to_id_dict[indx]
        id_list.append(id)

    # get image data
    dataset_hdf5_all_key = h5py.File('full5m/BIOSCAN_5M.hdf5', "r", libver="latest")['all_keys']
    dataset_processid_list = [item.decode("utf-8") for item in dataset_hdf5_all_key["processid"][:]]
    dataset_image_mask = dataset_hdf5_all_key["image_mask"][:]
    processid_to_index = {pid: idx for idx, pid in enumerate(dataset_processid_list)}

    image1 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][0])
    image2 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][1])
    image3 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][2])
    image4 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][3])
    image5 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][4])
    image6 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][5])
    image7 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][6])
    image8 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][7])
    image9 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][8])
    image10 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][9])

    # return id_list, id_list[0], id_list[1], id_list[2], id_list[3], id_list[4], id_list[5], id_list[6], id_list[7], id_list[8], id_list[9], image1, image2, image3, image4, image5, image6, image7, image8, image9, image10
    # return id_list, indx_to_id_dict[I[0][0]], indx_to_id_dict[I[0][1]], indx_to_id_dict[I[0][2]], indx_to_id_dict[I[0][3]], indx_to_id_dict[I[0][4]], indx_to_id_dict[I[0][5]], indx_to_id_dict[I[0][6]], indx_to_id_dict[I[0][7]], indx_to_id_dict[I[0][8]], indx_to_id_dict[I[0][9]]
    return id_list, image1, image2, image3, image4, image5, image6, image7, image8, image9, image10

with gr.Blocks() as demo:
    with gr.Column():
        process_id = gr.Textbox(label="ID:", info="Enter a sample ID to search for")
        process_id_list = gr.Textbox(label="Closest 10 matches:" )
        search_btn = gr.Button("Search") 

    with gr.Row():
        image1 = gr.Image(label=1)
        image2 = gr.Image(label=2)
        image3 = gr.Image(label=3)
        image4 = gr.Image(label=4)
        image5 = gr.Image(label=5)
    with gr.Row():
        image6 = gr.Image(label=6)
        image7 = gr.Image(label=7)
        image8 = gr.Image(label=8)
        image9 = gr.Image(label=9)
        image10 = gr.Image(label=10)
    
    search_btn.click(fn=searchEmbeddings, inputs=process_id, 
                     outputs=[process_id_list, image1, image2, image3, image4, image5, image6, image7, image8, image9, image10])

# ARONZ671-20
demo.launch(share=True)