from typing import Any, Dict, Optional, Set, Union import numpy as np from networkx import DiGraph from inference.core.utils.image_utils import ImageType from inference.enterprise.workflows.complier.steps_executors.constants import ( IMAGE_TYPE_KEY, IMAGE_VALUE_KEY, PARENT_ID_KEY, ) from inference.enterprise.workflows.complier.utils import ( get_nodes_of_specific_kind, is_input_selector, ) from inference.enterprise.workflows.constants import INPUT_NODE_KIND, STEP_NODE_KIND from inference.enterprise.workflows.entities.validators import get_last_selector_chunk from inference.enterprise.workflows.errors import ( InvalidStepInputDetected, RuntimeParameterMissingError, ) def prepare_runtime_parameters( execution_graph: DiGraph, runtime_parameters: Dict[str, Any], ) -> Dict[str, Any]: ensure_all_parameters_filled( execution_graph=execution_graph, runtime_parameters=runtime_parameters, ) runtime_parameters = fill_runtime_parameters_with_defaults( execution_graph=execution_graph, runtime_parameters=runtime_parameters, ) runtime_parameters = assembly_input_images( execution_graph=execution_graph, runtime_parameters=runtime_parameters, ) validate_inputs_binding( execution_graph=execution_graph, runtime_parameters=runtime_parameters, ) return runtime_parameters def ensure_all_parameters_filled( execution_graph: DiGraph, runtime_parameters: Dict[str, Any], ) -> None: parameters_without_default_values = get_input_parameters_without_default_values( execution_graph=execution_graph, ) missing_parameters = [] for name in parameters_without_default_values: if name not in runtime_parameters: missing_parameters.append(name) if len(missing_parameters) > 0: raise RuntimeParameterMissingError( f"Parameters passed to execution runtime do not define required inputs: {missing_parameters}" ) def get_input_parameters_without_default_values(execution_graph: DiGraph) -> Set[str]: input_nodes = get_nodes_of_specific_kind( execution_graph=execution_graph, kind=INPUT_NODE_KIND, ) result = set() for input_node in input_nodes: definition = execution_graph.nodes[input_node]["definition"] if definition.type == "InferenceImage": result.add(definition.name) continue if definition.type == "InferenceParameter" and definition.default_value is None: result.add(definition.name) continue return result def fill_runtime_parameters_with_defaults( execution_graph: DiGraph, runtime_parameters: Dict[str, Any], ) -> Dict[str, Any]: default_values_parameters = get_input_parameters_default_values( execution_graph=execution_graph ) default_values_parameters.update(runtime_parameters) return default_values_parameters def get_input_parameters_default_values(execution_graph: DiGraph) -> Dict[str, Any]: input_nodes = get_nodes_of_specific_kind( execution_graph=execution_graph, kind=INPUT_NODE_KIND, ) result = {} for input_node in input_nodes: definition = execution_graph.nodes[input_node]["definition"] if ( definition.type == "InferenceParameter" and definition.default_value is not None ): result[definition.name] = definition.default_value return result def assembly_input_images( execution_graph: DiGraph, runtime_parameters: Dict[str, Any], ) -> Dict[str, Any]: input_nodes = get_nodes_of_specific_kind( execution_graph=execution_graph, kind=INPUT_NODE_KIND, ) for input_node in input_nodes: definition = execution_graph.nodes[input_node]["definition"] if definition.type != "InferenceImage": continue if issubclass(type(runtime_parameters[definition.name]), list): runtime_parameters[definition.name] = [ assembly_input_image( parameter=input_node, image=image, identifier=i, ) for i, image in enumerate(runtime_parameters[definition.name]) ] else: runtime_parameters[definition.name] = [ assembly_input_image( parameter=input_node, image=runtime_parameters[definition.name] ) ] return runtime_parameters def assembly_input_image( parameter: str, image: Any, identifier: Optional[int] = None ) -> Dict[str, Union[str, np.ndarray]]: parent = parameter if identifier is not None: parent = f"{parent}.[{identifier}]" if issubclass(type(image), dict): image[PARENT_ID_KEY] = parent return image if issubclass(type(image), np.ndarray): return { IMAGE_TYPE_KEY: ImageType.NUMPY_OBJECT.value, IMAGE_VALUE_KEY: image, PARENT_ID_KEY: parent, } raise InvalidStepInputDetected( f"Detected runtime parameter `{parameter}` defined as `InferenceImage` with type {type(image)} that is invalid." ) def validate_inputs_binding( execution_graph: DiGraph, runtime_parameters: Dict[str, Any], ) -> None: step_nodes = get_nodes_of_specific_kind( execution_graph=execution_graph, kind=STEP_NODE_KIND, ) for step in step_nodes: validate_step_input_bindings( step=step, execution_graph=execution_graph, runtime_parameters=runtime_parameters, ) def validate_step_input_bindings( step: str, execution_graph: DiGraph, runtime_parameters: Dict[str, Any], ) -> None: step_definition = execution_graph.nodes[step]["definition"] for input_name in step_definition.get_input_names(): selector_or_value = getattr(step_definition, input_name) if not is_input_selector(selector_or_value=selector_or_value): continue input_parameter_name = get_last_selector_chunk(selector=selector_or_value) parameter_value = runtime_parameters[input_parameter_name] step_definition.validate_field_binding( field_name=input_name, value=parameter_value )