Spaces:
Sleeping
Sleeping
File size: 4,780 Bytes
6deebea e247857 6deebea e247857 6deebea e247857 6deebea e247857 6deebea e247857 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import gradio as gr
import torch
import numpy as np
import h5py
import faiss
from PIL import Image
import io
import pickle
def get_image(image1, image2, dataset_image_mask, processid_to_index, idx):
# idx = processid_to_index[query_id]
if (idx < 162834):
image_enc_padded = image1[idx].astype(np.uint8)
elif(idx >= 162834):
image_enc_padded = image2[idx-162834].astype(np.uint8)
enc_length = dataset_image_mask[idx]
image_enc = image_enc_padded[:enc_length]
image = Image.open(io.BytesIO(image_enc))
return image
def searchEmbeddings(id):
# variable and index initialization
dim = 768
count = 0
num_neighbors = 10
image_index = faiss.IndexFlatIP(dim)
# load dictionaries
with open("id_emb_dict.pickle", "rb") as f:
id_to_emb_dict = pickle.load(f)
with open("indx_to_id.pickle", "rb") as f:
indx_to_id_dict = pickle.load(f)
# get index
image_index = faiss.read_index("image_index.index")
# search for query
query = id_to_emb_dict[id]
query = query.astype(np.float32)
D, I = image_index.search(query, num_neighbors)
id_list = []
i = 1
for indx in I[0]:
id = indx_to_id_dict[indx]
id_list.append(id)
# get image data
# with open("dataset_processid_list.pickle", "rb") as f:
# dataset_processid_list = pickle.load(f)
# with open("dataset_image_mask.pickle", "rb") as f:
# dataset_image_mask = pickle.load(f)
# with open("processid_to_index.pickle", "rb") as f:
# processid_to_index = pickle.load(f)
# with open("dataset_image1.pickle", "rb") as f:
# dataset_image1 = pickle.load(f)
# with open("dataset_image2.pickle", "rb") as f:
# dataset_image2 = pickle.load(f)
image1 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][0])
image2 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][1])
image3 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][2])
image4 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][3])
image5 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][4])
image6 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][5])
image7 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][6])
image8 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][7])
image9 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][8])
image10 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][9])
# return id_list, id_list[0], id_list[1], id_list[2], id_list[3], id_list[4], id_list[5], id_list[6], id_list[7], id_list[8], id_list[9], image1, image2, image3, image4, image5, image6, image7, image8, image9, image10
# return id_list, indx_to_id_dict[I[0][0]], indx_to_id_dict[I[0][1]], indx_to_id_dict[I[0][2]], indx_to_id_dict[I[0][3]], indx_to_id_dict[I[0][4]], indx_to_id_dict[I[0][5]], indx_to_id_dict[I[0][6]], indx_to_id_dict[I[0][7]], indx_to_id_dict[I[0][8]], indx_to_id_dict[I[0][9]]
return id_list, image1, image2, image3, image4, image5, image6, image7, image8, image9, image10
with gr.Blocks() as demo:
with open("dataset_processid_list.pickle", "rb") as f:
dataset_processid_list = pickle.load(f)
with open("dataset_image_mask.pickle", "rb") as f:
dataset_image_mask = pickle.load(f)
with open("processid_to_index.pickle", "rb") as f:
processid_to_index = pickle.load(f)
with open("dataset_image1.pickle", "rb") as f:
dataset_image1 = pickle.load(f)
with open("dataset_image2.pickle", "rb") as f:
dataset_image2 = pickle.load(f)
with gr.Column():
process_id = gr.Textbox(label="ID:", info="Enter a sample ID to search for")
process_id_list = gr.Textbox(label="Closest 10 matches:" )
search_btn = gr.Button("Search")
with gr.Row():
image1 = gr.Image(label=1)
image2 = gr.Image(label=2)
image3 = gr.Image(label=3)
image4 = gr.Image(label=4)
image5 = gr.Image(label=5)
with gr.Row():
image6 = gr.Image(label=6)
image7 = gr.Image(label=7)
image8 = gr.Image(label=8)
image9 = gr.Image(label=9)
image10 = gr.Image(label=10)
search_btn.click(fn=searchEmbeddings, inputs=process_id,
outputs=[process_id_list, image1, image2, image3, image4, image5, image6, image7, image8, image9, image10])
# ARONZ671-20
demo.launch() |