Spaces:
Running
Running
Muhammad Mustehson
commited on
Commit
·
a360e3c
1
Parent(s):
902da82
Update Old Code
Browse files- .gitignore +5 -1
- app.py +167 -157
- requirements.txt +9 -10
- src/__init__.py +0 -0
- src/client.py +131 -0
- src/models.py +6 -0
- src/pipelines.py +98 -0
- src/prompts.py +22 -0
.gitignore
CHANGED
|
@@ -1 +1,5 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.env
|
| 2 |
+
.venv
|
| 3 |
+
__pycache__/
|
| 4 |
+
*.pyc
|
| 5 |
+
*.pyo
|
app.py
CHANGED
|
@@ -1,103 +1,88 @@
|
|
|
|
|
| 1 |
import os
|
| 2 |
-
import
|
|
|
|
| 3 |
import duckdb
|
| 4 |
-
import spaces
|
| 5 |
-
import lancedb
|
| 6 |
import gradio as gr
|
|
|
|
| 7 |
import pandas as pd
|
| 8 |
import pyarrow as pa
|
| 9 |
-
from
|
| 10 |
-
from langsmith import traceable
|
| 11 |
-
from sentence_transformers import SentenceTransformer
|
| 12 |
-
from langchain_huggingface.llms import HuggingFacePipeline
|
| 13 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
| 17 |
|
|
|
|
| 18 |
|
| 19 |
-
#----------CONNECT TO DATABASE----------
|
| 20 |
-
md_token = os.getenv('MD_TOKEN')
|
| 21 |
-
conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True)
|
| 22 |
-
#---------------------------------------
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
| 31 |
|
| 32 |
-
#--------------LanceDB-------------
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
|
| 40 |
-
lance_schema = pa.schema([
|
| 41 |
-
pa.field("vector", pa.list_(pa.float32())),
|
| 42 |
-
pa.field("sql-query", pa.utf8())
|
| 43 |
-
])
|
| 44 |
|
| 45 |
-
|
| 46 |
-
table = lance_db.create_table(name="SQL-Queries", schema=lance_schema)
|
| 47 |
-
except:
|
| 48 |
-
table = lance_db.open_table(name="SQL-Queries")
|
| 49 |
-
#---------------------------------------
|
| 50 |
|
| 51 |
-
#-------LOAD HUGGINGFACE PIPELINE-------
|
| 52 |
-
tokenizer = AutoTokenizer.from_pretrained("defog/llama-3-sqlcoder-8b")
|
| 53 |
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
-
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=1024, return_full_text=False)
|
| 64 |
-
hf = HuggingFacePipeline(pipeline=pipe)
|
| 65 |
-
#---------------------------------------
|
| 66 |
|
| 67 |
-
|
| 68 |
-
prompt = hub.pull("sql-agent-prompt")
|
| 69 |
-
#---------------------------------------
|
| 70 |
|
| 71 |
-
#-----LOAD EMBEDDING MODEL-----
|
| 72 |
-
embedding_model = SentenceTransformer("all-MiniLM-L6-v2", device=device)
|
| 73 |
-
#---------------------------------------
|
| 74 |
|
| 75 |
-
|
| 76 |
-
# Get Databases
|
| 77 |
-
def get_schemas():
|
| 78 |
schemas = conn.execute("""
|
| 79 |
SELECT DISTINCT schema_name
|
| 80 |
FROM information_schema.schemata
|
| 81 |
WHERE schema_name NOT IN ('information_schema', 'pg_catalog')
|
| 82 |
""").fetchall()
|
| 83 |
-
return [item[0] for item in schemas]
|
|
|
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
|
|
|
| 88 |
return [table[0] for table in tables]
|
| 89 |
|
| 90 |
-
|
| 91 |
-
def update_tables(schema_name):
|
| 92 |
tables = get_tables(schema_name)
|
| 93 |
return gr.update(choices=tables)
|
| 94 |
|
| 95 |
-
|
| 96 |
-
def get_table_schema(table):
|
| 97 |
-
result = conn.sql(
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
|
|
|
|
|
|
| 101 |
full_path = f"{parent_database}.{schema_name}.{table}"
|
| 102 |
if schema_name != "main":
|
| 103 |
old_path = f"{schema_name}.{table}"
|
|
@@ -106,85 +91,81 @@ def get_table_schema(table):
|
|
| 106 |
ddl_create = ddl_create.replace(old_path, full_path)
|
| 107 |
return ddl_create
|
| 108 |
|
| 109 |
-
|
| 110 |
-
def
|
| 111 |
-
return prompt.format(schema=schema, query_input=query_input)
|
| 112 |
-
|
| 113 |
-
@spaces.GPU(duration=60)
|
| 114 |
-
@traceable()
|
| 115 |
-
def generate_sql(prompt):
|
| 116 |
-
result = hf.invoke(prompt)
|
| 117 |
-
return result.strip()
|
| 118 |
-
@spaces.GPU(duration=10)
|
| 119 |
-
def embed_query(sql_query):
|
| 120 |
-
print(f'Creating Emebeddings {sql_query}')
|
| 121 |
-
if sql_query is not None:
|
| 122 |
-
embeddings = embedding_model.encode(sql_query, normalize_embeddings=True).tolist()
|
| 123 |
-
return embeddings
|
| 124 |
-
|
| 125 |
-
def log2lancedb(embeddings, sql_query):
|
| 126 |
-
data = [{
|
| 127 |
-
"sql-query": sql_query,
|
| 128 |
-
"vector": embeddings
|
| 129 |
-
}]
|
| 130 |
-
table.add(data)
|
| 131 |
-
print(f'Added to Lance DB.')
|
| 132 |
-
#---------------------------------------
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
# Generate SQL
|
| 136 |
-
def text2sql(table, query_input):
|
| 137 |
if table is None:
|
| 138 |
-
return
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
}
|
| 144 |
-
|
| 145 |
-
schema = get_table_schema(table)
|
| 146 |
-
print(f'Schema Generated...')
|
| 147 |
-
prompt = get_prompt(schema, query_input)
|
| 148 |
-
print(f'Prompt Generated...')
|
| 149 |
-
|
| 150 |
-
try:
|
| 151 |
-
print(f'Generating SQL... {model.device}')
|
| 152 |
-
result = generate_sql(prompt)
|
| 153 |
-
print('SQL Generated...')
|
| 154 |
-
except Exception as e:
|
| 155 |
-
return {
|
| 156 |
-
table_schema: schema,
|
| 157 |
-
input_prompt: prompt,
|
| 158 |
-
generated_query: "",
|
| 159 |
-
result_output:pd.DataFrame([{"error": f"❌ Unable to get the SQL query based on the text. {e}"}])
|
| 160 |
-
}
|
| 161 |
-
|
| 162 |
try:
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
try:
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
return {
|
| 181 |
table_schema: schema,
|
| 182 |
-
input_prompt:
|
| 183 |
-
generated_query:
|
| 184 |
-
result_output:
|
| 185 |
}
|
| 186 |
|
| 187 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
custom_css = """
|
| 189 |
.gradio-container {
|
| 190 |
background-color: #f0f4f8;
|
|
@@ -202,9 +183,11 @@ custom_css = """
|
|
| 202 |
}
|
| 203 |
"""
|
| 204 |
|
| 205 |
-
with gr.Blocks(
|
|
|
|
|
|
|
| 206 |
gr.Image("logo.png", label=None, show_label=False, container=False, height=100)
|
| 207 |
-
|
| 208 |
gr.Markdown("""
|
| 209 |
<div style='text-align: center;'>
|
| 210 |
<strong style='font-size: 36px;'>Datajoi SQL Agent</strong>
|
|
@@ -214,13 +197,18 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"
|
|
| 214 |
""")
|
| 215 |
|
| 216 |
with gr.Row():
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
|
|
|
|
|
|
|
|
|
| 221 |
|
| 222 |
with gr.Column(scale=2):
|
| 223 |
-
query_input = gr.Textbox(
|
|
|
|
|
|
|
| 224 |
with gr.Row():
|
| 225 |
with gr.Column(scale=7):
|
| 226 |
pass
|
|
@@ -229,17 +217,39 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"
|
|
| 229 |
|
| 230 |
with gr.Tabs():
|
| 231 |
with gr.Tab("Result"):
|
| 232 |
-
result_output = gr.DataFrame(
|
|
|
|
|
|
|
| 233 |
with gr.Tab("SQL Query"):
|
| 234 |
-
generated_query = gr.Textbox(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
with gr.Tab("Prompt"):
|
| 236 |
-
input_prompt = gr.Textbox(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
with gr.Tab("Schema"):
|
| 238 |
-
table_schema = gr.Textbox(
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
|
| 243 |
if __name__ == "__main__":
|
| 244 |
demo.launch()
|
| 245 |
-
|
|
|
|
| 1 |
+
import logging
|
| 2 |
import os
|
| 3 |
+
from typing import Any, Dict, List
|
| 4 |
+
|
| 5 |
import duckdb
|
|
|
|
|
|
|
| 6 |
import gradio as gr
|
| 7 |
+
import lancedb
|
| 8 |
import pandas as pd
|
| 9 |
import pyarrow as pa
|
| 10 |
+
from dotenv import load_dotenv
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
+
from src.client import LLMChain, embed_client
|
| 13 |
+
from src.pipelines import SQLPipeline
|
| 14 |
|
| 15 |
+
load_dotenv()
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
+
# ========ENV's========
|
| 19 |
+
MD_TOKEN = os.getenv("MD_TOKEN")
|
| 20 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 21 |
+
conn = duckdb.connect(f"md:my_db?motherduck_token={MD_TOKEN}", read_only=True)
|
| 22 |
+
LEVEL = "INFO" if not os.getenv("ENV") == "PROD" else "WARNING"
|
| 23 |
+
EMB_URL = os.getenv("EMB_URL")
|
| 24 |
+
EMB_MODEL = os.getenv("EMB_MODEL")
|
| 25 |
+
TAB_LINES = 8
|
| 26 |
+
# =====================
|
| 27 |
|
|
|
|
| 28 |
|
| 29 |
+
logging.basicConfig(
|
| 30 |
+
level=getattr(logging, LEVEL, logging.INFO),
|
| 31 |
+
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
| 32 |
+
)
|
| 33 |
+
logger = logging.getLogger(__name__)
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
+
pipe = SQLPipeline(duckdb=conn, chain=LLMChain())
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
|
|
|
|
|
|
| 38 |
|
| 39 |
+
def _setup_lancedb() -> lancedb.table.Table:
|
| 40 |
+
lance_db = lancedb.connect(
|
| 41 |
+
uri=os.getenv("lancedb_uri"),
|
| 42 |
+
api_key=os.getenv("lancedb_api_key"),
|
| 43 |
+
region=os.getenv("lancedb_region"),
|
| 44 |
+
)
|
| 45 |
+
lance_schema = pa.schema(
|
| 46 |
+
[pa.field("vector", pa.list_(pa.float32())), pa.field("sql-query", pa.utf8())]
|
| 47 |
+
)
|
| 48 |
+
try:
|
| 49 |
+
table = lance_db.create_table(name="SQL-Queries", schema=lance_schema)
|
| 50 |
+
except Exception:
|
| 51 |
+
table = lance_db.open_table(name="SQL-Queries")
|
| 52 |
+
return table
|
| 53 |
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
+
lance_table = _setup_lancedb()
|
|
|
|
|
|
|
| 56 |
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
+
def get_schemas() -> List:
|
|
|
|
|
|
|
| 59 |
schemas = conn.execute("""
|
| 60 |
SELECT DISTINCT schema_name
|
| 61 |
FROM information_schema.schemata
|
| 62 |
WHERE schema_name NOT IN ('information_schema', 'pg_catalog')
|
| 63 |
""").fetchall()
|
| 64 |
+
return [item[0] for item in schemas]
|
| 65 |
+
|
| 66 |
|
| 67 |
+
def get_tables(schema_name: str) -> List:
|
| 68 |
+
tables = conn.execute(
|
| 69 |
+
f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{schema_name}'"
|
| 70 |
+
).fetchall()
|
| 71 |
return [table[0] for table in tables]
|
| 72 |
|
| 73 |
+
|
| 74 |
+
def update_tables(schema_name: str):
|
| 75 |
tables = get_tables(schema_name)
|
| 76 |
return gr.update(choices=tables)
|
| 77 |
|
| 78 |
+
|
| 79 |
+
def get_table_schema(table: str) -> str:
|
| 80 |
+
result = conn.sql(
|
| 81 |
+
f"SELECT sql, database_name, schema_name FROM duckdb_tables() where table_name ='{table}';"
|
| 82 |
+
).df()
|
| 83 |
+
ddl_create = result.iloc[0, 0]
|
| 84 |
+
parent_database = result.iloc[0, 1]
|
| 85 |
+
schema_name = result.iloc[0, 2]
|
| 86 |
full_path = f"{parent_database}.{schema_name}.{table}"
|
| 87 |
if schema_name != "main":
|
| 88 |
old_path = f"{schema_name}.{table}"
|
|
|
|
| 91 |
ddl_create = ddl_create.replace(old_path, full_path)
|
| 92 |
return ddl_create
|
| 93 |
|
| 94 |
+
|
| 95 |
+
def run_pipeline(table: str, query_input: str) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
if table is None:
|
| 97 |
+
return _error_response(
|
| 98 |
+
query_input=query_input, message="❌ Please select a table/schema."
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
schema = ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
try:
|
| 103 |
+
schema = get_table_schema(table=table)
|
| 104 |
+
|
| 105 |
+
sql, df = pipe.try_sql_with_retries(
|
| 106 |
+
user_question=query_input,
|
| 107 |
+
context=schema,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
if not sql or df is None:
|
| 111 |
+
return _error_response(
|
| 112 |
+
query_input=query_input,
|
| 113 |
+
schema=schema,
|
| 114 |
+
message="❌ Unable to generate SQL from the input text.",
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
except Exception as exc:
|
| 118 |
+
logger.exception("Pipeline execution failed")
|
| 119 |
+
return _error_response(
|
| 120 |
+
query_input=query_input, schema=schema, message=f"❌ Pipeline error: {exc}"
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
try:
|
| 124 |
+
sql_str = f"{query_input}\n{sql.get('sql_query', '')}"
|
| 125 |
+
embeddings = embed_query(sql_str)
|
| 126 |
+
log2lancedb(embeddings, sql_str)
|
| 127 |
+
|
| 128 |
+
except Exception as exc:
|
| 129 |
+
logger.warning("Embedding/logging failed: %s", exc)
|
| 130 |
+
|
| 131 |
+
return {
|
| 132 |
+
table_schema: schema,
|
| 133 |
+
input_prompt: query_input,
|
| 134 |
+
generated_query: sql.get("sql_query", ""),
|
| 135 |
+
result_output: df,
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def _error_response(
|
| 140 |
+
*,
|
| 141 |
+
query_input: str,
|
| 142 |
+
message: str,
|
| 143 |
+
schema: str = "",
|
| 144 |
+
) -> Dict[str, Any]:
|
| 145 |
return {
|
| 146 |
table_schema: schema,
|
| 147 |
+
input_prompt: query_input,
|
| 148 |
+
generated_query: "",
|
| 149 |
+
result_output: pd.DataFrame([{"error": message}]),
|
| 150 |
}
|
| 151 |
|
| 152 |
+
|
| 153 |
+
def embed_query(data: str) -> List:
|
| 154 |
+
logger.info(f"Creating Emebeddings {data}")
|
| 155 |
+
try:
|
| 156 |
+
results = embed_client.feature_extraction(text=data, model=EMB_MODEL)
|
| 157 |
+
return results.tolist()
|
| 158 |
+
except Exception as e:
|
| 159 |
+
logger.error(f"Unable to Generate embedding for the given query: {e}")
|
| 160 |
+
return []
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def log2lancedb(embeddings: List, sql_query: str) -> None:
|
| 164 |
+
data = [{"sql-query": sql_query, "vector": embeddings}]
|
| 165 |
+
lance_table.add(data)
|
| 166 |
+
logger.info("Added to Lance DB.")
|
| 167 |
+
|
| 168 |
+
|
| 169 |
custom_css = """
|
| 170 |
.gradio-container {
|
| 171 |
background-color: #f0f4f8;
|
|
|
|
| 183 |
}
|
| 184 |
"""
|
| 185 |
|
| 186 |
+
with gr.Blocks(
|
| 187 |
+
theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"), css=custom_css
|
| 188 |
+
) as demo:
|
| 189 |
gr.Image("logo.png", label=None, show_label=False, container=False, height=100)
|
| 190 |
+
|
| 191 |
gr.Markdown("""
|
| 192 |
<div style='text-align: center;'>
|
| 193 |
<strong style='font-size: 36px;'>Datajoi SQL Agent</strong>
|
|
|
|
| 197 |
""")
|
| 198 |
|
| 199 |
with gr.Row():
|
| 200 |
+
with gr.Column(scale=1, variant="panel"):
|
| 201 |
+
schema_dropdown = gr.Dropdown(
|
| 202 |
+
choices=get_schemas(), label="Select Schema", interactive=True
|
| 203 |
+
)
|
| 204 |
+
tables_dropdown = gr.Dropdown(
|
| 205 |
+
choices=[], label="Available Tables", value=None
|
| 206 |
+
)
|
| 207 |
|
| 208 |
with gr.Column(scale=2):
|
| 209 |
+
query_input = gr.Textbox(
|
| 210 |
+
lines=5, label="Text Query", placeholder="Enter your text query here..."
|
| 211 |
+
)
|
| 212 |
with gr.Row():
|
| 213 |
with gr.Column(scale=7):
|
| 214 |
pass
|
|
|
|
| 217 |
|
| 218 |
with gr.Tabs():
|
| 219 |
with gr.Tab("Result"):
|
| 220 |
+
result_output = gr.DataFrame(
|
| 221 |
+
label="Query Results", value=[], interactive=False
|
| 222 |
+
)
|
| 223 |
with gr.Tab("SQL Query"):
|
| 224 |
+
generated_query = gr.Textbox(
|
| 225 |
+
lines=TAB_LINES,
|
| 226 |
+
label="Generated SQL Query",
|
| 227 |
+
value="",
|
| 228 |
+
interactive=False,
|
| 229 |
+
)
|
| 230 |
with gr.Tab("Prompt"):
|
| 231 |
+
input_prompt = gr.Textbox(
|
| 232 |
+
lines=TAB_LINES,
|
| 233 |
+
label="Input Prompt",
|
| 234 |
+
value="",
|
| 235 |
+
interactive=False,
|
| 236 |
+
)
|
| 237 |
with gr.Tab("Schema"):
|
| 238 |
+
table_schema = gr.Textbox(
|
| 239 |
+
lines=TAB_LINES,
|
| 240 |
+
label="Table Schema",
|
| 241 |
+
value="",
|
| 242 |
+
interactive=False,
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
schema_dropdown.change(
|
| 246 |
+
update_tables, inputs=schema_dropdown, outputs=tables_dropdown
|
| 247 |
+
)
|
| 248 |
+
generate_query_button.click(
|
| 249 |
+
run_pipeline,
|
| 250 |
+
inputs=[tables_dropdown, query_input],
|
| 251 |
+
outputs=[table_schema, input_prompt, generated_query, result_output],
|
| 252 |
+
)
|
| 253 |
|
| 254 |
if __name__ == "__main__":
|
| 255 |
demo.launch()
|
|
|
requirements.txt
CHANGED
|
@@ -1,10 +1,9 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
langchain-huggingface
|
|
|
|
| 1 |
+
huggingface-hub==0.35.0
|
| 2 |
+
duckdb==1.3.2
|
| 3 |
+
pandas==2.3.1
|
| 4 |
+
numpy==2.3.2
|
| 5 |
+
pydantic
|
| 6 |
+
python-dotenv
|
| 7 |
+
gradio
|
| 8 |
+
pyarrow
|
| 9 |
+
lancedb
|
|
|
src/__init__.py
ADDED
|
File without changes
|
src/client.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
from dotenv import load_dotenv
|
| 6 |
+
from huggingface_hub import InferenceClient
|
| 7 |
+
from pydantic import BaseModel
|
| 8 |
+
|
| 9 |
+
load_dotenv()
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
MAX_RESPONSE_TOKENS = 2048
|
| 14 |
+
TEMPERATURE = 0.9
|
| 15 |
+
|
| 16 |
+
models = json.loads(os.getenv("MODEL_NAMES"))
|
| 17 |
+
providers = json.loads(os.getenv("PROVIDERS"))
|
| 18 |
+
EMB_MODEL = os.getenv("EMB_MODEL")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _engine_working(engine: InferenceClient) -> bool:
|
| 22 |
+
try:
|
| 23 |
+
engine.chat_completion([{"role": "user", "content": "ping"}], max_tokens=1)
|
| 24 |
+
logger.info("Engine is Working.")
|
| 25 |
+
return True
|
| 26 |
+
except Exception as e:
|
| 27 |
+
logger.exception(f"Engine is not working: {e}")
|
| 28 |
+
return False
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _load_llm_client() -> InferenceClient:
|
| 32 |
+
"""
|
| 33 |
+
Attempts to load the provided model from the huggingface endpoint.
|
| 34 |
+
|
| 35 |
+
Returns InferenceClient if successful.
|
| 36 |
+
Raises Exception if no model is available.
|
| 37 |
+
"""
|
| 38 |
+
logger.warning("Loading Model...")
|
| 39 |
+
errors = []
|
| 40 |
+
for model in models:
|
| 41 |
+
for provider in providers:
|
| 42 |
+
if isinstance(model, str):
|
| 43 |
+
try:
|
| 44 |
+
logger.info(f"Checking model: {model} provider: {provider}")
|
| 45 |
+
client = InferenceClient(
|
| 46 |
+
model=model,
|
| 47 |
+
timeout=15,
|
| 48 |
+
provider=provider,
|
| 49 |
+
)
|
| 50 |
+
if _engine_working(client):
|
| 51 |
+
logger.info(
|
| 52 |
+
f"The model is loaded : {model} , provider: {provider}"
|
| 53 |
+
)
|
| 54 |
+
return client
|
| 55 |
+
except Exception as e:
|
| 56 |
+
logger.error(
|
| 57 |
+
f"Error loading model {model} provider {provider}: {e}"
|
| 58 |
+
)
|
| 59 |
+
errors.append(str(e))
|
| 60 |
+
raise Exception(f"Unable to load any provided model: {errors}.")
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _load_embedding_client() -> InferenceClient:
|
| 64 |
+
logger.warning("Loading Embedding Model...")
|
| 65 |
+
try:
|
| 66 |
+
emb_client = InferenceClient(timeout=15, model=EMB_MODEL)
|
| 67 |
+
return emb_client
|
| 68 |
+
except Exception as e:
|
| 69 |
+
logger.error(f"Error loading model {EMB_MODEL}: {e}")
|
| 70 |
+
raise Exception("Unable to load the embedding model.")
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
_default_client = _load_llm_client()
|
| 74 |
+
embed_client = _load_embedding_client()
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class LLMChain:
|
| 78 |
+
def __init__(self, client: InferenceClient = _default_client):
|
| 79 |
+
self.client = client
|
| 80 |
+
self.total_tokens = 0
|
| 81 |
+
|
| 82 |
+
def run(
|
| 83 |
+
self,
|
| 84 |
+
system_prompt: str | None = None,
|
| 85 |
+
user_prompt: str | None = None,
|
| 86 |
+
messages: list[dict] | None = None,
|
| 87 |
+
format_name: str | None = None,
|
| 88 |
+
response_format: type[BaseModel] | None = None,
|
| 89 |
+
) -> str | dict[str, str | int | float | None] | list[str] | None:
|
| 90 |
+
try:
|
| 91 |
+
if system_prompt and user_prompt:
|
| 92 |
+
messages = [
|
| 93 |
+
{"role": "system", "content": system_prompt},
|
| 94 |
+
{"role": "user", "content": user_prompt},
|
| 95 |
+
]
|
| 96 |
+
elif not messages:
|
| 97 |
+
raise ValueError(
|
| 98 |
+
"Either system_prompt and user_prompt or messages must be provided."
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
llm_response = self.client.chat_completion(
|
| 102 |
+
messages=messages,
|
| 103 |
+
max_tokens=MAX_RESPONSE_TOKENS,
|
| 104 |
+
temperature=TEMPERATURE,
|
| 105 |
+
response_format=(
|
| 106 |
+
{
|
| 107 |
+
"type": "json_schema",
|
| 108 |
+
"json_schema": {
|
| 109 |
+
"name": format_name,
|
| 110 |
+
"schema": response_format.model_json_schema(),
|
| 111 |
+
"strict": True,
|
| 112 |
+
},
|
| 113 |
+
}
|
| 114 |
+
if format_name and response_format
|
| 115 |
+
else None
|
| 116 |
+
),
|
| 117 |
+
)
|
| 118 |
+
self.total_tokens += llm_response.usage.total_tokens
|
| 119 |
+
analysis = llm_response.choices[0].message.content
|
| 120 |
+
if response_format:
|
| 121 |
+
analysis = json.loads(analysis)
|
| 122 |
+
fields = list(response_format.model_fields.keys())
|
| 123 |
+
if len(fields) == 1:
|
| 124 |
+
return analysis.get(fields[0])
|
| 125 |
+
return {field: analysis.get(field) for field in fields}
|
| 126 |
+
|
| 127 |
+
return analysis
|
| 128 |
+
|
| 129 |
+
except Exception as e:
|
| 130 |
+
logger.error(f"Error during LLM calls: {e}")
|
| 131 |
+
return None
|
src/models.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, Field
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class SQLQueryModel(BaseModel):
|
| 5 |
+
sql_query: str = Field(..., description="SQL query to execute.")
|
| 6 |
+
explanation: str = Field(..., description="Short explanation of the SQL query.")
|
src/pipelines.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from duckdb import DuckDBPyConnection
|
| 6 |
+
|
| 7 |
+
from src.models import SQLQueryModel
|
| 8 |
+
from src.prompts import SQL_PROMPT, USER_PROMPT
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
SQL_GENERATION_RETRIES = int(os.getenv("SQL_GENERATION_RETRIES", "5"))
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class SQLPipeline:
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
duckdb: DuckDBPyConnection,
|
| 20 |
+
chain,
|
| 21 |
+
) -> None:
|
| 22 |
+
self._duckdb = duckdb
|
| 23 |
+
self.chain = chain
|
| 24 |
+
|
| 25 |
+
def generate_sql(
|
| 26 |
+
self, user_question: str, context: str, errors: str | None = None
|
| 27 |
+
) -> str | dict[str, str | int | float | None] | list[str] | None:
|
| 28 |
+
"""Generate SQL + description."""
|
| 29 |
+
user_prompt_formatted = USER_PROMPT.format(
|
| 30 |
+
question=user_question, context=context
|
| 31 |
+
)
|
| 32 |
+
if errors:
|
| 33 |
+
user_prompt_formatted += f"Carefully review the previous error or\
|
| 34 |
+
exception and rewrite the SQL so that the error does not occur again.\
|
| 35 |
+
Try a different approach or rewrite SQL if needed. Last error: {errors}"
|
| 36 |
+
|
| 37 |
+
sql = self.chain.run(
|
| 38 |
+
system_prompt=SQL_PROMPT,
|
| 39 |
+
user_prompt=user_prompt_formatted,
|
| 40 |
+
format_name="sql_query",
|
| 41 |
+
response_format=SQLQueryModel,
|
| 42 |
+
)
|
| 43 |
+
logger.info(f"SQL Generated Successfully: {sql}")
|
| 44 |
+
return sql
|
| 45 |
+
|
| 46 |
+
def run_query(self, sql_query: str) -> pd.DataFrame | None:
|
| 47 |
+
"""Execute SQL and return dataframe."""
|
| 48 |
+
logger.info("Query Execution Started.")
|
| 49 |
+
return self._duckdb.query(sql_query).df()
|
| 50 |
+
|
| 51 |
+
def try_sql_with_retries(
|
| 52 |
+
self,
|
| 53 |
+
user_question: str,
|
| 54 |
+
context: str,
|
| 55 |
+
max_retries: int = SQL_GENERATION_RETRIES,
|
| 56 |
+
) -> tuple[
|
| 57 |
+
str | dict[str, str | int | float | None] | list[str] | None,
|
| 58 |
+
pd.DataFrame | None,
|
| 59 |
+
]:
|
| 60 |
+
"""Try SQL generation + execution with retries."""
|
| 61 |
+
last_error = None
|
| 62 |
+
all_errors = ""
|
| 63 |
+
|
| 64 |
+
for attempt in range(
|
| 65 |
+
1, max_retries + 2
|
| 66 |
+
): # @ Since the first is normal and not consider in retries
|
| 67 |
+
try:
|
| 68 |
+
if attempt > 1 and last_error:
|
| 69 |
+
logger.info(f"Retrying: {attempt - 1}")
|
| 70 |
+
# Generate SQL
|
| 71 |
+
sql = self.generate_sql(user_question, context, errors=all_errors)
|
| 72 |
+
if not sql:
|
| 73 |
+
return None, None
|
| 74 |
+
else:
|
| 75 |
+
# Generate SQL
|
| 76 |
+
sql = self.generate_sql(user_question, context)
|
| 77 |
+
if not sql:
|
| 78 |
+
return None, None
|
| 79 |
+
|
| 80 |
+
# Try executing query
|
| 81 |
+
sql_query_str = sql.get("sql_query") if isinstance(sql, dict) else sql
|
| 82 |
+
if not isinstance(sql_query_str, str):
|
| 83 |
+
raise ValueError(
|
| 84 |
+
f"Expected SQL query to be a string, got {type(sql_query_str).__name__}"
|
| 85 |
+
)
|
| 86 |
+
query_df = self.run_query(sql_query_str)
|
| 87 |
+
|
| 88 |
+
# If execution succeeds, stop retrying or if df is not empty
|
| 89 |
+
if query_df is not None and not query_df.empty:
|
| 90 |
+
return sql, query_df
|
| 91 |
+
|
| 92 |
+
except Exception as e:
|
| 93 |
+
last_error = f"\nAttempt {attempt - 1}] {type(e).__name__}: {e}"
|
| 94 |
+
logger.error(f"Error during SQL generation or execution: {last_error}")
|
| 95 |
+
all_errors += last_error
|
| 96 |
+
|
| 97 |
+
logger.error(f"Failed after {max_retries} attempts. Last error: {all_errors}")
|
| 98 |
+
return None, None
|
src/prompts.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
USER_PROMPT = """User's Text Question:
|
| 2 |
+
{question}
|
| 3 |
+
|
| 4 |
+
Provided table context information:
|
| 5 |
+
{context}"""
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
SQL_PROMPT = """You are an expert Text-to-SQL assistant. Convert the user's natural-language request into a single, read-only, syntactically valid DuckDB SQL SELECT statement that runs against the provided schema (the schema will be supplied as CREATE TABLE DDL). Use the exact table and column names from the schema.
|
| 9 |
+
|
| 10 |
+
Return two things:
|
| 11 |
+
1. The SQL statement.
|
| 12 |
+
2. A short natural-language description (1-2 sentences) of what the query returns.
|
| 13 |
+
|
| 14 |
+
Rules:
|
| 15 |
+
1. Output MUST be a single SELECT query. JOINs, subqueries, aggregations, GROUP BY, ORDER BY, and LIMIT are allowed.
|
| 16 |
+
2. Do NOT generate any DML/DDL (INSERT, UPDATE, DELETE, DROP, etc.) or non-read operations.
|
| 17 |
+
3. Use DuckDB SQL functions and syntax. For date/time grouping, use DATE_TRUNC('unit', column) (e.g., 'month', 'day', 'year').
|
| 18 |
+
4. Prefer explicit column lists. Use SELECT * only if the user explicitly requests all columns.
|
| 19 |
+
5. Make the query robust and maintainable, so it can be reused or adapted for similar analyses.
|
| 20 |
+
6. After execution in the downstream pipeline, if an error occurs (available as `Last Error` with a short description), analyze that error and rewrite the SQL to resolve it while preserving the user's intent. The rewritten query must still be valid DuckDB SQL.
|
| 21 |
+
7. If the user requests a distribution/histogram, return SQL that selects a single numeric column only, so binning can be performed downstream.
|
| 22 |
+
"""
|