avantol's picture
feat(app): more examples, better parsing and error handling
cc69c66
"""
Functions used in several different places. This file should not import from any other non-lib files to prevent
circular dependencies.
"""
import json
import logging
from copy import copy
from typing import Any, Callable, Dict, Optional, Tuple, Union
TOP_LEVEL_IDENTIFIERS = {"description", "links", "properties"}
def get_json_from_model_output(input_generated_json: str):
"""
Parses a string, potentially containing Markdown code fences, into a JSON object.
This function attempts to extract and parse a JSON object from a string,
often the output of a language model. It handles cases where the JSON
is enclosed in Markdown code fences (```json ... ``` or ``` ... ```).
If the initial parsing fails, it attempts a more robust parsing using
`_get_valid_json_from_string` and
logs debug messages indicating success or failure. If all attempts fail,
it returns an empty dictionary.
Args:
input_generated_json: A string potentially containing a JSON object.
Returns:
A tuple containing:
- The parsed JSON object (a dictionary) or an empty dictionary if parsing failed.
- An integer representing the number of times parsing failed initially.
"""
originally_invalid_json_count = 0
generated_json_attempt_1 = copy(input_generated_json)
try:
code_split = generated_json_attempt_1.split("```")
if len(code_split) > 1:
generated_json_attempt_1 = json.loads(
("```" + code_split[1]).replace("```json", "")
)
else:
generated_json_attempt_1 = json.loads(
generated_json_attempt_1.replace("```json", "").replace("```", "")
)
except Exception as exc:
logging.debug(f"could not parse AI model generated output as JSON. Exc: {exc}.")
# originally_invalid_json_count += 1
generated_json_attempt_1 = {}
some_value_in_attempt_1_is_not_a_dict = check_contents_valid(
generated_json_attempt_1
)
attempt_1_failed = (
not bool(generated_json_attempt_1) or some_value_in_attempt_1_is_not_a_dict
)
generated_json_attempt_2 = copy(input_generated_json) if attempt_1_failed else {}
if attempt_1_failed:
logging.debug(
"Attempting to make output valid to obtain better metrics (this works in limited cases where "
"the model output was simply cut off)"
)
try:
code_split = generated_json_attempt_2.split("```")
if len(code_split) > 1:
generated_json_attempt_2 = json.loads(
_get_valid_json_from_string(
("```" + code_split[1]).replace("```json", "")
)
)
else:
stripped_output = generated_json_attempt_2.replace(
"```json", ""
).replace("```", "")
balance_outcome = attempt(
json.loads, (balance_braces(stripped_output),)
)
if "error" not in balance_outcome:
generated_json_attempt_2 = balance_outcome
else:
generated_json_attempt_2 = json.loads(
_get_valid_json_from_string(stripped_output)
)
logging.debug(
"Success! Reconstructed valid JSON from unparseable model output. Continuing metrics comparison..."
)
except Exception as exc:
logging.debug(
"Failed. Setting model output as empty JSON to enable metrics comparison."
)
generated_json_attempt_2 = {}
some_value_in_attempt_2_is_not_a_dict = (
attempt_1_failed
and isinstance(generated_json_attempt_2, dict)
and check_contents_valid(generated_json_attempt_2)
)
if some_value_in_attempt_1_is_not_a_dict and some_value_in_attempt_2_is_not_a_dict:
logging.debug(f"Could not recover model output json, aborting!")
originally_invalid_json_count += 1
generated_json = (
generated_json_attempt_1 if not attempt_1_failed else generated_json_attempt_2
)
return generated_json, originally_invalid_json_count
def check_contents_valid(generated_json_attempt_1: Union[list, dict]):
"""
Checks that the sub nodes are not lists or anything
Args:
generated_json_attempt_1 (Union[list, dict]): data to check
Returns:
truthy based on contents of input
"""
if isinstance(generated_json_attempt_1, list):
for item in generated_json_attempt_1:
if not isinstance(item, dict):
return item
return None
elif (
isinstance(generated_json_attempt_1, dict)
and "nodes" in generated_json_attempt_1.keys()
):
for item in generated_json_attempt_1.get("nodes", []):
if not isinstance(item, dict):
return item
return None
else:
for item in generated_json_attempt_1.values():
if not isinstance(item, dict):
return item
return None
def _get_valid_json_from_string(s):
"""
Given a JSON string with potentially unclosed strings, arrays, or objects, close those things
to hopefully be able to parse as valid JSON
"""
double_quotes = 0
single_quotes = 0
brackets = []
for i, c in enumerate(s):
if c == '"':
double_quotes = 1 - double_quotes # Toggle between 0 and 1
elif c == "'":
single_quotes = 1 - single_quotes # Toggle between 0 and 1
elif c in "{[":
brackets.append((i, c))
elif c in "}]":
if double_quotes == 0 and single_quotes == 0:
if brackets:
last_opened = brackets.pop()
if (c == "}" and last_opened[1] != "{") or (
c == "]" and last_opened[1] != "["
):
raise ValueError(
f"Mismatched brackets/quotes found: opened {last_opened[1]} @ {last_opened[0]} "
f"but closed {c} @ {i}"
)
else:
# If no matching opening bracket, it's an error, but we can skip this for the task
pass
# Remove trailing comma if it exists
if s.strip().endswith(","):
logging.debug("Removing ending ,")
s = s.strip().rstrip(",")
closing_chars = ""
# Adding closing quotes if there are missing ones
if double_quotes > 0:
closing_chars += '"'
if single_quotes > 0:
closing_chars += "'"
# Add closing brackets for any unmatched opening brackets
while brackets:
last_opened = brackets.pop()
if last_opened[1] == "{":
closing_chars += "}"
elif last_opened[1] == "[":
closing_chars += "]"
logging.debug(f"closing_chars: {closing_chars}")
output_string = s + closing_chars
try:
json.loads(output_string)
except Exception:
logging.debug(
"JSON string still fails to be parseable, attempting another modification..."
)
# it's possible the closing quotes were on a property that didn't have a value, let's
# fix that and see if it works
new_closing_chars = ""
found_first_double_quote = False
for char in closing_chars:
if not found_first_double_quote and char == '"':
# for keys in objects with no value, append an empty value
#
# For example:
# ```
# {
# "properties": {
# "annotation
# ```
new_closing_chars += '": ""'
else:
new_closing_chars += char
logging.debug(f"new closing_chars: {new_closing_chars}")
output_string = s + new_closing_chars
return output_string
def on_fail(
outcome: Union[Any, Dict[str, str]],
fallback: Union[Any, Callable] = None,
):
"""
Allows you to provide a fallback to recover from a failed outcome.
Args:
outcome
fallback
Returns:
"""
is_fail = isinstance(outcome, dict) and "error" in outcome
is_callable = isinstance(fallback, Callable)
if is_fail and is_callable:
return fallback(outcome)
elif is_fail:
return fallback
return outcome
def attempt(
func: Callable,
args: Tuple[Any, ...] = (),
kwargs: Optional[Dict[str, Any]] = None,
) -> Union[Any, Dict[str, str]]:
"""
Attempts to execute a function with the provided arguments.
If the function raises an exception, the exception is caught and returned in a dict.
Args:
func (Callable): The function to execute.
args (Tuple[Any, ...], optional): A tuple of positional arguments for the function.
kwargs (Optional[Dict[str, Any]], optional): A dictionary of keyword arguments for the function.
Returns:
Function result or {"error": <msg>} response
"""
kwargs = kwargs or {}
try:
return func(*args, **kwargs)
except Exception as exc:
return {"error": str(exc)}
def balance_braces(s: str) -> str:
"""
Primitive function that just tries to add '{}' style braces to try to recover
the model string.
Args:
s(str): string to balance braces on.
Returns:
provided string with balanced braces if possible
"""
open_count = s.count("{")
close_count = s.count("}")
if open_count > close_count:
s += "}" * (open_count - close_count)
elif close_count > open_count:
s = "{" * (close_count - open_count) + s
return s
def flatten_list(coll):
flattened_data = []
for set_list in coll:
flattened_data = flattened_data + list(set_list)
return flattened_data
def keep_errors(collection):
"""
Given a set of outcomes, keeps any that resulted in an error
Args:
collection (Collection): collection of outcomes to filter.
Returns:
All instances of the collection that contain an error response.
"""
return [instance for instance in collection if "error" in (instance or [])]