from abc import ABCMeta, abstractmethod from enum import Enum from typing import Annotated, Any, Dict, List, Literal, Optional, Set, Tuple, Union from pydantic import ( BaseModel, ConfigDict, Field, NonNegativeInt, PositiveInt, confloat, field_validator, ) from inference.enterprise.workflows.entities.base import GraphNone from inference.enterprise.workflows.entities.validators import ( get_last_selector_chunk, is_selector, validate_field_has_given_type, validate_field_is_empty_or_selector_or_list_of_string, validate_field_is_in_range_zero_one_or_empty_or_selector, validate_field_is_list_of_selectors, validate_field_is_list_of_string, validate_field_is_one_of_selected_values, validate_field_is_selector_or_has_given_type, validate_field_is_selector_or_one_of_values, validate_image_biding, validate_image_is_valid_selector, validate_selector_holds_detections, validate_selector_holds_image, validate_selector_is_inference_parameter, validate_value_is_empty_or_number_in_range_zero_one, validate_value_is_empty_or_positive_number, validate_value_is_empty_or_selector_or_positive_number, ) from inference.enterprise.workflows.errors import ( ExecutionGraphError, InvalidStepInputDetected, VariableTypeError, ) class StepInterface(GraphNone, metaclass=ABCMeta): @abstractmethod def get_input_names(self) -> Set[str]: """ Supposed to give the name of all fields expected to represent inputs """ pass @abstractmethod def get_output_names(self) -> Set[str]: """ Supposed to give the name of all fields expected to represent outputs to be referred by other steps """ @abstractmethod def validate_field_selector( self, field_name: str, input_step: GraphNone, index: Optional[int] = None ) -> None: """ Supposed to validate the type of input is referred """ pass @abstractmethod def validate_field_binding(self, field_name: str, value: Any) -> None: """ Supposed to validate the type of value that is to be bounded with field as a result of graph execution (values passed by client to invocation, as well as constructed during graph execution) """ pass class RoboflowModel(BaseModel, StepInterface, metaclass=ABCMeta): model_config = ConfigDict(protected_namespaces=()) type: Literal["RoboflowModel"] name: str image: Union[str, List[str]] model_id: str disable_active_learning: Union[Optional[bool], str] = Field(default=False) @field_validator("image") @classmethod def validate_image(cls, value: Any) -> Union[str, List[str]]: validate_image_is_valid_selector(value=value) return value @field_validator("model_id") @classmethod def model_id_must_be_selector_or_str(cls, value: Any) -> str: validate_field_is_selector_or_has_given_type( value=value, field_name="model_id", allowed_types=[str] ) return value @field_validator("disable_active_learning") @classmethod def disable_active_learning_must_be_selector_or_bool( cls, value: Any ) -> Union[Optional[bool], str]: validate_field_is_selector_or_has_given_type( field_name="disable_active_learning", allowed_types=[type(None), bool], value=value, ) return value def get_type(self) -> str: return self.type def get_input_names(self) -> Set[str]: return {"image", "model_id", "disable_active_learning"} def get_output_names(self) -> Set[str]: return {"prediction_type"} def validate_field_selector( self, field_name: str, input_step: GraphNone, index: Optional[int] = None ) -> None: if not is_selector(selector_or_value=getattr(self, field_name)): raise ExecutionGraphError( f"Attempted to validate selector value for field {field_name}, but field is not selector." ) validate_selector_holds_image( step_type=self.type, field_name=field_name, input_step=input_step, ) validate_selector_is_inference_parameter( step_type=self.type, field_name=field_name, input_step=input_step, applicable_fields={"model_id", "disable_active_learning"}, ) def validate_field_binding(self, field_name: str, value: Any) -> None: if field_name == "image": validate_image_biding(value=value) elif field_name == "model_id": validate_field_has_given_type( field_name=field_name, allowed_types=[str], value=value, error=VariableTypeError, ) elif field_name == "disable_active_learning": validate_field_has_given_type( field_name=field_name, allowed_types=[bool], value=value, error=VariableTypeError, ) class ClassificationModel(RoboflowModel): type: Literal["ClassificationModel"] confidence: Union[Optional[float], str] = Field(default=0.4) @field_validator("confidence") @classmethod def confidence_must_be_selector_or_number( cls, value: Any ) -> Union[Optional[float], str]: validate_field_is_in_range_zero_one_or_empty_or_selector(value=value) return value def get_input_names(self) -> Set[str]: inputs = super().get_input_names() inputs.add("confidence") return inputs def get_output_names(self) -> Set[str]: outputs = super().get_output_names() outputs.update(["predictions", "top", "confidence", "parent_id"]) return outputs def validate_field_selector( self, field_name: str, input_step: GraphNone, index: Optional[int] = None ) -> None: super().validate_field_selector(field_name=field_name, input_step=input_step) validate_selector_is_inference_parameter( step_type=self.type, field_name=field_name, input_step=input_step, applicable_fields={"confidence"}, ) def validate_field_binding(self, field_name: str, value: Any) -> None: super().validate_field_binding(field_name=field_name, value=value) if field_name == "confidence": if value is None: raise VariableTypeError("Parameter `confidence` cannot be None") validate_value_is_empty_or_number_in_range_zero_one( value=value, error=VariableTypeError ) class MultiLabelClassificationModel(RoboflowModel): type: Literal["MultiLabelClassificationModel"] confidence: Union[Optional[float], str] = Field(default=0.4) @field_validator("confidence") @classmethod def confidence_must_be_selector_or_number( cls, value: Any ) -> Union[Optional[float], str]: validate_field_is_in_range_zero_one_or_empty_or_selector(value=value) return value def get_input_names(self) -> Set[str]: inputs = super().get_input_names() inputs.add("confidence") return inputs def get_output_names(self) -> Set[str]: outputs = super().get_output_names() outputs.update(["predictions", "predicted_classes", "parent_id"]) return outputs def validate_field_selector( self, field_name: str, input_step: GraphNone, index: Optional[int] = None ) -> None: super().validate_field_selector(field_name=field_name, input_step=input_step) validate_selector_is_inference_parameter( step_type=self.type, field_name=field_name, input_step=input_step, applicable_fields={"confidence"}, ) def validate_field_binding(self, field_name: str, value: Any) -> None: super().validate_field_binding(field_name=field_name, value=value) if field_name == "confidence": if value is None: raise VariableTypeError("Parameter `confidence` cannot be None") validate_value_is_empty_or_number_in_range_zero_one( value=value, error=VariableTypeError ) class ObjectDetectionModel(RoboflowModel): type: Literal["ObjectDetectionModel"] class_agnostic_nms: Union[Optional[bool], str] = Field(default=False) class_filter: Union[Optional[List[str]], str] = Field(default=None) confidence: Union[Optional[float], str] = Field(default=0.4) iou_threshold: Union[Optional[float], str] = Field(default=0.3) max_detections: Union[Optional[int], str] = Field(default=300) max_candidates: Union[Optional[int], str] = Field(default=3000) @field_validator("class_agnostic_nms") @classmethod def class_agnostic_nms_must_be_selector_or_bool( cls, value: Any ) -> Union[Optional[bool], str]: validate_field_is_selector_or_has_given_type( field_name="class_agnostic_nms", allowed_types=[type(None), bool], value=value, ) return value @field_validator("class_filter") @classmethod def class_filter_must_be_selector_or_list_of_string( cls, value: Any ) -> Union[Optional[List[str]], str]: validate_field_is_empty_or_selector_or_list_of_string( value=value, field_name="class_filter" ) return value @field_validator("confidence", "iou_threshold") @classmethod def field_must_be_selector_or_number_from_zero_to_one( cls, value: Any ) -> Union[Optional[float], str]: validate_field_is_in_range_zero_one_or_empty_or_selector( value=value, field_name="confidence | iou_threshold" ) return value @field_validator("max_detections", "max_candidates") @classmethod def field_must_be_selector_or_positive_number( cls, value: Any ) -> Union[Optional[int], str]: validate_value_is_empty_or_selector_or_positive_number( value=value, field_name="max_detections | max_candidates", ) return value def get_input_names(self) -> Set[str]: inputs = super().get_input_names() inputs.update( [ "class_agnostic_nms", "class_filter", "confidence", "iou_threshold", "max_detections", "max_candidates", ] ) return inputs def get_output_names(self) -> Set[str]: outputs = super().get_output_names() outputs.update(["predictions", "parent_id", "image"]) return outputs def validate_field_selector( self, field_name: str, input_step: GraphNone, index: Optional[int] = None ) -> None: super().validate_field_selector(field_name=field_name, input_step=input_step) validate_selector_is_inference_parameter( step_type=self.type, field_name=field_name, input_step=input_step, applicable_fields={ "class_agnostic_nms", "class_filter", "confidence", "iou_threshold", "max_detections", "max_candidates", }, ) def validate_field_binding(self, field_name: str, value: Any) -> None: super().validate_field_binding(field_name=field_name, value=value) if value is None: raise VariableTypeError(f"Parameter `{field_name}` cannot be None") if field_name == "class_agnostic_nms": validate_field_has_given_type( field_name=field_name, allowed_types=[bool], value=value, error=VariableTypeError, ) elif field_name == "class_filter": if value is None: return None validate_field_is_list_of_string( value=value, field_name=field_name, error=VariableTypeError ) elif field_name == "confidence" or field_name == "iou_threshold": validate_value_is_empty_or_number_in_range_zero_one( value=value, field_name=field_name, error=VariableTypeError, ) elif field_name == "max_detections" or field_name == "max_candidates": validate_value_is_empty_or_positive_number( value=value, field_name=field_name, error=VariableTypeError, ) class KeypointsDetectionModel(ObjectDetectionModel): type: Literal["KeypointsDetectionModel"] keypoint_confidence: Union[Optional[float], str] = Field(default=0.0) @field_validator("keypoint_confidence") @classmethod def keypoint_confidence_field_must_be_selector_or_number_from_zero_to_one( cls, value: Any ) -> Union[Optional[float], str]: validate_field_is_in_range_zero_one_or_empty_or_selector( value=value, field_name="keypoint_confidence" ) return value def get_input_names(self) -> Set[str]: inputs = super().get_input_names() inputs.add("keypoint_confidence") return inputs def validate_field_selector( self, field_name: str, input_step: GraphNone, index: Optional[int] = None ) -> None: super().validate_field_selector(field_name=field_name, input_step=input_step) validate_selector_is_inference_parameter( step_type=self.type, field_name=field_name, input_step=input_step, applicable_fields={"keypoint_confidence"}, ) def validate_field_binding(self, field_name: str, value: Any) -> None: super().validate_field_binding(field_name=field_name, value=value) if field_name == "keypoint_confidence": validate_value_is_empty_or_number_in_range_zero_one( value=value, field_name=field_name, error=VariableTypeError, ) DECODE_MODES = {"accurate", "tradeoff", "fast"} class InstanceSegmentationModel(ObjectDetectionModel): type: Literal["InstanceSegmentationModel"] mask_decode_mode: Optional[str] = Field(default="accurate") tradeoff_factor: Union[Optional[float], str] = Field(default=0.0) @field_validator("mask_decode_mode") @classmethod def mask_decode_mode_must_be_selector_or_one_of_allowed_values( cls, value: Any ) -> Optional[str]: validate_field_is_selector_or_one_of_values( value=value, field_name="mask_decode_mode", selected_values=DECODE_MODES, ) return value @field_validator("tradeoff_factor") @classmethod def field_must_be_selector_or_number_from_zero_to_one( cls, value: Any ) -> Union[Optional[float], str]: validate_field_is_in_range_zero_one_or_empty_or_selector( value=value, field_name="tradeoff_factor" ) return value def get_input_names(self) -> Set[str]: inputs = super().get_input_names() inputs.update(["mask_decode_mode", "tradeoff_factor"]) return inputs def validate_field_selector( self, field_name: str, input_step: GraphNone, index: Optional[int] = None ) -> None: super().validate_field_selector(field_name=field_name, input_step=input_step) validate_selector_is_inference_parameter( step_type=self.type, field_name=field_name, input_step=input_step, applicable_fields={"mask_decode_mode", "tradeoff_factor"}, ) def validate_field_binding(self, field_name: str, value: Any) -> None: super().validate_field_binding(field_name=field_name, value=value) if field_name == "mask_decode_mode": validate_field_is_one_of_selected_values( value=value, field_name=field_name, selected_values=DECODE_MODES, error=VariableTypeError, ) elif field_name == "tradeoff_factor": validate_value_is_empty_or_number_in_range_zero_one( value=value, field_name=field_name, error=VariableTypeError, ) class OCRModel(BaseModel, StepInterface): type: Literal["OCRModel"] name: str image: Union[str, List[str]] @field_validator("image") @classmethod def image_must_only_hold_selectors(cls, value: Any) -> Union[str, List[str]]: validate_image_is_valid_selector(value=value) return value def validate_field_selector( self, field_name: str, input_step: GraphNone, index: Optional[int] = None ) -> None: if not is_selector(selector_or_value=getattr(self, field_name)): raise ExecutionGraphError( f"Attempted to validate selector value for field {field_name}, but field is not selector." ) validate_selector_holds_image( step_type=self.type, field_name=field_name, input_step=input_step, ) def validate_field_binding(self, field_name: str, value: Any) -> None: if field_name == "image": validate_image_biding(value=value) def get_type(self) -> str: return self.type def get_input_names(self) -> Set[str]: return {"image"} def get_output_names(self) -> Set[str]: return {"result", "parent_id", "prediction_type"} class Crop(BaseModel, StepInterface): type: Literal["Crop"] name: str image: Union[str, List[str]] detections: str @field_validator("image") @classmethod def image_must_only_hold_selectors(cls, value: Any) -> Union[str, List[str]]: validate_image_is_valid_selector(value=value) return value @field_validator("detections") @classmethod def detections_must_hold_selector(cls, value: Any) -> str: if not is_selector(selector_or_value=value): raise ValueError("`detections` field can only contain selector values") return value def get_type(self) -> str: return self.type def get_input_names(self) -> Set[str]: return {"image", "detections"} def get_output_names(self) -> Set[str]: return {"crops", "parent_id"} def validate_field_selector( self, field_name: str, input_step: GraphNone, index: Optional[int] = None ) -> None: if not is_selector(selector_or_value=getattr(self, field_name)): raise ExecutionGraphError( f"Attempted to validate selector value for field {field_name}, but field is not selector." ) validate_selector_holds_image( step_type=self.type, field_name=field_name, input_step=input_step, ) validate_selector_holds_detections( step_name=self.name, image_selector=self.image, detections_selector=self.detections, field_name=field_name, input_step=input_step, ) def validate_field_binding(self, field_name: str, value: Any) -> None: if field_name == "image": validate_image_biding(value=value) class Operator(Enum): EQUAL = "equal" NOT_EQUAL = "not_equal" LOWER_THAN = "lower_than" GREATER_THAN = "greater_than" LOWER_OR_EQUAL_THAN = "lower_or_equal_than" GREATER_OR_EQUAL_THAN = "greater_or_equal_than" IN = "in" class Condition(BaseModel, StepInterface): type: Literal["Condition"] name: str left: Union[float, int, bool, str, list, set] operator: Operator right: Union[float, int, bool, str, list, set] step_if_true: str step_if_false: str def get_type(self) -> str: return self.type def get_input_names(self) -> Set[str]: return {"left", "right"} def get_output_names(self) -> Set[str]: return set() def validate_field_selector( self, field_name: str, input_step: GraphNone, index: Optional[int] = None ) -> None: if not is_selector(selector_or_value=getattr(self, field_name)): raise ExecutionGraphError( f"Attempted to validate selector value for field {field_name}, but field is not selector." ) input_type = input_step.get_type() if field_name in {"left", "right"}: if input_type == "InferenceImage": raise InvalidStepInputDetected( f"Field {field_name} of step {self.type} comes from invalid input type: {input_type}. " f"Expected: anything else than `InferenceImage`" ) def validate_field_binding(self, field_name: str, value: Any) -> None: pass class BinaryOperator(Enum): OR = "or" AND = "and" class DetectionFilterDefinition(BaseModel): type: Literal["DetectionFilterDefinition"] field_name: str operator: Operator reference_value: Union[float, int, bool, str, list, set] class CompoundDetectionFilterDefinition(BaseModel): type: Literal["CompoundDetectionFilterDefinition"] left: DetectionFilterDefinition operator: BinaryOperator right: DetectionFilterDefinition class DetectionFilter(BaseModel, StepInterface): type: Literal["DetectionFilter"] name: str predictions: str filter_definition: Annotated[ Union[DetectionFilterDefinition, CompoundDetectionFilterDefinition], Field(discriminator="type"), ] def get_input_names(self) -> Set[str]: return {"predictions"} def get_output_names(self) -> Set[str]: return {"predictions", "parent_id", "image", "prediction_type"} def validate_field_selector( self, field_name: str, input_step: GraphNone, index: Optional[int] = None ) -> None: if not is_selector(selector_or_value=getattr(self, field_name)): raise ExecutionGraphError( f"Attempted to validate selector value for field {field_name}, but field is not selector." ) validate_selector_holds_detections( step_name=self.name, image_selector=None, detections_selector=self.predictions, field_name=field_name, input_step=input_step, applicable_fields={"predictions"}, ) def validate_field_binding(self, field_name: str, value: Any) -> None: pass def get_type(self) -> str: return self.type class DetectionOffset(BaseModel, StepInterface): type: Literal["DetectionOffset"] name: str predictions: str offset_x: Union[int, str] offset_y: Union[int, str] def get_input_names(self) -> Set[str]: return {"predictions", "offset_x", "offset_y"} def get_output_names(self) -> Set[str]: return {"predictions", "parent_id", "image", "prediction_type"} def validate_field_selector( self, field_name: str, input_step: GraphNone, index: Optional[int] = None ) -> None: if not is_selector(selector_or_value=getattr(self, field_name)): raise ExecutionGraphError( f"Attempted to validate selector value for field {field_name}, but field is not selector." ) validate_selector_holds_detections( step_name=self.name, image_selector=None, detections_selector=self.predictions, field_name=field_name, input_step=input_step, applicable_fields={"predictions"}, ) validate_selector_is_inference_parameter( step_type=self.type, field_name=field_name, input_step=input_step, applicable_fields={"offset_x", "offset_y"}, ) def validate_field_binding(self, field_name: str, value: Any) -> None: if field_name in {"offset_x", "offset_y"}: validate_field_has_given_type( field_name=field_name, value=value, allowed_types=[int], error=VariableTypeError, ) def get_type(self) -> str: return self.type class AbsoluteStaticCrop(BaseModel, StepInterface): type: Literal["AbsoluteStaticCrop"] name: str image: Union[str, List[str]] x_center: Union[int, str] y_center: Union[int, str] width: Union[int, str] height: Union[int, str] @field_validator("image") @classmethod def image_must_only_hold_selectors(cls, value: Any) -> Union[str, List[str]]: validate_image_is_valid_selector(value=value) return value @field_validator("x_center", "y_center", "width", "height") @classmethod def validate_crops_coordinates(cls, value: Any) -> str: validate_value_is_empty_or_selector_or_positive_number( value=value, field_name="x_center | y_center | width | height" ) return value def get_type(self) -> str: return self.type def get_input_names(self) -> Set[str]: return {"image", "x_center", "y_center", "width", "height"} def get_output_names(self) -> Set[str]: return {"crops", "parent_id"} def validate_field_selector( self, field_name: str, input_step: GraphNone, index: Optional[int] = None ) -> None: if not is_selector(selector_or_value=getattr(self, field_name)): raise ExecutionGraphError( f"Attempted to validate selector value for field {field_name}, but field is not selector." ) validate_selector_holds_image( step_type=self.type, field_name=field_name, input_step=input_step, ) validate_selector_is_inference_parameter( step_type=self.type, field_name=field_name, input_step=input_step, applicable_fields={"x_center", "y_center", "width", "height"}, ) def validate_field_binding(self, field_name: str, value: Any) -> None: if field_name == "image": validate_image_biding(value=value) if field_name in {"x_center", "y_center", "width", "height"}: if ( not issubclass(type(value), int) and not issubclass(type(value), float) ) or value != round(value): raise VariableTypeError( f"Field {field_name} of step {self.type} must be integer" ) class RelativeStaticCrop(BaseModel, StepInterface): type: Literal["RelativeStaticCrop"] name: str image: Union[str, List[str]] x_center: Union[float, str] y_center: Union[float, str] width: Union[float, str] height: Union[float, str] @field_validator("image") @classmethod def image_must_only_hold_selectors(cls, value: Any) -> Union[str, List[str]]: validate_image_is_valid_selector(value=value) return value @field_validator("x_center", "y_center", "width", "height") @classmethod def detections_must_hold_selector(cls, value: Any) -> str: if issubclass(type(value), str): if not is_selector(selector_or_value=value): raise ValueError("Field must be either float of valid selector") elif not issubclass(type(value), float): raise ValueError("Field must be either float of valid selector") return value def get_type(self) -> str: return self.type def get_input_names(self) -> Set[str]: return {"image", "x_center", "y_center", "width", "height"} def get_output_names(self) -> Set[str]: return {"crops", "parent_id"} def validate_field_selector( self, field_name: str, input_step: GraphNone, index: Optional[int] = None ) -> None: if not is_selector(selector_or_value=getattr(self, field_name)): raise ExecutionGraphError( f"Attempted to validate selector value for field {field_name}, but field is not selector." ) validate_selector_holds_image( step_type=self.type, field_name=field_name, input_step=input_step, ) validate_selector_is_inference_parameter( step_type=self.type, field_name=field_name, input_step=input_step, applicable_fields={"x_center", "y_center", "width", "height"}, ) def validate_field_binding(self, field_name: str, value: Any) -> None: if field_name == "image": validate_image_biding(value=value) if field_name in {"x_center", "y_center", "width", "height"}: validate_field_has_given_type( field_name=field_name, value=value, allowed_types=[float], error=VariableTypeError, ) class ClipComparison(BaseModel, StepInterface): type: Literal["ClipComparison"] name: str image: Union[str, List[str]] text: Union[str, List[str]] @field_validator("image") @classmethod def image_must_only_hold_selectors(cls, value: Any) -> Union[str, List[str]]: validate_image_is_valid_selector(value=value) return value @field_validator("text") @classmethod def text_must_be_valid(cls, value: Any) -> Union[str, List[str]]: if is_selector(selector_or_value=value): return value if issubclass(type(value), list): validate_field_is_list_of_string(value=value, field_name="text") elif not issubclass(type(value), str): raise ValueError("`text` field given must be string or list of strings") return value def validate_field_selector( self, field_name: str, input_step: GraphNone, index: Optional[int] = None ) -> None: if not is_selector(selector_or_value=getattr(self, field_name)): raise ExecutionGraphError( f"Attempted to validate selector value for field {field_name}, but field is not selector." ) validate_selector_holds_image( step_type=self.type, field_name=field_name, input_step=input_step, ) validate_selector_is_inference_parameter( step_type=self.type, field_name=field_name, input_step=input_step, applicable_fields={"text"}, ) def validate_field_binding(self, field_name: str, value: Any) -> None: if field_name == "image": validate_image_biding(value=value) if field_name == "text": if issubclass(type(value), list): validate_field_is_list_of_string( value=value, field_name=field_name, error=VariableTypeError ) elif not issubclass(type(value), str): validate_field_has_given_type( value=value, field_name=field_name, allowed_types=[str], error=VariableTypeError, ) def get_type(self) -> str: return self.type def get_input_names(self) -> Set[str]: return {"image", "text"} def get_output_names(self) -> Set[str]: return {"similarity", "parent_id", "predictions_type"} class AggregationMode(Enum): AVERAGE = "average" MAX = "max" MIN = "min" class DetectionsConsensus(BaseModel, StepInterface): type: Literal["DetectionsConsensus"] name: str predictions: List[str] required_votes: Union[int, str] class_aware: Union[bool, str] = Field(default=True) iou_threshold: Union[float, str] = Field(default=0.3) confidence: Union[float, str] = Field(default=0.0) classes_to_consider: Optional[Union[List[str], str]] = Field(default=None) required_objects: Optional[Union[int, Dict[str, int], str]] = Field(default=None) presence_confidence_aggregation: AggregationMode = Field( default=AggregationMode.MAX ) detections_merge_confidence_aggregation: AggregationMode = Field( default=AggregationMode.AVERAGE ) detections_merge_coordinates_aggregation: AggregationMode = Field( default=AggregationMode.AVERAGE ) @field_validator("predictions") @classmethod def predictions_must_be_list_of_selectors(cls, value: Any) -> List[str]: validate_field_is_list_of_selectors(value=value, field_name="predictions") if len(value) < 1: raise ValueError( "There must be at least 1 `predictions` selectors in consensus step" ) return value @field_validator("required_votes") @classmethod def required_votes_must_be_selector_or_positive_integer( cls, value: Any ) -> Union[str, int]: if value is None: raise ValueError("Field `required_votes` is required.") validate_value_is_empty_or_selector_or_positive_number( value=value, field_name="required_votes" ) return value @field_validator("class_aware") @classmethod def class_aware_must_be_selector_or_boolean(cls, value: Any) -> Union[str, bool]: validate_field_is_selector_or_has_given_type( value=value, field_name="class_aware", allowed_types=[bool] ) return value @field_validator("iou_threshold", "confidence") @classmethod def field_must_be_selector_or_number_from_zero_to_one( cls, value: Any ) -> Union[str, float]: if value is None: raise ValueError("Fields `iou_threshold` and `confidence` cannot be None") validate_field_is_in_range_zero_one_or_empty_or_selector( value=value, field_name="iou_threshold | confidence" ) return value @field_validator("classes_to_consider") @classmethod def classes_to_consider_must_be_empty_or_selector_or_list_of_strings( cls, value: Any ) -> Optional[Union[str, List[str]]]: validate_field_is_empty_or_selector_or_list_of_string( value=value, field_name="classes_to_consider" ) return value @field_validator("required_objects") @classmethod def required_objects_field_must_be_valid( cls, value: Any ) -> Optional[Union[str, int, Dict[str, int]]]: if value is None: return value validate_field_is_selector_or_has_given_type( value=value, field_name="required_objects", allowed_types=[int, dict] ) if issubclass(type(value), int): validate_value_is_empty_or_positive_number( value=value, field_name="required_objects" ) return value elif issubclass(type(value), dict): for k, v in value.items(): if v is None: raise ValueError(f"Field `required_objects[{k}]` must not be None.") validate_value_is_empty_or_positive_number( value=v, field_name=f"required_objects[{k}]" ) return value def get_input_names(self) -> Set[str]: return { "predictions", "required_votes", "class_aware", "iou_threshold", "confidence", "classes_to_consider", "required_objects", } def get_output_names(self) -> Set[str]: return { "parent_id", "predictions", "image", "object_present", "presence_confidence", "predictions_type", } def validate_field_selector( self, field_name: str, input_step: GraphNone, index: Optional[int] = None ) -> None: if field_name != "predictions" and not is_selector( selector_or_value=getattr(self, field_name) ): raise ExecutionGraphError( f"Attempted to validate selector value for field {field_name}, but field is not selector." ) if field_name == "predictions": if index is None or index > len(self.predictions): raise ExecutionGraphError( f"Attempted to validate selector value for field {field_name}, which requires multiple inputs, " f"but `index` not provided." ) if not is_selector( selector_or_value=self.predictions[index], ): raise ExecutionGraphError( f"Attempted to validate selector value for field {field_name}[{index}], but field is not selector." ) validate_selector_holds_detections( step_name=self.name, image_selector=None, detections_selector=self.predictions[index], field_name=field_name, input_step=input_step, applicable_fields={"predictions"}, ) return None validate_selector_is_inference_parameter( step_type=self.type, field_name=field_name, input_step=input_step, applicable_fields={ "required_votes", "class_aware", "iou_threshold", "confidence", "classes_to_consider", "required_objects", }, ) def validate_field_binding(self, field_name: str, value: Any) -> None: if field_name == "required_votes": if value is None: raise VariableTypeError("Field `required_votes` cannot be None.") validate_value_is_empty_or_positive_number( value=value, field_name="required_votes", error=VariableTypeError ) elif field_name == "class_aware": validate_field_has_given_type( field_name=field_name, allowed_types=[bool], value=value, error=VariableTypeError, ) elif field_name in {"iou_threshold", "confidence"}: if value is None: raise VariableTypeError(f"Fields `{field_name}` cannot be None.") validate_value_is_empty_or_number_in_range_zero_one( value=value, field_name=field_name, error=VariableTypeError, ) elif field_name == "classes_to_consider": if value is None: return None validate_field_is_list_of_string( value=value, field_name=field_name, error=VariableTypeError, ) elif field_name == "required_objects": self._validate_required_objects_binding(value=value) return None def get_type(self) -> str: return self.type def _validate_required_objects_binding(self, value: Any) -> None: if value is None: return value validate_field_has_given_type( value=value, field_name="required_objects", allowed_types=[int, dict], error=VariableTypeError, ) if issubclass(type(value), int): validate_value_is_empty_or_positive_number( value=value, field_name="required_objects", error=VariableTypeError, ) return None for k, v in value.items(): if v is None: raise VariableTypeError( f"Field `required_objects[{k}]` must not be None." ) validate_value_is_empty_or_positive_number( value=v, field_name=f"required_objects[{k}]", error=VariableTypeError, ) ACTIVE_LEARNING_DATA_COLLECTOR_ELIGIBLE_SELECTORS = { "ObjectDetectionModel": "predictions", "KeypointsDetectionModel": "predictions", "InstanceSegmentationModel": "predictions", "DetectionFilter": "predictions", "DetectionsConsensus": "predictions", "DetectionOffset": "predictions", "YoloWorld": "predictions", "ClassificationModel": "top", } class DisabledActiveLearningConfiguration(BaseModel): enabled: bool @field_validator("enabled") @classmethod def ensure_only_false_is_valid(cls, value: Any) -> bool: if value is not False: raise ValueError( "One can only specify enabled=False in `DisabledActiveLearningConfiguration`" ) return value class LimitDefinition(BaseModel): type: Literal["minutely", "hourly", "daily"] value: PositiveInt class RandomSamplingConfig(BaseModel): type: Literal["random"] name: str traffic_percentage: confloat(ge=0.0, le=1.0) tags: List[str] = Field(default_factory=lambda: []) limits: List[LimitDefinition] = Field(default_factory=lambda: []) class CloseToThresholdSampling(BaseModel): type: Literal["close_to_threshold"] name: str probability: confloat(ge=0.0, le=1.0) threshold: confloat(ge=0.0, le=1.0) epsilon: confloat(ge=0.0, le=1.0) max_batch_images: Optional[int] = Field(default=None) only_top_classes: bool = Field(default=True) minimum_objects_close_to_threshold: int = Field(default=1) selected_class_names: Optional[List[str]] = Field(default=None) tags: List[str] = Field(default_factory=lambda: []) limits: List[LimitDefinition] = Field(default_factory=lambda: []) class ClassesBasedSampling(BaseModel): type: Literal["classes_based"] name: str probability: confloat(ge=0.0, le=1.0) selected_class_names: List[str] tags: List[str] = Field(default_factory=lambda: []) limits: List[LimitDefinition] = Field(default_factory=lambda: []) class DetectionsBasedSampling(BaseModel): type: Literal["detections_number_based"] name: str probability: confloat(ge=0.0, le=1.0) more_than: Optional[NonNegativeInt] less_than: Optional[NonNegativeInt] selected_class_names: Optional[List[str]] = Field(default=None) tags: List[str] = Field(default_factory=lambda: []) limits: List[LimitDefinition] = Field(default_factory=lambda: []) class ActiveLearningBatchingStrategy(BaseModel): batches_name_prefix: str recreation_interval: Literal["never", "daily", "weekly", "monthly"] max_batch_images: Optional[int] = Field(default=None) ActiveLearningStrategyType = Annotated[ Union[ RandomSamplingConfig, CloseToThresholdSampling, ClassesBasedSampling, DetectionsBasedSampling, ], Field(discriminator="type"), ] class EnabledActiveLearningConfiguration(BaseModel): enabled: bool persist_predictions: bool sampling_strategies: List[ActiveLearningStrategyType] batching_strategy: ActiveLearningBatchingStrategy tags: List[str] = Field(default_factory=lambda: []) max_image_size: Optional[Tuple[PositiveInt, PositiveInt]] = Field(default=None) jpeg_compression_level: int = Field(default=95) @field_validator("jpeg_compression_level") @classmethod def validate_json_compression_level(cls, value: Any) -> int: validate_field_has_given_type( field_name="jpeg_compression_level", allowed_types=[int], value=value ) if value <= 0 or value > 100: raise ValueError("`jpeg_compression_level` must be in range [1, 100]") return value class ActiveLearningDataCollector(BaseModel, StepInterface): type: Literal["ActiveLearningDataCollector"] name: str image: str predictions: str target_dataset: str target_dataset_api_key: Optional[str] = Field(default=None) disable_active_learning: Union[bool, str] = Field(default=False) active_learning_configuration: Optional[ Union[EnabledActiveLearningConfiguration, DisabledActiveLearningConfiguration] ] = Field(default=None) @field_validator("image") @classmethod def image_must_only_hold_selectors(cls, value: Any) -> Union[str, List[str]]: validate_image_is_valid_selector(value=value) return value @field_validator("predictions") @classmethod def predictions_must_hold_selector(cls, value: Any) -> str: if not is_selector(selector_or_value=value): raise ValueError("`predictions` field can only contain selector values") return value @field_validator("target_dataset") @classmethod def validate_target_dataset_field(cls, value: Any) -> str: validate_field_is_selector_or_has_given_type( value=value, field_name="target_dataset", allowed_types=[str] ) return value @field_validator("target_dataset_api_key") @classmethod def validate_target_dataset_api_key_field(cls, value: Any) -> Union[str, bool]: validate_field_is_selector_or_has_given_type( value=value, field_name="target_dataset_api_key", allowed_types=[bool, type(None)], ) return value @field_validator("disable_active_learning") @classmethod def validate_boolean_flags_or_selectors(cls, value: Any) -> Union[str, bool]: validate_field_is_selector_or_has_given_type( value=value, field_name="disable_active_learning", allowed_types=[bool] ) return value def get_type(self) -> str: return self.type def get_input_names(self) -> Set[str]: return { "image", "predictions", "target_dataset", "target_dataset_api_key", "disable_active_learning", } def get_output_names(self) -> Set[str]: return set() def validate_field_selector( self, field_name: str, input_step: GraphNone, index: Optional[int] = None ) -> None: selector = getattr(self, field_name) if not is_selector(selector_or_value=selector): raise ExecutionGraphError( f"Attempted to validate selector value for field {field_name}, but field is not selector." ) if field_name == "predictions": input_step_type = input_step.get_type() expected_last_selector_chunk = ( ACTIVE_LEARNING_DATA_COLLECTOR_ELIGIBLE_SELECTORS.get(input_step_type) ) if expected_last_selector_chunk is None: raise ExecutionGraphError( f"Attempted to validate predictions selector of {self.name} step, but input step of type: " f"{input_step_type} does match by type." ) if get_last_selector_chunk(selector) != expected_last_selector_chunk: raise ExecutionGraphError( f"It is only allowed to refer to {input_step_type} step output named {expected_last_selector_chunk}. " f"Reference that was found: {selector}" ) input_step_image = getattr(input_step, "image", self.image) if input_step_image != self.image: raise ExecutionGraphError( f"ActiveLearningDataCollector step refers to input step that uses reference to different image. " f"ActiveLearningDataCollector step image: {self.image}. Input step (of type {input_step_image}) " f"uses {input_step_image}." ) validate_selector_holds_image( step_type=self.type, field_name=field_name, input_step=input_step, ) validate_selector_is_inference_parameter( step_type=self.type, field_name=field_name, input_step=input_step, applicable_fields={ "target_dataset", "target_dataset_api_key", "disable_active_learning", }, ) def validate_field_binding(self, field_name: str, value: Any) -> None: if field_name == "image": validate_image_biding(value=value) elif field_name in {"disable_active_learning"}: validate_field_has_given_type( field_name=field_name, allowed_types=[bool], value=value, error=VariableTypeError, ) elif field_name in {"target_dataset"}: validate_field_has_given_type( field_name=field_name, allowed_types=[str], value=value, error=VariableTypeError, ) elif field_name in {"target_dataset_api_key"}: validate_field_has_given_type( field_name=field_name, allowed_types=[str], value=value, error=VariableTypeError, ) class YoloWorld(BaseModel, StepInterface): type: Literal["YoloWorld"] name: str image: str class_names: Union[str, List[str]] version: Optional[str] = Field(default="l") confidence: Union[Optional[float], str] = Field(default=0.4) @field_validator("image") @classmethod def image_must_only_hold_selectors(cls, value: Any) -> Union[str, List[str]]: validate_image_is_valid_selector(value=value) return value @field_validator("class_names") @classmethod def validate_class_names(cls, value: Any) -> Union[str, List[str]]: if is_selector(selector_or_value=value): return value if issubclass(type(value), list): validate_field_is_list_of_string(value=value, field_name="class_names") return value raise ValueError( "`class_names` field given must be selector or list of strings" ) @field_validator("version") @classmethod def validate_model_version(cls, value: Any) -> Optional[str]: validate_field_is_selector_or_one_of_values( value=value, selected_values={None, "s", "m", "l"}, field_name="version", ) return value @field_validator("confidence") @classmethod def field_must_be_selector_or_number_from_zero_to_one( cls, value: Any ) -> Union[Optional[float], str]: if value is None: return None validate_field_is_in_range_zero_one_or_empty_or_selector( value=value, field_name="confidence" ) return value def get_input_names(self) -> Set[str]: return {"image", "class_names", "version", "confidence"} def get_output_names(self) -> Set[str]: return {"predictions", "parent_id", "image", "prediction_type"} def validate_field_selector( self, field_name: str, input_step: GraphNone, index: Optional[int] = None ) -> None: selector = getattr(self, field_name) if not is_selector(selector_or_value=selector): raise ExecutionGraphError( f"Attempted to validate selector value for field {field_name}, but field is not selector." ) validate_selector_holds_image( step_type=self.type, field_name=field_name, input_step=input_step, ) validate_selector_is_inference_parameter( step_type=self.type, field_name=field_name, input_step=input_step, applicable_fields={"class_names", "version", "confidence"}, ) def validate_field_binding(self, field_name: str, value: Any) -> None: if field_name == "image": validate_image_biding(value=value) elif field_name == "class_names": validate_field_is_list_of_string( value=value, field_name=field_name, error=VariableTypeError, ) elif field_name == "version": validate_field_is_one_of_selected_values( value=value, field_name=field_name, selected_values={None, "s", "m", "l"}, error=VariableTypeError, ) elif field_name == "confidence": validate_value_is_empty_or_number_in_range_zero_one( value=value, field_name=field_name, error=VariableTypeError, ) def get_type(self) -> str: return self.type