Muhammad Mustehson commited on
Commit
a360e3c
·
1 Parent(s): 902da82

Update Old Code

Browse files
Files changed (8) hide show
  1. .gitignore +5 -1
  2. app.py +167 -157
  3. requirements.txt +9 -10
  4. src/__init__.py +0 -0
  5. src/client.py +131 -0
  6. src/models.py +6 -0
  7. src/pipelines.py +98 -0
  8. src/prompts.py +22 -0
.gitignore CHANGED
@@ -1 +1,5 @@
1
- app2.py
 
 
 
 
 
1
+ .env
2
+ .venv
3
+ __pycache__/
4
+ *.pyc
5
+ *.pyo
app.py CHANGED
@@ -1,103 +1,88 @@
 
1
  import os
2
- import torch
 
3
  import duckdb
4
- import spaces
5
- import lancedb
6
  import gradio as gr
 
7
  import pandas as pd
8
  import pyarrow as pa
9
- from langchain import hub
10
- from langsmith import traceable
11
- from sentence_transformers import SentenceTransformer
12
- from langchain_huggingface.llms import HuggingFacePipeline
13
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
14
 
15
- # Height of the Tabs Text Area
16
- TAB_LINES = 8
17
 
 
18
 
19
- #----------CONNECT TO DATABASE----------
20
- md_token = os.getenv('MD_TOKEN')
21
- conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True)
22
- #---------------------------------------
23
 
24
- if torch.cuda.is_available():
25
- device = torch.device("cuda")
26
- print(f"Using GPU: {torch.cuda.get_device_name(device)}")
27
- else:
28
- device = torch.device("cpu")
29
- print("Using CPU")
30
- #---------------------------------------
 
 
31
 
32
- #--------------LanceDB-------------
33
 
34
- lance_db = lancedb.connect(
35
- uri=os.getenv('lancedb_uri'),
36
- api_key=os.getenv('lancedb_api_key'),
37
- region=os.getenv('lancedb_region')
38
- )
39
 
40
- lance_schema = pa.schema([
41
- pa.field("vector", pa.list_(pa.float32())),
42
- pa.field("sql-query", pa.utf8())
43
- ])
44
 
45
- try:
46
- table = lance_db.create_table(name="SQL-Queries", schema=lance_schema)
47
- except:
48
- table = lance_db.open_table(name="SQL-Queries")
49
- #---------------------------------------
50
 
51
- #-------LOAD HUGGINGFACE PIPELINE-------
52
- tokenizer = AutoTokenizer.from_pretrained("defog/llama-3-sqlcoder-8b")
53
 
54
- quantization_config = BitsAndBytesConfig(
55
- load_in_4bit=True,
56
- bnb_4bit_compute_dtype=torch.bfloat16,
57
- bnb_4bit_use_double_quant=True,
58
- bnb_4bit_quant_type= "nf4")
59
-
60
- model = AutoModelForCausalLM.from_pretrained("defog/llama-3-sqlcoder-8b", quantization_config=quantization_config,
61
- device_map="auto", torch_dtype=torch.bfloat16)
 
 
 
 
 
 
62
 
63
- pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=1024, return_full_text=False)
64
- hf = HuggingFacePipeline(pipeline=pipe)
65
- #---------------------------------------
66
 
67
- #-----LOAD PROMPT FROM LANCHAIN HUB-----
68
- prompt = hub.pull("sql-agent-prompt")
69
- #---------------------------------------
70
 
71
- #-----LOAD EMBEDDING MODEL-----
72
- embedding_model = SentenceTransformer("all-MiniLM-L6-v2", device=device)
73
- #---------------------------------------
74
 
75
- #--------------ALL UTILS----------------
76
- # Get Databases
77
- def get_schemas():
78
  schemas = conn.execute("""
79
  SELECT DISTINCT schema_name
80
  FROM information_schema.schemata
81
  WHERE schema_name NOT IN ('information_schema', 'pg_catalog')
82
  """).fetchall()
83
- return [item[0] for item in schemas]
 
84
 
