File size: 8,000 Bytes
df6c67d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
from typing import Any, List, Optional, Set, Type

from pydantic import ValidationError

from inference.core.entities.requests.inference import InferenceRequestImage
from inference.enterprise.workflows.entities.base import GraphNone
from inference.enterprise.workflows.errors import (
    InvalidStepInputDetected,
    VariableTypeError,
)

STEPS_WITH_IMAGE = {
    "InferenceImage",
    "Crop",
    "AbsoluteStaticCrop",
    "RelativeStaticCrop",
}


def validate_image_is_valid_selector(value: Any, field_name: str = "image") -> None:
    if issubclass(type(value), list):
        if any(not is_selector(selector_or_value=e) for e in value):
            raise ValueError(f"`{field_name}` field can only contain selector values")
    elif not is_selector(selector_or_value=value):
        raise ValueError(f"`{field_name}` field can only contain selector values")


def validate_field_is_in_range_zero_one_or_empty_or_selector(
    value: Any, field_name: str = "confidence"
) -> None:
    if is_selector(selector_or_value=value) or value is None:
        return None
    validate_value_is_empty_or_number_in_range_zero_one(
        value=value, field_name=field_name
    )


def validate_value_is_empty_or_number_in_range_zero_one(
    value: Any, field_name: str = "confidence", error: Type[Exception] = ValueError
) -> None:
    validate_field_has_given_type(
        field_name=field_name,
        allowed_types=[type(None), int, float],
        value=value,
        error=error,
    )
    if value is None:
        return None
    if not (0 <= value <= 1):
        raise error(f"Parameter `{field_name}` must be in range [0.0, 1.0]")


def validate_value_is_empty_or_selector_or_positive_number(
    value: Any, field_name: str
) -> None:
    if is_selector(selector_or_value=value):
        return None
    validate_value_is_empty_or_positive_number(value=value, field_name=field_name)


def validate_value_is_empty_or_positive_number(
    value: Any, field_name: str, error: Type[Exception] = ValueError
) -> None:
    validate_field_has_given_type(
        field_name=field_name,
        allowed_types=[type(None), int, float],
        value=value,
        error=error,
    )
    if value is None:
        return None
    if value <= 0:
        raise error(f"Parameter `{field_name}` must be positive (> 0)")


def validate_field_is_list_of_selectors(
    value: Any, field_name: str, error: Type[Exception] = ValueError
) -> None:
    if not issubclass(type(value), list):
        raise error(f"`{field_name}` field must be list")
    if any(not is_selector(selector_or_value=e) for e in value):
        raise error(f"Parameter `{field_name}` must be a list of selectors")


def validate_field_is_empty_or_selector_or_list_of_string(
    value: Any, field_name: str
) -> None:
    if is_selector(selector_or_value=value) or value is None:
        return value
    validate_field_is_list_of_string(value=value, field_name=field_name)


def validate_field_is_list_of_string(
    value: Any, field_name: str, error: Type[Exception] = ValueError
) -> None:
    if not issubclass(type(value), list):
        raise error(f"`{field_name}` field must be list")
    if any(not issubclass(type(e), str) for e in value):
        raise error(f"Parameter `{field_name}` must be a list of string")


def validate_field_is_selector_or_one_of_values(
    value: Any, field_name: str, selected_values: set
) -> None:
    if is_selector(selector_or_value=value) or value is None:
        return value
    validate_field_is_one_of_selected_values(
        value=value, field_name=field_name, selected_values=selected_values
    )


def validate_field_is_one_of_selected_values(
    value: Any,
    field_name: str,
    selected_values: set,
    error: Type[Exception] = ValueError,
) -> None:
    if value not in selected_values:
        raise error(
            f"Value of field `{field_name}` must be in {selected_values}. Found: {value}"
        )


def validate_field_is_selector_or_has_given_type(
    value: Any, field_name: str, allowed_types: List[type]
) -> None:
    if is_selector(selector_or_value=value):
        return None
    validate_field_has_given_type(
        field_name=field_name, allowed_types=allowed_types, value=value
    )
    return None


def validate_field_has_given_type(
    value: Any,
    field_name: str,
    allowed_types: List[type],
    error: Type[Exception] = ValueError,
) -> None:
    if all(not issubclass(type(value), allowed_type) for allowed_type in allowed_types):
        raise error(
            f"`{field_name}` field type must be one of {allowed_types}. Detected: {value}"
        )


def validate_image_biding(value: Any, field_name: str = "image") -> None:
    try:
        if not issubclass(type(value), list):
            value = [value]
        for e in value:
            InferenceRequestImage.model_validate(e)
    except (ValueError, ValidationError) as error:
        raise VariableTypeError(
            f"Parameter `{field_name}` must be compatible with `InferenceRequestImage`"
        ) from error


def validate_selector_is_inference_parameter(
    step_type: str,
    field_name: str,
    input_step: GraphNone,
    applicable_fields: Set[str],
) -> None:
    if field_name not in applicable_fields:
        return None
    input_step_type = input_step.get_type()
    if input_step_type not in {"InferenceParameter"}:
        raise InvalidStepInputDetected(
            f"Field {field_name} of step {step_type} comes from invalid input type: {input_step_type}. "
            f"Expected: `InferenceParameter`"
        )


def validate_selector_holds_image(
    step_type: str,
    field_name: str,
    input_step: GraphNone,
    applicable_fields: Optional[Set[str]] = None,
) -> None:
    if applicable_fields is None:
        applicable_fields = {"image"}
    if field_name not in applicable_fields:
        return None
    if input_step.get_type() not in STEPS_WITH_IMAGE:
        raise InvalidStepInputDetected(
            f"Field {field_name} of step {step_type} comes from invalid input type: {input_step.get_type()}. "
            f"Expected: {STEPS_WITH_IMAGE}"
        )


def validate_selector_holds_detections(
    step_name: str,
    image_selector: Optional[str],
    detections_selector: str,
    field_name: str,
    input_step: GraphNone,
    applicable_fields: Optional[Set[str]] = None,
) -> None:
    if applicable_fields is None:
        applicable_fields = {"detections"}
    if field_name not in applicable_fields:
        return None
    if input_step.get_type() not in {
        "ObjectDetectionModel",
        "KeypointsDetectionModel",
        "InstanceSegmentationModel",
        "DetectionFilter",
        "DetectionsConsensus",
        "DetectionOffset",
        "YoloWorld",
    }:
        raise InvalidStepInputDetected(
            f"Step step with name {step_name} cannot take as an input predictions from {input_step.get_type()}. "
            f"Step requires detection-based output."
        )
    if get_last_selector_chunk(detections_selector) != "predictions":
        raise InvalidStepInputDetected(
            f"Step with name {step_name} must take as input step output of name `predictions`"
        )
    if not hasattr(input_step, "image") or image_selector is None:
        # Here, filter do not hold the reference to image, we skip the check in this case
        return None
    input_step_image_reference = input_step.image
    if image_selector != input_step_image_reference:
        raise InvalidStepInputDetected(
            f"Step step with name {step_name} was given detections reference that is bound to different image: "
            f"step.image: {image_selector}, detections step image: {input_step_image_reference}"
        )


def is_selector(selector_or_value: Any) -> bool:
    if not issubclass(type(selector_or_value), str):
        return False
    return selector_or_value.startswith("$")


def get_last_selector_chunk(selector: str) -> str:
    return selector.split(".")[-1]