from typing import Any, List, Optional, Set, Type from pydantic import ValidationError from inference.core.entities.requests.inference import InferenceRequestImage from inference.enterprise.workflows.entities.base import GraphNone from inference.enterprise.workflows.errors import ( InvalidStepInputDetected, VariableTypeError, ) STEPS_WITH_IMAGE = { "InferenceImage", "Crop", "AbsoluteStaticCrop", "RelativeStaticCrop", } def validate_image_is_valid_selector(value: Any, field_name: str = "image") -> None: if issubclass(type(value), list): if any(not is_selector(selector_or_value=e) for e in value): raise ValueError(f"`{field_name}` field can only contain selector values") elif not is_selector(selector_or_value=value): raise ValueError(f"`{field_name}` field can only contain selector values") def validate_field_is_in_range_zero_one_or_empty_or_selector( value: Any, field_name: str = "confidence" ) -> None: if is_selector(selector_or_value=value) or value is None: return None validate_value_is_empty_or_number_in_range_zero_one( value=value, field_name=field_name ) def validate_value_is_empty_or_number_in_range_zero_one( value: Any, field_name: str = "confidence", error: Type[Exception] = ValueError ) -> None: validate_field_has_given_type( field_name=field_name, allowed_types=[type(None), int, float], value=value, error=error, ) if value is None: return None if not (0 <= value <= 1): raise error(f"Parameter `{field_name}` must be in range [0.0, 1.0]") def validate_value_is_empty_or_selector_or_positive_number( value: Any, field_name: str ) -> None: if is_selector(selector_or_value=value): return None validate_value_is_empty_or_positive_number(value=value, field_name=field_name) def validate_value_is_empty_or_positive_number( value: Any, field_name: str, error: Type[Exception] = ValueError ) -> None: validate_field_has_given_type( field_name=field_name, allowed_types=[type(None), int, float], value=value, error=error, ) if value is None: return None if value <= 0: raise error(f"Parameter `{field_name}` must be positive (> 0)") def validate_field_is_list_of_selectors( value: Any, field_name: str, error: Type[Exception] = ValueError ) -> None: if not issubclass(type(value), list): raise error(f"`{field_name}` field must be list") if any(not is_selector(selector_or_value=e) for e in value): raise error(f"Parameter `{field_name}` must be a list of selectors") def validate_field_is_empty_or_selector_or_list_of_string( value: Any, field_name: str ) -> None: if is_selector(selector_or_value=value) or value is None: return value validate_field_is_list_of_string(value=value, field_name=field_name) def validate_field_is_list_of_string( value: Any, field_name: str, error: Type[Exception] = ValueError ) -> None: if not issubclass(type(value), list): raise error(f"`{field_name}` field must be list") if any(not issubclass(type(e), str) for e in value): raise error(f"Parameter `{field_name}` must be a list of string") def validate_field_is_selector_or_one_of_values( value: Any, field_name: str, selected_values: set ) -> None: if is_selector(selector_or_value=value) or value is None: return value validate_field_is_one_of_selected_values( value=value, field_name=field_name, selected_values=selected_values ) def validate_field_is_one_of_selected_values( value: Any, field_name: str, selected_values: set, error: Type[Exception] = ValueError, ) -> None: if value not in selected_values: raise error( f"Value of field `{field_name}` must be in {selected_values}. Found: {value}" ) def validate_field_is_selector_or_has_given_type( value: Any, field_name: str, allowed_types: List[type] ) -> None: if is_selector(selector_or_value=value): return None validate_field_has_given_type( field_name=field_name, allowed_types=allowed_types, value=value ) return None def validate_field_has_given_type( value: Any, field_name: str, allowed_types: List[type], error: Type[Exception] = ValueError, ) -> None: if all(not issubclass(type(value), allowed_type) for allowed_type in allowed_types): raise error( f"`{field_name}` field type must be one of {allowed_types}. Detected: {value}" ) def validate_image_biding(value: Any, field_name: str = "image") -> None: try: if not issubclass(type(value), list): value = [value] for e in value: InferenceRequestImage.model_validate(e) except (ValueError, ValidationError) as error: raise VariableTypeError( f"Parameter `{field_name}` must be compatible with `InferenceRequestImage`" ) from error def validate_selector_is_inference_parameter( step_type: str, field_name: str, input_step: GraphNone, applicable_fields: Set[str], ) -> None: if field_name not in applicable_fields: return None input_step_type = input_step.get_type() if input_step_type not in {"InferenceParameter"}: raise InvalidStepInputDetected( f"Field {field_name} of step {step_type} comes from invalid input type: {input_step_type}. " f"Expected: `InferenceParameter`" ) def validate_selector_holds_image( step_type: str, field_name: str, input_step: GraphNone, applicable_fields: Optional[Set[str]] = None, ) -> None: if applicable_fields is None: applicable_fields = {"image"} if field_name not in applicable_fields: return None if input_step.get_type() not in STEPS_WITH_IMAGE: raise InvalidStepInputDetected( f"Field {field_name} of step {step_type} comes from invalid input type: {input_step.get_type()}. " f"Expected: {STEPS_WITH_IMAGE}" ) def validate_selector_holds_detections( step_name: str, image_selector: Optional[str], detections_selector: str, field_name: str, input_step: GraphNone, applicable_fields: Optional[Set[str]] = None, ) -> None: if applicable_fields is None: applicable_fields = {"detections"} if field_name not in applicable_fields: return None if input_step.get_type() not in { "ObjectDetectionModel", "KeypointsDetectionModel", "InstanceSegmentationModel", "DetectionFilter", "DetectionsConsensus", "DetectionOffset", "YoloWorld", }: raise InvalidStepInputDetected( f"Step step with name {step_name} cannot take as an input predictions from {input_step.get_type()}. " f"Step requires detection-based output." ) if get_last_selector_chunk(detections_selector) != "predictions": raise InvalidStepInputDetected( f"Step with name {step_name} must take as input step output of name `predictions`" ) if not hasattr(input_step, "image") or image_selector is None: # Here, filter do not hold the reference to image, we skip the check in this case return None input_step_image_reference = input_step.image if image_selector != input_step_image_reference: raise InvalidStepInputDetected( f"Step step with name {step_name} was given detections reference that is bound to different image: " f"step.image: {image_selector}, detections step image: {input_step_image_reference}" ) def is_selector(selector_or_value: Any) -> bool: if not issubclass(type(selector_or_value), str): return False return selector_or_value.startswith("$") def get_last_selector_chunk(selector: str) -> str: return selector.split(".")[-1]