85
- # Get Tables
86
- def get_tables(schema_name):
87
- tables = conn.execute(f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{schema_name}'").fetchall()
 
88
  return [table[0] for table in tables]
89
 
90
- # Update Tables
91
- def update_tables(schema_name):
92
  tables = get_tables(schema_name)
93
  return gr.update(choices=tables)
94
 
95
- # Get Schema
96
- def get_table_schema(table):
97
- result = conn.sql(f"SELECT sql, database_name, schema_name FROM duckdb_tables() where table_name ='{table}';").df()
98
- ddl_create = result.iloc[0,0]
99
- parent_database = result.iloc[0,1]
100
- schema_name = result.iloc[0,2]
 
 
101
  full_path = f"{parent_database}.{schema_name}.{table}"
102
  if schema_name != "main":
103
  old_path = f"{schema_name}.{table}"
@@ -106,85 +91,81 @@ def get_table_schema(table):
106
  ddl_create = ddl_create.replace(old_path, full_path)
107
  return ddl_create
108
 
109
- # Get Prompt
110
- def get_prompt(schema, query_input):
111
- return prompt.format(schema=schema, query_input=query_input)
112
-
113
- @spaces.GPU(duration=60)
114
- @traceable()
115
- def generate_sql(prompt):
116
- result = hf.invoke(prompt)
117
- return result.strip()
118
- @spaces.GPU(duration=10)
119
- def embed_query(sql_query):
120
- print(f'Creating Emebeddings {sql_query}')
121
- if sql_query is not None:
122
- embeddings = embedding_model.encode(sql_query, normalize_embeddings=True).tolist()
123
- return embeddings
124
-
125
- def log2lancedb(embeddings, sql_query):
126
- data = [{
127
- "sql-query": sql_query,
128
- "vector": embeddings
129
- }]
130
- table.add(data)
131
- print(f'Added to Lance DB.')
132
- #---------------------------------------
133
-
134
-
135
- # Generate SQL
136
- def text2sql(table, query_input):
137
  if table is None:
138
- return {
139
- table_schema: "",
140
- input_prompt: "",
141
- generated_query: "",
142
- result_output:pd.DataFrame([{"error": "❌ Please Select Table, Schema.}"}])
143
- }
144
-
145
- schema = get_table_schema(table)
146
- print(f'Schema Generated...')
147
- prompt = get_prompt(schema, query_input)
148
- print(f'Prompt Generated...')
149
-
150
- try:
151
- print(f'Generating SQL... {model.device}')
152
- result = generate_sql(prompt)
153
- print('SQL Generated...')
154
- except Exception as e:
155
- return {
156
- table_schema: schema,
157
- input_prompt: prompt,
158
- generated_query: "",
159
- result_output:pd.DataFrame([{"error": f"❌ Unable to get the SQL query based on the text. {e}"}])
160
- }
161
-
162
  try:
163
- embeddings = embed_query(result)
164
- log2lancedb(embeddings, result)
165
- except Exception as e:
166
- print("Error Generating and Logging Embeddings...")
167
- print(e)
168
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  try:
170
- query_result = conn.sql(result).df()
171
-
172
- except Exception as e:
173
- return {
174
- table_schema: schema,
175
- input_prompt: prompt,
176
- generated_query: result,
177
- result_output:pd.DataFrame([{"error": f"❌ Unable to get the SQL query based on the text. {e}"}])
178
- }
179
-
 
 
 
 
 
 
 
 
 
 
 
180
  return {
181
  table_schema: schema,
182
- input_prompt: prompt,
183
- generated_query: result,
184
- result_output:query_result
185
  }
186
 
187
- # Custom CSS styling
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  custom_css = """
189
  .gradio-container {
190
  background-color: #f0f4f8;
@@ -202,9 +183,11 @@ custom_css = """
202
  }
203
  """
204
 
205
- with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"), css=custom_css) as demo:
 
 
206
  gr.Image("logo.png", label=None, show_label=False, container=False, height=100)
207
-
208
  gr.Markdown("""
209
  <div style='text-align: center;'>
210
  <strong style='font-size: 36px;'>Datajoi SQL Agent</strong>
