Rushikesh-S-Ware's picture
Update app.py
48dca97 verified
import re
import difflib
import pandas as pd
import sqlite3
import torch
import gradio as gr
from pathlib import Path
from transformers import BartForConditionalGeneration, BartTokenizerFast
# --- Configuration ---
CHECKPOINT = "./checkpoint/checkpoint" # ✅ Path relative to your repo
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# --- Load model & tokenizer ---
tokenizer = BartTokenizerFast.from_pretrained(CHECKPOINT)
model = BartForConditionalGeneration.from_pretrained(CHECKPOINT).to(DEVICE)
model.eval()
# --- Entity extractor stub ---
def extract_entities(q: str):
return []
# --- Prompt builder ---
def build_input(question: str, schema_txt: str) -> str:
ents = extract_entities(question)
et = ";".join(ents) if ents else "NONE"
return f"[ENT]{et}[/ENT][SCHEMA]{schema_txt}[/SCHEMA]Question: {question}"
# --- Load CSVs into SQLite ---
def load_csvs_to_sqlite(files):
conn = sqlite3.connect(":memory:")
table_defs = []
for idx, file in enumerate(files):
tbl = f"t{idx+1}"
df = pd.read_csv(file.name)
df.to_sql(tbl, conn, index=False, if_exists="replace")
cols = ",".join(df.columns)
table_defs.append(f"{tbl}({cols})")
schema = " | ".join(table_defs)
return conn, schema
# --- Main logic ---
def query_multi_csv(files, question):
try:
conn, schema = load_csvs_to_sqlite(files)
prompt = build_input(question, schema)
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding="longest").to(DEVICE)
out = model.generate(**inputs, num_beams=5, max_length=256, early_stopping=True)
sql = tokenizer.decode(out[0], skip_special_tokens=True)
# Post-process
sql = re.sub(r"\bFROMt(\d+)", r"FROM t\1", sql, flags=re.IGNORECASE)
sql = re.sub(r"\bFROM\s+1\b", "FROM t1", sql, flags=re.IGNORECASE)
sql = re.sub(r"WHERE(?=[A-Za-z_])", "WHERE ", sql, flags=re.IGNORECASE)
try:
df = pd.read_sql_query(sql, conn)
except Exception:
df = pd.DataFrame()
# Fallback 1: LIKE
if df.empty:
m = re.match(
r'SELECT\s+(?P<proj>.+?)\s+FROM\s+(?P<table>t\d+)\s+WHERE\s+(?P<col>\w+)\s*=\s*["\'](?P<val>.+?)["\']',
sql, re.IGNORECASE
)
if m:
proj, table, col, val = m.group("proj","table","col","val")
like_sql = f"SELECT {proj} FROM {table} WHERE lower({col}) LIKE '%{val.lower()}%' COLLATE NOCASE"
try:
df = pd.read_sql_query(like_sql, conn)
sql = like_sql
except:
pass
# Fallback 2: fuzzy match
if df.empty and 'm' in locals() and m:
distinct = pd.read_sql_query(f"SELECT DISTINCT {m.group('col')} FROM {m.group('table')}", conn)[m.group('col')].astype(str).tolist()
close = difflib.get_close_matches(m.group('val'), distinct, n=1, cutoff=0.6)
if close:
corrected = close[0]
fuzzy_sql = f"SELECT {m.group('proj')} FROM {m.group('table')} WHERE {m.group('col')} = '{corrected}'"
try:
df = pd.read_sql_query(fuzzy_sql, conn)
sql = fuzzy_sql
except:
pass
return sql, df
except Exception as e:
return f"ERROR: {type(e).__name__}: {e}", pd.DataFrame()
# --- Gradio UI ---
with gr.Blocks() as demo:
gr.Markdown("## Text-to-SQL over Multiple CSVs")
with gr.Row():
csv_inputs = gr.File(label="Upload one or more CSVs", file_types=['.csv'], file_count="multiple")
question = gr.Textbox(label="Question", placeholder="e.g. Who is author of The Catcher in the Rye?")
submit = gr.Button("Submit")
sql_out = gr.Textbox(label="Generated SQL")
results = gr.Dataframe(label="Query Results")
submit.click(fn=query_multi_csv, inputs=[csv_inputs, question], outputs=[sql_out, results])
demo.launch()