Caleb Fahlgren commited on
Commit
13e0d1b
β€’
1 Parent(s): 44cb622

add llm for generating sql

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ . filter=lfs diff=lfs merge=lfs -text
37
+ Hermes-2-Pro-Llama-3-8B-Q8_0.gguf filter=lfs diff=lfs merge=lfs -text
Hermes-2-Pro-Llama-3-8B-Q8_0.gguf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d138388cfda04d185a68eaf2396cf7a5cfa87d038a20896817a9b7cf1806f532
3
+ size 8541050176
app.py CHANGED
@@ -4,16 +4,36 @@ import pandas as pd
4
  import gradio as gr
5
  import duckdb
6
  import requests
 
 
 
 
7
 
8
  BASE_DATASETS_SERVER_URL = "https://datasets-server.huggingface.co"
 
9
 
10
  hf_api = HfApi()
11
  conn = duckdb.connect()
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- def get_dataset_ddl(dataset_id: str) -> pd.DataFrame:
15
- view_name = "dataset_view"
16
 
 
 
17
  response = requests.get(f"{BASE_DATASETS_SERVER_URL}/parquet?dataset={dataset_id}")
18
  response.raise_for_status() # Check if the request was successful
19
 
@@ -43,24 +63,61 @@ CREATE TABLE {} (
43
  return sql_ddl
44
 
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  with gr.Blocks() as demo:
47
  gr.Markdown("# Query your HF Datasets with Natural Language πŸ“ˆπŸ“Š")
48
- dataset_name = HuggingfaceHubSearch(
49
  label="Hub Dataset ID",
50
  placeholder="Find your favorite dataset...",
51
  search_type="dataset",
52
  value="jamescalam/world-cities-geo",
53
  )
54
- query_input = gr.Textbox("", label="Ask anything...")
55
 
56
  btn = gr.Button("Ask πŸͺ„")
57
- df = gr.DataFrame(datatype="markdown")
58
- ddl = gr.Text("")
 
59
 
60
  btn.click(
61
- get_dataset_ddl,
62
- inputs=[dataset_name],
63
- outputs=[ddl],
64
  )
65
 
66
 
 
4
  import gradio as gr
5
  import duckdb
6
  import requests
7
+ import llama_cpp
8
+ import instructor
9
+
10
+ from pydantic import BaseModel
11
 
12
  BASE_DATASETS_SERVER_URL = "https://datasets-server.huggingface.co"
13
+ view_name = "dataset_view"
14
 
15
  hf_api = HfApi()
16
  conn = duckdb.connect()
17
 
18
+ llama = llama_cpp.Llama(
19
+ model_path="Hermes-2-Pro-Llama-3-8B-Q8_0.gguf",
20
+ n_gpu_layers=-1,
21
+ chat_format="chatml",
22
+ n_ctx=2048,
23
+ verbose=False,
24
+ )
25
+
26
+ create = instructor.patch(
27
+ create=llama.create_chat_completion_openai_v1,
28
+ mode=instructor.Mode.JSON_SCHEMA,
29
+ )
30
+
31
 
32
+ class SQLResponse(BaseModel):
33
+ sql: str
34
 
35
+
36
+ def get_dataset_ddl(dataset_id: str) -> str:
37
  response = requests.get(f"{BASE_DATASETS_SERVER_URL}/parquet?dataset={dataset_id}")
38
  response.raise_for_status() # Check if the request was successful
39
 
 
63
  return sql_ddl
64
 
65
 
66
+ def generate_sql(dataset_id: str, query: str) -> str:
67
+ ddl = get_dataset_ddl(dataset_id)
68
+
69
+ system_prompt = f"""
70
+ You are an expert SQL assistant with access to the following DuckDB Table:
71
+
72
+ ```sql
73
+ {ddl}
74
+ ```
75
+
76
+ Please assist the user by writing a SQL query that answers the user's question.
77
+ """
78
+
79
+ resp: SQLResponse = create(
80
+ model="Hermes-2-Pro-Llama-3-8B",
81
+ messages=[
82
+ {"role": "system", "content": system_prompt},
83
+ {
84
+ "role": "user",
85
+ "content": query,
86
+ },
87
+ ],
88
+ response_model=SQLResponse,
89
+ )
90
+
91
+ return resp.sql
92
+
93
+
94
+ def query_dataset(dataset_id: str, query: str) -> tuple[pd.DataFrame, str]:
95
+ sql_query = generate_sql(dataset_id, query)
96
+ df = conn.execute(sql_query).fetchdf()
97
+
98
+ markdown_output = f"""```sql\n{sql_query}```"""
99
+ return df, markdown_output
100
+
101
+
102
  with gr.Blocks() as demo:
103
  gr.Markdown("# Query your HF Datasets with Natural Language πŸ“ˆπŸ“Š")
104
+ dataset_id = HuggingfaceHubSearch(
105
  label="Hub Dataset ID",
106
  placeholder="Find your favorite dataset...",
107
  search_type="dataset",
108
  value="jamescalam/world-cities-geo",
109
  )
110
+ user_query = gr.Textbox("", label="Ask anything...")
111
 
112
  btn = gr.Button("Ask πŸͺ„")
113
+
114
+ df = gr.DataFrame()
115
+ sql_query = gr.Markdown(label="Output SQL Query")
116
 
117
  btn.click(
118
+ query_dataset,
119
+ inputs=[dataset_id, user_query],
120
+ outputs=[df, sql_query],
121
  )
122
 
123
 
requirements.txt CHANGED
@@ -1,7 +1,10 @@
1
  aiofiles==23.2.1
 
 
2
  altair==5.3.0
3
  annotated-types==0.7.0
4
  anyio==4.4.0
 
5
  attrs==23.2.0
6
  certifi==2024.6.2
7
  charset-normalizer==3.3.2
@@ -9,7 +12,9 @@ click==8.1.7
9
  contourpy==1.2.1
10
  cycler==0.12.1
11
  diskcache==5.6.3
 
12
  dnspython==2.6.1
 
13
  duckdb==1.0.0
14
  email_validator==2.1.1
15
  exceptiongroup==1.2.1
@@ -18,6 +23,7 @@ fastapi-cli==0.0.4
18
  ffmpy==0.3.2
19
  filelock==3.14.0
20
  fonttools==4.53.0
 
21
  fsspec==2024.6.0
22
  gradio==4.32.2
23
  gradio_client==0.17.0
@@ -29,6 +35,7 @@ httpx==0.27.0
29
  huggingface-hub==0.23.2
30
  idna==3.7
31
  importlib_resources==6.4.0
 
32
  Jinja2==3.1.4
33
  jsonschema==4.22.0
34
  jsonschema-specifications==2023.12.1
@@ -39,8 +46,10 @@ MarkupSafe==2.1.5
39
  matplotlib==3.9.0
40
  mdurl==0.1.2
41
  mpmath==1.3.0
 
42
  networkx==3.3
43
  numpy==1.26.4
 
44
  orjson==3.10.3
45
  packaging==24.0
46
  pandas==2.2.2
@@ -68,6 +77,7 @@ sniffio==1.3.1
68
  spaces==0.28.3
69
  starlette==0.37.2
70
  sympy==1.12.1
 
71
  tomlkit==0.12.0
72
  toolz==0.12.1
73
  torch==2.3.0
@@ -81,3 +91,4 @@ uvicorn==0.30.1
81
  uvloop==0.19.0
82
  watchfiles==0.22.0
83
  websockets==11.0.3
 
 
1
  aiofiles==23.2.1
2
+ aiohttp==3.9.5
3
+ aiosignal==1.3.1
4
  altair==5.3.0
5
  annotated-types==0.7.0
6
  anyio==4.4.0
7
+ async-timeout==4.0.3
8
  attrs==23.2.0
9
  certifi==2024.6.2
10
  charset-normalizer==3.3.2
 
12
  contourpy==1.2.1
13
  cycler==0.12.1
14
  diskcache==5.6.3
15
+ distro==1.9.0
16
  dnspython==2.6.1
17
+ docstring_parser==0.16
18
  duckdb==1.0.0
19
  email_validator==2.1.1
20
  exceptiongroup==1.2.1
 
23
  ffmpy==0.3.2
24
  filelock==3.14.0
25
  fonttools==4.53.0
26
+ frozenlist==1.4.1
27
  fsspec==2024.6.0
28
  gradio==4.32.2
29
  gradio_client==0.17.0
 
35
  huggingface-hub==0.23.2
36
  idna==3.7
37
  importlib_resources==6.4.0
38
+ instructor==1.3.2
39
  Jinja2==3.1.4
40
  jsonschema==4.22.0
41
  jsonschema-specifications==2023.12.1
 
46
  matplotlib==3.9.0
47
  mdurl==0.1.2
48
  mpmath==1.3.0
49
+ multidict==6.0.5
50
  networkx==3.3
51
  numpy==1.26.4
52
+ openai==1.31.0
53
  orjson==3.10.3
54
  packaging==24.0
55
  pandas==2.2.2
 
77
  spaces==0.28.3
78
  starlette==0.37.2
79
  sympy==1.12.1
80
+ tenacity==8.3.0
81
  tomlkit==0.12.0
82
  toolz==0.12.1
83
  torch==2.3.0
 
91
  uvloop==0.19.0
92
  watchfiles==0.22.0
93
  websockets==11.0.3
94
+ yarl==1.9.4