Text-to-SQL / app.py
Shodnotantelope2's picture
Update app.py
67ee210 verified
raw
history blame
4.4 kB
import streamlit as st
import sqlite3
import pandas as pd
import plotly.express as px
import os
import re
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load AI model for SQL generation
MODEL_NAME = "microsoft/phi-2"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map="auto", torch_dtype="auto")
def generate_sql(nl_query, schema):
"""Generate SQL query from a natural language query using AI."""
prompt = f"""### Database Schema:
{schema}
### Convert the following question into an SQL query:
{nl_query}
SQL Query:
"""
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device, dtype=torch.long) # Ensure Long dtype
with torch.no_grad():
output_ids = model.generate(input_ids, max_length=128, pad_token_id=tokenizer.eos_token_id)
# If model outputs in float16 or bfloat16, convert back to long/int
if output_ids.dtype in [torch.float16, torch.bfloat16]:
output_ids = output_ids.to(dtype=torch.long)
# Decode and clean the output
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
sql_start = output_text.find("SQL Query:") + len("SQL Query:")
sql_query = output_text[sql_start:].strip()
# Clean SQL output
sql_query = re.sub(r"```sql|```", "", sql_query).split("###")[0].strip()
return sql_query
def execute_sql(sql_query, db_path):
"""Execute the generated SQL query on the provided database."""
try:
conn = sqlite3.connect(db_path)
df = pd.read_sql_query(sql_query, conn)
conn.close()
return df
except Exception as e:
return str(e)
def get_schema(db_path):
"""Extract schema from the uploaded database."""
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
schema = ""
for table in tables:
table_name = table[0]
cursor.execute(f"PRAGMA table_info({table_name});")
columns = cursor.fetchall()
schema += f"TABLE {table_name} (\n"
schema += ",\n".join([f" {col[1]} {col[2]}" for col in columns])
schema += "\n);\n\n"
conn.close()
return schema
# --- Streamlit UI ---
st.title("AI-Powered Text-to-SQL Generator")
st.write("Convert natural language questions into SQL queries and execute them.")
# Database selection
db_option = st.radio("How do you want to provide your database?", ["Upload .db file", "Enter schema manually"])
db_path = None
schema = ""
if db_option == "Upload .db file":
uploaded_file = st.file_uploader("Upload a SQLite `.db` file", type=["db"])
if uploaded_file:
db_path = "uploaded_database.db"
with open(db_path, "wb") as f:
f.write(uploaded_file.read())
schema = get_schema(db_path)
st.success("✅ Database uploaded successfully!")
elif db_option == "Enter schema manually":
st.write("Example schema format:")
st.code(
"""TABLE employees (
employee_id INT PRIMARY KEY,
first_name TEXT,
last_name TEXT,
salary INT
);""",
language="sql",
)
schema = st.text_area("Enter your schema:")
if schema:
st.subheader("Extracted/Provided Schema:")
st.code(schema, language="sql")
# Query input
user_query = st.text_area("📝 Enter your natural language query:")
if st.button("Generate SQL Query"):
if not schema:
st.error("❌ Please provide a database or schema first.")
else:
sql_query = generate_sql(user_query, schema)
st.subheader("Generated SQL Query:")
st.code(sql_query, language="sql")
# Execute SQL if database exists
if db_path:
result = execute_sql(sql_query, db_path)
if isinstance(result, pd.DataFrame):
st.subheader("📊 Query Results:")
st.dataframe(result)
# Visualization
if not result.empty:
st.subheader("📈 Data Visualization")
fig = px.bar(result, x=result.columns[0], y=result.columns[1])
st.plotly_chart(fig)
else:
st.error(f"❌ SQL Execution Error: {result}")
else:
st.info("No database provided, only SQL query was generated.")