File size: 6,361 Bytes
df6c67d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
from typing import Any, Dict, Optional, Set, Union

import numpy as np
from networkx import DiGraph

from inference.core.utils.image_utils import ImageType
from inference.enterprise.workflows.complier.steps_executors.constants import (
    IMAGE_TYPE_KEY,
    IMAGE_VALUE_KEY,
    PARENT_ID_KEY,
)
from inference.enterprise.workflows.complier.utils import (
    get_nodes_of_specific_kind,
    is_input_selector,
)
from inference.enterprise.workflows.constants import INPUT_NODE_KIND, STEP_NODE_KIND
from inference.enterprise.workflows.entities.validators import get_last_selector_chunk
from inference.enterprise.workflows.errors import (
    InvalidStepInputDetected,
    RuntimeParameterMissingError,
)


def prepare_runtime_parameters(
    execution_graph: DiGraph,
    runtime_parameters: Dict[str, Any],
) -> Dict[str, Any]:
    ensure_all_parameters_filled(
        execution_graph=execution_graph,
        runtime_parameters=runtime_parameters,
    )
    runtime_parameters = fill_runtime_parameters_with_defaults(
        execution_graph=execution_graph,
        runtime_parameters=runtime_parameters,
    )
    runtime_parameters = assembly_input_images(
        execution_graph=execution_graph,
        runtime_parameters=runtime_parameters,
    )
    validate_inputs_binding(
        execution_graph=execution_graph,
        runtime_parameters=runtime_parameters,
    )
    return runtime_parameters


def ensure_all_parameters_filled(
    execution_graph: DiGraph,
    runtime_parameters: Dict[str, Any],
) -> None:
    parameters_without_default_values = get_input_parameters_without_default_values(
        execution_graph=execution_graph,
    )
    missing_parameters = []
    for name in parameters_without_default_values:
        if name not in runtime_parameters:
            missing_parameters.append(name)
    if len(missing_parameters) > 0:
        raise RuntimeParameterMissingError(
            f"Parameters passed to execution runtime do not define required inputs: {missing_parameters}"
        )


def get_input_parameters_without_default_values(execution_graph: DiGraph) -> Set[str]:
    input_nodes = get_nodes_of_specific_kind(
        execution_graph=execution_graph,
        kind=INPUT_NODE_KIND,
    )
    result = set()
    for input_node in input_nodes:
        definition = execution_graph.nodes[input_node]["definition"]
        if definition.type == "InferenceImage":
            result.add(definition.name)
            continue
        if definition.type == "InferenceParameter" and definition.default_value is None:
            result.add(definition.name)
            continue
    return result


def fill_runtime_parameters_with_defaults(
    execution_graph: DiGraph,
    runtime_parameters: Dict[str, Any],
) -> Dict[str, Any]:
    default_values_parameters = get_input_parameters_default_values(
        execution_graph=execution_graph
    )
    default_values_parameters.update(runtime_parameters)
    return default_values_parameters


def get_input_parameters_default_values(execution_graph: DiGraph) -> Dict[str, Any]:
    input_nodes = get_nodes_of_specific_kind(
        execution_graph=execution_graph,
        kind=INPUT_NODE_KIND,
    )
    result = {}
    for input_node in input_nodes:
        definition = execution_graph.nodes[input_node]["definition"]
        if (
            definition.type == "InferenceParameter"
            and definition.default_value is not None
        ):
            result[definition.name] = definition.default_value
    return result


def assembly_input_images(
    execution_graph: DiGraph,
    runtime_parameters: Dict[str, Any],
) -> Dict[str, Any]:
    input_nodes = get_nodes_of_specific_kind(
        execution_graph=execution_graph,
        kind=INPUT_NODE_KIND,
    )
    for input_node in input_nodes:
        definition = execution_graph.nodes[input_node]["definition"]
        if definition.type != "InferenceImage":
            continue
        if issubclass(type(runtime_parameters[definition.name]), list):
            runtime_parameters[definition.name] = [
                assembly_input_image(
                    parameter=input_node,
                    image=image,
                    identifier=i,
                )
                for i, image in enumerate(runtime_parameters[definition.name])
            ]
        else:
            runtime_parameters[definition.name] = [
                assembly_input_image(
                    parameter=input_node, image=runtime_parameters[definition.name]
                )
            ]
    return runtime_parameters


def assembly_input_image(
    parameter: str, image: Any, identifier: Optional[int] = None
) -> Dict[str, Union[str, np.ndarray]]:
    parent = parameter
    if identifier is not None:
        parent = f"{parent}.[{identifier}]"
    if issubclass(type(image), dict):
        image[PARENT_ID_KEY] = parent
        return image
    if issubclass(type(image), np.ndarray):
        return {
            IMAGE_TYPE_KEY: ImageType.NUMPY_OBJECT.value,
            IMAGE_VALUE_KEY: image,
            PARENT_ID_KEY: parent,
        }
    raise InvalidStepInputDetected(
        f"Detected runtime parameter `{parameter}` defined as `InferenceImage` with type {type(image)} that is invalid."
    )


def validate_inputs_binding(
    execution_graph: DiGraph,
    runtime_parameters: Dict[str, Any],
) -> None:
    step_nodes = get_nodes_of_specific_kind(
        execution_graph=execution_graph,
        kind=STEP_NODE_KIND,
    )
    for step in step_nodes:
        validate_step_input_bindings(
            step=step,
            execution_graph=execution_graph,
            runtime_parameters=runtime_parameters,
        )


def validate_step_input_bindings(
    step: str,
    execution_graph: DiGraph,
    runtime_parameters: Dict[str, Any],
) -> None:
    step_definition = execution_graph.nodes[step]["definition"]
    for input_name in step_definition.get_input_names():
        selector_or_value = getattr(step_definition, input_name)
        if not is_input_selector(selector_or_value=selector_or_value):
            continue
        input_parameter_name = get_last_selector_chunk(selector=selector_or_value)
        parameter_value = runtime_parameters[input_parameter_name]
        step_definition.validate_field_binding(
            field_name=input_name, value=parameter_value
        )