bean@DESKTOP-G2JAGVE commited on
Commit
612907b
·
1 Parent(s): ba05b10

Add application file

Browse files
Files changed (1) hide show
  1. imageSearching.py +105 -0
imageSearching.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import h5py
5
+ import faiss
6
+ import json
7
+ import hydra
8
+ import time
9
+ import random
10
+ from PIL import Image
11
+ import io
12
+ import pickle
13
+
14
+ def get_image(file, dataset_image_mask, processid_to_index, idx):
15
+ # idx = processid_to_index[query_id]
16
+ image_enc_padded = file["image"][idx].astype(np.uint8)
17
+ enc_length = dataset_image_mask[idx]
18
+ image_enc = image_enc_padded[:enc_length]
19
+ image = Image.open(io.BytesIO(image_enc))
20
+ return image
21
+
22
+ def searchEmbeddings(id):
23
+ # get embeddings from file
24
+ embeddings_file = h5py.File('5m/extracted_features_of_all_keys.hdf5', 'r')
25
+
26
+ # variable and index initialization
27
+ dim = 768
28
+ count = 0
29
+ num_neighbors = 10
30
+
31
+ image_index = faiss.IndexFlatIP(dim)
32
+
33
+ # load dictionaries
34
+ with open("id_emb_dict.pickle", "rb") as f:
35
+ id_to_emb_dict = pickle.load(f)
36
+ with open("indx_to_id.pickle", "rb") as f:
37
+ indx_to_id_dict = pickle.load(f)
38
+
39
+ # get index
40
+ image_index = faiss.read_index("image_index.index")
41
+
42
+ # search for query
43
+ query = id_to_emb_dict[id]
44
+ query = query.astype(np.float32)
45
+ D, I = image_index.search(query, num_neighbors)
46
+
47
+ id_list = []
48
+ # need to convert I to id
49
+ i = 1
50
+ for indx in I[0]:
51
+ id = indx_to_id_dict[indx]
52
+ id_list.append(id)
53
+ # id_list.append(str(i) + ": " + id)
54
+ # i += 1
55
+
56
+ # get image data
57
+ dataset_hdf5_all_key = h5py.File('full5m/BIOSCAN_5M.hdf5', "r", libver="latest")['all_keys']
58
+ dataset_processid_list = [item.decode("utf-8") for item in dataset_hdf5_all_key["processid"][:]]
59
+ dataset_image_mask = dataset_hdf5_all_key["image_mask"][:]
60
+ processid_to_index = {pid: idx for idx, pid in enumerate(dataset_processid_list)}
61
+
62
+ image1 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][0])
63
+ image2 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][1])
64
+ image3 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][2])
65
+ image4 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][3])
66
+ image5 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][4])
67
+ image6 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][5])
68
+ image7 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][6])
69
+ image8 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][7])
70
+ image9 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][8])
71
+ image10 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][9])
72
+
73
+ # return id_list, id_list[0], id_list[1], id_list[2], id_list[3], id_list[4], id_list[5], id_list[6], id_list[7], id_list[8], id_list[9], image1, image2, image3, image4, image5, image6, image7, image8, image9, image10
74
+ # return id_list, indx_to_id_dict[I[0][0]], indx_to_id_dict[I[0][1]], indx_to_id_dict[I[0][2]], indx_to_id_dict[I[0][3]], indx_to_id_dict[I[0][4]], indx_to_id_dict[I[0][5]], indx_to_id_dict[I[0][6]], indx_to_id_dict[I[0][7]], indx_to_id_dict[I[0][8]], indx_to_id_dict[I[0][9]]
75
+ return id_list, image1, image2, image3, image4, image5, image6, image7, image8, image9, image10
76
+
77
+ with gr.Blocks() as demo:
78
+ with gr.Column():
79
+ process_id = gr.Textbox(label="ID:", info="Enter a sample ID to search for")
80
+ process_id_list = gr.Textbox(label="Closest 10 matches:" )
81
+ search_btn = gr.Button("Search")
82
+
83
+ with gr.Row():
84
+ image1 = gr.Image(label=1)
85
+ image2 = gr.Image(label=2)
86
+ image3 = gr.Image(label=3)
87
+ image4 = gr.Image(label=4)
88
+ image5 = gr.Image(label=5)
89
+ with gr.Row():
90
+ image6 = gr.Image(label=6)
91
+ image7 = gr.Image(label=7)
92
+ image8 = gr.Image(label=8)
93
+ image9 = gr.Image(label=9)
94
+ image10 = gr.Image(label=10)
95
+
96
+ search_btn.click(fn=searchEmbeddings, inputs=process_id,
97
+ outputs=[process_id_list, image1, image2, image3, image4, image5, image6, image7, image8, image9, image10])
98
+
99
+
100
+ # cant make functions with Image's as inputs
101
+ # MUST be a way to format after
102
+
103
+
104
+ # ARONZ671-20
105
+ demo.launch(share=True)