File size: 2,990 Bytes
37f526a
63e6a46
 
 
fb85d2e
63e6a46
 
 
fb85d2e
63e6a46
 
fb85d2e
 
63e6a46
 
 
 
fb85d2e
 
63e6a46
 
fb85d2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63e6a46
fb85d2e
 
f248584
fb85d2e
 
 
 
 
d29a670
fb85d2e
 
 
 
e6e05c7
fb85d2e
 
 
 
 
 
 
 
 
 
 
 
 
63e6a46
fb85d2e
63e6a46
fb85d2e
 
 
 
63e6a46
fb85d2e
 
63e6a46
fb85d2e
63e6a46
fb85d2e
63e6a46
 
 
fb85d2e
 
63e6a46
 
fb85d2e
 
63e6a46
fb85d2e
 
63e6a46
 
fb85d2e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
import sqlite3
import pandas as pd
import gradio as gr
import re
from langchain_community.llms import HuggingFacePipeline
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

# Load model
model_id = "defog/sqlcoder-7b-2"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map="auto")
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=256, do_sample=False)
sqlcoder_llm = HuggingFacePipeline(pipeline=pipe)

def ask_question(user_db, question):
    if not user_db:
        return "โŒ Upload database", None
    
    conn = sqlite3.connect(user_db.name)
    cursor = conn.cursor()
    
    # Get full schema with columns
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = [row[0] for row in cursor.fetchall()]
    
    schema_info = []
    for table in tables:
        cursor.execute(f"PRAGMA table_info({table});")
        columns = [col[1] for col in cursor.fetchall()]
        schema_info.append(f"{table}({', '.join(columns)})")
    
    schema_text = "\n".join(schema_info)
    
    # Smart prompt - let model figure out the right table
    prompt = f"""You are a SQL expert. Generate ONLY the SQL query, nothing else.

Database Schema:
{schema_text}

Instructions:
- Use the EXACT table and column names from the schema above
- If user asks about concepts (like "sales", "customers", "products"), find the most relevant table
- Return ONLY valid SQL with semicolon
- No explanations, no markdown, just SQL

Question: {question}
SQL:"""
    
    # Generate SQL
    response = sqlcoder_llm.invoke(prompt)
    sql = str(response).strip()
    
    # Extract SQL
    if "SQL:" in sql:
        sql = sql.split("SQL:")[-1].strip()
    sql = sql.split("\n")[0].strip()
    if not sql.endswith(";"): 
        sql += ";"
    
    # Remove common formatting
    sql = sql.replace("```sql", "").replace("```", "").strip()
    
    # Execute
    try:
        cursor.execute(sql)
        rows = cursor.fetchall()
        if cursor.description:
            df = pd.DataFrame(rows, columns=[d[0] for d in cursor.description])
        else:
            df = pd.DataFrame()
        conn.close()
        return f"โœ… SQL:\n{sql}\n\n๐Ÿ“Š {len(df)} rows", df
    except sqlite3.Error as e:
        conn.close()
        return f"โŒ Error: {e}\n\nSQL tried:\n{sql}\n\n๐Ÿ’ก Available tables:\n{schema_text}", None

# UI
demo = gr.Interface(
    fn=ask_question,
    inputs=[
        gr.File(label="๐Ÿ“ Upload Database (.db)"),
        gr.Textbox(label="โ“ Ask Question", placeholder="e.g., show all data, highest value, total count")
    ],
    outputs=[
        gr.Textbox(label="๐Ÿค– SQL & Status", lines=6),
        gr.Dataframe(label="๐Ÿ“Š Results")
    ],
    title="๐Ÿ”ฎ Universal Text-to-SQL",
    description="Upload ANY SQLite database and ask questions. The AI will figure out the right tables!"
)

demo.launch()