Fucius's picture
Upload 422 files
df6c67d verified
from abc import ABCMeta, abstractmethod
from enum import Enum
from typing import Annotated, Any, Dict, List, Literal, Optional, Set, Tuple, Union
from pydantic import (
BaseModel,
ConfigDict,
Field,
NonNegativeInt,
PositiveInt,
confloat,
field_validator,
)
from inference.enterprise.workflows.entities.base import GraphNone
from inference.enterprise.workflows.entities.validators import (
get_last_selector_chunk,
is_selector,
validate_field_has_given_type,
validate_field_is_empty_or_selector_or_list_of_string,
validate_field_is_in_range_zero_one_or_empty_or_selector,
validate_field_is_list_of_selectors,
validate_field_is_list_of_string,
validate_field_is_one_of_selected_values,
validate_field_is_selector_or_has_given_type,
validate_field_is_selector_or_one_of_values,
validate_image_biding,
validate_image_is_valid_selector,
validate_selector_holds_detections,
validate_selector_holds_image,
validate_selector_is_inference_parameter,
validate_value_is_empty_or_number_in_range_zero_one,
validate_value_is_empty_or_positive_number,
validate_value_is_empty_or_selector_or_positive_number,
)
from inference.enterprise.workflows.errors import (
ExecutionGraphError,
InvalidStepInputDetected,
VariableTypeError,
)
class StepInterface(GraphNone, metaclass=ABCMeta):
@abstractmethod
def get_input_names(self) -> Set[str]:
"""
Supposed to give the name of all fields expected to represent inputs
"""
pass
@abstractmethod
def get_output_names(self) -> Set[str]:
"""
Supposed to give the name of all fields expected to represent outputs to be referred by other steps
"""
@abstractmethod
def validate_field_selector(
self, field_name: str, input_step: GraphNone, index: Optional[int] = None
) -> None:
"""
Supposed to validate the type of input is referred
"""
pass
@abstractmethod
def validate_field_binding(self, field_name: str, value: Any) -> None:
"""
Supposed to validate the type of value that is to be bounded with field as a result of graph
execution (values passed by client to invocation, as well as constructed during graph execution)
"""
pass
class RoboflowModel(BaseModel, StepInterface, metaclass=ABCMeta):
model_config = ConfigDict(protected_namespaces=())
type: Literal["RoboflowModel"]
name: str
image: Union[str, List[str]]
model_id: str
disable_active_learning: Union[Optional[bool], str] = Field(default=False)
@field_validator("image")
@classmethod
def validate_image(cls, value: Any) -> Union[str, List[str]]:
validate_image_is_valid_selector(value=value)
return value
@field_validator("model_id")
@classmethod
def model_id_must_be_selector_or_str(cls, value: Any) -> str:
validate_field_is_selector_or_has_given_type(
value=value, field_name="model_id", allowed_types=[str]
)
return value
@field_validator("disable_active_learning")
@classmethod
def disable_active_learning_must_be_selector_or_bool(
cls, value: Any
) -> Union[Optional[bool], str]:
validate_field_is_selector_or_has_given_type(
field_name="disable_active_learning",
allowed_types=[type(None), bool],
value=value,
)
return value
def get_type(self) -> str:
return self.type
def get_input_names(self) -> Set[str]:
return {"image", "model_id", "disable_active_learning"}
def get_output_names(self) -> Set[str]:
return {"prediction_type"}
def validate_field_selector(
self, field_name: str, input_step: GraphNone, index: Optional[int] = None
) -> None:
if not is_selector(selector_or_value=getattr(self, field_name)):
raise ExecutionGraphError(
f"Attempted to validate selector value for field {field_name}, but field is not selector."
)
validate_selector_holds_image(
step_type=self.type,
field_name=field_name,
input_step=input_step,
)
validate_selector_is_inference_parameter(
step_type=self.type,
field_name=field_name,
input_step=input_step,
applicable_fields={"model_id", "disable_active_learning"},
)
def validate_field_binding(self, field_name: str, value: Any) -> None:
if field_name == "image":
validate_image_biding(value=value)
elif field_name == "model_id":
validate_field_has_given_type(
field_name=field_name,
allowed_types=[str],
value=value,
error=VariableTypeError,
)
elif field_name == "disable_active_learning":
validate_field_has_given_type(
field_name=field_name,
allowed_types=[bool],
value=value,
error=VariableTypeError,
)
class ClassificationModel(RoboflowModel):
type: Literal["ClassificationModel"]
confidence: Union[Optional[float], str] = Field(default=0.4)
@field_validator("confidence")
@classmethod
def confidence_must_be_selector_or_number(
cls, value: Any
) -> Union[Optional[float], str]:
validate_field_is_in_range_zero_one_or_empty_or_selector(value=value)
return value
def get_input_names(self) -> Set[str]:
inputs = super().get_input_names()
inputs.add("confidence")
return inputs
def get_output_names(self) -> Set[str]:
outputs = super().get_output_names()
outputs.update(["predictions", "top", "confidence", "parent_id"])
return outputs
def validate_field_selector(
self, field_name: str, input_step: GraphNone, index: Optional[int] = None
) -> None:
super().validate_field_selector(field_name=field_name, input_step=input_step)
validate_selector_is_inference_parameter(
step_type=self.type,
field_name=field_name,
input_step=input_step,
applicable_fields={"confidence"},
)
def validate_field_binding(self, field_name: str, value: Any) -> None:
super().validate_field_binding(field_name=field_name, value=value)
if field_name == "confidence":
if value is None:
raise VariableTypeError("Parameter `confidence` cannot be None")
validate_value_is_empty_or_number_in_range_zero_one(
value=value, error=VariableTypeError
)
class MultiLabelClassificationModel(RoboflowModel):
type: Literal["MultiLabelClassificationModel"]
confidence: Union[Optional[float], str] = Field(default=0.4)
@field_validator("confidence")
@classmethod
def confidence_must_be_selector_or_number(
cls, value: Any
) -> Union[Optional[float], str]:
validate_field_is_in_range_zero_one_or_empty_or_selector(value=value)
return value
def get_input_names(self) -> Set[str]:
inputs = super().get_input_names()
inputs.add("confidence")
return inputs
def get_output_names(self) -> Set[str]:
outputs = super().get_output_names()
outputs.update(["predictions", "predicted_classes", "parent_id"])
return outputs
def validate_field_selector(
self, field_name: str, input_step: GraphNone, index: Optional[int] = None
) -> None:
super().validate_field_selector(field_name=field_name, input_step=input_step)
validate_selector_is_inference_parameter(
step_type=self.type,
field_name=field_name,
input_step=input_step,
applicable_fields={"confidence"},
)
def validate_field_binding(self, field_name: str, value: Any) -> None:
super().validate_field_binding(field_name=field_name, value=value)
if field_name == "confidence":
if value is None:
raise VariableTypeError("Parameter `confidence` cannot be None")
validate_value_is_empty_or_number_in_range_zero_one(
value=value, error=VariableTypeError
)
class ObjectDetectionModel(RoboflowModel):
type: Literal["ObjectDetectionModel"]
class_agnostic_nms: Union[Optional[bool], str] = Field(default=False)
class_filter: Union[Optional[List[str]], str] = Field(default=None)
confidence: Union[Optional[float], str] = Field(default=0.4)
iou_threshold: Union[Optional[float], str] = Field(default=0.3)
max_detections: Union[Optional[int], str] = Field(default=300)
max_candidates: Union[Optional[int], str] = Field(default=3000)
@field_validator("class_agnostic_nms")
@classmethod
def class_agnostic_nms_must_be_selector_or_bool(
cls, value: Any
) -> Union[Optional[bool], str]:
validate_field_is_selector_or_has_given_type(
field_name="class_agnostic_nms",
allowed_types=[type(None), bool],
value=value,
)
return value
@field_validator("class_filter")
@classmethod
def class_filter_must_be_selector_or_list_of_string(
cls, value: Any
) -> Union[Optional[List[str]], str]:
validate_field_is_empty_or_selector_or_list_of_string(
value=value, field_name="class_filter"
)
return value
@field_validator("confidence", "iou_threshold")
@classmethod
def field_must_be_selector_or_number_from_zero_to_one(
cls, value: Any
) -> Union[Optional[float], str]:
validate_field_is_in_range_zero_one_or_empty_or_selector(
value=value, field_name="confidence | iou_threshold"
)
return value
@field_validator("max_detections", "max_candidates")
@classmethod
def field_must_be_selector_or_positive_number(
cls, value: Any
) -> Union[Optional[int], str]:
validate_value_is_empty_or_selector_or_positive_number(
value=value,
field_name="max_detections | max_candidates",
)
return value
def get_input_names(self) -> Set[str]:
inputs = super().get_input_names()
inputs.update(
[
"class_agnostic_nms",
"class_filter",
"confidence",
"iou_threshold",
"max_detections",
"max_candidates",
]
)
return inputs
def get_output_names(self) -> Set[str]:
outputs = super().get_output_names()
outputs.update(["predictions", "parent_id", "image"])
return outputs
def validate_field_selector(
self, field_name: str, input_step: GraphNone, index: Optional[int] = None
) -> None:
super().validate_field_selector(field_name=field_name, input_step=input_step)
validate_selector_is_inference_parameter(
step_type=self.type,
field_name=field_name,
input_step=input_step,
applicable_fields={
"class_agnostic_nms",
"class_filter",
"confidence",
"iou_threshold",
"max_detections",
"max_candidates",
},
)
def validate_field_binding(self, field_name: str, value: Any) -> None:
super().validate_field_binding(field_name=field_name, value=value)
if value is None:
raise VariableTypeError(f"Parameter `{field_name}` cannot be None")
if field_name == "class_agnostic_nms":
validate_field_has_given_type(
field_name=field_name,
allowed_types=[bool],
value=value,
error=VariableTypeError,
)
elif field_name == "class_filter":
if value is None:
return None
validate_field_is_list_of_string(
value=value, field_name=field_name, error=VariableTypeError
)
elif field_name == "confidence" or field_name == "iou_threshold":
validate_value_is_empty_or_number_in_range_zero_one(
value=value,
field_name=field_name,
error=VariableTypeError,
)
elif field_name == "max_detections" or field_name == "max_candidates":
validate_value_is_empty_or_positive_number(
value=value,
field_name=field_name,
error=VariableTypeError,
)
class KeypointsDetectionModel(ObjectDetectionModel):
type: Literal["KeypointsDetectionModel"]
keypoint_confidence: Union[Optional[float], str] = Field(default=0.0)
@field_validator("keypoint_confidence")
@classmethod
def keypoint_confidence_field_must_be_selector_or_number_from_zero_to_one(
cls, value: Any
) -> Union[Optional[float], str]:
validate_field_is_in_range_zero_one_or_empty_or_selector(
value=value, field_name="keypoint_confidence"
)
return value
def get_input_names(self) -> Set[str]:
inputs = super().get_input_names()
inputs.add("keypoint_confidence")
return inputs
def validate_field_selector(
self, field_name: str, input_step: GraphNone, index: Optional[int] = None
) -> None:
super().validate_field_selector(field_name=field_name, input_step=input_step)
validate_selector_is_inference_parameter(
step_type=self.type,
field_name=field_name,
input_step=input_step,
applicable_fields={"keypoint_confidence"},
)
def validate_field_binding(self, field_name: str, value: Any) -> None:
super().validate_field_binding(field_name=field_name, value=value)
if field_name == "keypoint_confidence":
validate_value_is_empty_or_number_in_range_zero_one(
value=value,
field_name=field_name,
error=VariableTypeError,
)
DECODE_MODES = {"accurate", "tradeoff", "fast"}
class InstanceSegmentationModel(ObjectDetectionModel):
type: Literal["InstanceSegmentationModel"]
mask_decode_mode: Optional[str] = Field(default="accurate")
tradeoff_factor: Union[Optional[float], str] = Field(default=0.0)
@field_validator("mask_decode_mode")
@classmethod
def mask_decode_mode_must_be_selector_or_one_of_allowed_values(
cls, value: Any
) -> Optional[str]:
validate_field_is_selector_or_one_of_values(
value=value,
field_name="mask_decode_mode",
selected_values=DECODE_MODES,
)
return value
@field_validator("tradeoff_factor")
@classmethod
def field_must_be_selector_or_number_from_zero_to_one(
cls, value: Any
) -> Union[Optional[float], str]:
validate_field_is_in_range_zero_one_or_empty_or_selector(
value=value, field_name="tradeoff_factor"
)
return value
def get_input_names(self) -> Set[str]:
inputs = super().get_input_names()
inputs.update(["mask_decode_mode", "tradeoff_factor"])
return inputs
def validate_field_selector(
self, field_name: str, input_step: GraphNone, index: Optional[int] = None
) -> None:
super().validate_field_selector(field_name=field_name, input_step=input_step)
validate_selector_is_inference_parameter(
step_type=self.type,
field_name=field_name,
input_step=input_step,
applicable_fields={"mask_decode_mode", "tradeoff_factor"},
)
def validate_field_binding(self, field_name: str, value: Any) -> None:
super().validate_field_binding(field_name=field_name, value=value)
if field_name == "mask_decode_mode":
validate_field_is_one_of_selected_values(
value=value,
field_name=field_name,
selected_values=DECODE_MODES,
error=VariableTypeError,
)
elif field_name == "tradeoff_factor":
validate_value_is_empty_or_number_in_range_zero_one(
value=value,
field_name=field_name,
error=VariableTypeError,
)
class OCRModel(BaseModel, StepInterface):
type: Literal["OCRModel"]
name: str
image: Union[str, List[str]]
@field_validator("image")
@classmethod
def image_must_only_hold_selectors(cls, value: Any) -> Union[str, List[str]]:
validate_image_is_valid_selector(value=value)
return value
def validate_field_selector(
self, field_name: str, input_step: GraphNone, index: Optional[int] = None
) -> None:
if not is_selector(selector_or_value=getattr(self, field_name)):
raise ExecutionGraphError(
f"Attempted to validate selector value for field {field_name}, but field is not selector."
)
validate_selector_holds_image(
step_type=self.type,
field_name=field_name,
input_step=input_step,
)
def validate_field_binding(self, field_name: str, value: Any) -> None:
if field_name == "image":
validate_image_biding(value=value)
def get_type(self) -> str:
return self.type
def get_input_names(self) -> Set[str]:
return {"image"}
def get_output_names(self) -> Set[str]:
return {"result", "parent_id", "prediction_type"}
class Crop(BaseModel, StepInterface):
type: Literal["Crop"]
name: str
image: Union[str, List[str]]
detections: str
@field_validator("image")
@classmethod
def image_must_only_hold_selectors(cls, value: Any) -> Union[str, List[str]]:
validate_image_is_valid_selector(value=value)
return value
@field_validator("detections")
@classmethod
def detections_must_hold_selector(cls, value: Any) -> str:
if not is_selector(selector_or_value=value):
raise ValueError("`detections` field can only contain selector values")
return value
def get_type(self) -> str:
return self.type
def get_input_names(self) -> Set[str]:
return {"image", "detections"}
def get_output_names(self) -> Set[str]:
return {"crops", "parent_id"}
def validate_field_selector(
self, field_name: str, input_step: GraphNone, index: Optional[int] = None
) -> None:
if not is_selector(selector_or_value=getattr(self, field_name)):
raise ExecutionGraphError(
f"Attempted to validate selector value for field {field_name}, but field is not selector."
)
validate_selector_holds_image(
step_type=self.type,
field_name=field_name,
input_step=input_step,
)
validate_selector_holds_detections(
step_name=self.name,
image_selector=self.image,
detections_selector=self.detections,
field_name=field_name,
input_step=input_step,
)
def validate_field_binding(self, field_name: str, value: Any) -> None:
if field_name == "image":
validate_image_biding(value=value)
class Operator(Enum):
EQUAL = "equal"
NOT_EQUAL = "not_equal"
LOWER_THAN = "lower_than"
GREATER_THAN = "greater_than"
LOWER_OR_EQUAL_THAN = "lower_or_equal_than"
GREATER_OR_EQUAL_THAN = "greater_or_equal_than"
IN = "in"
class Condition(BaseModel, StepInterface):
type: Literal["Condition"]
name: str
left: Union[float, int, bool, str, list, set]
operator: Operator
right: Union[float, int, bool, str, list, set]
step_if_true: str
step_if_false: str
def get_type(self) -> str:
return self.type
def get_input_names(self) -> Set[str]:
return {"left", "right"}
def get_output_names(self) -> Set[str]:
return set()
def validate_field_selector(
self, field_name: str, input_step: GraphNone, index: Optional[int] = None
) -> None:
if not is_selector(selector_or_value=getattr(self, field_name)):
raise ExecutionGraphError(
f"Attempted to validate selector value for field {field_name}, but field is not selector."
)
input_type = input_step.get_type()
if field_name in {"left", "right"}:
if input_type == "InferenceImage":
raise InvalidStepInputDetected(
f"Field {field_name} of step {self.type} comes from invalid input type: {input_type}. "
f"Expected: anything else than `InferenceImage`"
)
def validate_field_binding(self, field_name: str, value: Any) -> None:
pass
class BinaryOperator(Enum):
OR = "or"
AND = "and"
class DetectionFilterDefinition(BaseModel):
type: Literal["DetectionFilterDefinition"]
field_name: str
operator: Operator
reference_value: Union[float, int, bool, str, list, set]
class CompoundDetectionFilterDefinition(BaseModel):
type: Literal["CompoundDetectionFilterDefinition"]
left: DetectionFilterDefinition
operator: BinaryOperator
right: DetectionFilterDefinition
class DetectionFilter(BaseModel, StepInterface):
type: Literal["DetectionFilter"]
name: str
predictions: str
filter_definition: Annotated[
Union[DetectionFilterDefinition, CompoundDetectionFilterDefinition],
Field(discriminator="type"),
]
def get_input_names(self) -> Set[str]:
return {"predictions"}
def get_output_names(self) -> Set[str]:
return {"predictions", "parent_id", "image", "prediction_type"}
def validate_field_selector(
self, field_name: str, input_step: GraphNone, index: Optional[int] = None
) -> None:
if not is_selector(selector_or_value=getattr(self, field_name)):
raise ExecutionGraphError(
f"Attempted to validate selector value for field {field_name}, but field is not selector."
)
validate_selector_holds_detections(
step_name=self.name,
image_selector=None,
detections_selector=self.predictions,
field_name=field_name,
input_step=input_step,
applicable_fields={"predictions"},
)
def validate_field_binding(self, field_name: str, value: Any) -> None:
pass
def get_type(self) -> str:
return self.type
class DetectionOffset(BaseModel, StepInterface):
type: Literal["DetectionOffset"]
name: str
predictions: str
offset_x: Union[int, str]
offset_y: Union[int, str]
def get_input_names(self) -> Set[str]:
return {"predictions", "offset_x", "offset_y"}
def get_output_names(self) -> Set[str]:
return {"predictions", "parent_id", "image", "prediction_type"}
def validate_field_selector(
self, field_name: str, input_step: GraphNone, index: Optional[int] = None
) -> None:
if not is_selector(selector_or_value=getattr(self, field_name)):
raise ExecutionGraphError(
f"Attempted to validate selector value for field {field_name}, but field is not selector."
)
validate_selector_holds_detections(
step_name=self.name,
image_selector=None,
detections_selector=self.predictions,
field_name=field_name,
input_step=input_step,
applicable_fields={"predictions"},
)
validate_selector_is_inference_parameter(
step_type=self.type,
field_name=field_name,
input_step=input_step,
applicable_fields={"offset_x", "offset_y"},
)
def validate_field_binding(self, field_name: str, value: Any) -> None:
if field_name in {"offset_x", "offset_y"}:
validate_field_has_given_type(
field_name=field_name,
value=value,
allowed_types=[int],
error=VariableTypeError,
)
def get_type(self) -> str:
return self.type
class AbsoluteStaticCrop(BaseModel, StepInterface):
type: Literal["AbsoluteStaticCrop"]
name: str
image: Union[str, List[str]]
x_center: Union[int, str]
y_center: Union[int, str]
width: Union[int, str]
height: Union[int, str]
@field_validator("image")
@classmethod
def image_must_only_hold_selectors(cls, value: Any) -> Union[str, List[str]]:
validate_image_is_valid_selector(value=value)
return value
@field_validator("x_center", "y_center", "width", "height")
@classmethod
def validate_crops_coordinates(cls, value: Any) -> str:
validate_value_is_empty_or_selector_or_positive_number(
value=value, field_name="x_center | y_center | width | height"
)
return value
def get_type(self) -> str:
return self.type
def get_input_names(self) -> Set[str]:
return {"image", "x_center", "y_center", "width", "height"}
def get_output_names(self) -> Set[str]:
return {"crops", "parent_id"}
def validate_field_selector(
self, field_name: str, input_step: GraphNone, index: Optional[int] = None
) -> None:
if not is_selector(selector_or_value=getattr(self, field_name)):
raise ExecutionGraphError(
f"Attempted to validate selector value for field {field_name}, but field is not selector."
)
validate_selector_holds_image(
step_type=self.type,
field_name=field_name,
input_step=input_step,
)
validate_selector_is_inference_parameter(
step_type=self.type,
field_name=field_name,
input_step=input_step,
applicable_fields={"x_center", "y_center", "width", "height"},
)
def validate_field_binding(self, field_name: str, value: Any) -> None:
if field_name == "image":
validate_image_biding(value=value)
if field_name in {"x_center", "y_center", "width", "height"}:
if (
not issubclass(type(value), int) and not issubclass(type(value), float)
) or value != round(value):
raise VariableTypeError(
f"Field {field_name} of step {self.type} must be integer"
)
class RelativeStaticCrop(BaseModel, StepInterface):
type: Literal["RelativeStaticCrop"]
name: str
image: Union[str, List[str]]
x_center: Union[float, str]
y_center: Union[float, str]
width: Union[float, str]
height: Union[float, str]
@field_validator("image")
@classmethod
def image_must_only_hold_selectors(cls, value: Any) -> Union[str, List[str]]:
validate_image_is_valid_selector(value=value)
return value
@field_validator("x_center", "y_center", "width", "height")
@classmethod
def detections_must_hold_selector(cls, value: Any) -> str:
if issubclass(type(value), str):
if not is_selector(selector_or_value=value):
raise ValueError("Field must be either float of valid selector")
elif not issubclass(type(value), float):
raise ValueError("Field must be either float of valid selector")
return value
def get_type(self) -> str:
return self.type
def get_input_names(self) -> Set[str]:
return {"image", "x_center", "y_center", "width", "height"}
def get_output_names(self) -> Set[str]:
return {"crops", "parent_id"}
def validate_field_selector(
self, field_name: str, input_step: GraphNone, index: Optional[int] = None
) -> None:
if not is_selector(selector_or_value=getattr(self, field_name)):
raise ExecutionGraphError(
f"Attempted to validate selector value for field {field_name}, but field is not selector."
)
validate_selector_holds_image(
step_type=self.type,
field_name=field_name,
input_step=input_step,
)
validate_selector_is_inference_parameter(
step_type=self.type,
field_name=field_name,
input_step=input_step,
applicable_fields={"x_center", "y_center", "width", "height"},
)
def validate_field_binding(self, field_name: str, value: Any) -> None:
if field_name == "image":
validate_image_biding(value=value)
if field_name in {"x_center", "y_center", "width", "height"}:
validate_field_has_given_type(
field_name=field_name,
value=value,
allowed_types=[float],
error=VariableTypeError,
)
class ClipComparison(BaseModel, StepInterface):
type: Literal["ClipComparison"]
name: str
image: Union[str, List[str]]
text: Union[str, List[str]]
@field_validator("image")
@classmethod
def image_must_only_hold_selectors(cls, value: Any) -> Union[str, List[str]]:
validate_image_is_valid_selector(value=value)
return value
@field_validator("text")
@classmethod
def text_must_be_valid(cls, value: Any) -> Union[str, List[str]]:
if is_selector(selector_or_value=value):
return value
if issubclass(type(value), list):
validate_field_is_list_of_string(value=value, field_name="text")
elif not issubclass(type(value), str):
raise ValueError("`text` field given must be string or list of strings")
return value
def validate_field_selector(
self, field_name: str, input_step: GraphNone, index: Optional[int] = None
) -> None:
if not is_selector(selector_or_value=getattr(self, field_name)):
raise ExecutionGraphError(
f"Attempted to validate selector value for field {field_name}, but field is not selector."
)
validate_selector_holds_image(
step_type=self.type,
field_name=field_name,
input_step=input_step,
)
validate_selector_is_inference_parameter(
step_type=self.type,
field_name=field_name,
input_step=input_step,
applicable_fields={"text"},
)
def validate_field_binding(self, field_name: str, value: Any) -> None:
if field_name == "image":
validate_image_biding(value=value)
if field_name == "text":
if issubclass(type(value), list):
validate_field_is_list_of_string(
value=value, field_name=field_name, error=VariableTypeError
)
elif not issubclass(type(value), str):
validate_field_has_given_type(
value=value,
field_name=field_name,
allowed_types=[str],
error=VariableTypeError,
)
def get_type(self) -> str:
return self.type
def get_input_names(self) -> Set[str]:
return {"image", "text"}
def get_output_names(self) -> Set[str]:
return {"similarity", "parent_id", "predictions_type"}
class AggregationMode(Enum):
AVERAGE = "average"
MAX = "max"
MIN = "min"
class DetectionsConsensus(BaseModel, StepInterface):
type: Literal["DetectionsConsensus"]
name: str
predictions: List[str]
required_votes: Union[int, str]
class_aware: Union[bool, str] = Field(default=True)
iou_threshold: Union[float, str] = Field(default=0.3)
confidence: Union[float, str] = Field(default=0.0)
classes_to_consider: Optional[Union[List[str], str]] = Field(default=None)
required_objects: Optional[Union[int, Dict[str, int], str]] = Field(default=None)
presence_confidence_aggregation: AggregationMode = Field(
default=AggregationMode.MAX
)
detections_merge_confidence_aggregation: AggregationMode = Field(
default=AggregationMode.AVERAGE
)
detections_merge_coordinates_aggregation: AggregationMode = Field(
default=AggregationMode.AVERAGE
)
@field_validator("predictions")
@classmethod
def predictions_must_be_list_of_selectors(cls, value: Any) -> List[str]:
validate_field_is_list_of_selectors(value=value, field_name="predictions")
if len(value) < 1:
raise ValueError(
"There must be at least 1 `predictions` selectors in consensus step"
)
return value
@field_validator("required_votes")
@classmethod
def required_votes_must_be_selector_or_positive_integer(
cls, value: Any
) -> Union[str, int]:
if value is None:
raise ValueError("Field `required_votes` is required.")
validate_value_is_empty_or_selector_or_positive_number(
value=value, field_name="required_votes"
)
return value
@field_validator("class_aware")
@classmethod
def class_aware_must_be_selector_or_boolean(cls, value: Any) -> Union[str, bool]:
validate_field_is_selector_or_has_given_type(
value=value, field_name="class_aware", allowed_types=[bool]
)
return value
@field_validator("iou_threshold", "confidence")
@classmethod
def field_must_be_selector_or_number_from_zero_to_one(
cls, value: Any
) -> Union[str, float]:
if value is None:
raise ValueError("Fields `iou_threshold` and `confidence` cannot be None")
validate_field_is_in_range_zero_one_or_empty_or_selector(
value=value, field_name="iou_threshold | confidence"
)
return value
@field_validator("classes_to_consider")
@classmethod
def classes_to_consider_must_be_empty_or_selector_or_list_of_strings(
cls, value: Any
) -> Optional[Union[str, List[str]]]:
validate_field_is_empty_or_selector_or_list_of_string(
value=value, field_name="classes_to_consider"
)
return value
@field_validator("required_objects")
@classmethod
def required_objects_field_must_be_valid(
cls, value: Any
) -> Optional[Union[str, int, Dict[str, int]]]:
if value is None:
return value
validate_field_is_selector_or_has_given_type(
value=value, field_name="required_objects", allowed_types=[int, dict]
)
if issubclass(type(value), int):
validate_value_is_empty_or_positive_number(
value=value, field_name="required_objects"
)
return value
elif issubclass(type(value), dict):
for k, v in value.items():
if v is None:
raise ValueError(f"Field `required_objects[{k}]` must not be None.")
validate_value_is_empty_or_positive_number(
value=v, field_name=f"required_objects[{k}]"
)
return value
def get_input_names(self) -> Set[str]:
return {
"predictions",
"required_votes",
"class_aware",
"iou_threshold",
"confidence",
"classes_to_consider",
"required_objects",
}
def get_output_names(self) -> Set[str]:
return {
"parent_id",
"predictions",
"image",
"object_present",
"presence_confidence",
"predictions_type",
}
def validate_field_selector(
self, field_name: str, input_step: GraphNone, index: Optional[int] = None
) -> None:
if field_name != "predictions" and not is_selector(
selector_or_value=getattr(self, field_name)
):
raise ExecutionGraphError(
f"Attempted to validate selector value for field {field_name}, but field is not selector."
)
if field_name == "predictions":
if index is None or index > len(self.predictions):
raise ExecutionGraphError(
f"Attempted to validate selector value for field {field_name}, which requires multiple inputs, "
f"but `index` not provided."
)
if not is_selector(
selector_or_value=self.predictions[index],
):
raise ExecutionGraphError(
f"Attempted to validate selector value for field {field_name}[{index}], but field is not selector."
)
validate_selector_holds_detections(
step_name=self.name,
image_selector=None,
detections_selector=self.predictions[index],
field_name=field_name,
input_step=input_step,
applicable_fields={"predictions"},
)
return None
validate_selector_is_inference_parameter(
step_type=self.type,
field_name=field_name,
input_step=input_step,
applicable_fields={
"required_votes",
"class_aware",
"iou_threshold",
"confidence",
"classes_to_consider",
"required_objects",
},
)
def validate_field_binding(self, field_name: str, value: Any) -> None:
if field_name == "required_votes":
if value is None:
raise VariableTypeError("Field `required_votes` cannot be None.")
validate_value_is_empty_or_positive_number(
value=value, field_name="required_votes", error=VariableTypeError
)
elif field_name == "class_aware":
validate_field_has_given_type(
field_name=field_name,
allowed_types=[bool],
value=value,
error=VariableTypeError,
)
elif field_name in {"iou_threshold", "confidence"}:
if value is None:
raise VariableTypeError(f"Fields `{field_name}` cannot be None.")
validate_value_is_empty_or_number_in_range_zero_one(
value=value,
field_name=field_name,
error=VariableTypeError,
)
elif field_name == "classes_to_consider":
if value is None:
return None
validate_field_is_list_of_string(
value=value,
field_name=field_name,
error=VariableTypeError,
)
elif field_name == "required_objects":
self._validate_required_objects_binding(value=value)
return None
def get_type(self) -> str:
return self.type
def _validate_required_objects_binding(self, value: Any) -> None:
if value is None:
return value
validate_field_has_given_type(
value=value,
field_name="required_objects",
allowed_types=[int, dict],
error=VariableTypeError,
)
if issubclass(type(value), int):
validate_value_is_empty_or_positive_number(
value=value,
field_name="required_objects",
error=VariableTypeError,
)
return None
for k, v in value.items():
if v is None:
raise VariableTypeError(
f"Field `required_objects[{k}]` must not be None."
)
validate_value_is_empty_or_positive_number(
value=v,
field_name=f"required_objects[{k}]",
error=VariableTypeError,
)
ACTIVE_LEARNING_DATA_COLLECTOR_ELIGIBLE_SELECTORS = {
"ObjectDetectionModel": "predictions",
"KeypointsDetectionModel": "predictions",
"InstanceSegmentationModel": "predictions",
"DetectionFilter": "predictions",
"DetectionsConsensus": "predictions",
"DetectionOffset": "predictions",
"YoloWorld": "predictions",
"ClassificationModel": "top",
}
class DisabledActiveLearningConfiguration(BaseModel):
enabled: bool
@field_validator("enabled")
@classmethod
def ensure_only_false_is_valid(cls, value: Any) -> bool:
if value is not False:
raise ValueError(
"One can only specify enabled=False in `DisabledActiveLearningConfiguration`"
)
return value
class LimitDefinition(BaseModel):
type: Literal["minutely", "hourly", "daily"]
value: PositiveInt
class RandomSamplingConfig(BaseModel):
type: Literal["random"]
name: str
traffic_percentage: confloat(ge=0.0, le=1.0)
tags: List[str] = Field(default_factory=lambda: [])
limits: List[LimitDefinition] = Field(default_factory=lambda: [])
class CloseToThresholdSampling(BaseModel):
type: Literal["close_to_threshold"]
name: str
probability: confloat(ge=0.0, le=1.0)
threshold: confloat(ge=0.0, le=1.0)
epsilon: confloat(ge=0.0, le=1.0)
max_batch_images: Optional[int] = Field(default=None)
only_top_classes: bool = Field(default=True)
minimum_objects_close_to_threshold: int = Field(default=1)
selected_class_names: Optional[List[str]] = Field(default=None)
tags: List[str] = Field(default_factory=lambda: [])
limits: List[LimitDefinition] = Field(default_factory=lambda: [])
class ClassesBasedSampling(BaseModel):
type: Literal["classes_based"]
name: str
probability: confloat(ge=0.0, le=1.0)
selected_class_names: List[str]
tags: List[str] = Field(default_factory=lambda: [])
limits: List[LimitDefinition] = Field(default_factory=lambda: [])
class DetectionsBasedSampling(BaseModel):
type: Literal["detections_number_based"]
name: str
probability: confloat(ge=0.0, le=1.0)
more_than: Optional[NonNegativeInt]
less_than: Optional[NonNegativeInt]
selected_class_names: Optional[List[str]] = Field(default=None)
tags: List[str] = Field(default_factory=lambda: [])
limits: List[LimitDefinition] = Field(default_factory=lambda: [])
class ActiveLearningBatchingStrategy(BaseModel):
batches_name_prefix: str
recreation_interval: Literal["never", "daily", "weekly", "monthly"]
max_batch_images: Optional[int] = Field(default=None)
ActiveLearningStrategyType = Annotated[
Union[
RandomSamplingConfig,
CloseToThresholdSampling,
ClassesBasedSampling,
DetectionsBasedSampling,
],
Field(discriminator="type"),
]
class EnabledActiveLearningConfiguration(BaseModel):
enabled: bool
persist_predictions: bool
sampling_strategies: List[ActiveLearningStrategyType]
batching_strategy: ActiveLearningBatchingStrategy
tags: List[str] = Field(default_factory=lambda: [])
max_image_size: Optional[Tuple[PositiveInt, PositiveInt]] = Field(default=None)
jpeg_compression_level: int = Field(default=95)
@field_validator("jpeg_compression_level")
@classmethod
def validate_json_compression_level(cls, value: Any) -> int:
validate_field_has_given_type(
field_name="jpeg_compression_level", allowed_types=[int], value=value
)
if value <= 0 or value > 100:
raise ValueError("`jpeg_compression_level` must be in range [1, 100]")
return value
class ActiveLearningDataCollector(BaseModel, StepInterface):
type: Literal["ActiveLearningDataCollector"]
name: str
image: str
predictions: str
target_dataset: str
target_dataset_api_key: Optional[str] = Field(default=None)
disable_active_learning: Union[bool, str] = Field(default=False)
active_learning_configuration: Optional[
Union[EnabledActiveLearningConfiguration, DisabledActiveLearningConfiguration]
] = Field(default=None)
@field_validator("image")
@classmethod
def image_must_only_hold_selectors(cls, value: Any) -> Union[str, List[str]]:
validate_image_is_valid_selector(value=value)
return value
@field_validator("predictions")
@classmethod
def predictions_must_hold_selector(cls, value: Any) -> str:
if not is_selector(selector_or_value=value):
raise ValueError("`predictions` field can only contain selector values")
return value
@field_validator("target_dataset")
@classmethod
def validate_target_dataset_field(cls, value: Any) -> str:
validate_field_is_selector_or_has_given_type(
value=value, field_name="target_dataset", allowed_types=[str]
)
return value
@field_validator("target_dataset_api_key")
@classmethod
def validate_target_dataset_api_key_field(cls, value: Any) -> Union[str, bool]:
validate_field_is_selector_or_has_given_type(
value=value,
field_name="target_dataset_api_key",
allowed_types=[bool, type(None)],
)
return value
@field_validator("disable_active_learning")
@classmethod
def validate_boolean_flags_or_selectors(cls, value: Any) -> Union[str, bool]:
validate_field_is_selector_or_has_given_type(
value=value, field_name="disable_active_learning", allowed_types=[bool]
)
return value
def get_type(self) -> str:
return self.type
def get_input_names(self) -> Set[str]:
return {
"image",
"predictions",
"target_dataset",
"target_dataset_api_key",
"disable_active_learning",
}
def get_output_names(self) -> Set[str]:
return set()
def validate_field_selector(
self, field_name: str, input_step: GraphNone, index: Optional[int] = None
) -> None:
selector = getattr(self, field_name)
if not is_selector(selector_or_value=selector):
raise ExecutionGraphError(
f"Attempted to validate selector value for field {field_name}, but field is not selector."
)
if field_name == "predictions":
input_step_type = input_step.get_type()
expected_last_selector_chunk = (
ACTIVE_LEARNING_DATA_COLLECTOR_ELIGIBLE_SELECTORS.get(input_step_type)
)
if expected_last_selector_chunk is None:
raise ExecutionGraphError(
f"Attempted to validate predictions selector of {self.name} step, but input step of type: "
f"{input_step_type} does match by type."
)
if get_last_selector_chunk(selector) != expected_last_selector_chunk:
raise ExecutionGraphError(
f"It is only allowed to refer to {input_step_type} step output named {expected_last_selector_chunk}. "
f"Reference that was found: {selector}"
)
input_step_image = getattr(input_step, "image", self.image)
if input_step_image != self.image:
raise ExecutionGraphError(
f"ActiveLearningDataCollector step refers to input step that uses reference to different image. "
f"ActiveLearningDataCollector step image: {self.image}. Input step (of type {input_step_image}) "
f"uses {input_step_image}."
)
validate_selector_holds_image(
step_type=self.type,
field_name=field_name,
input_step=input_step,
)
validate_selector_is_inference_parameter(
step_type=self.type,
field_name=field_name,
input_step=input_step,
applicable_fields={
"target_dataset",
"target_dataset_api_key",
"disable_active_learning",
},
)
def validate_field_binding(self, field_name: str, value: Any) -> None:
if field_name == "image":
validate_image_biding(value=value)
elif field_name in {"disable_active_learning"}:
validate_field_has_given_type(
field_name=field_name,
allowed_types=[bool],
value=value,
error=VariableTypeError,
)
elif field_name in {"target_dataset"}:
validate_field_has_given_type(
field_name=field_name,
allowed_types=[str],
value=value,
error=VariableTypeError,
)
elif field_name in {"target_dataset_api_key"}:
validate_field_has_given_type(
field_name=field_name,
allowed_types=[str],
value=value,
error=VariableTypeError,
)
class YoloWorld(BaseModel, StepInterface):
type: Literal["YoloWorld"]
name: str
image: str
class_names: Union[str, List[str]]
version: Optional[str] = Field(default="l")
confidence: Union[Optional[float], str] = Field(default=0.4)
@field_validator("image")
@classmethod
def image_must_only_hold_selectors(cls, value: Any) -> Union[str, List[str]]:
validate_image_is_valid_selector(value=value)
return value
@field_validator("class_names")
@classmethod
def validate_class_names(cls, value: Any) -> Union[str, List[str]]:
if is_selector(selector_or_value=value):
return value
if issubclass(type(value), list):
validate_field_is_list_of_string(value=value, field_name="class_names")
return value
raise ValueError(
"`class_names` field given must be selector or list of strings"
)
@field_validator("version")
@classmethod
def validate_model_version(cls, value: Any) -> Optional[str]:
validate_field_is_selector_or_one_of_values(
value=value,
selected_values={None, "s", "m", "l"},
field_name="version",
)
return value
@field_validator("confidence")
@classmethod
def field_must_be_selector_or_number_from_zero_to_one(
cls, value: Any
) -> Union[Optional[float], str]:
if value is None:
return None
validate_field_is_in_range_zero_one_or_empty_or_selector(
value=value, field_name="confidence"
)
return value
def get_input_names(self) -> Set[str]:
return {"image", "class_names", "version", "confidence"}
def get_output_names(self) -> Set[str]:
return {"predictions", "parent_id", "image", "prediction_type"}
def validate_field_selector(
self, field_name: str, input_step: GraphNone, index: Optional[int] = None
) -> None:
selector = getattr(self, field_name)
if not is_selector(selector_or_value=selector):
raise ExecutionGraphError(
f"Attempted to validate selector value for field {field_name}, but field is not selector."
)
validate_selector_holds_image(
step_type=self.type,
field_name=field_name,
input_step=input_step,
)
validate_selector_is_inference_parameter(
step_type=self.type,
field_name=field_name,
input_step=input_step,
applicable_fields={"class_names", "version", "confidence"},
)
def validate_field_binding(self, field_name: str, value: Any) -> None:
if field_name == "image":
validate_image_biding(value=value)
elif field_name == "class_names":
validate_field_is_list_of_string(
value=value,
field_name=field_name,
error=VariableTypeError,
)
elif field_name == "version":
validate_field_is_one_of_selected_values(
value=value,
field_name=field_name,
selected_values={None, "s", "m", "l"},
error=VariableTypeError,
)
elif field_name == "confidence":
validate_value_is_empty_or_number_in_range_zero_one(
value=value,
field_name=field_name,
error=VariableTypeError,
)
def get_type(self) -> str:
return self.type