text2sql-demo / src /predict.py
tjhalanigrid's picture
Add src folder
dc59b01
import torch
import sqlite3
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# --------------------------------------------------
# PATH
# --------------------------------------------------
MODEL_PATH = "outputs/model"
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
print("Loading fine-tuned model...")
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_PATH)
model.eval()
# --------------------------------------------------
# CONNECT DATABASE
# --------------------------------------------------
print("Connecting to database...")
# conn = sqlite3.connect("../data/database/department_management/department_management.sqlite")
conn = sqlite3.connect("data/database/department_management/department_management.sqlite")
cursor = conn.cursor()
print("Database connected ✔")
# --------------------------------------------------
# BUILD PROMPT
# --------------------------------------------------
def build_prompt(question):
schema = """
Table department columns = Department_ID, Name, Creation, Ranking, Budget_in_Billions, Num_Employees.
Table head columns = head_ID, name, born_state, age.
Table management columns = department_ID, head_ID, temporary_acting.
"""
return f"translate English to SQL: {schema} question: {question}"
# --------------------------------------------------
# GENERATE SQL
# --------------------------------------------------
def generate_sql(question):
prompt = build_prompt(question)
encoding = tokenizer(
prompt,
return_tensors="pt",
truncation=True,
padding=True,
max_length=256
)
with torch.no_grad():
outputs = model.generate(
input_ids=encoding["input_ids"],
attention_mask=encoding["attention_mask"],
max_length=256,
num_beams=5,
early_stopping=True
)
sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
return sql.strip()
# --------------------------------------------------
# EVALUATE SQL (REWARD FUNCTION)
# --------------------------------------------------
def evaluate_sql(sql):
try:
cursor.execute(sql)
rows = cursor.fetchall()
# executed but no useful result
if len(rows) == 0:
return -0.2, rows
# good query
else:
return 1.0, rows
except Exception as e:
# invalid SQL
return -1.0, str(e)
# --------------------------------------------------
# INTERACTIVE LOOP
# --------------------------------------------------
while True:
q = input("\nAsk question (type exit to quit): ")
if q.lower() == "exit":
break
sql = generate_sql(q)
print("\nPredicted SQL:")
print(sql)
# ---------------- RUN SQL + REWARD ----------------
reward, output = evaluate_sql(sql)
print("\nReward:", reward)
if reward == -1.0:
print("SQL Error:", output)
elif reward == -0.2:
print("No results found")
else:
print("\nAnswer:")
for r in output:
print(r)