Spaces:
Sleeping
Sleeping
Muhammad Mustehson
commited on
Commit
·
1a436de
1
Parent(s):
9f8e201
initial Draft
Browse files- .gitignore +18 -0
- app.py +430 -52
- audits/.gitkeep +0 -0
- config.yaml +19 -0
- database/.gitkeep +0 -0
- logo.png +0 -0
- macros/.gitkeep +0 -0
- models/.gitkeep +0 -0
- pytest.ini +4 -0
- requirements.txt +280 -0
- seeds/.gitkeep +0 -0
- src/__init__.py +0 -0
- src/client.py +119 -0
- src/models.py +12 -0
- src/pipelines.py +128 -0
- src/prompts.py +71 -0
.gitignore
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python-generated files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[oc]
|
| 4 |
+
build/
|
| 5 |
+
dist/
|
| 6 |
+
wheels/
|
| 7 |
+
*.egg-info
|
| 8 |
+
|
| 9 |
+
# Virtual environments
|
| 10 |
+
.venv
|
| 11 |
+
.cache
|
| 12 |
+
.pytest_cache
|
| 13 |
+
.env
|
| 14 |
+
pyproject.toml
|
| 15 |
+
uv.lock
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
*.duckdb
|
app.py
CHANGED
|
@@ -1,70 +1,448 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
| 3 |
|
|
|
|
|
|
|
| 4 |
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
"""
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
"""
|
| 17 |
-
|
| 18 |
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
-
|
| 22 |
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
-
response = ""
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
-
response += token
|
| 40 |
-
yield response
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
"""
|
| 44 |
-
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
|
| 45 |
-
"""
|
| 46 |
-
chatbot = gr.ChatInterface(
|
| 47 |
-
respond,
|
| 48 |
-
type="messages",
|
| 49 |
-
additional_inputs=[
|
| 50 |
-
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
|
| 51 |
-
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
|
| 52 |
-
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
|
| 53 |
-
gr.Slider(
|
| 54 |
-
minimum=0.1,
|
| 55 |
-
maximum=1.0,
|
| 56 |
-
value=0.95,
|
| 57 |
-
step=0.05,
|
| 58 |
-
label="Top-p (nucleus sampling)",
|
| 59 |
-
),
|
| 60 |
-
],
|
| 61 |
-
)
|
| 62 |
|
| 63 |
-
with gr.Blocks(
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
if __name__ == "__main__":
|
| 70 |
-
demo.launch()
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import shutil
|
| 5 |
+
import sys
|
| 6 |
+
import tempfile
|
| 7 |
+
import uuid
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Tuple
|
| 10 |
+
|
| 11 |
+
import duckdb
|
| 12 |
import gradio as gr
|
| 13 |
+
import pandas as pd
|
| 14 |
+
import pytest
|
| 15 |
+
import requests
|
| 16 |
+
from dotenv import load_dotenv
|
| 17 |
|
| 18 |
+
from src.client import LLMChain
|
| 19 |
+
from src.pipelines import Query2Schema
|
| 20 |
|
| 21 |
+
load_dotenv()
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
LEVEL = "INFO" if not os.getenv("ENV") == "PROD" else "WARNING"
|
| 25 |
+
logging.basicConfig(
|
| 26 |
+
level=getattr(logging, LEVEL, logging.INFO),
|
| 27 |
+
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
| 28 |
+
)
|
| 29 |
+
logger = logging.getLogger(__name__)
|
| 30 |
+
|
| 31 |
+
if not Path("/tmp").exists():
|
| 32 |
+
os.mkdir("/tmp")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def create_conn(url: str, save_path: str):
|
| 36 |
+
try:
|
| 37 |
+
response = requests.get(url, stream=True)
|
| 38 |
+
response.raise_for_status()
|
| 39 |
+
|
| 40 |
+
with open(save_path, "wb") as out_file:
|
| 41 |
+
shutil.copyfileobj(response.raw, out_file)
|
| 42 |
+
|
| 43 |
+
return duckdb.connect(database=save_path)
|
| 44 |
+
except Exception as e:
|
| 45 |
+
logger.error(f"Error downloading database: {e}")
|
| 46 |
+
raise
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
if not Path("database/chinook.duckdb").exists():
|
| 50 |
+
conn = create_conn(
|
| 51 |
+
url="https://raw.githubusercontent.com/RandomFractals/duckdb-sql-tools/main/data/chinook/duckdb/chinook.duckdb",
|
| 52 |
+
save_path="database/chinook.duckdb",
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
pipe = Query2Schema(duckdb=conn, chain=LLMChain())
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def get_tables_names(schema_name):
|
| 59 |
+
tables = conn.execute("SELECT table_name FROM information_schema.tables").fetchall()
|
| 60 |
+
return [table[0] for table in tables]
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def update_table_names(schema_name):
|
| 64 |
+
tables = get_tables_names(schema_name)
|
| 65 |
+
return gr.update(choices=tables, value=tables[0] if tables else None)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def update_column_names(table_name):
|
| 69 |
+
columns = conn.execute(
|
| 70 |
+
f"SELECT column_name FROM information_schema.columns WHERE table_name = '{table_name}' "
|
| 71 |
+
).fetchall()
|
| 72 |
+
columns = [column[0] for column in columns]
|
| 73 |
+
df = pd.DataFrame(columns, columns=["Column Names"])
|
| 74 |
+
# return gr.update(
|
| 75 |
+
# choices=columns,
|
| 76 |
+
# value=columns[0] if columns else None
|
| 77 |
+
# )
|
| 78 |
+
return df
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def get_ddl(table: str) -> str:
|
| 82 |
+
result = conn.sql(
|
| 83 |
+
f"SELECT sql, database_name, schema_name FROM duckdb_tables() where table_name ='{table}';"
|
| 84 |
+
).df()
|
| 85 |
+
ddl_create = result.iloc[0, 0]
|
| 86 |
+
parent_database = result.iloc[0, 1]
|
| 87 |
+
schema_name = result.iloc[0, 2]
|
| 88 |
+
full_path = f"{parent_database}.{schema_name}.{table}"
|
| 89 |
+
if schema_name != "main":
|
| 90 |
+
old_path = f"{schema_name}.{table}"
|
| 91 |
+
else:
|
| 92 |
+
old_path = table
|
| 93 |
+
ddl_create = ddl_create.replace(old_path, full_path)
|
| 94 |
+
return ddl_create
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def run_pipeline(table: str, query_input: str) -> Tuple[str, pd.DataFrame]:
|
| 98 |
+
try:
|
| 99 |
+
schema = get_ddl(table=table)
|
| 100 |
+
except Exception as e:
|
| 101 |
+
logger.error(f"Failed to fetch DDL for table {table}: {e}")
|
| 102 |
+
raise
|
| 103 |
+
try:
|
| 104 |
+
sql, df = pipe.try_sql_with_retries(
|
| 105 |
+
user_question=query_input,
|
| 106 |
+
context=schema,
|
| 107 |
+
)
|
| 108 |
+
sql = sql.get("sql_query") if isinstance(sql, dict) else sql
|
| 109 |
+
if not sql:
|
| 110 |
+
raise ValueError("SQL generation returned None")
|
| 111 |
+
return sql, df
|
| 112 |
+
except Exception as e:
|
| 113 |
+
logger.error(f"Error generating SQL for table {table}: {e}")
|
| 114 |
+
raise
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def create_mesh_model(sql: str, db_name: str = "chinook") -> Tuple[str, str, str]:
|
| 118 |
+
model_name = f"model_{uuid.uuid4().hex[:8]}"
|
| 119 |
+
|
| 120 |
+
# Use catalog.schema.model_name format
|
| 121 |
+
full_model_name = f"{db_name}.{model_name}"
|
| 122 |
+
|
| 123 |
+
MODEL_HEADER = f"""MODEL (
|
| 124 |
+
name {full_model_name},
|
| 125 |
+
kind FULL
|
| 126 |
+
);
|
| 127 |
+
"""
|
| 128 |
+
try:
|
| 129 |
+
model_dir = Path("models/")
|
| 130 |
+
model_dir.mkdir(parents=True, exist_ok=True)
|
| 131 |
+
|
| 132 |
+
model_path = model_dir / f"{model_name}.sql"
|
| 133 |
+
model_text = MODEL_HEADER + "\n" + sql.replace("chinook.main.", "")
|
| 134 |
+
model_path.write_text(model_text)
|
| 135 |
+
|
| 136 |
+
return model_text, str(model_path), full_model_name
|
| 137 |
+
except Exception as e:
|
| 138 |
+
logger.error(f"Error creating SQL Mesh model: {e}")
|
| 139 |
+
raise
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def create_pandera_schema(
|
| 143 |
+
sql: str, user_instruction: str, model_name: str
|
| 144 |
+
) -> Tuple[str, str]:
|
| 145 |
+
SCRIPT_HEADER = """
|
| 146 |
+
import pandas as pd
|
| 147 |
+
import pandera.pandas as pa
|
| 148 |
+
from pandera.typing import *
|
| 149 |
+
|
| 150 |
+
import pytest
|
| 151 |
+
from sqlmesh import Context
|
| 152 |
+
from datetime import date
|
| 153 |
+
from pathlib import Path
|
| 154 |
+
import shutil
|
| 155 |
+
import duckdb
|
| 156 |
+
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
+
MESH_STR = f"""
|
| 160 |
+
@pytest.fixture(scope="session")
|
| 161 |
+
def mesh_context():
|
| 162 |
+
|
| 163 |
+
context = Context(paths=".", gateway="duckdb")
|
| 164 |
+
yield context
|
| 165 |
+
|
| 166 |
+
@pytest.fixture
|
| 167 |
+
def today_str():
|
| 168 |
+
return date.today().isoformat()
|
| 169 |
+
|
| 170 |
+
def test_back_fill(mesh_context, today_str):
|
| 171 |
+
mesh_context.plan(skip_backfill=False, auto_apply=True)
|
| 172 |
+
mesh_context.run(start=today_str, end=today_str)
|
| 173 |
+
|
| 174 |
+
df = mesh_context.fetchdf("SELECT * FROM {model_name} LIMIT 10")
|
| 175 |
+
assert not df.empty
|
| 176 |
"""
|
| 177 |
+
try:
|
| 178 |
+
schema = pipe.generate_pandera_schema(
|
| 179 |
+
sql_query=sql, user_instruction=user_instruction
|
| 180 |
+
)
|
| 181 |
+
test_schema = f"""
|
| 182 |
+
|
| 183 |
+
def test_schema(mesh_context, today_str):
|
| 184 |
+
df = mesh_context.evaluate(
|
| 185 |
+
"{model_name}",
|
| 186 |
+
start=today_str,
|
| 187 |
+
end=today_str,
|
| 188 |
+
execution_time=today_str,
|
| 189 |
+
)
|
| 190 |
+
{schema.split()[1].split("(")[0].strip()}.validate(df)
|
| 191 |
"""
|
| 192 |
+
print(schema)
|
| 193 |
|
| 194 |
+
with tempfile.NamedTemporaryFile(
|
| 195 |
+
mode="w",
|
| 196 |
+
prefix="test_",
|
| 197 |
+
suffix=".py",
|
| 198 |
+
delete=False,
|
| 199 |
+
encoding="utf-8",
|
| 200 |
+
) as f:
|
| 201 |
+
f.write(SCRIPT_HEADER)
|
| 202 |
+
f.write("\n\n")
|
| 203 |
+
f.write(schema)
|
| 204 |
+
f.write("\n\n")
|
| 205 |
+
f.write(MESH_STR)
|
| 206 |
+
f.write("\n\n")
|
| 207 |
+
f.write(test_schema)
|
| 208 |
|
| 209 |
+
file_path = Path(f.name)
|
| 210 |
|
| 211 |
+
return schema, str(file_path)
|
| 212 |
+
except Exception as e:
|
| 213 |
+
logger.error(f"Error creating Pandera schema: {e}")
|
| 214 |
+
raise
|
| 215 |
|
|
|
|
| 216 |
|
| 217 |
+
def create_test_file(
|
| 218 |
+
table_name: str, db_name: str, sql_instruction: str, user_instruction: str
|
| 219 |
+
) -> Tuple[str, str, pd.DataFrame, str, str]:
|
| 220 |
+
try:
|
| 221 |
+
sql, df = run_pipeline(table=table_name, query_input=sql_instruction)
|
| 222 |
+
model_text, model_file, model_name = create_mesh_model(sql=sql, db_name=db_name)
|
| 223 |
+
schema, test_file = create_pandera_schema(
|
| 224 |
+
sql=sql,
|
| 225 |
+
user_instruction=user_instruction,
|
| 226 |
+
model_name=model_name,
|
| 227 |
+
)
|
| 228 |
+
return test_file, model_file, df, model_text, schema
|
| 229 |
+
except Exception as e:
|
| 230 |
+
logger.error(f"Error creating test file for table {table_name}: {e}")
|
| 231 |
+
raise
|
| 232 |
|
|
|
|
|
|
|
| 233 |
|
| 234 |
+
def run_tests(
|
| 235 |
+
table_name: str, db_name: str, sql_instruction: str, user_instruction: str
|
| 236 |
+
):
|
| 237 |
+
test_file, model_file, df, model_text, schema = create_test_file(
|
| 238 |
+
table_name=table_name,
|
| 239 |
+
db_name=db_name,
|
| 240 |
+
sql_instruction=sql_instruction,
|
| 241 |
+
user_instruction=user_instruction,
|
| 242 |
+
)
|
| 243 |
|
| 244 |
+
capture_out = io.StringIO()
|
| 245 |
+
capture_err = io.StringIO()
|
| 246 |
+
|
| 247 |
+
old_out = sys.stdout
|
| 248 |
+
old_err = sys.stderr
|
| 249 |
+
|
| 250 |
+
sys.stdout = capture_out
|
| 251 |
+
sys.stderr = capture_err
|
| 252 |
+
|
| 253 |
+
try:
|
| 254 |
+
retcode = pytest.main(
|
| 255 |
+
[
|
| 256 |
+
test_file,
|
| 257 |
+
"-s",
|
| 258 |
+
"--tb=short",
|
| 259 |
+
"--disable-warnings",
|
| 260 |
+
"-o",
|
| 261 |
+
"cache_dir=/tmp",
|
| 262 |
+
]
|
| 263 |
+
)
|
| 264 |
+
except Exception as e:
|
| 265 |
+
sys.stdout = old_out
|
| 266 |
+
sys.stderr = old_err
|
| 267 |
+
return f"Error running tests: {str(e)}", ""
|
| 268 |
+
|
| 269 |
+
sys.stdout = old_out
|
| 270 |
+
sys.stderr = old_err
|
| 271 |
+
|
| 272 |
+
output = capture_out.getvalue() + "\n" + capture_err.getvalue()
|
| 273 |
+
|
| 274 |
+
for f in [test_file, model_file]:
|
| 275 |
+
try:
|
| 276 |
+
os.remove(f)
|
| 277 |
+
except FileNotFoundError:
|
| 278 |
+
pass
|
| 279 |
+
|
| 280 |
+
return output, df, model_text, schema
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
custom_css = """
|
| 284 |
+
/* --- Overall container --- */
|
| 285 |
+
.gradio-container {
|
| 286 |
+
background-color: #f0f4f8; /* light background */
|
| 287 |
+
font-family: 'Arial', sans-serif;
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
/* --- Logo --- */
|
| 291 |
+
.logo {
|
| 292 |
+
max-width: 200px;
|
| 293 |
+
margin: 20px auto;
|
| 294 |
+
display: block;
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
/* --- Buttons --- */
|
| 298 |
+
.gr-button {
|
| 299 |
+
background-color: #4a90e2 !important; /* primary color */
|
| 300 |
+
font-size: 14px; /* fixed font size */
|
| 301 |
+
padding: 6px 12px !important; /* fixed padding */
|
| 302 |
+
height: 36px !important; /* fixed height */
|
| 303 |
+
min-width: 120px !important; /* fixed width */
|
| 304 |
+
}
|
| 305 |
+
.gr-button:hover {
|
| 306 |
+
background-color: #3a7bc8 !important;
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
/* --- Logs Textbox --- */
|
| 310 |
+
#logs textarea {
|
| 311 |
+
overflow-y: scroll;
|
| 312 |
+
resize: none;
|
| 313 |
+
height: 400px;
|
| 314 |
+
width: 100%;
|
| 315 |
+
font-family: monospace;
|
| 316 |
+
font-size: 13px;
|
| 317 |
+
line-height: 1.4;
|
| 318 |
+
}
|
| 319 |
+
|
| 320 |
+
/* Optional: small spacing between rows */
|
| 321 |
+
.gr-row {
|
| 322 |
+
gap: 10px;
|
| 323 |
+
}
|
| 324 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 325 |
|
| 326 |
+
with gr.Blocks(
|
| 327 |
+
theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"), css=custom_css
|
| 328 |
+
) as demo:
|
| 329 |
+
gr.Image("logo.png", label=None, show_label=False, container=False, height=100)
|
| 330 |
+
|
| 331 |
+
gr.Markdown(
|
| 332 |
+
"""
|
| 333 |
+
<div style='text-align: center;'>
|
| 334 |
+
<strong style='font-size: 36px;'>SQL Test Suite</strong>
|
| 335 |
+
<br>
|
| 336 |
+
<span style='font-size: 20px;'>Automated testing and schema validation for SQL models with LLM.</span>
|
| 337 |
+
</div>
|
| 338 |
+
"""
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
with gr.Row():
|
| 342 |
+
with gr.Column(scale=1):
|
| 343 |
+
schema_dropdown = gr.Dropdown(
|
| 344 |
+
choices=["chinook", "northwind"],
|
| 345 |
+
value="chinook",
|
| 346 |
+
label="Select Schema",
|
| 347 |
+
interactive=True,
|
| 348 |
+
)
|
| 349 |
+
tables_dropdown = gr.Dropdown(
|
| 350 |
+
choices=[], label="Available Tables", value=None, interactive=True
|
| 351 |
+
)
|
| 352 |
+
# columns_dropdown = gr.Dropdown(choices=[], label="Available Columns", value=None, interactive=True)
|
| 353 |
+
columns_df = gr.DataFrame(label="Columns", value=[], interactive=False)
|
| 354 |
+
# with gr.Row():
|
| 355 |
+
# generate_result = gr.Button("Run Tests", variant="primary")
|
| 356 |
+
|
| 357 |
+
with gr.Column(scale=3):
|
| 358 |
+
with gr.Row():
|
| 359 |
+
sql_instruction = gr.Textbox(
|
| 360 |
+
lines=3,
|
| 361 |
+
label="Business Metric Query (Plain English)",
|
| 362 |
+
placeholder=(
|
| 363 |
+
"Describe the business question you want to answer.\n"
|
| 364 |
+
"Example: 'Show me the average sales per month.'\n"
|
| 365 |
+
"Example: 'Total revenue by product category for last year.'"
|
| 366 |
+
),
|
| 367 |
+
)
|
| 368 |
+
with gr.Row():
|
| 369 |
+
user_instruction = gr.Textbox(
|
| 370 |
+
lines=5,
|
| 371 |
+
label="Define Data Quality Level",
|
| 372 |
+
placeholder=(
|
| 373 |
+
"Describe the validation rule and how strict it should be.\n"
|
| 374 |
+
"Example: Validate that the incident_zip column contains valid 5-digit ZIP codes.\n"
|
| 375 |
+
),
|
| 376 |
+
)
|
| 377 |
+
with gr.Row():
|
| 378 |
+
with gr.Column(scale=7):
|
| 379 |
+
pass
|
| 380 |
+
with gr.Column(scale=1):
|
| 381 |
+
run_tests_btn = gr.Button("▶ Run Tests", variant="primary")
|
| 382 |
+
|
| 383 |
+
with gr.Row():
|
| 384 |
+
with gr.Column():
|
| 385 |
+
with gr.Tabs():
|
| 386 |
+
with gr.Tab("Test Logs"):
|
| 387 |
+
with gr.Row():
|
| 388 |
+
with gr.Column():
|
| 389 |
+
test_logs = gr.Textbox(
|
| 390 |
+
label="Test Logs",
|
| 391 |
+
lines=20,
|
| 392 |
+
max_lines=20,
|
| 393 |
+
interactive=False,
|
| 394 |
+
elem_id="logs",
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
with gr.Tab("SQL Model"):
|
| 398 |
+
with gr.Row():
|
| 399 |
+
with gr.Column():
|
| 400 |
+
sql_model = gr.Textbox(
|
| 401 |
+
label="SQL Model",
|
| 402 |
+
lines=20,
|
| 403 |
+
max_lines=20,
|
| 404 |
+
interactive=False,
|
| 405 |
+
elem_id="sql_model",
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
with gr.Tab("Schema"):
|
| 409 |
+
with gr.Row():
|
| 410 |
+
with gr.Column():
|
| 411 |
+
result_schema = gr.Textbox(
|
| 412 |
+
label="Validation Schema",
|
| 413 |
+
lines=20,
|
| 414 |
+
max_lines=20,
|
| 415 |
+
interactive=False,
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
with gr.Tab("Data"):
|
| 419 |
+
with gr.Row():
|
| 420 |
+
with gr.Column():
|
| 421 |
+
result_data = gr.DataFrame(
|
| 422 |
+
label="Query Result",
|
| 423 |
+
value=[],
|
| 424 |
+
interactive=False,
|
| 425 |
+
)
|
| 426 |
|
| 427 |
+
schema_dropdown.change(
|
| 428 |
+
update_table_names, inputs=schema_dropdown, outputs=tables_dropdown
|
| 429 |
+
)
|
| 430 |
+
tables_dropdown.change(
|
| 431 |
+
update_column_names, inputs=tables_dropdown, outputs=columns_df
|
| 432 |
+
)
|
| 433 |
+
run_tests_btn.click(
|
| 434 |
+
run_tests,
|
| 435 |
+
inputs=[
|
| 436 |
+
tables_dropdown,
|
| 437 |
+
schema_dropdown,
|
| 438 |
+
sql_instruction,
|
| 439 |
+
user_instruction,
|
| 440 |
+
],
|
| 441 |
+
outputs=[test_logs, result_data, sql_model, result_schema],
|
| 442 |
+
)
|
| 443 |
+
demo.load(
|
| 444 |
+
fn=update_table_names, inputs=schema_dropdown, outputs=tables_dropdown
|
| 445 |
+
)
|
| 446 |
|
| 447 |
if __name__ == "__main__":
|
| 448 |
+
demo.launch(debug=True)
|
audits/.gitkeep
ADDED
|
File without changes
|
config.yaml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gateways:
|
| 2 |
+
duckdb:
|
| 3 |
+
connection:
|
| 4 |
+
type: duckdb
|
| 5 |
+
catalogs:
|
| 6 |
+
local: 'database/chinook.duckdb'
|
| 7 |
+
|
| 8 |
+
default_gateway: duckdb
|
| 9 |
+
|
| 10 |
+
cache_dir: /tmp
|
| 11 |
+
model_defaults:
|
| 12 |
+
dialect: duckdb
|
| 13 |
+
|
| 14 |
+
linter:
|
| 15 |
+
enabled: true
|
| 16 |
+
rules:
|
| 17 |
+
- ambiguousorinvalidcolumn
|
| 18 |
+
- invalidselectstarexpansion
|
| 19 |
+
- noambiguousprojections
|
database/.gitkeep
ADDED
|
File without changes
|
logo.png
ADDED
|
macros/.gitkeep
ADDED
|
File without changes
|
models/.gitkeep
ADDED
|
File without changes
|
pytest.ini
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[pytest]
|
| 2 |
+
filterwarnings =
|
| 3 |
+
ignore::DeprecationWarning
|
| 4 |
+
ignore::PendingDeprecationWarning
|
requirements.txt
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file was autogenerated by uv via the following command:
|
| 2 |
+
# uv pip compile pyproject.toml -o requirements.txt
|
| 3 |
+
aiofiles==24.1.0
|
| 4 |
+
# via gradio
|
| 5 |
+
annotated-doc==0.0.4
|
| 6 |
+
# via fastapi
|
| 7 |
+
annotated-types==0.7.0
|
| 8 |
+
# via pydantic
|
| 9 |
+
anyio==4.12.1
|
| 10 |
+
# via
|
| 11 |
+
# gradio
|
| 12 |
+
# httpx
|
| 13 |
+
# starlette
|
| 14 |
+
astor==0.8.1
|
| 15 |
+
# via sqlmesh
|
| 16 |
+
asttokens==3.0.1
|
| 17 |
+
# via stack-data
|
| 18 |
+
audioop-lts==0.2.2
|
| 19 |
+
# via gradio
|
| 20 |
+
brotli==1.2.0
|
| 21 |
+
# via gradio
|
| 22 |
+
certifi==2026.1.4
|
| 23 |
+
# via
|
| 24 |
+
# httpcore
|
| 25 |
+
# httpx
|
| 26 |
+
# requests
|
| 27 |
+
charset-normalizer==3.4.4
|
| 28 |
+
# via requests
|
| 29 |
+
click==8.3.1
|
| 30 |
+
# via
|
| 31 |
+
# sqlmesh
|
| 32 |
+
# typer
|
| 33 |
+
# typer-slim
|
| 34 |
+
# uvicorn
|
| 35 |
+
colorama==0.4.6
|
| 36 |
+
# via
|
| 37 |
+
# click
|
| 38 |
+
# ipython
|
| 39 |
+
# pytest
|
| 40 |
+
# tqdm
|
| 41 |
+
comm==0.2.3
|
| 42 |
+
# via ipywidgets
|
| 43 |
+
croniter==6.0.0
|
| 44 |
+
# via sqlmesh
|
| 45 |
+
dateparser==1.2.1
|
| 46 |
+
# via sqlmesh
|
| 47 |
+
decorator==5.2.1
|
| 48 |
+
# via ipython
|
| 49 |
+
duckdb==1.4.3
|
| 50 |
+
# via sqlmesh
|
| 51 |
+
executing==2.2.1
|
| 52 |
+
# via stack-data
|
| 53 |
+
fastapi==0.128.0
|
| 54 |
+
# via gradio
|
| 55 |
+
ffmpy==1.0.0
|
| 56 |
+
# via gradio
|
| 57 |
+
filelock==3.20.3
|
| 58 |
+
# via huggingface-hub
|
| 59 |
+
fsspec==2026.1.0
|
| 60 |
+
# via
|
| 61 |
+
# gradio-client
|
| 62 |
+
# huggingface-hub
|
| 63 |
+
gradio==6.3.0
|
| 64 |
+
# via data-test-demo (pyproject.toml)
|
| 65 |
+
gradio-client==2.0.3
|
| 66 |
+
# via gradio
|
| 67 |
+
groovy==0.1.2
|
| 68 |
+
# via gradio
|
| 69 |
+
h11==0.16.0
|
| 70 |
+
# via
|
| 71 |
+
# httpcore
|
| 72 |
+
# uvicorn
|
| 73 |
+
hf-xet==1.2.0
|
| 74 |
+
# via huggingface-hub
|
| 75 |
+
httpcore==1.0.9
|
| 76 |
+
# via httpx
|
| 77 |
+
httpx==0.28.1
|
| 78 |
+
# via
|
| 79 |
+
# gradio
|
| 80 |
+
# gradio-client
|
| 81 |
+
# huggingface-hub
|
| 82 |
+
# safehttpx
|
| 83 |
+
huggingface-hub==1.3.2
|
| 84 |
+
# via
|
| 85 |
+
# gradio
|
| 86 |
+
# gradio-client
|
| 87 |
+
humanize==4.15.0
|
| 88 |
+
# via sqlmesh
|
| 89 |
+
hyperscript==0.3.0
|
| 90 |
+
# via sqlmesh
|
| 91 |
+
idna==3.11
|
| 92 |
+
# via
|
| 93 |
+
# anyio
|
| 94 |
+
# httpx
|
| 95 |
+
# requests
|
| 96 |
+
iniconfig==2.3.0
|
| 97 |
+
# via pytest
|
| 98 |
+
ipython==9.9.0
|
| 99 |
+
# via ipywidgets
|
| 100 |
+
ipython-pygments-lexers==1.1.1
|
| 101 |
+
# via ipython
|
| 102 |
+
ipywidgets==8.1.8
|
| 103 |
+
# via
|
| 104 |
+
# rich
|
| 105 |
+
# sqlmesh
|
| 106 |
+
jedi==0.19.2
|
| 107 |
+
# via ipython
|
| 108 |
+
jinja2==3.1.6
|
| 109 |
+
# via
|
| 110 |
+
# gradio
|
| 111 |
+
# sqlmesh
|
| 112 |
+
json-stream==2.4.1
|
| 113 |
+
# via sqlmesh
|
| 114 |
+
json-stream-rs-tokenizer==0.5.0
|
| 115 |
+
# via json-stream
|
| 116 |
+
jupyterlab-widgets==3.0.16
|
| 117 |
+
# via ipywidgets
|
| 118 |
+
markdown-it-py==4.0.0
|
| 119 |
+
# via rich
|
| 120 |
+
markupsafe==3.0.3
|
| 121 |
+
# via
|
| 122 |
+
# gradio
|
| 123 |
+
# jinja2
|
| 124 |
+
matplotlib-inline==0.2.1
|
| 125 |
+
# via ipython
|
| 126 |
+
mdurl==0.1.2
|
| 127 |
+
# via markdown-it-py
|
| 128 |
+
mypy-extensions==1.1.0
|
| 129 |
+
# via typing-inspect
|
| 130 |
+
numpy==2.4.1
|
| 131 |
+
# via
|
| 132 |
+
# gradio
|
| 133 |
+
# pandas
|
| 134 |
+
orjson==3.11.5
|
| 135 |
+
# via gradio
|
| 136 |
+
packaging==25.0
|
| 137 |
+
# via
|
| 138 |
+
# gradio
|
| 139 |
+
# gradio-client
|
| 140 |
+
# huggingface-hub
|
| 141 |
+
# pandera
|
| 142 |
+
# pytest
|
| 143 |
+
# sqlmesh
|
| 144 |
+
pandas==2.3.3
|
| 145 |
+
# via
|
| 146 |
+
# gradio
|
| 147 |
+
# sqlmesh
|
| 148 |
+
pandera==0.28.1
|
| 149 |
+
# via data-test-demo (pyproject.toml)
|
| 150 |
+
parso==0.8.5
|
| 151 |
+
# via jedi
|
| 152 |
+
pillow==12.1.0
|
| 153 |
+
# via gradio
|
| 154 |
+
pluggy==1.6.0
|
| 155 |
+
# via pytest
|
| 156 |
+
prompt-toolkit==3.0.52
|
| 157 |
+
# via ipython
|
| 158 |
+
pure-eval==0.2.3
|
| 159 |
+
# via stack-data
|
| 160 |
+
pydantic==2.12.5
|
| 161 |
+
# via
|
| 162 |
+
# fastapi
|
| 163 |
+
# gradio
|
| 164 |
+
# pandera
|
| 165 |
+
# sqlmesh
|
| 166 |
+
pydantic-core==2.41.5
|
| 167 |
+
# via pydantic
|
| 168 |
+
pydub==0.25.1
|
| 169 |
+
# via gradio
|
| 170 |
+
pygments==2.19.2
|
| 171 |
+
# via
|
| 172 |
+
# ipython
|
| 173 |
+
# ipython-pygments-lexers
|
| 174 |
+
# pytest
|
| 175 |
+
# rich
|
| 176 |
+
pymysql==1.1.2
|
| 177 |
+
# via sqlmesh
|
| 178 |
+
pytest==9.0.2
|
| 179 |
+
# via data-test-demo (pyproject.toml)
|
| 180 |
+
python-dateutil==2.9.0.post0
|
| 181 |
+
# via
|
| 182 |
+
# croniter
|
| 183 |
+
# dateparser
|
| 184 |
+
# pandas
|
| 185 |
+
python-dotenv==1.2.1
|
| 186 |
+
# via sqlmesh
|
| 187 |
+
python-multipart==0.0.21
|
| 188 |
+
# via gradio
|
| 189 |
+
pytz==2025.2
|
| 190 |
+
# via
|
| 191 |
+
# croniter
|
| 192 |
+
# dateparser
|
| 193 |
+
# pandas
|
| 194 |
+
pyyaml==6.0.3
|
| 195 |
+
# via
|
| 196 |
+
# gradio
|
| 197 |
+
# huggingface-hub
|
| 198 |
+
regex==2026.1.15
|
| 199 |
+
# via dateparser
|
| 200 |
+
requests==2.32.5
|
| 201 |
+
# via sqlmesh
|
| 202 |
+
rich==14.2.0
|
| 203 |
+
# via
|
| 204 |
+
# sqlmesh
|
| 205 |
+
# typer
|
| 206 |
+
ruamel-yaml==0.19.1
|
| 207 |
+
# via sqlmesh
|
| 208 |
+
safehttpx==0.1.7
|
| 209 |
+
# via gradio
|
| 210 |
+
semantic-version==2.10.0
|
| 211 |
+
# via gradio
|
| 212 |
+
shellingham==1.5.4
|
| 213 |
+
# via
|
| 214 |
+
# huggingface-hub
|
| 215 |
+
# typer
|
| 216 |
+
six==1.17.0
|
| 217 |
+
# via python-dateutil
|
| 218 |
+
sqlglot==27.28.1
|
| 219 |
+
# via sqlmesh
|
| 220 |
+
sqlglotrs==0.7.3
|
| 221 |
+
# via sqlglot
|
| 222 |
+
sqlmesh==0.228.4
|
| 223 |
+
# via data-test-demo (pyproject.toml)
|
| 224 |
+
stack-data==0.6.3
|
| 225 |
+
# via ipython
|
| 226 |
+
starlette==0.50.0
|
| 227 |
+
# via
|
| 228 |
+
# fastapi
|
| 229 |
+
# gradio
|
| 230 |
+
tenacity==9.1.2
|
| 231 |
+
# via sqlmesh
|
| 232 |
+
time-machine==3.2.0
|
| 233 |
+
# via sqlmesh
|
| 234 |
+
tomlkit==0.13.3
|
| 235 |
+
# via gradio
|
| 236 |
+
tqdm==4.67.1
|
| 237 |
+
# via huggingface-hub
|
| 238 |
+
traitlets==5.14.3
|
| 239 |
+
# via
|
| 240 |
+
# ipython
|
| 241 |
+
# ipywidgets
|
| 242 |
+
# matplotlib-inline
|
| 243 |
+
typeguard==4.4.4
|
| 244 |
+
# via pandera
|
| 245 |
+
typer==0.21.1
|
| 246 |
+
# via gradio
|
| 247 |
+
typer-slim==0.21.1
|
| 248 |
+
# via huggingface-hub
|
| 249 |
+
typing-extensions==4.15.0
|
| 250 |
+
# via
|
| 251 |
+
# fastapi
|
| 252 |
+
# gradio
|
| 253 |
+
# gradio-client
|
| 254 |
+
# huggingface-hub
|
| 255 |
+
# pandera
|
| 256 |
+
# pydantic
|
| 257 |
+
# pydantic-core
|
| 258 |
+
# typeguard
|
| 259 |
+
# typer
|
| 260 |
+
# typer-slim
|
| 261 |
+
# typing-inspect
|
| 262 |
+
# typing-inspection
|
| 263 |
+
typing-inspect==0.9.0
|
| 264 |
+
# via pandera
|
| 265 |
+
typing-inspection==0.4.2
|
| 266 |
+
# via pydantic
|
| 267 |
+
tzdata==2025.3
|
| 268 |
+
# via
|
| 269 |
+
# pandas
|
| 270 |
+
# tzlocal
|
| 271 |
+
tzlocal==5.3.1
|
| 272 |
+
# via dateparser
|
| 273 |
+
urllib3==2.6.3
|
| 274 |
+
# via requests
|
| 275 |
+
uvicorn==0.40.0
|
| 276 |
+
# via gradio
|
| 277 |
+
wcwidth==0.2.14
|
| 278 |
+
# via prompt-toolkit
|
| 279 |
+
widgetsnbextension==4.0.15
|
| 280 |
+
# via ipywidgets
|
seeds/.gitkeep
ADDED
|
File without changes
|
src/__init__.py
ADDED
|
File without changes
|
src/client.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
from dotenv import load_dotenv
|
| 6 |
+
from huggingface_hub import InferenceClient
|
| 7 |
+
from pydantic import BaseModel
|
| 8 |
+
|
| 9 |
+
load_dotenv()
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
MAX_RESPONSE_TOKENS = 2048
|
| 14 |
+
TEMPERATURE = 0.9
|
| 15 |
+
|
| 16 |
+
models = json.loads(os.getenv("MODEL_NAMES"))
|
| 17 |
+
providers = json.loads(os.getenv("PROVIDERS"))
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _engine_working(engine: InferenceClient) -> bool:
|
| 21 |
+
try:
|
| 22 |
+
engine.chat_completion([{"role": "user", "content": "ping"}], max_tokens=1)
|
| 23 |
+
logger.info("Engine is Working.")
|
| 24 |
+
return True
|
| 25 |
+
except Exception as e:
|
| 26 |
+
logger.exception(f"Engine is not working: {e}")
|
| 27 |
+
return False
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _load_llm_client() -> InferenceClient:
|
| 31 |
+
"""
|
| 32 |
+
Attempts to load the provided model from the huggingface endpoint.
|
| 33 |
+
|
| 34 |
+
Returns InferenceClient if successful.
|
| 35 |
+
Raises Exception if no model is available.
|
| 36 |
+
"""
|
| 37 |
+
logger.warning("Loading Model...")
|
| 38 |
+
errors = []
|
| 39 |
+
for model in models:
|
| 40 |
+
for provider in providers:
|
| 41 |
+
if isinstance(model, str):
|
| 42 |
+
try:
|
| 43 |
+
logger.info(f"Checking model: {model} provider: {provider}")
|
| 44 |
+
client = InferenceClient(
|
| 45 |
+
model=model,
|
| 46 |
+
timeout=15,
|
| 47 |
+
provider=provider,
|
| 48 |
+
)
|
| 49 |
+
if _engine_working(client):
|
| 50 |
+
logger.info(
|
| 51 |
+
f"The model is loaded : {model} , provider: {provider}"
|
| 52 |
+
)
|
| 53 |
+
return client
|
| 54 |
+
except Exception as e:
|
| 55 |
+
logger.error(
|
| 56 |
+
f"Error loading model {model} provider {provider}: {e}"
|
| 57 |
+
)
|
| 58 |
+
errors.append(str(e))
|
| 59 |
+
raise Exception(f"Unable to load any provided model: {errors}.")
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
_default_client = _load_llm_client()
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class LLMChain:
|
| 66 |
+
def __init__(self, client: InferenceClient = _default_client):
|
| 67 |
+
self.client = client
|
| 68 |
+
self.total_tokens = 0
|
| 69 |
+
|
| 70 |
+
def run(
|
| 71 |
+
self,
|
| 72 |
+
system_prompt: str | None = None,
|
| 73 |
+
user_prompt: str | None = None,
|
| 74 |
+
messages: list[dict] | None = None,
|
| 75 |
+
format_name: str | None = None,
|
| 76 |
+
response_format: type[BaseModel] | None = None,
|
| 77 |
+
) -> str | dict[str, str | int | float | None] | list[str] | None:
|
| 78 |
+
try:
|
| 79 |
+
if system_prompt and user_prompt:
|
| 80 |
+
messages = [
|
| 81 |
+
{"role": "system", "content": system_prompt},
|
| 82 |
+
{"role": "user", "content": user_prompt},
|
| 83 |
+
]
|
| 84 |
+
elif not messages:
|
| 85 |
+
raise ValueError(
|
| 86 |
+
"Either system_prompt and user_prompt or messages must be provided."
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
llm_response = self.client.chat_completion(
|
| 90 |
+
messages=messages,
|
| 91 |
+
max_tokens=MAX_RESPONSE_TOKENS,
|
| 92 |
+
temperature=TEMPERATURE,
|
| 93 |
+
response_format=(
|
| 94 |
+
{
|
| 95 |
+
"type": "json_schema",
|
| 96 |
+
"json_schema": {
|
| 97 |
+
"name": format_name,
|
| 98 |
+
"schema": response_format.model_json_schema(),
|
| 99 |
+
"strict": True,
|
| 100 |
+
},
|
| 101 |
+
}
|
| 102 |
+
if format_name and response_format
|
| 103 |
+
else None
|
| 104 |
+
),
|
| 105 |
+
)
|
| 106 |
+
self.total_tokens += llm_response.usage.total_tokens
|
| 107 |
+
analysis = llm_response.choices[0].message.content
|
| 108 |
+
if response_format:
|
| 109 |
+
analysis = json.loads(analysis)
|
| 110 |
+
fields = list(response_format.model_fields.keys())
|
| 111 |
+
if len(fields) == 1:
|
| 112 |
+
return analysis.get(fields[0])
|
| 113 |
+
return {field: analysis.get(field) for field in fields}
|
| 114 |
+
|
| 115 |
+
return analysis
|
| 116 |
+
|
| 117 |
+
except Exception as e:
|
| 118 |
+
logger.error(f"Error during LLM calls: {e}")
|
| 119 |
+
return None
|
src/models.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, Field
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class SQLQueryModel(BaseModel):
|
| 5 |
+
sql_query: str = Field(..., description="SQL query to execute.")
|
| 6 |
+
explanation: str = Field(..., description="Short explanation of the SQL query.")
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class PanderaSchemaModel(BaseModel):
|
| 10 |
+
schema_name: str = Field(
|
| 11 |
+
..., description="Only Pandera schema to validate the data."
|
| 12 |
+
)
|
src/pipelines.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ast
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
from duckdb import DuckDBPyConnection
|
| 8 |
+
|
| 9 |
+
from src.models import PanderaSchemaModel, SQLQueryModel
|
| 10 |
+
|
| 11 |
+
load_dotenv()
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
SQL_GENERATION_RETRIES = int(os.getenv("SQL_GENERATION_RETRIES", "5"))
|
| 17 |
+
PANDERA_PROMPT = os.getenv("PANDERA_PROMPT")
|
| 18 |
+
PANDERA_USER_PROMPT = os.getenv("PANDERA_USER_PROMPT")
|
| 19 |
+
SQL_PROMPT = os.getenv("SQL_PROMPT")
|
| 20 |
+
USER_PROMPT = os.getenv("USER_PROMPT")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class Query2Schema:
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
duckdb: DuckDBPyConnection,
|
| 27 |
+
chain,
|
| 28 |
+
) -> None:
|
| 29 |
+
self._duckdb = duckdb
|
| 30 |
+
self.chain = chain
|
| 31 |
+
|
| 32 |
+
def generate_sql(
|
| 33 |
+
self, user_question: str, context: str, errors: str | None = None
|
| 34 |
+
) -> str | dict[str, str | int | float | None] | list[str] | None:
|
| 35 |
+
"""Generate SQL + description."""
|
| 36 |
+
user_prompt_formatted = USER_PROMPT.format(
|
| 37 |
+
question=user_question, context=context
|
| 38 |
+
)
|
| 39 |
+
if errors:
|
| 40 |
+
user_prompt_formatted += f"Carefully review the previous error or\
|
| 41 |
+
exception and rewrite the SQL so that the error does not occur again.\
|
| 42 |
+
Try a different approach or rewrite SQL if needed. Last error: {errors}"
|
| 43 |
+
|
| 44 |
+
sql = self.chain.run(
|
| 45 |
+
system_prompt=SQL_PROMPT,
|
| 46 |
+
user_prompt=user_prompt_formatted,
|
| 47 |
+
format_name="sql_query",
|
| 48 |
+
response_format=SQLQueryModel,
|
| 49 |
+
)
|
| 50 |
+
logger.info(f"SQL Generated Successfully: {sql}")
|
| 51 |
+
return sql
|
| 52 |
+
|
| 53 |
+
def run_query(self, sql_query: str) -> pd.DataFrame | None:
|
| 54 |
+
"""Execute SQL and return dataframe."""
|
| 55 |
+
logger.info("Query Execution Started.")
|
| 56 |
+
return self._duckdb.query(sql_query).df()
|
| 57 |
+
|
| 58 |
+
def try_sql_with_retries(
|
| 59 |
+
self,
|
| 60 |
+
user_question: str,
|
| 61 |
+
context: str,
|
| 62 |
+
max_retries: int = SQL_GENERATION_RETRIES,
|
| 63 |
+
) -> tuple[
|
| 64 |
+
str | dict[str, str | int | float | None] | list[str] | None,
|
| 65 |
+
pd.DataFrame | None,
|
| 66 |
+
]:
|
| 67 |
+
"""Try SQL generation + execution with retries."""
|
| 68 |
+
last_error = None
|
| 69 |
+
all_errors = ""
|
| 70 |
+
|
| 71 |
+
for attempt in range(
|
| 72 |
+
1, max_retries + 2
|
| 73 |
+
): # @ Since the first is normal and not consider in retries
|
| 74 |
+
try:
|
| 75 |
+
if attempt > 1 and last_error:
|
| 76 |
+
logger.info(f"Retrying: {attempt - 1}")
|
| 77 |
+
# Generate SQL
|
| 78 |
+
sql = self.generate_sql(user_question, context, errors=all_errors)
|
| 79 |
+
if not sql:
|
| 80 |
+
return None, None
|
| 81 |
+
else:
|
| 82 |
+
# Generate SQL
|
| 83 |
+
sql = self.generate_sql(user_question, context)
|
| 84 |
+
if not sql:
|
| 85 |
+
return None, None
|
| 86 |
+
|
| 87 |
+
# Try executing query
|
| 88 |
+
sql_query_str = sql.get("sql_query") if isinstance(sql, dict) else sql
|
| 89 |
+
if not isinstance(sql_query_str, str):
|
| 90 |
+
raise ValueError(
|
| 91 |
+
f"Expected SQL query to be a string, got {type(sql_query_str).__name__}"
|
| 92 |
+
)
|
| 93 |
+
query_df = self.run_query(sql_query_str)
|
| 94 |
+
|
| 95 |
+
# If execution succeeds, stop retrying or if df is not empty
|
| 96 |
+
if query_df is not None and not query_df.empty:
|
| 97 |
+
return sql, query_df
|
| 98 |
+
|
| 99 |
+
except Exception as e:
|
| 100 |
+
last_error = f"\nAttempt {attempt - 1}] {type(e).__name__}: {e}"
|
| 101 |
+
logger.error(f"Error during SQL generation or execution: {last_error}")
|
| 102 |
+
all_errors += last_error
|
| 103 |
+
|
| 104 |
+
logger.error(f"Failed after {max_retries} attempts. Last error: {all_errors}")
|
| 105 |
+
return None, None
|
| 106 |
+
|
| 107 |
+
def generate_pandera_schema(self, sql_query: str, user_instruction: str) -> str:
|
| 108 |
+
"""Generate Pandera schema."""
|
| 109 |
+
class_lines = []
|
| 110 |
+
|
| 111 |
+
schema_str = self.chain.run(
|
| 112 |
+
system_prompt=PANDERA_PROMPT,
|
| 113 |
+
user_prompt=PANDERA_USER_PROMPT.format(
|
| 114 |
+
sql_query=sql_query, instructions=user_instruction
|
| 115 |
+
),
|
| 116 |
+
format_name="pandera_schema",
|
| 117 |
+
response_format=PanderaSchemaModel,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
parsed = ast.parse(schema_str)
|
| 121 |
+
|
| 122 |
+
original_lines = schema_str.splitlines()
|
| 123 |
+
for node in parsed.body:
|
| 124 |
+
if isinstance(node, ast.ClassDef):
|
| 125 |
+
start, end = node.lineno - 1, node.end_lineno
|
| 126 |
+
class_lines.extend(original_lines[start:end])
|
| 127 |
+
|
| 128 |
+
return "\n".join(class_lines)
|
src/prompts.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
USER_PROMPT = """User's Text Question:
|
| 2 |
+
{question}
|
| 3 |
+
|
| 4 |
+
Provided table context information:
|
| 5 |
+
{context}"""
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
SQL_PROMPT = """You are an expert Text-to-SQL assistant. Convert the user's natural-language request into a single, read-only, syntactically valid DuckDB SQL SELECT statement that runs against the provided schema (the schema will be supplied as CREATE TABLE DDL). Use the exact table and column names from the schema.
|
| 9 |
+
|
| 10 |
+
Return two things:
|
| 11 |
+
1. The SQL statement.
|
| 12 |
+
2. A short natural-language description (1-2 sentences) of what the query returns.
|
| 13 |
+
|
| 14 |
+
Rules:
|
| 15 |
+
1. Output MUST be a single SELECT query. JOINs, subqueries, aggregations, GROUP BY, ORDER BY, and LIMIT are allowed.
|
| 16 |
+
2. Do NOT generate any DML/DDL (INSERT, UPDATE, DELETE, DROP, etc.) or non-read operations.
|
| 17 |
+
3. Use DuckDB SQL functions and syntax. For date/time grouping, use DATE_TRUNC('unit', column) (e.g., 'month', 'day', 'year').
|
| 18 |
+
4. Prefer explicit column lists. Use SELECT * only if the user explicitly requests all columns.
|
| 19 |
+
5. Make the query robust and maintainable, so it can be reused or adapted for similar analyses.
|
| 20 |
+
6. After execution in the downstream pipeline, if an error occurs (available as `Last Error` with a short description), analyze that error and rewrite the SQL to resolve it while preserving the user's intent. The rewritten query must still be valid DuckDB SQL.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
PANDERA_PROMPT = """You are provided with a SQL query which is used to fetch data from a database. Your task is to generate a valid Pandera SchemaModel class that can be used to validate the resulting data from the query.
|
| 24 |
+
The generated schema should be **general and simple**, not overly complex. Only validate basic aspects like column types, nullability, and simple value constraints (like positive integers, string patterns, or ranges) since you only have the SQL query and the resulting column names/types.
|
| 25 |
+
|
| 26 |
+
Follow these guidelines:
|
| 27 |
+
1. **Use Pandera SchemaModel**:
|
| 28 |
+
- Each column should have a type hint using `Series[Type]`.
|
| 29 |
+
- Use `pa.Field` to define simple validations.
|
| 30 |
+
2. **Validation rules should be simple and reasonable**:
|
| 31 |
+
- `nullable` for optional columns
|
| 32 |
+
- `unique` for IDs if obvious
|
| 33 |
+
- `gt`/`ge`/`lt`/`le` for numeric ranges if reasonable
|
| 34 |
+
- `str_matches`, `str_length` or `str_contains` for string patterns (like ZIP codes or emails)
|
| 35 |
+
- Avoid complex cross-column or statistical checks
|
| 36 |
+
3. **Add Config class**:
|
| 37 |
+
- Set `coerce = True` to cast data types automatically
|
| 38 |
+
4. **Add optional metadata**:
|
| 39 |
+
- Include `description` for columns if possible
|
| 40 |
+
- Include `title` for columns if it helps
|
| 41 |
+
5. **Output only valid Python code**:
|
| 42 |
+
- The output should be a **single Python class definition**.
|
| 43 |
+
- Do not include any explanations, comments, or extra text.
|
| 44 |
+
6. **Example Output**:
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
import pandas as pd
|
| 48 |
+
import pandera as pa
|
| 49 |
+
from pandera.typing import Series
|
| 50 |
+
|
| 51 |
+
class CustomerSchema(pa.DataFrameModel):
|
| 52 |
+
customer_id: Series[int] = pa.Field(gt=0, unique=True, nullable=False, description="Unique customer identifier")
|
| 53 |
+
first_name: Series[str] = pa.Field(nullable=False, str_length=(1, 50), description="Customer first name")
|
| 54 |
+
last_name: Series[str] = pa.Field(nullable=False, str_length=(1, 50), description="Customer last name")
|
| 55 |
+
email: Series[str] = pa.Field(nullable=False, str_matches=r"^[\\w\\.-]+@[\\w\\.-]+\\.\\w+$", description="Customer email address")
|
| 56 |
+
age: Series[int] = pa.Field(ge=0, le=120, nullable=True, description="Customer age in years")
|
| 57 |
+
|
| 58 |
+
class Config:
|
| 59 |
+
coerce = True
|
| 60 |
+
|
| 61 |
+
Additional notes:
|
| 62 |
+
If the SQL query uses JOIN, only include columns that appear in the SELECT statement.
|
| 63 |
+
You may infer basic constraints from column names (e.g., columns ending with _id are likely unique integers).
|
| 64 |
+
Avoid domain-specific logic unless it is obvious from the column names or SQL query.
|
| 65 |
+
Keep the schema robust but simple, suitable for automated ETL validation."""
|
| 66 |
+
|
| 67 |
+
PANDERA_USER_PROMPT = """SQL Query:
|
| 68 |
+
{sql_query}
|
| 69 |
+
|
| 70 |
+
User Instructions:
|
| 71 |
+
{instructions}"""
|