Spaces:
Running
on
Zero
Running
on
Zero
""" | |
This file handles reshaping raw results data into a list of nodes that match | |
what we expect the model output to be. That is to say, it handles parsing any raw data. | |
""" | |
import json | |
from collections import defaultdict | |
from functools import reduce | |
from typing import Any, Dict, Hashable, Optional, Tuple, Union | |
from models import DataModel | |
from shared import ( | |
TOP_LEVEL_IDENTIFIERS, | |
attempt, | |
get_json_from_model_output, | |
keep_errors, | |
on_fail, | |
) | |
def handle_parsing_schema_files(expected_location: str, actual_location: str): | |
raw_reference, raw_generated = read_in_expected_and_actual_json( | |
expected_location, actual_location | |
) | |
read_errors = keep_errors((raw_reference, raw_generated)) | |
if len(read_errors) > 0: | |
raise ValueError(f"Could not ingest raw data: {read_errors}") | |
generated_nodes_to_content = try_parsing_actual_model_output(raw_generated) | |
reference_nodes_to_content = derive_nodes_from_actual_json_output( | |
parse_json(raw_reference) | |
) | |
errors = keep_errors((reference_nodes_to_content, generated_nodes_to_content)) | |
if len(errors) > 0: | |
raise ValueError(f"Error parsing files: {errors}") | |
return generated_nodes_to_content, reference_nodes_to_content | |
def read_in_expected_and_actual_json( | |
expected_json_location: str, actual_json_location: str | |
): | |
unparsed_expected, err = read_json_as_text(expected_json_location) | |
if err: | |
return {"error": f"Could not read in expected file. e: {err}"}, None | |
unparsed_actual, err = read_json_as_text(actual_json_location) | |
if err: | |
return None, {"error": str(err)} | |
return unparsed_expected, unparsed_actual | |
def read_json_as_text(file_path: str) -> Tuple[Optional[str], Optional[Exception]]: | |
""" | |
Reads the contents of a file as text without attempting to parse it as JSON. | |
""" | |
try: | |
with open(file_path, "r", encoding="utf-8") as file: | |
return file.read(), None | |
except Exception as e: | |
return None, e | |
def try_parsing_actual_model_output(model_output: str): | |
first_parse_json = parse_json(model_output) | |
if isinstance(first_parse_json, list): | |
return { | |
"error": "Could not parse json file. Model output should not be a list." | |
} | |
first_pass_failed = "error" in first_parse_json | |
recovered_json, errors = ( | |
get_json_from_model_output(model_output) if first_pass_failed else ({}, 0) | |
) | |
if errors > 0: | |
return {"error": "Could not parse json file, no metrics to calculate"} | |
parsed_json = recovered_json if first_pass_failed else first_parse_json | |
node_derivation_outcome = on_fail( | |
attempt(derive_nodes_from_actual_json_output, (parsed_json,)), [] | |
) | |
if not node_derivation_outcome: | |
return {"error": f"Could not derive nodes. Parsed json: {parsed_json}"} | |
return node_derivation_outcome | |
def find_all_nodes(name_and_contents) -> Dict: | |
name, contents = name_and_contents | |
content_contains_nodes = bool(set(contents.keys()) & TOP_LEVEL_IDENTIFIERS) | |
if content_contains_nodes: | |
return dict([name_and_contents]) | |
sub_dicts = list(filter(lambda kvp: isinstance(kvp[1], dict), contents.items())) | |
all_sub_nodes = {} | |
for sub_name_and_contents in sub_dicts: | |
sub_nodes = find_all_nodes(sub_name_and_contents) | |
all_sub_nodes.update(sub_nodes) | |
return all_sub_nodes | |
def assign_to_key(key: Hashable): | |
def add_at_key( | |
assignment_mapping: Dict[Hashable, Any], mapping_to_add: Dict[Hashable, Any] | |
): | |
assignment_id = mapping_to_add[key] | |
assignment_mapping[assignment_id] = mapping_to_add | |
return assignment_mapping | |
return add_at_key | |
key_exists = lambda key: lambda mapping: key in mapping | |
def handle_property_correction(all_nodes): | |
has_incorrect_property_shape = lambda node: ( | |
"properties" in node[1] and isinstance(node[1].get("properties", {}), list) | |
) | |
nodes_that_need_corrected = dict( | |
filter(has_incorrect_property_shape, all_nodes.items()) | |
) | |
nodes_with_corrected_properties = dict( | |
map(correct_properties_for_node, nodes_that_need_corrected.items()) | |
) | |
all_corrected_nodes = {**all_nodes, **nodes_with_corrected_properties} | |
return all_corrected_nodes | |
def correct_properties_for_node(node): | |
"""Maps node's property names to their actual content.""" | |
node_name, node_data = node | |
properties = node_data["properties"] | |
identified_properties = list(filter(key_exists("name"), properties)) | |
actual_properties = reduce(assign_to_key("name"), identified_properties, {}) | |
node_data["properties"] = actual_properties | |
return node | |
def derive_nodes_from_actual_json_output(json_data: Union[dict, list]): | |
""" | |
Find nodes from non-deterministic AI output | |
""" | |
if isinstance(json_data, list): | |
return {} | |
all_nodes = flatten_all_nodes(json_data) | |
if json_data.get("nodes", None) is None: | |
return all_nodes | |
nodes_with_properties_corrected = handle_property_correction(all_nodes) | |
return nodes_with_properties_corrected | |
def flatten_all_nodes(json_data) -> Dict[Hashable, Any]: | |
""" | |
Model output could have nested nodes, this extracts them. | |
""" | |
nodes = json_data.get("nodes", None) | |
if nodes is None: | |
sub_nodes_list = [ | |
find_all_nodes((name, contents)) | |
for name, contents in json_data.items() | |
if isinstance(contents, dict) | |
] | |
else: | |
sub_nodes_list = [ | |
find_all_nodes((node["name"], node)) | |
for node in nodes | |
if isinstance(node, dict) and node.get("name") is not None | |
] | |
all_nodes = {k: v for sub_nodes in sub_nodes_list for k, v in sub_nodes.items()} | |
return all_nodes | |
def aggregate_desc(acc, node): | |
node_name, node_info = node | |
desc = node_info.get("description", None) | |
if desc is not None and isinstance(desc, str): | |
acc[desc].add(node_name) | |
return acc | |
def reform_links(outer_acc, node): | |
node_name, node_info = node | |
links = node_info.get("links", []) | |
collect_links_to_aggregator = lambda inner_acc, link: upsert_set( | |
inner_acc, (link, node_name) | |
) | |
links_to_node_names = reduce(collect_links_to_aggregator, links, outer_acc) | |
return links_to_node_names | |
def lens(key, default=None): | |
"""Simple way to interface with the contents of a dict""" | |
return lambda d: d.get(key, default) | |
def aggregate_properties(outer_acc, node): | |
node_name, node_info = node | |
properties = node_info.get("properties", {}) | |
is_list = isinstance(properties, list) | |
property_names = ( | |
list(map(lens("name"), properties)) if is_list else list(properties.keys()) | |
) | |
aggregate_properties = reduce( | |
lambda inner_acc, prop_name: upsert_set(inner_acc, (prop_name, node_name)), | |
property_names, | |
outer_acc, | |
) | |
return aggregate_properties | |
def conform_node_to_expected_schema(name_to_data_model): | |
name, dm = name_to_data_model | |
conform_result = attempt(DataModel.model_validate, (dm,)) | |
model_outcome = ( | |
conform_result.model_dump(exclude_none=True, exclude_unset=True) | |
if "error" not in conform_result | |
else {} | |
) | |
errors = conform_result if "error" in conform_result else {} | |
return (name, model_outcome), errors | |
def aggregate_parsed_file(nodes: dict): | |
conformed_nodes_result = [ | |
conform_node_to_expected_schema(node) for node in nodes.items() | |
] | |
conformed_nodes = [node for node, errors in conformed_nodes_result if not errors] | |
aggregated_links = reduce(reform_links, conformed_nodes, defaultdict(set)) | |
aggregated_properties = reduce( | |
aggregate_properties, conformed_nodes, defaultdict(set) | |
) | |
aggregated_descriptions = reduce(aggregate_desc, conformed_nodes, defaultdict(set)) | |
parsed_results = { | |
"node_names": dict(conformed_nodes), | |
"links": aggregated_links, | |
"properties": aggregated_properties, | |
"description": aggregated_descriptions, | |
} | |
return parsed_results | |
def upsert_set(accumulator, kvp): | |
key, value = kvp | |
if isinstance(key, Hashable): | |
accumulator[key].add(value) | |
return accumulator | |
def parse_json(json_string: str) -> Optional[Union[dict, list]]: | |
""" | |
Safely parses a JSON string into a Python dictionary or list. | |
Args: | |
json_string (str): The JSON string to parse | |
Returns: | |
dict/list: Parsed JSON data if successful | |
dict["error"]: If parsing fails, provides error as string in response | |
""" | |
try: | |
return json.loads(json_string) | |
except json.JSONDecodeError as e: | |
return {"error": str(e)} | |
except TypeError as e: | |
return {"error": str(e)} | |