|
import json |
|
import logging |
|
import os |
|
import urllib.parse |
|
from typing import Any |
|
|
|
import gradio as gr |
|
import requests |
|
from gradio_huggingfacehub_search import HuggingfaceHubSearch |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
example = HuggingfaceHubSearch().example_value() |
|
|
|
HEADER_CONTENT = ( |
|
"# 🤗 Dataset DuckDB Query Chatbot\n\n" |
|
"This is a basic text to SQL tool that allows you to query datasets on Hugging Face Hub. " |
|
"It's a fork of " |
|
"[davidberenstein1957/text-to-sql-hub-datasets](https://huggingface.co/spaces/davidberenstein1957/text-to-sql-hub-datasets) " |
|
"that adds chat capability and table name generation." |
|
) |
|
ABOUT_CONTENT = """ |
|
This space uses [LLama 3.1 70B](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct). |
|
via [together.ai](https://together.ai) |
|
Also, it uses the |
|
[dataset-server API](https://redocly.github.io/redoc/?url=https://datasets-server.huggingface.co/openapi.json#operation/isValidDataset). |
|
|
|
Query history is saved and given to the chat model so you can chat to refine your query as you go. |
|
|
|
When the DuckDB modal is presented, you may need to click on the name of the |
|
config/split at the base of the modal to get the table loaded for DuckDB's use. |
|
|
|
Search for and select a dataset to begin. |
|
""" |
|
|
|
SYSTEM_PROMPT_TEMPLATE = ( |
|
"You are a SQL query expert assistant that returns a DuckDB SQL queries " |
|
"based on the user's natural language query and dataset features. " |
|
"You might need to use DuckDB functions for lists and aggregations, " |
|
"given the features. Only return the SQL query, no other text. The " |
|
"user may ask you to make various adjustments to the query. Every " |
|
"time your response should only include the refined SQL query and " |
|
"nothing else.\n\n" |
|
"The table being queried is named: {table_name}.\n\n" |
|
"# Features\n" |
|
"{features}" |
|
) |
|
|
|
|
|
def get_iframe(hub_repo_id, sql_query=None): |
|
if not hub_repo_id: |
|
raise ValueError("Hub repo id is required") |
|
if sql_query: |
|
sql_query = urllib.parse.quote(sql_query) |
|
url = f"https://huggingface.co/datasets/{hub_repo_id}/embed/viewer?sql_console=true&sql={sql_query}" |
|
else: |
|
url = f"https://huggingface.co/datasets/{hub_repo_id}/embed/viewer" |
|
iframe = f""" |
|
<iframe |
|
src="{url}" |
|
frameborder="0" |
|
width="100%" |
|
height="800px" |
|
></iframe> |
|
""" |
|
return iframe |
|
|
|
|
|
def get_table_info(hub_repo_id): |
|
url: str = f"https://datasets-server.huggingface.co/info?dataset={hub_repo_id}" |
|
response = requests.get(url) |
|
try: |
|
data = response.json() |
|
data = data.get("dataset_info") |
|
return json.dumps(data) |
|
except Exception as e: |
|
gr.Error(f"Error getting column info: {e}") |
|
|
|
|
|
def get_table_name( |
|
config: str | None, |
|
split: str | None, |
|
config_choices: list[str], |
|
split_choices: list[str], |
|
): |
|
if len(config_choices) > 0 and config is None: |
|
config = config_choices[0] |
|
if len(split_choices) > 0 and split is None: |
|
split = split_choices[0] |
|
|
|
if len(config_choices) > 1 and len(split_choices) > 1: |
|
base_name = f"{config}_{split}" |
|
elif len(config_choices) >= 1 and len(split_choices) <= 1: |
|
base_name = config |
|
else: |
|
base_name = split |
|
|
|
def replace_char(c): |
|
if c.isalnum(): |
|
return c |
|
if c in ["-", "_", "/"]: |
|
return "_" |
|
return "" |
|
|
|
table_name = "".join(replace_char(c) for c in base_name) |
|
if table_name[0].isdigit(): |
|
table_name = f"_{table_name}" |
|
return table_name.lower() |
|
|
|
|
|
def get_system_prompt( |
|
card_data: dict[str, Any], |
|
config: str | None, |
|
split: str | None, |
|
): |
|
config_choices = get_config_choices(card_data) |
|
split_choices = get_split_choices(card_data) |
|
|
|
table_name = get_table_name(config, split, config_choices, split_choices) |
|
features = card_data[config]["features"] |
|
return SYSTEM_PROMPT_TEMPLATE.format( |
|
table_name=table_name, |
|
features=features, |
|
) |
|
|
|
|
|
def get_config_choices(card_data: dict[str, Any]) -> list[str]: |
|
return list(card_data.keys()) |
|
|
|
|
|
def get_split_choices(card_data: dict[str, Any]) -> list[str]: |
|
splits = set() |
|
for config in card_data.values(): |
|
splits.update(config.get("splits", {}).keys()) |
|
|
|
return list(splits) |
|
|
|
|
|
def query_dataset(hub_repo_id, card_data, query, config, split, history): |
|
if card_data is None or len(card_data) == 0: |
|
if hub_repo_id: |
|
iframe = get_iframe(hub_repo_id) |
|
else: |
|
iframe = "<p>No dataset selected.</p>" |
|
return "", iframe, [], "" |
|
card_data = json.loads(card_data) |
|
system_prompt = get_system_prompt(card_data, config, split) |
|
messages = [{"role": "system", "content": system_prompt}] |
|
for turn in history: |
|
user, assistant = turn |
|
messages.append( |
|
{ |
|
"role": "user", |
|
"content": user, |
|
} |
|
) |
|
messages.append( |
|
{ |
|
"role": "assistant", |
|
"content": assistant, |
|
} |
|
) |
|
messages.append( |
|
{ |
|
"role": "user", |
|
"content": query, |
|
} |
|
) |
|
api_key = os.environ["API_KEY_TOGETHER_AI"].strip() |
|
response = requests.post( |
|
"https://api.together.xyz/v1/chat/completions", |
|
json=dict( |
|
model="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", |
|
messages=messages, |
|
max_tokens=1000, |
|
), |
|
headers={"Authorization": f"Bearer {api_key}"}, |
|
) |
|
|
|
if response.status_code != 200: |
|
logger.warning(response.text) |
|
|
|
try: |
|
response.raise_for_status() |
|
except Exception as e: |
|
gr.Error(f"Could not query LLM for suggestion: {e}") |
|
|
|
response_dict = response.json() |
|
duck_query = response_dict["choices"][0]["message"]["content"] |
|
duck_query = _sanitize_duck_query(duck_query) |
|
history.append((query, duck_query)) |
|
return duck_query, get_iframe(hub_repo_id, duck_query), history, "" |
|
|
|
|
|
def _sanitize_duck_query(duck_query: str) -> str: |
|
|
|
|
|
|
|
|
|
|
|
if "```" not in duck_query: |
|
return duck_query |
|
start_idx = duck_query.index("```") + len("```") |
|
end_idx = duck_query.rindex("```") |
|
duck_query = duck_query[start_idx:end_idx] |
|
if duck_query.startswith("sql\n"): |
|
duck_query = duck_query.replace("sql\n", "", 1) |
|
return duck_query |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(HEADER_CONTENT) |
|
with gr.Accordion("About/Help", open=False): |
|
gr.Markdown(ABOUT_CONTENT) |
|
with gr.Row(): |
|
search_in = HuggingfaceHubSearch( |
|
label="Search Hugging Face Hub", |
|
placeholder="Search for models on Huggingface", |
|
search_type="dataset", |
|
sumbit_on_select=True, |
|
) |
|
with gr.Row(): |
|
show_btn = gr.Button("Show Dataset") |
|
with gr.Row(): |
|
sql_out = gr.Code( |
|
label="DuckDB SQL Query", |
|
interactive=True, |
|
language="sql", |
|
lines=1, |
|
visible=False, |
|
) |
|
with gr.Row(): |
|
card_data = gr.Code(label="Card data", language="json", visible=False) |
|
|
|
@gr.render(inputs=[card_data]) |
|
def show_config_split_choices(data): |
|
try: |
|
data = json.loads(data.strip()) |
|
config_choices = get_config_choices(data) |
|
split_choices = get_split_choices(data) |
|
except Exception: |
|
config_choices = [] |
|
split_choices = [] |
|
|
|
initial_config = config_choices[0] if len(config_choices) > 0 else None |
|
initial_split = split_choices[0] if len(split_choices) > 0 else None |
|
with gr.Row(): |
|
with gr.Column(): |
|
config_selection = gr.Dropdown( |
|
label="Config Name", choices=config_choices, value=initial_config |
|
) |
|
with gr.Column(): |
|
split_selection = gr.Dropdown( |
|
label="Split Name", choices=split_choices, value=initial_split |
|
) |
|
|
|
with gr.Accordion("Query Suggestion History.", open=False) as accordion: |
|
chatbot = gr.Chatbot(height=200, layout="bubble") |
|
with gr.Row(): |
|
query = gr.Textbox( |
|
label="Query Description", |
|
placeholder="Enter a natural language query to generate SQL", |
|
) |
|
with gr.Row(): |
|
with gr.Column(): |
|
query_btn = gr.Button("Get Suggested Query") |
|
with gr.Column(): |
|
clear = gr.ClearButton([query, chatbot], value="Reset Query History") |
|
with gr.Row(): |
|
search_out = gr.HTML(label="Search Results") |
|
gr.on( |
|
[show_btn.click, search_in.submit], |
|
fn=get_iframe, |
|
inputs=[search_in], |
|
outputs=[search_out], |
|
).then( |
|
fn=get_table_info, |
|
inputs=[search_in], |
|
outputs=[card_data], |
|
) |
|
gr.on( |
|
[query_btn.click, query.submit], |
|
fn=query_dataset, |
|
inputs=[ |
|
search_in, |
|
card_data, |
|
query, |
|
config_selection, |
|
split_selection, |
|
chatbot, |
|
], |
|
outputs=[sql_out, search_out, chatbot, query], |
|
) |
|
gr.on([query_btn.click], fn=lambda: gr.update(open=True), outputs=[accordion]) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|