import re import json import records from typing import List, Dict from sqlalchemy.exc import SQLAlchemyError from utils.sql.all_keywords import ALL_KEY_WORDS class WTQDBEngine: def __init__(self, fdb): self.db = records.Database('sqlite:///{}'.format(fdb)) self.conn = self.db.get_connection() def execute_wtq_query(self, sql_query: str): out = self.conn.query(sql_query) results = out.all() merged_results = [] for i in range(len(results)): merged_results.extend(results[i].values()) return merged_results def delete_rows(self, row_indices: List[int]): sql_queries = [ "delete from w where id == {}".format(row) for row in row_indices ] for query in sql_queries: self.conn.query(query) def process_table_structure(_wtq_table_content: Dict, _add_all_column: bool = False): # remove id and agg column headers = [_.replace("\n", " ").lower() for _ in _wtq_table_content["headers"][2:]] header_map = {} for i in range(len(headers)): header_map["c" + str(i + 1)] = headers[i] header_types = _wtq_table_content["types"][2:] all_headers = [] all_header_types = [] vertical_content = [] for column_content in _wtq_table_content["contents"][2:]: # only take the first one if _add_all_column: for i in range(len(column_content)): column_alias = column_content[i]["col"] # do not add the numbered column if "_number" in column_alias: continue vertical_content.append([str(_).replace("\n", " ").lower() for _ in column_content[i]["data"]]) if "_" in column_alias: first_slash_pos = column_alias.find("_") column_name = header_map[column_alias[:first_slash_pos]] + " " + \ column_alias[first_slash_pos + 1:].replace("_", " ") else: column_name = header_map[column_alias] all_headers.append(column_name) if column_content[i]["type"] == "TEXT": all_header_types.append("text") else: all_header_types.append("number") else: vertical_content.append([str(_).replace("\n", " ").lower() for _ in column_content[0]["data"]]) row_content = list(map(list, zip(*vertical_content))) if _add_all_column: ret_header = all_headers ret_types = all_header_types else: ret_header = headers ret_types = header_types return { "header": ret_header, "rows": row_content, "types": ret_types, "alias": list(_wtq_table_content["is_list"].keys()) } def retrieve_wtq_query_answer(_engine, _table_content, _sql_struct: List): # do not append id / agg headers = _table_content["header"] def flatten_sql(_ex_sql_struct: List): # [ "Keyword", "select", [] ], [ "Column", "c4", [] ] _encode_sql = [] _execute_sql = [] for _ex_tuple in _ex_sql_struct: keyword = str(_ex_tuple[1]) # upper the keywords. if keyword in ALL_KEY_WORDS: keyword = str(keyword).upper() # extra column, which we do not need in result if keyword == "w" or keyword == "from": # add 'FROM w' make it executable _encode_sql.append(keyword) elif re.fullmatch(r"c\d+(_.+)?", keyword): # only take the first part index_key = int(keyword.split("_")[0][1:]) - 1 # wrap it with `` to make it executable _encode_sql.append("`{}`".format(headers[index_key])) else: _encode_sql.append(keyword) # c4_list, replace it with the original one if "_address" in keyword or "_list" in keyword: keyword = re.findall(r"c\d+", keyword)[0] _execute_sql.append(keyword) return " ".join(_execute_sql), " ".join(_encode_sql) _exec_sql_str, _encode_sql_str = flatten_sql(_sql_struct) try: _sql_answers = _engine.execute_wtq_query(_exec_sql_str) except SQLAlchemyError as e: _sql_answers = [] _norm_sql_answers = [str(_).replace("\n", " ") for _ in _sql_answers if _ is not None] if "none" in _norm_sql_answers: _norm_sql_answers = [] return _encode_sql_str, _norm_sql_answers, _exec_sql_str def _load_table_w_page(table_path, page_title_path=None) -> dict: """ attention: the table_path must be the .tsv path. Load the WikiTableQuestion from csv file. Result in a dict format like: {"header": [header1, header2,...], "rows": [[row11, row12, ...], [row21,...]... [...rownm]]} """ from utils.utils import _load_table table_item = _load_table(table_path) # Load page title if not page_title_path: page_title_path = table_path.replace("csv", "page").replace(".tsv", ".json") with open(page_title_path, "r") as f: page_title = json.load(f)['title'] table_item['page_title'] = page_title return table_item