from transformers import AutoModelForSeq2SeqLM, AutoTokenizer import json import sqlite3 from tqdm import tqdm from typing import List import os from pathlib import Path db_schemas_path = "db_schemas.json" model_path = "gaussalgo/T5-LM-Large-text2sql-spider" model = AutoModelForSeq2SeqLM.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_path) def query_db(question: str, db_path: str) -> dict: try: # assert db_path.endswith('.sqlite') con = sqlite3.connect(db_path) cur = con.cursor() cur.execute(question) data = cur.fetchall() return json.dumps(data) except Exception as e: print(question, " ", e) pass def evaluate(eval_dataset: List[dict]): reference = [] gen_queries = [] with open(db_schemas_path, "r") as schemas: db_schema_dict = json.load(schemas) for data in tqdm(eval_dataset, total=len(eval_dataset), desc="Executing queries"): question = data["question"] schema = data["db_id"] filenames = [ i for i in os.listdir(Path(DB_PATH, schema)) if i.endswith(SQLITE_SUFFIX) ] path_to_db = Path(DB_PATH, schema, filenames[0]) input_text = " ".join( ["Question: ", question, "Schema:", db_schema_dict[schema]] ) model_inputs = tokenizer(input_text, return_tensors="pt") outputs = model.generate(**model_inputs, max_length=512) output_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] reference.append(query_db(data["query"], path_to_db)) gen_queries.append(query_db(output_text, path_to_db)) equal_results = [ref == q for ref, q in zip(reference, gen_queries)] eq_results_when_reference_works = [ ref == q for ref, q in zip(reference, gen_queries) if ref is not None ] num_of_working_ref = len([ref for ref in reference if ref is not None]) print("Length of eval dataset: ", len(eval_dataset)) print("Working references: ", num_of_working_ref) print("Correct queries in labels: ", num_of_working_ref / len(eval_dataset)) print("Accuracy with whole dataset: ", sum(equal_results) / len(eval_dataset)) print( "Accuracy with only working references: ", sum(eq_results_when_reference_works) / num_of_working_ref, )