from gradio_huggingfacehub_search import HuggingfaceHubSearch from llama_cpp.llama_speculative import LlamaPromptLookupDecoding from llama_cpp_cuda_tensorcores import Llama from huggingface_hub import hf_hub_download from huggingface_hub import HfApi import matplotlib.pyplot as plt from typing import Tuple, Optional import pandas as pd import gradio as gr import duckdb import requests import instructor import spaces import enum import os from pydantic import BaseModel, Field BASE_DATASETS_SERVER_URL = "https://datasets-server.huggingface.co" view_name = "dataset_view" hf_api = HfApi() conn = duckdb.connect() gpu_layers = int(os.environ.get("GPU_LAYERS", 0)) draft_pred_tokens = int(os.environ.get("DRAFT_PRED_TOKENS", 2)) repo_id = os.getenv("MODEL_REPO_ID", "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF") model_file_name = os.getenv("MODEL_FILE_NAME", "Hermes-2-Pro-Llama-3-8B-Q4_K_M.gguf") hf_hub_download( repo_id=repo_id, filename=model_file_name, local_dir="./models", ) class OutputTypes(str, enum.Enum): TABLE = "table" BARCHART = "barchart" LINECHART = "linechart" class SQLResponse(BaseModel): sql: str visualization_type: Optional[OutputTypes] = Field( None, description="The type of visualization to display" ) data_key: Optional[str] = Field( None, description="The column name from the sql query that contains the data for chart responses", ) label_key: Optional[str] = Field( None, description="The column name from the sql query that contains the labels for chart responses", ) def get_dataset_ddl(dataset_id: str) -> str: response = requests.get(f"{BASE_DATASETS_SERVER_URL}/parquet?dataset={dataset_id}") response.raise_for_status() # Check if the request was successful first_parquet = response.json().get("parquet_files", [])[0] first_parquet_url = first_parquet.get("url") if not first_parquet_url: raise ValueError("No valid URL found for the first parquet file.") conn.execute( f"CREATE OR REPLACE VIEW {view_name} as SELECT * FROM read_parquet('{first_parquet_url}');" ) dataset_ddl = conn.execute(f"PRAGMA table_info('{view_name}');").fetchall() column_data_types = ",\n\t".join( [f"{column[1]} {column[2]}" for column in dataset_ddl] ) sql_ddl = """ CREATE TABLE {} ( {} ); """.format( view_name, column_data_types ) return sql_ddl @spaces.GPU(duration=120) def generate_query(ddl: str, query: str) -> dict: llama = Llama( model_path=f"models/{model_file_name}", n_gpu_layers=gpu_layers, chat_format="chatml", draft_model=LlamaPromptLookupDecoding(num_pred_tokens=draft_pred_tokens), logits_all=True, n_ctx=2048, verbose=True, temperature=0.1, ) create = instructor.patch( create=llama.create_chat_completion_openai_v1, mode=instructor.Mode.JSON_SCHEMA, ) system_prompt = f""" You are an expert SQL assistant with access to the following PostgreSQL Table: ```sql {ddl.strip()} ``` Please assist the user by writing a SQL query that answers the user's question. """ print("Calling LLM with system prompt: ", system_prompt, query) resp: SQLResponse = create( model="Hermes-2-Pro-Llama-3-8B", messages=[ {"role": "system", "content": system_prompt}, { "role": "user", "content": query, }, ], response_model=SQLResponse, ) print("Received Response: ", resp) return resp.model_dump() def query_dataset(dataset_id: str, query: str) -> Tuple[pd.DataFrame, str, plt.Figure]: ddl = get_dataset_ddl(dataset_id) response = generate_query(ddl, query) print("Querying Parquet...") df = conn.execute(response.get("sql")).fetchdf() plot = None label_key = response.get("label_key") data_key = response.get("data_key") viz_type = response.get("visualization_type") sql = response.get("sql") markdown_output = f"""```sql\n{sql}\n```""" # handle incorrect data and label keys if label_key and label_key not in df.columns: label_key = None if data_key and data_key not in df.columns: data_key = None if df.empty: return df, f"```sql\n{sql}\n```", plot if viz_type == OutputTypes.LINECHART: plot = df.plot(kind="line", x=label_key, y=data_key).get_figure() plt.xticks(rotation=45, ha="right") plt.tight_layout() elif viz_type == OutputTypes.BARCHART: plot = df.plot(kind="bar", x=label_key, y=data_key).get_figure() plt.xticks(rotation=45, ha="right") plt.tight_layout() return df, markdown_output, plot with gr.Blocks() as demo: gr.Markdown("# Query your HF Datasets with Natural Language 📈📊") dataset_id = HuggingfaceHubSearch( label="Hub Dataset ID", placeholder="Find your favorite dataset...", search_type="dataset", value="gretelai/synthetic_text_to_sql", ) user_query = gr.Textbox("", label="Ask anything...") examples = [ ["Show me a preview of the data"], ["Show me something interesting"], ["Which row has longest description length?"], ["find the average length of sql query context"], ] gr.Examples(examples=examples, inputs=[user_query], outputs=[]) btn = gr.Button("Ask 🪄") sql_query = gr.Markdown(label="Output SQL Query") df = gr.DataFrame() plot = gr.Plot() btn.click( query_dataset, inputs=[dataset_id, user_query], outputs=[df, sql_query, plot], ) if __name__ == "__main__": demo.launch( show_error=True, quiet=False, debug=True, )