Fucius's picture
Upload 422 files
df6c67d verified
from typing import Any, List, Optional, Set, Type
from pydantic import ValidationError
from inference.core.entities.requests.inference import InferenceRequestImage
from inference.enterprise.workflows.entities.base import GraphNone
from inference.enterprise.workflows.errors import (
InvalidStepInputDetected,
VariableTypeError,
)
STEPS_WITH_IMAGE = {
"InferenceImage",
"Crop",
"AbsoluteStaticCrop",
"RelativeStaticCrop",
}
def validate_image_is_valid_selector(value: Any, field_name: str = "image") -> None:
if issubclass(type(value), list):
if any(not is_selector(selector_or_value=e) for e in value):
raise ValueError(f"`{field_name}` field can only contain selector values")
elif not is_selector(selector_or_value=value):
raise ValueError(f"`{field_name}` field can only contain selector values")
def validate_field_is_in_range_zero_one_or_empty_or_selector(
value: Any, field_name: str = "confidence"
) -> None:
if is_selector(selector_or_value=value) or value is None:
return None
validate_value_is_empty_or_number_in_range_zero_one(
value=value, field_name=field_name
)
def validate_value_is_empty_or_number_in_range_zero_one(
value: Any, field_name: str = "confidence", error: Type[Exception] = ValueError
) -> None:
validate_field_has_given_type(
field_name=field_name,
allowed_types=[type(None), int, float],
value=value,
error=error,
)
if value is None:
return None
if not (0 <= value <= 1):
raise error(f"Parameter `{field_name}` must be in range [0.0, 1.0]")
def validate_value_is_empty_or_selector_or_positive_number(
value: Any, field_name: str
) -> None:
if is_selector(selector_or_value=value):
return None
validate_value_is_empty_or_positive_number(value=value, field_name=field_name)
def validate_value_is_empty_or_positive_number(
value: Any, field_name: str, error: Type[Exception] = ValueError
) -> None:
validate_field_has_given_type(
field_name=field_name,
allowed_types=[type(None), int, float],
value=value,
error=error,
)
if value is None:
return None
if value <= 0:
raise error(f"Parameter `{field_name}` must be positive (> 0)")
def validate_field_is_list_of_selectors(
value: Any, field_name: str, error: Type[Exception] = ValueError
) -> None:
if not issubclass(type(value), list):
raise error(f"`{field_name}` field must be list")
if any(not is_selector(selector_or_value=e) for e in value):
raise error(f"Parameter `{field_name}` must be a list of selectors")
def validate_field_is_empty_or_selector_or_list_of_string(
value: Any, field_name: str
) -> None:
if is_selector(selector_or_value=value) or value is None:
return value
validate_field_is_list_of_string(value=value, field_name=field_name)
def validate_field_is_list_of_string(
value: Any, field_name: str, error: Type[Exception] = ValueError
) -> None:
if not issubclass(type(value), list):
raise error(f"`{field_name}` field must be list")
if any(not issubclass(type(e), str) for e in value):
raise error(f"Parameter `{field_name}` must be a list of string")
def validate_field_is_selector_or_one_of_values(
value: Any, field_name: str, selected_values: set
) -> None:
if is_selector(selector_or_value=value) or value is None:
return value
validate_field_is_one_of_selected_values(
value=value, field_name=field_name, selected_values=selected_values
)
def validate_field_is_one_of_selected_values(
value: Any,
field_name: str,
selected_values: set,
error: Type[Exception] = ValueError,
) -> None:
if value not in selected_values:
raise error(
f"Value of field `{field_name}` must be in {selected_values}. Found: {value}"
)
def validate_field_is_selector_or_has_given_type(
value: Any, field_name: str, allowed_types: List[type]
) -> None:
if is_selector(selector_or_value=value):
return None
validate_field_has_given_type(
field_name=field_name, allowed_types=allowed_types, value=value
)
return None
def validate_field_has_given_type(
value: Any,
field_name: str,
allowed_types: List[type],
error: Type[Exception] = ValueError,
) -> None:
if all(not issubclass(type(value), allowed_type) for allowed_type in allowed_types):
raise error(
f"`{field_name}` field type must be one of {allowed_types}. Detected: {value}"
)
def validate_image_biding(value: Any, field_name: str = "image") -> None:
try:
if not issubclass(type(value), list):
value = [value]
for e in value:
InferenceRequestImage.model_validate(e)
except (ValueError, ValidationError) as error:
raise VariableTypeError(
f"Parameter `{field_name}` must be compatible with `InferenceRequestImage`"
) from error
def validate_selector_is_inference_parameter(
step_type: str,
field_name: str,
input_step: GraphNone,
applicable_fields: Set[str],
) -> None:
if field_name not in applicable_fields:
return None
input_step_type = input_step.get_type()
if input_step_type not in {"InferenceParameter"}:
raise InvalidStepInputDetected(
f"Field {field_name} of step {step_type} comes from invalid input type: {input_step_type}. "
f"Expected: `InferenceParameter`"
)
def validate_selector_holds_image(
step_type: str,
field_name: str,
input_step: GraphNone,
applicable_fields: Optional[Set[str]] = None,
) -> None:
if applicable_fields is None:
applicable_fields = {"image"}
if field_name not in applicable_fields:
return None
if input_step.get_type() not in STEPS_WITH_IMAGE:
raise InvalidStepInputDetected(
f"Field {field_name} of step {step_type} comes from invalid input type: {input_step.get_type()}. "
f"Expected: {STEPS_WITH_IMAGE}"
)
def validate_selector_holds_detections(
step_name: str,
image_selector: Optional[str],
detections_selector: str,
field_name: str,
input_step: GraphNone,
applicable_fields: Optional[Set[str]] = None,
) -> None:
if applicable_fields is None:
applicable_fields = {"detections"}
if field_name not in applicable_fields:
return None
if input_step.get_type() not in {
"ObjectDetectionModel",
"KeypointsDetectionModel",
"InstanceSegmentationModel",
"DetectionFilter",
"DetectionsConsensus",
"DetectionOffset",
"YoloWorld",
}:
raise InvalidStepInputDetected(
f"Step step with name {step_name} cannot take as an input predictions from {input_step.get_type()}. "
f"Step requires detection-based output."
)
if get_last_selector_chunk(detections_selector) != "predictions":
raise InvalidStepInputDetected(
f"Step with name {step_name} must take as input step output of name `predictions`"
)
if not hasattr(input_step, "image") or image_selector is None:
# Here, filter do not hold the reference to image, we skip the check in this case
return None
input_step_image_reference = input_step.image
if image_selector != input_step_image_reference:
raise InvalidStepInputDetected(
f"Step step with name {step_name} was given detections reference that is bound to different image: "
f"step.image: {image_selector}, detections step image: {input_step_image_reference}"
)
def is_selector(selector_or_value: Any) -> bool:
if not issubclass(type(selector_or_value), str):
return False
return selector_or_value.startswith("$")
def get_last_selector_chunk(selector: str) -> str:
return selector.split(".")[-1]