awacke1 commited on
Commit
7fde12b
1 Parent(s): 8d9fd8c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +168 -0
app.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import lru_cache
2
+
3
+ import duckdb
4
+ import gradio as gr
5
+ import polars as pl
6
+ from datasets import load_dataset
7
+ from gradio_huggingfacehub_search import HuggingfaceHubSearch
8
+ from model2vec import StaticModel
9
+
10
+ global df
11
+
12
+ # Load a model from the HuggingFace hub (in this case the potion-base-8M model)
13
+ model_name = "minishlab/potion-base-8M"
14
+ model = StaticModel.from_pretrained(model_name)
15
+
16
+
17
+ def get_iframe(hub_repo_id):
18
+ if not hub_repo_id:
19
+ raise ValueError("Hub repo id is required")
20
+ url = f"https://huggingface.co/datasets/{hub_repo_id}/embed/viewer"
21
+ iframe = f"""
22
+ <iframe
23
+ src="{url}"
24
+ frameborder="0"
25
+ width="100%"
26
+ height="600px"
27
+ ></iframe>
28
+ """
29
+ return iframe
30
+
31
+
32
+ def load_dataset_from_hub(hub_repo_id: str):
33
+ gr.Info(message="Loading dataset...")
34
+ ds = load_dataset(hub_repo_id)
35
+
36
+
37
+ def get_columns(hub_repo_id: str, split: str):
38
+ ds = load_dataset(hub_repo_id)
39
+ ds_split = ds[split]
40
+ return gr.Dropdown(
41
+ choices=ds_split.column_names,
42
+ value=ds_split.column_names[0],
43
+ label="Select a column",
44
+ visible=True,
45
+ )
46
+
47
+
48
+ def get_splits(hub_repo_id: str):
49
+ ds = load_dataset(hub_repo_id)
50
+ splits = list(ds.keys())
51
+ return gr.Dropdown(
52
+ choices=splits, value=splits[0], label="Select a split", visible=True
53
+ )
54
+
55
+
56
+ @lru_cache
57
+ def vectorize_dataset(hub_repo_id: str, split: str, column: str):
58
+ gr.Info("Vectorizing dataset...")
59
+ ds = load_dataset(hub_repo_id)
60
+ df = ds[split].to_polars()
61
+ embeddings = model.encode(df[column].cast(str), max_length=512)
62
+ return embeddings
63
+
64
+
65
+ def run_query(hub_repo_id: str, query: str, split: str, column: str):
66
+ embeddings = vectorize_dataset(hub_repo_id, split, column)
67
+ ds = load_dataset(hub_repo_id)
68
+ df = ds[split].to_polars()
69
+ df = df.with_columns(pl.Series(embeddings).alias("embeddings"))
70
+ try:
71
+ vector = model.encode(query)
72
+ df_results = duckdb.sql(
73
+ query=f"""
74
+ SELECT *
75
+ FROM df
76
+ ORDER BY array_cosine_distance(embeddings, {vector.tolist()}::FLOAT[256])
77
+ LIMIT 5
78
+ """
79
+ ).to_df()
80
+ return gr.Dataframe(df_results, visible=True)
81
+ except Exception as e:
82
+ raise gr.Error(f"Error running query: {e}")
83
+
84
+
85
+ def hide_components():
86
+ return [
87
+ gr.Dropdown(visible=False),
88
+ gr.Dropdown(visible=False),
89
+ gr.Textbox(visible=False),
90
+ gr.Button(visible=False),
91
+ gr.Dataframe(visible=False),
92
+ ]
93
+
94
+
95
+ def partial_hide_components():
96
+ return [
97
+ gr.Textbox(visible=False),
98
+ gr.Button(visible=False),
99
+ gr.Dataframe(visible=False),
100
+ ]
101
+
102
+
103
+ def show_components():
104
+ return [
105
+ gr.Textbox(visible=True, label="Query"),
106
+ gr.Button(visible=True, value="Search"),
107
+ ]
108
+
109
+
110
+ with gr.Blocks() as demo:
111
+ gr.HTML(
112
+ """
113
+ <h1>Vector Search any Hugging Face Dataset</h1>
114
+ <p>
115
+ This app allows you to vector search any Hugging Face dataset.
116
+ You can search for the nearest neighbors of a query vector, or
117
+ perform a similarity search on a dataframe.
118
+ </p>
119
+ """
120
+ )
121
+ with gr.Row():
122
+ with gr.Column():
123
+ search_in = HuggingfaceHubSearch(
124
+ label="Search Huggingface Hub",
125
+ placeholder="Search for models on Huggingface",
126
+ search_type="dataset",
127
+ sumbit_on_select=True,
128
+ )
129
+ with gr.Row():
130
+ search_out = gr.HTML(label="Search Results")
131
+
132
+ with gr.Row():
133
+ split_dropdown = gr.Dropdown(label="Select a split", visible=False)
134
+ column_dropdown = gr.Dropdown(label="Select a column", visible=False)
135
+ with gr.Row():
136
+ query_input = gr.Textbox(label="Query", visible=False)
137
+
138
+ btn_run = gr.Button("Search", visible=False)
139
+
140
+ results_output = gr.Dataframe(label="Results", visible=False)
141
+
142
+ search_in.submit(get_iframe, inputs=search_in, outputs=search_out).then(
143
+ fn=load_dataset_from_hub,
144
+ inputs=search_in,
145
+ show_progress=True,
146
+ ).then(
147
+ fn=hide_components,
148
+ outputs=[split_dropdown, column_dropdown, query_input, btn_run, results_output],
149
+ ).then(fn=get_splits, inputs=search_in, outputs=split_dropdown).then(
150
+ fn=get_columns, inputs=[search_in, split_dropdown], outputs=column_dropdown
151
+ )
152
+
153
+ split_dropdown.change(
154
+ fn=get_columns, inputs=[search_in, split_dropdown], outputs=column_dropdown
155
+ )
156
+
157
+ column_dropdown.change(
158
+ fn=partial_hide_components,
159
+ outputs=[query_input, btn_run, results_output],
160
+ ).then(fn=show_components, outputs=[query_input, btn_run])
161
+
162
+ btn_run.click(
163
+ fn=run_query,
164
+ inputs=[search_in, query_input, split_dropdown, column_dropdown],
165
+ outputs=results_output,
166
+ )
167
+
168
+ demo.launch()