avantol's picture
feat(app): more examples, better parsing and error handling
cc69c66
"""
This file handles reshaping raw results data into a list of nodes that match
what we expect the model output to be. That is to say, it handles parsing any raw data.
"""
import json
from collections import defaultdict
from functools import reduce
from typing import Any, Dict, Hashable, Optional, Tuple, Union
from models import DataModel
from shared import (
TOP_LEVEL_IDENTIFIERS,
attempt,
get_json_from_model_output,
keep_errors,
on_fail,
)
def handle_parsing_schema_files(expected_location: str, actual_location: str):
raw_reference, raw_generated = read_in_expected_and_actual_json(
expected_location, actual_location
)
read_errors = keep_errors((raw_reference, raw_generated))
if len(read_errors) > 0:
raise ValueError(f"Could not ingest raw data: {read_errors}")
generated_nodes_to_content = try_parsing_actual_model_output(raw_generated)
reference_nodes_to_content = derive_nodes_from_actual_json_output(
parse_json(raw_reference)
)
errors = keep_errors((reference_nodes_to_content, generated_nodes_to_content))
if len(errors) > 0:
raise ValueError(f"Error parsing files: {errors}")
return generated_nodes_to_content, reference_nodes_to_content
def read_in_expected_and_actual_json(
expected_json_location: str, actual_json_location: str
):
unparsed_expected, err = read_json_as_text(expected_json_location)
if err:
return {"error": f"Could not read in expected file. e: {err}"}, None
unparsed_actual, err = read_json_as_text(actual_json_location)
if err:
return None, {"error": str(err)}
return unparsed_expected, unparsed_actual
def read_json_as_text(file_path: str) -> Tuple[Optional[str], Optional[Exception]]:
"""
Reads the contents of a file as text without attempting to parse it as JSON.
"""
try:
with open(file_path, "r", encoding="utf-8") as file:
return file.read(), None
except Exception as e:
return None, e
def try_parsing_actual_model_output(model_output: str):
first_parse_json = parse_json(model_output)
if isinstance(first_parse_json, list):
return {
"error": "Could not parse json file. Model output should not be a list."
}
first_pass_failed = "error" in first_parse_json
recovered_json, errors = (
get_json_from_model_output(model_output) if first_pass_failed else ({}, 0)
)
if errors > 0:
return {"error": "Could not parse json file, no metrics to calculate"}
parsed_json = recovered_json if first_pass_failed else first_parse_json
node_derivation_outcome = on_fail(
attempt(derive_nodes_from_actual_json_output, (parsed_json,)), []
)
if not node_derivation_outcome:
return {"error": f"Could not derive nodes. Parsed json: {parsed_json}"}
return node_derivation_outcome
def find_all_nodes(name_and_contents) -> Dict:
name, contents = name_and_contents
content_contains_nodes = bool(set(contents.keys()) & TOP_LEVEL_IDENTIFIERS)
if content_contains_nodes:
return dict([name_and_contents])
sub_dicts = list(filter(lambda kvp: isinstance(kvp[1], dict), contents.items()))
all_sub_nodes = {}
for sub_name_and_contents in sub_dicts:
sub_nodes = find_all_nodes(sub_name_and_contents)
all_sub_nodes.update(sub_nodes)
return all_sub_nodes
def assign_to_key(key: Hashable):
def add_at_key(
assignment_mapping: Dict[Hashable, Any], mapping_to_add: Dict[Hashable, Any]
):
assignment_id = mapping_to_add[key]
assignment_mapping[assignment_id] = mapping_to_add
return assignment_mapping
return add_at_key
key_exists = lambda key: lambda mapping: key in mapping
def handle_property_correction(all_nodes):
has_incorrect_property_shape = lambda node: (
"properties" in node[1] and isinstance(node[1].get("properties", {}), list)
)
nodes_that_need_corrected = dict(
filter(has_incorrect_property_shape, all_nodes.items())
)
nodes_with_corrected_properties = dict(
map(correct_properties_for_node, nodes_that_need_corrected.items())
)
all_corrected_nodes = {**all_nodes, **nodes_with_corrected_properties}
return all_corrected_nodes
def correct_properties_for_node(node):
"""Maps node's property names to their actual content."""
node_name, node_data = node
properties = node_data["properties"]
identified_properties = list(filter(key_exists("name"), properties))
actual_properties = reduce(assign_to_key("name"), identified_properties, {})
node_data["properties"] = actual_properties
return node
def derive_nodes_from_actual_json_output(json_data: Union[dict, list]):
"""
Find nodes from non-deterministic AI output
"""
if isinstance(json_data, list):
return {}
all_nodes = flatten_all_nodes(json_data)
if json_data.get("nodes", None) is None:
return all_nodes
nodes_with_properties_corrected = handle_property_correction(all_nodes)
return nodes_with_properties_corrected
def flatten_all_nodes(json_data) -> Dict[Hashable, Any]:
"""
Model output could have nested nodes, this extracts them.
"""
nodes = json_data.get("nodes", None)
if nodes is None:
sub_nodes_list = [
find_all_nodes((name, contents))
for name, contents in json_data.items()
if isinstance(contents, dict)
]
else:
sub_nodes_list = [
find_all_nodes((node["name"], node))
for node in nodes
if isinstance(node, dict) and node.get("name") is not None
]
all_nodes = {k: v for sub_nodes in sub_nodes_list for k, v in sub_nodes.items()}
return all_nodes
def aggregate_desc(acc, node):
node_name, node_info = node
desc = node_info.get("description", None)
if desc is not None and isinstance(desc, str):
acc[desc].add(node_name)
return acc
def reform_links(outer_acc, node):
node_name, node_info = node
links = node_info.get("links", [])
collect_links_to_aggregator = lambda inner_acc, link: upsert_set(
inner_acc, (link, node_name)
)
links_to_node_names = reduce(collect_links_to_aggregator, links, outer_acc)
return links_to_node_names
def lens(key, default=None):
"""Simple way to interface with the contents of a dict"""
return lambda d: d.get(key, default)
def aggregate_properties(outer_acc, node):
node_name, node_info = node
properties = node_info.get("properties", {})
is_list = isinstance(properties, list)
property_names = (
list(map(lens("name"), properties)) if is_list else list(properties.keys())
)
aggregate_properties = reduce(
lambda inner_acc, prop_name: upsert_set(inner_acc, (prop_name, node_name)),
property_names,
outer_acc,
)
return aggregate_properties
def conform_node_to_expected_schema(name_to_data_model):
name, dm = name_to_data_model
conform_result = attempt(DataModel.model_validate, (dm,))
model_outcome = (
conform_result.model_dump(exclude_none=True, exclude_unset=True)
if "error" not in conform_result
else {}
)
errors = conform_result if "error" in conform_result else {}
return (name, model_outcome), errors
def aggregate_parsed_file(nodes: dict):
conformed_nodes_result = [
conform_node_to_expected_schema(node) for node in nodes.items()
]
conformed_nodes = [node for node, errors in conformed_nodes_result if not errors]
aggregated_links = reduce(reform_links, conformed_nodes, defaultdict(set))
aggregated_properties = reduce(
aggregate_properties, conformed_nodes, defaultdict(set)
)
aggregated_descriptions = reduce(aggregate_desc, conformed_nodes, defaultdict(set))
parsed_results = {
"node_names": dict(conformed_nodes),
"links": aggregated_links,
"properties": aggregated_properties,
"description": aggregated_descriptions,
}
return parsed_results
def upsert_set(accumulator, kvp):
key, value = kvp
if isinstance(key, Hashable):
accumulator[key].add(value)
return accumulator
def parse_json(json_string: str) -> Optional[Union[dict, list]]:
"""
Safely parses a JSON string into a Python dictionary or list.
Args:
json_string (str): The JSON string to parse
Returns:
dict/list: Parsed JSON data if successful
dict["error"]: If parsing fails, provides error as string in response
"""
try:
return json.loads(json_string)
except json.JSONDecodeError as e:
return {"error": str(e)}
except TypeError as e:
return {"error": str(e)}