Text-to-SQL-RAG / app.py
alokik29's picture
Update app.py
fb85d2e verified
import torch
import sqlite3
import pandas as pd
import gradio as gr
import re
from langchain_community.llms import HuggingFacePipeline
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
# Load model
model_id = "defog/sqlcoder-7b-2"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map="auto")
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=256, do_sample=False)
sqlcoder_llm = HuggingFacePipeline(pipeline=pipe)
def ask_question(user_db, question):
if not user_db:
return "โŒ Upload database", None
conn = sqlite3.connect(user_db.name)
cursor = conn.cursor()
# Get full schema with columns
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = [row[0] for row in cursor.fetchall()]
schema_info = []
for table in tables:
cursor.execute(f"PRAGMA table_info({table});")
columns = [col[1] for col in cursor.fetchall()]
schema_info.append(f"{table}({', '.join(columns)})")
schema_text = "\n".join(schema_info)
# Smart prompt - let model figure out the right table
prompt = f"""You are a SQL expert. Generate ONLY the SQL query, nothing else.
Database Schema:
{schema_text}
Instructions:
- Use the EXACT table and column names from the schema above
- If user asks about concepts (like "sales", "customers", "products"), find the most relevant table
- Return ONLY valid SQL with semicolon
- No explanations, no markdown, just SQL
Question: {question}
SQL:"""
# Generate SQL
response = sqlcoder_llm.invoke(prompt)
sql = str(response).strip()
# Extract SQL
if "SQL:" in sql:
sql = sql.split("SQL:")[-1].strip()
sql = sql.split("\n")[0].strip()
if not sql.endswith(";"):
sql += ";"
# Remove common formatting
sql = sql.replace("```sql", "").replace("```", "").strip()
# Execute
try:
cursor.execute(sql)
rows = cursor.fetchall()
if cursor.description:
df = pd.DataFrame(rows, columns=[d[0] for d in cursor.description])
else:
df = pd.DataFrame()
conn.close()
return f"โœ… SQL:\n{sql}\n\n๐Ÿ“Š {len(df)} rows", df
except sqlite3.Error as e:
conn.close()
return f"โŒ Error: {e}\n\nSQL tried:\n{sql}\n\n๐Ÿ’ก Available tables:\n{schema_text}", None
# UI
demo = gr.Interface(
fn=ask_question,
inputs=[
gr.File(label="๐Ÿ“ Upload Database (.db)"),
gr.Textbox(label="โ“ Ask Question", placeholder="e.g., show all data, highest value, total count")
],
outputs=[
gr.Textbox(label="๐Ÿค– SQL & Status", lines=6),
gr.Dataframe(label="๐Ÿ“Š Results")
],
title="๐Ÿ”ฎ Universal Text-to-SQL",
description="Upload ANY SQLite database and ask questions. The AI will figure out the right tables!"
)
demo.launch()