File size: 4,077 Bytes
3adc1c7
 
 
 
 
 
 
 
 
 
48dca97
3adc1c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import re
import difflib
import pandas as pd
import sqlite3
import torch
import gradio as gr
from pathlib import Path
from transformers import BartForConditionalGeneration, BartTokenizerFast

# --- Configuration ---
CHECKPOINT = "./checkpoint/checkpoint"  # ✅ Path relative to your repo
DEVICE     = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Load model & tokenizer ---
tokenizer = BartTokenizerFast.from_pretrained(CHECKPOINT)
model     = BartForConditionalGeneration.from_pretrained(CHECKPOINT).to(DEVICE)
model.eval()

# --- Entity extractor stub ---
def extract_entities(q: str):
    return []

# --- Prompt builder ---
def build_input(question: str, schema_txt: str) -> str:
    ents = extract_entities(question)
    et = ";".join(ents) if ents else "NONE"
    return f"[ENT]{et}[/ENT][SCHEMA]{schema_txt}[/SCHEMA]Question: {question}"

# --- Load CSVs into SQLite ---
def load_csvs_to_sqlite(files):
    conn = sqlite3.connect(":memory:")
    table_defs = []
    for idx, file in enumerate(files):
        tbl = f"t{idx+1}"
        df = pd.read_csv(file.name)
        df.to_sql(tbl, conn, index=False, if_exists="replace")
        cols = ",".join(df.columns)
        table_defs.append(f"{tbl}({cols})")
    schema = " | ".join(table_defs)
    return conn, schema

# --- Main logic ---
def query_multi_csv(files, question):
    try:
        conn, schema = load_csvs_to_sqlite(files)
        prompt = build_input(question, schema)
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding="longest").to(DEVICE)
        out    = model.generate(**inputs, num_beams=5, max_length=256, early_stopping=True)
        sql    = tokenizer.decode(out[0], skip_special_tokens=True)

        # Post-process
        sql = re.sub(r"\bFROMt(\d+)", r"FROM t\1", sql, flags=re.IGNORECASE)
        sql = re.sub(r"\bFROM\s+1\b", "FROM t1", sql, flags=re.IGNORECASE)
        sql = re.sub(r"WHERE(?=[A-Za-z_])", "WHERE ", sql, flags=re.IGNORECASE)

        try:
            df = pd.read_sql_query(sql, conn)
        except Exception:
            df = pd.DataFrame()

        # Fallback 1: LIKE
        if df.empty:
            m = re.match(
                r'SELECT\s+(?P<proj>.+?)\s+FROM\s+(?P<table>t\d+)\s+WHERE\s+(?P<col>\w+)\s*=\s*["\'](?P<val>.+?)["\']',
                sql, re.IGNORECASE
            )
            if m:
                proj, table, col, val = m.group("proj","table","col","val")
                like_sql = f"SELECT {proj} FROM {table} WHERE lower({col}) LIKE '%{val.lower()}%' COLLATE NOCASE"
                try:
                    df = pd.read_sql_query(like_sql, conn)
                    sql = like_sql
                except:
                    pass

        # Fallback 2: fuzzy match
        if df.empty and 'm' in locals() and m:
            distinct = pd.read_sql_query(f"SELECT DISTINCT {m.group('col')} FROM {m.group('table')}", conn)[m.group('col')].astype(str).tolist()
            close = difflib.get_close_matches(m.group('val'), distinct, n=1, cutoff=0.6)
            if close:
                corrected = close[0]
                fuzzy_sql = f"SELECT {m.group('proj')} FROM {m.group('table')} WHERE {m.group('col')} = '{corrected}'"
                try:
                    df = pd.read_sql_query(fuzzy_sql, conn)
                    sql = fuzzy_sql
                except:
                    pass

        return sql, df

    except Exception as e:
        return f"ERROR: {type(e).__name__}: {e}", pd.DataFrame()

# --- Gradio UI ---
with gr.Blocks() as demo:
    gr.Markdown("## Text-to-SQL over Multiple CSVs")
    with gr.Row():
        csv_inputs = gr.File(label="Upload one or more CSVs", file_types=['.csv'], file_count="multiple")
        question   = gr.Textbox(label="Question", placeholder="e.g. Who is author of The Catcher in the Rye?")
    submit    = gr.Button("Submit")
    sql_out   = gr.Textbox(label="Generated SQL")
    results   = gr.Dataframe(label="Query Results")
    submit.click(fn=query_multi_csv, inputs=[csv_inputs, question], outputs=[sql_out, results])

demo.launch()