T5-LM-Large-text2sql-spider / sqlite_parser.py
NGrov's picture
Upload sqlite_parser.py
698d7e3
raw
history blame contribute delete
No virus
2.37 kB
import sqlite3
import json
import os
from typing import Union, List, Dict
from pathlib import Path
from itertools import chain
from simple_ddl_parser import parse_from_file
class DBParser:
def __init__(self, db_path:Union[str, Path]) -> None:
self.db_path = db_path
self.suffix:str =".sql"
self.primary_key_token:str = "primary key:"
self.foreign_key_token:str = "foreign_key:"
self.separator:str = " [SEP] "
@staticmethod
def dump_sqlite_to_sql(path_to_sqlite: Union[str, Path], output_path: Union[str, Path]) -> None:
assert path_to_sqlite.endswith('.sqlite')
con = sqlite3.connect(path_to_sqlite)
with open(output_path, 'w') as f:
for line in con.iterdump():
f.write('%s\n' % line)
def parse_table(self, table_schema:dict) -> str:
normal_keys = " ".join(list(chain.from_iterable((column["name"], column["type"], ",") for column in table_schema["columns"] if column["references"] is None)))
foreign_keys =" ".join(list(chain.from_iterable((column["name"], column["type"],"from", column["references"]["table"], column["references"]["column"], ",") for column in table_schema["columns"] if column["references"] is not None)))
primary_keys = " ".join(table_schema["primary_key"])
return " ".join([table_schema["table_name"], normal_keys, self.foreign_key_token, foreign_keys, self.primary_key_token, primary_keys])
def parse_schema(self, schema:List[dict]) -> str:
table_schemas: List[str] = [self.parse_table(table) for table in schema if 'columns' in table]
return self.separator.join(table_schemas)
def create_db_prompt_dict(self, output_file: str = 'db_schemas.json') -> Dict[str, str]:
db_schema_dict = {}
for dir in os.listdir(self.db_path):
print("Processing database: ", dir)
filenames = [i for i in os.listdir(Path(self.db_path, dir)) if i.endswith(self.suffix)]
path_to_db = Path(self.db_path, dir,filenames[0])
schema = parse_from_file(path_to_db)
db_schema_dict[dir]=self.parse_schema(schema)
with open(output_file, 'w') as f:
f.write(json.dumps(db_schema_dict))
return db_schema_dict
# # Usage
# db_parser = DBParser(<PATH_To_DATABASES>)
# db_parser.create_db_prompt_dict()