OMG / inference /enterprise /workflows /complier /runtime_input_validator.py
Fucius's picture
Upload 422 files
df6c67d verified
from typing import Any, Dict, Optional, Set, Union
import numpy as np
from networkx import DiGraph
from inference.core.utils.image_utils import ImageType
from inference.enterprise.workflows.complier.steps_executors.constants import (
IMAGE_TYPE_KEY,
IMAGE_VALUE_KEY,
PARENT_ID_KEY,
)
from inference.enterprise.workflows.complier.utils import (
get_nodes_of_specific_kind,
is_input_selector,
)
from inference.enterprise.workflows.constants import INPUT_NODE_KIND, STEP_NODE_KIND
from inference.enterprise.workflows.entities.validators import get_last_selector_chunk
from inference.enterprise.workflows.errors import (
InvalidStepInputDetected,
RuntimeParameterMissingError,
)
def prepare_runtime_parameters(
execution_graph: DiGraph,
runtime_parameters: Dict[str, Any],
) -> Dict[str, Any]:
ensure_all_parameters_filled(
execution_graph=execution_graph,
runtime_parameters=runtime_parameters,
)
runtime_parameters = fill_runtime_parameters_with_defaults(
execution_graph=execution_graph,
runtime_parameters=runtime_parameters,
)
runtime_parameters = assembly_input_images(
execution_graph=execution_graph,
runtime_parameters=runtime_parameters,
)
validate_inputs_binding(
execution_graph=execution_graph,
runtime_parameters=runtime_parameters,
)
return runtime_parameters
def ensure_all_parameters_filled(
execution_graph: DiGraph,
runtime_parameters: Dict[str, Any],
) -> None:
parameters_without_default_values = get_input_parameters_without_default_values(
execution_graph=execution_graph,
)
missing_parameters = []
for name in parameters_without_default_values:
if name not in runtime_parameters:
missing_parameters.append(name)
if len(missing_parameters) > 0:
raise RuntimeParameterMissingError(
f"Parameters passed to execution runtime do not define required inputs: {missing_parameters}"
)
def get_input_parameters_without_default_values(execution_graph: DiGraph) -> Set[str]:
input_nodes = get_nodes_of_specific_kind(
execution_graph=execution_graph,
kind=INPUT_NODE_KIND,
)
result = set()
for input_node in input_nodes:
definition = execution_graph.nodes[input_node]["definition"]
if definition.type == "InferenceImage":
result.add(definition.name)
continue
if definition.type == "InferenceParameter" and definition.default_value is None:
result.add(definition.name)
continue
return result
def fill_runtime_parameters_with_defaults(
execution_graph: DiGraph,
runtime_parameters: Dict[str, Any],
) -> Dict[str, Any]:
default_values_parameters = get_input_parameters_default_values(
execution_graph=execution_graph
)
default_values_parameters.update(runtime_parameters)
return default_values_parameters
def get_input_parameters_default_values(execution_graph: DiGraph) -> Dict[str, Any]:
input_nodes = get_nodes_of_specific_kind(
execution_graph=execution_graph,
kind=INPUT_NODE_KIND,
)
result = {}
for input_node in input_nodes:
definition = execution_graph.nodes[input_node]["definition"]
if (
definition.type == "InferenceParameter"
and definition.default_value is not None
):
result[definition.name] = definition.default_value
return result
def assembly_input_images(
execution_graph: DiGraph,
runtime_parameters: Dict[str, Any],
) -> Dict[str, Any]:
input_nodes = get_nodes_of_specific_kind(
execution_graph=execution_graph,
kind=INPUT_NODE_KIND,
)
for input_node in input_nodes:
definition = execution_graph.nodes[input_node]["definition"]
if definition.type != "InferenceImage":
continue
if issubclass(type(runtime_parameters[definition.name]), list):
runtime_parameters[definition.name] = [
assembly_input_image(
parameter=input_node,
image=image,
identifier=i,
)
for i, image in enumerate(runtime_parameters[definition.name])
]
else:
runtime_parameters[definition.name] = [
assembly_input_image(
parameter=input_node, image=runtime_parameters[definition.name]
)
]
return runtime_parameters
def assembly_input_image(
parameter: str, image: Any, identifier: Optional[int] = None
) -> Dict[str, Union[str, np.ndarray]]:
parent = parameter
if identifier is not None:
parent = f"{parent}.[{identifier}]"
if issubclass(type(image), dict):
image[PARENT_ID_KEY] = parent
return image
if issubclass(type(image), np.ndarray):
return {
IMAGE_TYPE_KEY: ImageType.NUMPY_OBJECT.value,
IMAGE_VALUE_KEY: image,
PARENT_ID_KEY: parent,
}
raise InvalidStepInputDetected(
f"Detected runtime parameter `{parameter}` defined as `InferenceImage` with type {type(image)} that is invalid."
)
def validate_inputs_binding(
execution_graph: DiGraph,
runtime_parameters: Dict[str, Any],
) -> None:
step_nodes = get_nodes_of_specific_kind(
execution_graph=execution_graph,
kind=STEP_NODE_KIND,
)
for step in step_nodes:
validate_step_input_bindings(
step=step,
execution_graph=execution_graph,
runtime_parameters=runtime_parameters,
)
def validate_step_input_bindings(
step: str,
execution_graph: DiGraph,
runtime_parameters: Dict[str, Any],
) -> None:
step_definition = execution_graph.nodes[step]["definition"]
for input_name in step_definition.get_input_names():
selector_or_value = getattr(step_definition, input_name)
if not is_input_selector(selector_or_value=selector_or_value):
continue
input_parameter_name = get_last_selector_chunk(selector=selector_or_value)
parameter_value = runtime_parameters[input_parameter_name]
step_definition.validate_field_binding(
field_name=input_name, value=parameter_value
)