Caleb Fahlgren commited on
Commit
e8c1c43
1 Parent(s): 7247642

make model parameters more dynamic w env variables

Browse files
Files changed (2) hide show
  1. Hermes-2-Pro-Llama-3-8B-Q8_0.gguf +0 -3
  2. app.py +25 -11
Hermes-2-Pro-Llama-3-8B-Q8_0.gguf DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:d138388cfda04d185a68eaf2396cf7a5cfa87d038a20896817a9b7cf1806f532
3
- size 8541050176
 
 
 
 
app.py CHANGED
@@ -1,5 +1,6 @@
1
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
2
  from llama_cpp.llama_speculative import LlamaPromptLookupDecoding
 
3
  from huggingface_hub import HfApi
4
  import matplotlib.pyplot as plt
5
  from typing import Tuple, Optional
@@ -11,6 +12,7 @@ import llama_cpp
11
  import instructor
12
  import spaces
13
  import enum
 
14
 
15
  from pydantic import BaseModel, Field
16
 
@@ -20,6 +22,18 @@ view_name = "dataset_view"
20
  hf_api = HfApi()
21
  conn = duckdb.connect()
22
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  class OutputTypes(str, enum.Enum):
25
  TABLE = "table"
@@ -75,10 +89,10 @@ CREATE TABLE {} (
75
  @spaces.GPU(duration=120)
76
  def generate_query(ddl: str, query: str) -> dict:
77
  llama = llama_cpp.Llama(
78
- model_path="Hermes-2-Pro-Llama-3-8B-Q8_0.gguf",
79
- n_gpu_layers=50,
80
  chat_format="chatml",
81
- draft_model=LlamaPromptLookupDecoding(num_pred_tokens=2),
82
  logits_all=True,
83
  n_ctx=2048,
84
  verbose=True,
@@ -94,16 +108,13 @@ def generate_query(ddl: str, query: str) -> dict:
94
  You are an expert SQL assistant with access to the following PostgreSQL Table:
95
 
96
  ```sql
97
- {ddl}
98
  ```
99
 
100
  Please assist the user by writing a SQL query that answers the user's question.
101
-
102
- Use Label Key as the column name for the x-axis and Data Key as the column name for the y-axis for chart responses. The
103
- label key and data key must be present in the SQL output.
104
  """
105
 
106
- print("Calling LLM with system prompt: ", system_prompt)
107
 
108
  resp: SQLResponse = create(
109
  model="Hermes-2-Pro-Llama-3-8B",
@@ -135,6 +146,7 @@ def query_dataset(dataset_id: str, query: str) -> Tuple[pd.DataFrame, str, plt.F
135
  data_key = response.get("data_key")
136
  viz_type = response.get("visualization_type")
137
  sql = response.get("sql")
 
138
 
139
  # handle incorrect data and label keys
140
  if label_key and label_key not in df.columns:
@@ -142,6 +154,9 @@ def query_dataset(dataset_id: str, query: str) -> Tuple[pd.DataFrame, str, plt.F
142
  if data_key and data_key not in df.columns:
143
  data_key = None
144
 
 
 
 
145
  if viz_type == OutputTypes.LINECHART:
146
  plot = df.plot(kind="line", x=label_key, y=data_key).get_figure()
147
  plt.xticks(rotation=45, ha="right")
@@ -151,7 +166,6 @@ def query_dataset(dataset_id: str, query: str) -> Tuple[pd.DataFrame, str, plt.F
151
  plt.xticks(rotation=45, ha="right")
152
  plt.tight_layout()
153
 
154
- markdown_output = f"""```sql\n{sql}\n```"""
155
  return df, markdown_output, plot
156
 
157
 
@@ -167,8 +181,8 @@ with gr.Blocks() as demo:
167
  examples = [
168
  ["Show me a preview of the data"],
169
  ["Show me something interesting"],
170
- ["What is the largest length of sql query context?"],
171
- ["show me counts by sql_query_type in a bar chart"],
172
  ]
173
  gr.Examples(examples=examples, inputs=[user_query], outputs=[])
174
 
 
1
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
2
  from llama_cpp.llama_speculative import LlamaPromptLookupDecoding
3
+ from huggingface_hub import hf_hub_download
4
  from huggingface_hub import HfApi
5
  import matplotlib.pyplot as plt
6
  from typing import Tuple, Optional
 
12
  import instructor
13
  import spaces
14
  import enum
15
+ import os
16
 
17
  from pydantic import BaseModel, Field
18
 
 
22
  hf_api = HfApi()
23
  conn = duckdb.connect()
24
 
25
+ gpu_layers = int(os.environ.get("GPU_LAYERS", 81))
26
+ draft_pred_tokens = int(os.environ.get("DRAFT_PRED_TOKENS", 2))
27
+
28
+ repo_id = os.getenv("MODEL_REPO_ID", "NousResearch/Hermes-2-Pro-Llama-3-8B-GGUF")
29
+ model_file_name = os.getenv("MODEL_FILE_NAME", "Hermes-2-Pro-Llama-3-8B-Q8_0.gguf")
30
+
31
+ hf_hub_download(
32
+ repo_id=repo_id,
33
+ filename=model_file_name,
34
+ local_dir="./models",
35
+ )
36
+
37
 
38
  class OutputTypes(str, enum.Enum):
39
  TABLE = "table"
 
89
  @spaces.GPU(duration=120)
90
  def generate_query(ddl: str, query: str) -> dict:
91
  llama = llama_cpp.Llama(
92
+ model_path=f"models/{model_file_name}",
93
+ n_gpu_layers=gpu_layers,
94
  chat_format="chatml",
95
+ draft_model=LlamaPromptLookupDecoding(num_pred_tokens=draft_pred_tokens),
96
  logits_all=True,
97
  n_ctx=2048,
98
  verbose=True,
 
108
  You are an expert SQL assistant with access to the following PostgreSQL Table:
109
 
110
  ```sql
111
+ {ddl.strip()}
112
  ```
113
 
114
  Please assist the user by writing a SQL query that answers the user's question.
 
 
 
115
  """
116
 
117
+ print("Calling LLM with system prompt: ", system_prompt, query)
118
 
119
  resp: SQLResponse = create(
120
  model="Hermes-2-Pro-Llama-3-8B",
 
146
  data_key = response.get("data_key")
147
  viz_type = response.get("visualization_type")
148
  sql = response.get("sql")
149
+ markdown_output = f"""```sql\n{sql}\n```"""
150
 
151
  # handle incorrect data and label keys
152
  if label_key and label_key not in df.columns:
 
154
  if data_key and data_key not in df.columns:
155
  data_key = None
156
 
157
+ if df.empty:
158
+ return df, f"```sql\n{sql}\n```", plot
159
+
160
  if viz_type == OutputTypes.LINECHART:
161
  plot = df.plot(kind="line", x=label_key, y=data_key).get_figure()
162
  plt.xticks(rotation=45, ha="right")
 
166
  plt.xticks(rotation=45, ha="right")
167
  plt.tight_layout()
168
 
 
169
  return df, markdown_output, plot
170
 
171
 
 
181
  examples = [
182
  ["Show me a preview of the data"],
183
  ["Show me something interesting"],
184
+ ["Which row has longest description length?"],
185
+ ["find the average length of sql query context"],
186
  ]
187
  gr.Examples(examples=examples, inputs=[user_query], outputs=[])
188