T5-LM-Large-text2sql-spider / evaluate_with_db.py
NGrov's picture
evaluate script
23ceea2
raw
history blame
No virus
2.35 kB
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import json
import sqlite3
from tqdm import tqdm
from typing import List
import os
from pathlib import Path
db_schemas_path = "db_schemas.json"
model_path = "gaussalgo/T5-LM-Large-text2sql-spider"
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
def query_db(question: str, db_path: str) -> dict:
try:
# assert db_path.endswith('.sqlite')
con = sqlite3.connect(db_path)
cur = con.cursor()
cur.execute(question)
data = cur.fetchall()
return json.dumps(data)
except Exception as e:
print(question, " ", e)
pass
def evaluate(eval_dataset: List[dict]):
reference = []
gen_queries = []
with open(db_schemas_path, "r") as schemas:
db_schema_dict = json.load(schemas)
for data in tqdm(eval_dataset, total=len(eval_dataset), desc="Executing queries"):
question = data["question"]
schema = data["db_id"]
filenames = [
i for i in os.listdir(Path(DB_PATH, schema)) if i.endswith(SQLITE_SUFFIX)
]
path_to_db = Path(DB_PATH, schema, filenames[0])
input_text = " ".join(
["Question: ", question, "Schema:", db_schema_dict[schema]]
)
model_inputs = tokenizer(input_text, return_tensors="pt")
outputs = model.generate(**model_inputs, max_length=512)
output_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
reference.append(query_db(data["query"], path_to_db))
gen_queries.append(query_db(output_text, path_to_db))
equal_results = [ref == q for ref, q in zip(reference, gen_queries)]
eq_results_when_reference_works = [
ref == q for ref, q in zip(reference, gen_queries) if ref is not None
]
num_of_working_ref = len([ref for ref in reference if ref is not None])
print("Length of eval dataset: ", len(eval_dataset))
print("Working references: ", num_of_working_ref)
print("Correct queries in labels: ", num_of_working_ref / len(eval_dataset))
print("Accuracy with whole dataset: ", sum(equal_results) / len(eval_dataset))
print(
"Accuracy with only working references: ",
sum(eq_results_when_reference_works) / num_of_working_ref,
)