Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Functions used in several different places. This file should not import from any other non-lib files to prevent | |
circular dependencies. | |
""" | |
import json | |
import logging | |
from copy import copy | |
from typing import Any, Callable, Dict, Optional, Tuple, Union | |
TOP_LEVEL_IDENTIFIERS = {"description", "links", "properties"} | |
def get_json_from_model_output(input_generated_json: str): | |
""" | |
Parses a string, potentially containing Markdown code fences, into a JSON object. | |
This function attempts to extract and parse a JSON object from a string, | |
often the output of a language model. It handles cases where the JSON | |
is enclosed in Markdown code fences (```json ... ``` or ``` ... ```). | |
If the initial parsing fails, it attempts a more robust parsing using | |
`_get_valid_json_from_string` and | |
logs debug messages indicating success or failure. If all attempts fail, | |
it returns an empty dictionary. | |
Args: | |
input_generated_json: A string potentially containing a JSON object. | |
Returns: | |
A tuple containing: | |
- The parsed JSON object (a dictionary) or an empty dictionary if parsing failed. | |
- An integer representing the number of times parsing failed initially. | |
""" | |
originally_invalid_json_count = 0 | |
generated_json_attempt_1 = copy(input_generated_json) | |
try: | |
code_split = generated_json_attempt_1.split("```") | |
if len(code_split) > 1: | |
generated_json_attempt_1 = json.loads( | |
("```" + code_split[1]).replace("```json", "") | |
) | |
else: | |
generated_json_attempt_1 = json.loads( | |
generated_json_attempt_1.replace("```json", "").replace("```", "") | |
) | |
except Exception as exc: | |
logging.debug(f"could not parse AI model generated output as JSON. Exc: {exc}.") | |
# originally_invalid_json_count += 1 | |
generated_json_attempt_1 = {} | |
some_value_in_attempt_1_is_not_a_dict = check_contents_valid( | |
generated_json_attempt_1 | |
) | |
attempt_1_failed = ( | |
not bool(generated_json_attempt_1) or some_value_in_attempt_1_is_not_a_dict | |
) | |
generated_json_attempt_2 = copy(input_generated_json) if attempt_1_failed else {} | |
if attempt_1_failed: | |
logging.debug( | |
"Attempting to make output valid to obtain better metrics (this works in limited cases where " | |
"the model output was simply cut off)" | |
) | |
try: | |
code_split = generated_json_attempt_2.split("```") | |
if len(code_split) > 1: | |
generated_json_attempt_2 = json.loads( | |
_get_valid_json_from_string( | |
("```" + code_split[1]).replace("```json", "") | |
) | |
) | |
else: | |
stripped_output = generated_json_attempt_2.replace( | |
"```json", "" | |
).replace("```", "") | |
balance_outcome = attempt( | |
json.loads, (balance_braces(stripped_output),) | |
) | |
if "error" not in balance_outcome: | |
generated_json_attempt_2 = balance_outcome | |
else: | |
generated_json_attempt_2 = json.loads( | |
_get_valid_json_from_string(stripped_output) | |
) | |
logging.debug( | |
"Success! Reconstructed valid JSON from unparseable model output. Continuing metrics comparison..." | |
) | |
except Exception as exc: | |
logging.debug( | |
"Failed. Setting model output as empty JSON to enable metrics comparison." | |
) | |
generated_json_attempt_2 = {} | |
some_value_in_attempt_2_is_not_a_dict = ( | |
attempt_1_failed | |
and isinstance(generated_json_attempt_2, dict) | |
and check_contents_valid(generated_json_attempt_2) | |
) | |
if some_value_in_attempt_1_is_not_a_dict and some_value_in_attempt_2_is_not_a_dict: | |
logging.debug(f"Could not recover model output json, aborting!") | |
originally_invalid_json_count += 1 | |
generated_json = ( | |
generated_json_attempt_1 if not attempt_1_failed else generated_json_attempt_2 | |
) | |
return generated_json, originally_invalid_json_count | |
def check_contents_valid(generated_json_attempt_1: Union[list, dict]): | |
""" | |
Checks that the sub nodes are not lists or anything | |
Args: | |
generated_json_attempt_1 (Union[list, dict]): data to check | |
Returns: | |
truthy based on contents of input | |
""" | |
if isinstance(generated_json_attempt_1, list): | |
for item in generated_json_attempt_1: | |
if not isinstance(item, dict): | |
return item | |
return None | |
elif ( | |
isinstance(generated_json_attempt_1, dict) | |
and "nodes" in generated_json_attempt_1.keys() | |
): | |
for item in generated_json_attempt_1.get("nodes", []): | |
if not isinstance(item, dict): | |
return item | |
return None | |
else: | |
for item in generated_json_attempt_1.values(): | |
if not isinstance(item, dict): | |
return item | |
return None | |
def _get_valid_json_from_string(s): | |
""" | |
Given a JSON string with potentially unclosed strings, arrays, or objects, close those things | |
to hopefully be able to parse as valid JSON | |
""" | |
double_quotes = 0 | |
single_quotes = 0 | |
brackets = [] | |
for i, c in enumerate(s): | |
if c == '"': | |
double_quotes = 1 - double_quotes # Toggle between 0 and 1 | |
elif c == "'": | |
single_quotes = 1 - single_quotes # Toggle between 0 and 1 | |
elif c in "{[": | |
brackets.append((i, c)) | |
elif c in "}]": | |
if double_quotes == 0 and single_quotes == 0: | |
if brackets: | |
last_opened = brackets.pop() | |
if (c == "}" and last_opened[1] != "{") or ( | |
c == "]" and last_opened[1] != "[" | |
): | |
raise ValueError( | |
f"Mismatched brackets/quotes found: opened {last_opened[1]} @ {last_opened[0]} " | |
f"but closed {c} @ {i}" | |
) | |
else: | |
# If no matching opening bracket, it's an error, but we can skip this for the task | |
pass | |
# Remove trailing comma if it exists | |
if s.strip().endswith(","): | |
logging.debug("Removing ending ,") | |
s = s.strip().rstrip(",") | |
closing_chars = "" | |
# Adding closing quotes if there are missing ones | |
if double_quotes > 0: | |
closing_chars += '"' | |
if single_quotes > 0: | |
closing_chars += "'" | |
# Add closing brackets for any unmatched opening brackets | |
while brackets: | |
last_opened = brackets.pop() | |
if last_opened[1] == "{": | |
closing_chars += "}" | |
elif last_opened[1] == "[": | |
closing_chars += "]" | |
logging.debug(f"closing_chars: {closing_chars}") | |
output_string = s + closing_chars | |
try: | |
json.loads(output_string) | |
except Exception: | |
logging.debug( | |
"JSON string still fails to be parseable, attempting another modification..." | |
) | |
# it's possible the closing quotes were on a property that didn't have a value, let's | |
# fix that and see if it works | |
new_closing_chars = "" | |
found_first_double_quote = False | |
for char in closing_chars: | |
if not found_first_double_quote and char == '"': | |
# for keys in objects with no value, append an empty value | |
# | |
# For example: | |
# ``` | |
# { | |
# "properties": { | |
# "annotation | |
# ``` | |
new_closing_chars += '": ""' | |
else: | |
new_closing_chars += char | |
logging.debug(f"new closing_chars: {new_closing_chars}") | |
output_string = s + new_closing_chars | |
return output_string | |
def on_fail( | |
outcome: Union[Any, Dict[str, str]], | |
fallback: Union[Any, Callable] = None, | |
): | |
""" | |
Allows you to provide a fallback to recover from a failed outcome. | |
Args: | |
outcome | |
fallback | |
Returns: | |
""" | |
is_fail = isinstance(outcome, dict) and "error" in outcome | |
is_callable = isinstance(fallback, Callable) | |
if is_fail and is_callable: | |
return fallback(outcome) | |
elif is_fail: | |
return fallback | |
return outcome | |
def attempt( | |
func: Callable, | |
args: Tuple[Any, ...] = (), | |
kwargs: Optional[Dict[str, Any]] = None, | |
) -> Union[Any, Dict[str, str]]: | |
""" | |
Attempts to execute a function with the provided arguments. | |
If the function raises an exception, the exception is caught and returned in a dict. | |
Args: | |
func (Callable): The function to execute. | |
args (Tuple[Any, ...], optional): A tuple of positional arguments for the function. | |
kwargs (Optional[Dict[str, Any]], optional): A dictionary of keyword arguments for the function. | |
Returns: | |
Function result or {"error": <msg>} response | |
""" | |
kwargs = kwargs or {} | |
try: | |
return func(*args, **kwargs) | |
except Exception as exc: | |
return {"error": str(exc)} | |
def balance_braces(s: str) -> str: | |
""" | |
Primitive function that just tries to add '{}' style braces to try to recover | |
the model string. | |
Args: | |
s(str): string to balance braces on. | |
Returns: | |
provided string with balanced braces if possible | |
""" | |
open_count = s.count("{") | |
close_count = s.count("}") | |
if open_count > close_count: | |
s += "}" * (open_count - close_count) | |
elif close_count > open_count: | |
s = "{" * (close_count - open_count) + s | |
return s | |
def flatten_list(coll): | |
flattened_data = [] | |
for set_list in coll: | |
flattened_data = flattened_data + list(set_list) | |
return flattened_data | |
def keep_errors(collection): | |
""" | |
Given a set of outcomes, keeps any that resulted in an error | |
Args: | |
collection (Collection): collection of outcomes to filter. | |
Returns: | |
All instances of the collection that contain an error response. | |
""" | |
return [instance for instance in collection if "error" in (instance or [])] | |