Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import sqlite3 | |
| import pandas as pd | |
| import plotly.express as px | |
| import os | |
| import re | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| # Load AI model for SQL generation | |
| MODEL_NAME = "microsoft/phi-2" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map="auto", torch_dtype="auto") | |
| def generate_sql(nl_query, schema): | |
| """Generate SQL query from a natural language query using AI.""" | |
| prompt = f"""### Database Schema: | |
| {schema} | |
| ### Convert the following question into an SQL query: | |
| {nl_query} | |
| SQL Query: | |
| """ | |
| input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device, dtype=torch.long) # Ensure Long dtype | |
| with torch.no_grad(): | |
| output_ids = model.generate(input_ids, max_length=128, pad_token_id=tokenizer.eos_token_id) | |
| # If model outputs in float16 or bfloat16, convert back to long/int | |
| if output_ids.dtype in [torch.float16, torch.bfloat16]: | |
| output_ids = output_ids.to(dtype=torch.long) | |
| # Decode and clean the output | |
| output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| sql_start = output_text.find("SQL Query:") + len("SQL Query:") | |
| sql_query = output_text[sql_start:].strip() | |
| # Clean SQL output | |
| sql_query = re.sub(r"```sql|```", "", sql_query).split("###")[0].strip() | |
| return sql_query | |
| def execute_sql(sql_query, db_path): | |
| """Execute the generated SQL query on the provided database.""" | |
| try: | |
| conn = sqlite3.connect(db_path) | |
| df = pd.read_sql_query(sql_query, conn) | |
| conn.close() | |
| return df | |
| except Exception as e: | |
| return str(e) | |
| def get_schema(db_path): | |
| """Extract schema from the uploaded database.""" | |
| conn = sqlite3.connect(db_path) | |
| cursor = conn.cursor() | |
| cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") | |
| tables = cursor.fetchall() | |
| schema = "" | |
| for table in tables: | |
| table_name = table[0] | |
| cursor.execute(f"PRAGMA table_info({table_name});") | |
| columns = cursor.fetchall() | |
| schema += f"TABLE {table_name} (\n" | |
| schema += ",\n".join([f" {col[1]} {col[2]}" for col in columns]) | |
| schema += "\n);\n\n" | |
| conn.close() | |
| return schema | |
| # --- Streamlit UI --- | |
| st.title("AI-Powered Text-to-SQL Generator") | |
| st.write("Convert natural language questions into SQL queries and execute them.") | |
| # Database selection | |
| db_option = st.radio("How do you want to provide your database?", ["Upload .db file", "Enter schema manually"]) | |
| db_path = None | |
| schema = "" | |
| if db_option == "Upload .db file": | |
| uploaded_file = st.file_uploader("Upload a SQLite `.db` file", type=["db"]) | |
| if uploaded_file: | |
| db_path = "uploaded_database.db" | |
| with open(db_path, "wb") as f: | |
| f.write(uploaded_file.read()) | |
| schema = get_schema(db_path) | |
| st.success("✅ Database uploaded successfully!") | |
| elif db_option == "Enter schema manually": | |
| st.write("Example schema format:") | |
| st.code( | |
| """TABLE employees ( | |
| employee_id INT PRIMARY KEY, | |
| first_name TEXT, | |
| last_name TEXT, | |
| salary INT | |
| );""", | |
| language="sql", | |
| ) | |
| schema = st.text_area("Enter your schema:") | |
| if schema: | |
| st.subheader("Extracted/Provided Schema:") | |
| st.code(schema, language="sql") | |
| # Query input | |
| user_query = st.text_area("📝 Enter your natural language query:") | |
| if st.button("Generate SQL Query"): | |
| if not schema: | |
| st.error("❌ Please provide a database or schema first.") | |
| else: | |
| sql_query = generate_sql(user_query, schema) | |
| st.subheader("Generated SQL Query:") | |
| st.code(sql_query, language="sql") | |
| # Execute SQL if database exists | |
| if db_path: | |
| result = execute_sql(sql_query, db_path) | |
| if isinstance(result, pd.DataFrame): | |
| st.subheader("📊 Query Results:") | |
| st.dataframe(result) | |
| # Visualization | |
| if not result.empty: | |
| st.subheader("📈 Data Visualization") | |
| fig = px.bar(result, x=result.columns[0], y=result.columns[1]) | |
| st.plotly_chart(fig) | |
| else: | |
| st.error(f"❌ SQL Execution Error: {result}") | |
| else: | |
| st.info("No database provided, only SQL query was generated.") | |