Fucius's picture
Upload 422 files
df6c67d verified
raw
history blame
52.6 kB
from abc import ABCMeta, abstractmethod
from enum import Enum
from typing import Annotated, Any, Dict, List, Literal, Optional, Set, Tuple, Union
from pydantic import (
BaseModel,
ConfigDict,
Field,
NonNegativeInt,
PositiveInt,
confloat,
field_validator,
)
from inference.enterprise.workflows.entities.base import GraphNone
from inference.enterprise.workflows.entities.validators import (
get_last_selector_chunk,
is_selector,
validate_field_has_given_type,
validate_field_is_empty_or_selector_or_list_of_string,
validate_field_is_in_range_zero_one_or_empty_or_selector,
validate_field_is_list_of_selectors,
validate_field_is_list_of_string,
validate_field_is_one_of_selected_values,
validate_field_is_selector_or_has_given_type,
validate_field_is_selector_or_one_of_values,
validate_image_biding,
validate_image_is_valid_selector,
validate_selector_holds_detections,
validate_selector_holds_image,
validate_selector_is_inference_parameter,
validate_value_is_empty_or_number_in_range_zero_one,
validate_value_is_empty_or_positive_number,
validate_value_is_empty_or_selector_or_positive_number,
)
from inference.enterprise.workflows.errors import (
ExecutionGraphError,
InvalidStepInputDetected,
VariableTypeError,
)
class StepInterface(GraphNone, metaclass=ABCMeta):
@abstractmethod
def get_input_names(self) -> Set[str]:
"""
Supposed to give the name of all fields expected to represent inputs
"""
pass
@abstractmethod
def get_output_names(self) -> Set[str]:
"""
Supposed to give the name of all fields expected to represent outputs to be referred by other steps
"""
@abstractmethod
def validate_field_selector(
self, field_name: str, input_step: GraphNone, index: Optional[int] = None
) -> None:
"""
Supposed to validate the type of input is referred
"""
pass
@abstractmethod
def validate_field_binding(self, field_name: str, value: Any) -> None:
"""
Supposed to validate the type of value that is to be bounded with field as a result of graph
execution (values passed by client to invocation, as well as constructed during graph execution)
"""
pass
class RoboflowModel(BaseModel, StepInterface, metaclass=ABCMeta):
model_config = ConfigDict(protected_namespaces=())
type: Literal["RoboflowModel"]
name: str
image: Union[str, List[str]]
model_id: str
disable_active_learning: Union[Optional[bool], str] = Field(default=False)
@field_validator("image")
@classmethod
def validate_image(cls, value: Any) -> Union[str, List[str]]:
validate_image_is_valid_selector(value=value)
return value
@field_validator("model_id")
@classmethod
def model_id_must_be_selector_or_str(cls, value: Any) -> str:
validate_field_is_selector_or_has_given_type(
value=value, field_name="model_id", allowed_types=[str]
)
return value
@field_validator("disable_active_learning")
@classmethod
def disable_active_learning_must_be_selector_or_bool(
cls, value: Any
) -> Union[Optional[bool], str]:
validate_field_is_selector_or_has_given_type(
field_name="disable_active_learning",
allowed_types=[type(None), bool],
value=value,
)
return value
def get_type(self) -> str:
return self.type
def get_input_names(self) -> Set[str]:
return {"image", "model_id", "disable_active_learning"}
def get_output_names(self) -> Set[str]:
return {"prediction_type"}
def validate_field_selector(
self, field_name: str, input_step: GraphNone, index: Optional[int] = None
) -> None:
if not is_selector(selector_or_value=getattr(self, field_name)):
raise ExecutionGraphError(
f"Attempted to validate selector value for field {field_name}, but field is not selector."
)
validate_selector_holds_image(
step_type=self.type,
field_name=field_name,
input_step=input_step,
)
validate_selector_is_inference_parameter(
step_type=self.type,
field_name=field_name,
input_step=input_step,
applicable_fields={"model_id", "disable_active_learning"},
)
def validate_field_binding(self, field_name: str, value: Any) -> None:
if field_name == "image":
validate_image_biding(value=value)
elif field_name == "model_id":
validate_field_has_given_type(
field_name=field_name,
allowed_types=[str],
value=value,
error=VariableTypeError,
)
elif field_name == "disable_active_learning":
validate_field_has_given_type(
field_name=field_name,
allowed_types=[bool],
value=value,
error=VariableTypeError,
)
class ClassificationModel(RoboflowModel):
type: Literal["ClassificationModel"]
confidence: Union[Optional[float], str] = Field(default=0.4)
@field_validator("confidence")
@classmethod
def confidence_must_be_selector_or_number(
cls, value: Any
) -> Union[Optional[float], str]:
validate_field_is_in_range_zero_one_or_empty_or_selector(value=value)
return value
def get_input_names(self) -> Set[str]:
inputs = super().get_input_names()
inputs.add("confidence")
return inputs
def get_output_names(self) -> Set[str]:
outputs = super().get_output_names()
outputs.update(["predictions", "top", "confidence", "parent_id"])
return outputs
def validate_field_selector(
self, field_name: str, input_step: GraphNone, index: Optional[int] = None
) -> None:
super().validate_field_selector(field_name=field_name, input_step=input_step)
validate_selector_is_inference_parameter(
step_type=self.type,
field_name=field_name,
input_step=input_step,
applicable_fields={"confidence"},
)
def validate_field_binding(self, field_name: str, value: Any) -> None:
super().validate_field_binding(field_name=field_name, value=value)
if field_name == "confidence":
if value is None:
raise VariableTypeError("Parameter `confidence` cannot be None")
validate_value_is_empty_or_number_in_range_zero_one(
value=value, error=VariableTypeError
)
class MultiLabelClassificationModel(RoboflowModel):
type: Literal["MultiLabelClassificationModel"]
confidence: Union[Optional[float], str] = Field(default=0.4)
@field_validator("confidence")
@classmethod
def confidence_must_be_selector_or_number(
cls, value: Any
) -> Union[Optional[float], str]:
validate_field_is_in_range_zero_one_or_empty_or_selector(value=value)
return value
def get_input_names(self) -> Set[str]:
inputs = super().get_input_names()
inputs.add("confidence")
return inputs
def get_output_names(self) -> Set[str]:
outputs = super().get_output_names()
outputs.update(["predictions", "predicted_classes", "parent_id"])
return outputs
def validate_field_selector(
self, field_name: str, input_step: GraphNone, index: Optional[int] = None
) -> None:
super().validate_field_selector(field_name=field_name, input_step=input_step)
validate_selector_is_inference_parameter(
step_type=self.type,
field_name=field_name,
input_step=input_step,
applicable_fields={"confidence"},
)
def validate_field_binding(self, field_name: str, value: Any) -> None:
super().validate_field_binding(field_name=field_name, value=value)
if field_name == "confidence":
if value is None:
raise VariableTypeError("Parameter `confidence` cannot be None")
validate_value_is_empty_or_number_in_range_zero_one(
value=value, error=VariableTypeError
)
class ObjectDetectionModel(RoboflowModel):
type: Literal["ObjectDetectionModel"]
class_agnostic_nms: Union[Optional[bool], str] = Field(default=False)
class_filter: Union[Optional[List[str]], str] = Field(default=None)
confidence: Union[Optional[float], str] = Field(default=0.4)
iou_threshold: Union[Optional[float], str] = Field(default=0.3)
max_detections: Union[Optional[int], str] = Field(default=300)
max_candidates: Union[Optional[int], str] = Field(default=3000)
@field_validator("class_agnostic_nms")
@classmethod
def class_agnostic_nms_must_be_selector_or_bool(
cls, value: Any
) -> Union[Optional[bool], str]:
validate_field_is_selector_or_has_given_type(
field_name="class_agnostic_nms",
allowed_types=[type(None), bool],
value=value,
)
return value
@field_validator("class_filter")
@classmethod
def class_filter_must_be_selector_or_list_of_string(
cls, value: Any
) -> Union[Optional[List[str]], str]:
validate_field_is_empty_or_selector_or_list_of_string(
value=value, field_name="class_filter"
)
return value
@field_validator("confidence", "iou_threshold")
@classmethod
def field_must_be_selector_or_number_from_zero_to_one(
cls, value: Any
) -> Union[Optional[float], str]:
validate_field_is_in_range_zero_one_or_empty_or_selector(
value=value, field_name="confidence | iou_threshold"
)
return value
@field_validator("max_detections", "max_candidates")
@classmethod
def field_must_be_selector_or_positive_number(
cls, value: Any
) -> Union[Optional[int], str]:
validate_value_is_empty_or_selector_or_positive_number(
value=value,
field_name="max_detections | max_candidates",
)
return value
def get_input_names(self) -> Set[str]:
inputs = super().get_input_names()
inputs.update(
[
"class_agnostic_nms",
"class_filter",
"confidence",
"iou_threshold",
"max_detections",
"max_candidates",
]
)
return inputs
def get_output_names(self) -> Set[str]:
outputs = super().get_output_names()
outputs.update(["predictions", "parent_id", "image"])
return outputs
def validate_field_selector(
self, field_name: str, input_step: GraphNone, index: Optional[int] = None
) -> None:
super().validate_field_selector(field_name=field_name, input_step=input_step)
validate_selector_is_inference_parameter(
step_type=self.type,
field_name=field_name,
input_step=input_step,
applicable_fields={
"class_agnostic_nms",
"class_filter",
"confidence",
"iou_threshold",
"max_detections",
"max_candidates",
},
)
def validate_field_binding(self, field_name: str, value: Any) -> None:
super().validate_field_binding(field_name=field_name, value=value)
if value is None:
raise VariableTypeError(f"Parameter `{field_name}` cannot be None")
if field_name == "class_agnostic_nms":
validate_field_has_given_type(
field_name=field_name,
allowed_types=[bool],
value=value,
error=VariableTypeError,
)
elif field_name == "class_filter":
if value is None:
return None
validate_field_is_list_of_string(
value=value, field_name=field_name, error=VariableTypeError
)
elif field_name == "confidence" or field_name == "iou_threshold":
validate_value_is_empty_or_number_in_range_zero_one(
value=value,
field_name=field_name,
error=VariableTypeError,
)
elif field_name == "max_detections" or field_name == "max_candidates":
validate_value_is_empty_or_positive_number(
value=value,
field_name=field_name,
error=VariableTypeError,
)
class KeypointsDetectionModel(ObjectDetectionModel):
type: Literal["KeypointsDetectionModel"]
keypoint_confidence: Union[Optional[float], str] = Field(default=0.0)
@field_validator("keypoint_confidence")
@classmethod
def keypoint_confidence_field_must_be_selector_or_number_from_zero_to_one(
cls, value: Any
) -> Union[Optional[float], str]:
validate_field_is_in_range_zero_one_or_empty_or_selector(
value=value, field_name="keypoint_confidence"
)
return value
def get_input_names(self) -> Set[str]:
inputs = super().get_input_names()
inputs.add("keypoint_confidence")
return inputs
def validate_field_selector(
self, field_name: str, input_step: GraphNone, index: Optional[int] = None
) -> None:
super().validate_field_selector(field_name=field_name, input_step=input_step)
validate_selector_is_inference_parameter(
step_type=self.type,
field_name=field_name,
input_step=input_step,
applicable_fields={"keypoint_confidence"},
)
def validate_field_binding(self, field_name: str, value: Any) -> None:
super().validate_field_binding(field_name=field_name, value=value)
if field_name == "keypoint_confidence":
validate_value_is_empty_or_number_in_range_zero_one(
value=value,
field_name=field_name,
error=VariableTypeError,
)
DECODE_MODES = {"accurate", "tradeoff", "fast"}
class InstanceSegmentationModel(ObjectDetectionModel):
type: Literal["InstanceSegmentationModel"]
mask_decode_mode: Optional[str] = Field(default="accurate")
tradeoff_factor: Union[Optional[float], str] = Field(default=0.0)
@field_validator("mask_decode_mode")
@classmethod
def mask_decode_mode_must_be_selector_or_one_of_allowed_values(
cls, value: Any
) -> Optional[str]:
validate_field_is_selector_or_one_of_values(
value=value,
field_name="mask_decode_mode",
selected_values=DECODE_MODES,
)
return value
@field_validator("tradeoff_factor")
@classmethod
def field_must_be_selector_or_number_from_zero_to_one(
cls, value: Any
) -> Union[Optional[float], str]:
validate_field_is_in_range_zero_one_or_empty_or_selector(
value=value, field_name="tradeoff_factor"
)
return value
def get_input_names(self) -> Set[str]:
inputs = super().get_input_names()
inputs.update(["mask_decode_mode", "tradeoff_factor"])
return inputs
def validate_field_selector(
self, field_name: str, input_step: GraphNone, index: Optional[int] = None
) -> None:
super().validate_field_selector(field_name=field_name, input_step=input_step)
validate_selector_is_inference_parameter(
step_type=self.type,
field_name=field_name,
input_step=input_step,
applicable_fields={"mask_decode_mode", "tradeoff_factor"},
)
def validate_field_binding(self, field_name: str, value: Any) -> None:
super().validate_field_binding(field_name=field_name, value=value)
if field_name == "mask_decode_mode":
validate_field_is_one_of_selected_values(
value=value,
field_name=field_name,
selected_values=DECODE_MODES,
error=VariableTypeError,
)
elif field_name == "tradeoff_factor":
validate_value_is_empty_or_number_in_range_zero_one(
value=value,
field_name=field_name,
error=VariableTypeError,
)
class OCRModel(BaseModel, StepInterface):
type: Literal["OCRModel"]
name: str
image: Union[str, List[str]]
@field_validator("image")
@classmethod
def image_must_only_hold_selectors(cls, value: Any) -> Union[str, List[str]]:
validate_image_is_valid_selector(value=value)
return value
def validate_field_selector(
self, field_name: str, input_step: GraphNone, index: Optional[int] = None
) -> None:
if not is_selector(selector_or_value=getattr(self, field_name)):
raise ExecutionGraphError(
f"Attempted to validate selector value for field {field_name}, but field is not selector."
)
validate_selector_holds_image(
step_type=self.type,
field_name=field_name,
input_step=input_step,
)
def validate_field_binding(self, field_name: str, value: Any) -> None:
if field_name == "image":
validate_image_biding(value=value)
def get_type(self) -> str:
return self.type
def get_input_names(self) -> Set[str]:
return {"image"}
def get_output_names(self) -> Set[str]:
return {"result", "parent_id", "prediction_type"}
class Crop(BaseModel, StepInterface):
type: Literal["Crop"]
name: str
image: Union[str, List[str]]
detections: str
@field_validator("image")
@classmethod
def image_must_only_hold_selectors(cls, value: Any) -> Union[str, List[str]]:
validate_image_is_valid_selector(value=value)
return value
@field_validator("detections")
@classmethod
def detections_must_hold_selector(cls, value: Any) -> str:
if not is_selector(selector_or_value=value):
raise ValueError("`detections` field can only contain selector values")
return value
def get_type(self) -> str:
return self.type
def get_input_names(self) -> Set[str]:
return {"image", "detections"}
def get_output_names(self) -> Set[str]:
return {"crops", "parent_id"}
def validate_field_selector(
self, field_name: str, input_step: GraphNone, index: Optional[int] = None
) -> None:
if not is_selector(selector_or_value=getattr(self, field_name)):
raise ExecutionGraphError(
f"Attempted to validate selector value for field {field_name}, but field is not selector."
)
validate_selector_holds_image(
step_type=self.type,
field_name=field_name,
input_step=input_step,
)
validate_selector_holds_detections(
step_name=self.name,
image_selector=self.image,
detections_selector=self.detections,
field_name=field_name,
input_step=input_step,
)
def validate_field_binding(self, field_name: str, value: Any) -> None:
if field_name == "image":
validate_image_biding(value=value)
class Operator(Enum):
EQUAL = "equal"
NOT_EQUAL = "not_equal"
LOWER_THAN = "lower_than"
GREATER_THAN = "greater_than"
LOWER_OR_EQUAL_THAN = "lower_or_equal_than"
GREATER_OR_EQUAL_THAN = "greater_or_equal_than"
IN = "in"
class Condition(BaseModel, StepInterface):
type: Literal["Condition"]
name: str
left: Union[float, int, bool, str, list, set]
operator: Operator
right: Union[float, int, bool, str, list, set]
step_if_true: str
step_if_false: str
def get_type(self) -> str:
return self.type
def get_input_names(self) -> Set[str]:
return {"left", "right"}
def get_output_names(self) -> Set[str]:
return set()
def validate_field_selector(
self, field_name: str, input_step: GraphNone, index: Optional[int] = None
) -> None:
if not is_selector(selector_or_value=getattr(self, field_name)):
raise ExecutionGraphError(
f"Attempted to validate selector value for field {field_name}, but field is not selector."
)
input_type = input_step.get_type()
if field_name in {"left", "right"}:
if input_type == "InferenceImage":
raise InvalidStepInputDetected(
f"Field {field_name} of step {self.type} comes from invalid input type: {input_type}. "
f"Expected: anything else than `InferenceImage`"
)
def validate_field_binding(self, field_name: str, value: Any) -> None:
pass
class BinaryOperator(Enum):
OR = "or"
AND = "and"
class DetectionFilterDefinition(BaseModel):
type: Literal["DetectionFilterDefinition"]
field_name: str
operator: Operator
reference_value: Union[float, int, bool, str, list, set]
class CompoundDetectionFilterDefinition(BaseModel):
type: Literal["CompoundDetectionFilterDefinition"]
left: DetectionFilterDefinition
operator: BinaryOperator
right: DetectionFilterDefinition
class DetectionFilter(BaseModel, StepInterface):
type: Literal["DetectionFilter"]
name: str
predictions: str
filter_definition: Annotated[
Union[DetectionFilterDefinition, CompoundDetectionFilterDefinition],
Field(discriminator="type"),
]
def get_input_names(self) -> Set[str]:
return {"predictions"}
def get_output_names(self) -> Set[str]:
return {"predictions", "parent_id", "image", "prediction_type"}
def validate_field_selector(
self, field_name: str, input_step: GraphNone, index: Optional[int] = None
) -> None:
if not is_selector(selector_or_value=getattr(self, field_name)):
raise ExecutionGraphError(
f"Attempted to validate selector value for field {field_name}, but field is not selector."
)
validate_selector_holds_detections(
step_name=self.name,
image_selector=None,
detections_selector=self.predictions,
field_name=field_name,
input_step=input_step,
applicable_fields={"predictions"},
)
def validate_field_binding(self, field_name: str, value: Any) -> None:
pass
def get_type(self) -> str:
return self.type
class DetectionOffset(BaseModel, StepInterface):
type: Literal["DetectionOffset"]
name: str
predictions: str
offset_x: Union[int, str]
offset_y: Union[int, str]
def get_input_names(self) -> Set[str]:
return {"predictions", "offset_x", "offset_y"}
def get_output_names(self) -> Set[str]:
return {"predictions", "parent_id", "image", "prediction_type"}
def validate_field_selector(
self, field_name: str, input_step: GraphNone, index: Optional[int] = None
) -> None:
if not is_selector(selector_or_value=getattr(self, field_name)):
raise ExecutionGraphError(
f"Attempted to validate selector value for field {field_name}, but field is not selector."
)
validate_selector_holds_detections(
step_name=self.name,
image_selector=None,
detections_selector=self.predictions,
field_name=field_name,
input_step=input_step,
applicable_fields={"predictions"},
)
validate_selector_is_inference_parameter(
step_type=self.type,
field_name=field_name,
input_step=input_step,
applicable_fields={"offset_x", "offset_y"},
)
def validate_field_binding(self, field_name: str, value: Any) -> None:
if field_name in {"offset_x", "offset_y"}:
validate_field_has_given_type(
field_name=field_name,
value=value,
allowed_types=[int],
error=VariableTypeError,
)
def get_type(self) -> str:
return self.type
class AbsoluteStaticCrop(BaseModel, StepInterface):
type: Literal["AbsoluteStaticCrop"]
name: str
image: Union[str, List[str]]
x_center: Union[int, str]
y_center: Union[int, str]
width: Union[int, str]
height: Union[int, str]
@field_validator("image")
@classmethod
def image_must_only_hold_selectors(cls, value: Any) -> Union[str, List[str]]:
validate_image_is_valid_selector(value=value)
return value
@field_validator("x_center", "y_center", "width", "height")
@classmethod
def validate_crops_coordinates(cls, value: Any) -> str:
validate_value_is_empty_or_selector_or_positive_number(
value=value, field_name="x_center | y_center | width | height"
)
return value
def get_type(self) -> str:
return self.type
def get_input_names(self) -> Set[str]:
return {"image", "x_center", "y_center", "width", "height"}
def get_output_names(self) -> Set[str]:
return {"crops", "parent_id"}
def validate_field_selector(
self, field_name: str, input_step: GraphNone, index: Optional[int] = None
) -> None:
if not is_selector(selector_or_value=getattr(self, field_name)):
raise ExecutionGraphError(
f"Attempted to validate selector value for field {field_name}, but field is not selector."
)
validate_selector_holds_image(
step_type=self.type,
field_name=field_name,
input_step=input_step,
)
validate_selector_is_inference_parameter(
step_type=self.type,
field_name=field_name,
input_step=input_step,
applicable_fields={"x_center", "y_center", "width", "height"},
)
def validate_field_binding(self, field_name: str, value: Any) -> None:
if field_name == "image":
validate_image_biding(value=value)
if field_name in {"x_center", "y_center", "width", "height"}:
if (
not issubclass(type(value), int) and not issubclass(type(value), float)
) or value != round(value):
raise VariableTypeError(
f"Field {field_name} of step {self.type} must be integer"
)
class RelativeStaticCrop(BaseModel, StepInterface):
type: Literal["RelativeStaticCrop"]
name: str
image: Union[str, List[str]]
x_center: Union[float, str]
y_center: Union[float, str]
width: Union[float, str]
height: Union[float, str]
@field_validator("image")
@classmethod
def image_must_only_hold_selectors(cls, value: Any) -> Union[str, List[str]]:
validate_image_is_valid_selector(value=value)
return value
@field_validator("x_center", "y_center", "width", "height")
@classmethod
def detections_must_hold_selector(cls, value: Any) -> str:
if issubclass(type(value), str):
if not is_selector(selector_or_value=value):
raise ValueError("Field must be either float of valid selector")
elif not issubclass(type(value), float):
raise ValueError("Field must be either float of valid selector")
return value
def get_type(self) -> str:
return self.type
def get_input_names(self) -> Set[str]:
return {"image", "x_center", "y_center", "width", "height"}
def get_output_names(self) -> Set[str]:
return {"crops", "parent_id"}
def validate_field_selector(
self, field_name: str, input_step: GraphNone, index: Optional[int] = None
) -> None:
if not is_selector(selector_or_value=getattr(self, field_name)):
raise ExecutionGraphError(
f"Attempted to validate selector value for field {field_name}, but field is not selector."
)
validate_selector_holds_image(
step_type=self.type,
field_name=field_name,
input_step=input_step,
)
validate_selector_is_inference_parameter(
step_type=self.type,
field_name=field_name,
input_step=input_step,
applicable_fields={"x_center", "y_center", "width", "height"},
)
def validate_field_binding(self, field_name: str, value: Any) -> None:
if field_name == "image":
validate_image_biding(value=value)
if field_name in {"x_center", "y_center", "width", "height"}:
validate_field_has_given_type(
field_name=field_name,
value=value,
allowed_types=[float],
error=VariableTypeError,
)
class ClipComparison(BaseModel, StepInterface):
type: Literal["ClipComparison"]
name: str
image: Union[str, List[str]]
text: Union[str, List[str]]
@field_validator("image")
@classmethod
def image_must_only_hold_selectors(cls, value: Any) -> Union[str, List[str]]:
validate_image_is_valid_selector(value=value)
return value
@field_validator("text")
@classmethod
def text_must_be_valid(cls, value: Any) -> Union[str, List[str]]:
if is_selector(selector_or_value=value):
return value
if issubclass(type(value), list):
validate_field_is_list_of_string(value=value, field_name="text")
elif not issubclass(type(value), str):
raise ValueError("`text` field given must be string or list of strings")
return value
def validate_field_selector(
self, field_name: str, input_step: GraphNone, index: Optional[int] = None
) -> None:
if not is_selector(selector_or_value=getattr(self, field_name)):
raise ExecutionGraphError(
f"Attempted to validate selector value for field {field_name}, but field is not selector."
)
validate_selector_holds_image(
step_type=self.type,
field_name=field_name,
input_step=input_step,
)
validate_selector_is_inference_parameter(
step_type=self.type,
field_name=field_name,
input_step=input_step,
applicable_fields={"text"},
)
def validate_field_binding(self, field_name: str, value: Any) -> None:
if field_name == "image":
validate_image_biding(value=value)
if field_name == "text":
if issubclass(type(value), list):
validate_field_is_list_of_string(
value=value, field_name=field_name, error=VariableTypeError
)
elif not issubclass(type(value), str):
validate_field_has_given_type(
value=value,
field_name=field_name,
allowed_types=[str],
error=VariableTypeError,
)
def get_type(self) -> str:
return self.type
def get_input_names(self) -> Set[str]:
return {"image", "text"}
def get_output_names(self) -> Set[str]:
return {"similarity", "parent_id", "predictions_type"}
class AggregationMode(Enum):
AVERAGE = "average"
MAX = "max"
MIN = "min"
class DetectionsConsensus(BaseModel, StepInterface):
type: Literal["DetectionsConsensus"]
name: str
predictions: List[str]
required_votes: Union[int, str]
class_aware: Union[bool, str] = Field(default=True)
iou_threshold: Union[float, str] = Field(default=0.3)
confidence: Union[float, str] = Field(default=0.0)
classes_to_consider: Optional[Union[List[str], str]] = Field(default=None)
required_objects: Optional[Union[int, Dict[str, int], str]] = Field(default=None)
presence_confidence_aggregation: AggregationMode = Field(
default=AggregationMode.MAX
)
detections_merge_confidence_aggregation: AggregationMode = Field(
default=AggregationMode.AVERAGE
)
detections_merge_coordinates_aggregation: AggregationMode = Field(
default=AggregationMode.AVERAGE
)
@field_validator("predictions")
@classmethod
def predictions_must_be_list_of_selectors(cls, value: Any) -> List[str]:
validate_field_is_list_of_selectors(value=value, field_name="predictions")
if len(value) < 1:
raise ValueError(
"There must be at least 1 `predictions` selectors in consensus step"
)
return value
@field_validator("required_votes")
@classmethod
def required_votes_must_be_selector_or_positive_integer(
cls, value: Any
) -> Union[str, int]:
if value is None:
raise ValueError("Field `required_votes` is required.")
validate_value_is_empty_or_selector_or_positive_number(
value=value, field_name="required_votes"
)
return value
@field_validator("class_aware")
@classmethod
def class_aware_must_be_selector_or_boolean(cls, value: Any) -> Union[str, bool]:
validate_field_is_selector_or_has_given_type(
value=value, field_name="class_aware", allowed_types=[bool]
)
return value
@field_validator("iou_threshold", "confidence")
@classmethod
def field_must_be_selector_or_number_from_zero_to_one(
cls, value: Any
) -> Union[str, float]:
if value is None:
raise ValueError("Fields `iou_threshold` and `confidence` cannot be None")
validate_field_is_in_range_zero_one_or_empty_or_selector(
value=value, field_name="iou_threshold | confidence"
)
return value
@field_validator("classes_to_consider")
@classmethod
def classes_to_consider_must_be_empty_or_selector_or_list_of_strings(
cls, value: Any
) -> Optional[Union[str, List[str]]]:
validate_field_is_empty_or_selector_or_list_of_string(
value=value, field_name="classes_to_consider"
)
return value
@field_validator("required_objects")
@classmethod
def required_objects_field_must_be_valid(
cls, value: Any
) -> Optional[Union[str, int, Dict[str, int]]]:
if value is None:
return value
validate_field_is_selector_or_has_given_type(
value=value, field_name="required_objects", allowed_types=[int, dict]
)
if issubclass(type(value), int):
validate_value_is_empty_or_positive_number(
value=value, field_name="required_objects"
)
return value
elif issubclass(type(value), dict):
for k, v in value.items():
if v is None:
raise ValueError(f"Field `required_objects[{k}]` must not be None.")
validate_value_is_empty_or_positive_number(
value=v, field_name=f"required_objects[{k}]"
)
return value
def get_input_names(self) -> Set[str]:
return {
"predictions",
"required_votes",
"class_aware",
"iou_threshold",
"confidence",
"classes_to_consider",
"required_objects",
}
def get_output_names(self) -> Set[str]:
return {
"parent_id",
"predictions",
"image",
"object_present",
"presence_confidence",
"predictions_type",
}
def validate_field_selector(
self, field_name: str, input_step: GraphNone, index: Optional[int] = None
) -> None:
if field_name != "predictions" and not is_selector(
selector_or_value=getattr(self, field_name)
):
raise ExecutionGraphError(
f"Attempted to validate selector value for field {field_name}, but field is not selector."
)
if field_name == "predictions":
if index is None or index > len(self.predictions):
raise ExecutionGraphError(
f"Attempted to validate selector value for field {field_name}, which requires multiple inputs, "
f"but `index` not provided."
)
if not is_selector(
selector_or_value=self.predictions[index],
):
raise ExecutionGraphError(
f"Attempted to validate selector value for field {field_name}[{index}], but field is not selector."
)
validate_selector_holds_detections(
step_name=self.name,
image_selector=None,
detections_selector=self.predictions[index],
field_name=field_name,
input_step=input_step,
applicable_fields={"predictions"},
)
return None
validate_selector_is_inference_parameter(
step_type=self.type,
field_name=field_name,
input_step=input_step,
applicable_fields={
"required_votes",
"class_aware",
"iou_threshold",
"confidence",
"classes_to_consider",
"required_objects",
},
)
def validate_field_binding(self, field_name: str, value: Any) -> None:
if field_name == "required_votes":
if value is None:
raise VariableTypeError("Field `required_votes` cannot be None.")
validate_value_is_empty_or_positive_number(
value=value, field_name="required_votes", error=VariableTypeError
)
elif field_name == "class_aware":
validate_field_has_given_type(
field_name=field_name,
allowed_types=[bool],
value=value,
error=VariableTypeError,
)
elif field_name in {"iou_threshold", "confidence"}:
if value is None:
raise VariableTypeError(f"Fields `{field_name}` cannot be None.")
validate_value_is_empty_or_number_in_range_zero_one(
value=value,
field_name=field_name,
error=VariableTypeError,
)
elif field_name == "classes_to_consider":
if value is None:
return None
validate_field_is_list_of_string(
value=value,
field_name=field_name,
error=VariableTypeError,
)
elif field_name == "required_objects":
self._validate_required_objects_binding(value=value)
return None
def get_type(self) -> str:
return self.type
def _validate_required_objects_binding(self, value: Any) -> None:
if value is None:
return value
validate_field_has_given_type(
value=value,
field_name="required_objects",
allowed_types=[int, dict],
error=VariableTypeError,
)
if issubclass(type(value), int):
validate_value_is_empty_or_positive_number(
value=value,
field_name="required_objects",
error=VariableTypeError,
)
return None
for k, v in value.items():
if v is None:
raise VariableTypeError(
f"Field `required_objects[{k}]` must not be None."
)
validate_value_is_empty_or_positive_number(
value=v,
field_name=f"required_objects[{k}]",
error=VariableTypeError,
)
ACTIVE_LEARNING_DATA_COLLECTOR_ELIGIBLE_SELECTORS = {
"ObjectDetectionModel": "predictions",
"KeypointsDetectionModel": "predictions",
"InstanceSegmentationModel": "predictions",
"DetectionFilter": "predictions",
"DetectionsConsensus": "predictions",
"DetectionOffset": "predictions",
"YoloWorld": "predictions",
"ClassificationModel": "top",
}
class DisabledActiveLearningConfiguration(BaseModel):
enabled: bool
@field_validator("enabled")
@classmethod
def ensure_only_false_is_valid(cls, value: Any) -> bool:
if value is not False:
raise ValueError(
"One can only specify enabled=False in `DisabledActiveLearningConfiguration`"
)
return value
class LimitDefinition(BaseModel):
type: Literal["minutely", "hourly", "daily"]
value: PositiveInt
class RandomSamplingConfig(BaseModel):
type: Literal["random"]
name: str
traffic_percentage: confloat(ge=0.0, le=1.0)
tags: List[str] = Field(default_factory=lambda: [])
limits: List[LimitDefinition] = Field(default_factory=lambda: [])
class CloseToThresholdSampling(BaseModel):
type: Literal["close_to_threshold"]
name: str
probability: confloat(ge=0.0, le=1.0)
threshold: confloat(ge=0.0, le=1.0)
epsilon: confloat(ge=0.0, le=1.0)
max_batch_images: Optional[int] = Field(default=None)
only_top_classes: bool = Field(default=True)
minimum_objects_close_to_threshold: int = Field(default=1)
selected_class_names: Optional[List[str]] = Field(default=None)
tags: List[str] = Field(default_factory=lambda: [])
limits: List[LimitDefinition] = Field(default_factory=lambda: [])
class ClassesBasedSampling(BaseModel):
type: Literal["classes_based"]
name: str
probability: confloat(ge=0.0, le=1.0)
selected_class_names: List[str]
tags: List[str] = Field(default_factory=lambda: [])
limits: List[LimitDefinition] = Field(default_factory=lambda: [])
class DetectionsBasedSampling(BaseModel):
type: Literal["detections_number_based"]
name: str
probability: confloat(ge=0.0, le=1.0)
more_than: Optional[NonNegativeInt]
less_than: Optional[NonNegativeInt]
selected_class_names: Optional[List[str]] = Field(default=None)
tags: List[str] = Field(default_factory=lambda: [])
limits: List[LimitDefinition] = Field(default_factory=lambda: [])
class ActiveLearningBatchingStrategy(BaseModel):
batches_name_prefix: str
recreation_interval: Literal["never", "daily", "weekly", "monthly"]
max_batch_images: Optional[int] = Field(default=None)
ActiveLearningStrategyType = Annotated[
Union[
RandomSamplingConfig,
CloseToThresholdSampling,
ClassesBasedSampling,
DetectionsBasedSampling,
],
Field(discriminator="type"),
]
class EnabledActiveLearningConfiguration(BaseModel):
enabled: bool
persist_predictions: bool
sampling_strategies: List[ActiveLearningStrategyType]
batching_strategy: ActiveLearningBatchingStrategy
tags: List[str] = Field(default_factory=lambda: [])
max_image_size: Optional[Tuple[PositiveInt, PositiveInt]] = Field(default=None)
jpeg_compression_level: int = Field(default=95)
@field_validator("jpeg_compression_level")
@classmethod
def validate_json_compression_level(cls, value: Any) -> int:
validate_field_has_given_type(
field_name="jpeg_compression_level", allowed_types=[int], value=value
)
if value <= 0 or value > 100:
raise ValueError("`jpeg_compression_level` must be in range [1, 100]")
return value
class ActiveLearningDataCollector(BaseModel, StepInterface):
type: Literal["ActiveLearningDataCollector"]
name: str
image: str
predictions: str
target_dataset: str
target_dataset_api_key: Optional[str] = Field(default=None)
disable_active_learning: Union[bool, str] = Field(default=False)
active_learning_configuration: Optional[
Union[EnabledActiveLearningConfiguration, DisabledActiveLearningConfiguration]
] = Field(default=None)
@field_validator("image")
@classmethod
def image_must_only_hold_selectors(cls, value: Any) -> Union[str, List[str]]:
validate_image_is_valid_selector(value=value)
return value
@field_validator("predictions")
@classmethod
def predictions_must_hold_selector(cls, value: Any) -> str:
if not is_selector(selector_or_value=value):
raise ValueError("`predictions` field can only contain selector values")
return value
@field_validator("target_dataset")
@classmethod
def validate_target_dataset_field(cls, value: Any) -> str:
validate_field_is_selector_or_has_given_type(
value=value, field_name="target_dataset", allowed_types=[str]
)
return value
@field_validator("target_dataset_api_key")
@classmethod
def validate_target_dataset_api_key_field(cls, value: Any) -> Union[str, bool]:
validate_field_is_selector_or_has_given_type(
value=value,
field_name="target_dataset_api_key",
allowed_types=[bool, type(None)],
)
return value
@field_validator("disable_active_learning")
@classmethod
def validate_boolean_flags_or_selectors(cls, value: Any) -> Union[str, bool]:
validate_field_is_selector_or_has_given_type(
value=value, field_name="disable_active_learning", allowed_types=[bool]
)
return value
def get_type(self) -> str:
return self.type
def get_input_names(self) -> Set[str]:
return {
"image",
"predictions",
"target_dataset",
"target_dataset_api_key",
"disable_active_learning",
}
def get_output_names(self) -> Set[str]:
return set()
def validate_field_selector(
self, field_name: str, input_step: GraphNone, index: Optional[int] = None
) -> None:
selector = getattr(self, field_name)
if not is_selector(selector_or_value=selector):
raise ExecutionGraphError(
f"Attempted to validate selector value for field {field_name}, but field is not selector."
)
if field_name == "predictions":
input_step_type = input_step.get_type()
expected_last_selector_chunk = (
ACTIVE_LEARNING_DATA_COLLECTOR_ELIGIBLE_SELECTORS.get(input_step_type)
)
if expected_last_selector_chunk is None:
raise ExecutionGraphError(
f"Attempted to validate predictions selector of {self.name} step, but input step of type: "
f"{input_step_type} does match by type."
)
if get_last_selector_chunk(selector) != expected_last_selector_chunk:
raise ExecutionGraphError(
f"It is only allowed to refer to {input_step_type} step output named {expected_last_selector_chunk}. "
f"Reference that was found: {selector}"
)
input_step_image = getattr(input_step, "image", self.image)
if input_step_image != self.image:
raise ExecutionGraphError(
f"ActiveLearningDataCollector step refers to input step that uses reference to different image. "
f"ActiveLearningDataCollector step image: {self.image}. Input step (of type {input_step_image}) "
f"uses {input_step_image}."
)
validate_selector_holds_image(
step_type=self.type,
field_name=field_name,
input_step=input_step,
)
validate_selector_is_inference_parameter(
step_type=self.type,
field_name=field_name,
input_step=input_step,
applicable_fields={
"target_dataset",
"target_dataset_api_key",
"disable_active_learning",
},
)
def validate_field_binding(self, field_name: str, value: Any) -> None:
if field_name == "image":
validate_image_biding(value=value)
elif field_name in {"disable_active_learning"}:
validate_field_has_given_type(
field_name=field_name,
allowed_types=[bool],
value=value,
error=VariableTypeError,
)
elif field_name in {"target_dataset"}:
validate_field_has_given_type(
field_name=field_name,
allowed_types=[str],
value=value,
error=VariableTypeError,
)
elif field_name in {"target_dataset_api_key"}:
validate_field_has_given_type(
field_name=field_name,
allowed_types=[str],
value=value,
error=VariableTypeError,
)
class YoloWorld(BaseModel, StepInterface):
type: Literal["YoloWorld"]
name: str
image: str
class_names: Union[str, List[str]]
version: Optional[str] = Field(default="l")
confidence: Union[Optional[float], str] = Field(default=0.4)
@field_validator("image")
@classmethod
def image_must_only_hold_selectors(cls, value: Any) -> Union[str, List[str]]:
validate_image_is_valid_selector(value=value)
return value
@field_validator("class_names")
@classmethod
def validate_class_names(cls, value: Any) -> Union[str, List[str]]:
if is_selector(selector_or_value=value):
return value
if issubclass(type(value), list):
validate_field_is_list_of_string(value=value, field_name="class_names")
return value
raise ValueError(
"`class_names` field given must be selector or list of strings"
)
@field_validator("version")
@classmethod
def validate_model_version(cls, value: Any) -> Optional[str]:
validate_field_is_selector_or_one_of_values(
value=value,
selected_values={None, "s", "m", "l"},
field_name="version",
)
return value
@field_validator("confidence")
@classmethod
def field_must_be_selector_or_number_from_zero_to_one(
cls, value: Any
) -> Union[Optional[float], str]:
if value is None:
return None
validate_field_is_in_range_zero_one_or_empty_or_selector(
value=value, field_name="confidence"
)
return value
def get_input_names(self) -> Set[str]:
return {"image", "class_names", "version", "confidence"}
def get_output_names(self) -> Set[str]:
return {"predictions", "parent_id", "image", "prediction_type"}
def validate_field_selector(
self, field_name: str, input_step: GraphNone, index: Optional[int] = None
) -> None:
selector = getattr(self, field_name)
if not is_selector(selector_or_value=selector):
raise ExecutionGraphError(
f"Attempted to validate selector value for field {field_name}, but field is not selector."
)
validate_selector_holds_image(
step_type=self.type,
field_name=field_name,
input_step=input_step,
)
validate_selector_is_inference_parameter(
step_type=self.type,
field_name=field_name,
input_step=input_step,
applicable_fields={"class_names", "version", "confidence"},
)
def validate_field_binding(self, field_name: str, value: Any) -> None:
if field_name == "image":
validate_image_biding(value=value)
elif field_name == "class_names":
validate_field_is_list_of_string(
value=value,
field_name=field_name,
error=VariableTypeError,
)
elif field_name == "version":
validate_field_is_one_of_selected_values(
value=value,
field_name=field_name,
selected_values={None, "s", "m", "l"},
error=VariableTypeError,
)
elif field_name == "confidence":
validate_value_is_empty_or_number_in_range_zero_one(
value=value,
field_name=field_name,
error=VariableTypeError,
)
def get_type(self) -> str:
return self.type