import gradio as gr import torch import numpy as np import h5py import faiss from PIL import Image import io import pickle import random def getRandID(): indx = random.randrange(0, 396503) return indx_to_id_dict[indx], indx def chooseImageIndex(indexType): if (indexType == "FlatIP(default)"): return image_index_IP elif (indexType == "FlatL2"): return image_index_L2 elif (indexType == "HNSWFlat"): return image_index_HNSW elif (indexType == "IVFFlat"): return image_index_IVF elif (indexType == "LSH"): return image_index_LSH def chooseDNAIndex(indexType): if (indexType == "FlatIP(default)"): return dna_index_IP elif (indexType == "FlatL2"): return dna_index_L2 elif (indexType == "HNSWFlat"): return dna_index_HNSW elif (indexType == "IVFFlat"): return dna_index_IVF elif (indexType == "LSH"): return dna_index_LSH def searchEmbeddings(id, mod1, mod2, indexType): # variable and index initialization dim = 768 count = 0 num_neighbors = 10 index = faiss.IndexFlatIP(dim) # get index if (mod2 == "Image"): index = chooseImageIndex(indexType) elif (mod2 == "DNA"): index = chooseDNAIndex(indexType) # search for query if (mod1 == "Image"): query = id_to_image_emb_dict[id] elif (mod1 == "DNA"): query = id_to_dna_emb_dict[id] query = query.astype(np.float32) D, I = index.search(query, num_neighbors) id_list = [] i = 1 for indx in I[0]: id = indx_to_id_dict[indx] id_list.append(id) return id_list with gr.Blocks() as demo: # for hf: change all file paths, indx_to_id_dict as well # load indexes image_index_IP = faiss.read_index("big_image_index_FlatIP.index") image_index_L2 = faiss.read_index("big_image_index_FlatL2.index") image_index_HNSW = faiss.read_index("big_image_index_HNSWFlat.index") image_index_IVF = faiss.read_index("big_image_index_IVFFlat.index") image_index_LSH = faiss.read_index("big_image_index_LSH.index") dna_index_IP = faiss.read_index("big_dna_index_FlatIP.index") dna_index_L2 = faiss.read_index("big_dna_index_FlatL2.index") dna_index_HNSW = faiss.read_index("big_dna_index_HNSWFlat.index") dna_index_IVF = faiss.read_index("big_dna_index_IVFFlat.index") dna_index_LSH = faiss.read_index("big_dna_index_LSH.index") with open("dataset_processid_list.pickle", "rb") as f: dataset_processid_list = pickle.load(f) with open("processid_to_index.pickle", "rb") as f: processid_to_index = pickle.load(f) with open("big_indx_to_id_dict.pickle", "rb") as f: indx_to_id_dict = pickle.load(f) # initialize both possible dicts with open("big_id_to_image_emb_dict.pickle", "rb") as f: id_to_image_emb_dict = pickle.load(f) with open("big_id_to_dna_emb_dict.pickle", "rb") as f: id_to_dna_emb_dict = pickle.load(f) with gr.Column(): with gr.Row(): with gr.Column(): rand_id = gr.Textbox(label="Random ID:") rand_id_indx = gr.Textbox(label="Index:") id_btn = gr.Button("Get Random ID") with gr.Column(): mod1 = gr.Radio(choices=["DNA", "Image"], label="Search From:") mod2 = gr.Radio(choices=["DNA", "Image"], label="Search To:") indexType = gr.Radio(choices=["FlatIP(default)", "FlatL2", "HNSWFlat", "IVFFlat", "LSH"], label="Index:", value="FlatIP(default)") process_id = gr.Textbox(label="ID:", info="Enter a sample ID to search for") process_id_list = gr.Textbox(label="Closest 10 matches:" ) search_btn = gr.Button("Search") id_btn.click(fn=getRandID, inputs=[], outputs=[rand_id, rand_id_indx]) search_btn.click(fn=searchEmbeddings, inputs=[process_id, mod1, mod2, indexType], outputs=[process_id_list]) demo.launch()