msabia commited on
Commit
89816ab
·
verified ·
1 Parent(s): cb889d6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -0
app.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import h5py
5
+ import faiss
6
+ from PIL import Image
7
+ import io
8
+ import pickle
9
+ import random
10
+
11
+ def getRandID():
12
+ indx = random.randrange(0, 396503)
13
+ return indx_to_id_dict[indx], indx
14
+
15
+ def chooseImageIndex(indexType):
16
+ if (indexType == "FlatIP(default)"):
17
+ return image_index_IP
18
+ elif (indexType == "FlatL2"):
19
+ return image_index_L2
20
+ elif (indexType == "HNSWFlat"):
21
+ return image_index_HNSW
22
+ elif (indexType == "IVFFlat"):
23
+ return image_index_IVF
24
+ elif (indexType == "LSH"):
25
+ return image_index_LSH
26
+
27
+ def chooseDNAIndex(indexType):
28
+ if (indexType == "FlatIP(default)"):
29
+ return dna_index_IP
30
+ elif (indexType == "FlatL2"):
31
+ return dna_index_L2
32
+ elif (indexType == "HNSWFlat"):
33
+ return dna_index_HNSW
34
+ elif (indexType == "IVFFlat"):
35
+ return dna_index_IVF
36
+ elif (indexType == "LSH"):
37
+ return dna_index_LSH
38
+
39
+
40
+
41
+ def searchEmbeddings(id, mod1, mod2, indexType):
42
+ # variable and index initialization
43
+ dim = 768
44
+ count = 0
45
+ num_neighbors = 10
46
+
47
+ index = faiss.IndexFlatIP(dim)
48
+
49
+ # get index
50
+ if (mod2 == "Image"):
51
+ index = chooseImageIndex(indexType)
52
+ elif (mod2 == "DNA"):
53
+ index = chooseDNAIndex(indexType)
54
+
55
+
56
+ # search for query
57
+ if (mod1 == "Image"):
58
+ query = id_to_image_emb_dict[id]
59
+ elif (mod1 == "DNA"):
60
+ query = id_to_dna_emb_dict[id]
61
+ query = query.astype(np.float32)
62
+ D, I = index.search(query, num_neighbors)
63
+
64
+ id_list = []
65
+ i = 1
66
+ for indx in I[0]:
67
+ id = indx_to_id_dict[indx]
68
+ id_list.append(id)
69
+
70
+ return id_list
71
+
72
+ with gr.Blocks() as demo:
73
+
74
+ # for hf: change all file paths, indx_to_id_dict as well
75
+
76
+ # load indexes
77
+ image_index_IP = faiss.read_index("big_image_index_FlatIP.index")
78
+ image_index_L2 = faiss.read_index("big_image_index_FlatL2.index")
79
+ image_index_HNSW = faiss.read_index("big_image_index_HNSWFlat.index")
80
+ image_index_IVF = faiss.read_index("big_image_index_IVFFlat.index")
81
+ image_index_LSH = faiss.read_index("big_image_index_LSH.index")
82
+
83
+ dna_index_IP = faiss.read_index("big_dna_index_FlatIP.index")
84
+ dna_index_L2 = faiss.read_index("big_dna_index_FlatL2.index")
85
+ dna_index_HNSW = faiss.read_index("big_dna_index_HNSWFlat.index")
86
+ dna_index_IVF = faiss.read_index("big_dna_index_IVFFlat.index")
87
+ dna_index_LSH = faiss.read_index("big_dna_index_LSH.index")
88
+
89
+ with open("dataset_processid_list.pickle", "rb") as f:
90
+ dataset_processid_list = pickle.load(f)
91
+ with open("processid_to_index.pickle", "rb") as f:
92
+ processid_to_index = pickle.load(f)
93
+ with open("big_indx_to_id_dict.pickle", "rb") as f:
94
+ indx_to_id_dict = pickle.load(f)
95
+
96
+ # initialize both possible dicts
97
+ with open("big_id_to_image_emb_dict.pickle", "rb") as f:
98
+ id_to_image_emb_dict = pickle.load(f)
99
+ with open("big_id_to_dna_emb_dict.pickle", "rb") as f:
100
+ id_to_dna_emb_dict = pickle.load(f)
101
+
102
+ with gr.Column():
103
+ with gr.Row():
104
+ with gr.Column():
105
+ rand_id = gr.Textbox(label="Random ID:")
106
+ rand_id_indx = gr.Textbox(label="Index:")
107
+ id_btn = gr.Button("Get Random ID")
108
+ with gr.Column():
109
+ mod1 = gr.Radio(choices=["DNA", "Image"], label="Search From:")
110
+ mod2 = gr.Radio(choices=["DNA", "Image"], label="Search To:")
111
+
112
+ indexType = gr.Radio(choices=["FlatIP(default)", "FlatL2", "HNSWFlat", "IVFFlat", "LSH"], label="Index:", value="FlatIP(default)")
113
+ process_id = gr.Textbox(label="ID:", info="Enter a sample ID to search for")
114
+ process_id_list = gr.Textbox(label="Closest 10 matches:" )
115
+ search_btn = gr.Button("Search")
116
+ id_btn.click(fn=getRandID, inputs=[], outputs=[rand_id, rand_id_indx])
117
+
118
+ search_btn.click(fn=searchEmbeddings, inputs=[process_id, mod1, mod2, indexType],
119
+ outputs=[process_id_list])
120
+
121
+
122
+ demo.launch()