File size: 3,277 Bytes
db66d62
 
07356cd
 
 
 
 
db66d62
b383c02
db66d62
1f788b3
 
db66d62
b383c02
1f788b3
07356cd
 
 
 
db66d62
b383c02
1f788b3
07356cd
 
 
 
db66d62
 
1f788b3
 
 
db66d62
 
51e3825
1f788b3
51e3825
1f788b3
 
 
 
 
db66d62
 
51e3825
1f788b3
51e3825
1f788b3
 
 
db66d62
1f788b3
 
 
db66d62
1f788b3
 
b383c02
1f788b3
db66d62
b383c02
db66d62
 
 
 
 
07356cd
 
 
 
 
 
 
1f788b3
db66d62
 
 
 
 
 
 
 
07356cd
 
db66d62
73ff1bb
b383c02
 
1f788b3
 
db66d62
07356cd
b383c02
db66d62
 
1f788b3
 
 
 
 
b383c02
db66d62
b383c02
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import pickle
import random

import gradio as gr
import numpy as np

from data import load_indexes_local, load_indexes_hf, load_index_pickle


def getRandID():
    indx = random.randrange(0, len(index_to_id_dict))
    return index_to_id_dict[indx], indx


def get_image_index(indexType):
    try:
        return image_indexes[indexType]
    except KeyError:
        raise KeyError(f"Tried to load an image index that is not supported: {indexType}")


def get_dna_index(indexType):
    try:
        return dna_indexes[indexType]
    except KeyError:
        raise KeyError(f"Tried to load a DNA index that is not supported: {indexType}")


def searchEmbeddings(id, key_type, query_type, index_type, num_results: int = 10):
    image_index = get_image_index(index_type)
    dna_index = get_dna_index(index_type)

    # get index
    if query_type == "Image":
        query = image_index.reconstruct(id_to_index_dict[id])
    elif query_type == "DNA":
        query = dna_index.reconstruct(id_to_index_dict[id])
    else:
        raise ValueError(f"Invalid query type: {query_type}")
    query = query.astype(np.float32)
    query = np.expand_dims(query, axis=0)

    # search for query
    if key_type == "Image":
        index = image_index
    elif key_type == "DNA":
        index = dna_index
    else:
        raise ValueError(f"Invalid key type: {key_type}")

    _, I = index.search(query, num_results)

    closest_ids = []
    for indx in I[0]:
        id = index_to_id_dict[indx]
        closest_ids.append(id)

    return closest_ids


with gr.Blocks() as demo:

    # for hf: change all file paths, indx_to_id_dict as well

    # load indexes
    image_indexes = load_indexes_hf(
        {"FlatIP(default)": "bioscan_5m_image_IndexFlatIP.index"}, repo_name="bioscan-ml/bioscan-clibd"
    )
    dna_indexes = load_indexes_hf(
        {"FlatIP(default)": "bioscan_5m_dna_IndexFlatIP.index"}, repo_name="bioscan-ml/bioscan-clibd"
    )
    index_to_id_dict = load_index_pickle("big_indx_to_id_dict.pickle", repo_name="bioscan-ml/bioscan-clibd")
    id_to_index_dict = {v: k for k, v in index_to_id_dict.items()}

    with gr.Column():
        with gr.Row():
            with gr.Column():
                rand_id = gr.Textbox(label="Random ID:")
                rand_id_indx = gr.Textbox(label="Index:")
                id_btn = gr.Button("Get Random ID")
            with gr.Column():
                query_type = gr.Radio(choices=["Image", "DNA"], label="Query:", value="Image")
                key_type = gr.Radio(choices=["Image", "DNA"], label="Key:", value="Image")

        index_type = gr.Radio(
            choices=["FlatIP(default)", "FlatL2", "HNSWFlat", "IVFFlat", "LSH"], label="Index:", value="FlatIP(default)"
        )
        num_results = gr.Number(label="Number of Results:", value=10, precision=0)

        process_id = gr.Textbox(label="ID:", info="Enter a sample ID to search for")
        process_id_list = gr.Textbox(label="Closest matches:")
        search_btn = gr.Button("Search")
        id_btn.click(fn=getRandID, inputs=[], outputs=[rand_id, rand_id_indx])

    search_btn.click(
        fn=searchEmbeddings,
        inputs=[process_id, key_type, query_type, index_type, num_results],
        outputs=[process_id_list],
    )


demo.launch()