davidberenstein1957 HF staff commited on
Commit
ce0d80f
1 Parent(s): bcbb85b

feat: set casting to str by default

Browse files
Files changed (1) hide show
  1. app.py +15 -5
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import duckdb
2
  import gradio as gr
3
  import polars as pl
@@ -56,18 +58,24 @@ def vectorize_dataset(split: str, column: str):
56
  global df
57
  global ds
58
  df = ds[split].to_polars()
59
- embeddings = model.encode(df[column], max_length=512)
60
- df = df.with_columns(pl.Series(embeddings).alias("embeddings"))
61
 
62
 
63
- def run_query(query: str):
64
  global df
 
 
 
 
 
 
65
  vector = model.encode(query)
66
  df_results = duckdb.sql(
67
  query=f"""
68
  SELECT *
69
  FROM df
70
- ORDER BY array_cosine_distance(embeddings, {vector.tolist()}::FLOAT[256])
71
  LIMIT 5
72
  """
73
  ).to_df()
@@ -152,6 +160,8 @@ with gr.Blocks() as demo:
152
  fn=show_components, outputs=[query_input, btn_run]
153
  )
154
 
155
- btn_run.click(fn=run_query, inputs=query_input, outputs=results_output)
 
 
156
 
157
  demo.launch()
 
1
+ import time
2
+
3
  import duckdb
4
  import gradio as gr
5
  import polars as pl
 
58
  global df
59
  global ds
60
  df = ds[split].to_polars()
61
+ embeddings = model.encode(df[column].cast(str), max_length=512)
62
+ df = df.with_columns(pl.Series(embeddings).alias(f"{column}_embeddings"))
63
 
64
 
65
+ def run_query(query: str, column: str):
66
  global df
67
+ while f"{column}_embeddings" not in df.columns:
68
+ sleeper = 5
69
+ gr.Info(
70
+ f"Waiting for vectorization to complete... ({sleeper}s)", duration=sleeper
71
+ )
72
+ time.sleep(sleeper)
73
  vector = model.encode(query)
74
  df_results = duckdb.sql(
75
  query=f"""
76
  SELECT *
77
  FROM df
78
+ ORDER BY array_cosine_distance({column}_embeddings, {vector.tolist()}::FLOAT[256])
79
  LIMIT 5
80
  """
81
  ).to_df()
 
160
  fn=show_components, outputs=[query_input, btn_run]
161
  )
162
 
163
+ btn_run.click(
164
+ fn=run_query, inputs=[query_input, column_dropdown], outputs=results_output
165
+ )
166
 
167
  demo.launch()