@@ -214,13 +197,18 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"
214
  """)
215
 
216
  with gr.Row():
217
-
218
- with gr.Column(scale=1, variant='panel'):
219
- schema_dropdown = gr.Dropdown(choices=get_schemas(), label="Select Schema", interactive=True)
220
- tables_dropdown = gr.Dropdown(choices=[], label="Available Tables", value=None)
 
 
 
221
 
222
  with gr.Column(scale=2):
223
- query_input = gr.Textbox(lines=5, label="Text Query", placeholder="Enter your text query here...")
 
 
224
  with gr.Row():
225
  with gr.Column(scale=7):
226
  pass
@@ -229,17 +217,39 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"
229
 
230
  with gr.Tabs():
231
  with gr.Tab("Result"):
232
- result_output = gr.DataFrame(label="Query Results", value=[], interactive=False)
 
 
233
  with gr.Tab("SQL Query"):
234
- generated_query = gr.Textbox(lines=TAB_LINES, label="Generated SQL Query", value="", interactive=False)
 
 
 
 
 
235
  with gr.Tab("Prompt"):
236
- input_prompt = gr.Textbox(lines=TAB_LINES, label="Input Prompt", value="", interactive=False)
 
 
 
 
 
237
  with gr.Tab("Schema"):
238
- table_schema = gr.Textbox(lines=TAB_LINES, label="Table Schema", value="", interactive=False)
239
-
240
- schema_dropdown.change(update_tables, inputs=schema_dropdown, outputs=tables_dropdown)
241
- generate_query_button.click(text2sql, inputs=[tables_dropdown, query_input], outputs=[table_schema, input_prompt, generated_query, result_output])
 
 
 
 
 
 
 
 
 
 
 
242
 
243
  if __name__ == "__main__":
244
  demo.launch()
245
-
 
1
+ import logging
2
  import os
3
+ from typing import Any, Dict, List
4
+
5
  import duckdb
 
 
6
  import gradio as gr
7
+ import lancedb
8
  import pandas as pd
9
  import pyarrow as pa
10
+ from dotenv import load_dotenv
 
 
 
 
11
 
12
+ from src.client import LLMChain, embed_client
13
+ from src.pipelines import SQLPipeline
14
 
15
+ load_dotenv()
16
 
 
 
 
 
17
 
18
+ # ========ENV's========
19
+ MD_TOKEN = os.getenv("MD_TOKEN")
20
+ HF_TOKEN = os.getenv("HF_TOKEN")
21
+ conn = duckdb.connect(f"md:my_db?motherduck_token={MD_TOKEN}", read_only=True)
22
+ LEVEL = "INFO" if not os.getenv("ENV") == "PROD" else "WARNING"
23
+ EMB_URL = os.getenv("EMB_URL")
24
+ EMB_MODEL = os.getenv("EMB_MODEL")
25
+ TAB_LINES = 8
26
+ # =====================
27
 
 
28
 
29
+ logging.basicConfig(
30
+ level=getattr(logging, LEVEL, logging.INFO),
31
+ format="%(asctime)s %(levelname)s %(name)s: %(message)s",
32
+ )
33
+ logger = logging.getLogger(__name__)
34
 
 
 
 
 
35
 
36
+ pipe = SQLPipeline(duckdb=conn, chain=LLMChain())
 
 
 
 
37
 
 
 
38
 
39
+ def _setup_lancedb() -> lancedb.table.Table:
40
+ lance_db = lancedb.connect(
41
+ uri=os.getenv("lancedb_uri"),
42
+ api_key=os.getenv("lancedb_api_key"),
43
+ region=os.getenv("lancedb_region"),
44
+ )
45
+ lance_schema = pa.schema(
46
+ [pa.field("vector", pa.list_(pa.float32())), pa.field("sql-query", pa.utf8())]
47
+ )
48
+ try:
49
+ table = lance_db.create_table(name="SQL-Queries", schema=lance_schema)
50
+ except Exception:
51
+ table = lance_db.open_table(name="SQL-Queries")
52
+ return table
53
 
 
 
 
54
 
55
+ lance_table = _setup_lancedb()
 
 
56
 
 
 
 
57
 
58
+ def get_schemas() -> List:
 
 
59
  schemas = conn.execute("""
60
  SELECT DISTINCT schema_name
61
  FROM information_schema.schemata
62
  WHERE schema_name NOT IN ('information_schema', 'pg_catalog')
