msabia commited on
Commit
7b85781
1 Parent(s): 612907b

Create imageSearching.py

Browse files
Files changed (1) hide show
  1. imageSearching.py +0 -105
imageSearching.py CHANGED
@@ -1,105 +0,0 @@
1
- import gradio as gr
2
- import torch
3
- import numpy as np
4
- import h5py
5
- import faiss
6
- import json
7
- import hydra
8
- import time
9
- import random
10
- from PIL import Image
11
- import io
12
- import pickle
13
-
14
- def get_image(file, dataset_image_mask, processid_to_index, idx):
15
- # idx = processid_to_index[query_id]
16
- image_enc_padded = file["image"][idx].astype(np.uint8)
17
- enc_length = dataset_image_mask[idx]
18
- image_enc = image_enc_padded[:enc_length]
19
- image = Image.open(io.BytesIO(image_enc))
20
- return image
21
-
22
- def searchEmbeddings(id):
23
- # get embeddings from file
24
- embeddings_file = h5py.File('5m/extracted_features_of_all_keys.hdf5', 'r')
25
-
26
- # variable and index initialization
27
- dim = 768
28
- count = 0
29
- num_neighbors = 10
30
-
31
- image_index = faiss.IndexFlatIP(dim)
32
-
33
- # load dictionaries
34
- with open("id_emb_dict.pickle", "rb") as f:
35
- id_to_emb_dict = pickle.load(f)
36
- with open("indx_to_id.pickle", "rb") as f:
37
- indx_to_id_dict = pickle.load(f)
38
-
39
- # get index
40
- image_index = faiss.read_index("image_index.index")
41
-
42
- # search for query
43
- query = id_to_emb_dict[id]
44
- query = query.astype(np.float32)
45
- D, I = image_index.search(query, num_neighbors)
46
-
47
- id_list = []
48
- # need to convert I to id
49
- i = 1
50
- for indx in I[0]:
51
- id = indx_to_id_dict[indx]
52
- id_list.append(id)
53
- # id_list.append(str(i) + ": " + id)
54
- # i += 1
55
-
56
- # get image data
57
- dataset_hdf5_all_key = h5py.File('full5m/BIOSCAN_5M.hdf5', "r", libver="latest")['all_keys']
58
- dataset_processid_list = [item.decode("utf-8") for item in dataset_hdf5_all_key["processid"][:]]
59
- dataset_image_mask = dataset_hdf5_all_key["image_mask"][:]
60
- processid_to_index = {pid: idx for idx, pid in enumerate(dataset_processid_list)}
61
-
62
- image1 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][0])
63
- image2 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][1])
64
- image3 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][2])
65
- image4 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][3])
66
- image5 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][4])
67
- image6 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][5])
68
- image7 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][6])
69
- image8 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][7])
70
- image9 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][8])
71
- image10 = get_image(dataset_hdf5_all_key, dataset_image_mask, processid_to_index, I[0][9])
72
-
73
- # return id_list, id_list[0], id_list[1], id_list[2], id_list[3], id_list[4], id_list[5], id_list[6], id_list[7], id_list[8], id_list[9], image1, image2, image3, image4, image5, image6, image7, image8, image9, image10
74
- # return id_list, indx_to_id_dict[I[0][0]], indx_to_id_dict[I[0][1]], indx_to_id_dict[I[0][2]], indx_to_id_dict[I[0][3]], indx_to_id_dict[I[0][4]], indx_to_id_dict[I[0][5]], indx_to_id_dict[I[0][6]], indx_to_id_dict[I[0][7]], indx_to_id_dict[I[0][8]], indx_to_id_dict[I[0][9]]
75
- return id_list, image1, image2, image3, image4, image5, image6, image7, image8, image9, image10
76
-
77
- with gr.Blocks() as demo:
78
- with gr.Column():
79
- process_id = gr.Textbox(label="ID:", info="Enter a sample ID to search for")
80
- process_id_list = gr.Textbox(label="Closest 10 matches:" )
81
- search_btn = gr.Button("Search")
82
-
83
- with gr.Row():
84
- image1 = gr.Image(label=1)
85
- image2 = gr.Image(label=2)
86
- image3 = gr.Image(label=3)
87
- image4 = gr.Image(label=4)
88
- image5 = gr.Image(label=5)
89
- with gr.Row():
90
- image6 = gr.Image(label=6)
91
- image7 = gr.Image(label=7)
92
- image8 = gr.Image(label=8)
93
- image9 = gr.Image(label=9)
94
- image10 = gr.Image(label=10)
95
-
96
- search_btn.click(fn=searchEmbeddings, inputs=process_id,
97
- outputs=[process_id_list, image1, image2, image3, image4, image5, image6, image7, image8, image9, image10])
98
-
99
-
100
- # cant make functions with Image's as inputs
101
- # MUST be a way to format after
102
-
103
-
104
- # ARONZ671-20
105
- demo.launch(share=True)