Spaces:
Running
on
Zero
Running
on
Zero
import sqlite3 | |
def convert_type(type): | |
""" | |
Returns SQL type for given AI generated type | |
This function takes AI generated type and returns SQL type. | |
For simplified Data Dictionary enums are converted to text data type, and | |
arrays are converted in text arrays | |
Parameters: | |
type (str): AI generated type | |
Returns: | |
sql_type (str): SQL type | |
""" | |
sql_match = { | |
"string": "TEXT", | |
"integer": "INTEGER", | |
"number": "REAL", | |
"boolean": "BOOLEAN", | |
"array": "TEXT[]", | |
"enum": "TEXT", | |
} | |
sql_type = sql_match.get(type, "TEXT") | |
return sql_type | |
def get_pk_field(node): | |
""" | |
Returns primary key field for given AI generated node | |
This function takes AI generated node dictionary and returns primary key field. | |
Parameters: | |
node (dict): AI generated node dictionary | |
Returns: | |
pk_field (str): Primary key field | |
""" | |
# Look for a typical PK pattern: <table>.id | |
for prop in node["properties"]: | |
if prop["name"] == f"{node['name']}.id": | |
return prop["name"] | |
# Fallback | |
return None | |
def get_all_columns(node): | |
""" | |
Returns all columns for given AI generated node | |
This function takes AI generated node dictionary and returns all columns. | |
Parameters: | |
node (dict): AI generated node dictionary | |
Returns: | |
columns (list): List of column names | |
""" | |
return [prop["name"] for prop in node["properties"]] | |
def as_sql_col(prop_name): | |
""" | |
Returns property name as a sql column name with "." replaced with "__" | |
This function takes AI generated DD node property name and replaces "." with "__". | |
Dot in the field name may cause issues during the SQL table creation. | |
Parameters: | |
prop_name (str): property name | |
Returns: | |
col_name (str): Column name with "." replaced with "__" | |
""" | |
return prop_name.replace(".", "__") | |
def get_foreign_table_and_field(prop_name, node_name): | |
""" | |
Returns foreign table and field for given property name and node_name | |
This function takes AI generated DD node name and property name and returns foreign table and field. | |
Parameters: | |
prop_name (str): property name | |
node_name (str): node name | |
Returns: | |
foreign_table (str): Foreign table name | |
foreign_field (str): Foreign field name | |
""" | |
# Looks for pattern: e.g. project.id when not in 'project' | |
if prop_name.endswith(".id") and not prop_name.startswith(node_name + "."): | |
parent = prop_name.split(".")[0] | |
return parent, prop_name | |
return None, None | |
def transform_dd(dd): | |
""" | |
Returns transformed DD | |
This function takes AI generated DD and ensures all required fields are | |
present in properties and properties are dictionaries. | |
Parameters: | |
dd (dict): AI generated DD | |
Returns: | |
dd (dict): Transformed DD | |
""" | |
for node in dd.get("nodes", []): | |
props = node.get("properties", []) | |
if props and all(isinstance(x, dict) for x in props): | |
prop_names = {p["name"] for p in props} | |
elif props and all(isinstance(x, str) for x in props): | |
prop_names = set(props) | |
# Upgrade to list of dicts | |
props = [ | |
{"name": prop, "description": "", "type": "string"} for prop in props | |
] | |
else: | |
props = [] | |
prop_names = set() | |
# Ensure each required field is present in properties | |
for req in node.get("required", []): | |
if req not in prop_names: | |
props.append({"name": req, "description": "", "type": "string"}) | |
prop_names.add(req) | |
node["properties"] = props | |
return dd | |
def generate_create_table(node, table_lookup): | |
""" | |
Returns SQL for the given AI generated node | |
This function takes AI generated node dictionary and returns SQL for the node. | |
Parameters: | |
node (dict): AI generated node dictionary | |
table_lookup (dict): Dictionary of tables and their columns | |
Returns: | |
sql (str): SQL for the node | |
""" | |
col_lines = [] | |
fk_constraints = [] | |
pk_fields = [] | |
pk_field = get_pk_field(node) | |
required = node.get("required", []) | |
for prop in node["properties"]: | |
col = prop["name"] | |
coltype = convert_type(prop["type"]) | |
sql_col = as_sql_col(col) | |
line = f' "{sql_col}" {coltype}' | |
if pk_field and col == pk_field: | |
pk_fields.append(sql_col) | |
if col in required or (pk_field and col == pk_field): | |
line += " NOT NULL" | |
col_lines.append(line) | |
# Foreign Keys | |
parent, parent_field = get_foreign_table_and_field(col, node["name"]) | |
if parent: | |
ref_col = as_sql_col(parent_field) | |
parent_cols = table_lookup.get(parent, {}) | |
if parent_field in parent_cols: | |
fk_constraints.append( | |
f' FOREIGN KEY ("{sql_col}") REFERENCES "{parent}"("{ref_col}")' | |
) | |
else: | |
fk_constraints.append( | |
f" -- WARNING: {parent} does not have field {parent_field}" | |
) | |
# Primary Keys | |
constraints = [] | |
if pk_fields: | |
constraint_sql = ", ".join(f'"{c}"' for c in pk_fields) | |
constraints.append(f" PRIMARY KEY ({constraint_sql})") | |
lines = col_lines + constraints + fk_constraints | |
return f'CREATE TABLE "{node["name"]}" (\n' + ",\n".join(lines) + "\n);" | |
def validate_sql(sql, node_name): | |
""" | |
Returns validation result for the given SQL | |
This function takes SQL and node name and returns validation result. | |
Parameters: | |
sql (str): SQL | |
node_name (str): Node name | |
Returns: | |
validation_result (str): Validation result | |
""" | |
conn = sqlite3.connect(":memory:") | |
try: | |
conn.execute(sql) | |
validation_result = f'Valid SQL for table "{node_name}"\n' | |
except sqlite3.Error as e: | |
validation_result = f'Invalid SQL for table "{node_name}":\n{e}\n' | |
finally: | |
conn.close() | |
return validation_result | |
def dd_to_sql(dd): | |
""" | |
Returns SQL for the given AI generated DD | |
This function takes AI generated DD and returns SQL for the DD. | |
Parameters: | |
dd (dict): AI generated DD | |
Returns: | |
sql (str): SQL | |
validation (str): Validation result | |
""" | |
dd = transform_dd(dd) | |
# Build a lookup for table columns in all nodes | |
table_lookup = {} | |
for node in dd["nodes"]: | |
table_lookup[node["name"]] = get_all_columns(node) | |
# pprint.pprint(table_lookup) | |
# Generate SQL | |
combined_sql = "" | |
validation = "Validation notes:\n" | |
for node in dd["nodes"]: | |
sql = generate_create_table(node, table_lookup) + "\n\n" | |
validation = validation + validate_sql(sql, node["name"]) | |
combined_sql = combined_sql + sql | |
return combined_sql, validation | |