63
  """).fetchall()
64
+ return [item[0] for item in schemas]
65
+
66
 
67
+ def get_tables(schema_name: str) -> List:
68
+ tables = conn.execute(
69
+ f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{schema_name}'"
70
+ ).fetchall()
71
  return [table[0] for table in tables]
72
 
73
+
74
+ def update_tables(schema_name: str):
75
  tables = get_tables(schema_name)
76
  return gr.update(choices=tables)
77
 
78
+
79
+ def get_table_schema(table: str) -> str:
80
+ result = conn.sql(
81
+ f"SELECT sql, database_name, schema_name FROM duckdb_tables() where table_name ='{table}';"
82
+ ).df()
83
+ ddl_create = result.iloc[0, 0]
84
+ parent_database = result.iloc[0, 1]
85
+ schema_name = result.iloc[0, 2]
86
  full_path = f"{parent_database}.{schema_name}.{table}"
87
  if schema_name != "main":
88
  old_path = f"{schema_name}.{table}"
 
91
  ddl_create = ddl_create.replace(old_path, full_path)
92
  return ddl_create
93
 
94
+
95
+ def run_pipeline(table: str, query_input: str) -> Dict[str, Any]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  if table is None:
97
+ return _error_response(
98
+ query_input=query_input, message="❌ Please select a table/schema."
99
+ )
100
+
101
+ schema = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  try:
103
+ schema = get_table_schema(table=table)
104
+
105
+ sql, df = pipe.try_sql_with_retries(
106
+ user_question=query_input,
107
+ context=schema,
108
+ )
109
+
110
+ if not sql or df is None:
111
+ return _error_response(
112
+ query_input=query_input,
113
+ schema=schema,
114
+ message="❌ Unable to generate SQL from the input text.",
115
+ )
116
+
117
+ except Exception as exc:
118
+ logger.exception("Pipeline execution failed")
119
+ return _error_response(
120
+ query_input=query_input, schema=schema, message=f"❌ Pipeline error: {exc}"
121
+ )
122
+
123
  try:
124
+ sql_str = f"{query_input}\n{sql.get('sql_query', '')}"
125
+ embeddings = embed_query(sql_str)
126
+ log2lancedb(embeddings, sql_str)
127
+
128
+ except Exception as exc:
129
+ logger.warning("Embedding/logging failed: %s", exc)
130
+
131
+ return {
132
+ table_schema: schema,
133
+ input_prompt: query_input,
134
+ generated_query: sql.get("sql_query", ""),
135
+ result_output: df,
136
+ }
137
+
138
+
139
+ def _error_response(
140
+ *,
141
+ query_input: str,
142
+ message: str,
143
+ schema: str = "",
144
+ ) -> Dict[str, Any]:
145
  return {
146
  table_schema: schema,
147
+ input_prompt: query_input,
148
+ generated_query: "",
149
+ result_output: pd.DataFrame([{"error": message}]),
150
  }
151
 
152
+
153
+ def embed_query(data: str) -> List:
154
+ logger.info(f"Creating Emebeddings {data}")
155
+ try:
156
+ results = embed_client.feature_extraction(text=data, model=EMB_MODEL)
157
+ return results.tolist()
158
+ except Exception as e:
159
+ logger.error(f"Unable to Generate embedding for the given query: {e}")
160
+ return []
161
+
162
+
163
+ def log2lancedb(embeddings: List, sql_query: str) -> None:
164
+ data = [{"sql-query": sql_query, "vector": embeddings}]
165
+ lance_table.add(data)
166
+ logger.info("Added to Lance DB.")
167
+
168
+
169
  custom_css = """
170
  .gradio-container {
171
  background-color: #f0f4f8;
 
183
  }
184
  """
185
 
186
+ with gr.Blocks(
187
+ theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"), css=custom_css
188
+ ) as demo:
189
  gr.Image("logo.png", label=None, show_label=False, container=False, height=100)
190
+
191
  gr.Markdown("""
192
  <div style='text-align: center;'>
193
  <strong style='font-size: 36px;'>Datajoi SQL Agent</strong>
 
