import copy import os import sqlite3 import records import sqlalchemy import pandas as pd from typing import Dict, List import uuid from utils.normalizer import convert_df_type, prepare_df_for_neuraldb_from_table from utils.mmqa.image_stuff import get_caption def check_in_and_return(key: str, source: dict): # `` wrapped means as a whole if key.startswith("`") and key.endswith("`"): key = key[1:-1] if key in source.keys(): return source[key] else: for _k, _v in source.items(): if _k.lower() == key.lower(): return _v raise ValueError("{} not in {}".format(key, source)) class NeuralDB(object): def __init__(self, tables: List[Dict[str, Dict]], passages=None, images=None): self.raw_tables = copy.deepcopy(tables) self.passages = {} self.images = {} self.image_captions = {} self.passage_linker = {} # The links from cell value to passage self.image_linker = {} # The links from cell value to images # Get passages if passages: for passage in passages: title, passage_content = passage['title'], passage['text'] self.passages[title] = passage_content # Get images if images: for image in images: _id, title, picture = image['id'], image['title'], image['pic'] self.images[title] = picture self.image_captions[title] = get_caption(_id) # Link grounding resources from other modalities(passages, images). if self.raw_tables[0]['table'].get('rows_with_links', None): rows = self.raw_tables[0]['table']['rows'] rows_with_links = self.raw_tables[0]['table']['rows_with_links'] link_title2cell_map = {} for row_id in range(len(rows)): for col_id in range(len(rows[row_id])): cell = rows_with_links[row_id][col_id] for text, title, url in zip(cell[0], cell[1], cell[2]): text = text.lower().strip() link_title2cell_map[title] = text # Link Passages for passage in passages: title, passage_content = passage['title'], passage['text'] linked_cell = link_title2cell_map.get(title, None) if linked_cell: self.passage_linker[linked_cell] = title # Images for image in images: title, picture = image['title'], image['pic'] linked_cell = link_title2cell_map.get(title, None) if linked_cell: self.image_linker[linked_cell] = title for table_info in tables: table_info['table'] = prepare_df_for_neuraldb_from_table(table_info['table']) self.tables = tables # Connect to SQLite database self.tmp_path = "tmp" os.makedirs(self.tmp_path, exist_ok=True) # self.db_path = os.path.join(self.tmp_path, '{}.db'.format(hash(time.time()))) self.db_path = os.path.join(self.tmp_path, '{}.db'.format(uuid.uuid4())) self.sqlite_conn = sqlite3.connect(self.db_path) # Create DB assert len(tables) >= 1, "DB has no table inside" table_0 = tables[0] if len(tables) > 1: raise ValueError("More than one table not support yet.") else: table_0["table"].to_sql("w", self.sqlite_conn) self.table_name = "w" self.table_title = table_0.get('title', None) # Records conn self.db = records.Database('sqlite:///{}'.format(self.db_path)) self.records_conn = self.db.get_connection() def __str__(self): return str(self.execute_query("SELECT * FROM {}".format(self.table_name))) def get_table(self, table_name=None): table_name = self.table_name if not table_name else table_name sql_query = "SELECT * FROM {}".format(table_name) _table = self.execute_query(sql_query) return _table def get_header(self, table_name=None): _table = self.get_table(table_name) return _table['header'] def get_rows(self, table_name): _table = self.get_table(table_name) return _table['rows'] def get_table_df(self): return self.tables[0]['table'] def get_table_raw(self): return self.raw_tables[0]['table'] def get_table_title(self): return self.tables[0]['title'] def get_passages_titles(self): return list(self.passages.keys()) def get_images_titles(self): return list(self.images.keys()) def get_passage_by_title(self, title: str): return check_in_and_return(title, self.passages) def get_image_by_title(self, title): return check_in_and_return(title, self.images) def get_image_caption_by_title(self, title): return check_in_and_return(title, self.image_captions) def get_image_linker(self): return copy.deepcopy(self.image_linker) def get_passage_linker(self): return copy.deepcopy(self.passage_linker) def execute_query(self, sql_query: str): """ Basic operation. Execute the sql query on the database we hold. @param sql_query: @return: """ # When the sql query is a column name (@deprecated: or a certain value with '' and "" surrounded). if len(sql_query.split(' ')) == 1 or (sql_query.startswith('`') and sql_query.endswith('`')): col_name = sql_query new_sql_query = r"SELECT row_id, {} FROM {}".format(col_name, self.table_name) # Here we use a hack that when a value is surrounded by '' or "", the sql will return a column of the value, # while for variable, no ''/"" surrounded, this sql will query for the column. out = self.records_conn.query(new_sql_query) # When the sql query wants all cols or col_id, which is no need for us to add 'row_id'. elif sql_query.lower().startswith("select *") or sql_query.startswith("select col_id"): out = self.records_conn.query(sql_query) else: try: # SELECT row_id in addition, needed for result and old table alignment. new_sql_query = "SELECT row_id, " + sql_query[7:] out = self.records_conn.query(new_sql_query) except sqlalchemy.exc.OperationalError as e: # Execute normal SQL, and in this case the row_id is actually in no need. out = self.records_conn.query(sql_query) results = out.all() unmerged_results = [] merged_results = [] headers = out.dataset.headers for i in range(len(results)): unmerged_results.append(list(results[i].values())) merged_results.extend(results[i].values()) return {"header": headers, "rows": unmerged_results} def add_sub_table(self, sub_table, table_name=None, verbose=True): """ Add sub_table into the table. @return: """ table_name = self.table_name if not table_name else table_name sql_query = "SELECT * FROM {}".format(table_name) oring_table = self.execute_query(sql_query) old_table = pd.DataFrame(oring_table["rows"], columns=oring_table["header"]) # concat the new column into old table sub_table_df_normed = convert_df_type(pd.DataFrame(data=sub_table['rows'], columns=sub_table['header'])) new_table = old_table.merge(sub_table_df_normed, how='left', on='row_id') # do left join new_table.to_sql(table_name, self.sqlite_conn, if_exists='replace', index=False) if verbose: print("Insert column(s) {} (dtypes: {}) into table.\n".format(', '.join([_ for _ in sub_table['header']]), sub_table_df_normed.dtypes))