import re import difflib import pandas as pd import sqlite3 import torch import gradio as gr from pathlib import Path from transformers import BartForConditionalGeneration, BartTokenizerFast # --- Configuration --- CHECKPOINT = "./checkpoint/checkpoint" # ✅ Path relative to your repo DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # --- Load model & tokenizer --- tokenizer = BartTokenizerFast.from_pretrained(CHECKPOINT) model = BartForConditionalGeneration.from_pretrained(CHECKPOINT).to(DEVICE) model.eval() # --- Entity extractor stub --- def extract_entities(q: str): return [] # --- Prompt builder --- def build_input(question: str, schema_txt: str) -> str: ents = extract_entities(question) et = ";".join(ents) if ents else "NONE" return f"[ENT]{et}[/ENT][SCHEMA]{schema_txt}[/SCHEMA]Question: {question}" # --- Load CSVs into SQLite --- def load_csvs_to_sqlite(files): conn = sqlite3.connect(":memory:") table_defs = [] for idx, file in enumerate(files): tbl = f"t{idx+1}" df = pd.read_csv(file.name) df.to_sql(tbl, conn, index=False, if_exists="replace") cols = ",".join(df.columns) table_defs.append(f"{tbl}({cols})") schema = " | ".join(table_defs) return conn, schema # --- Main logic --- def query_multi_csv(files, question): try: conn, schema = load_csvs_to_sqlite(files) prompt = build_input(question, schema) inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding="longest").to(DEVICE) out = model.generate(**inputs, num_beams=5, max_length=256, early_stopping=True) sql = tokenizer.decode(out[0], skip_special_tokens=True) # Post-process sql = re.sub(r"\bFROMt(\d+)", r"FROM t\1", sql, flags=re.IGNORECASE) sql = re.sub(r"\bFROM\s+1\b", "FROM t1", sql, flags=re.IGNORECASE) sql = re.sub(r"WHERE(?=[A-Za-z_])", "WHERE ", sql, flags=re.IGNORECASE) try: df = pd.read_sql_query(sql, conn) except Exception: df = pd.DataFrame() # Fallback 1: LIKE if df.empty: m = re.match( r'SELECT\s+(?P.+?)\s+FROM\s+(?Pt\d+)\s+WHERE\s+(?P\w+)\s*=\s*["\'](?P.+?)["\']', sql, re.IGNORECASE ) if m: proj, table, col, val = m.group("proj","table","col","val") like_sql = f"SELECT {proj} FROM {table} WHERE lower({col}) LIKE '%{val.lower()}%' COLLATE NOCASE" try: df = pd.read_sql_query(like_sql, conn) sql = like_sql except: pass # Fallback 2: fuzzy match if df.empty and 'm' in locals() and m: distinct = pd.read_sql_query(f"SELECT DISTINCT {m.group('col')} FROM {m.group('table')}", conn)[m.group('col')].astype(str).tolist() close = difflib.get_close_matches(m.group('val'), distinct, n=1, cutoff=0.6) if close: corrected = close[0] fuzzy_sql = f"SELECT {m.group('proj')} FROM {m.group('table')} WHERE {m.group('col')} = '{corrected}'" try: df = pd.read_sql_query(fuzzy_sql, conn) sql = fuzzy_sql except: pass return sql, df except Exception as e: return f"ERROR: {type(e).__name__}: {e}", pd.DataFrame() # --- Gradio UI --- with gr.Blocks() as demo: gr.Markdown("## Text-to-SQL over Multiple CSVs") with gr.Row(): csv_inputs = gr.File(label="Upload one or more CSVs", file_types=['.csv'], file_count="multiple") question = gr.Textbox(label="Question", placeholder="e.g. Who is author of The Catcher in the Rye?") submit = gr.Button("Submit") sql_out = gr.Textbox(label="Generated SQL") results = gr.Dataframe(label="Query Results") submit.click(fn=query_multi_csv, inputs=[csv_inputs, question], outputs=[sql_out, results]) demo.launch()