import asyncio from copy import deepcopy from functools import partial from typing import Any, Dict, List, Optional, Tuple, Union from uuid import uuid4 from inference.core.entities.requests.clip import ClipCompareRequest from inference.core.entities.requests.doctr import DoctrOCRInferenceRequest from inference.core.entities.requests.inference import ( ClassificationInferenceRequest, InstanceSegmentationInferenceRequest, KeypointsDetectionInferenceRequest, ObjectDetectionInferenceRequest, ) from inference.core.entities.requests.yolo_world import YOLOWorldInferenceRequest from inference.core.env import ( HOSTED_CLASSIFICATION_URL, HOSTED_CORE_MODEL_URL, HOSTED_DETECT_URL, HOSTED_INSTANCE_SEGMENTATION_URL, LOCAL_INFERENCE_API_URL, WORKFLOWS_REMOTE_API_TARGET, WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_BATCH_SIZE, WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, ) from inference.core.managers.base import ModelManager from inference.enterprise.workflows.complier.entities import StepExecutionMode from inference.enterprise.workflows.complier.steps_executors.constants import ( CENTER_X_KEY, CENTER_Y_KEY, ORIGIN_COORDINATES_KEY, ORIGIN_SIZE_KEY, PARENT_COORDINATES_SUFFIX, ) from inference.enterprise.workflows.complier.steps_executors.types import ( NextStepReference, OutputsLookup, ) from inference.enterprise.workflows.complier.steps_executors.utils import ( get_image, make_batches, resolve_parameter, ) from inference.enterprise.workflows.complier.utils import construct_step_selector from inference.enterprise.workflows.entities.steps import ( ClassificationModel, ClipComparison, InstanceSegmentationModel, KeypointsDetectionModel, MultiLabelClassificationModel, ObjectDetectionModel, OCRModel, RoboflowModel, StepInterface, YoloWorld, ) from inference_sdk import InferenceConfiguration, InferenceHTTPClient MODEL_TYPE2PREDICTION_TYPE = { "ClassificationModel": "classification", "MultiLabelClassificationModel": "classification", "ObjectDetectionModel": "object-detection", "InstanceSegmentationModel": "instance-segmentation", "KeypointsDetectionModel": "keypoint-detection", } async def run_roboflow_model_step( step: RoboflowModel, runtime_parameters: Dict[str, Any], outputs_lookup: OutputsLookup, model_manager: ModelManager, api_key: Optional[str], step_execution_mode: StepExecutionMode, ) -> Tuple[NextStepReference, OutputsLookup]: model_id = resolve_parameter( selector_or_value=step.model_id, runtime_parameters=runtime_parameters, outputs_lookup=outputs_lookup, ) image = get_image( step=step, runtime_parameters=runtime_parameters, outputs_lookup=outputs_lookup, ) if step_execution_mode is StepExecutionMode.LOCAL: serialised_result = await get_roboflow_model_predictions_locally( image=image, model_id=model_id, step=step, runtime_parameters=runtime_parameters, outputs_lookup=outputs_lookup, model_manager=model_manager, api_key=api_key, ) else: serialised_result = await get_roboflow_model_predictions_from_remote_api( image=image, model_id=model_id, step=step, runtime_parameters=runtime_parameters, outputs_lookup=outputs_lookup, api_key=api_key, ) serialised_result = attach_prediction_type_info( results=serialised_result, prediction_type=MODEL_TYPE2PREDICTION_TYPE[step.get_type()], ) if step.type in {"ClassificationModel", "MultiLabelClassificationModel"}: serialised_result = attach_parent_info( image=image, results=serialised_result, nested_key=None ) else: serialised_result = attach_parent_info(image=image, results=serialised_result) serialised_result = anchor_detections_in_parent_coordinates( image=image, serialised_result=serialised_result, ) outputs_lookup[construct_step_selector(step_name=step.name)] = serialised_result return None, outputs_lookup async def get_roboflow_model_predictions_locally( image: List[dict], model_id: str, step: RoboflowModel, runtime_parameters: Dict[str, Any], outputs_lookup: OutputsLookup, model_manager: ModelManager, api_key: Optional[str], ) -> List[dict]: request_constructor = MODEL_TYPE2REQUEST_CONSTRUCTOR[step.type] request = request_constructor( step=step, image=image, api_key=api_key, runtime_parameters=runtime_parameters, outputs_lookup=outputs_lookup, ) model_manager.add_model( model_id=model_id, api_key=api_key, ) result = await model_manager.infer_from_request(model_id=model_id, request=request) if issubclass(type(result), list): serialised_result = [e.dict(by_alias=True, exclude_none=True) for e in result] else: serialised_result = [result.dict(by_alias=True, exclude_none=True)] return serialised_result def construct_classification_request( step: Union[ClassificationModel, MultiLabelClassificationModel], image: Any, api_key: Optional[str], runtime_parameters: Dict[str, Any], outputs_lookup: OutputsLookup, ) -> ClassificationInferenceRequest: resolve = partial( resolve_parameter, runtime_parameters=runtime_parameters, outputs_lookup=outputs_lookup, ) return ClassificationInferenceRequest( api_key=api_key, model_id=resolve(step.model_id), image=image, confidence=resolve(step.confidence), disable_active_learning=resolve(step.disable_active_learning), ) def construct_object_detection_request( step: ObjectDetectionModel, image: Any, api_key: Optional[str], runtime_parameters: Dict[str, Any], outputs_lookup: OutputsLookup, ) -> ObjectDetectionInferenceRequest: resolve = partial( resolve_parameter, runtime_parameters=runtime_parameters, outputs_lookup=outputs_lookup, ) return ObjectDetectionInferenceRequest( api_key=api_key, model_id=resolve(step.model_id), image=image, disable_active_learning=resolve(step.disable_active_learning), class_agnostic_nms=resolve(step.class_agnostic_nms), class_filter=resolve(step.class_filter), confidence=resolve(step.confidence), iou_threshold=resolve(step.iou_threshold), max_detections=resolve(step.max_detections), max_candidates=resolve(step.max_candidates), ) def construct_instance_segmentation_request( step: InstanceSegmentationModel, image: Any, api_key: Optional[str], runtime_parameters: Dict[str, Any], outputs_lookup: OutputsLookup, ) -> InstanceSegmentationInferenceRequest: resolve = partial( resolve_parameter, runtime_parameters=runtime_parameters, outputs_lookup=outputs_lookup, ) return InstanceSegmentationInferenceRequest( api_key=api_key, model_id=resolve(step.model_id), image=image, disable_active_learning=resolve(step.disable_active_learning), class_agnostic_nms=resolve(step.class_agnostic_nms), class_filter=resolve(step.class_filter), confidence=resolve(step.confidence), iou_threshold=resolve(step.iou_threshold), max_detections=resolve(step.max_detections), max_candidates=resolve(step.max_candidates), mask_decode_mode=resolve(step.mask_decode_mode), tradeoff_factor=resolve(step.tradeoff_factor), ) def construct_keypoints_detection_request( step: KeypointsDetectionModel, image: Any, api_key: Optional[str], runtime_parameters: Dict[str, Any], outputs_lookup: OutputsLookup, ) -> KeypointsDetectionInferenceRequest: resolve = partial( resolve_parameter, runtime_parameters=runtime_parameters, outputs_lookup=outputs_lookup, ) return KeypointsDetectionInferenceRequest( api_key=api_key, model_id=resolve(step.model_id), image=image, disable_active_learning=resolve(step.disable_active_learning), class_agnostic_nms=resolve(step.class_agnostic_nms), class_filter=resolve(step.class_filter), confidence=resolve(step.confidence), iou_threshold=resolve(step.iou_threshold), max_detections=resolve(step.max_detections), max_candidates=resolve(step.max_candidates), keypoint_confidence=resolve(step.keypoint_confidence), ) MODEL_TYPE2REQUEST_CONSTRUCTOR = { "ClassificationModel": construct_classification_request, "MultiLabelClassificationModel": construct_classification_request, "ObjectDetectionModel": construct_object_detection_request, "InstanceSegmentationModel": construct_instance_segmentation_request, "KeypointsDetectionModel": construct_keypoints_detection_request, } async def get_roboflow_model_predictions_from_remote_api( image: List[dict], model_id: str, step: RoboflowModel, runtime_parameters: Dict[str, Any], outputs_lookup: OutputsLookup, api_key: Optional[str], ) -> List[dict]: api_url = resolve_model_api_url(step=step) client = InferenceHTTPClient( api_url=api_url, api_key=api_key, ) if WORKFLOWS_REMOTE_API_TARGET == "hosted": client.select_api_v0() configuration = MODEL_TYPE2HTTP_CLIENT_CONSTRUCTOR[step.type]( step=step, runtime_parameters=runtime_parameters, outputs_lookup=outputs_lookup, ) client.configure(inference_configuration=configuration) inference_input = [i["value"] for i in image] results = await client.infer_async( inference_input=inference_input, model_id=model_id, ) if not issubclass(type(results), list): return [results] return results def construct_http_client_configuration_for_classification_step( step: Union[ClassificationModel, MultiLabelClassificationModel], runtime_parameters: Dict[str, Any], outputs_lookup: OutputsLookup, ) -> InferenceConfiguration: resolve = partial( resolve_parameter, runtime_parameters=runtime_parameters, outputs_lookup=outputs_lookup, ) return InferenceConfiguration( confidence_threshold=resolve(step.confidence), disable_active_learning=resolve(step.disable_active_learning), max_batch_size=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_BATCH_SIZE, max_concurrent_requests=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, ) def construct_http_client_configuration_for_detection_step( step: ObjectDetectionModel, runtime_parameters: Dict[str, Any], outputs_lookup: OutputsLookup, ) -> InferenceConfiguration: resolve = partial( resolve_parameter, runtime_parameters=runtime_parameters, outputs_lookup=outputs_lookup, ) return InferenceConfiguration( disable_active_learning=resolve(step.disable_active_learning), class_agnostic_nms=resolve(step.class_agnostic_nms), class_filter=resolve(step.class_filter), confidence_threshold=resolve(step.confidence), iou_threshold=resolve(step.iou_threshold), max_detections=resolve(step.max_detections), max_candidates=resolve(step.max_candidates), max_batch_size=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_BATCH_SIZE, max_concurrent_requests=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, ) def construct_http_client_configuration_for_segmentation_step( step: InstanceSegmentationModel, runtime_parameters: Dict[str, Any], outputs_lookup: OutputsLookup, ) -> InferenceConfiguration: resolve = partial( resolve_parameter, runtime_parameters=runtime_parameters, outputs_lookup=outputs_lookup, ) return InferenceConfiguration( disable_active_learning=resolve(step.disable_active_learning), class_agnostic_nms=resolve(step.class_agnostic_nms), class_filter=resolve(step.class_filter), confidence_threshold=resolve(step.confidence), iou_threshold=resolve(step.iou_threshold), max_detections=resolve(step.max_detections), max_candidates=resolve(step.max_candidates), mask_decode_mode=resolve(step.mask_decode_mode), tradeoff_factor=resolve(step.tradeoff_factor), max_batch_size=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_BATCH_SIZE, max_concurrent_requests=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, ) def construct_http_client_configuration_for_keypoints_detection_step( step: KeypointsDetectionModel, runtime_parameters: Dict[str, Any], outputs_lookup: OutputsLookup, ) -> InferenceConfiguration: resolve = partial( resolve_parameter, runtime_parameters=runtime_parameters, outputs_lookup=outputs_lookup, ) return InferenceConfiguration( disable_active_learning=resolve(step.disable_active_learning), class_agnostic_nms=resolve(step.class_agnostic_nms), class_filter=resolve(step.class_filter), confidence_threshold=resolve(step.confidence), iou_threshold=resolve(step.iou_threshold), max_detections=resolve(step.max_detections), max_candidates=resolve(step.max_candidates), keypoint_confidence_threshold=resolve(step.keypoint_confidence), max_batch_size=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_BATCH_SIZE, max_concurrent_requests=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, ) MODEL_TYPE2HTTP_CLIENT_CONSTRUCTOR = { "ClassificationModel": construct_http_client_configuration_for_classification_step, "MultiLabelClassificationModel": construct_http_client_configuration_for_classification_step, "ObjectDetectionModel": construct_http_client_configuration_for_detection_step, "InstanceSegmentationModel": construct_http_client_configuration_for_segmentation_step, "KeypointsDetectionModel": construct_http_client_configuration_for_keypoints_detection_step, } async def run_yolo_world_model_step( step: YoloWorld, runtime_parameters: Dict[str, Any], outputs_lookup: OutputsLookup, model_manager: ModelManager, api_key: Optional[str], step_execution_mode: StepExecutionMode, ) -> Tuple[NextStepReference, OutputsLookup]: image = get_image( step=step, runtime_parameters=runtime_parameters, outputs_lookup=outputs_lookup, ) class_names = resolve_parameter( selector_or_value=step.class_names, runtime_parameters=runtime_parameters, outputs_lookup=outputs_lookup, ) model_version = resolve_parameter( selector_or_value=step.version, runtime_parameters=runtime_parameters, outputs_lookup=outputs_lookup, ) confidence = resolve_parameter( selector_or_value=step.confidence, runtime_parameters=runtime_parameters, outputs_lookup=outputs_lookup, ) if step_execution_mode is StepExecutionMode.LOCAL: serialised_result = await get_yolo_world_predictions_locally( image=image, class_names=class_names, model_version=model_version, confidence=confidence, model_manager=model_manager, api_key=api_key, ) else: serialised_result = await get_yolo_world_predictions_from_remote_api( image=image, class_names=class_names, model_version=model_version, confidence=confidence, step=step, api_key=api_key, ) serialised_result = attach_prediction_type_info( results=serialised_result, prediction_type="object-detection", ) serialised_result = attach_parent_info(image=image, results=serialised_result) serialised_result = anchor_detections_in_parent_coordinates( image=image, serialised_result=serialised_result, ) outputs_lookup[construct_step_selector(step_name=step.name)] = serialised_result return None, outputs_lookup async def get_yolo_world_predictions_locally( image: List[dict], class_names: List[str], model_version: Optional[str], confidence: Optional[float], model_manager: ModelManager, api_key: Optional[str], ) -> List[dict]: serialised_result = [] for single_image in image: inference_request = YOLOWorldInferenceRequest( image=single_image, yolo_world_version_id=model_version, confidence=confidence, text=class_names, ) yolo_world_model_id = load_core_model( model_manager=model_manager, inference_request=inference_request, core_model="yolo_world", api_key=api_key, ) result = await model_manager.infer_from_request( yolo_world_model_id, inference_request ) serialised_result.append(result.dict()) return serialised_result async def get_yolo_world_predictions_from_remote_api( image: List[dict], class_names: List[str], model_version: Optional[str], confidence: Optional[float], step: YoloWorld, api_key: Optional[str], ) -> List[dict]: api_url = resolve_model_api_url(step=step) client = InferenceHTTPClient( api_url=api_url, api_key=api_key, ) configuration = InferenceConfiguration( max_concurrent_requests=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, ) client.configure(inference_configuration=configuration) if WORKFLOWS_REMOTE_API_TARGET == "hosted": client.select_api_v0() image_batches = list( make_batches( iterable=image, batch_size=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, ) ) serialised_result = [] for single_batch in image_batches: batch_results = await client.infer_from_yolo_world_async( inference_input=[i["value"] for i in single_batch], class_names=class_names, model_version=model_version, confidence=confidence, ) serialised_result.extend(batch_results) return serialised_result async def run_ocr_model_step( step: OCRModel, runtime_parameters: Dict[str, Any], outputs_lookup: OutputsLookup, model_manager: ModelManager, api_key: Optional[str], step_execution_mode: StepExecutionMode, ) -> Tuple[NextStepReference, OutputsLookup]: image = get_image( step=step, runtime_parameters=runtime_parameters, outputs_lookup=outputs_lookup, ) if step_execution_mode is StepExecutionMode.LOCAL: serialised_result = await get_ocr_predictions_locally( image=image, model_manager=model_manager, api_key=api_key, ) else: serialised_result = await get_ocr_predictions_from_remote_api( step=step, image=image, api_key=api_key, ) serialised_result = attach_parent_info( image=image, results=serialised_result, nested_key=None, ) serialised_result = attach_prediction_type_info( results=serialised_result, prediction_type="ocr", ) outputs_lookup[construct_step_selector(step_name=step.name)] = serialised_result return None, outputs_lookup async def get_ocr_predictions_locally( image: List[dict], model_manager: ModelManager, api_key: Optional[str], ) -> List[dict]: serialised_result = [] for single_image in image: inference_request = DoctrOCRInferenceRequest( image=single_image, ) doctr_model_id = load_core_model( model_manager=model_manager, inference_request=inference_request, core_model="doctr", api_key=api_key, ) result = await model_manager.infer_from_request( doctr_model_id, inference_request ) serialised_result.append(result.dict()) return serialised_result async def get_ocr_predictions_from_remote_api( step: OCRModel, image: List[dict], api_key: Optional[str], ) -> List[dict]: api_url = resolve_model_api_url(step=step) client = InferenceHTTPClient( api_url=api_url, api_key=api_key, ) if WORKFLOWS_REMOTE_API_TARGET == "hosted": client.select_api_v0() configuration = InferenceConfiguration( max_batch_size=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_BATCH_SIZE, max_concurrent_requests=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, ) client.configure(configuration) result = await client.ocr_image_async( inference_input=[i["value"] for i in image], ) if len(image) == 1: return [result] return result async def run_clip_comparison_step( step: ClipComparison, runtime_parameters: Dict[str, Any], outputs_lookup: OutputsLookup, model_manager: ModelManager, api_key: Optional[str], step_execution_mode: StepExecutionMode, ) -> Tuple[NextStepReference, OutputsLookup]: image = get_image( step=step, runtime_parameters=runtime_parameters, outputs_lookup=outputs_lookup, ) text = resolve_parameter( selector_or_value=step.text, runtime_parameters=runtime_parameters, outputs_lookup=outputs_lookup, ) if step_execution_mode is StepExecutionMode.LOCAL: serialised_result = await get_clip_comparison_locally( image=image, text=text, model_manager=model_manager, api_key=api_key, ) else: serialised_result = await get_clip_comparison_from_remote_api( step=step, image=image, text=text, api_key=api_key, ) serialised_result = attach_parent_info( image=image, results=serialised_result, nested_key=None, ) serialised_result = attach_prediction_type_info( results=serialised_result, prediction_type="embeddings-comparison", ) outputs_lookup[construct_step_selector(step_name=step.name)] = serialised_result return None, outputs_lookup async def get_clip_comparison_locally( image: List[dict], text: str, model_manager: ModelManager, api_key: Optional[str], ) -> List[dict]: serialised_result = [] for single_image in image: inference_request = ClipCompareRequest( subject=single_image, subject_type="image", prompt=text, prompt_type="text" ) doctr_model_id = load_core_model( model_manager=model_manager, inference_request=inference_request, core_model="clip", api_key=api_key, ) result = await model_manager.infer_from_request( doctr_model_id, inference_request ) serialised_result.append(result.dict()) return serialised_result async def get_clip_comparison_from_remote_api( step: ClipComparison, image: List[dict], text: str, api_key: Optional[str], ) -> List[dict]: api_url = resolve_model_api_url(step=step) client = InferenceHTTPClient( api_url=api_url, api_key=api_key, ) if WORKFLOWS_REMOTE_API_TARGET == "hosted": client.select_api_v0() image_batches = list( make_batches( iterable=image, batch_size=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, ) ) serialised_result = [] for single_batch in image_batches: coroutines = [] for single_image in single_batch: coroutine = client.clip_compare_async( subject=single_image["value"], prompt=text, ) coroutines.append(coroutine) batch_results = list(await asyncio.gather(*coroutines)) serialised_result.extend(batch_results) return serialised_result def load_core_model( model_manager: ModelManager, inference_request: Union[DoctrOCRInferenceRequest, ClipCompareRequest], core_model: str, api_key: Optional[str] = None, ) -> str: if api_key: inference_request.api_key = api_key version_id_field = f"{core_model}_version_id" core_model_id = ( f"{core_model}/{inference_request.__getattribute__(version_id_field)}" ) model_manager.add_model(core_model_id, inference_request.api_key) return core_model_id def attach_prediction_type_info( results: List[Dict[str, Any]], prediction_type: str, key: str = "prediction_type", ) -> List[Dict[str, Any]]: for result in results: result[key] = prediction_type return results def attach_parent_info( image: List[Dict[str, Any]], results: List[Dict[str, Any]], nested_key: Optional[str] = "predictions", ) -> List[Dict[str, Any]]: return [ attach_parent_info_to_image_detections( image=i, predictions=p, nested_key=nested_key ) for i, p in zip(image, results) ] def attach_parent_info_to_image_detections( image: Dict[str, Any], predictions: Dict[str, Any], nested_key: Optional[str], ) -> Dict[str, Any]: predictions["parent_id"] = image["parent_id"] if nested_key is None: return predictions for prediction in predictions[nested_key]: prediction["parent_id"] = image["parent_id"] return predictions def anchor_detections_in_parent_coordinates( image: List[Dict[str, Any]], serialised_result: List[Dict[str, Any]], image_metadata_key: str = "image", detections_key: str = "predictions", ) -> List[Dict[str, Any]]: return [ anchor_image_detections_in_parent_coordinates( image=i, serialised_result=d, image_metadata_key=image_metadata_key, detections_key=detections_key, ) for i, d in zip(image, serialised_result) ] def anchor_image_detections_in_parent_coordinates( image: Dict[str, Any], serialised_result: Dict[str, Any], image_metadata_key: str = "image", detections_key: str = "predictions", ) -> Dict[str, Any]: serialised_result[f"{detections_key}{PARENT_COORDINATES_SUFFIX}"] = deepcopy( serialised_result[detections_key] ) serialised_result[f"{image_metadata_key}{PARENT_COORDINATES_SUFFIX}"] = deepcopy( serialised_result[image_metadata_key] ) if ORIGIN_COORDINATES_KEY not in image: return serialised_result shift_x, shift_y = ( image[ORIGIN_COORDINATES_KEY][CENTER_X_KEY], image[ORIGIN_COORDINATES_KEY][CENTER_Y_KEY], ) for detection in serialised_result[f"{detections_key}{PARENT_COORDINATES_SUFFIX}"]: detection["x"] += shift_x detection["y"] += shift_y serialised_result[f"{image_metadata_key}{PARENT_COORDINATES_SUFFIX}"] = image[ ORIGIN_COORDINATES_KEY ][ORIGIN_SIZE_KEY] return serialised_result ROBOFLOW_MODEL2HOSTED_ENDPOINT = { "ClassificationModel": HOSTED_CLASSIFICATION_URL, "MultiLabelClassificationModel": HOSTED_CLASSIFICATION_URL, "ObjectDetectionModel": HOSTED_DETECT_URL, "KeypointsDetectionModel": HOSTED_DETECT_URL, "InstanceSegmentationModel": HOSTED_INSTANCE_SEGMENTATION_URL, "OCRModel": HOSTED_CORE_MODEL_URL, "ClipComparison": HOSTED_CORE_MODEL_URL, } def resolve_model_api_url(step: StepInterface) -> str: if WORKFLOWS_REMOTE_API_TARGET != "hosted": return LOCAL_INFERENCE_API_URL return ROBOFLOW_MODEL2HOSTED_ENDPOINT[step.get_type()]