|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
import json |
|
import sqlite3 |
|
from tqdm import tqdm |
|
from typing import List |
|
import os |
|
from pathlib import Path |
|
|
|
db_schemas_path = "db_schemas.json" |
|
model_path = "gaussalgo/T5-LM-Large-text2sql-spider" |
|
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_path) |
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
|
|
|
def query_db(question: str, db_path: str) -> dict: |
|
try: |
|
|
|
con = sqlite3.connect(db_path) |
|
cur = con.cursor() |
|
cur.execute(question) |
|
data = cur.fetchall() |
|
return json.dumps(data) |
|
except Exception as e: |
|
print(question, " ", e) |
|
pass |
|
|
|
|
|
def evaluate(eval_dataset: List[dict]): |
|
reference = [] |
|
gen_queries = [] |
|
|
|
with open(db_schemas_path, "r") as schemas: |
|
db_schema_dict = json.load(schemas) |
|
|
|
for data in tqdm(eval_dataset, total=len(eval_dataset), desc="Executing queries"): |
|
question = data["question"] |
|
schema = data["db_id"] |
|
|
|
filenames = [ |
|
i for i in os.listdir(Path(DB_PATH, schema)) if i.endswith(SQLITE_SUFFIX) |
|
] |
|
path_to_db = Path(DB_PATH, schema, filenames[0]) |
|
|
|
input_text = " ".join( |
|
["Question: ", question, "Schema:", db_schema_dict[schema]] |
|
) |
|
model_inputs = tokenizer(input_text, return_tensors="pt") |
|
outputs = model.generate(**model_inputs, max_length=512) |
|
|
|
output_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] |
|
reference.append(query_db(data["query"], path_to_db)) |
|
gen_queries.append(query_db(output_text, path_to_db)) |
|
|
|
equal_results = [ref == q for ref, q in zip(reference, gen_queries)] |
|
eq_results_when_reference_works = [ |
|
ref == q for ref, q in zip(reference, gen_queries) if ref is not None |
|
] |
|
num_of_working_ref = len([ref for ref in reference if ref is not None]) |
|
print("Length of eval dataset: ", len(eval_dataset)) |
|
print("Working references: ", num_of_working_ref) |
|
print("Correct queries in labels: ", num_of_working_ref / len(eval_dataset)) |
|
print("Accuracy with whole dataset: ", sum(equal_results) / len(eval_dataset)) |
|
print( |
|
"Accuracy with only working references: ", |
|
sum(eq_results_when_reference_works) / num_of_working_ref, |
|
) |
|
|