|
import re |
|
import difflib |
|
import pandas as pd |
|
import sqlite3 |
|
import torch |
|
import gradio as gr |
|
from pathlib import Path |
|
from transformers import BartForConditionalGeneration, BartTokenizerFast |
|
|
|
|
|
CHECKPOINT = "./checkpoint/checkpoint" |
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
tokenizer = BartTokenizerFast.from_pretrained(CHECKPOINT) |
|
model = BartForConditionalGeneration.from_pretrained(CHECKPOINT).to(DEVICE) |
|
model.eval() |
|
|
|
|
|
def extract_entities(q: str): |
|
return [] |
|
|
|
|
|
def build_input(question: str, schema_txt: str) -> str: |
|
ents = extract_entities(question) |
|
et = ";".join(ents) if ents else "NONE" |
|
return f"[ENT]{et}[/ENT][SCHEMA]{schema_txt}[/SCHEMA]Question: {question}" |
|
|
|
|
|
def load_csvs_to_sqlite(files): |
|
conn = sqlite3.connect(":memory:") |
|
table_defs = [] |
|
for idx, file in enumerate(files): |
|
tbl = f"t{idx+1}" |
|
df = pd.read_csv(file.name) |
|
df.to_sql(tbl, conn, index=False, if_exists="replace") |
|
cols = ",".join(df.columns) |
|
table_defs.append(f"{tbl}({cols})") |
|
schema = " | ".join(table_defs) |
|
return conn, schema |
|
|
|
|
|
def query_multi_csv(files, question): |
|
try: |
|
conn, schema = load_csvs_to_sqlite(files) |
|
prompt = build_input(question, schema) |
|
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding="longest").to(DEVICE) |
|
out = model.generate(**inputs, num_beams=5, max_length=256, early_stopping=True) |
|
sql = tokenizer.decode(out[0], skip_special_tokens=True) |
|
|
|
|
|
sql = re.sub(r"\bFROMt(\d+)", r"FROM t\1", sql, flags=re.IGNORECASE) |
|
sql = re.sub(r"\bFROM\s+1\b", "FROM t1", sql, flags=re.IGNORECASE) |
|
sql = re.sub(r"WHERE(?=[A-Za-z_])", "WHERE ", sql, flags=re.IGNORECASE) |
|
|
|
try: |
|
df = pd.read_sql_query(sql, conn) |
|
except Exception: |
|
df = pd.DataFrame() |
|
|
|
|
|
if df.empty: |
|
m = re.match( |
|
r'SELECT\s+(?P<proj>.+?)\s+FROM\s+(?P<table>t\d+)\s+WHERE\s+(?P<col>\w+)\s*=\s*["\'](?P<val>.+?)["\']', |
|
sql, re.IGNORECASE |
|
) |
|
if m: |
|
proj, table, col, val = m.group("proj","table","col","val") |
|
like_sql = f"SELECT {proj} FROM {table} WHERE lower({col}) LIKE '%{val.lower()}%' COLLATE NOCASE" |
|
try: |
|
df = pd.read_sql_query(like_sql, conn) |
|
sql = like_sql |
|
except: |
|
pass |
|
|
|
|
|
if df.empty and 'm' in locals() and m: |
|
distinct = pd.read_sql_query(f"SELECT DISTINCT {m.group('col')} FROM {m.group('table')}", conn)[m.group('col')].astype(str).tolist() |
|
close = difflib.get_close_matches(m.group('val'), distinct, n=1, cutoff=0.6) |
|
if close: |
|
corrected = close[0] |
|
fuzzy_sql = f"SELECT {m.group('proj')} FROM {m.group('table')} WHERE {m.group('col')} = '{corrected}'" |
|
try: |
|
df = pd.read_sql_query(fuzzy_sql, conn) |
|
sql = fuzzy_sql |
|
except: |
|
pass |
|
|
|
return sql, df |
|
|
|
except Exception as e: |
|
return f"ERROR: {type(e).__name__}: {e}", pd.DataFrame() |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## Text-to-SQL over Multiple CSVs") |
|
with gr.Row(): |
|
csv_inputs = gr.File(label="Upload one or more CSVs", file_types=['.csv'], file_count="multiple") |
|
question = gr.Textbox(label="Question", placeholder="e.g. Who is author of The Catcher in the Rye?") |
|
submit = gr.Button("Submit") |
|
sql_out = gr.Textbox(label="Generated SQL") |
|
results = gr.Dataframe(label="Query Results") |
|
submit.click(fn=query_multi_csv, inputs=[csv_inputs, question], outputs=[sql_out, results]) |
|
|
|
demo.launch() |
|
|