""" Functions used in several different places. This file should not import from any other non-lib files to prevent circular dependencies. """ import json import logging from copy import copy from typing import Any, Callable, Dict, Optional, Tuple, Union TOP_LEVEL_IDENTIFIERS = {"description", "links", "properties"} def get_json_from_model_output(input_generated_json: str): """ Parses a string, potentially containing Markdown code fences, into a JSON object. This function attempts to extract and parse a JSON object from a string, often the output of a language model. It handles cases where the JSON is enclosed in Markdown code fences (```json ... ``` or ``` ... ```). If the initial parsing fails, it attempts a more robust parsing using `_get_valid_json_from_string` and logs debug messages indicating success or failure. If all attempts fail, it returns an empty dictionary. Args: input_generated_json: A string potentially containing a JSON object. Returns: A tuple containing: - The parsed JSON object (a dictionary) or an empty dictionary if parsing failed. - An integer representing the number of times parsing failed initially. """ originally_invalid_json_count = 0 generated_json_attempt_1 = copy(input_generated_json) try: code_split = generated_json_attempt_1.split("```") if len(code_split) > 1: generated_json_attempt_1 = json.loads( ("```" + code_split[1]).replace("```json", "") ) else: generated_json_attempt_1 = json.loads( generated_json_attempt_1.replace("```json", "").replace("```", "") ) except Exception as exc: logging.debug(f"could not parse AI model generated output as JSON. Exc: {exc}.") # originally_invalid_json_count += 1 generated_json_attempt_1 = {} some_value_in_attempt_1_is_not_a_dict = check_contents_valid( generated_json_attempt_1 ) attempt_1_failed = ( not bool(generated_json_attempt_1) or some_value_in_attempt_1_is_not_a_dict ) generated_json_attempt_2 = copy(input_generated_json) if attempt_1_failed else {} if attempt_1_failed: logging.debug( "Attempting to make output valid to obtain better metrics (this works in limited cases where " "the model output was simply cut off)" ) try: code_split = generated_json_attempt_2.split("```") if len(code_split) > 1: generated_json_attempt_2 = json.loads( _get_valid_json_from_string( ("```" + code_split[1]).replace("```json", "") ) ) else: stripped_output = generated_json_attempt_2.replace( "```json", "" ).replace("```", "") balance_outcome = attempt( json.loads, (balance_braces(stripped_output),) ) if "error" not in balance_outcome: generated_json_attempt_2 = balance_outcome else: generated_json_attempt_2 = json.loads( _get_valid_json_from_string(stripped_output) ) logging.debug( "Success! Reconstructed valid JSON from unparseable model output. Continuing metrics comparison..." ) except Exception as exc: logging.debug( "Failed. Setting model output as empty JSON to enable metrics comparison." ) generated_json_attempt_2 = {} some_value_in_attempt_2_is_not_a_dict = ( attempt_1_failed and isinstance(generated_json_attempt_2, dict) and check_contents_valid(generated_json_attempt_2) ) if some_value_in_attempt_1_is_not_a_dict and some_value_in_attempt_2_is_not_a_dict: logging.debug(f"Could not recover model output json, aborting!") originally_invalid_json_count += 1 generated_json = ( generated_json_attempt_1 if not attempt_1_failed else generated_json_attempt_2 ) return generated_json, originally_invalid_json_count def check_contents_valid(generated_json_attempt_1: Union[list, dict]): """ Checks that the sub nodes are not lists or anything Args: generated_json_attempt_1 (Union[list, dict]): data to check Returns: truthy based on contents of input """ if isinstance(generated_json_attempt_1, list): for item in generated_json_attempt_1: if not isinstance(item, dict): return item return None elif ( isinstance(generated_json_attempt_1, dict) and "nodes" in generated_json_attempt_1.keys() ): for item in generated_json_attempt_1.get("nodes", []): if not isinstance(item, dict): return item return None else: for item in generated_json_attempt_1.values(): if not isinstance(item, dict): return item return None def _get_valid_json_from_string(s): """ Given a JSON string with potentially unclosed strings, arrays, or objects, close those things to hopefully be able to parse as valid JSON """ double_quotes = 0 single_quotes = 0 brackets = [] for i, c in enumerate(s): if c == '"': double_quotes = 1 - double_quotes # Toggle between 0 and 1 elif c == "'": single_quotes = 1 - single_quotes # Toggle between 0 and 1 elif c in "{[": brackets.append((i, c)) elif c in "}]": if double_quotes == 0 and single_quotes == 0: if brackets: last_opened = brackets.pop() if (c == "}" and last_opened[1] != "{") or ( c == "]" and last_opened[1] != "[" ): raise ValueError( f"Mismatched brackets/quotes found: opened {last_opened[1]} @ {last_opened[0]} " f"but closed {c} @ {i}" ) else: # If no matching opening bracket, it's an error, but we can skip this for the task pass # Remove trailing comma if it exists if s.strip().endswith(","): logging.debug("Removing ending ,") s = s.strip().rstrip(",") closing_chars = "" # Adding closing quotes if there are missing ones if double_quotes > 0: closing_chars += '"' if single_quotes > 0: closing_chars += "'" # Add closing brackets for any unmatched opening brackets while brackets: last_opened = brackets.pop() if last_opened[1] == "{": closing_chars += "}" elif last_opened[1] == "[": closing_chars += "]" logging.debug(f"closing_chars: {closing_chars}") output_string = s + closing_chars try: json.loads(output_string) except Exception: logging.debug( "JSON string still fails to be parseable, attempting another modification..." ) # it's possible the closing quotes were on a property that didn't have a value, let's # fix that and see if it works new_closing_chars = "" found_first_double_quote = False for char in closing_chars: if not found_first_double_quote and char == '"': # for keys in objects with no value, append an empty value # # For example: # ``` # { # "properties": { # "annotation # ``` new_closing_chars += '": ""' else: new_closing_chars += char logging.debug(f"new closing_chars: {new_closing_chars}") output_string = s + new_closing_chars return output_string def on_fail( outcome: Union[Any, Dict[str, str]], fallback: Union[Any, Callable] = None, ): """ Allows you to provide a fallback to recover from a failed outcome. Args: outcome fallback Returns: """ is_fail = isinstance(outcome, dict) and "error" in outcome is_callable = isinstance(fallback, Callable) if is_fail and is_callable: return fallback(outcome) elif is_fail: return fallback return outcome def attempt( func: Callable, args: Tuple[Any, ...] = (), kwargs: Optional[Dict[str, Any]] = None, ) -> Union[Any, Dict[str, str]]: """ Attempts to execute a function with the provided arguments. If the function raises an exception, the exception is caught and returned in a dict. Args: func (Callable): The function to execute. args (Tuple[Any, ...], optional): A tuple of positional arguments for the function. kwargs (Optional[Dict[str, Any]], optional): A dictionary of keyword arguments for the function. Returns: Function result or {"error": } response """ kwargs = kwargs or {} try: return func(*args, **kwargs) except Exception as exc: return {"error": str(exc)} def balance_braces(s: str) -> str: """ Primitive function that just tries to add '{}' style braces to try to recover the model string. Args: s(str): string to balance braces on. Returns: provided string with balanced braces if possible """ open_count = s.count("{") close_count = s.count("}") if open_count > close_count: s += "}" * (open_count - close_count) elif close_count > open_count: s = "{" * (close_count - open_count) + s return s def flatten_list(coll): flattened_data = [] for set_list in coll: flattened_data = flattened_data + list(set_list) return flattened_data def keep_errors(collection): """ Given a set of outcomes, keeps any that resulted in an error Args: collection (Collection): collection of outcomes to filter. Returns: All instances of the collection that contain an error response. """ return [instance for instance in collection if "error" in (instance or [])]