Spaces:
Running
Running
| import torch | |
| import sqlite3 | |
| import pandas as pd | |
| import gradio as gr | |
| import re | |
| from langchain_community.llms import HuggingFacePipeline | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| # Load model | |
| model_id = "defog/sqlcoder-7b-2" | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map="auto") | |
| pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=256, do_sample=False) | |
| sqlcoder_llm = HuggingFacePipeline(pipeline=pipe) | |
| def ask_question(user_db, question): | |
| if not user_db: | |
| return "โ Upload database", None | |
| conn = sqlite3.connect(user_db.name) | |
| cursor = conn.cursor() | |
| # Get full schema with columns | |
| cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") | |
| tables = [row[0] for row in cursor.fetchall()] | |
| schema_info = [] | |
| for table in tables: | |
| cursor.execute(f"PRAGMA table_info({table});") | |
| columns = [col[1] for col in cursor.fetchall()] | |
| schema_info.append(f"{table}({', '.join(columns)})") | |
| schema_text = "\n".join(schema_info) | |
| # Smart prompt - let model figure out the right table | |
| prompt = f"""You are a SQL expert. Generate ONLY the SQL query, nothing else. | |
| Database Schema: | |
| {schema_text} | |
| Instructions: | |
| - Use the EXACT table and column names from the schema above | |
| - If user asks about concepts (like "sales", "customers", "products"), find the most relevant table | |
| - Return ONLY valid SQL with semicolon | |
| - No explanations, no markdown, just SQL | |
| Question: {question} | |
| SQL:""" | |
| # Generate SQL | |
| response = sqlcoder_llm.invoke(prompt) | |
| sql = str(response).strip() | |
| # Extract SQL | |
| if "SQL:" in sql: | |
| sql = sql.split("SQL:")[-1].strip() | |
| sql = sql.split("\n")[0].strip() | |
| if not sql.endswith(";"): | |
| sql += ";" | |
| # Remove common formatting | |
| sql = sql.replace("```sql", "").replace("```", "").strip() | |
| # Execute | |
| try: | |
| cursor.execute(sql) | |
| rows = cursor.fetchall() | |
| if cursor.description: | |
| df = pd.DataFrame(rows, columns=[d[0] for d in cursor.description]) | |
| else: | |
| df = pd.DataFrame() | |
| conn.close() | |
| return f"โ SQL:\n{sql}\n\n๐ {len(df)} rows", df | |
| except sqlite3.Error as e: | |
| conn.close() | |
| return f"โ Error: {e}\n\nSQL tried:\n{sql}\n\n๐ก Available tables:\n{schema_text}", None | |
| # UI | |
| demo = gr.Interface( | |
| fn=ask_question, | |
| inputs=[ | |
| gr.File(label="๐ Upload Database (.db)"), | |
| gr.Textbox(label="โ Ask Question", placeholder="e.g., show all data, highest value, total count") | |
| ], | |
| outputs=[ | |
| gr.Textbox(label="๐ค SQL & Status", lines=6), | |
| gr.Dataframe(label="๐ Results") | |
| ], | |
| title="๐ฎ Universal Text-to-SQL", | |
| description="Upload ANY SQLite database and ask questions. The AI will figure out the right tables!" | |
| ) | |
| demo.launch() |