julien-c's picture
julien-c HF staff
Hook full-text search
a00452d verified
raw
history blame
2.57 kB
# Inspired by https://huggingface.co/spaces/asoria/duckdb-parquet-demo
import gradio as gr
import duckdb
import pandas as pd
import requests
DATASETS_SERVER_ENDPOINT = "https://datasets-server.huggingface.co"
PARQUET_REVISION="refs/convert/parquet"
EXAMPLE_DATASET_NAME = "LLMs/Alpaca-ShareGPT"
def get_parquet_urls(dataset: str) -> list[str]:
splits = requests.get(f"{DATASETS_SERVER_ENDPOINT}/splits?dataset={dataset}", timeout=60).json().get("splits")
split = splits[0]
response = requests.get(f"{DATASETS_SERVER_ENDPOINT}/parquet?dataset={dataset}&config={split['config']}", timeout=60)
if response.status_code != 200:
raise Exception(response)
response = response.json()
parquet_files = response["parquet_files"]
urls = [content["url"] for content in parquet_files if content["split"] == split["split"]]
if len(urls) == 0:
raise Exception("No parquet files found for dataset")
return urls
def run_command(query: str) -> pd.DataFrame:
try:
result = duckdb.execute("SELECT fts_main_data.match_bm25(id, ?) AS score, id, instruction, input, output FROM data WHERE score IS NOT NULL ORDER BY score;", [query])
print("Ok")
except Exception as error:
print(f"Error: {str(error)}")
return pd.DataFrame({"Error": [f"❌ {str(error)}"]})
print(result)
return result.df()
def import_data():
# Import data + index
parquet_url = get_parquet_urls(EXAMPLE_DATASET_NAME)[0]
print("parquet_url", parquet_url)
duckdb.sql("CREATE SEQUENCE serial START 1;")
# We need a sequence id column for Full text search
# I'm very rusty in SQL so it's very possible there are simpler ways.
duckdb.sql(f"CREATE TABLE data AS SELECT nextval('serial') AS id, * FROM '{parquet_url}';")
duckdb.sql("PRAGMA create_fts_index('data', 'id', '*');")
duckdb.sql("DESCRIBE SELECT * FROM data").show()
print("foo foo")
with gr.Blocks() as demo:
gr.Markdown(" ## Full-text search using DuckDB on top of datasets-server Parquet files 🐀")
gr.CheckboxGroup(label="Dataset", choices=["LLMs/Alpaca-ShareGPT"], value="LLMs/Alpaca-ShareGPT", info="Dataset to query"),
query = gr.Textbox(label="query", placeholder="Full-text search...")
run_button = gr.Button("Run")
gr.Markdown("### Result")
cached_responses_table = gr.DataFrame()
run_button.click(run_command, inputs=[query], outputs=cached_responses_table)
if __name__ == "__main__":
import_data()
demo.launch()