Muhammad Mustehson commited on
Commit
1a436de
·
1 Parent(s): 9f8e201

initial Draft

Browse files
Files changed (16) hide show
  1. .gitignore +18 -0
  2. app.py +430 -52
  3. audits/.gitkeep +0 -0
  4. config.yaml +19 -0
  5. database/.gitkeep +0 -0
  6. logo.png +0 -0
  7. macros/.gitkeep +0 -0
  8. models/.gitkeep +0 -0
  9. pytest.ini +4 -0
  10. requirements.txt +280 -0
  11. seeds/.gitkeep +0 -0
  12. src/__init__.py +0 -0
  13. src/client.py +119 -0
  14. src/models.py +12 -0
  15. src/pipelines.py +128 -0
  16. src/prompts.py +71 -0
.gitignore ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ # Virtual environments
10
+ .venv
11
+ .cache
12
+ .pytest_cache
13
+ .env
14
+ pyproject.toml
15
+ uv.lock
16
+
17
+
18
+ *.duckdb
app.py CHANGED
@@ -1,70 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
3
 
 
 
4
 
5
- def respond(
6
- message,
7
- history: list[dict[str, str]],
8
- system_message,
9
- max_tokens,
10
- temperature,
11
- top_p,
12
- hf_token: gr.OAuthToken,
13
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  """
15
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  """
17
- client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
18
 
19
- messages = [{"role": "system", "content": system_message}]
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- messages.extend(history)
22
 
23
- messages.append({"role": "user", "content": message})
 
 
 
24
 
25
- response = ""
26
 
27
- for message in client.chat_completion(
28
- messages,
29
- max_tokens=max_tokens,
30
- stream=True,
31
- temperature=temperature,
32
- top_p=top_p,
33
- ):
34
- choices = message.choices
35
- token = ""
36
- if len(choices) and choices[0].delta.content:
37
- token = choices[0].delta.content
 
 
 
 
38
 
39
- response += token
40
- yield response
41
 
 
 
 
 
 
 
 
 
 
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- chatbot = gr.ChatInterface(
47
- respond,
48
- type="messages",
49
- additional_inputs=[
50
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
51
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
52
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
53
- gr.Slider(
54
- minimum=0.1,
55
- maximum=1.0,
56
- value=0.95,
57
- step=0.05,
58
- label="Top-p (nucleus sampling)",
59
- ),
60
- ],
61
- )
62
 
63
- with gr.Blocks() as demo:
64
- with gr.Sidebar():
65
- gr.LoginButton()
66
- chatbot.render()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  if __name__ == "__main__":
70
- demo.launch()
 
1
+ import io
2
+ import logging
3
+ import os
4
+ import shutil
5
+ import sys
6
+ import tempfile
7
+ import uuid
8
+ from pathlib import Path
9
+ from typing import Tuple
10
+
11
+ import duckdb
12
  import gradio as gr
13
+ import pandas as pd
14
+ import pytest
15
+ import requests
16
+ from dotenv import load_dotenv
17
 
18
+ from src.client import LLMChain
19
+ from src.pipelines import Query2Schema
20
 
21
+ load_dotenv()
22
+
23
+
24
+ LEVEL = "INFO" if not os.getenv("ENV") == "PROD" else "WARNING"
25
+ logging.basicConfig(
26
+ level=getattr(logging, LEVEL, logging.INFO),
27
+ format="%(asctime)s %(levelname)s %(name)s: %(message)s",
28
+ )
29
+ logger = logging.getLogger(__name__)
30
+
31
+ if not Path("/tmp").exists():
32
+ os.mkdir("/tmp")
33
+
34
+
35
+ def create_conn(url: str, save_path: str):
36
+ try:
37
+ response = requests.get(url, stream=True)
38
+ response.raise_for_status()
39
+
40
+ with open(save_path, "wb") as out_file:
41
+ shutil.copyfileobj(response.raw, out_file)
42
+
43
+ return duckdb.connect(database=save_path)
44
+ except Exception as e:
45
+ logger.error(f"Error downloading database: {e}")
46
+ raise
47
+
48
+
49
+ if not Path("database/chinook.duckdb").exists():
50
+ conn = create_conn(
51
+ url="https://raw.githubusercontent.com/RandomFractals/duckdb-sql-tools/main/data/chinook/duckdb/chinook.duckdb",
52
+ save_path="database/chinook.duckdb",
53
+ )
54
+
55
+ pipe = Query2Schema(duckdb=conn, chain=LLMChain())
56
+
57
+
58
+ def get_tables_names(schema_name):
59
+ tables = conn.execute("SELECT table_name FROM information_schema.tables").fetchall()
60
+ return [table[0] for table in tables]
61
+
62
+
63
+ def update_table_names(schema_name):
64
+ tables = get_tables_names(schema_name)
65
+ return gr.update(choices=tables, value=tables[0] if tables else None)
66
+
67
+
68
+ def update_column_names(table_name):
69
+ columns = conn.execute(
70
+ f"SELECT column_name FROM information_schema.columns WHERE table_name = '{table_name}' "
71
+ ).fetchall()
72
+ columns = [column[0] for column in columns]
73
+ df = pd.DataFrame(columns, columns=["Column Names"])
74
+ # return gr.update(
75
+ # choices=columns,
76
+ # value=columns[0] if columns else None
77
+ # )
78
+ return df
79
+
80
+
81
+ def get_ddl(table: str) -> str:
82
+ result = conn.sql(
83
+ f"SELECT sql, database_name, schema_name FROM duckdb_tables() where table_name ='{table}';"
84
+ ).df()
85
+ ddl_create = result.iloc[0, 0]
86
+ parent_database = result.iloc[0, 1]
87
+ schema_name = result.iloc[0, 2]
88
+ full_path = f"{parent_database}.{schema_name}.{table}"
89
+ if schema_name != "main":
90
+ old_path = f"{schema_name}.{table}"
91
+ else:
92
+ old_path = table
93
+ ddl_create = ddl_create.replace(old_path, full_path)
94
+ return ddl_create
95
+
96
+
97
+ def run_pipeline(table: str, query_input: str) -> Tuple[str, pd.DataFrame]:
98
+ try:
99
+ schema = get_ddl(table=table)
100
+ except Exception as e:
101
+ logger.error(f"Failed to fetch DDL for table {table}: {e}")
102
+ raise
103
+ try:
104
+ sql, df = pipe.try_sql_with_retries(
105
+ user_question=query_input,
106
+ context=schema,
107
+ )
108
+ sql = sql.get("sql_query") if isinstance(sql, dict) else sql
109
+ if not sql:
110
+ raise ValueError("SQL generation returned None")
111
+ return sql, df
112
+ except Exception as e:
113
+ logger.error(f"Error generating SQL for table {table}: {e}")
114
+ raise
115
+
116
+
117
+ def create_mesh_model(sql: str, db_name: str = "chinook") -> Tuple[str, str, str]:
118
+ model_name = f"model_{uuid.uuid4().hex[:8]}"
119
+
120
+ # Use catalog.schema.model_name format
121
+ full_model_name = f"{db_name}.{model_name}"
122
+
123
+ MODEL_HEADER = f"""MODEL (
124
+ name {full_model_name},
125
+ kind FULL
126
+ );
127
+ """
128
+ try:
129
+ model_dir = Path("models/")
130
+ model_dir.mkdir(parents=True, exist_ok=True)
131
+
132
+ model_path = model_dir / f"{model_name}.sql"
133
+ model_text = MODEL_HEADER + "\n" + sql.replace("chinook.main.", "")
134
+ model_path.write_text(model_text)
135
+
136
+ return model_text, str(model_path), full_model_name
137
+ except Exception as e:
138
+ logger.error(f"Error creating SQL Mesh model: {e}")
139
+ raise
140
+
141
+
142
+ def create_pandera_schema(
143
+ sql: str, user_instruction: str, model_name: str
144
+ ) -> Tuple[str, str]:
145
+ SCRIPT_HEADER = """
146
+ import pandas as pd
147
+ import pandera.pandas as pa
148
+ from pandera.typing import *
149
+
150
+ import pytest
151
+ from sqlmesh import Context
152
+ from datetime import date
153
+ from pathlib import Path
154
+ import shutil
155
+ import duckdb
156
+
157
+ """
158
+
159
+ MESH_STR = f"""
160
+ @pytest.fixture(scope="session")
161
+ def mesh_context():
162
+
163
+ context = Context(paths=".", gateway="duckdb")
164
+ yield context
165
+
166
+ @pytest.fixture
167
+ def today_str():
168
+ return date.today().isoformat()
169
+
170
+ def test_back_fill(mesh_context, today_str):
171
+ mesh_context.plan(skip_backfill=False, auto_apply=True)
172
+ mesh_context.run(start=today_str, end=today_str)
173
+
174
+ df = mesh_context.fetchdf("SELECT * FROM {model_name} LIMIT 10")
175
+ assert not df.empty
176
  """
177
+ try:
178
+ schema = pipe.generate_pandera_schema(
179
+ sql_query=sql, user_instruction=user_instruction
180
+ )
181
+ test_schema = f"""
182
+
183
+ def test_schema(mesh_context, today_str):
184
+ df = mesh_context.evaluate(
185
+ "{model_name}",
186
+ start=today_str,
187
+ end=today_str,
188
+ execution_time=today_str,
189
+ )
190
+ {schema.split()[1].split("(")[0].strip()}.validate(df)
191
  """
192
+ print(schema)
193
 
194
+ with tempfile.NamedTemporaryFile(
195
+ mode="w",
196
+ prefix="test_",
197
+ suffix=".py",
198
+ delete=False,
199
+ encoding="utf-8",
200
+ ) as f:
201
+ f.write(SCRIPT_HEADER)
202
+ f.write("\n\n")
203
+ f.write(schema)
204
+ f.write("\n\n")
205
+ f.write(MESH_STR)
206
+ f.write("\n\n")
207
+ f.write(test_schema)
208
 
209
+ file_path = Path(f.name)
210
 
211
+ return schema, str(file_path)
212
+ except Exception as e:
213
+ logger.error(f"Error creating Pandera schema: {e}")
214
+ raise
215
 
 
216
 
217
+ def create_test_file(
218
+ table_name: str, db_name: str, sql_instruction: str, user_instruction: str
219
+ ) -> Tuple[str, str, pd.DataFrame, str, str]:
220
+ try:
221
+ sql, df = run_pipeline(table=table_name, query_input=sql_instruction)
222
+ model_text, model_file, model_name = create_mesh_model(sql=sql, db_name=db_name)
223
+ schema, test_file = create_pandera_schema(
224
+ sql=sql,
225
+ user_instruction=user_instruction,
226
+ model_name=model_name,
227
+ )
228
+ return test_file, model_file, df, model_text, schema
229
+ except Exception as e:
230
+ logger.error(f"Error creating test file for table {table_name}: {e}")
231
+ raise
232
 
 
 
233
 
234
+ def run_tests(
235
+ table_name: str, db_name: str, sql_instruction: str, user_instruction: str
236
+ ):
237
+ test_file, model_file, df, model_text, schema = create_test_file(
238
+ table_name=table_name,
239
+ db_name=db_name,
240
+ sql_instruction=sql_instruction,
241
+ user_instruction=user_instruction,
242
+ )
243
 
244
+ capture_out = io.StringIO()
245
+ capture_err = io.StringIO()
246
+
247
+ old_out = sys.stdout
248
+ old_err = sys.stderr
249
+
250
+ sys.stdout = capture_out
251
+ sys.stderr = capture_err
252
+
253
+ try:
254
+ retcode = pytest.main(
255
+ [
256
+ test_file,
257
+ "-s",
258
+ "--tb=short",
259
+ "--disable-warnings",
260
+ "-o",
261
+ "cache_dir=/tmp",
262
+ ]
263
+ )
264
+ except Exception as e:
265
+ sys.stdout = old_out
266
+ sys.stderr = old_err
267
+ return f"Error running tests: {str(e)}", ""
268
+
269
+ sys.stdout = old_out
270
+ sys.stderr = old_err
271
+
272
+ output = capture_out.getvalue() + "\n" + capture_err.getvalue()
273
+
274
+ for f in [test_file, model_file]:
275
+ try:
276
+ os.remove(f)
277
+ except FileNotFoundError:
278
+ pass
279
+
280
+ return output, df, model_text, schema
281
+
282
+
283
+ custom_css = """
284
+ /* --- Overall container --- */
285
+ .gradio-container {
286
+ background-color: #f0f4f8; /* light background */
287
+ font-family: 'Arial', sans-serif;
288
+ }
289
+
290
+ /* --- Logo --- */
291
+ .logo {
292
+ max-width: 200px;
293
+ margin: 20px auto;
294
+ display: block;
295
+ }
296
+
297
+ /* --- Buttons --- */
298
+ .gr-button {
299
+ background-color: #4a90e2 !important; /* primary color */
300
+ font-size: 14px; /* fixed font size */
301
+ padding: 6px 12px !important; /* fixed padding */
302
+ height: 36px !important; /* fixed height */
303
+ min-width: 120px !important; /* fixed width */
304
+ }
305
+ .gr-button:hover {
306
+ background-color: #3a7bc8 !important;
307
+ }
308
+
309
+ /* --- Logs Textbox --- */
310
+ #logs textarea {
311
+ overflow-y: scroll;
312
+ resize: none;
313
+ height: 400px;
314
+ width: 100%;
315
+ font-family: monospace;
316
+ font-size: 13px;
317
+ line-height: 1.4;
318
+ }
319
+
320
+ /* Optional: small spacing between rows */
321
+ .gr-row {
322
+ gap: 10px;
323
+ }
324
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
 
326
+ with gr.Blocks(
327
+ theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"), css=custom_css
328
+ ) as demo:
329
+ gr.Image("logo.png", label=None, show_label=False, container=False, height=100)
330
+
331
+ gr.Markdown(
332
+ """
333
+ <div style='text-align: center;'>
334
+ <strong style='font-size: 36px;'>SQL Test Suite</strong>
335
+ <br>
336
+ <span style='font-size: 20px;'>Automated testing and schema validation for SQL models with LLM.</span>
337
+ </div>
338
+ """
339
+ )
340
+
341
+ with gr.Row():
342
+ with gr.Column(scale=1):
343
+ schema_dropdown = gr.Dropdown(
344
+ choices=["chinook", "northwind"],
345
+ value="chinook",
346
+ label="Select Schema",
347
+ interactive=True,
348
+ )
349
+ tables_dropdown = gr.Dropdown(
350
+ choices=[], label="Available Tables", value=None, interactive=True
351
+ )
352
+ # columns_dropdown = gr.Dropdown(choices=[], label="Available Columns", value=None, interactive=True)
353
+ columns_df = gr.DataFrame(label="Columns", value=[], interactive=False)
354
+ # with gr.Row():
355
+ # generate_result = gr.Button("Run Tests", variant="primary")
356
+
357
+ with gr.Column(scale=3):
358
+ with gr.Row():
359
+ sql_instruction = gr.Textbox(
360
+ lines=3,
361
+ label="Business Metric Query (Plain English)",
362
+ placeholder=(
363
+ "Describe the business question you want to answer.\n"
364
+ "Example: 'Show me the average sales per month.'\n"
365
+ "Example: 'Total revenue by product category for last year.'"
366
+ ),
367
+ )
368
+ with gr.Row():
369
+ user_instruction = gr.Textbox(
370
+ lines=5,
371
+ label="Define Data Quality Level",
372
+ placeholder=(
373
+ "Describe the validation rule and how strict it should be.\n"
374
+ "Example: Validate that the incident_zip column contains valid 5-digit ZIP codes.\n"
375
+ ),
376
+ )
377
+ with gr.Row():
378
+ with gr.Column(scale=7):
379
+ pass
380
+ with gr.Column(scale=1):
381
+ run_tests_btn = gr.Button("▶ Run Tests", variant="primary")
382
+
383
+ with gr.Row():
384
+ with gr.Column():
385
+ with gr.Tabs():
386
+ with gr.Tab("Test Logs"):
387
+ with gr.Row():
388
+ with gr.Column():
389
+ test_logs = gr.Textbox(
390
+ label="Test Logs",
391
+ lines=20,
392
+ max_lines=20,
393
+ interactive=False,
394
+ elem_id="logs",
395
+ )
396
+
397
+ with gr.Tab("SQL Model"):
398
+ with gr.Row():
399
+ with gr.Column():
400
+ sql_model = gr.Textbox(
401
+ label="SQL Model",
402
+ lines=20,
403
+ max_lines=20,
404
+ interactive=False,
405
+ elem_id="sql_model",
406
+ )
407
+
408
+ with gr.Tab("Schema"):
409
+ with gr.Row():
410
+ with gr.Column():
411
+ result_schema = gr.Textbox(
412
+ label="Validation Schema",
413
+ lines=20,
414
+ max_lines=20,
415
+ interactive=False,
416
+ )
417
+
418
+ with gr.Tab("Data"):
419
+ with gr.Row():
420
+ with gr.Column():
421
+ result_data = gr.DataFrame(
422
+ label="Query Result",
423
+ value=[],
424
+ interactive=False,
425
+ )
426
 
427
+ schema_dropdown.change(
428
+ update_table_names, inputs=schema_dropdown, outputs=tables_dropdown
429
+ )
430
+ tables_dropdown.change(
431
+ update_column_names, inputs=tables_dropdown, outputs=columns_df
432
+ )
433
+ run_tests_btn.click(
434
+ run_tests,
435
+ inputs=[
436
+ tables_dropdown,
437
+ schema_dropdown,
438
+ sql_instruction,
439
+ user_instruction,
440
+ ],
441
+ outputs=[test_logs, result_data, sql_model, result_schema],
442
+ )
443
+ demo.load(
444
+ fn=update_table_names, inputs=schema_dropdown, outputs=tables_dropdown
445
+ )
446
 
447
  if __name__ == "__main__":
448
+ demo.launch(debug=True)
audits/.gitkeep ADDED
File without changes
config.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gateways:
2
+ duckdb:
3
+ connection:
4
+ type: duckdb
5
+ catalogs:
6
+ local: 'database/chinook.duckdb'
7
+
8
+ default_gateway: duckdb
9
+
10
+ cache_dir: /tmp
11
+ model_defaults:
12
+ dialect: duckdb
13
+
14
+ linter:
15
+ enabled: true
16
+ rules:
17
+ - ambiguousorinvalidcolumn
18
+ - invalidselectstarexpansion
19
+ - noambiguousprojections
database/.gitkeep ADDED
File without changes
logo.png ADDED
macros/.gitkeep ADDED
File without changes
models/.gitkeep ADDED
File without changes
pytest.ini ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ [pytest]
2
+ filterwarnings =
3
+ ignore::DeprecationWarning
4
+ ignore::PendingDeprecationWarning
requirements.txt ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file was autogenerated by uv via the following command:
2
+ # uv pip compile pyproject.toml -o requirements.txt
3
+ aiofiles==24.1.0
4
+ # via gradio
5
+ annotated-doc==0.0.4
6
+ # via fastapi
7
+ annotated-types==0.7.0
8
+ # via pydantic
9
+ anyio==4.12.1
10
+ # via
11
+ # gradio
12
+ # httpx
13
+ # starlette
14
+ astor==0.8.1
15
+ # via sqlmesh
16
+ asttokens==3.0.1
17
+ # via stack-data
18
+ audioop-lts==0.2.2
19
+ # via gradio
20
+ brotli==1.2.0
21
+ # via gradio
22
+ certifi==2026.1.4
23
+ # via
24
+ # httpcore
25
+ # httpx
26
+ # requests
27
+ charset-normalizer==3.4.4
28
+ # via requests
29
+ click==8.3.1
30
+ # via
31
+ # sqlmesh
32
+ # typer
33
+ # typer-slim
34
+ # uvicorn
35
+ colorama==0.4.6
36
+ # via
37
+ # click
38
+ # ipython
39
+ # pytest
40
+ # tqdm
41
+ comm==0.2.3
42
+ # via ipywidgets
43
+ croniter==6.0.0
44
+ # via sqlmesh
45
+ dateparser==1.2.1
46
+ # via sqlmesh
47
+ decorator==5.2.1
48
+ # via ipython
49
+ duckdb==1.4.3
50
+ # via sqlmesh
51
+ executing==2.2.1
52
+ # via stack-data
53
+ fastapi==0.128.0
54
+ # via gradio
55
+ ffmpy==1.0.0
56
+ # via gradio
57
+ filelock==3.20.3
58
+ # via huggingface-hub
59
+ fsspec==2026.1.0
60
+ # via
61
+ # gradio-client
62
+ # huggingface-hub
63
+ gradio==6.3.0
64
+ # via data-test-demo (pyproject.toml)
65
+ gradio-client==2.0.3
66
+ # via gradio
67
+ groovy==0.1.2
68
+ # via gradio
69
+ h11==0.16.0
70
+ # via
71
+ # httpcore
72
+ # uvicorn
73
+ hf-xet==1.2.0
74
+ # via huggingface-hub
75
+ httpcore==1.0.9
76
+ # via httpx
77
+ httpx==0.28.1
78
+ # via
79
+ # gradio
80
+ # gradio-client
81
+ # huggingface-hub
82
+ # safehttpx
83
+ huggingface-hub==1.3.2
84
+ # via
85
+ # gradio
86
+ # gradio-client
87
+ humanize==4.15.0
88
+ # via sqlmesh
89
+ hyperscript==0.3.0
90
+ # via sqlmesh
91
+ idna==3.11
92
+ # via
93
+ # anyio
94
+ # httpx
95
+ # requests
96
+ iniconfig==2.3.0
97
+ # via pytest
98
+ ipython==9.9.0
99
+ # via ipywidgets
100
+ ipython-pygments-lexers==1.1.1
101
+ # via ipython
102
+ ipywidgets==8.1.8
103
+ # via
104
+ # rich
105
+ # sqlmesh
106
+ jedi==0.19.2
107
+ # via ipython
108
+ jinja2==3.1.6
109
+ # via
110
+ # gradio
111
+ # sqlmesh
112
+ json-stream==2.4.1
113
+ # via sqlmesh
114
+ json-stream-rs-tokenizer==0.5.0
115
+ # via json-stream
116
+ jupyterlab-widgets==3.0.16
117
+ # via ipywidgets
118
+ markdown-it-py==4.0.0
119
+ # via rich
120
+ markupsafe==3.0.3
121
+ # via
122
+ # gradio
123
+ # jinja2
124
+ matplotlib-inline==0.2.1
125
+ # via ipython
126
+ mdurl==0.1.2
127
+ # via markdown-it-py
128
+ mypy-extensions==1.1.0
129
+ # via typing-inspect
130
+ numpy==2.4.1
131
+ # via
132
+ # gradio
133
+ # pandas
134
+ orjson==3.11.5
135
+ # via gradio
136
+ packaging==25.0
137
+ # via
138
+ # gradio
139
+ # gradio-client
140
+ # huggingface-hub
141
+ # pandera
142
+ # pytest
143
+ # sqlmesh
144
+ pandas==2.3.3
145
+ # via
146
+ # gradio
147
+ # sqlmesh
148
+ pandera==0.28.1
149
+ # via data-test-demo (pyproject.toml)
150
+ parso==0.8.5
151
+ # via jedi
152
+ pillow==12.1.0
153
+ # via gradio
154
+ pluggy==1.6.0
155
+ # via pytest
156
+ prompt-toolkit==3.0.52
157
+ # via ipython
158
+ pure-eval==0.2.3
159
+ # via stack-data
160
+ pydantic==2.12.5
161
+ # via
162
+ # fastapi
163
+ # gradio
164
+ # pandera
165
+ # sqlmesh
166
+ pydantic-core==2.41.5
167
+ # via pydantic
168
+ pydub==0.25.1
169
+ # via gradio
170
+ pygments==2.19.2
171
+ # via
172
+ # ipython
173
+ # ipython-pygments-lexers
174
+ # pytest
175
+ # rich
176
+ pymysql==1.1.2
177
+ # via sqlmesh
178
+ pytest==9.0.2
179
+ # via data-test-demo (pyproject.toml)
180
+ python-dateutil==2.9.0.post0
181
+ # via
182
+ # croniter
183
+ # dateparser
184
+ # pandas
185
+ python-dotenv==1.2.1
186
+ # via sqlmesh
187
+ python-multipart==0.0.21
188
+ # via gradio
189
+ pytz==2025.2
190
+ # via
191
+ # croniter
192
+ # dateparser
193
+ # pandas
194
+ pyyaml==6.0.3
195
+ # via
196
+ # gradio
197
+ # huggingface-hub
198
+ regex==2026.1.15
199
+ # via dateparser
200
+ requests==2.32.5
201
+ # via sqlmesh
202
+ rich==14.2.0
203
+ # via
204
+ # sqlmesh
205
+ # typer
206
+ ruamel-yaml==0.19.1
207
+ # via sqlmesh
208
+ safehttpx==0.1.7
209
+ # via gradio
210
+ semantic-version==2.10.0
211
+ # via gradio
212
+ shellingham==1.5.4
213
+ # via
214
+ # huggingface-hub
215
+ # typer
216
+ six==1.17.0
217
+ # via python-dateutil
218
+ sqlglot==27.28.1
219
+ # via sqlmesh
220
+ sqlglotrs==0.7.3
221
+ # via sqlglot
222
+ sqlmesh==0.228.4
223
+ # via data-test-demo (pyproject.toml)
224
+ stack-data==0.6.3
225
+ # via ipython
226
+ starlette==0.50.0
227
+ # via
228
+ # fastapi
229
+ # gradio
230
+ tenacity==9.1.2
231
+ # via sqlmesh
232
+ time-machine==3.2.0
233
+ # via sqlmesh
234
+ tomlkit==0.13.3
235
+ # via gradio
236
+ tqdm==4.67.1
237
+ # via huggingface-hub
238
+ traitlets==5.14.3
239
+ # via
240
+ # ipython
241
+ # ipywidgets
242
+ # matplotlib-inline
243
+ typeguard==4.4.4
244
+ # via pandera
245
+ typer==0.21.1
246
+ # via gradio
247
+ typer-slim==0.21.1
248
+ # via huggingface-hub
249
+ typing-extensions==4.15.0
250
+ # via
251
+ # fastapi
252
+ # gradio
253
+ # gradio-client
254
+ # huggingface-hub
255
+ # pandera
256
+ # pydantic
257
+ # pydantic-core
258
+ # typeguard
259
+ # typer
260
+ # typer-slim
261
+ # typing-inspect
262
+ # typing-inspection
263
+ typing-inspect==0.9.0
264
+ # via pandera
265
+ typing-inspection==0.4.2
266
+ # via pydantic
267
+ tzdata==2025.3
268
+ # via
269
+ # pandas
270
+ # tzlocal
271
+ tzlocal==5.3.1
272
+ # via dateparser
273
+ urllib3==2.6.3
274
+ # via requests
275
+ uvicorn==0.40.0
276
+ # via gradio
277
+ wcwidth==0.2.14
278
+ # via prompt-toolkit
279
+ widgetsnbextension==4.0.15
280
+ # via ipywidgets
seeds/.gitkeep ADDED
File without changes
src/__init__.py ADDED
File without changes
src/client.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+
5
+ from dotenv import load_dotenv
6
+ from huggingface_hub import InferenceClient
7
+ from pydantic import BaseModel
8
+
9
+ load_dotenv()
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ MAX_RESPONSE_TOKENS = 2048
14
+ TEMPERATURE = 0.9
15
+
16
+ models = json.loads(os.getenv("MODEL_NAMES"))
17
+ providers = json.loads(os.getenv("PROVIDERS"))
18
+
19
+
20
+ def _engine_working(engine: InferenceClient) -> bool:
21
+ try:
22
+ engine.chat_completion([{"role": "user", "content": "ping"}], max_tokens=1)
23
+ logger.info("Engine is Working.")
24
+ return True
25
+ except Exception as e:
26
+ logger.exception(f"Engine is not working: {e}")
27
+ return False
28
+
29
+
30
+ def _load_llm_client() -> InferenceClient:
31
+ """
32
+ Attempts to load the provided model from the huggingface endpoint.
33
+
34
+ Returns InferenceClient if successful.
35
+ Raises Exception if no model is available.
36
+ """
37
+ logger.warning("Loading Model...")
38
+ errors = []
39
+ for model in models:
40
+ for provider in providers:
41
+ if isinstance(model, str):
42
+ try:
43
+ logger.info(f"Checking model: {model} provider: {provider}")
44
+ client = InferenceClient(
45
+ model=model,
46
+ timeout=15,
47
+ provider=provider,
48
+ )
49
+ if _engine_working(client):
50
+ logger.info(
51
+ f"The model is loaded : {model} , provider: {provider}"
52
+ )
53
+ return client
54
+ except Exception as e:
55
+ logger.error(
56
+ f"Error loading model {model} provider {provider}: {e}"
57
+ )
58
+ errors.append(str(e))
59
+ raise Exception(f"Unable to load any provided model: {errors}.")
60
+
61
+
62
+ _default_client = _load_llm_client()
63
+
64
+
65
+ class LLMChain:
66
+ def __init__(self, client: InferenceClient = _default_client):
67
+ self.client = client
68
+ self.total_tokens = 0
69
+
70
+ def run(
71
+ self,
72
+ system_prompt: str | None = None,
73
+ user_prompt: str | None = None,
74
+ messages: list[dict] | None = None,
75
+ format_name: str | None = None,
76
+ response_format: type[BaseModel] | None = None,
77
+ ) -> str | dict[str, str | int | float | None] | list[str] | None:
78
+ try:
79
+ if system_prompt and user_prompt:
80
+ messages = [
81
+ {"role": "system", "content": system_prompt},
82
+ {"role": "user", "content": user_prompt},
83
+ ]
84
+ elif not messages:
85
+ raise ValueError(
86
+ "Either system_prompt and user_prompt or messages must be provided."
87
+ )
88
+
89
+ llm_response = self.client.chat_completion(
90
+ messages=messages,
91
+ max_tokens=MAX_RESPONSE_TOKENS,
92
+ temperature=TEMPERATURE,
93
+ response_format=(
94
+ {
95
+ "type": "json_schema",
96
+ "json_schema": {
97
+ "name": format_name,
98
+ "schema": response_format.model_json_schema(),
99
+ "strict": True,
100
+ },
101
+ }
102
+ if format_name and response_format
103
+ else None
104
+ ),
105
+ )
106
+ self.total_tokens += llm_response.usage.total_tokens
107
+ analysis = llm_response.choices[0].message.content
108
+ if response_format:
109
+ analysis = json.loads(analysis)
110
+ fields = list(response_format.model_fields.keys())
111
+ if len(fields) == 1:
112
+ return analysis.get(fields[0])
113
+ return {field: analysis.get(field) for field in fields}
114
+
115
+ return analysis
116
+
117
+ except Exception as e:
118
+ logger.error(f"Error during LLM calls: {e}")
119
+ return None
src/models.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+
3
+
4
+ class SQLQueryModel(BaseModel):
5
+ sql_query: str = Field(..., description="SQL query to execute.")
6
+ explanation: str = Field(..., description="Short explanation of the SQL query.")
7
+
8
+
9
+ class PanderaSchemaModel(BaseModel):
10
+ schema_name: str = Field(
11
+ ..., description="Only Pandera schema to validate the data."
12
+ )
src/pipelines.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import logging
3
+ import os
4
+
5
+ import pandas as pd
6
+ from dotenv import load_dotenv
7
+ from duckdb import DuckDBPyConnection
8
+
9
+ from src.models import PanderaSchemaModel, SQLQueryModel
10
+
11
+ load_dotenv()
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ SQL_GENERATION_RETRIES = int(os.getenv("SQL_GENERATION_RETRIES", "5"))
17
+ PANDERA_PROMPT = os.getenv("PANDERA_PROMPT")
18
+ PANDERA_USER_PROMPT = os.getenv("PANDERA_USER_PROMPT")
19
+ SQL_PROMPT = os.getenv("SQL_PROMPT")
20
+ USER_PROMPT = os.getenv("USER_PROMPT")
21
+
22
+
23
+ class Query2Schema:
24
+ def __init__(
25
+ self,
26
+ duckdb: DuckDBPyConnection,
27
+ chain,
28
+ ) -> None:
29
+ self._duckdb = duckdb
30
+ self.chain = chain
31
+
32
+ def generate_sql(
33
+ self, user_question: str, context: str, errors: str | None = None
34
+ ) -> str | dict[str, str | int | float | None] | list[str] | None:
35
+ """Generate SQL + description."""
36
+ user_prompt_formatted = USER_PROMPT.format(
37
+ question=user_question, context=context
38
+ )
39
+ if errors:
40
+ user_prompt_formatted += f"Carefully review the previous error or\
41
+ exception and rewrite the SQL so that the error does not occur again.\
42
+ Try a different approach or rewrite SQL if needed. Last error: {errors}"
43
+
44
+ sql = self.chain.run(
45
+ system_prompt=SQL_PROMPT,
46
+ user_prompt=user_prompt_formatted,
47
+ format_name="sql_query",
48
+ response_format=SQLQueryModel,
49
+ )
50
+ logger.info(f"SQL Generated Successfully: {sql}")
51
+ return sql
52
+
53
+ def run_query(self, sql_query: str) -> pd.DataFrame | None:
54
+ """Execute SQL and return dataframe."""
55
+ logger.info("Query Execution Started.")
56
+ return self._duckdb.query(sql_query).df()
57
+
58
+ def try_sql_with_retries(
59
+ self,
60
+ user_question: str,
61
+ context: str,
62
+ max_retries: int = SQL_GENERATION_RETRIES,
63
+ ) -> tuple[
64
+ str | dict[str, str | int | float | None] | list[str] | None,
65
+ pd.DataFrame | None,
66
+ ]:
67
+ """Try SQL generation + execution with retries."""
68
+ last_error = None
69
+ all_errors = ""
70
+
71
+ for attempt in range(
72
+ 1, max_retries + 2
73
+ ): # @ Since the first is normal and not consider in retries
74
+ try:
75
+ if attempt > 1 and last_error:
76
+ logger.info(f"Retrying: {attempt - 1}")
77
+ # Generate SQL
78
+ sql = self.generate_sql(user_question, context, errors=all_errors)
79
+ if not sql:
80
+ return None, None
81
+ else:
82
+ # Generate SQL
83
+ sql = self.generate_sql(user_question, context)
84
+ if not sql:
85
+ return None, None
86
+
87
+ # Try executing query
88
+ sql_query_str = sql.get("sql_query") if isinstance(sql, dict) else sql
89
+ if not isinstance(sql_query_str, str):
90
+ raise ValueError(
91
+ f"Expected SQL query to be a string, got {type(sql_query_str).__name__}"
92
+ )
93
+ query_df = self.run_query(sql_query_str)
94
+
95
+ # If execution succeeds, stop retrying or if df is not empty
96
+ if query_df is not None and not query_df.empty:
97
+ return sql, query_df
98
+
99
+ except Exception as e:
100
+ last_error = f"\nAttempt {attempt - 1}] {type(e).__name__}: {e}"
101
+ logger.error(f"Error during SQL generation or execution: {last_error}")
102
+ all_errors += last_error
103
+
104
+ logger.error(f"Failed after {max_retries} attempts. Last error: {all_errors}")
105
+ return None, None
106
+
107
+ def generate_pandera_schema(self, sql_query: str, user_instruction: str) -> str:
108
+ """Generate Pandera schema."""
109
+ class_lines = []
110
+
111
+ schema_str = self.chain.run(
112
+ system_prompt=PANDERA_PROMPT,
113
+ user_prompt=PANDERA_USER_PROMPT.format(
114
+ sql_query=sql_query, instructions=user_instruction
115
+ ),
116
+ format_name="pandera_schema",
117
+ response_format=PanderaSchemaModel,
118
+ )
119
+
120
+ parsed = ast.parse(schema_str)
121
+
122
+ original_lines = schema_str.splitlines()
123
+ for node in parsed.body:
124
+ if isinstance(node, ast.ClassDef):
125
+ start, end = node.lineno - 1, node.end_lineno
126
+ class_lines.extend(original_lines[start:end])
127
+
128
+ return "\n".join(class_lines)
src/prompts.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ USER_PROMPT = """User's Text Question:
2
+ {question}
3
+
4
+ Provided table context information:
5
+ {context}"""
6
+
7
+
8
+ SQL_PROMPT = """You are an expert Text-to-SQL assistant. Convert the user's natural-language request into a single, read-only, syntactically valid DuckDB SQL SELECT statement that runs against the provided schema (the schema will be supplied as CREATE TABLE DDL). Use the exact table and column names from the schema.
9
+
10
+ Return two things:
11
+ 1. The SQL statement.
12
+ 2. A short natural-language description (1-2 sentences) of what the query returns.
13
+
14
+ Rules:
15
+ 1. Output MUST be a single SELECT query. JOINs, subqueries, aggregations, GROUP BY, ORDER BY, and LIMIT are allowed.
16
+ 2. Do NOT generate any DML/DDL (INSERT, UPDATE, DELETE, DROP, etc.) or non-read operations.
17
+ 3. Use DuckDB SQL functions and syntax. For date/time grouping, use DATE_TRUNC('unit', column) (e.g., 'month', 'day', 'year').
18
+ 4. Prefer explicit column lists. Use SELECT * only if the user explicitly requests all columns.
19
+ 5. Make the query robust and maintainable, so it can be reused or adapted for similar analyses.
20
+ 6. After execution in the downstream pipeline, if an error occurs (available as `Last Error` with a short description), analyze that error and rewrite the SQL to resolve it while preserving the user's intent. The rewritten query must still be valid DuckDB SQL.
21
+ """
22
+
23
+ PANDERA_PROMPT = """You are provided with a SQL query which is used to fetch data from a database. Your task is to generate a valid Pandera SchemaModel class that can be used to validate the resulting data from the query.
24
+ The generated schema should be **general and simple**, not overly complex. Only validate basic aspects like column types, nullability, and simple value constraints (like positive integers, string patterns, or ranges) since you only have the SQL query and the resulting column names/types.
25
+
26
+ Follow these guidelines:
27
+ 1. **Use Pandera SchemaModel**:
28
+ - Each column should have a type hint using `Series[Type]`.
29
+ - Use `pa.Field` to define simple validations.
30
+ 2. **Validation rules should be simple and reasonable**:
31
+ - `nullable` for optional columns
32
+ - `unique` for IDs if obvious
33
+ - `gt`/`ge`/`lt`/`le` for numeric ranges if reasonable
34
+ - `str_matches`, `str_length` or `str_contains` for string patterns (like ZIP codes or emails)
35
+ - Avoid complex cross-column or statistical checks
36
+ 3. **Add Config class**:
37
+ - Set `coerce = True` to cast data types automatically
38
+ 4. **Add optional metadata**:
39
+ - Include `description` for columns if possible
40
+ - Include `title` for columns if it helps
41
+ 5. **Output only valid Python code**:
42
+ - The output should be a **single Python class definition**.
43
+ - Do not include any explanations, comments, or extra text.
44
+ 6. **Example Output**:
45
+
46
+
47
+ import pandas as pd
48
+ import pandera as pa
49
+ from pandera.typing import Series
50
+
51
+ class CustomerSchema(pa.DataFrameModel):
52
+ customer_id: Series[int] = pa.Field(gt=0, unique=True, nullable=False, description="Unique customer identifier")
53
+ first_name: Series[str] = pa.Field(nullable=False, str_length=(1, 50), description="Customer first name")
54
+ last_name: Series[str] = pa.Field(nullable=False, str_length=(1, 50), description="Customer last name")
55
+ email: Series[str] = pa.Field(nullable=False, str_matches=r"^[\\w\\.-]+@[\\w\\.-]+\\.\\w+$", description="Customer email address")
56
+ age: Series[int] = pa.Field(ge=0, le=120, nullable=True, description="Customer age in years")
57
+
58
+ class Config:
59
+ coerce = True
60
+
61
+ Additional notes:
62
+ If the SQL query uses JOIN, only include columns that appear in the SELECT statement.
63
+ You may infer basic constraints from column names (e.g., columns ending with _id are likely unique integers).
64
+ Avoid domain-specific logic unless it is obvious from the column names or SQL query.
65
+ Keep the schema robust but simple, suitable for automated ETL validation."""
66
+
67
+ PANDERA_USER_PROMPT = """SQL Query:
68
+ {sql_query}
69
+
70
+ User Instructions:
71
+ {instructions}"""