197
  """)
198
 
199
  with gr.Row():
200
+ with gr.Column(scale=1, variant="panel"):
201
+ schema_dropdown = gr.Dropdown(
202
+ choices=get_schemas(), label="Select Schema", interactive=True
203
+ )
204
+ tables_dropdown = gr.Dropdown(
205
+ choices=[], label="Available Tables", value=None
206
+ )
207
 
208
  with gr.Column(scale=2):
209
+ query_input = gr.Textbox(
210
+ lines=5, label="Text Query", placeholder="Enter your text query here..."
211
+ )
212
  with gr.Row():
213
  with gr.Column(scale=7):
214
  pass
 
217
 
218
  with gr.Tabs():
219
  with gr.Tab("Result"):
220
+ result_output = gr.DataFrame(
221
+ label="Query Results", value=[], interactive=False
222
+ )
223
  with gr.Tab("SQL Query"):
224
+ generated_query = gr.Textbox(
225
+ lines=TAB_LINES,
226
+ label="Generated SQL Query",
227
+ value="",
228
+ interactive=False,
229
+ )
230
  with gr.Tab("Prompt"):
231
+ input_prompt = gr.Textbox(
232
+ lines=TAB_LINES,
233
+ label="Input Prompt",
234
+ value="",
235
+ interactive=False,
236
+ )
237
  with gr.Tab("Schema"):
238
+ table_schema = gr.Textbox(
239
+ lines=TAB_LINES,
240
+ label="Table Schema",
241
+ value="",
242
+ interactive=False,
243
+ )
244
+
245
+ schema_dropdown.change(
246
+ update_tables, inputs=schema_dropdown, outputs=tables_dropdown
247
+ )
248
+ generate_query_button.click(
249
+ run_pipeline,
250
+ inputs=[tables_dropdown, query_input],
251
+ outputs=[table_schema, input_prompt, generated_query, result_output],
252
+ )
253
 
254
  if __name__ == "__main__":
255
  demo.launch()
 
requirements.txt CHANGED
@@ -1,10 +1,9 @@
1
- accelerate==0.34.2
2
- bitsandbytes==0.44.1
3
- transformers==4.44.2
4
- duckdb==1.1.1
5
- langsmith==0.1.135
6
- langchain==0.3.4
7
- lancedb==0.15.0
8
- sentence-transformers==3.2.1
9
- pyarrow==17.0.0
10
- langchain-huggingface
 
1
+ huggingface-hub==0.35.0
2
+ duckdb==1.3.2
3
+ pandas==2.3.1
4
+ numpy==2.3.2
5
+ pydantic
6
+ python-dotenv
7
+ gradio
8
+ pyarrow
9
+ lancedb
 
src/__init__.py ADDED
File without changes
src/client.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+
5
+ from dotenv import load_dotenv
6
+ from huggingface_hub import InferenceClient
7
+ from pydantic import BaseModel
8
+
9
+ load_dotenv()
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ MAX_RESPONSE_TOKENS = 2048
14
+ TEMPERATURE = 0.9
15
+
16
+ models = json.loads(os.getenv("MODEL_NAMES"))
17
+ providers = json.loads(os.getenv("PROVIDERS"))
18
+ EMB_MODEL = os.getenv("EMB_MODEL")
19
+
20
+
21
+ def _engine_working(engine: InferenceClient) -> bool:
22
+ try:
23
+ engine.chat_completion([{"role": "user", "content": "ping"}], max_tokens=1)
24
+ logger.info("Engine is Working.")
25
+ return True
26
+ except Exception as e:
27
+ logger.exception(f"Engine is not working: {e}")
28
+ return False
29
+
30
+
31
+ def _load_llm_client() -> InferenceClient:
32
+ """
33
+ Attempts to load the provided model from the huggingface endpoint.
34
+
35
+ Returns InferenceClient if successful.
36
+ Raises Exception if no model is available.
37
+ """
38
+ logger.warning("Loading Model...")
39
+ errors = []
40
+ for model in models:
41
+ for provider in providers:
42
+ if isinstance(model, str):
43
+ try:
44
+ logger.info(f"Checking model: {model} provider: {provider}")
45
+ client = InferenceClient(
46
+ model=model,
47
+ timeout=15,
48
+ provider=provider,
49
+ )
50
+ if _engine_working(client):
51
+ logger.info(
52
+ f"The model is loaded : {model} , provider: {provider}"
53
+ )
54
+ return client
55
+ except Exception as e:
56
+ logger.error(
57
+ f"Error loading model {model} provider {provider}: {e}"
58
+ )
59
+ errors.append(str(e))
60
+ raise Exception(f"Unable to load any provided model: {errors}.")
61
+
62
+
63
+ def _load_embedding_client() -> InferenceClient:
64
+ logger.warning("Loading Embedding Model...")
65
+ try:
66
+ emb_client = InferenceClient(timeout=15, model=EMB_MODEL)
67
+ return emb_client
68
+ except Exception as e:
69
+ logger.error(f"Error loading model {EMB_MODEL}: {e}")
70
+ raise Exception("Unable to load the embedding model.")
71
+
72
+
73
+ _default_client = _load_llm_client()
74
+ embed_client = _load_embedding_client()
75
+
76
+
77
+ class LLMChain:
78
+ def __init__(self, client: InferenceClient = _default_client):
79
+ self.client = client
80
+ self.total_tokens = 0
81
+
82
+ def run(
83
+ self,
84
+ system_prompt: str | None = None,
85
+ user_prompt: str | None = None,
86
+ messages: list[dict] | None = None,
87
+ format_name: str | None = None,
88
+ response_format: type[BaseModel] | None = None,
89
+ ) -> str | dict[str, str | int | float | None] | list[str] | None:
90
+ try:
91
+ if system_prompt and user_prompt:
92
+ messages = [
93
+ {"role": "system", "content": system_prompt},
94
+ {"role": "user", "content": user_prompt},
95
+ ]
96
+ elif not messages:
97
+ raise ValueError(
98
+ "Either system_prompt and user_prompt or messages must be provided."
99
+ )
100
+
101
+ llm_response = self.client.chat_completion(
102
+ messages=messages,
103
+ max_tokens=MAX_RESPONSE_TOKENS,
104
+ temperature=TEMPERATURE,
105
+ response_format=(
106
+ {
107
+ "type": "json_schema",
108
+ "json_schema": {
109
+ "name": format_name,
110
+ "schema": response_format.model_json_schema(),
111
+ "strict": True,
112
+ },
113
+ }
114
+ if format_name and response_format
115
+ else None
116
+ ),
117
+ )
118
+ self.total_tokens += llm_response.usage.total_tokens
119
+ analysis = llm_response.choices[0].message.content
120
+ if response_format:
121
+ analysis = json.loads(analysis)
122
+ fields = list(response_format.model_fields.keys())
123
+ if len(fields) == 1:
124
+ return analysis.get(fields[0])
125
+ return {field: analysis.get(field) for field in fields}
126
+
127
+ return analysis
128
+
129
+ except Exception as e:
130
+ logger.error(f"Error during LLM calls: {e}")
131
+ return None
src/models.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+
3
+
4
+ class SQLQueryModel(BaseModel):
5
+ sql_query: str = Field(..., description="SQL query to execute.")
6
+ explanation: str = Field(..., description="Short explanation of the SQL query.")
src/pipelines.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ import pandas as pd
5
+ from duckdb import DuckDBPyConnection
6
+
7
+ from src.models import SQLQueryModel
8
+ from src.prompts import SQL_PROMPT, USER_PROMPT
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ SQL_GENERATION_RETRIES = int(os.getenv("SQL_GENERATION_RETRIES", "5"))
14
+
15
+
16
+ class SQLPipeline:
17
+ def __init__(
18
+ self,
19
+ duckdb: DuckDBPyConnection,
20
+ chain,
21
+ ) -> None:
22
+ self._duckdb = duckdb
23
+ self.chain = chain
24
+
25
+ def generate_sql(
26
+ self, user_question: str, context: str, errors: str | None = None
27
+ ) -> str | dict[str, str | int | float | None] | list[str] | None:
28
+ """Generate SQL + description."""
29
+ user_prompt_formatted = USER_PROMPT.format(
30
+ question=user_question, context=context
31
+ )
32
+ if errors:
33
+ user_prompt_formatted += f"Carefully review the previous error or\
34
+ exception and rewrite the SQL so that the error does not occur again.\
35
+ Try a different approach or rewrite SQL if needed. Last error: {errors}"
36
+
37
+ sql = self.chain.run(
38
+ system_prompt=SQL_PROMPT,
39
+ user_prompt=user_prompt_formatted,
40
+ format_name="sql_query",
41
+ response_format=SQLQueryModel,
42
+ )
43
+ logger.info(f"SQL Generated Successfully: {sql}")
44
+ return sql
45
+
46
+ def run_query(self, sql_query: str) -> pd.DataFrame | None:
47
+ """Execute SQL and return dataframe."""
48
+ logger.info("Query Execution Started.")
49
+ return self._duckdb.query(sql_query).df()
50
+
51
+ def try_sql_with_retries(
52
+ self,
53
+ user_question: str,
54
+ context: str,
55
+ max_retries: int = SQL_GENERATION_RETRIES,
56
+ ) -> tuple[
57
+ str | dict[str, str | int | float | None] | list[str] | None,
58
+ pd.DataFrame | None,
59
+ ]:
60
+ """Try SQL generation + execution with retries."""
61
+ last_error = None
62
+ all_errors = ""
63
+
64
+ for attempt in range(
65
+ 1, max_retries + 2
66
+ ): # @ Since the first is normal and not consider in retries
67
+ try:
68
+ if attempt > 1 and last_error:
69
+ logger.info(f"Retrying: {attempt - 1}")
70
+ # Generate SQL
71
+ sql = self.generate_sql(user_question, context, errors=all_errors)
72
+ if not sql:
73
+ return None, None
74
+ else:
75
+ # Generate SQL
76
+ sql = self.generate_sql(user_question, context)
77
+ if not sql:
78
+ return None, None
79
+
80
+ # Try executing query
81
+ sql_query_str = sql.get("sql_query") if isinstance(sql, dict) else sql
82
+ if not isinstance(sql_query_str, str):
83
+ raise ValueError(
84
+ f"Expected SQL query to be a string, got {type(sql_query_str).__name__}"
85
+ )
86
+ query_df = self.run_query(sql_query_str)
87
+
88
+ # If execution succeeds, stop retrying or if df is not empty
89
+ if query_df is not None and not query_df.empty:
90
+ return sql, query_df
91
+
92
+ except Exception as e:
93
+ last_error = f"\nAttempt {attempt - 1}] {type(e).__name__}: {e}"
94
+ logger.error(f"Error during SQL generation or execution: {last_error}")
95
+ all_errors += last_error
96
+
97
+ logger.error(f"Failed after {max_retries} attempts. Last error: {all_errors}")
98
+ return None, None
src/prompts.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ USER_PROMPT = """User's Text Question:
2
+ {question}
3
+
4
+ Provided table context information:
5
+ {context}"""
6
+
7
+
8
+ SQL_PROMPT = """You are an expert Text-to-SQL assistant. Convert the user's natural-language request into a single, read-only, syntactically valid DuckDB SQL SELECT statement that runs against the provided schema (the schema will be supplied as CREATE TABLE DDL). Use the exact table and column names from the schema.
9
+
10
+ Return two things:
11
+ 1. The SQL statement.
12
+ 2. A short natural-language description (1-2 sentences) of what the query returns.
13
+
14
+ Rules:
15
+ 1. Output MUST be a single SELECT query. JOINs, subqueries, aggregations, GROUP BY, ORDER BY, and LIMIT are allowed.
16
+ 2. Do NOT generate any DML/DDL (INSERT, UPDATE, DELETE, DROP, etc.) or non-read operations.
17
+ 3. Use DuckDB SQL functions and syntax. For date/time grouping, use DATE_TRUNC('unit', column) (e.g., 'month', 'day', 'year').
18
+ 4. Prefer explicit column lists. Use SELECT * only if the user explicitly requests all columns.
19
+ 5. Make the query robust and maintainable, so it can be reused or adapted for similar analyses.
20
+ 6. After execution in the downstream pipeline, if an error occurs (available as `Last Error` with a short description), analyze that error and rewrite the SQL to resolve it while preserving the user's intent. The rewritten query must still be valid DuckDB SQL.
21
+ 7. If the user requests a distribution/histogram, return SQL that selects a single numeric column only, so binning can be performed downstream.
22
+ """